# Imports import gradio as gr import transformers import torch import os from transformers import pipeline, AutoTokenizer from huggingface_hub import login HF_TOKEN = os.getenv('mentalhealth_llama_chat') login(HF_TOKEN) # Model name in Hugging Face docs model ='klyang/MentaLLaMA-chat-13B' tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True) llama_pipeline = pipeline( "text-generation", # LLM task model=model, torch_dtype=torch.float16, device_map="auto", ) SYSTEM_PROMPT = """[INST] <> You are Mentra, a friendly, empathetic mental health chatbot who listens and tries to understand the speaker’s perspective. Do not talk about anything else, focus only on mental health. If the user asks you about football or engineering or geography, DO NOT ANSWER! You do not use harmful, hurtful, rude, and crude language. <> """ # Formatting function for message and history def format_message(message: str, history: list, memory_limit: int = 20) -> str: """ Formats the message and history for the Llama model. Parameters: message (str): Current message to send. history (list): Past conversation history. memory_limit (int): Limit on how many past interactions to consider. Returns: str: Formatted message string """ # always keep len(history) <= memory_limit if len(history) > memory_limit: history = history[-memory_limit:] if len(history) == 0: return SYSTEM_PROMPT + f"{message} [/INST]" formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} " # Handle conversation history for user_msg, model_answer in history[1:]: formatted_message += f"[INST] {user_msg} [/INST] {model_answer} " # Handle the current message formatted_message += f"[INST] {message} [/INST]" return formatted_message # Generate a response from the Llama model def get_llama_response(message: str, history: list) -> str: """ Generates a conversational response from the Llama model. Parameters: message (str): User's input message. history (list): Past conversation history. Returns: str: Generated response from the Llama model. """ query = format_message(message, history) response = "" sequences = llama_pipeline( query, do_sample=True, top_k=10, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, max_length=1024, ) generated_text = sequences[0]['generated_text'] response = generated_text[len(query):] # Remove the prompt from the output print("Chatbot:", response.strip()) return response.strip() gr.ChatInterface(get_llama_response).launch(debug=True)