llama2 / app.py
Waseem7711's picture
Update app.py
75748e2 verified
# Import necessary libraries
import streamlit as st
from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
# Check if Hugging Face API token exists
if not HUGGINGFACE_API_TOKEN:
st.error("Hugging Face API token not found. Please set it in your environment variables.")
st.stop()
# Title of the app
st.title("Optimized LLaMA 2 Chatbot")
# Load the LLaMA model and tokenizer from Hugging Face
@st.cache_resource
def load_model_and_tokenizer():
"""Load the tokenizer and model."""
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-hf", # Correct model identifier
use_auth_token=HUGGINGFACE_API_TOKEN
)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", # Correct model identifier
use_auth_token=HUGGINGFACE_API_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
return tokenizer, model
# Function to generate text based on a prompt
def generate_text(prompt, tokenizer, model, max_length=256, temperature=0.6):
"""Generate text based on a prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate text with optimizations
with torch.no_grad():
generate_ids = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_k=30,
top_p=0.85,
repetition_penalty=1.1,
)
return tokenizer.decode(generate_ids[0], skip_special_tokens=True)
# Input field for user prompt
user_input = st.text_input("Enter your prompt:", "Hello, how are you?")
# Load model and tokenizer
tokenizer, model = load_model_and_tokenizer()
# Generate response when user enters a prompt
if st.button("Generate Response"):
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model)
st.write(f"**Response:** {response}")
# Optional settings in sidebar
st.sidebar.header("Settings")
max_length = st.sidebar.slider("Max Length", min_value=50, max_value=512, value=256)
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, value=0.6)
# Regenerate response with updated settings
if st.sidebar.button("Regenerate"):
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model, max_length, temperature)
st.write(f"**Response:** {response}")