{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"import transformers\n",
"import torch\n",
"\n",
"model = \"meta-llama/Llama-2-7b-chat-hf\" # meta-llama/Llama-2-7b-chat-hf\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"llama_pipeline = pipeline(\n",
" \"text-generation\", # LLM task\n",
" model=model,\n",
" torch_dtype=torch.float16,\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SYSTEM_PROMPT = \"\"\"[INST] <>\n",
"You are a helpful bot. Your answers are clear and concise.\n",
"<>\n",
"\n",
"\"\"\"\n",
"\n",
"# Formatting function for message and history\n",
"def format_message(message: str, history: list, memory_limit: int = 3) -> str:\n",
" \"\"\"\n",
" Formats the message and history for the Llama model.\n",
"\n",
" Parameters:\n",
" message (str): Current message to send.\n",
" history (list): Past conversation history.\n",
" memory_limit (int): Limit on how many past interactions to consider.\n",
"\n",
" Returns:\n",
" str: Formatted message string\n",
" \"\"\"\n",
" # always keep len(history) <= memory_limit\n",
" if len(history) > memory_limit:\n",
" history = history[-memory_limit:]\n",
"\n",
" if len(history) == 0:\n",
" return SYSTEM_PROMPT + f\"{message} [/INST]\"\n",
"\n",
" formatted_message = SYSTEM_PROMPT + f\"{history[0][0]} [/INST] {history[0][1]} \"\n",
"\n",
" # Handle conversation history\n",
" for user_msg, model_answer in history[1:]:\n",
" formatted_message += f\"[INST] {user_msg} [/INST] {model_answer} \"\n",
"\n",
" # Handle the current message\n",
" formatted_message += f\"[INST] {message} [/INST]\"\n",
"\n",
" return formatted_message"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate a response from the Llama model\n",
"def get_llama_response(message: str, history: list) -> str:\n",
" \"\"\"\n",
" Generates a conversational response from the Llama model.\n",
"\n",
" Parameters:\n",
" message (str): User's input message.\n",
" history (list): Past conversation history.\n",
"\n",
" Returns:\n",
" str: Generated response from the Llama model.\n",
" \"\"\"\n",
" query = format_message(message, history)\n",
" response = \"\"\n",
"\n",
" sequences = llama_pipeline(\n",
" query,\n",
" do_sample=True,\n",
" top_k=10,\n",
" num_return_sequences=1,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" max_length=1024,\n",
" )\n",
"\n",
" generated_text = sequences[0]['generated_text']\n",
" response = generated_text[len(query):] # Remove the prompt from the output\n",
"\n",
" print(\"Chatbot:\", response.strip())\n",
" return response.strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"\n",
"gr.ChatInterface(get_llama_response).launch()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "itam",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}