wenge-research
commited on
Commit
•
00be6c9
1
Parent(s):
ae584ce
Update README.md
Browse files
README.md
CHANGED
@@ -29,7 +29,10 @@ prompt = "你好"
|
|
29 |
formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
|
30 |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
31 |
|
|
|
32 |
generation_config = GenerationConfig(
|
|
|
|
|
33 |
do_sample=True,
|
34 |
max_new_tokens=100,
|
35 |
temperature=0.3,
|
@@ -40,26 +43,7 @@ response = model.generate(**inputs, generation_config=generation_config)
|
|
40 |
print(tokenizer.decode(response[0]))
|
41 |
```
|
42 |
|
43 |
-
注意,模型训练时添加了 special token `<|End|>`
|
44 |
-
|
45 |
-
```python
|
46 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
47 |
-
|
48 |
-
class KeywordsStoppingCriteria(StoppingCriteria):
|
49 |
-
def __init__(self, keywords_ids:list):
|
50 |
-
self.keywords = keywords_ids
|
51 |
-
|
52 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
53 |
-
if input_ids[0][-1] in self.keywords:
|
54 |
-
return True
|
55 |
-
return False
|
56 |
-
```
|
57 |
-
|
58 |
-
```python
|
59 |
-
stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
|
60 |
-
response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
|
61 |
-
print(tokenizer.decode(response[0]))
|
62 |
-
```
|
63 |
|
64 |
## 相关协议
|
65 |
|
|
|
29 |
formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
|
30 |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
31 |
|
32 |
+
eos_token_id = tokenizer("<|End|>").input_ids[0]
|
33 |
generation_config = GenerationConfig(
|
34 |
+
eos_token_id=eos_token_id,
|
35 |
+
pad_token_id=eos_token_id,
|
36 |
do_sample=True,
|
37 |
max_new_tokens=100,
|
38 |
temperature=0.3,
|
|
|
43 |
print(tokenizer.decode(response[0]))
|
44 |
```
|
45 |
|
46 |
+
注意,模型训练时添加了 special token `<|End|>` 作为结束符,因此上述代码 `GenerationConfig` 里将 `eos_token_id` 设置为该结束符对应的 token id。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
## 相关协议
|
49 |
|