{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "import transformers\n", "import torch\n", "\n", "model = \"meta-llama/Llama-2-7b-chat-hf\" # meta-llama/Llama-2-7b-chat-hf\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "\n", "llama_pipeline = pipeline(\n", " \"text-generation\", # LLM task\n", " model=model,\n", " torch_dtype=torch.float16,\n", " device_map=\"auto\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SYSTEM_PROMPT = \"\"\"[INST] <>\n", "You are a helpful bot. Your answers are clear and concise.\n", "<>\n", "\n", "\"\"\"\n", "\n", "# Formatting function for message and history\n", "def format_message(message: str, history: list, memory_limit: int = 3) -> str:\n", " \"\"\"\n", " Formats the message and history for the Llama model.\n", "\n", " Parameters:\n", " message (str): Current message to send.\n", " history (list): Past conversation history.\n", " memory_limit (int): Limit on how many past interactions to consider.\n", "\n", " Returns:\n", " str: Formatted message string\n", " \"\"\"\n", " # always keep len(history) <= memory_limit\n", " if len(history) > memory_limit:\n", " history = history[-memory_limit:]\n", "\n", " if len(history) == 0:\n", " return SYSTEM_PROMPT + f\"{message} [/INST]\"\n", "\n", " formatted_message = SYSTEM_PROMPT + f\"{history[0][0]} [/INST] {history[0][1]} \"\n", "\n", " # Handle conversation history\n", " for user_msg, model_answer in history[1:]:\n", " formatted_message += f\"[INST] {user_msg} [/INST] {model_answer} \"\n", "\n", " # Handle the current message\n", " formatted_message += f\"[INST] {message} [/INST]\"\n", "\n", " return formatted_message" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate a response from the Llama model\n", "def get_llama_response(message: str, history: list) -> str:\n", " \"\"\"\n", " Generates a conversational response from the Llama model.\n", "\n", " Parameters:\n", " message (str): User's input message.\n", " history (list): Past conversation history.\n", "\n", " Returns:\n", " str: Generated response from the Llama model.\n", " \"\"\"\n", " query = format_message(message, history)\n", " response = \"\"\n", "\n", " sequences = llama_pipeline(\n", " query,\n", " do_sample=True,\n", " top_k=10,\n", " num_return_sequences=1,\n", " eos_token_id=tokenizer.eos_token_id,\n", " max_length=1024,\n", " )\n", "\n", " generated_text = sequences[0]['generated_text']\n", " response = generated_text[len(query):] # Remove the prompt from the output\n", "\n", " print(\"Chatbot:\", response.strip())\n", " return response.strip()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "\n", "gr.ChatInterface(get_llama_response).launch()\n" ] } ], "metadata": { "kernelspec": { "display_name": "itam", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 2 }