Orca-2-13B / .ipynb_checkpoints /app-checkpoint.py
ari9dam
adding app file
11250e9
raw
history blame
5.88 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import torch
import transformers
from transformers import TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model_id = "microsoft/Orca-2-13b"
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
system_message = "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
user_message = "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?"
DESCRIPTION = """
# Orca-2 13B
This Space demonstrates model [Orca-2-13B](https://huggingface.co/microsoft/Orca-2-13B) by Microsoft, a Llama 2 derivate model with 13B parameters fine-tuned for sigle turn instructions. This space is running on Inference Endpoints using text-generation-inference library. If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://ui.endpoints.huggingface.co/).
The system message is set to be the cautious system message:
You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.
Feel free to modify it in the additional input section. The demo uses greedy decoding.
πŸ”Ž For more details about the Orca family of models take a look [at our blog post](https://msft.it/6042iGtzK).
πŸ”¨ Looking for lighter versions of Orca-2? πŸ‡ Check out the [7B Chat model](https://huggingface.co/spaces/huggingface-projects/Orca-2-7b). Note: Orca 2 is licensed under the [Microsoft Research License](LICENSE). Llama 2 is licensed under the [LLAMA 2 Community License](https://ai.meta.com/llama/license/).
"""
# Function to combine system message and user
def to_prompt(conversations):
text = ""
for message in conversations:
if message['role']!="assistant":
text += f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>\n"
else:
text += f"<|im_start|>{message['role']}\n{message['content']}{tokenizer.eos_token}\n"
prompt = text + "<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors='pt').input_ids
return inputs
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt.strip()})
else:
conversation.append({"role": "system", "content": ""})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = to_prompt(conversation)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6, value=system_message),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
],
stop_btn=None,
examples=[
["How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?"],
["The eighth-grade class held a bake-off. Kelsie made two times more cookies than Josh. Josh made one-fourth the number of cookies that Suzanne made. If Suzanne made 36 cookies, how many did Kelsie make?"],
["Read the following web search snippets carefully and then answer the question below:\nWashington state remains near the top of the list for the most expensive average. According to the AAA, the current average price for a gallon of gas in Washington state is $5.01.\nToday's average price of gas in the U.S. is $3.82 per gallon, unchanged from yesterday, down $0.01 from last week and down $0.02 from last month.\n\nAnswer the following question:\n\nHow does the gas price in Washington compare to the national average? and what is the exact difference?"],
["The ages of New Havens residents are 25.4% under the age of 18, 16.4% from 18 to 24, 31.2% from 25 to 44, 16.7% from 45 to 64, and 10.2% who were 65 years of age or older. The median age is 29 years, which is significantly lower than the national average. There are 91.8 males per 100 females. For every 100 females age 18 and over, there are 87.6 males.\n\nWhich gender group is larger: females or males?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()