shuvom's picture
Update app.py
11f5672
raw
history blame contribute delete
No virus
5.23 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
from threading import Thread
import gradio as gr
from peft import PeftModel
model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")
# tokenizer.chat_template = chat_template
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
model = PeftModel.from_pretrained(model, peft_model_id)
class ChatCompletion:
def __init__(self, model, tokenizer, system_prompt=None):
self.model = model
self.tokenizer = tokenizer
self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
# set the model in inference mode
self.model.eval()
self.system_prompt = system_prompt
def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
if temperature < 1e-2:
temperature = 1e-2
messages = []
if message_history is not None:
messages.extend(message_history)
elif system_prompt or self.system_prompt:
system_prompt = system_prompt or self.system_prompt
messages.append({"role": "system", "content":system_prompt})
messages.append({"role": "user", "content": prompt})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2
)
generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
return generated_text
def get_chat_completion(self, message, history):
messages = []
if self.system_prompt:
messages.append({"role": "system", "content":self.system_prompt})
for user_message, assistant_message in history:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "system", "content": assistant_message})
messages.append({"role": "user", "content": message})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt")
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs,
streamer=self.streamer,
max_new_tokens=2048,
temperature=0.2,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.2,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in self.streamer:
generated_text += new_text.replace(self.tokenizer.eos_token, "")
yield generated_text
thread.join()
return generated_text
def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
if temperature < 1e-2:
temperature = 1e-2
messages = []
if message_history is not None:
messages.extend(message_history)
elif system_prompt or self.system_prompt:
system_prompt = system_prompt or self.system_prompt
messages.append({"role": "system", "content":system_prompt})
messages.append({"role": "user", "content": prompt})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
repetition_penalty=1.1)
outputs = self.model.generate(**inputs, **generation_kwargs)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")
gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)