import fire from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer import gradio as gr import torch import re def make_prompt( references: str = "", consult: str = "" ): prompt = "" if references == "" else f"References:\n{references}\n" prompt += f"Consult:\n{consult}\nResponse:\n" return prompt def main( model: str = "JessyTsu1/ChatLaw-13B", ): tokenizer = LlamaTokenizer.from_pretrained(model) model = LlamaForCausalLM.from_pretrained( model, torch_dtype=torch.float16, device_map="auto", ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.unk_token model.eval() def evaluate( references, consult, temperature=0.1, top_p=0.75, top_k=40, num_beams=4, max_new_tokens=128, **kwargs, ): prompt = make_prompt(references, consult) inputs = tokenizer(prompt, return_tensors="pt") inputs['input_ids'] = inputs['input_ids'].to(model.device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs, ) with torch.no_grad(): generation_output = model.generate( **inputs, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, repetition_penalty=1.2, ) s = generation_output.sequences[0] output = tokenizer.decode(s) if search_result := re.search("Response\s*:\s*([\s\S]+?)", output): return search_result.group(1) return "Error! Maybe response is over length." gr.Interface( fn=evaluate, inputs=[ gr.components.Textbox( lines=4, label="References", placeholder="输入你的参考资料", ), gr.components.Textbox( lines=2, label="Consult", placeholder="输入你的咨询内容,在问题前加上“详细分析:”会有更好的效果。", ), gr.components.Slider( minimum=0, maximum=1, value=0.7, label="Temperature" ), gr.components.Slider( minimum=0, maximum=1, value=0.75, label="Top p" ), gr.components.Slider( minimum=0, maximum=100, step=1, value=40, label="Top k" ), gr.components.Slider( minimum=1, maximum=4, step=1, value=1, label="Beams" ), gr.components.Slider( minimum=1, maximum=1024, step=1, value=1024, label="Max tokens" ), ], outputs = [ gr.inputs.Textbox( lines=8, label="Response", ) ], title="ChatLaw Academic Demo", description="", ).queue().launch(server_name="0.0.0.0",server_port=1234) if __name__ == "__main__": fire.Fire(main)