rajakr's picture
Update app.py
2cafa29 verified
import transformers
import streamlit as st
import re
from transformers import LlamaTokenizer, LlamaForCausalLM
# Cache the model and tokenizer loading to ensure they are only loaded once
@st.cache_resource(show_spinner=False)
def load_model_and_tokenizer():
tokenizer = LlamaTokenizer.from_pretrained('klyang/MentaLLaMA-chat-7B')
model = LlamaForCausalLM.from_pretrained('klyang/MentaLLaMA-chat-7B', device_map='auto')
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
# Define the prediction function
def predict(input_text):
input_data = [f"Consider this post: {input_text}. Question: Does the poster suffer from depression?"]
inputs = tokenizer(input_data, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('cuda')
generate_ids = model.generate(input_ids=input_ids, max_length=512)
trunc_ids = generate_ids[0][len(input_ids[0]):]
response = tokenizer.decode(trunc_ids, skip_special_tokens=True, spaces_between_special_tokens=False)
return response
# Streamlit app
st.title("Depression Detection App")
st.write("Provide a post or message to determine if it suggests the user may be suffering from depression.")
input_text = st.text_area("Enter a post/input text")
if st.button("Predict"):
with st.spinner("Analyzing..."):
output_text = predict(input_text)
# Splitting the text into sentences
sentences = output_text.strip().split('.')
# Extracting the last sentence for the verdict
verdict = sentences[-2].strip()
# Extracting the last 3rd to last 2nd sentences for the reason
reason = sentences[-4:-2]
# Remove "Reasoning:" if it appears in the last two lines
reason = [re.sub(r'reasoning:\s*', '', line, flags=re.IGNORECASE).strip() for line in reason]
# Displaying the verdict and reason
st.write("Verdict:")
st.write(verdict)
st.write("Reason:")
st.write(". ".join(reason))