Spaces:
Paused
Paused
from Prompter import Prompter | |
from Callback import Stream, Iteratorize | |
import os | |
import sys | |
import gradio as gr | |
import torch | |
import transformers | |
from peft import PeftModel | |
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer | |
import pandas as pd | |
import numpy as np | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
try: | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
except: # noqa: E722 | |
pass | |
base_model = "openthaigpt/openthaigpt-1.0.0-beta-7b-chat-ckpt-hf" | |
load_8bit = True | |
# lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations" | |
lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations" | |
prompter = Prompter("alpaca") | |
tokenizer = LlamaTokenizer.from_pretrained(base_model) | |
model = LlamaForCausalLM.from_pretrained( | |
base_model, | |
load_in_8bit=load_8bit, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
offload_folder = "./offload" | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_weights, | |
torch_dtype=torch.float16, | |
offload_folder = "./offload" | |
) | |
# unwind broken decapoda-research config | |
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk | |
model.config.bos_token_id = 1 | |
model.config.eos_token_id = 2 | |
if not load_8bit: | |
model.half() # seems to fix bugs for some users. | |
model.eval() | |
if torch.__version__ >= "2" and sys.platform != "win32": | |
model = torch.compile(model) | |
def evaluate( | |
instruction, | |
input=None, | |
stream_output=False, | |
**kwargs, | |
): | |
temperature=0.5 | |
top_p=0.75 | |
top_k=40 | |
num_beams=4 | |
max_new_tokens=380 | |
prompt = prompter.generate_prompt(instruction, input) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
generate_params = { | |
"input_ids": input_ids, | |
"generation_config": generation_config, | |
"return_dict_in_generate": True, | |
"output_scores": True, | |
"max_new_tokens": max_new_tokens, | |
} | |
if stream_output: | |
# Stream the reply 1 token at a time. | |
# This is based on the trick of using 'stopping_criteria' to create an iterator, | |
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. | |
def generate_with_callback(callback=None, **kwargs): | |
kwargs.setdefault( | |
"stopping_criteria", transformers.StoppingCriteriaList() | |
) | |
kwargs["stopping_criteria"].append( | |
Stream(callback_func=callback) | |
) | |
with torch.no_grad(): | |
model.generate(**kwargs) | |
def generate_with_streaming(**kwargs): | |
return Iteratorize( | |
generate_with_callback, kwargs, callback=None | |
) | |
with generate_with_streaming(**generate_params) as generator: | |
for output in generator: | |
# new_tokens = len(output) - len(input_ids[0]) | |
decoded_output = tokenizer.decode(output) | |
if output[-1] in [tokenizer.eos_token_id]: | |
break | |
yield prompter.get_response(decoded_output) | |
return # early return for stream_output | |
# Without streaming | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
yield prompter.get_response(output) | |
# From SMOTE with 4 neightbor | |
fourNSMOTE = pd.read_csv("FILTER_GREATERTHANTHREE_FROM_SHEETS_SMOTE_train.csv") | |
with gr.Blocks() as demo: | |
birth_year = gr.components.Number(minimum = 2536, maximum = 2557, value= 2545, | |
label="ปีเกิด", | |
info="ต่ำสุด : 2536 สูงสุด : 2557") | |
nationality_name = gr.components.Dropdown(choices=fourNSMOTE.NATIONALITY_NAME.unique().tolist(), | |
label="สัญชาติ", | |
value = fourNSMOTE.NATIONALITY_NAME.unique().tolist()[0]) | |
religion_name = gr.components.Dropdown(choices=fourNSMOTE.RELIGION_NAME.unique().tolist(), | |
label="ศาสนา", | |
value = fourNSMOTE.RELIGION_NAME.unique().tolist()[0]) | |
sex = gr.components.Dropdown(choices=fourNSMOTE.JVN_SEX.unique().tolist(), | |
label="เพศ", | |
value = fourNSMOTE.JVN_SEX.unique().tolist()[0]) | |
inform_status = gr.components.Dropdown(choices=fourNSMOTE.INFORM_STATUS_TXT.unique().tolist(), | |
label="เหตุที่นำมาสู่การดำเนินคดี", | |
value = fourNSMOTE.INFORM_STATUS_TXT.unique().tolist()[0]) | |
age = gr.components.Number(minimum = 10, maximum = 19, value= 17, | |
label="อายุตอนกระทำผิด", | |
info="ต่ำสุด : 10 ปี สูงสุด : 19") | |
offense_name = gr.components.Dropdown(choices=fourNSMOTE.OFFENSE_NAME.unique().tolist(), | |
label="คดีที่กระทำผิด", | |
value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0]) | |
ref_value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0] | |
allegation_name = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_NAME.unique().tolist(), label="ชื่อของข้อกล่าวหา", | |
value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_NAME"].unique().tolist()[0]) | |
allegation_desc = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_DESC.unique().tolist(), label="รายละเอียดของข้อกล่าวหา", | |
value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_DESC"].unique().tolist()[0]) | |
def update_dropDown(value): | |
query_state = fourNSMOTE.query("OFFENSE_NAME == @value") | |
allegation_name = gr.components.Dropdown(choices=query_state["ALLEGATION_NAME"].unique().tolist()) | |
allegation_desc = gr.components.Dropdown(choices=query_state["ALLEGATION_DESC"].unique().tolist()) | |
return allegation_name, allegation_desc | |
offense_name.change(fn=update_dropDown, inputs=offense_name, outputs=[allegation_name, allegation_desc]) | |
rn1 = gr.components.Radio(choices=["ถูก", "ผิด"], | |
label="ปรากฎลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่", | |
value="ถูก") | |
rn2 = gr.components.Radio(choices=["ถูก", "ผิด"], | |
label="ปรากฎประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย", | |
value = "ถูก") | |
rn3 = gr.components.Radio(choices=["ถูก", "ผิด"], | |
label="ปรากฎประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว", | |
value = "ถูก") | |
education = gr.components.Dropdown(choices=fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist(), | |
label="สถาณะการศึกษา", | |
value = fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist()[0]) | |
occupation = gr.components.Dropdown(choices=fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist(), | |
label="สถาณะการประกอบอาชีพ", | |
value = fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist()[0]) | |
province = gr.components.Dropdown(choices=fourNSMOTE.PROVINCE_NAME.unique().tolist(), | |
label="จังหวัดที่กระทำผิด", | |
value = fourNSMOTE.PROVINCE_NAME.unique().tolist()[0]) | |
def generate_input(birth_year, nationality_name, religion_name, sex, | |
inform_status, age, offense_name, allegation_name, | |
allegation_desc, rn1, rn2, rn3, education, occupation, province): | |
birth_year = f"เกิดเมื่อปี พ.ศ. {int(birth_year)}" | |
if int(age) >= 10 or int(age) <=15: | |
age = f"มีอายุอยู่ในช่วง 10 ถึง 15 ปี" | |
elif int(age) >=16 or int(age) <= 20: | |
age = f"มีอายุอยู่ในช่วง 16 ถึง 20 ปี" | |
elif int(age) >=21 or int(age) <= 25: | |
age = f"มีอายุอยู่ในช่วง 21 ถึง 25 ปี" | |
elif int(age) >=26: | |
age = f"มีอายุอยู่ในช่วง 26 ปีขึ้นไป" | |
if rn1 == "ถูก": | |
rn1 = "มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่" | |
else: | |
rn1 = "ไม่มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่" | |
if rn2 == "ถูก": | |
rn2 = "มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย" | |
else: | |
rn2 = "ไม่มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย" | |
if rn3 == "ถูก": | |
rn3 = "มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว" | |
else: | |
rn3 = "ไม่มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว" | |
instruciton = "จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้" | |
input = f"{birth_year} {nationality_name} {religion_name} {sex} {inform_status} {age} {offense_name} {allegation_name} {allegation_desc} {rn1} {rn2} {rn3} {education} {occupation} {province}" | |
return input | |
def generate_output(instruction, input): | |
return input | |
def generate_input2(*values): | |
return "คำสั่ง : จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้ " + " ".join(str(value) for value in values) | |
instruction = gr.Textbox(label = "คำสั่ง", value="จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้", visible=True, interactive=False) | |
input_compo = gr.Textbox(label = "ข้อมูลเข้า (input)") | |
btn1 = gr.Button("GENERATE INPUT") | |
outputModel = gr.Textbox(label= "ผลลัพธ์ (output)") | |
stream_output = gr.components.Checkbox(label="Stream output") | |
# show input text format for user | |
btn1.click(fn=generate_input, inputs=[birth_year, nationality_name, religion_name, sex, | |
inform_status, age, offense_name, allegation_name, | |
allegation_desc, rn1, rn2, rn3, education, occupation, province], | |
outputs=input_compo) | |
btn2 = gr.Button("GENERATE OUTPUT") | |
btn2.click(fn=evaluate, inputs=[instruction, input_compo, stream_output], outputs=outputModel) | |
# outputChatInterface = gr.ChatInterface(fn=evaluate) | |
# input text format for model | |
# btn.click(fn=generate_text_test2, inputs = [birth_year, nationality_name, religion_name, sex, | |
# inform_status, age, offense_name, allegation_name, | |
# allegation_desc, rn1, rn2, rn3, education, occupation, province], | |
# outputs = input_compo) | |
demo.launch(debug=True, share=True) |