import argparse import os import gradio as gr import mdtex2html from gradio.themes.utils import colors, fonts, sizes import torch from peft import PeftModel from transformers import ( AutoModel, AutoTokenizer, AutoModelForCausalLM, BloomForCausalLM, BloomTokenizerFast, LlamaTokenizer, LlamaForCausalLM, GenerationConfig, ) MODEL_CLASSES = { "bloom": (BloomForCausalLM, BloomTokenizerFast), "chatglm": (AutoModel, AutoTokenizer), "llama": (LlamaForCausalLM, LlamaTokenizer), "auto": (AutoModelForCausalLM, AutoTokenizer), } class OpenGVLab(gr.themes.base.Base): def __init__( self, *, primary_hue=colors.blue, secondary_hue=colors.sky, neutral_hue=colors.gray, spacing_size=sizes.spacing_md, radius_size=sizes.radius_sm, text_size=sizes.text_md, font=( fonts.GoogleFont("Noto Sans"), "ui-sans-serif", "sans-serif", ), font_mono=( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono, ) super().set( body_background_fill="*neutral_50", ) gvlabtheme = OpenGVLab(primary_hue=colors.blue, secondary_hue=colors.sky, neutral_hue=colors.gray, spacing_size=sizes.spacing_md, radius_size=sizes.radius_sm, text_size=sizes.text_md, ) def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_type', default="llama", type=str) parser.add_argument('--base_model', default=r"/data/wangpeng/JiaotongGPT-main/merged-sft-no-1ep", type=str) parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model") parser.add_argument('--tokenizer_path', default=None, type=str) parser.add_argument('--gpus', default="0", type=str) parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference') parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings') args = parser.parse_args() if args.only_cpu is True: args.gpus = "" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert((message)), None if response is None else mdtex2html.convert(response), ) return y gr.Chatbot.postprocess = postprocess generation_config = dict( temperature=0.2, top_k=40, top_p=0.9, do_sample=True, num_beams=1, repetition_penalty=1.1, max_new_tokens=400 ) load_type = torch.float16 if torch.cuda.is_available(): device = torch.device(0) else: device = torch.device('cpu') if args.tokenizer_path is None: args.tokenizer_path = args.base_model model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True) base_model = model_class.from_pretrained( args.base_model, load_in_8bit=False, torch_dtype=load_type, low_cpu_mem_usage=True, device_map='auto', trust_remote_code=True, ) if args.resize_emb: model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenzier_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") if model_vocab_size != tokenzier_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenzier_vocab_size) if args.lora_model: model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto') print("loaded lora model") else: model = base_model if device == torch.device('cpu'): model.float() model.eval() def reset_user_input(): return gr.update(value='') def reset_state(): return [], [] def generate_prompt(instruction): return f"""You are TransGPT, a specialist in the field of transportation.Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Response: """ def predict( input, chatbot, history, max_new_tokens=128, top_p=0.75, temperature=0.1, top_k=40, num_beams=4, repetition_penalty=1.0, max_memory=256, **kwargs, ): now_input = input chatbot.append((input, "")) history = history or [] if len(history) != 0: input = "".join( ["### Instruction:\n" + i[0] + "\n\n" + "### Response: " + i[1] + "\n\n" for i in history]) + \ "### Instruction:\n" + input input = input[len("### Instruction:\n"):] if len(input) > max_memory: input = input[-max_memory:] prompt = generate_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_new_tokens, repetition_penalty=float(repetition_penalty), ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) output = output.split("### Response:")[-1].strip() history.append((now_input, output)) chatbot[-1] = (now_input, output) return chatbot, history title = """

Welcome to TransGPT!""" with gr.Blocks(title="DUOMO TransGPT!", theme=gvlabtheme, css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: gr.Markdown(title) # with gr.Blocks() as demo: # gr.HTML("""

TransGPT

""") # # gr.Markdown( # # "> 为了促进医疗行业大模型的开放研究,本项目开源了TransGPT医疗大模型") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider( 0, 4096, value=128, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider( 0, 1, value=0.7, step=0.01, label="Temperature", interactive=True) history = gr.State([]) # (message, bot_message) submitBtn.click(predict, [user_input, chatbot, history, max_length, top_p, temperature], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=8080) if __name__ == '__main__': main()