f15hb0wn commited on
Commit
4532822
1 Parent(s): e71f223

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ import gradio as gr
4
+
5
+ model_id = "witfoo/witq-1.0"
6
+ dtype = torch.float16 # float16 for Tesla T4, V100, bfloat16 for Ampere+
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_id,
10
+ torch_dtype=dtype,
11
+ device_map="auto",
12
+ )
13
+
14
+ preamble = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
15
+
16
+
17
+
18
+ def input_tokens(instruction, prompt):
19
+ messages = [
20
+ {"role": "system", "content": preamble + " " + instruction},
21
+ {"role": "user", "content": prompt},
22
+ ]
23
+ inputs = tokenizer.apply_chat_template(
24
+ messages,
25
+ add_generation_prompt=True,
26
+ return_tensors="pt"
27
+ ).to(model.device)
28
+ return inputs
29
+
30
+
31
+
32
+ def generate_response(instruction, input_text):
33
+ input_ids = input_tokens(instruction, input_text)
34
+ terminators = [
35
+ tokenizer.eos_token_id,
36
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
37
+ ]
38
+ outputs = model.generate(
39
+ input_ids,
40
+ max_new_tokens=256,
41
+ eos_token_id=terminators,
42
+ do_sample=True,
43
+ temperature=0.6,
44
+ top_p=0.9,
45
+ )
46
+
47
+ # Extract the response portion
48
+ response = outputs[0][input_ids.shape[-1]:]
49
+ result = tokenizer.decode(response, skip_special_tokens=True)
50
+ return result
51
+
52
+ def chatbot(instructions, input_text):
53
+ response = generate_response(instructions, input_text)
54
+ return response
55
+
56
+ trained_instructions = [
57
+ "Answer this question",
58
+ "Create a JSON artifact from the message",
59
+ "Identify this syslog message",
60
+ "Explain this syslog message",
61
+ ]
62
+
63
+ iface = gr.Interface(
64
+ fn=chatbot,
65
+ inputs=[
66
+ gr.Dropdown(choices=trained_instructions, label="Instruction"),
67
+ gr.Textbox(lines=2, placeholder="Enter your input here...", label="Input Text")
68
+ ],
69
+ outputs=gr.Textbox(label="Response"),
70
+ title="WitQ Chatbot"
71
+ )
72
+
73
+
74
+ app = gr.Blocks()
75
+
76
+ with app:
77
+ iface.render()
78
+
79
+
80
+ app.launch()