kimmeoungjun commited on
Commit
277c222
1 Parent(s): abd383d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -38
app.py CHANGED
@@ -1,50 +1,27 @@
1
  import torch
2
  import gradio as gr
3
 
4
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  from peft import PeftModel, PeftConfig
 
6
 
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
  peft_model_id = "kimmeoungjun/qlora-koalpaca"
9
  config = PeftConfig.from_pretrained(peft_model_id)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
11
  model = PeftModel.from_pretrained(model, peft_model_id).to(device)
12
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
 
14
- def my_split(s, seps):
15
- res = [s]
16
- for sep in seps:
17
- s, res = res, []
18
- for seq in s:
19
- res += seq.split(sep)
20
- return res
21
-
22
- def chat_base(input):
23
- p = input
24
- input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
25
- gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, eos_token_id=2,)
26
- gen_text = tokenizer.batch_decode(gen_tokens)[0]
27
- # print(gen_text)
28
- result = gen_text[len(p):]
29
- # print(">", result)
30
- result = my_split(result, [']', '\n'])[1]
31
- # print(">>", result)
32
- # print(">>>", result)
33
- return result
34
-
35
- def chat(message):
36
- history = gr.get_state() or []
37
- print(history)
38
- response = chat_base(message)
39
- history.append((message, response))
40
- gr.set_state(history)
41
- html = "<div class='chatbot'>"
42
- for user_msg, resp_msg in history:
43
- html += f"<div class='user_msg'>{user_msg}</div>"
44
- html += f"<div class='resp_msg'>{resp_msg}</div>"
45
- html += "</div>"
46
- return response
47
-
48
- iface = gr.Interface(chat_base, gr.inputs.Textbox(label="물어보세요"), "text", allow_screenshot=False, allow_flagging=False)
49
- iface.launch()
50
 
 
 
1
  import torch
2
  import gradio as gr
3
 
 
4
  from peft import PeftModel, PeftConfig
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
  peft_model_id = "kimmeoungjun/qlora-koalpaca"
9
  config = PeftConfig.from_pretrained(peft_model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
11
  model = PeftModel.from_pretrained(model, peft_model_id).to(device)
12
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
 
14
+ def generate(q):
15
+ inputs = tokenizer(f"### 질문: {q}\n\n### 답변:", return_tensors='pt', return_token_type_ids=False)
16
+ outputs = model.generate(
17
+ **{k: v.to(device) for k, v in inputs.items()},
18
+ max_new_tokens=256,
19
+ do_sample=True,
20
+ eos_token_id=2,
21
+ )
22
+ result = tokenizer.decode(outputs[0])
23
+ answer_idx = result.find("### 답변:")
24
+ answer = result[answer_idx + 7:].strip()
25
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ gr.Interface(generate, "text", "text").launch(share=True)