llama2 / app.py
Waseem7711's picture
Update app.py
0753e91 verified
raw
history blame
2.6 kB
# Import necessary libraries
import streamlit as st
from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
# Check if Hugging Face API token exists
if not HUGGINGFACE_API_TOKEN:
st.error("Hugging Face API token not found. Please set it in your environment variables.")
st.stop()
# Title of the app
st.title("LLaMA 2 Chatbot")
# Load the LLaMA model and tokenizer from Hugging Face
@st.cache_resource
def load_model_and_tokenizer():
"""Load the tokenizer and model."""
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-hf",
use_auth_token=HUGGINGFACE_API_TOKEN
)
# Load the model
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
use_auth_token=HUGGINGFACE_API_TOKEN,
torch_dtype=torch.float16,
device_map="auto", # This requires the `accelerate` package
)
return tokenizer, model
# Function to generate text based on a prompt
def generate_text(prompt, tokenizer, model, max_length=512, temperature=0.7):
"""Generate text based on a prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate text
with torch.no_grad():
generate_ids = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
)
return tokenizer.decode(generate_ids[0], skip_special_tokens=True)
# Input field for user prompt
user_input = st.text_input("Enter your prompt:", "Hello, how are you?")
# Load model and tokenizer
tokenizer, model = load_model_and_tokenizer()
# Generate response when user enters a prompt
if st.button("Generate Response"):
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model)
st.write(f"**Response:** {response}")
# Optional settings
st.sidebar.header("Settings")
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=512)
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7)
# Regenerate response with updated settings
if st.sidebar.button("Regenerate"):
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model, max_length, temperature)
st.write(f"**Response:** {response}")