DachengZhang commited on
Commit
2ad8f46
1 Parent(s): 1c2e9a9

add generation utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +55 -0
generation_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ # build chat input prompt
5
+ def build_chat_input(tokenizer, messages: List[dict]):
6
+ # chat format:
7
+ # single-turn: <s>Human: Hello!\n\nAssistant: </s>
8
+ # multi-turn: <s>Human: Hello!\n\nAssistant: </s>Hi!</s>Human: How are you?\n\nAssistant: </s>I'm fine</s>
9
+
10
+ prompt = "<s>"
11
+ for msg in messages:
12
+ role = msg["role"]
13
+ message = msg["content"]
14
+ if message is None :
15
+ continue
16
+ if role == "user":
17
+ prompt += "Human: " + message + "\n\nAssistant: </s>"
18
+ if role == "assistant":
19
+ prompt += message + "</s>"
20
+
21
+ input_tokens = tokenizer.encode(prompt)
22
+ return input_tokens
23
+
24
+
25
+ class TextIterStreamer:
26
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
27
+ self.tokenizer = tokenizer
28
+ self.skip_prompt = skip_prompt
29
+ self.skip_special_tokens = skip_special_tokens
30
+ self.tokens = []
31
+ self.text_queue = Queue()
32
+ self.next_tokens_are_prompt = True
33
+
34
+ def put(self, value):
35
+ if self.skip_prompt and self.next_tokens_are_prompt:
36
+ self.next_tokens_are_prompt = False
37
+ else:
38
+ if len(value.shape) > 1:
39
+ value = value[0]
40
+ self.tokens.extend(value.tolist())
41
+ self.text_queue.put(
42
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
43
+
44
+ def end(self):
45
+ self.text_queue.put(None)
46
+
47
+ def __iter__(self):
48
+ return self
49
+
50
+ def __next__(self):
51
+ value = self.text_queue.get()
52
+ if value is None:
53
+ raise StopIteration()
54
+ else:
55
+ return value