wenge-research commited on
Commit
00be6c9
1 Parent(s): ae584ce

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -20
README.md CHANGED
@@ -29,7 +29,10 @@ prompt = "你好"
29
  formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
30
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
31
 
 
32
  generation_config = GenerationConfig(
 
 
33
  do_sample=True,
34
  max_new_tokens=100,
35
  temperature=0.3,
@@ -40,26 +43,7 @@ response = model.generate(**inputs, generation_config=generation_config)
40
  print(tokenizer.decode(response[0]))
41
  ```
42
 
43
- 注意,模型训练时添加了 special token `<|End|>` 作为结束符,上述代码在生成式若不能自动停止,可定义 `KeywordsStoppingCriteria` 类,并将其对象传参至 `model.generate()` 函数。
44
-
45
- ```python
46
- from transformers import StoppingCriteria, StoppingCriteriaList
47
-
48
- class KeywordsStoppingCriteria(StoppingCriteria):
49
- def __init__(self, keywords_ids:list):
50
- self.keywords = keywords_ids
51
-
52
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
53
- if input_ids[0][-1] in self.keywords:
54
- return True
55
- return False
56
- ```
57
-
58
- ```python
59
- stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
60
- response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
61
- print(tokenizer.decode(response[0]))
62
- ```
63
 
64
  ## 相关协议
65
 
 
29
  formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
30
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
31
 
32
+ eos_token_id = tokenizer("<|End|>").input_ids[0]
33
  generation_config = GenerationConfig(
34
+ eos_token_id=eos_token_id,
35
+ pad_token_id=eos_token_id,
36
  do_sample=True,
37
  max_new_tokens=100,
38
  temperature=0.3,
 
43
  print(tokenizer.decode(response[0]))
44
  ```
45
 
46
+ 注意,模型训练时添加了 special token `<|End|>` 作为结束符,因此上述代码 `GenerationConfig` 里将 `eos_token_id` 设置为该结束符对应的 token id。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  ## 相关协议
49