wenge-research
commited on
Commit
•
ae584ce
1
Parent(s):
940e2d8
Update README.md
Browse files
README.md
CHANGED
@@ -19,6 +19,7 @@ tags:
|
|
19 |
|
20 |
```python
|
21 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
|
22 |
|
23 |
yayi_7b_path = "wenge-research/yayi-7b"
|
24 |
tokenizer = AutoTokenizer.from_pretrained(yayi_7b_path)
|
@@ -26,7 +27,7 @@ model = AutoModelForCausalLM.from_pretrained(yayi_7b_path, device_map="auto", to
|
|
26 |
|
27 |
prompt = "你好"
|
28 |
formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
|
29 |
-
inputs = tokenizer
|
30 |
|
31 |
generation_config = GenerationConfig(
|
32 |
do_sample=True,
|
@@ -36,12 +37,14 @@ generation_config = GenerationConfig(
|
|
36 |
no_repeat_ngram_size=0
|
37 |
)
|
38 |
response = model.generate(**inputs, generation_config=generation_config)
|
39 |
-
print(tokenizer.decode(
|
40 |
```
|
41 |
|
42 |
注意,模型训练时添加了 special token `<|End|>` 作为结束符,上述代码在生成式若不能自动停止,可定义 `KeywordsStoppingCriteria` 类,并将其对象传参至 `model.generate()` 函数。
|
43 |
|
44 |
```python
|
|
|
|
|
45 |
class KeywordsStoppingCriteria(StoppingCriteria):
|
46 |
def __init__(self, keywords_ids:list):
|
47 |
self.keywords = keywords_ids
|
@@ -54,11 +57,10 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
54 |
|
55 |
```python
|
56 |
stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
|
57 |
-
|
58 |
-
|
59 |
```
|
60 |
|
61 |
-
|
62 |
## 相关协议
|
63 |
|
64 |
### 局限性
|
|
|
19 |
|
20 |
```python
|
21 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
22 |
+
import torch
|
23 |
|
24 |
yayi_7b_path = "wenge-research/yayi-7b"
|
25 |
tokenizer = AutoTokenizer.from_pretrained(yayi_7b_path)
|
|
|
27 |
|
28 |
prompt = "你好"
|
29 |
formatted_prompt = f"<|System|>:\nA chat between a human and an AI assistant named YaYi.\nYaYi is a helpful and harmless language model developed by Beijing Wenge Technology Co.,Ltd.\n\n<|Human|>:\n{prompt}\n\n<|YaYi|>:"
|
30 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
31 |
|
32 |
generation_config = GenerationConfig(
|
33 |
do_sample=True,
|
|
|
37 |
no_repeat_ngram_size=0
|
38 |
)
|
39 |
response = model.generate(**inputs, generation_config=generation_config)
|
40 |
+
print(tokenizer.decode(response[0]))
|
41 |
```
|
42 |
|
43 |
注意,模型训练时添加了 special token `<|End|>` 作为结束符,上述代码在生成式若不能自动停止,可定义 `KeywordsStoppingCriteria` 类,并将其对象传参至 `model.generate()` 函数。
|
44 |
|
45 |
```python
|
46 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
47 |
+
|
48 |
class KeywordsStoppingCriteria(StoppingCriteria):
|
49 |
def __init__(self, keywords_ids:list):
|
50 |
self.keywords = keywords_ids
|
|
|
57 |
|
58 |
```python
|
59 |
stop_criteria = KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in ["<|End|>"]])
|
60 |
+
response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
|
61 |
+
print(tokenizer.decode(response[0]))
|
62 |
```
|
63 |
|
|
|
64 |
## 相关协议
|
65 |
|
66 |
### 局限性
|