Spaces:
Runtime error
Runtime error
File size: 2,958 Bytes
abf42db c0bdecb abf42db c0bdecb abf42db db19ba6 abf42db d45c2c9 864097c 4a2be52 35d2e2e abf42db 9fd65cd abf42db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# Imports
import gradio as gr
import transformers
import torch
import os
from transformers import pipeline, AutoTokenizer
from huggingface_hub import login
HF_TOKEN = os.getenv('mentalhealth_llama_chat')
login(HF_TOKEN)
# Model name in Hugging Face docs
model ='klyang/MentaLLaMA-chat-13B'
tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)
llama_pipeline = pipeline(
"text-generation", # LLM task
model=model,
torch_dtype=torch.float16,
device_map="auto",
)
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
You are Mentra, a mental health assistant. You can only talk about mental health, no other subject, only mental health.
You are to provide individual's with mental health support. Do not talk about any other subject, focus only on mental health!
Secondly, do not engage the user in topics like Mediccally Assisted Dying, Suicide, Murder, Self-harm, Islamophobia, Politics, and other topics of this controversial nature.
Thirdly, keep your responses short, but kind and thoughtful.
<</SYS>>
"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 20) -> str:
"""
Formats the message and history for the Llama model.
Parameters:
message (str): Current message to send.
history (list): Past conversation history.
memory_limit (int): Limit on how many past interactions to consider.
Returns:
str: Formatted message string
"""
# always keep len(history) <= memory_limit
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return SYSTEM_PROMPT + f"{message} [/INST]"
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
# Handle conversation history
for user_msg, model_answer in history[1:]:
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
# Handle the current message
formatted_message += f"<s>[INST] {message} [/INST]"
return formatted_message
# Generate a response from the Llama model
def get_llama_response(message: str, history: list) -> str:
"""
Generates a conversational response from the Llama model.
Parameters:
message (str): User's input message.
history (list): Past conversation history.
Returns:
str: Generated response from the Llama model.
"""
query = format_message(message, history)
response = ""
sequences = llama_pipeline(
query,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1024,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(query):] # Remove the prompt from the output
print("Chatbot:", response.strip())
return response.strip()
gr.ChatInterface(get_llama_response).launch(debug=True)
|