Spaces:
Runtime error
Runtime error
File size: 3,489 Bytes
cc5b602 6f619d7 d381360 6386510 51a7d9e 3eed0af 6386510 51a7d9e e6367a7 423ddc8 51a7d9e 6386510 bd34f0b 423ddc8 bd34f0b 51a7d9e 423ddc8 d381360 3eed0af 423ddc8 3eed0af d381360 4ed884e 1d4c579 4ed884e e59867b 423ddc8 3eed0af 423ddc8 e59867b 423ddc8 3eed0af 6386510 51a7d9e 36f75a7 51a7d9e 1d4c579 51a7d9e 4ed884e 51a7d9e b64165b 51a7d9e 3fb77c6 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import time
import spaces
import torch
import gradio as gr
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
TITLE = "<h1><center>Mistral-lab</center></h1>"
PLACEHOLDER = """
<center>
<p>Chat with Mistral AI LLM.</p>
</center>
"""
from huggingface_hub import snapshot_download
from pathlib import Path
mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
device = "cuda" # for GPU usage or "cpu" for CPU usage
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)
@spaces.GPU()
def stream_chat(
message: str,
history: list,
temperature: float = 0.3,
max_new_tokens: int = 1024,
):
print(f'message: {message}')
print(f'history: {history}')
conversation = []
for prompt, answer in history:
conversation.append(UserMessage(content=prompt))
conversation.append(AssistantMessage(content=answer))
conversation.append(UserMessage(content=message))
completion_request = ChatCompletionRequest(messages=conversation)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate(
[tokens],
model,
max_tokens=max_new_tokens,
temperature=temperature,
eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
return result
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(theme="ocean") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|