Spaces:
Sleeping
Sleeping
import transformers | |
import streamlit as st | |
import re | |
from transformers import LlamaTokenizer, LlamaForCausalLM | |
# Cache the model and tokenizer loading to ensure they are only loaded once | |
def load_model_and_tokenizer(): | |
tokenizer = LlamaTokenizer.from_pretrained('klyang/MentaLLaMA-chat-7B') | |
model = LlamaForCausalLM.from_pretrained('klyang/MentaLLaMA-chat-7B', device_map='auto') | |
return tokenizer, model | |
tokenizer, model = load_model_and_tokenizer() | |
# Define the prediction function | |
def predict(input_text): | |
input_data = [f"Consider this post: {input_text}. Question: Does the poster suffer from depression?"] | |
inputs = tokenizer(input_data, return_tensors="pt", padding=True) | |
input_ids = inputs.input_ids.to('cuda') | |
generate_ids = model.generate(input_ids=input_ids, max_length=512) | |
trunc_ids = generate_ids[0][len(input_ids[0]):] | |
response = tokenizer.decode(trunc_ids, skip_special_tokens=True, spaces_between_special_tokens=False) | |
return response | |
# Streamlit app | |
st.title("Depression Detection App") | |
st.write("Provide a post or message to determine if it suggests the user may be suffering from depression.") | |
input_text = st.text_area("Enter a post/input text") | |
if st.button("Predict"): | |
with st.spinner("Analyzing..."): | |
output_text = predict(input_text) | |
# Splitting the text into sentences | |
sentences = output_text.strip().split('.') | |
# Extracting the last sentence for the verdict | |
verdict = sentences[-2].strip() | |
# Extracting the last 3rd to last 2nd sentences for the reason | |
reason = sentences[-4:-2] | |
# Remove "Reasoning:" if it appears in the last two lines | |
reason = [re.sub(r'reasoning:\s*', '', line, flags=re.IGNORECASE).strip() for line in reason] | |
# Displaying the verdict and reason | |
st.write("Verdict:") | |
st.write(verdict) | |
st.write("Reason:") | |
st.write(". ".join(reason)) | |