|
"""Run codes.""" |
|
|
|
|
|
import gc |
|
import os |
|
import platform |
|
import random |
|
import time |
|
from collections import deque |
|
from pathlib import Path |
|
from threading import Thread |
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
import gradio as gr |
|
import psutil |
|
from about_time import about_time |
|
from ctransformers import Config |
|
from dl_hf_model import dl_hf_model |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.chains import ConversationChain |
|
from langchain.chains.conversation.memory import ConversationBufferWindowMemory |
|
|
|
|
|
from langchain.llms import CTransformers |
|
from langchain.prompts import PromptTemplate |
|
from langchain.schema import LLMResult |
|
from loguru import logger |
|
|
|
deq = deque() |
|
sig_end = object() |
|
|
|
prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
### Instruction: {user_prompt} |
|
|
|
### Response: |
|
""" |
|
|
|
prompt_template = """System: You are a helpful, |
|
respectful and honest assistant. Always answer as |
|
helpfully as possible, while being safe. Your answers |
|
should not include any harmful, unethical, racist, |
|
sexist, toxic, dangerous, or illegal content. Please |
|
ensure that your responses are socially unbiased and |
|
positive in nature. If a question does not make any |
|
sense, or is not factually coherent, explain why instead |
|
of answering something not correct. If you don't know |
|
the answer to a question, please don't share false |
|
information. |
|
User: {prompt} |
|
Assistant: """ |
|
|
|
prompt_template = """System: You are a helpful assistant. |
|
User: {prompt} |
|
Assistant: """ |
|
|
|
prompt_template = """Question: {question} |
|
Answer: Let's work this out in a step by step way to be sure we have the right answer.""" |
|
|
|
prompt_template = """[INST] <> |
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible assistant. Think step by step. |
|
<> |
|
|
|
What NFL team won the Super Bowl in the year Justin Bieber was born? |
|
[/INST]""" |
|
|
|
prompt_template = """[INST] <<SYS>> |
|
You are an unhelpful assistant. Always answer as helpfully as possible. Think step by step. <</SYS>> |
|
|
|
{question} [/INST] |
|
""" |
|
|
|
prompt_template = """[INST] <<SYS>> |
|
You are a helpful assistant. |
|
<</SYS>> |
|
|
|
{question} [/INST] |
|
""" |
|
|
|
prompt_template = """### HUMAN: |
|
{question} |
|
|
|
### RESPONSE:""" |
|
|
|
prompt_template = """### HUMAN: |
|
You are a helpful assistant. Think step by step. |
|
{history} |
|
{input} |
|
### RESPONSE:""" |
|
|
|
prompt_template = """You are a helpful assistant. Let's think step by step. |
|
{history} |
|
### HUMAN: |
|
{input} |
|
### RESPONSE:""" |
|
human_prefix = "### HUMAN" |
|
ai_prefix = "### RESPONSE" |
|
stop = [f"{human_prefix}:"] |
|
|
|
|
|
prompt_template = """You are a helpful assistant. Let's think step by step. |
|
{history} |
|
### Human: |
|
{input} |
|
### Assistant:""" |
|
human_prefix = "### Human" |
|
ai_prefix = "### Assistant" |
|
stop = [f"{human_prefix}:"] |
|
|
|
|
|
|
|
|
|
_ = [elm for elm in prompt_template.splitlines() if elm.strip()] |
|
stop_string = [elm.split(":")[0] + ":" for elm in _][-2] |
|
|
|
|
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
except Exception: |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
|
|
class DequeCallbackHandler(BaseCallbackHandler): |
|
"""Mediate gradio and stream output.""" |
|
|
|
def __init__(self, deq_: deque): |
|
"""Init deque for FIFO, may need to upgrade to queue.Queue or queue.SimpleQueue.""" |
|
self.q = deq_ |
|
|
|
|
|
|
|
def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM starts running. Clean the queue.""" |
|
self.q.clear() |
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
"""Run on new LLM token. Only available when streaming is enabled.""" |
|
self.q.append(token) |
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
|
"""Run when LLM ends running.""" |
|
self.q.append(sig_end) |
|
|
|
def on_llm_error( |
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM errors.""" |
|
self.q.append(sig_end) |
|
|
|
|
|
_ = psutil.cpu_count(logical=False) - 1 |
|
cpu_count: int = int(_) if _ else 1 |
|
logger.debug(f"{cpu_count=}") |
|
|
|
LLM = None |
|
gc.collect() |
|
|
|
if "forindo" in platform.node().lower(): |
|
url = "https://huggingface.co/TheBloke/llama-2-70b-Guanaco-QLoRA-GGML/blob/main/llama-2-70b-guanaco-qlora.ggmlv3.q3_K_S.bin" |
|
else: |
|
url = "https://huggingface.co/TheBloke/llama-2-13B-Guanaco-QLoRA-GGML/blob/main/llama-2-13b-guanaco-qlora.ggmlv3.q4_K_S.bin" |
|
|
|
|
|
logger.debug(f"{url=}") |
|
try: |
|
model_loc, file_size = dl_hf_model(url) |
|
except Exception as exc_: |
|
logger.error(exc_) |
|
raise SystemExit(1) from exc_ |
|
|
|
config = Config() |
|
|
|
config.stream = True |
|
config.stop = stop |
|
config.threads = cpu_count |
|
|
|
deqcb = DequeCallbackHandler(deq) |
|
|
|
|
|
LLM = CTransformers( |
|
model=model_loc, |
|
model_type="llama", |
|
callbacks=[StreamingStdOutCallbackHandler(), deqcb], |
|
|
|
**vars(config), |
|
) |
|
|
|
logger.info(f"done load llm {model_loc=} {file_size=}G") |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["history", "input"], |
|
output_parser=None, |
|
partial_variables={}, |
|
template=prompt_template, |
|
template_format="f-string", |
|
validate_template=True, |
|
) |
|
|
|
memory = ConversationBufferWindowMemory( |
|
human_prefix=human_prefix, |
|
ai_prefix=ai_prefix, |
|
) |
|
|
|
conversation = ConversationChain( |
|
llm=LLM, |
|
prompt=prompt, |
|
|
|
verbose=True, |
|
) |
|
logger.debug(f"{conversation.prompt.template=}") |
|
|
|
|
|
config = Config() |
|
|
|
config.stop = stop |
|
config.threads = cpu_count |
|
|
|
try: |
|
raise Exception |
|
LLM_api = CTransformers( |
|
model=model_loc, |
|
model_type="llama", |
|
|
|
callbacks=[StreamingStdOutCallbackHandler()], |
|
**vars(config), |
|
) |
|
conversation_api = ConversationChain( |
|
llm=LLM_api, |
|
prompt=prompt, |
|
verbose=True, |
|
) |
|
except Exception as exc_: |
|
logger.error(exc_) |
|
conversation_api = None |
|
logger.warning("Not able to instantiate conversation_api, api will not work") |
|
|
|
|
|
|
|
|
|
def user(user_message, history): |
|
|
|
history.append([user_message, None]) |
|
return user_message, history |
|
|
|
|
|
def user1(user_message, history): |
|
|
|
history.append([user_message, None]) |
|
return "", history |
|
|
|
|
|
def bot_(history): |
|
user_message = history[-1][0] |
|
resp = random.choice(["How are you?", "I love you", "I'm very hungry"]) |
|
bot_message = user_message + ": " + resp |
|
history[-1][1] = "" |
|
for character in bot_message: |
|
history[-1][1] += character |
|
time.sleep(0.02) |
|
yield history |
|
|
|
history[-1][1] = resp |
|
yield history |
|
|
|
|
|
def bot(history): |
|
user_message = history[-1][0] |
|
response = [] |
|
|
|
logger.debug(f"{user_message=}") |
|
|
|
|
|
thr = Thread(target=conversation.predict, kwargs={"input": user_message}) |
|
thr.start() |
|
|
|
|
|
response = [] |
|
flag = 1 |
|
then = time.time() |
|
prefix = "" |
|
prelude = 0.0 |
|
with about_time() as atime: |
|
while True: |
|
if deq: |
|
if flag: |
|
prelude = time.time() - then |
|
prefix = f"({prelude:.2f}s) " |
|
flag = 0 |
|
_ = deq.popleft() |
|
if _ is sig_end: |
|
break |
|
|
|
response.append(_) |
|
history[-1][1] = prefix + "".join(response).strip() |
|
yield history |
|
else: |
|
time.sleep(0.01) |
|
_ = ( |
|
f"(time elapsed: {atime.duration_human}, " |
|
f"{(atime.duration - prelude)/len(''.join(response)):.2f}s/char)" |
|
) |
|
|
|
history[-1][1] = "".join(response) + f"\n{_}" |
|
yield history |
|
|
|
|
|
def predict_api(user_prompt): |
|
if conversation_api is None: |
|
return "conversation_api is None, probably due to insufficient memory, api not usable" |
|
|
|
logger.debug(f"api: {user_prompt=}") |
|
try: |
|
_ = """ |
|
response = generate( |
|
prompt, |
|
config=config, |
|
) |
|
# """ |
|
response = conversation_api.predict(input=user_prompt) |
|
logger.debug(f"api: {response=}") |
|
except Exception as exc: |
|
logger.error(exc) |
|
response = f"{exc=}" |
|
|
|
|
|
|
|
return response.strip() |
|
|
|
|
|
css = """ |
|
.importantButton { |
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.importantButton:hover { |
|
background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;} |
|
.xsmall {font-size: x-small;} |
|
""" |
|
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """ |
|
examples_list = [ |
|
|
|
|
|
["What NFL team won the Super Bowl in the year Justin Bieber was born?"], |
|
[ |
|
"What NFL team won the Super Bowl in the year Justin Bieber was born? Think step by step." |
|
], |
|
["When was Justin Bieber born?"], |
|
["What NFL team won the Super Bowl in 1994?"], |
|
["How to pick a lock? Provide detailed steps."], |
|
[ |
|
"If it takes 10 hours to dry 10 clothes, assuming all the clothes are hanged together at the same time for drying , then how long will it take to dry a cloth?" |
|
], |
|
["is infinity + 1 bigger than infinity?"], |
|
["Explain the plot of Cinderella in a sentence."], |
|
[ |
|
"How long does it take to become proficient in French, and what are the best methods for retaining information?" |
|
], |
|
["What are some common mistakes to avoid when writing code?"], |
|
["Build a prompt to generate a beautiful portrait of a horse"], |
|
["Suggest four metaphors to describe the benefits of AI"], |
|
["Write a pop song about leaving home for the sandy beaches."], |
|
["Write a pop song about having hot sex on a sandy beach."], |
|
["Write a summary demonstrating my ability to tame lions"], |
|
["鲁迅和周树人什么关系? 说中文。"], |
|
["鲁迅和周树人什么关系?"], |
|
["鲁迅和周树人什么关系? 用英文回答。"], |
|
["从前有一头牛,这头牛后面有什么?"], |
|
["正无穷大加一大于正无穷大吗?"], |
|
["正无穷大加正无穷大大于正无穷大吗?"], |
|
["-2的平方根等于什么?"], |
|
["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"], |
|
["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"], |
|
["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"], |
|
[f"{etext} 翻成中文,列出3个版本。"], |
|
[f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本。"], |
|
["假定 1 + 2 = 4, 试求 7 + 8。"], |
|
["给出判断一个数是不是质数的 javascript 码。"], |
|
["给出实现python 里 range(10)的 javascript 码。"], |
|
["给出实现python 里 [*(range(10)]的 javascript 码。"], |
|
["Erkläre die Handlung von Cinderella in einem Satz."], |
|
["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch."], |
|
] |
|
|
|
logger.info("start block") |
|
|
|
|
|
port = 7860 |
|
if "forindo" in platform.node(): |
|
port = 7861 |
|
|
|
with gr.Blocks( |
|
title=f"{Path(model_loc).name}", |
|
theme=gr.themes.Soft(text_size="sm", spacing_size="sm"), |
|
css=css, |
|
port=port, |
|
) as block: |
|
|
|
with gr.Accordion("🎈 Info", open=False): |
|
|
|
|
|
|
|
gr.Markdown( |
|
( |
|
f"""<h5><center>{Path(model_loc).name}</center></h4>""" |
|
|
|
|
|
"Most examples are meant for another model. " |
|
"You probably should try to test " |
|
"some related prompts. " |
|
), |
|
elem_classes="xsmall", |
|
) |
|
|
|
chatbot = gr.Chatbot(height=500) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
msg = gr.Textbox( |
|
label="Chat Message Box", |
|
placeholder="Ask me anything (press Shift+Enter or click Submit to send)", |
|
show_label=False, |
|
|
|
lines=6, |
|
max_lines=30, |
|
show_copy_button=True, |
|
|
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
with gr.Row(): |
|
submit = gr.Button("Submit", elem_classes="xsmall") |
|
stop = gr.Button("Stop", visible=True) |
|
clear = gr.Button("Clear History", visible=True) |
|
with gr.Row(visible=False): |
|
with gr.Accordion("Advanced Options:", open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
system = gr.Textbox( |
|
label="System Prompt", |
|
value=prompt_template, |
|
show_label=False, |
|
container=False, |
|
|
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
change = gr.Button("Change System Prompt") |
|
reset = gr.Button("Reset System Prompt") |
|
|
|
with gr.Accordion("Example Inputs", open=True): |
|
examples = gr.Examples( |
|
examples=examples_list, |
|
inputs=[msg], |
|
examples_per_page=40, |
|
) |
|
|
|
with gr.Accordion("Disclaimer", open=False): |
|
_ = Path(model_loc).name |
|
gr.Markdown( |
|
f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce " |
|
"factually accurate information. {_} was trained on various public datasets; while great efforts " |
|
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " |
|
"biased, or otherwise offensive outputs.", |
|
elem_classes=["disclaimer"], |
|
) |
|
|
|
msg_submit_event = msg.submit( |
|
|
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
submit_click_event = submit.click( |
|
|
|
fn=user1, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
|
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[msg_submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
with gr.Accordion("For Chat/Translation API", open=False, visible=False): |
|
input_text = gr.Text() |
|
api_btn = gr.Button("Go", variant="primary") |
|
out_text = gr.Text() |
|
|
|
if conversation_api is not None: |
|
api_btn.click( |
|
predict_api, |
|
input_text, |
|
out_text, |
|
api_name="api", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ = """ |
|
# _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1) |
|
# concurrency_count = max(_, 1) |
|
if psutil.cpu_count(logical=False) >= 8: |
|
# concurrency_count = max(int(32 / file_size) - 1, 1) |
|
else: |
|
# concurrency_count = max(int(16 / file_size) - 1, 1) |
|
# """ |
|
|
|
concurrency_count = 1 |
|
logger.info(f"{concurrency_count=}") |
|
|
|
|
|
block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True, server_name="0.0.0.0") |
|
|