Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import streamlit as st | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
import torch | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN") | |
# Check if Hugging Face API token exists | |
if not HUGGINGFACE_API_TOKEN: | |
st.error("Hugging Face API token not found. Please set it in your environment variables.") | |
st.stop() | |
# Title of the app | |
st.title("Optimized LLaMA 2 Chatbot") | |
# Load the LLaMA model and tokenizer from Hugging Face | |
def load_model_and_tokenizer(): | |
"""Load the tokenizer and model.""" | |
tokenizer = AutoTokenizer.from_pretrained( | |
"meta-llama/Llama-2-7b-hf", # Correct model identifier | |
use_auth_token=HUGGINGFACE_API_TOKEN | |
) | |
model = LlamaForCausalLM.from_pretrained( | |
"meta-llama/Llama-2-7b-hf", # Correct model identifier | |
use_auth_token=HUGGINGFACE_API_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
return tokenizer, model | |
# Function to generate text based on a prompt | |
def generate_text(prompt, tokenizer, model, max_length=256, temperature=0.6): | |
"""Generate text based on a prompt.""" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate text with optimizations | |
with torch.no_grad(): | |
generate_ids = model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_k=30, | |
top_p=0.85, | |
repetition_penalty=1.1, | |
) | |
return tokenizer.decode(generate_ids[0], skip_special_tokens=True) | |
# Input field for user prompt | |
user_input = st.text_input("Enter your prompt:", "Hello, how are you?") | |
# Load model and tokenizer | |
tokenizer, model = load_model_and_tokenizer() | |
# Generate response when user enters a prompt | |
if st.button("Generate Response"): | |
with st.spinner("Generating response..."): | |
response = generate_text(user_input, tokenizer, model) | |
st.write(f"**Response:** {response}") | |
# Optional settings in sidebar | |
st.sidebar.header("Settings") | |
max_length = st.sidebar.slider("Max Length", min_value=50, max_value=512, value=256) | |
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, value=0.6) | |
# Regenerate response with updated settings | |
if st.sidebar.button("Regenerate"): | |
with st.spinner("Generating response..."): | |
response = generate_text(user_input, tokenizer, model, max_length, temperature) | |
st.write(f"**Response:** {response}") | |