raghavdw's picture
updated app.py generate response function
4b9d0f3 verified
raw
history blame
1.64 kB
import gradio
import torch
from transformers import AutoModelWithLMHead, AutoTokenizer
# Load model directly
loaded_tokenizer = AutoTokenizer.from_pretrained("raghavdw/finedtuned_gpt2_medQA_model")
loaded_model = AutoModelWithLMHead.from_pretrained("raghavdw/finedtuned_gpt2_medQA_model")
# Function for response generation
def generate_query_response(prompt, max_length=200):
model = loaded_model
tokenizer = loaded_tokenizer
input_ids = tokenizer.encode(prompt, return_tensors="pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids = input_ids.to(device)
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
# Gradio elements
# Input from user
in_prompt = gradio.Textbox(label="Enter your prompt")
# Output response
in_max_length = 200
# Output response
out_response = gradio.Textbox(label="Generated Response")
# Gradio
iface = gradio.Interface(fn=generate_query_response,
inputs=[in_prompt],
outputs=out_response,
title = "Medical Summary",
description = "using fine-tune medQA gpt-2 model")
# YOUR CODE HERE to launch the interface
iface.launch(share = True)