Kirill Gelvan commited on
Commit
26ff0b0
1 Parent(s): 67f50e8

add inference code to readme

Browse files
Files changed (1) hide show
  1. README.md +73 -0
README.md CHANGED
@@ -3,5 +3,78 @@ language: ru
3
  tags:
4
  - conversational
5
  ---
 
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  tags:
4
  - conversational
5
  ---
6
+ ### Description
7
 
8
 
9
+ ### Inference
10
+
11
+ ```python
12
+
13
+ def get_length_param(text: str, tokenizer) -> str:
14
+ tokens_count = len(tokenizer.encode(text))
15
+ if tokens_count <= 15:
16
+ len_param = '1'
17
+ elif tokens_count <= 50:
18
+ len_param = '2'
19
+ elif tokens_count <= 256:
20
+ len_param = '3'
21
+ else:
22
+ len_param = '-'
23
+ return len_param
24
+
25
+
26
+ def get_user_param(text: dict, machine_name_in_chat: str) -> str:
27
+ if text['from'] == machine_name_in_chat:
28
+ return '1' # machine
29
+ else:
30
+ return '0' # human
31
+
32
+
33
+ chat_history_ids = torch.zeros((1, 0), dtype=torch.int)
34
+
35
+ while True:
36
+
37
+ next_who = input("Who's phrase?\t") #input("H / G?") # Human or GPT
38
+
39
+ # In case Human
40
+ if next_who == "H" or next_who == "Human":
41
+ input_user = input("===> Human: ")
42
+
43
+ # encode the new user input, add parameters and return a tensor in Pytorch
44
+ new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" + input_user + tokenizer.eos_token, return_tensors="pt")
45
+ # append the new user input tokens to the chat history
46
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
47
+
48
+ if next_who == "G" or next_who == "GPT":
49
+
50
+ next_len = input("Phrase len? 1/2/3/-\t") #input("Exp. len?(-/1/2/3): ")
51
+ # encode the new user input, add parameters and return a tensor in Pytorch
52
+ new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
53
+ # append the new user input tokens to the chat history
54
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
55
+
56
+ # print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input
57
+
58
+ # save previous len
59
+ input_len = chat_history_ids.shape[-1]
60
+ # generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
61
+ chat_history_ids = model.generate(
62
+ chat_history_ids,
63
+ num_return_sequences=1, # use for more variants, but have to print [i]
64
+ max_length=512,
65
+ no_repeat_ngram_size=3,
66
+ do_sample=True,
67
+ top_k=50,
68
+ top_p=0.9,
69
+ temperature = 0.6, # 0 for greedy
70
+ mask_token_id=tokenizer.mask_token_id,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ unk_token_id=tokenizer.unk_token_id,
73
+ pad_token_id=tokenizer.pad_token_id,
74
+ device='cpu'
75
+ )
76
+
77
+
78
+ # pretty print last ouput tokens from bot
79
+ print(f"===> GPT-3: {tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)}")
80
+ ```