treadon's picture
Add key
3b81b26
raw
history blame
No virus
1.53 kB
import gradio as gr
#import peft
import transformers
import os
device = "cpu"
is_peft = False
model_id = "treadon/promt-fungineer-355M"
# if is_peft:
# config = peft.PeftConfig.from_pretrained(model_id)
# model = transformers.AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, low_cpu_mem_usage=True)
# tokenizer = transformers.AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# model = peft.PeftModel.from_pretrained(model, model_id)
# else:
auth_token = os.environ.get("hub_token") or True
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True,use_auth_token=auth_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
def generate_text(prompt):
if not prompt.startswith("BRF:"):
prompt = "BRF: " + prompt
model.eval()
# SOFT SAMPLE
inputs = tokenizer(prompt, return_tensors="pt").to(device)
samples = []
try:
for i in range(1):
outputs = model.generate(**inputs, max_length=256, do_sample=True, top_k=100, top_p=0.95, temperature=0.85, num_return_sequences=4, pad_token_id=tokenizer.eos_token_id)
for output in outputs:
sample = tokenizer.decode(output, skip_special_tokens=True)
samples.append(sample)
except Exception as e:
print(e)
return samples
iface = gr.Interface(fn=generate_text, inputs="text", outputs=("text","text","text","text") )
iface.launch()