import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import datetime import os from threading import Event, Thread from uuid import uuid4 import gradio as gr import requests model_name = "decapoda-research/llama-13b-hf" adapters_name = 'timdettmers/guanaco-13b' print(f"Starting to load the model {model_name} into memory") model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, device_map={"": 0} ) model = PeftModel.from_pretrained(model, adapters_name) tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer.bos_token_id = 1 stop_token_ids = [0] max_new_tokens = 2048 start_message = """A chat between a human user and a kind AI. The assistant gives helpful, cordial, and polite answers to the user's questions.""" class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def convert_history_to_text(history): text = start_message + "".join( [ "".join( [ f"### Human: {item[0]}\n", f"### Assistant: {item[1]}\n", ] ) for item in history[:-1] ] ) text += "".join( [ "".join( [ f"### Human: {history[-1][0]}\n", f"### Assistant: {history[-1][1]}\n", ] ) ] ) return text def log_conversation(conversation_id, history, messages, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): print(f"history: {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = convert_history_to_text(history) # Tokenize the messages string input_ids = tokenizer(messages, return_tensors="pt").input_ids input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): model.generate(**generate_kwargs) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( conversation_id, history, messages, { "top_k": top_k, "top_p": top_p, "temperature": temperature, "repetition_penalty": repetition_penalty, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text history[-1][1] = partial_text yield history def get_uuid(): return str(uuid4()) with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """