pyx9913
commited on
Commit
•
aa60bbf
1
Parent(s):
4f1e38f
feat: 🎸 add chat model code
Browse files- README.md +124 -69
- README_en.md +162 -0
- beit3.py +108 -0
- config.json +27 -0
- configuration_viscpmchatbee.py +133 -0
- feature_extraction_viscpmchatbee.py +17 -0
- generation_config.json +12 -0
- modeling_cpmbee.py +0 -0
- preprocessor_config.json +10 -0
- processing_viscpmchatbee.py +428 -0
- tokenization_viscpmchatbee.py +1007 -0
- tokenizer_config.json +10 -0
- utils.py +730 -0
- vocab.txt +0 -0
README.md
CHANGED
@@ -1,58 +1,45 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
- en
|
4 |
-
- zh
|
5 |
-
---
|
6 |
-
<div align="center">
|
7 |
-
|
8 |
-
**VisCPM**
|
9 |
-
|
10 |
-
**Chinese-English bilingual multi-modal large model series based on CPM (Chinese Pretrained Models) basic model**
|
11 |
|
12 |
<p align="center">
|
13 |
-
|
14 |
-
|
|
|
15 |
</p>
|
16 |
|
17 |
-
|
18 |
|
19 |
-
`VisCPM
|
20 |
-
|
21 |
-
- **👐 Open-source Usage**: VisCPM is free to be used for personal and research purposes. By open-sourcing the VisCPM model family, we hope to promote the development of the open-source community of large multimodal models and related research.
|
22 |
-
- **🌟 Image and text generation coverage**: VisCPM models provide relatively comprehensive support for image and text multimodal capabilities, covering both multimodal conversation (image-to-text generation) capabilities and text-to-image generation capabilities.
|
23 |
-
- **💫 Excellent bilingual performance**: Thanks to the excellent bilingual capability of the base language model CPM-Bee, VisCPM achieves outstanding results in both bilingual multimodal conversation and text-to-image generation.
|
24 |
|
25 |
## VisCPM-Chat
|
26 |
-
`VisCPM-Chat
|
27 |
|
28 |
-
*
|
29 |
|
30 |
-
*
|
31 |
|
32 |
-
|
33 |
|
34 |
<table>
|
35 |
<tr>
|
36 |
-
<td align="center" rowspan="2" colspan="2"
|
37 |
-
<td align="center"
|
38 |
-
<td align="center" colspan="4"
|
39 |
-
<td align="center" colspan="4">Chinese</td>
|
40 |
</tr>
|
41 |
<tr>
|
42 |
-
<td align="center"
|
43 |
-
<td align="center"
|
44 |
-
<td align="center"
|
45 |
-
<td align="center"
|
46 |
-
<td align="center"
|
47 |
-
<td align="center"
|
48 |
-
<td align="center"
|
49 |
-
<td align="center"
|
50 |
</tr>
|
51 |
<tr>
|
52 |
-
<td align="center" rowspan="3"
|
53 |
<td align="center">MiniGPT4</td>
|
54 |
-
<td align="center">
|
55 |
-
<td align="center">65.0</td>
|
56 |
<td align="center">67.3</td>
|
57 |
<td align="center">76.6</td>
|
58 |
<td align="center">69.7</td>
|
@@ -63,9 +50,8 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
63 |
</tr>
|
64 |
<tr>
|
65 |
<td align="center">InstructBLIP</td>
|
66 |
-
<td align="center">Vicuna-13B</td>
|
67 |
<td align="center">81.9</td>
|
68 |
-
<td align="center">68
|
69 |
<td align="center">91.2</td>
|
70 |
<td align="center">80.5</td>
|
71 |
<td align="center">-</td>
|
@@ -75,20 +61,18 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
75 |
</tr>
|
76 |
<tr>
|
77 |
<td align="center">LLaVA</td>
|
78 |
-
<td align="center">
|
79 |
-
<td align="center"
|
80 |
-
<td align="center"
|
81 |
-
<td align="center"
|
82 |
-
<td align="center"><b>85.6</b></td>
|
83 |
<td align="center">-</td>
|
84 |
<td align="center">-</td>
|
85 |
<td align="center">-</td>
|
86 |
<td align="center">-</td>
|
87 |
</tr>
|
88 |
<tr>
|
89 |
-
<td align="center" rowspan="
|
90 |
<td align="center">mPLUG-Owl </td>
|
91 |
-
<td align="center">LLaMA-7B</td>
|
92 |
<td align="center">64.6</td>
|
93 |
<td align="center">47.7</td>
|
94 |
<td align="center">80.1</td>
|
@@ -96,61 +80,132 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
|
|
96 |
<td align="center">76.3</td>
|
97 |
<td align="center">61.2</td>
|
98 |
<td align="center">77.8</td>
|
99 |
-
<td align="center">72
|
100 |
</tr>
|
101 |
<tr>
|
102 |
<td align="center">VisualGLM</td>
|
103 |
-
<td align="center">ChatGLM-6B</td>
|
104 |
<td align="center">62.4</td>
|
105 |
-
<td align="center">63
|
106 |
<td align="center">80.6</td>
|
107 |
<td align="center">68.7</td>
|
108 |
<td align="center">76.6</td>
|
109 |
-
<td align="center"
|
110 |
<td align="center">83.6</td>
|
111 |
<td align="center">82.7</td>
|
112 |
</tr>
|
113 |
<tr>
|
114 |
-
<td align="center">Ziya
|
115 |
-
<td align="center">Ziya-LLaMA-13B-v1</td>
|
116 |
<td align="center">82.7</td>
|
117 |
<td align="center">69.9</td>
|
118 |
<td align="center">92.1</td>
|
119 |
<td align="center">81.7</td>
|
120 |
-
<td align="center">85
|
121 |
<td align="center">74.7</td>
|
122 |
<td align="center">82.4</td>
|
123 |
<td align="center">80.8</td>
|
124 |
</tr>
|
125 |
<tr>
|
126 |
-
<td align="center">VisCPM-Chat
|
127 |
-
<td align="center">CPMBee-10B</td>
|
128 |
<td align="center">83.3</td>
|
129 |
<td align="center">68.9</td>
|
130 |
<td align="center">90.5</td>
|
131 |
<td align="center">81.1</td>
|
132 |
-
<td align="center"
|
133 |
<td align="center">76.1</td>
|
134 |
<td align="center">89.2</td>
|
135 |
<td align="center">86.3</td>
|
136 |
</tr>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
<tr>
|
138 |
-
<td align="center">
|
139 |
-
<td align="center">
|
140 |
-
<td align="center">
|
141 |
-
<td align="center">
|
142 |
-
|
143 |
-
|
144 |
-
<td align="center">
|
145 |
-
<td align="center">
|
146 |
-
<td align="center"
|
147 |
-
<td align="center"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
</tr>
|
149 |
</table>
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
|
|
|
|
153 |
|
154 |
-
|
|
|
|
|
155 |
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VisCPM
|
2 |
+
简体中文 | [English](README_en.md)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
<p align="center">
|
5 |
+
<p align="left">
|
6 |
+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
|
7 |
+
<a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
|
8 |
</p>
|
9 |
|
10 |
+
`VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (Q-Former) and visual decoder (Diffusion-UNet) to support visual inputs and outputs. Thanks to the good bilingual capability of CPM-Bee, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
|
11 |
|
12 |
+
`VisCPM`是一个开源的多模态大模型系列,支持中英双语的多模态对话能力(`VisCPM-Chat`模型)和文到图生成能力(`VisCPM-Paint`模型),在中文多模态开源模型中达到最佳水平。`VisCPM`基于百亿参数量语言大模型[CPM-Bee](https://github.com/OpenBMB/CPM-Bee)(10B)训练,融合视觉编码器(`Q-Former`)和视觉解码器(`Diffusion-UNet`)以支持视觉信号的输入和输出。得益于`CPM-Bee`底座优秀的双语能力,`VisCPM`可以仅通过英文多模态数据预训练,泛化实现优秀的中文多模态能力。
|
|
|
|
|
|
|
|
|
13 |
|
14 |
## VisCPM-Chat
|
15 |
+
`VisCPM-Chat`支持面向图像进行中英双语多模态对话。该模型使用`Q-Former`作为视觉编码器,使用CPM-Bee(10B)作为语言交互基底模型,并通过语言建模训练目标融合视觉和语言模型。模型训练包括预训练和指令精调两阶段:
|
16 |
|
17 |
+
* 预训练:我们使用约100M高质量英文图文对数据对`VisCPM-Chat`进行了预训练,数据包括CC3M、CC12M、COCO、Visual Genome、Laion等。在预训练阶段,语言模型参数保持固定,仅更新`Q-Former`部分参数,以支持大规模视觉-语言表示的高效对齐。
|
18 |
|
19 |
+
* 指令精调:我们采用[LLaVA-150K](https://llava-vl.github.io/)英文指令精调数据,并混合相应翻译后的中文数据对模型进行指令精调,以对齐模型多模态基础能力和用户使用意图。在指令精调阶段,我们更新全部模型参数,以提升指令精调数据的利用效率。有趣的是,我们发现即使仅采用英文指令数据进行指令精调,模型也可以理解中文问题,但仅能用英文回答。这表明模型的多语言多模态能力已经得到良好的泛化。在指令精调阶段进一步加入少量中文翻译数据,可以将模型回复语言和用户问题语言对齐。
|
20 |
|
21 |
+
我们在LLaVA英文测试集和翻译的中文测试集对模型进行了评测,该评测基准考察模型在开放域对话、图像细节描述、复杂推理方面的表现,并使用GPT-4进行打分。可以观察到,`VisCPM-Chat`在中文多模态能力方面取得了最佳的平均性能,在通用域对话和复杂推理表现出色,同时也表现出了不错的英文多模态能力。
|
22 |
|
23 |
<table>
|
24 |
<tr>
|
25 |
+
<td align="center" rowspan="2" colspan="2">模型</td>
|
26 |
+
<td align="center" colspan="4">英文</td>
|
27 |
+
<td align="center" colspan="4">中文</td>
|
|
|
28 |
</tr>
|
29 |
<tr>
|
30 |
+
<td align="center">多模态对话</td>
|
31 |
+
<td align="center">细节描述</td>
|
32 |
+
<td align="center">复杂推理</td>
|
33 |
+
<td align="center">平均</td>
|
34 |
+
<td align="center">多模态对话</td>
|
35 |
+
<td align="center">细节描述</td>
|
36 |
+
<td align="center">复杂推理</td>
|
37 |
+
<td align="center">平均</td>
|
38 |
</tr>
|
39 |
<tr>
|
40 |
+
<td align="center" rowspan="3">英文模型</td>
|
41 |
<td align="center">MiniGPT4</td>
|
42 |
+
<td align="center">65</td>
|
|
|
43 |
<td align="center">67.3</td>
|
44 |
<td align="center">76.6</td>
|
45 |
<td align="center">69.7</td>
|
|
|
50 |
</tr>
|
51 |
<tr>
|
52 |
<td align="center">InstructBLIP</td>
|
|
|
53 |
<td align="center">81.9</td>
|
54 |
+
<td align="center">68</td>
|
55 |
<td align="center">91.2</td>
|
56 |
<td align="center">80.5</td>
|
57 |
<td align="center">-</td>
|
|
|
61 |
</tr>
|
62 |
<tr>
|
63 |
<td align="center">LLaVA</td>
|
64 |
+
<td align="center">89.5</td>
|
65 |
+
<td align="center">70.4</td>
|
66 |
+
<td align="center">96.2</td>
|
67 |
+
<td align="center">85.6</td>
|
|
|
68 |
<td align="center">-</td>
|
69 |
<td align="center">-</td>
|
70 |
<td align="center">-</td>
|
71 |
<td align="center">-</td>
|
72 |
</tr>
|
73 |
<tr>
|
74 |
+
<td align="center" rowspan="4">中英双语</td>
|
75 |
<td align="center">mPLUG-Owl </td>
|
|
|
76 |
<td align="center">64.6</td>
|
77 |
<td align="center">47.7</td>
|
78 |
<td align="center">80.1</td>
|
|
|
80 |
<td align="center">76.3</td>
|
81 |
<td align="center">61.2</td>
|
82 |
<td align="center">77.8</td>
|
83 |
+
<td align="center">72</td>
|
84 |
</tr>
|
85 |
<tr>
|
86 |
<td align="center">VisualGLM</td>
|
|
|
87 |
<td align="center">62.4</td>
|
88 |
+
<td align="center">63</td>
|
89 |
<td align="center">80.6</td>
|
90 |
<td align="center">68.7</td>
|
91 |
<td align="center">76.6</td>
|
92 |
+
<td align="center">87.8</td>
|
93 |
<td align="center">83.6</td>
|
94 |
<td align="center">82.7</td>
|
95 |
</tr>
|
96 |
<tr>
|
97 |
+
<td align="center">Ziya (LLaMA 13B)</td>
|
|
|
98 |
<td align="center">82.7</td>
|
99 |
<td align="center">69.9</td>
|
100 |
<td align="center">92.1</td>
|
101 |
<td align="center">81.7</td>
|
102 |
+
<td align="center">85</td>
|
103 |
<td align="center">74.7</td>
|
104 |
<td align="center">82.4</td>
|
105 |
<td align="center">80.8</td>
|
106 |
</tr>
|
107 |
<tr>
|
108 |
+
<td align="center">VisCPM-Chat</td>
|
|
|
109 |
<td align="center">83.3</td>
|
110 |
<td align="center">68.9</td>
|
111 |
<td align="center">90.5</td>
|
112 |
<td align="center">81.1</td>
|
113 |
+
<td align="center">92.7</td>
|
114 |
<td align="center">76.1</td>
|
115 |
<td align="center">89.2</td>
|
116 |
<td align="center">86.3</td>
|
117 |
</tr>
|
118 |
+
</table>
|
119 |
+
|
120 |
+
## VisCPM-Paint
|
121 |
+
`VisCPM-Paint`支持中英双语的文到图生成。该模型使用CPM-Bee(10B)作为文本编码器,使用`UNet`作为图像解码器,并通过扩散模型训练目标融合语言和视觉模型。在训练过程中,语言模型参数始终保持固定。我们使用[Stable Diffusion 2.1](https://github.com/Stability-AI/stablediffusion)的UNet参数初始化视觉解码器,并通过逐步解冻其中关键的桥接参数将其与语言模型融合:首先训练文本表示映射到视觉模型的线性层,然后进一步解冻`UNet`的交叉注意力层。该模型在[LAION 2B](https://laion.ai/)英文图文对数据上进行了训练。
|
122 |
+
|
123 |
+
与`VisCPM-Chat`类似,我们发现得益于CPM-Bee的双语能力,`VisCPM-Paint`可以仅通过英文图文对训练,泛化实现良好的中文文到图生成能力,达到中文开源模型的最佳效果。通过进一步加入20M清洗后的原生中文图文对数据,以及120M翻译到中文的图文对数据,模型的中文文到图生成能力可以获得进一步提升。我们在MSCOCO上采样了3万张图片,计算了FID(Fréchet Inception Distance)和Clip Score,前者用于评估生成图片的质量,后面用于评估生成的图片与输入的匹配程度。
|
124 |
+
|
125 |
+
<table>
|
126 |
+
<tr>
|
127 |
+
<td align="center" rowspan="2">模型</td>
|
128 |
+
<td align="center" colspan="2">英文</td>
|
129 |
+
<td align="center" colspan="2">中文</td>
|
130 |
+
</tr>
|
131 |
<tr>
|
132 |
+
<td align="center">FID↓</td>
|
133 |
+
<td align="center">CLIP Score↑</td>
|
134 |
+
<td align="center">FID↓</td>
|
135 |
+
<td align="center">CLIP Score↑</td>
|
136 |
+
</tr>
|
137 |
+
<tr>
|
138 |
+
<td align="center">AltDiffusion</td>
|
139 |
+
<td align="center">17.16</td>
|
140 |
+
<td align="center">25.24</td>
|
141 |
+
<td align="center">16.09</td>
|
142 |
+
<td align="center">24.05</td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td align="center">TaiyiDiffusion</td>
|
146 |
+
<td align="center">-</td>
|
147 |
+
<td align="center">-</td>
|
148 |
+
<td align="center">15.58</td>
|
149 |
+
<td align="center">22.69</td>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td align="center">Stable Diffusion</td>
|
153 |
+
<td align="center">9.08</td>
|
154 |
+
<td align="center">26.22</td>
|
155 |
+
<td align="center">-</td>
|
156 |
+
<td align="center">-</td>
|
157 |
+
</tr>
|
158 |
+
<tr>
|
159 |
+
<td align="center">VisCPM-Paint-en</td>
|
160 |
+
<td align="center">9.51</td>
|
161 |
+
<td align="center">25.35</td>
|
162 |
+
<td align="center">10.86</td>
|
163 |
+
<td align="center">23.38</td>
|
164 |
+
</tr>
|
165 |
+
<tr>
|
166 |
+
<td align="center">VisCPM-Paint-zh</td>
|
167 |
+
<td align="center">9.98</td>
|
168 |
+
<td align="center">25.04</td>
|
169 |
+
<td align="center">9.65</td>
|
170 |
+
<td align="center">24.17</td>
|
171 |
</tr>
|
172 |
</table>
|
173 |
|
174 |
+
# 安装
|
175 |
+
|
176 |
+
```Shell
|
177 |
+
conda create -n viscpm python=3.10 -y
|
178 |
+
conda activate viscpm
|
179 |
+
pip install setuptools
|
180 |
+
pip install diffusers jieba matplotlib numpy opencv_python
|
181 |
+
pip install pandas Pillow psutil pydantic scipy
|
182 |
+
pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
|
183 |
+
pip install transformers==4.28.0
|
184 |
+
pip install tqdm typing_extensions
|
185 |
+
pip install git+https://github.com/thunlp/OpenDelta.git
|
186 |
+
pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
|
187 |
+
```
|
188 |
+
|
189 |
+
VisCPM需要单卡40GB以上的GPU运行,我们会在尽快更新更加节省显存的推理方式。
|
190 |
+
|
191 |
+
## 使用
|
192 |
|
193 |
+
```python
|
194 |
+
>>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
195 |
+
>>> from PIL import Image
|
196 |
|
197 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
|
198 |
+
>>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
|
199 |
+
>>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
|
200 |
|
201 |
+
>>> data = [{
|
202 |
+
>>> 'context': '',
|
203 |
+
>>> 'question': 'describe this image in detail.',
|
204 |
+
>>> 'image': tokenizer.unk_token * model.query_num,
|
205 |
+
>>> '<ans>': ''
|
206 |
+
>>> }]
|
207 |
+
>>> image = Image.open('case.jpg')
|
208 |
+
>>> result = model.generate(data, tokenizer, processor, image)
|
209 |
+
>>> print(result[0]['<ans>'])
|
210 |
+
这幅图片显示了一群热气球在天空中飞行。这些热气球漂浮在不同的地方,包括山脉、城市和乡村地区。
|
211 |
+
```
|
README_en.md
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VisCPM
|
2 |
+
[简体中文](README.md) | English
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<p align="left">
|
6 |
+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
|
7 |
+
<a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
|
8 |
+
</p>
|
9 |
+
|
10 |
+
`VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (`Q-Former`) and visual decoder (`Diffusion-UNet`) to support visual inputs and outputs. Thanks to the good bilingual capability of `CPM-Bee`, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
|
11 |
+
|
12 |
+
## VisCPM-Chat
|
13 |
+
`VisCPM-Chat` supports bilingual multimodal conversations involving images in both Chinese and English. The model utilizes `Q-Former` as the visual encoder and CPM-Bee (10B) as the base LLM. It combines visual and language models through language modeling training objectives. The model training consists of two stages: pretraining and instruction fine-tuning.
|
14 |
+
|
15 |
+
* Pretrain: `VisCPM-Chat` was pretrained using approximately 100 million high-quality English multimodal data pairs. The data sources include CC3M, CC12M, COCO, Visual Genome, Laion, and others. In this stage, the language model parameters remain fixed, and only the parameters of the `Q-Former` are updated to enable efficient alignment of large-scale visual-language representations.
|
16 |
+
|
17 |
+
* Instruction fine-tuning: We utilized the [LLaVA-150K](https://llava-vl.github.io/) dataset, which consists of English multimodal instruction-following dataset. We mixed this data with corresponding translated Chinese data to fine-tune the model and align its multimodal capabilities with user intents. In this phase, we updated all model parameters to improve the utilization efficiency of the instruction fine-tuning data. Interestingly, we observed that even when using only English instruction data for fine-tuning, the model can comprehend Chinese questions but can only respond in English. This indicates that the model has achieved good generalization in terms of its multilingual and multimodal capabilities. By incorporating a small amount of translated Chinese data during the instruction fine-tuning phase, we can align the model's response language with the user's question language.
|
18 |
+
|
19 |
+
We evaluated the model on the LLaVA English test set and the translated Chinese test set. The evaluation benchmark examined the model's performance in open-domain conversations, image detail descriptions, and complex reasoning tasks, using GPT-4 for scoring. It is evident that `VisCPM-Chat` achieved the best average performance in Chinese multimodal capabilities, excelling in general-domain conversations and complex reasoning. It also demonstrated commendable English multimodal abilities.
|
20 |
+
|
21 |
+
<table>
|
22 |
+
<tr>
|
23 |
+
<td align="center" rowspan="2" colspan="2">Model</td>
|
24 |
+
<td align="center" colspan="4">English</td>
|
25 |
+
<td align="center" colspan="4">Chinese</td>
|
26 |
+
</tr>
|
27 |
+
<tr>
|
28 |
+
<td align="center">Conversation</td>
|
29 |
+
<td align="center">Detailed Description</td>
|
30 |
+
<td align="center">Complex Reasoning</td>
|
31 |
+
<td align="center">All</td>
|
32 |
+
<td align="center">Conversation</td>
|
33 |
+
<td align="center">Detailed Description</td>
|
34 |
+
<td align="center">Complex Reasoning</td>
|
35 |
+
<td align="center">All</td>
|
36 |
+
</tr>
|
37 |
+
<tr>
|
38 |
+
<td align="center" rowspan="3">English Model</td>
|
39 |
+
<td align="center">MiniGPT4</td>
|
40 |
+
<td align="center">65</td>
|
41 |
+
<td align="center">67.3</td>
|
42 |
+
<td align="center">76.6</td>
|
43 |
+
<td align="center">69.7</td>
|
44 |
+
<td align="center">-</td>
|
45 |
+
<td align="center">-</td>
|
46 |
+
<td align="center">-</td>
|
47 |
+
<td align="center">-</td>
|
48 |
+
</tr>
|
49 |
+
<tr>
|
50 |
+
<td align="center">InstructBLIP</td>
|
51 |
+
<td align="center">81.9</td>
|
52 |
+
<td align="center">68</td>
|
53 |
+
<td align="center">91.2</td>
|
54 |
+
<td align="center">80.5</td>
|
55 |
+
<td align="center">-</td>
|
56 |
+
<td align="center">-</td>
|
57 |
+
<td align="center">-</td>
|
58 |
+
<td align="center">-</td>
|
59 |
+
</tr>
|
60 |
+
<tr>
|
61 |
+
<td align="center">LLaVA</td>
|
62 |
+
<td align="center">89.5</td>
|
63 |
+
<td align="center">70.4</td>
|
64 |
+
<td align="center">96.2</td>
|
65 |
+
<td align="center">85.6</td>
|
66 |
+
<td align="center">-</td>
|
67 |
+
<td align="center">-</td>
|
68 |
+
<td align="center">-</td>
|
69 |
+
<td align="center">-</td>
|
70 |
+
</tr>
|
71 |
+
<tr>
|
72 |
+
<td align="center" rowspan="4">En-Zh Bilingual Model</td>
|
73 |
+
<td align="center">mPLUG-Owl </td>
|
74 |
+
<td align="center">64.6</td>
|
75 |
+
<td align="center">47.7</td>
|
76 |
+
<td align="center">80.1</td>
|
77 |
+
<td align="center">64.2</td>
|
78 |
+
<td align="center">76.3</td>
|
79 |
+
<td align="center">61.2</td>
|
80 |
+
<td align="center">77.8</td>
|
81 |
+
<td align="center">72</td>
|
82 |
+
</tr>
|
83 |
+
<tr>
|
84 |
+
<td align="center">VisualGLM</td>
|
85 |
+
<td align="center">62.4</td>
|
86 |
+
<td align="center">63</td>
|
87 |
+
<td align="center">80.6</td>
|
88 |
+
<td align="center">68.7</td>
|
89 |
+
<td align="center">76.6</td>
|
90 |
+
<td align="center">87.8</td>
|
91 |
+
<td align="center">83.6</td>
|
92 |
+
<td align="center">82.7</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td align="center">Ziya (LLaMA 13B)</td>
|
96 |
+
<td align="center">82.7</td>
|
97 |
+
<td align="center">69.9</td>
|
98 |
+
<td align="center">92.1</td>
|
99 |
+
<td align="center">81.7</td>
|
100 |
+
<td align="center">85</td>
|
101 |
+
<td align="center">74.7</td>
|
102 |
+
<td align="center">82.4</td>
|
103 |
+
<td align="center">80.8</td>
|
104 |
+
</tr>
|
105 |
+
<tr>
|
106 |
+
<td align="center">VisCPM-Chat</td>
|
107 |
+
<td align="center">83.3</td>
|
108 |
+
<td align="center">68.9</td>
|
109 |
+
<td align="center">90.5</td>
|
110 |
+
<td align="center">81.1</td>
|
111 |
+
<td align="center">92.7</td>
|
112 |
+
<td align="center">76.1</td>
|
113 |
+
<td align="center">89.2</td>
|
114 |
+
<td align="center">86.3</td>
|
115 |
+
</tr>
|
116 |
+
</table>
|
117 |
+
|
118 |
+
# Install
|
119 |
+
|
120 |
+
1. Clone this repository and navigate to source folder
|
121 |
+
```bash
|
122 |
+
git clone <github repo URL>
|
123 |
+
cd viscpm
|
124 |
+
```
|
125 |
+
|
126 |
+
2. Install Package
|
127 |
+
```Shell
|
128 |
+
conda create -n viscpm python=3.10 -y
|
129 |
+
conda activate viscpm
|
130 |
+
pip install setuptools
|
131 |
+
pip install diffusers jieba matplotlib numpy opencv_python
|
132 |
+
pip install pandas Pillow psutil pydantic scipy
|
133 |
+
pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
|
134 |
+
pip install transformers==4.28.0
|
135 |
+
pip install tqdm typing_extensions
|
136 |
+
pip install git+https://github.com/thunlp/OpenDelta.git
|
137 |
+
pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
|
138 |
+
```
|
139 |
+
|
140 |
+
`VisCPM` require GPUs with more than 40GB memory. We will soon update more memory-friendly inference methods.
|
141 |
+
|
142 |
+
## How to use
|
143 |
+
|
144 |
+
```python
|
145 |
+
>>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
146 |
+
>>> from PIL import Image
|
147 |
+
|
148 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
|
149 |
+
>>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
|
150 |
+
>>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
|
151 |
+
|
152 |
+
>>> data = [{
|
153 |
+
>>> 'context': '',
|
154 |
+
>>> 'question': 'describe this image in detail.',
|
155 |
+
>>> 'image': tokenizer.unk_token * model.query_num,
|
156 |
+
>>> '<ans>': ''
|
157 |
+
>>> }]
|
158 |
+
>>> image = Image.open('case.jpg')
|
159 |
+
>>> result = model.generate(data, tokenizer, processor, image)
|
160 |
+
>>> print(result[0]['<ans>'])
|
161 |
+
这幅图片显示了一群热气球在天空中飞行。这些热气球漂浮在不同的地方,包括山脉、城市和乡村地区。
|
162 |
+
```
|
beit3.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
|
4 |
+
# Copyright (c) 2023 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# --------------------------------------------------------'
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
12 |
+
from timm.models.registry import register_model
|
13 |
+
|
14 |
+
from torchscale.model.BEiT3 import BEiT3
|
15 |
+
from torchscale.architecture.config import EncoderConfig
|
16 |
+
|
17 |
+
|
18 |
+
def trunc_normal_(tensor, mean=0., std=1.):
|
19 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
20 |
+
|
21 |
+
|
22 |
+
def _get_base_config(
|
23 |
+
img_size=224, patch_size=16, drop_path_rate=0,
|
24 |
+
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
|
25 |
+
):
|
26 |
+
return EncoderConfig(
|
27 |
+
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
|
28 |
+
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
|
29 |
+
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
|
30 |
+
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
|
31 |
+
checkpoint_activations=checkpoint_activations,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def _get_large_config(
|
36 |
+
img_size=224, patch_size=16, drop_path_rate=0,
|
37 |
+
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
|
38 |
+
):
|
39 |
+
return EncoderConfig(
|
40 |
+
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
|
41 |
+
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
|
42 |
+
drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
|
43 |
+
encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
|
44 |
+
checkpoint_activations=checkpoint_activations,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class BEiT3Wrapper(nn.Module):
|
49 |
+
def __init__(self, args, **kwargs):
|
50 |
+
super().__init__()
|
51 |
+
self.args = args
|
52 |
+
self.beit3 = BEiT3(args)
|
53 |
+
self.apply(self._init_weights)
|
54 |
+
self.mim_head = nn.Linear(1024, 8192)
|
55 |
+
self.num_img_patches = self.beit3.vision_embed.num_position_embeddings()
|
56 |
+
self.hidden_size = args.encoder_embed_dim
|
57 |
+
|
58 |
+
def fix_init_weight(self):
|
59 |
+
def rescale(param, layer_id):
|
60 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
61 |
+
|
62 |
+
for layer_id, layer in enumerate(self.blocks):
|
63 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
64 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
65 |
+
|
66 |
+
def get_num_layers(self):
|
67 |
+
return self.beit3.encoder.num_layers
|
68 |
+
|
69 |
+
@torch.jit.ignore
|
70 |
+
def no_weight_decay(self):
|
71 |
+
return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
|
72 |
+
|
73 |
+
def _init_weights(self, m):
|
74 |
+
if isinstance(m, nn.Linear):
|
75 |
+
trunc_normal_(m.weight, std=.02)
|
76 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
77 |
+
nn.init.constant_(m.bias, 0)
|
78 |
+
elif isinstance(m, nn.LayerNorm):
|
79 |
+
nn.init.constant_(m.bias, 0)
|
80 |
+
nn.init.constant_(m.weight, 1.0)
|
81 |
+
|
82 |
+
def forward(self, pixel_values, query_embed=None):
|
83 |
+
B = pixel_values.size(0)
|
84 |
+
dtype = self.beit3.vision_embed.proj.weight.dtype
|
85 |
+
pixel_values = pixel_values.to(dtype)
|
86 |
+
token_embeddings = self.beit3.vision_embed(pixel_values)
|
87 |
+
multiway_split_position = -1
|
88 |
+
if query_embed is not None:
|
89 |
+
query_embed = torch.stack([query_embed] * B)
|
90 |
+
multiway_split_position = token_embeddings.size(1)
|
91 |
+
token_embeddings = torch.cat([token_embeddings, query_embed], dim=1)
|
92 |
+
|
93 |
+
outputs = self.beit3.encoder(
|
94 |
+
src_tokens=None,
|
95 |
+
token_embeddings=token_embeddings,
|
96 |
+
multiway_split_position=multiway_split_position
|
97 |
+
)
|
98 |
+
vision_hidden_states = outputs["encoder_out"]
|
99 |
+
if query_embed is not None:
|
100 |
+
vision_hidden_states = vision_hidden_states[:, self.num_img_patches:]
|
101 |
+
return vision_hidden_states
|
102 |
+
|
103 |
+
|
104 |
+
@register_model
|
105 |
+
def beit3_large_patch16_224(pretrained=False, **kwargs):
|
106 |
+
args = _get_large_config(img_size=224, **kwargs)
|
107 |
+
model = BEiT3Wrapper(args, **kwargs)
|
108 |
+
return model
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"_name_or_path": "openbmb/viscpmchat-bee-10b",
|
4 |
+
"architectures": [
|
5 |
+
"VisCpmBeeForCausalLM"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_viscpmchatbee.VisCpmChatBeeConfig",
|
9 |
+
"AutoModel": "modeling_cpmbee.VisCpmBeeForCausalLM",
|
10 |
+
"AutoModelForCausalLM": "modeling_cpmbee.VisCpmBeeForCausalLM"
|
11 |
+
},
|
12 |
+
"vocab_size": 86583,
|
13 |
+
"hidden_size": 4096,
|
14 |
+
"dim_ff" : 10240,
|
15 |
+
"num_hidden_layers" : 48,
|
16 |
+
"num_attention_heads": 32,
|
17 |
+
"dim_head" : 128,
|
18 |
+
"dropout_p" : 0.0,
|
19 |
+
"position_bias_num_buckets" : 256,
|
20 |
+
"position_bias_num_segment_buckets": 256,
|
21 |
+
"position_bias_max_distance" : 2048,
|
22 |
+
"vision_dim": 1024,
|
23 |
+
"query_num": 64,
|
24 |
+
"eps" : 1e-6,
|
25 |
+
"half" : true,
|
26 |
+
"model_type": "viscpmchatbee"
|
27 |
+
}
|
configuration_viscpmchatbee.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" CpmBee model configuration"""
|
16 |
+
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/resolve/main/config.json",
|
27 |
+
# See all VisCpmBee models at https://huggingface.co/models?filter=viscpmbee
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class VisCpmChatBeeConfig(PretrainedConfig):
|
32 |
+
r"""
|
33 |
+
This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
|
34 |
+
CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
35 |
+
with the defaults will yield a similar configuration to that of the CPMBee
|
36 |
+
[openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
|
37 |
+
|
38 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
39 |
+
documentation from [`PretrainedConfig`] for more information.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 30720):
|
43 |
+
Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
|
44 |
+
`input` passed when calling [`CpmBeeModel`].
|
45 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
46 |
+
Dimension of the encoder layers.
|
47 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
48 |
+
Number of attention heads in the Transformer encoder.
|
49 |
+
dim_head (`int`, *optional*, defaults to 128):
|
50 |
+
Dimension of attention heads for each attention layer in the Transformer encoder.
|
51 |
+
dim_ff (`int`, *optional*, defaults to 10240):
|
52 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
53 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
54 |
+
Number of layers of the Transformer encoder.
|
55 |
+
dropout_p (`float`, *optional*, defaults to 0.1):
|
56 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
|
57 |
+
position_bias_num_buckets (`int`, *optional*, defaults to 512):
|
58 |
+
The number of position_bias buckets.
|
59 |
+
position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
|
60 |
+
The number of segment buckets.
|
61 |
+
position_bias_max_distance (`int`, *optional*, defaults to 2048):
|
62 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
63 |
+
just in case (e.g., 512 or 1024 or 2048).
|
64 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
65 |
+
The epsilon used by the layer normalization layers.
|
66 |
+
init_std (`float`, *optional*, defaults to 1.0):
|
67 |
+
Initialize parameters with std = init_std.
|
68 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
69 |
+
Whether to use cache.
|
70 |
+
distance_scale (`float` or `int`, *optional*, defaults to 16):
|
71 |
+
Scale the rotary embedding.
|
72 |
+
mask_modules (`list` or `tuple`, *optional*, defaults to None):
|
73 |
+
Decides which feedforward block or attention block is pruned.
|
74 |
+
half (`bool`, *optional*, defaults to `False`):
|
75 |
+
Decides the model parameters are half-precision or not.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
|
79 |
+
```python
|
80 |
+
>>> from transformers import CpmBeeModel, CpmBeeConfig
|
81 |
+
|
82 |
+
>>> # Initializing a CPMBee cpm-bee-10b style configuration
|
83 |
+
>>> configuration = CpmBeeConfig()
|
84 |
+
|
85 |
+
>>> # Initializing a model from the cpm-bee-10b style configuration
|
86 |
+
>>> model = CpmBeeModel(configuration)
|
87 |
+
|
88 |
+
>>> # Accessing the model configuration
|
89 |
+
>>> configuration = model.config
|
90 |
+
```"""
|
91 |
+
model_type = "viscpmchatbee"
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
vocab_size: int = 30720,
|
96 |
+
hidden_size: int = 4096,
|
97 |
+
num_attention_heads: int = 64,
|
98 |
+
dim_head: int = 64,
|
99 |
+
dim_ff: int = 10240,
|
100 |
+
num_hidden_layers: int = 32,
|
101 |
+
dropout_p: int = 0.0,
|
102 |
+
position_bias_num_buckets: int = 256,
|
103 |
+
position_bias_num_segment_buckets: int = 32,
|
104 |
+
position_bias_max_distance: int = 2048,
|
105 |
+
eps: int = 1e-6,
|
106 |
+
init_std: float = 1.0,
|
107 |
+
use_cache: bool = True,
|
108 |
+
distance_scale: Union[int, float] = 16,
|
109 |
+
mask_modules: Optional[Union[List, Tuple]] = None,
|
110 |
+
half: bool = False,
|
111 |
+
vision_dim: int = 1024,
|
112 |
+
query_num: int = 64,
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
super().__init__(**kwargs)
|
116 |
+
self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
|
117 |
+
self.hidden_size = hidden_size
|
118 |
+
self.num_attention_heads = num_attention_heads
|
119 |
+
self.dim_head = dim_head
|
120 |
+
self.dim_ff = dim_ff
|
121 |
+
self.num_hidden_layers = num_hidden_layers
|
122 |
+
self.position_bias_num_buckets = position_bias_num_buckets
|
123 |
+
self.position_bias_max_distance = position_bias_max_distance
|
124 |
+
self.dropout_p = dropout_p
|
125 |
+
self.eps = eps
|
126 |
+
self.use_cache = use_cache
|
127 |
+
self.vocab_size = vocab_size
|
128 |
+
self.init_std = init_std
|
129 |
+
self.distance_scale = distance_scale
|
130 |
+
self.half = half
|
131 |
+
self.mask_modules = mask_modules
|
132 |
+
self.vision_dim = vision_dim
|
133 |
+
self.query_num = query_num
|
feature_extraction_viscpmchatbee.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
from transformers.utils import logging
|
4 |
+
from processing_viscpmchatbee import VisCpmChatBeeImageProcessor
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class VisCpmChatBeeFeatureExtractor(VisCpmChatBeeImageProcessor):
|
11 |
+
def __init__(self, *args, **kwargs) -> None:
|
12 |
+
warnings.warn(
|
13 |
+
"The class VisCpmBeeFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
14 |
+
" use CLIPImageProcessor instead.",
|
15 |
+
FutureWarning,
|
16 |
+
)
|
17 |
+
super().__init__(*args, **kwargs)
|
generation_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_beams": 3,
|
3 |
+
"num_beam_groups": 1,
|
4 |
+
"do_sample": false,
|
5 |
+
"is_constraint_gen_mode": false,
|
6 |
+
"is_contrastive_search_gen_mode": false,
|
7 |
+
"pad_token_id": 0,
|
8 |
+
"eos_token_id": 7,
|
9 |
+
"bos_token_id": 6,
|
10 |
+
"max_new_tokens": 100,
|
11 |
+
"vocab_size": 86583
|
12 |
+
}
|
modeling_cpmbee.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
preprocessor_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"image_processor_type": "VisCpmChatBeeImageProcessor",
|
3 |
+
"is_train": false,
|
4 |
+
"randaug": false,
|
5 |
+
"input_size": 224,
|
6 |
+
"interpolation": "bicubic",
|
7 |
+
"auto_map": {
|
8 |
+
"AutoImageProcessor": "processing_viscpmchatbee.VisCpmChatBeeImageProcessor"
|
9 |
+
}
|
10 |
+
}
|
processing_viscpmchatbee.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
5 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
6 |
+
from torchvision import transforms
|
7 |
+
import urllib
|
8 |
+
from tqdm import tqdm
|
9 |
+
from cpm_live.tokenizers import CPMBeeTokenizer
|
10 |
+
from torch.utils.data import default_collate
|
11 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
12 |
+
from typing_extensions import TypedDict
|
13 |
+
from numpy.typing import NDArray
|
14 |
+
import importlib.machinery
|
15 |
+
import importlib.util
|
16 |
+
import types
|
17 |
+
import random
|
18 |
+
from transformers.image_utils import make_list_of_images
|
19 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
20 |
+
from transformers import TensorType
|
21 |
+
import json
|
22 |
+
|
23 |
+
|
24 |
+
# aug functions
|
25 |
+
def identity_func(img):
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def autocontrast_func(img, cutoff=0):
|
30 |
+
'''
|
31 |
+
same output as PIL.ImageOps.autocontrast
|
32 |
+
'''
|
33 |
+
n_bins = 256
|
34 |
+
|
35 |
+
def tune_channel(ch):
|
36 |
+
n = ch.size
|
37 |
+
cut = cutoff * n // 100
|
38 |
+
if cut == 0:
|
39 |
+
high, low = ch.max(), ch.min()
|
40 |
+
else:
|
41 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
42 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
43 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
44 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
45 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
46 |
+
if high <= low:
|
47 |
+
table = np.arange(n_bins)
|
48 |
+
else:
|
49 |
+
scale = (n_bins - 1) / (high - low)
|
50 |
+
table = np.arange(n_bins) * scale - low * scale
|
51 |
+
table[table < 0] = 0
|
52 |
+
table[table > n_bins - 1] = n_bins - 1
|
53 |
+
table = table.clip(0, 255).astype(np.uint8)
|
54 |
+
return table[ch]
|
55 |
+
|
56 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
57 |
+
out = cv2.merge(channels)
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
def equalize_func(img):
|
62 |
+
'''
|
63 |
+
same output as PIL.ImageOps.equalize
|
64 |
+
PIL's implementation is different from cv2.equalize
|
65 |
+
'''
|
66 |
+
n_bins = 256
|
67 |
+
|
68 |
+
def tune_channel(ch):
|
69 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
70 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
71 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
72 |
+
if step == 0:
|
73 |
+
return ch
|
74 |
+
n = np.empty_like(hist)
|
75 |
+
n[0] = step // 2
|
76 |
+
n[1:] = hist[:-1]
|
77 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
78 |
+
return table[ch]
|
79 |
+
|
80 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
81 |
+
out = cv2.merge(channels)
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
86 |
+
'''
|
87 |
+
like PIL, rotate by degree, not radians
|
88 |
+
'''
|
89 |
+
H, W = img.shape[0], img.shape[1]
|
90 |
+
center = W / 2, H / 2
|
91 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
92 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
def solarize_func(img, thresh=128):
|
97 |
+
'''
|
98 |
+
same output as PIL.ImageOps.posterize
|
99 |
+
'''
|
100 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
101 |
+
table = table.clip(0, 255).astype(np.uint8)
|
102 |
+
out = table[img]
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
def color_func(img, factor):
|
107 |
+
'''
|
108 |
+
same output as PIL.ImageEnhance.Color
|
109 |
+
'''
|
110 |
+
# implementation according to PIL definition, quite slow
|
111 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
112 |
+
# out = blend(degenerate, img, factor)
|
113 |
+
# M = (
|
114 |
+
# np.eye(3) * factor
|
115 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
116 |
+
# )[np.newaxis, np.newaxis, :]
|
117 |
+
M = (
|
118 |
+
np.float32([
|
119 |
+
[0.886, -0.114, -0.114],
|
120 |
+
[-0.587, 0.413, -0.587],
|
121 |
+
[-0.299, -0.299, 0.701]]) * factor
|
122 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
123 |
+
)
|
124 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
def contrast_func(img, factor):
|
129 |
+
"""
|
130 |
+
same output as PIL.ImageEnhance.Contrast
|
131 |
+
"""
|
132 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
133 |
+
table = np.array([(
|
134 |
+
el - mean) * factor + mean
|
135 |
+
for el in range(256)
|
136 |
+
]).clip(0, 255).astype(np.uint8)
|
137 |
+
out = table[img]
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def brightness_func(img, factor):
|
142 |
+
'''
|
143 |
+
same output as PIL.ImageEnhance.Contrast
|
144 |
+
'''
|
145 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
146 |
+
out = table[img]
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def sharpness_func(img, factor):
|
151 |
+
'''
|
152 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
153 |
+
areas are same
|
154 |
+
'''
|
155 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
156 |
+
kernel[1][1] = 5
|
157 |
+
kernel /= 13
|
158 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
159 |
+
if factor == 0.0:
|
160 |
+
out = degenerate
|
161 |
+
elif factor == 1.0:
|
162 |
+
out = img
|
163 |
+
else:
|
164 |
+
out = img.astype(np.float32)
|
165 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
166 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
167 |
+
out = out.astype(np.uint8)
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
172 |
+
H, W = img.shape[0], img.shape[1]
|
173 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
174 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
179 |
+
'''
|
180 |
+
same output as PIL.Image.transform
|
181 |
+
'''
|
182 |
+
H, W = img.shape[0], img.shape[1]
|
183 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
184 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
189 |
+
'''
|
190 |
+
same output as PIL.Image.transform
|
191 |
+
'''
|
192 |
+
H, W = img.shape[0], img.shape[1]
|
193 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
194 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
def posterize_func(img, bits):
|
199 |
+
'''
|
200 |
+
same output as PIL.ImageOps.posterize
|
201 |
+
'''
|
202 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
203 |
+
return out
|
204 |
+
|
205 |
+
|
206 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
207 |
+
H, W = img.shape[0], img.shape[1]
|
208 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
209 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
210 |
+
return out
|
211 |
+
|
212 |
+
|
213 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
214 |
+
replace = np.array(replace, dtype=np.uint8)
|
215 |
+
H, W = img.shape[0], img.shape[1]
|
216 |
+
rh, rw = np.random.random(2)
|
217 |
+
pad_size = pad_size // 2
|
218 |
+
ch, cw = int(rh * H), int(rw * W)
|
219 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
220 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
221 |
+
out = img.copy()
|
222 |
+
out[x1:x2, y1:y2, :] = replace
|
223 |
+
return out
|
224 |
+
|
225 |
+
|
226 |
+
# level to args
|
227 |
+
def enhance_level_to_args(MAX_LEVEL):
|
228 |
+
def level_to_args(level):
|
229 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
230 |
+
return level_to_args
|
231 |
+
|
232 |
+
|
233 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
234 |
+
def level_to_args(level):
|
235 |
+
level = (level / MAX_LEVEL) * 0.3
|
236 |
+
if np.random.random() > 0.5:
|
237 |
+
level = -level
|
238 |
+
return (level, replace_value)
|
239 |
+
|
240 |
+
return level_to_args
|
241 |
+
|
242 |
+
|
243 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
244 |
+
def level_to_args(level):
|
245 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
246 |
+
if np.random.random() > 0.5:
|
247 |
+
level = -level
|
248 |
+
return (level, replace_value)
|
249 |
+
|
250 |
+
return level_to_args
|
251 |
+
|
252 |
+
|
253 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
254 |
+
def level_to_args(level):
|
255 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
256 |
+
return (level, replace_value)
|
257 |
+
|
258 |
+
return level_to_args
|
259 |
+
|
260 |
+
|
261 |
+
def solarize_level_to_args(MAX_LEVEL):
|
262 |
+
def level_to_args(level):
|
263 |
+
level = int((level / MAX_LEVEL) * 256)
|
264 |
+
return (level, )
|
265 |
+
return level_to_args
|
266 |
+
|
267 |
+
|
268 |
+
def none_level_to_args(level):
|
269 |
+
return ()
|
270 |
+
|
271 |
+
|
272 |
+
def posterize_level_to_args(MAX_LEVEL):
|
273 |
+
def level_to_args(level):
|
274 |
+
level = int((level / MAX_LEVEL) * 4)
|
275 |
+
return (level, )
|
276 |
+
return level_to_args
|
277 |
+
|
278 |
+
|
279 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
280 |
+
def level_to_args(level):
|
281 |
+
level = (level / MAX_LEVEL) * 30
|
282 |
+
if np.random.random() < 0.5:
|
283 |
+
level = -level
|
284 |
+
return (level, replace_value)
|
285 |
+
|
286 |
+
return level_to_args
|
287 |
+
|
288 |
+
|
289 |
+
func_dict = {
|
290 |
+
'Identity': identity_func,
|
291 |
+
'AutoContrast': autocontrast_func,
|
292 |
+
'Equalize': equalize_func,
|
293 |
+
'Rotate': rotate_func,
|
294 |
+
'Solarize': solarize_func,
|
295 |
+
'Color': color_func,
|
296 |
+
'Contrast': contrast_func,
|
297 |
+
'Brightness': brightness_func,
|
298 |
+
'Sharpness': sharpness_func,
|
299 |
+
'ShearX': shear_x_func,
|
300 |
+
'TranslateX': translate_x_func,
|
301 |
+
'TranslateY': translate_y_func,
|
302 |
+
'Posterize': posterize_func,
|
303 |
+
'ShearY': shear_y_func,
|
304 |
+
}
|
305 |
+
|
306 |
+
translate_const = 10
|
307 |
+
MAX_LEVEL = 10
|
308 |
+
replace_value = (128, 128, 128)
|
309 |
+
arg_dict = {
|
310 |
+
'Identity': none_level_to_args,
|
311 |
+
'AutoContrast': none_level_to_args,
|
312 |
+
'Equalize': none_level_to_args,
|
313 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
314 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
315 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
316 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
317 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
318 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
319 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
320 |
+
'TranslateX': translate_level_to_args(
|
321 |
+
translate_const, MAX_LEVEL, replace_value
|
322 |
+
),
|
323 |
+
'TranslateY': translate_level_to_args(
|
324 |
+
translate_const, MAX_LEVEL, replace_value
|
325 |
+
),
|
326 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
327 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
328 |
+
}
|
329 |
+
|
330 |
+
|
331 |
+
class RandomAugment(object):
|
332 |
+
|
333 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
334 |
+
self.N = N
|
335 |
+
self.M = M
|
336 |
+
self.isPIL = isPIL
|
337 |
+
if augs:
|
338 |
+
self.augs = augs
|
339 |
+
else:
|
340 |
+
self.augs = list(arg_dict.keys())
|
341 |
+
|
342 |
+
def get_random_ops(self):
|
343 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
344 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
345 |
+
|
346 |
+
def __call__(self, img):
|
347 |
+
if self.isPIL:
|
348 |
+
img = np.array(img)
|
349 |
+
ops = self.get_random_ops()
|
350 |
+
for name, prob, level in ops:
|
351 |
+
if np.random.random() > prob:
|
352 |
+
continue
|
353 |
+
args = arg_dict[name](level)
|
354 |
+
img = func_dict[name](img, *args)
|
355 |
+
return img
|
356 |
+
|
357 |
+
|
358 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
|
359 |
+
if is_train:
|
360 |
+
t = [
|
361 |
+
RandomResizedCropAndInterpolation(
|
362 |
+
input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
363 |
+
transforms.RandomHorizontalFlip(),
|
364 |
+
]
|
365 |
+
if randaug:
|
366 |
+
t.append(
|
367 |
+
RandomAugment(
|
368 |
+
2, 7, isPIL=True,
|
369 |
+
augs=[
|
370 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
371 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
372 |
+
]))
|
373 |
+
t += [
|
374 |
+
transforms.ToTensor(),
|
375 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
376 |
+
]
|
377 |
+
t = transforms.Compose(t)
|
378 |
+
else:
|
379 |
+
t = transforms.Compose([
|
380 |
+
transforms.Resize((input_size, input_size),
|
381 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
382 |
+
transforms.ToTensor(),
|
383 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
|
384 |
+
])
|
385 |
+
|
386 |
+
return t
|
387 |
+
|
388 |
+
|
389 |
+
class VisCpmChatBeeImageProcessor(BaseImageProcessor):
|
390 |
+
def __init__(self, is_train, randaug=True, input_size=224, interpolation='bicubic', **kwargs):
|
391 |
+
super().__init__(**kwargs)
|
392 |
+
self.is_train = is_train
|
393 |
+
self.randaug = randaug
|
394 |
+
self.input_size = input_size
|
395 |
+
self.interpolation = interpolation
|
396 |
+
self._transform = build_transform(is_train, randaug=randaug, input_size=input_size, interpolation=interpolation)
|
397 |
+
|
398 |
+
def preprocess(self, images, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs) -> BatchFeature:
|
399 |
+
images = make_list_of_images(images)
|
400 |
+
images = [self._transform(image) for image in images]
|
401 |
+
images = torch.tensor([image.numpy() for image in images])
|
402 |
+
|
403 |
+
data = {"pixel_values": images}
|
404 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
405 |
+
|
406 |
+
def to_json_string(self) -> str:
|
407 |
+
"""
|
408 |
+
Serializes this instance to a JSON string.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
`str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
|
412 |
+
"""
|
413 |
+
dictionary = self.to_dict()
|
414 |
+
|
415 |
+
for key, value in dictionary.items():
|
416 |
+
if isinstance(value, np.ndarray):
|
417 |
+
dictionary[key] = value.tolist()
|
418 |
+
|
419 |
+
# make sure private name "_processor_class" is correctly
|
420 |
+
# saved as "processor_class"
|
421 |
+
_processor_class = dictionary.pop("_processor_class", None)
|
422 |
+
if _processor_class is not None:
|
423 |
+
dictionary["processor_class"] = _processor_class
|
424 |
+
_transform = dictionary.pop("_transform", None)
|
425 |
+
if _transform is not None:
|
426 |
+
dictionary["_transform"] = str(type(_transform))
|
427 |
+
|
428 |
+
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
tokenization_viscpmchatbee.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for CpmBee."""
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from numpy.typing import NDArray
|
22 |
+
from typing_extensions import TypedDict
|
23 |
+
|
24 |
+
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
25 |
+
from transformers.tokenization_utils_base import AddedToken, BatchEncoding, TextInput, TruncationStrategy
|
26 |
+
from transformers.utils import logging
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
32 |
+
|
33 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
34 |
+
"vocab_file": {
|
35 |
+
"openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/blob/main/vocab.txt",
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
40 |
+
"openbmb/viscpmchat-bee-10b": 4096,
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class _PrevExtTableStates(TypedDict):
|
45 |
+
ext_table: Dict[int, str]
|
46 |
+
token_id_table: Dict[str, Dict[int, int]]
|
47 |
+
|
48 |
+
|
49 |
+
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
|
50 |
+
|
51 |
+
|
52 |
+
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
|
53 |
+
ret = n_up * max_depth + n_down
|
54 |
+
if ret == 0:
|
55 |
+
return ret
|
56 |
+
else:
|
57 |
+
# bucket 1 is reserved for incontext samples
|
58 |
+
return ret + 1
|
59 |
+
|
60 |
+
|
61 |
+
class _DictTree(TypedDict):
|
62 |
+
value: str
|
63 |
+
children: List["_DictTree"]
|
64 |
+
depth: int
|
65 |
+
segment_id: int
|
66 |
+
need_predict: bool
|
67 |
+
is_image: bool
|
68 |
+
|
69 |
+
|
70 |
+
class VisCpmChatBeeTokenizer(PreTrainedTokenizer):
|
71 |
+
"""
|
72 |
+
Construct a CPMBee tokenizer.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
vocab_file (`str`):
|
76 |
+
Path to the vocabulary file.
|
77 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
78 |
+
The beginning of sequence token.
|
79 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
80 |
+
The end of sequence token.
|
81 |
+
line_token (`str`, *optional*, defaults to `"\n"`):
|
82 |
+
The line token.
|
83 |
+
space_token (`str`, *optional*, defaults to `" "`):
|
84 |
+
The space token.
|
85 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
86 |
+
The unknown token.
|
87 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
88 |
+
The mask token.
|
89 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
90 |
+
The token used for padding.
|
91 |
+
padding_side (`str`, *optional*, defaults to `"left"`):
|
92 |
+
The padding side. CPM-Bee will use left padding by default.
|
93 |
+
"""
|
94 |
+
|
95 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
96 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
97 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
98 |
+
model_input_names: List[str] = [
|
99 |
+
"input_ids",
|
100 |
+
"attention_mask",
|
101 |
+
"input_id_sub",
|
102 |
+
"position",
|
103 |
+
"context",
|
104 |
+
"sample_ids",
|
105 |
+
"num_segments",
|
106 |
+
"segment",
|
107 |
+
"segment_rel_offset",
|
108 |
+
"segment_rel",
|
109 |
+
]
|
110 |
+
add_prefix_space = False
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
vocab_file,
|
115 |
+
bos_token="<s>",
|
116 |
+
eos_token="</s>",
|
117 |
+
line_token="\n",
|
118 |
+
space_token=" ",
|
119 |
+
unk_token="<unk>",
|
120 |
+
mask_token="<mask>",
|
121 |
+
pad_token="<pad>",
|
122 |
+
padding_side="left",
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
super().__init__(
|
126 |
+
bos_token=bos_token,
|
127 |
+
eos_token=eos_token,
|
128 |
+
line_token=line_token,
|
129 |
+
space_token=space_token,
|
130 |
+
unk_token=unk_token,
|
131 |
+
mask_token=mask_token,
|
132 |
+
pad_token=pad_token,
|
133 |
+
padding_side=padding_side,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.encoder: Dict[str, int] = {}
|
138 |
+
|
139 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
140 |
+
for token in reader.readlines():
|
141 |
+
token = token.rstrip("\n")
|
142 |
+
if len(token) == 0:
|
143 |
+
continue
|
144 |
+
self.encoder[token] = len(self.encoder)
|
145 |
+
|
146 |
+
self.encoder[" "] = self.encoder["</_>"]
|
147 |
+
self.encoder["\n"] = self.encoder["</n>"]
|
148 |
+
del self.encoder["</_>"]
|
149 |
+
del self.encoder["</n>"]
|
150 |
+
|
151 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
152 |
+
|
153 |
+
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
154 |
+
self.cpmbee_special_tokens = {k: v for k, v in self.encoder.items() if k.startswith("<") and k.endswith(">")}
|
155 |
+
|
156 |
+
self.ext_table: Dict[int, str] = {}
|
157 |
+
self.ext_table_rev: Dict[str, int] = {}
|
158 |
+
|
159 |
+
self.token_id_table: Dict[str, Dict[int, int]] = {}
|
160 |
+
self.ext_special_tokens = []
|
161 |
+
|
162 |
+
self.ext_args_for_model = [
|
163 |
+
"input_id_subs",
|
164 |
+
"input_pos",
|
165 |
+
"context",
|
166 |
+
"segment_ids",
|
167 |
+
"segment_rel_offset",
|
168 |
+
"segment_rel",
|
169 |
+
"sample_ids",
|
170 |
+
"num_segments",
|
171 |
+
"predict_segments",
|
172 |
+
"answer_placeholders",
|
173 |
+
"ext_table",
|
174 |
+
"token_id_table",
|
175 |
+
"image_bound"
|
176 |
+
]
|
177 |
+
|
178 |
+
@property
|
179 |
+
def bod_token_id(self):
|
180 |
+
return self.encoder[self.bod_token]
|
181 |
+
|
182 |
+
@property
|
183 |
+
def eod_token_id(self):
|
184 |
+
return self.encoder[self.eod_token]
|
185 |
+
|
186 |
+
@property
|
187 |
+
def newline_id(self):
|
188 |
+
return self.encoder[self.line_token]
|
189 |
+
|
190 |
+
@property
|
191 |
+
def vocab_size(self) -> int:
|
192 |
+
return len(self.encoder)
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
"""
|
196 |
+
Size of the full vocabulary with the added tokens.
|
197 |
+
"""
|
198 |
+
return self.vocab_size + len(self.added_tokens_encoder)
|
199 |
+
|
200 |
+
def get_vocab(self):
|
201 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
202 |
+
|
203 |
+
def get_piece(self, text: str) -> str:
|
204 |
+
"""
|
205 |
+
Match with maximum length.
|
206 |
+
"""
|
207 |
+
len_text = len(text)
|
208 |
+
for i in range(len(text)):
|
209 |
+
sub = text[: len_text - i]
|
210 |
+
if (sub in self.encoder) or (sub in self.added_tokens_encoder):
|
211 |
+
return sub
|
212 |
+
return text[0]
|
213 |
+
|
214 |
+
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
|
215 |
+
r"""
|
216 |
+
Override the `tokenize` to meet the needs of CPMBee:
|
217 |
+
1. Mark the special token with `<` and `>`. The `<>` will be ignored.
|
218 |
+
2. Split sentences by the marked special tokens.
|
219 |
+
3. Record the marked special token by `ext_table` and `ext_table_rev`.
|
220 |
+
4. Tokenize the sentence without special tokens.
|
221 |
+
"""
|
222 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
223 |
+
all_special_tokens_extended = {
|
224 |
+
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
225 |
+
}
|
226 |
+
|
227 |
+
sentence_split = [""]
|
228 |
+
is_special_token = False
|
229 |
+
for i, c in enumerate(text):
|
230 |
+
if is_special_token:
|
231 |
+
if c == "<":
|
232 |
+
tail = sentence_split.pop(-1)
|
233 |
+
sentence_split[-1] += tail
|
234 |
+
sentence_split.append(c)
|
235 |
+
elif c == ">":
|
236 |
+
# end of special token
|
237 |
+
sentence_split[-1] += c
|
238 |
+
if sentence_split[-1] == "<>":
|
239 |
+
continue
|
240 |
+
is_special_token = False
|
241 |
+
sentence_split.append("")
|
242 |
+
else:
|
243 |
+
sentence_split[-1] += c
|
244 |
+
else:
|
245 |
+
if c == "<":
|
246 |
+
is_special_token = True
|
247 |
+
sentence_split.append(c)
|
248 |
+
else:
|
249 |
+
sentence_split[-1] += c
|
250 |
+
if is_special_token:
|
251 |
+
tail = sentence_split.pop(-1)
|
252 |
+
sentence_split[-1] += tail
|
253 |
+
|
254 |
+
output_tokens = []
|
255 |
+
for i, part in enumerate(sentence_split):
|
256 |
+
if (i & 1) == 1:
|
257 |
+
# special token
|
258 |
+
output_tokens.append(part)
|
259 |
+
if for_cpmbee and (part not in self.encoder) and (part not in self.ext_table_rev):
|
260 |
+
self.ext_table_rev[part] = len(self.ext_table_rev) + self.vocab_size
|
261 |
+
self.ext_table[self.ext_table_rev[part]] = part
|
262 |
+
else:
|
263 |
+
output_tokens.extend(self._tokenize(part, for_cpmbee=for_cpmbee))
|
264 |
+
|
265 |
+
# drop spaces
|
266 |
+
for i, token in enumerate(output_tokens):
|
267 |
+
if token in self.added_tokens_encoder:
|
268 |
+
token = all_special_tokens_extended.get(token, None)
|
269 |
+
left = output_tokens[i - 1] if i > 0 else None
|
270 |
+
right = output_tokens[i + 1] if i < len(output_tokens) - 1 else None
|
271 |
+
if isinstance(token, AddedToken):
|
272 |
+
if token.rstrip and right:
|
273 |
+
# A bit counter-intuitive but we strip the left of the string
|
274 |
+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
275 |
+
output_tokens[i + 1] = right.lstrip()
|
276 |
+
# Strip white spaces on the left
|
277 |
+
if token.lstrip and left:
|
278 |
+
output_tokens[i - 1] = left.rstrip() # Opposite here
|
279 |
+
else:
|
280 |
+
if right:
|
281 |
+
output_tokens[i + 1] = right.lstrip()
|
282 |
+
if left:
|
283 |
+
output_tokens[i - 1] = left.rstrip()
|
284 |
+
|
285 |
+
skipped_tokens = []
|
286 |
+
for token in output_tokens:
|
287 |
+
if not token:
|
288 |
+
continue
|
289 |
+
else:
|
290 |
+
skipped_tokens.append(token)
|
291 |
+
|
292 |
+
return skipped_tokens
|
293 |
+
|
294 |
+
def _tokenize(self, text, **kwargs):
|
295 |
+
"""
|
296 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
297 |
+
vocabulary.
|
298 |
+
|
299 |
+
Do NOT take care of added tokens. Record the unk tokens and special tokens in `ext_table` and `ext_table_rev`.
|
300 |
+
"""
|
301 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
302 |
+
output_tokens = []
|
303 |
+
|
304 |
+
part_st = 0
|
305 |
+
last_unk = None
|
306 |
+
while part_st < len(text):
|
307 |
+
piece = self.get_piece(text[part_st:])
|
308 |
+
if piece in self.encoder or self.added_tokens_encoder:
|
309 |
+
if last_unk is None:
|
310 |
+
output_tokens.append(piece)
|
311 |
+
else:
|
312 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
313 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
314 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
315 |
+
output_tokens.append(last_unk)
|
316 |
+
output_tokens.append(piece)
|
317 |
+
last_unk = None
|
318 |
+
else:
|
319 |
+
if last_unk is None:
|
320 |
+
last_unk = piece
|
321 |
+
else:
|
322 |
+
last_unk += piece
|
323 |
+
part_st += len(piece)
|
324 |
+
if last_unk is not None:
|
325 |
+
# part end with UNK
|
326 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
327 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
328 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
329 |
+
output_tokens.append(last_unk)
|
330 |
+
|
331 |
+
return output_tokens
|
332 |
+
|
333 |
+
def check(self, token):
|
334 |
+
return token in self.encoder
|
335 |
+
|
336 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
337 |
+
return "".join(tokens)
|
338 |
+
|
339 |
+
def _convert_token_to_id(self, token: str):
|
340 |
+
"""Converts a token (str) in an id using the vocab and ext_table."""
|
341 |
+
if token in self.encoder:
|
342 |
+
return self.encoder.get(token)
|
343 |
+
elif token in self.ext_table_rev:
|
344 |
+
return self.ext_table_rev[token]
|
345 |
+
elif token in self.added_tokens_encoder:
|
346 |
+
return self.added_tokens_encoder[token]
|
347 |
+
else:
|
348 |
+
return self.unk_token_id
|
349 |
+
|
350 |
+
def _convert_id_to_token(self, index):
|
351 |
+
"""Converts an index (integer) in a token (str) using the vocab and ext_table."""
|
352 |
+
if index in self.ext_table:
|
353 |
+
return self.ext_table[index]
|
354 |
+
elif index in self.added_tokens_decoder:
|
355 |
+
return self.added_tokens_decoder[index]
|
356 |
+
else:
|
357 |
+
if index >= 0:
|
358 |
+
return self.decoder[index]
|
359 |
+
|
360 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
361 |
+
if os.path.isdir(save_directory):
|
362 |
+
vocab_file = os.path.join(
|
363 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
367 |
+
index = 0
|
368 |
+
self.encoder["</n>"] = self.encoder["\n"]
|
369 |
+
del self.encoder["\n"]
|
370 |
+
self.encoder["</_>"] = self.encoder[" "]
|
371 |
+
del self.encoder[" "]
|
372 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
373 |
+
for token, token_index in sorted(self.encoder.items(), key=lambda x: x[1]):
|
374 |
+
if index != token_index:
|
375 |
+
logger.warning(
|
376 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
377 |
+
" Please check that the vocabulary is not corrupted!"
|
378 |
+
)
|
379 |
+
index = token_index
|
380 |
+
writer.write(token + "\n")
|
381 |
+
index += 1
|
382 |
+
return (vocab_file,)
|
383 |
+
|
384 |
+
def __call__(self, text, *args, **kwargs):
|
385 |
+
r"""
|
386 |
+
CPMBee `call` method will use `_tokenize_cpmbee` when the input type is dict.
|
387 |
+
"""
|
388 |
+
if isinstance(text, dict):
|
389 |
+
return self._batch_tokenize_cpmbee([text], *args, **kwargs)
|
390 |
+
elif isinstance(text, (list, tuple)):
|
391 |
+
if isinstance(text[0], dict):
|
392 |
+
return self._batch_tokenize_cpmbee(text, *args, **kwargs)
|
393 |
+
else:
|
394 |
+
return super().__call__(text, *args, **kwargs)
|
395 |
+
else:
|
396 |
+
return super().__call__(text, *args, **kwargs)
|
397 |
+
|
398 |
+
# 分词
|
399 |
+
def _tokenize_cpmbee(self, data: TextInput, *args, **kwargs) -> List[str]:
|
400 |
+
"""
|
401 |
+
A tokenize method to process dict data. Exclusive for CPMBee.
|
402 |
+
"""
|
403 |
+
if isinstance(data, str):
|
404 |
+
data = json.loads(data)
|
405 |
+
if not isinstance(data, Dict):
|
406 |
+
raise TypeError(
|
407 |
+
"CpmBeeTokenizer input data should be dict or str in dict format, but got {}".format(type(data))
|
408 |
+
)
|
409 |
+
|
410 |
+
# 1. prepare answer placeholder
|
411 |
+
answer_placeholders = []
|
412 |
+
|
413 |
+
def _put_placeholder(data: Any, path: List[str] = []):
|
414 |
+
if isinstance(data, dict):
|
415 |
+
ret = {}
|
416 |
+
for k, v in data.items():
|
417 |
+
ret[k] = _put_placeholder(v, path + [k])
|
418 |
+
return ret
|
419 |
+
else:
|
420 |
+
answer_placeholders.append(path)
|
421 |
+
return "<ans_{}>".format(len(answer_placeholders))
|
422 |
+
|
423 |
+
data["<ans>"] = _put_placeholder(data["<ans>"])
|
424 |
+
|
425 |
+
(
|
426 |
+
input_ids,
|
427 |
+
input_id_subs,
|
428 |
+
context,
|
429 |
+
segment_ids,
|
430 |
+
segment_rel,
|
431 |
+
n_segments,
|
432 |
+
table_states,
|
433 |
+
image_bound
|
434 |
+
) = self.convert_data_to_id(data, shuffle_answer=False, max_depth=8)
|
435 |
+
|
436 |
+
# <ans> mapping from sub to id
|
437 |
+
sub_ans_map: Dict[int, int] = {}
|
438 |
+
for fake_id, token_sub in table_states["token_id_table"]["<ans>"].items():
|
439 |
+
token = table_states["ext_table"][fake_id]
|
440 |
+
if token.startswith("<ans_") and token.endswith(">"):
|
441 |
+
ans_id = int(token[5:-1])
|
442 |
+
sub_ans_map[token_sub] = ans_id
|
443 |
+
|
444 |
+
tmp_input_ids = []
|
445 |
+
tmp_input_sub = []
|
446 |
+
tmp_input_seg = []
|
447 |
+
|
448 |
+
# get predict segments
|
449 |
+
predict_segments: List[Tuple[int, int]] = []
|
450 |
+
for i in range(input_ids.shape[0]):
|
451 |
+
if context[i] == 0:
|
452 |
+
if input_ids[i] == self.encoder["<ans>"]:
|
453 |
+
# is ans
|
454 |
+
# (segment_id, ans_id)
|
455 |
+
predict_segments.append((segment_ids[i], sub_ans_map[input_id_subs[i]]))
|
456 |
+
else:
|
457 |
+
tmp_input_ids.append(input_ids[i])
|
458 |
+
tmp_input_sub.append(input_id_subs[i])
|
459 |
+
tmp_input_seg.append(segment_ids[i])
|
460 |
+
|
461 |
+
if len(predict_segments) == 0:
|
462 |
+
raise ValueError("No answer to predict")
|
463 |
+
|
464 |
+
input_ids = np.array(tmp_input_ids, dtype=np.int32) # all context
|
465 |
+
input_id_subs = np.array(tmp_input_sub, dtype=np.int32) # [0, 0, 0, 0, 1, 0, 0, 2, 0, ...]
|
466 |
+
context = np.full_like(tmp_input_ids, 1, dtype=np.int8) # [1, 1, 1, ...]
|
467 |
+
segment_ids = np.array(tmp_input_seg, dtype=np.int32) # [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, ...]
|
468 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, 0, ...]
|
469 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, ...]
|
470 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32) # [n_seg, n_seg, n_seg, ...]
|
471 |
+
input_pos = np.arange(input_ids.shape[0], dtype=np.int32) # [0, 1, 2, 3, 4, ...]
|
472 |
+
image_bound = np.array(image_bound)
|
473 |
+
|
474 |
+
return (
|
475 |
+
self.prepare_for_model(
|
476 |
+
input_ids.tolist(),
|
477 |
+
input_id_subs=input_id_subs.tolist(),
|
478 |
+
input_pos=input_pos.tolist(),
|
479 |
+
context=context.tolist(),
|
480 |
+
segment_ids=segment_ids.tolist(),
|
481 |
+
segment_rel_offset=segment_rel_offset.tolist(),
|
482 |
+
segment_rel=segment_rel.tolist(),
|
483 |
+
sample_ids=sample_ids.tolist(),
|
484 |
+
num_segments=num_segments.tolist(),
|
485 |
+
image_bound=image_bound,
|
486 |
+
**kwargs,
|
487 |
+
),
|
488 |
+
predict_segments,
|
489 |
+
answer_placeholders,
|
490 |
+
table_states["ext_table"],
|
491 |
+
table_states["token_id_table"],
|
492 |
+
)
|
493 |
+
|
494 |
+
def _batch_tokenize_cpmbee(self, data_lst, *args, **kwargs):
|
495 |
+
"""
|
496 |
+
Batched _token_cpmbee.
|
497 |
+
"""
|
498 |
+
device = kwargs.get("device", "cpu")
|
499 |
+
return_tensors = kwargs.get("return_tensors", None)
|
500 |
+
batch_outputs = {}
|
501 |
+
segment_rel_pack = []
|
502 |
+
other_info = []
|
503 |
+
|
504 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
505 |
+
batch_ext_table_ids: List[int] = []
|
506 |
+
batch_ext_table_sub: List[int] = []
|
507 |
+
|
508 |
+
for data in data_lst:
|
509 |
+
self.ext_table = {}
|
510 |
+
self.ext_table_rev = {}
|
511 |
+
self.token_id_table = {}
|
512 |
+
(outputs, predict_segments, answer_placeholders, ext_table, token_id_table) = self._tokenize_cpmbee(
|
513 |
+
data,
|
514 |
+
truncation=None,
|
515 |
+
padding=PaddingStrategy.DO_NOT_PAD.value,
|
516 |
+
max_length=None,
|
517 |
+
pad_to_multiple_of=None,
|
518 |
+
return_attention_mask=False,
|
519 |
+
return_tensors=None,
|
520 |
+
)
|
521 |
+
rev_ext_table = {}
|
522 |
+
for token, mp in token_id_table.items():
|
523 |
+
if token == "<ans>":
|
524 |
+
continue
|
525 |
+
token_id = self.encoder[token]
|
526 |
+
for fake_id, token_sub in mp.items():
|
527 |
+
if token_sub > 0:
|
528 |
+
if (token_id, token_sub) not in batch_ext_table_map:
|
529 |
+
batch_ext_table_map[(token_id, token_sub)] = len(batch_ext_table_ids) + self.vocab_size
|
530 |
+
batch_ext_table_ids.append(token_id)
|
531 |
+
batch_ext_table_sub.append(token_sub)
|
532 |
+
rev_ext_table[batch_ext_table_map[(token_id, token_sub)]] = ext_table[fake_id]
|
533 |
+
else:
|
534 |
+
rev_ext_table[token_id] = ext_table[fake_id]
|
535 |
+
|
536 |
+
segment_rel_pack.append(np.array(outputs.pop("segment_rel")))
|
537 |
+
other_info.append(
|
538 |
+
{
|
539 |
+
"predict_segments": predict_segments,
|
540 |
+
"answer_placeholders": answer_placeholders,
|
541 |
+
"ext_table": rev_ext_table,
|
542 |
+
}
|
543 |
+
)
|
544 |
+
|
545 |
+
for key, value in outputs.items():
|
546 |
+
if key not in batch_outputs:
|
547 |
+
batch_outputs[key] = []
|
548 |
+
batch_outputs[key].append(value)
|
549 |
+
|
550 |
+
max_length = max([len(item) for item in batch_outputs[self.model_input_names[0]]])
|
551 |
+
batch_size = len(batch_outputs[self.model_input_names[0]])
|
552 |
+
for i in range(batch_size):
|
553 |
+
inputs = {k: v[i] for k, v in batch_outputs.items()}
|
554 |
+
|
555 |
+
for k, v in inputs.items():
|
556 |
+
required_input = v
|
557 |
+
|
558 |
+
needs_to_be_padded = len(required_input) != max_length and k != 'image_bound'
|
559 |
+
|
560 |
+
if needs_to_be_padded:
|
561 |
+
difference = max_length - len(required_input)
|
562 |
+
batch_outputs[k][i] = [self.pad_token_id] * difference + required_input
|
563 |
+
|
564 |
+
max_num_rels = 0
|
565 |
+
for rel in segment_rel_pack:
|
566 |
+
max_num_rels = max(max_num_rels, rel.shape[0])
|
567 |
+
padded_rels = np.zeros((len(segment_rel_pack), max_num_rels), dtype=np.int32)
|
568 |
+
for i, rel in enumerate(segment_rel_pack):
|
569 |
+
padded_rels[i, : rel.shape[0]] = rel
|
570 |
+
batch_outputs["segment_rel"] = padded_rels
|
571 |
+
batch_outputs["batch_ext_table_ids"] = np.array(batch_ext_table_ids, dtype=np.int32)
|
572 |
+
batch_outputs["batch_ext_table_sub"] = np.array(batch_ext_table_sub, dtype=np.int32)
|
573 |
+
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
574 |
+
if return_tensors == "pt":
|
575 |
+
batch_outputs = batch_outputs.to(device=device)
|
576 |
+
batch_outputs["other_info"] = other_info
|
577 |
+
|
578 |
+
return batch_outputs
|
579 |
+
|
580 |
+
def convert_data_to_id(
|
581 |
+
self,
|
582 |
+
data: Any,
|
583 |
+
prev_ext_states: Optional[_PrevExtTableStates] = None,
|
584 |
+
shuffle_answer: bool = True,
|
585 |
+
max_depth: int = 8,
|
586 |
+
):
|
587 |
+
"""
|
588 |
+
Parse a dict to data ids. Exclusive for CPMBee. It will
|
589 |
+
1. parse the dict to segments and get segment_rel, which for calculating of position_bias.
|
590 |
+
2. tokenize every segment.
|
591 |
+
"""
|
592 |
+
root: _DictTree = {
|
593 |
+
"value": "<root>",
|
594 |
+
"children": [],
|
595 |
+
"depth": 0,
|
596 |
+
"segment_id": 0,
|
597 |
+
"need_predict": False,
|
598 |
+
"is_image": False
|
599 |
+
}
|
600 |
+
|
601 |
+
segments = [root]
|
602 |
+
|
603 |
+
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
|
604 |
+
if isinstance(data, dict):
|
605 |
+
ret_list: List[_DictTree] = []
|
606 |
+
curr_items = list(data.items())
|
607 |
+
if need_predict and shuffle_answer:
|
608 |
+
access_idx = np.arange(len(curr_items))
|
609 |
+
np.random.shuffle(access_idx)
|
610 |
+
curr_items = [curr_items[idx] for idx in access_idx]
|
611 |
+
for k, v in curr_items:
|
612 |
+
child_info: _DictTree = {
|
613 |
+
"value": k,
|
614 |
+
"children": [],
|
615 |
+
"depth": depth,
|
616 |
+
"segment_id": len(segments),
|
617 |
+
"need_predict": False, # only leaves are contexts
|
618 |
+
"is_image": False,
|
619 |
+
}
|
620 |
+
segments.append(child_info)
|
621 |
+
child_info["children"] = _build_dict_tree(
|
622 |
+
v, depth + 1,
|
623 |
+
need_predict=need_predict or (depth == 1 and k == "<ans>"),
|
624 |
+
is_image=is_image or (depth == 1 and k == "image")
|
625 |
+
) # elements in <root>.<ans>
|
626 |
+
|
627 |
+
ret_list.append(child_info)
|
628 |
+
return ret_list
|
629 |
+
else:
|
630 |
+
assert isinstance(data, str), "Invalid data {}".format(data)
|
631 |
+
ret: _DictTree = {
|
632 |
+
"value": data,
|
633 |
+
"children": [],
|
634 |
+
"depth": depth,
|
635 |
+
"segment_id": len(segments),
|
636 |
+
"need_predict": need_predict,
|
637 |
+
"is_image": is_image,
|
638 |
+
}
|
639 |
+
segments.append(ret)
|
640 |
+
return [ret]
|
641 |
+
|
642 |
+
root["children"] = _build_dict_tree(data, 1, False, False)
|
643 |
+
|
644 |
+
num_segments = len(segments)
|
645 |
+
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
|
646 |
+
|
647 |
+
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
|
648 |
+
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
|
649 |
+
for child in node["children"]:
|
650 |
+
sub = _build_segment_rel(child)
|
651 |
+
for seg_id_1, depth_1 in sub:
|
652 |
+
for seg_id_2, depth_2 in ret:
|
653 |
+
n_up = min(depth_1 - node["depth"], max_depth - 1)
|
654 |
+
n_down = min(depth_2 - node["depth"], max_depth - 1)
|
655 |
+
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
|
656 |
+
n_up, n_down, max_depth=max_depth
|
657 |
+
)
|
658 |
+
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
|
659 |
+
n_down, n_up, max_depth=max_depth
|
660 |
+
)
|
661 |
+
ret.extend(sub)
|
662 |
+
return ret
|
663 |
+
|
664 |
+
_build_segment_rel(root)
|
665 |
+
|
666 |
+
input_ids: List[int] = []
|
667 |
+
input_id_subs: List[int] = []
|
668 |
+
segment_bound: List[Tuple[int, int]] = []
|
669 |
+
image_bound: List[Tuple[int, int]] = []
|
670 |
+
|
671 |
+
|
672 |
+
if prev_ext_states is not None:
|
673 |
+
self.ext_table = prev_ext_states["ext_table"]
|
674 |
+
self.token_id_table = prev_ext_states["token_id_table"]
|
675 |
+
|
676 |
+
for seg in segments:
|
677 |
+
# tokenize
|
678 |
+
tokens = self.convert_tokens_to_ids(self.tokenize(seg["value"], for_cpmbee=True))
|
679 |
+
|
680 |
+
token_id_subs = []
|
681 |
+
reid_token_ids = []
|
682 |
+
for idx in tokens:
|
683 |
+
if idx in self.ext_table:
|
684 |
+
# unk or special token
|
685 |
+
token = self.ext_table[idx]
|
686 |
+
if token.startswith("<") and token.endswith(">"):
|
687 |
+
# special token
|
688 |
+
if "_" in token:
|
689 |
+
token_name = token[1:-1].split("_", maxsplit=1)[0]
|
690 |
+
else:
|
691 |
+
token_name = token[1:-1]
|
692 |
+
token_name = "<{}>".format(token_name)
|
693 |
+
else:
|
694 |
+
token_name = "<unk>"
|
695 |
+
|
696 |
+
if token_name not in self.token_id_table:
|
697 |
+
self.token_id_table[token_name] = {}
|
698 |
+
if idx not in self.token_id_table[token_name]:
|
699 |
+
self.token_id_table[token_name][idx] = len(self.token_id_table[token_name])
|
700 |
+
if token_name not in self.encoder:
|
701 |
+
raise ValueError("Invalid token {}".format(token))
|
702 |
+
reid_token_ids.append(self.encoder[token_name])
|
703 |
+
token_id_subs.append(self.token_id_table[token_name][idx])
|
704 |
+
else:
|
705 |
+
reid_token_ids.append(idx)
|
706 |
+
token_id_subs.append(0)
|
707 |
+
tokens = [self.bos_token_id] + reid_token_ids
|
708 |
+
token_id_subs = [0] + token_id_subs
|
709 |
+
# eos_id 表示 no need_predict
|
710 |
+
if not seg["need_predict"]: # eos
|
711 |
+
tokens = tokens + [self.eos_token_id]
|
712 |
+
token_id_subs = token_id_subs + [0]
|
713 |
+
else:
|
714 |
+
# no eos
|
715 |
+
pass
|
716 |
+
begin = len(input_ids)
|
717 |
+
input_ids.extend(tokens)
|
718 |
+
input_id_subs.extend(token_id_subs)
|
719 |
+
end = len(input_ids)
|
720 |
+
segment_bound.append((begin, end))
|
721 |
+
|
722 |
+
ids = np.array(input_ids, dtype=np.int32)
|
723 |
+
id_subs = np.array(input_id_subs, dtype=np.int32)
|
724 |
+
segs = np.zeros((ids.shape[0],), dtype=np.int32) # 按segment_bound对seg编号
|
725 |
+
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
726 |
+
for i, (begin, end) in enumerate(segment_bound):
|
727 |
+
if not segments[i]["need_predict"]:
|
728 |
+
context[begin:end] = 1
|
729 |
+
if segments[i]["is_image"]:
|
730 |
+
image_bound.append((begin + 1, end - 1))
|
731 |
+
segs[begin:end] = i
|
732 |
+
|
733 |
+
curr_ext_table_states: _PrevExtTableStates = {
|
734 |
+
"ext_table": self.ext_table,
|
735 |
+
"token_id_table": self.token_id_table,
|
736 |
+
}
|
737 |
+
image_bound = np.array(image_bound, dtype=np.int32)
|
738 |
+
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
|
739 |
+
|
740 |
+
def prepare_for_model(
|
741 |
+
self,
|
742 |
+
ids: List[int],
|
743 |
+
pair_ids: Optional[List[int]] = None,
|
744 |
+
add_special_tokens: bool = True,
|
745 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
746 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
747 |
+
max_length: Optional[int] = None,
|
748 |
+
stride: int = 0,
|
749 |
+
pad_to_multiple_of: Optional[int] = None,
|
750 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
751 |
+
return_token_type_ids: Optional[bool] = None,
|
752 |
+
return_attention_mask: Optional[bool] = None,
|
753 |
+
return_overflowing_tokens: bool = False,
|
754 |
+
return_special_tokens_mask: bool = False,
|
755 |
+
return_length: bool = False,
|
756 |
+
verbose: bool = True,
|
757 |
+
prepend_batch_axis: bool = False,
|
758 |
+
**kwargs,
|
759 |
+
) -> BatchEncoding:
|
760 |
+
"""
|
761 |
+
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
762 |
+
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
763 |
+
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
|
764 |
+
different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
|
765 |
+
overflowing tokens. Such a combination of arguments will raise an error.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
ids (`List[int]`):
|
769 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
770 |
+
`convert_tokens_to_ids` methods.
|
771 |
+
pair_ids (`List[int]`, *optional*):
|
772 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
773 |
+
and `convert_tokens_to_ids` methods.
|
774 |
+
"""
|
775 |
+
|
776 |
+
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
777 |
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
778 |
+
padding=padding,
|
779 |
+
truncation=truncation,
|
780 |
+
max_length=max_length,
|
781 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
782 |
+
verbose=verbose,
|
783 |
+
**kwargs,
|
784 |
+
)
|
785 |
+
|
786 |
+
pair = bool(pair_ids is not None)
|
787 |
+
len_ids = len(ids)
|
788 |
+
len_pair_ids = len(pair_ids) if pair else 0
|
789 |
+
|
790 |
+
if return_token_type_ids and not add_special_tokens:
|
791 |
+
raise ValueError(
|
792 |
+
"Asking to return token_type_ids while setting add_special_tokens to False "
|
793 |
+
"results in an undefined behavior. Please set add_special_tokens to True or "
|
794 |
+
"set return_token_type_ids to None."
|
795 |
+
)
|
796 |
+
|
797 |
+
if (
|
798 |
+
return_overflowing_tokens
|
799 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
800 |
+
and pair_ids is not None
|
801 |
+
):
|
802 |
+
raise ValueError(
|
803 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
804 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
805 |
+
"for instance `only_second` or `only_first`."
|
806 |
+
)
|
807 |
+
|
808 |
+
# Load from model defaults
|
809 |
+
if return_token_type_ids is None:
|
810 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
811 |
+
if return_attention_mask is None:
|
812 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
813 |
+
|
814 |
+
encoded_inputs = {}
|
815 |
+
|
816 |
+
# Compute the total size of the returned encodings
|
817 |
+
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
818 |
+
|
819 |
+
# Truncation: Handle max sequence length
|
820 |
+
overflowing_tokens = []
|
821 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
822 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
823 |
+
ids,
|
824 |
+
pair_ids=pair_ids,
|
825 |
+
num_tokens_to_remove=total_len - max_length,
|
826 |
+
truncation_strategy=truncation_strategy,
|
827 |
+
stride=stride,
|
828 |
+
)
|
829 |
+
|
830 |
+
if return_overflowing_tokens:
|
831 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
832 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
833 |
+
|
834 |
+
# Add special tokens
|
835 |
+
if add_special_tokens:
|
836 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
837 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
838 |
+
else:
|
839 |
+
sequence = ids + pair_ids if pair else ids
|
840 |
+
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
|
841 |
+
|
842 |
+
# Build output dictionary
|
843 |
+
encoded_inputs["input_ids"] = sequence
|
844 |
+
if return_token_type_ids:
|
845 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
846 |
+
if return_special_tokens_mask:
|
847 |
+
if add_special_tokens:
|
848 |
+
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
849 |
+
else:
|
850 |
+
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
851 |
+
|
852 |
+
# Check lengths
|
853 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
854 |
+
|
855 |
+
# Padding
|
856 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
857 |
+
encoded_inputs = self.pad(
|
858 |
+
encoded_inputs,
|
859 |
+
max_length=max_length,
|
860 |
+
padding=padding_strategy.value,
|
861 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
862 |
+
return_attention_mask=return_attention_mask,
|
863 |
+
)
|
864 |
+
|
865 |
+
if return_length:
|
866 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
867 |
+
|
868 |
+
# for CPMBee, encode all the model arguments
|
869 |
+
for arg in self.ext_args_for_model:
|
870 |
+
v = kwargs.get(arg, None)
|
871 |
+
if v is not None:
|
872 |
+
encoded_inputs[arg] = v
|
873 |
+
|
874 |
+
batch_outputs = BatchEncoding(
|
875 |
+
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
876 |
+
)
|
877 |
+
|
878 |
+
return batch_outputs
|
879 |
+
|
880 |
+
def prepare_for_finetune(
|
881 |
+
self,
|
882 |
+
data_list: List[Dict],
|
883 |
+
max_length: int = 2048
|
884 |
+
):
|
885 |
+
_inputs: List[NDArray[np.int32]] = []
|
886 |
+
_inputs_sub: List[NDArray[np.int32]] = []
|
887 |
+
_context: List[NDArray[np.int8]] = []
|
888 |
+
_sample_ids: List[NDArray[np.int32]] = []
|
889 |
+
_segments: List[NDArray[np.int32]] = []
|
890 |
+
_num_segments: List[NDArray[np.int32]] = []
|
891 |
+
_segment_rel_offset: List[NDArray[np.int32]] = []
|
892 |
+
_segment_rel: List[NDArray[np.int32]] = []
|
893 |
+
_spans: List[List[int]] = []
|
894 |
+
_raw_data: List[List[Any]] = []
|
895 |
+
|
896 |
+
raw_data = {}
|
897 |
+
for data in data_list:
|
898 |
+
(
|
899 |
+
input_ids,
|
900 |
+
input_id_subs,
|
901 |
+
context,
|
902 |
+
segment_ids,
|
903 |
+
segment_rel,
|
904 |
+
n_segments,
|
905 |
+
_
|
906 |
+
) = self.convert_data_to_id(data)
|
907 |
+
|
908 |
+
input_ids = input_ids[: max_length]
|
909 |
+
context = context[: max_length]
|
910 |
+
segment_ids = segment_ids[: max_length]
|
911 |
+
raw_data["input"] = data
|
912 |
+
raw_data["samples"] = []
|
913 |
+
|
914 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
915 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
916 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
917 |
+
|
918 |
+
_inputs.append(input_ids)
|
919 |
+
_inputs_sub.append(input_id_subs)
|
920 |
+
_context.append(context)
|
921 |
+
_sample_ids.append(sample_ids)
|
922 |
+
_segments.append(segment_ids)
|
923 |
+
_num_segments.append(num_segments)
|
924 |
+
_segment_rel_offset.append(segment_rel_offset)
|
925 |
+
_segment_rel.append(segment_rel)
|
926 |
+
_spans.append([input_ids.shape[0]])
|
927 |
+
_raw_data.append([raw_data])
|
928 |
+
|
929 |
+
batch_size = len(_inputs)
|
930 |
+
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
931 |
+
inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
|
932 |
+
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
933 |
+
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
934 |
+
segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
935 |
+
num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
936 |
+
segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
|
937 |
+
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
938 |
+
|
939 |
+
max_rel = 0
|
940 |
+
for i in range(batch_size):
|
941 |
+
max_rel = max(max_rel, _segment_rel[i].shape[0])
|
942 |
+
segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
|
943 |
+
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
944 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
945 |
+
|
946 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
947 |
+
batch_ext_table_ids: List[int] = []
|
948 |
+
batch_ext_table_sub: List[int] = []
|
949 |
+
raw_data_list: List[Any] = []
|
950 |
+
|
951 |
+
for i in range(batch_size):
|
952 |
+
instance_length = _inputs[i].shape[0]
|
953 |
+
rel_size = _segment_rel[i].shape[0]
|
954 |
+
inputs[i, :instance_length] = _inputs[i]
|
955 |
+
inputs_sub[i, :instance_length] = _inputs_sub[i]
|
956 |
+
context[i, :instance_length] = _context[i]
|
957 |
+
sample_ids[i, :instance_length] = _sample_ids[i]
|
958 |
+
segments[i, :instance_length] = _segments[i]
|
959 |
+
num_segments[i, :instance_length] = _num_segments[i]
|
960 |
+
segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
|
961 |
+
segment_rel[i, :rel_size] = _segment_rel[i]
|
962 |
+
|
963 |
+
span_begin = 0
|
964 |
+
for span_id, span_end in enumerate(_spans[i]):
|
965 |
+
spans[i, span_begin:span_end] = span_id
|
966 |
+
span_begin = span_end
|
967 |
+
length[i] = instance_length
|
968 |
+
raw_data_list.extend(_raw_data[i])
|
969 |
+
|
970 |
+
for j in range(instance_length):
|
971 |
+
idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
|
972 |
+
tgt_idx = idx
|
973 |
+
if idx_sub > 0:
|
974 |
+
# need to be in ext table
|
975 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
976 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
977 |
+
batch_ext_table_ids.append(idx)
|
978 |
+
batch_ext_table_sub.append(idx_sub)
|
979 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
|
980 |
+
if j > 1 and context[i, j - 1] == 0:
|
981 |
+
if idx != self.bos_token_id:
|
982 |
+
tgt[i, j - 1] = tgt_idx
|
983 |
+
else:
|
984 |
+
tgt[i, j - 1] = self.eos_token_id
|
985 |
+
if context[i, instance_length - 1] == 0:
|
986 |
+
tgt[i, instance_length - 1] = self.eos_token_id
|
987 |
+
|
988 |
+
if len(batch_ext_table_map) == 0:
|
989 |
+
# placeholder
|
990 |
+
batch_ext_table_ids.append(0)
|
991 |
+
batch_ext_table_sub.append(1)
|
992 |
+
|
993 |
+
return BatchEncoding({
|
994 |
+
"input_ids": inputs,
|
995 |
+
"input_id_sub": inputs_sub,
|
996 |
+
"length": length,
|
997 |
+
"context": context > 0,
|
998 |
+
"sample_ids": sample_ids,
|
999 |
+
"num_segments": num_segments,
|
1000 |
+
"segment": segments,
|
1001 |
+
"segment_rel_offset": segment_rel_offset,
|
1002 |
+
"segment_rel": segment_rel,
|
1003 |
+
"span": spans,
|
1004 |
+
"labels": tgt,
|
1005 |
+
"ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
|
1006 |
+
"ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
|
1007 |
+
}, tensor_type="pt")
|
tokenizer_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "openbmb/viscpmchat-bee-10b",
|
3 |
+
"tokenizer_class": "VisCpmChatBeeTokenizer",
|
4 |
+
"auto_map": {
|
5 |
+
"AutoTokenizer": [
|
6 |
+
"tokenization_viscpmchatbee.VisCpmChatBeeTokenizer",
|
7 |
+
null
|
8 |
+
]
|
9 |
+
}
|
10 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
5 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
6 |
+
from torchvision import transforms
|
7 |
+
import urllib
|
8 |
+
from tqdm import tqdm
|
9 |
+
from cpm_live.tokenizers import CPMBeeTokenizer
|
10 |
+
from torch.utils.data import default_collate
|
11 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
12 |
+
from typing_extensions import TypedDict
|
13 |
+
from numpy.typing import NDArray
|
14 |
+
import importlib.machinery
|
15 |
+
import importlib.util
|
16 |
+
import types
|
17 |
+
import random
|
18 |
+
|
19 |
+
|
20 |
+
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
|
21 |
+
|
22 |
+
|
23 |
+
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
|
24 |
+
items = []
|
25 |
+
if isinstance(orig_items[0][key], list):
|
26 |
+
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
27 |
+
for it in orig_items:
|
28 |
+
for tr in it[key]:
|
29 |
+
items.append({key: tr})
|
30 |
+
else:
|
31 |
+
assert isinstance(orig_items[0][key], torch.Tensor)
|
32 |
+
items = orig_items
|
33 |
+
|
34 |
+
batch_size = len(items)
|
35 |
+
shape = items[0][key].shape
|
36 |
+
dim = len(shape)
|
37 |
+
assert dim <= 3
|
38 |
+
if max_length is None:
|
39 |
+
max_length = 0
|
40 |
+
max_length = max(max_length, max(item[key].shape[-1] for item in items))
|
41 |
+
min_length = min(item[key].shape[-1] for item in items)
|
42 |
+
dtype = items[0][key].dtype
|
43 |
+
|
44 |
+
if dim == 1:
|
45 |
+
return torch.cat([item[key] for item in items], dim=0)
|
46 |
+
elif dim == 2:
|
47 |
+
if max_length == min_length:
|
48 |
+
return torch.cat([item[key] for item in items], dim=0)
|
49 |
+
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
50 |
+
else:
|
51 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
52 |
+
|
53 |
+
for i, item in enumerate(items):
|
54 |
+
if dim == 2:
|
55 |
+
if padding_side == "left":
|
56 |
+
tensor[i, -len(item[key][0]):] = item[key][0].clone()
|
57 |
+
else:
|
58 |
+
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
59 |
+
elif dim == 3:
|
60 |
+
if padding_side == "left":
|
61 |
+
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
|
62 |
+
else:
|
63 |
+
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
64 |
+
|
65 |
+
return tensor
|
66 |
+
|
67 |
+
|
68 |
+
class CPMBeeCollater:
|
69 |
+
"""
|
70 |
+
针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker
|
71 |
+
目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning
|
72 |
+
并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持
|
73 |
+
todo: @wangchongyi 重写一下 Dataloader or BatchPacker
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, tokenizer: CPMBeeTokenizer, max_len):
|
77 |
+
self.tokenizer = tokenizer
|
78 |
+
self._max_length = max_len
|
79 |
+
self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset',
|
80 |
+
'segment_rel', 'sample_ids', 'num_segments']
|
81 |
+
|
82 |
+
def __call__(self, batch):
|
83 |
+
batch_size = len(batch)
|
84 |
+
|
85 |
+
tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32)
|
86 |
+
# 目前没有 best_fit, span 为全 0
|
87 |
+
span = np.zeros((batch_size, self._max_length), dtype=np.int32)
|
88 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
89 |
+
|
90 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
91 |
+
batch_ext_table_ids: List[int] = []
|
92 |
+
batch_ext_table_sub: List[int] = []
|
93 |
+
raw_data_list: List[Any] = []
|
94 |
+
|
95 |
+
for i in range(batch_size):
|
96 |
+
instance_length = batch[i]['input_ids'][0].shape[0]
|
97 |
+
length[i] = instance_length
|
98 |
+
raw_data_list.extend(batch[i]['raw_data'])
|
99 |
+
|
100 |
+
for j in range(instance_length):
|
101 |
+
idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j]
|
102 |
+
tgt_idx = idx
|
103 |
+
if idx_sub > 0:
|
104 |
+
# need to be in ext table
|
105 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
106 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
107 |
+
batch_ext_table_ids.append(idx)
|
108 |
+
batch_ext_table_sub.append(idx_sub)
|
109 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size
|
110 |
+
if j > 1 and batch[i]['context'][0, j - 1] == 0:
|
111 |
+
if idx != self.tokenizer.bos_id:
|
112 |
+
tgt[i, j - 1] = tgt_idx
|
113 |
+
else:
|
114 |
+
tgt[i, j - 1] = self.tokenizer.eos_id
|
115 |
+
if batch[i]['context'][0, instance_length - 1] == 0:
|
116 |
+
tgt[i, instance_length - 1] = self.tokenizer.eos_id
|
117 |
+
|
118 |
+
if len(batch_ext_table_map) == 0:
|
119 |
+
# placeholder
|
120 |
+
batch_ext_table_ids.append(0)
|
121 |
+
batch_ext_table_sub.append(1)
|
122 |
+
|
123 |
+
# image
|
124 |
+
if 'pixel_values' in batch[0]:
|
125 |
+
data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])}
|
126 |
+
else:
|
127 |
+
data = {}
|
128 |
+
|
129 |
+
# image_bound
|
130 |
+
if 'image_bound' in batch[0]:
|
131 |
+
data['image_bound'] = default_collate([i['image_bound'] for i in batch])
|
132 |
+
|
133 |
+
# bee inp
|
134 |
+
for key in self.pad_keys:
|
135 |
+
data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right')
|
136 |
+
|
137 |
+
data['context'] = data['context'] > 0
|
138 |
+
data['length'] = torch.from_numpy(length)
|
139 |
+
data['span'] = torch.from_numpy(span)
|
140 |
+
data['target'] = torch.from_numpy(tgt)
|
141 |
+
data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids))
|
142 |
+
data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub))
|
143 |
+
data['raw_data'] = raw_data_list
|
144 |
+
|
145 |
+
return data
|
146 |
+
|
147 |
+
|
148 |
+
class _DictTree(TypedDict):
|
149 |
+
value: str
|
150 |
+
children: List["_DictTree"]
|
151 |
+
depth: int
|
152 |
+
segment_id: int
|
153 |
+
need_predict: bool
|
154 |
+
is_image: bool
|
155 |
+
|
156 |
+
|
157 |
+
class _PrevExtTableStates(TypedDict):
|
158 |
+
ext_table: Dict[int, str]
|
159 |
+
token_id_table: Dict[str, Dict[int, int]]
|
160 |
+
|
161 |
+
|
162 |
+
class _TransformFuncDict(TypedDict):
|
163 |
+
loader: importlib.machinery.SourceFileLoader
|
164 |
+
module: types.ModuleType
|
165 |
+
last_m: float
|
166 |
+
|
167 |
+
|
168 |
+
_TransformFunction = Callable[[CPMBeeInputType, int, random.Random], CPMBeeInputType]
|
169 |
+
|
170 |
+
|
171 |
+
class CPMBeeBatch(TypedDict):
|
172 |
+
inputs: NDArray[np.int32]
|
173 |
+
inputs_sub: NDArray[np.int32]
|
174 |
+
length: NDArray[np.int32]
|
175 |
+
context: NDArray[np.bool_]
|
176 |
+
sample_ids: NDArray[np.int32]
|
177 |
+
num_segments: NDArray[np.int32]
|
178 |
+
segment_ids: NDArray[np.int32]
|
179 |
+
segment_rel_offset: NDArray[np.int32]
|
180 |
+
segment_rel: NDArray[np.int32]
|
181 |
+
spans: NDArray[np.int32]
|
182 |
+
target: NDArray[np.int32]
|
183 |
+
ext_ids: NDArray[np.int32]
|
184 |
+
ext_sub: NDArray[np.int32]
|
185 |
+
task_ids: NDArray[np.int32]
|
186 |
+
task_names: List[str]
|
187 |
+
raw_data: List[Any]
|
188 |
+
|
189 |
+
|
190 |
+
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
|
191 |
+
ret = n_up * max_depth + n_down
|
192 |
+
if ret == 0:
|
193 |
+
return ret
|
194 |
+
else:
|
195 |
+
# bucket 1 is reserved for incontext samples
|
196 |
+
return ret + 1
|
197 |
+
|
198 |
+
|
199 |
+
def convert_data_to_id(
|
200 |
+
tokenizer: CPMBeeTokenizer,
|
201 |
+
data: Any,
|
202 |
+
prev_ext_states: Optional[_PrevExtTableStates] = None,
|
203 |
+
shuffle_answer: bool = True,
|
204 |
+
max_depth: int = 8
|
205 |
+
):
|
206 |
+
root: _DictTree = {
|
207 |
+
"value": "<root>",
|
208 |
+
"children": [],
|
209 |
+
"depth": 0,
|
210 |
+
"segment_id": 0,
|
211 |
+
"need_predict": False,
|
212 |
+
"is_image": False
|
213 |
+
}
|
214 |
+
|
215 |
+
segments = [root]
|
216 |
+
|
217 |
+
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
|
218 |
+
if isinstance(data, dict):
|
219 |
+
ret_list: List[_DictTree] = []
|
220 |
+
curr_items = list(data.items())
|
221 |
+
if need_predict and shuffle_answer:
|
222 |
+
access_idx = np.arange(len(curr_items))
|
223 |
+
np.random.shuffle(access_idx)
|
224 |
+
curr_items = [curr_items[idx] for idx in access_idx]
|
225 |
+
for k, v in curr_items:
|
226 |
+
child_info: _DictTree = {
|
227 |
+
"value": k,
|
228 |
+
"children": [],
|
229 |
+
"depth": depth,
|
230 |
+
"segment_id": len(segments),
|
231 |
+
"need_predict": False, # only leaves are contexts
|
232 |
+
"is_image": False,
|
233 |
+
}
|
234 |
+
segments.append(child_info)
|
235 |
+
child_info["children"] = _build_dict_tree(
|
236 |
+
v, depth + 1,
|
237 |
+
need_predict=need_predict or (depth == 1 and k == "<ans>"),
|
238 |
+
is_image=is_image or (depth == 1 and k == "image")
|
239 |
+
) # elements in <root>.<ans>
|
240 |
+
|
241 |
+
ret_list.append(child_info)
|
242 |
+
return ret_list
|
243 |
+
else:
|
244 |
+
assert isinstance(data, str), "Invalid data {}".format(data)
|
245 |
+
ret: _DictTree = {
|
246 |
+
"value": data,
|
247 |
+
"children": [],
|
248 |
+
"depth": depth,
|
249 |
+
"segment_id": len(segments),
|
250 |
+
"need_predict": need_predict,
|
251 |
+
"is_image": is_image,
|
252 |
+
}
|
253 |
+
segments.append(ret)
|
254 |
+
return [ret]
|
255 |
+
|
256 |
+
root["children"] = _build_dict_tree(data, 1, False, False)
|
257 |
+
|
258 |
+
num_segments = len(segments)
|
259 |
+
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
|
260 |
+
|
261 |
+
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
|
262 |
+
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
|
263 |
+
for child in node["children"]:
|
264 |
+
sub = _build_segment_rel(child)
|
265 |
+
for seg_id_1, depth_1 in sub:
|
266 |
+
for seg_id_2, depth_2 in ret:
|
267 |
+
n_up = min(depth_1 - node["depth"], max_depth - 1)
|
268 |
+
n_down = min(depth_2 - node["depth"], max_depth - 1)
|
269 |
+
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
|
270 |
+
n_up, n_down, max_depth=max_depth
|
271 |
+
)
|
272 |
+
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
|
273 |
+
n_down, n_up, max_depth=max_depth
|
274 |
+
)
|
275 |
+
ret.extend(sub)
|
276 |
+
return ret
|
277 |
+
|
278 |
+
_build_segment_rel(root)
|
279 |
+
|
280 |
+
input_ids: List[int] = []
|
281 |
+
input_id_subs: List[int] = []
|
282 |
+
segment_bound: List[Tuple[int, int]] = []
|
283 |
+
image_bound: List[Tuple[int, int]] = []
|
284 |
+
|
285 |
+
ext_table: Dict[int, str] = {}
|
286 |
+
token_id_table: Dict[str, Dict[int, int]] = {}
|
287 |
+
|
288 |
+
if prev_ext_states is not None:
|
289 |
+
ext_table = prev_ext_states["ext_table"]
|
290 |
+
token_id_table = prev_ext_states["token_id_table"]
|
291 |
+
|
292 |
+
for seg in segments:
|
293 |
+
tokens, ext_table = tokenizer.encode(seg["value"], ext_table)
|
294 |
+
|
295 |
+
token_id_subs = []
|
296 |
+
reid_token_ids = []
|
297 |
+
for idx in tokens:
|
298 |
+
if idx in ext_table:
|
299 |
+
# unk or special token
|
300 |
+
token = ext_table[idx]
|
301 |
+
if token.startswith("<") and token.endswith(">"):
|
302 |
+
# special token
|
303 |
+
if "_" in token:
|
304 |
+
token_name = token[1:-1].split("_", maxsplit=1)[0]
|
305 |
+
else:
|
306 |
+
token_name = token[1:-1]
|
307 |
+
token_name = "<{}>".format(token_name)
|
308 |
+
else:
|
309 |
+
token_name = "<unk>"
|
310 |
+
|
311 |
+
if token_name not in token_id_table:
|
312 |
+
token_id_table[token_name] = {}
|
313 |
+
if idx not in token_id_table[token_name]:
|
314 |
+
token_id_table[token_name][idx] = len(token_id_table[token_name])
|
315 |
+
if token_name not in tokenizer.encoder:
|
316 |
+
raise ValueError("Invalid token {}".format(token))
|
317 |
+
reid_token_ids.append(tokenizer.encoder[token_name])
|
318 |
+
token_id_subs.append(token_id_table[token_name][idx])
|
319 |
+
else:
|
320 |
+
reid_token_ids.append(idx)
|
321 |
+
token_id_subs.append(0)
|
322 |
+
tokens = [tokenizer.bos_id] + reid_token_ids
|
323 |
+
token_id_subs = [0] + token_id_subs
|
324 |
+
if not seg["need_predict"]:
|
325 |
+
tokens = tokens + [tokenizer.eos_id]
|
326 |
+
token_id_subs = token_id_subs + [0]
|
327 |
+
else:
|
328 |
+
# no eos
|
329 |
+
pass
|
330 |
+
begin = len(input_ids)
|
331 |
+
input_ids.extend(tokens)
|
332 |
+
input_id_subs.extend(token_id_subs)
|
333 |
+
end = len(input_ids)
|
334 |
+
segment_bound.append((begin, end))
|
335 |
+
|
336 |
+
ids = np.array(input_ids, dtype=np.int32)
|
337 |
+
id_subs = np.array(input_id_subs, dtype=np.int32)
|
338 |
+
segs = np.zeros((ids.shape[0],), dtype=np.int32)
|
339 |
+
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
340 |
+
for i, (begin, end) in enumerate(segment_bound):
|
341 |
+
if not segments[i]["need_predict"]:
|
342 |
+
context[begin:end] = 1
|
343 |
+
if segments[i]["is_image"]:
|
344 |
+
image_bound.append((begin+1, end-1))
|
345 |
+
segs[begin:end] = i
|
346 |
+
|
347 |
+
curr_ext_table_states: _PrevExtTableStates = {
|
348 |
+
"ext_table": ext_table,
|
349 |
+
"token_id_table": token_id_table,
|
350 |
+
}
|
351 |
+
image_bound = np.array(image_bound, dtype=np.int32)
|
352 |
+
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
|
353 |
+
|
354 |
+
|
355 |
+
# aug functions
|
356 |
+
def identity_func(img):
|
357 |
+
return img
|
358 |
+
|
359 |
+
|
360 |
+
def autocontrast_func(img, cutoff=0):
|
361 |
+
'''
|
362 |
+
same output as PIL.ImageOps.autocontrast
|
363 |
+
'''
|
364 |
+
n_bins = 256
|
365 |
+
|
366 |
+
def tune_channel(ch):
|
367 |
+
n = ch.size
|
368 |
+
cut = cutoff * n // 100
|
369 |
+
if cut == 0:
|
370 |
+
high, low = ch.max(), ch.min()
|
371 |
+
else:
|
372 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
373 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
374 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
375 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
376 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
377 |
+
if high <= low:
|
378 |
+
table = np.arange(n_bins)
|
379 |
+
else:
|
380 |
+
scale = (n_bins - 1) / (high - low)
|
381 |
+
table = np.arange(n_bins) * scale - low * scale
|
382 |
+
table[table < 0] = 0
|
383 |
+
table[table > n_bins - 1] = n_bins - 1
|
384 |
+
table = table.clip(0, 255).astype(np.uint8)
|
385 |
+
return table[ch]
|
386 |
+
|
387 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
388 |
+
out = cv2.merge(channels)
|
389 |
+
return out
|
390 |
+
|
391 |
+
|
392 |
+
def equalize_func(img):
|
393 |
+
'''
|
394 |
+
same output as PIL.ImageOps.equalize
|
395 |
+
PIL's implementation is different from cv2.equalize
|
396 |
+
'''
|
397 |
+
n_bins = 256
|
398 |
+
|
399 |
+
def tune_channel(ch):
|
400 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
401 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
402 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
403 |
+
if step == 0:
|
404 |
+
return ch
|
405 |
+
n = np.empty_like(hist)
|
406 |
+
n[0] = step // 2
|
407 |
+
n[1:] = hist[:-1]
|
408 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
409 |
+
return table[ch]
|
410 |
+
|
411 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
412 |
+
out = cv2.merge(channels)
|
413 |
+
return out
|
414 |
+
|
415 |
+
|
416 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
417 |
+
'''
|
418 |
+
like PIL, rotate by degree, not radians
|
419 |
+
'''
|
420 |
+
H, W = img.shape[0], img.shape[1]
|
421 |
+
center = W / 2, H / 2
|
422 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
423 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
424 |
+
return out
|
425 |
+
|
426 |
+
|
427 |
+
def solarize_func(img, thresh=128):
|
428 |
+
'''
|
429 |
+
same output as PIL.ImageOps.posterize
|
430 |
+
'''
|
431 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
432 |
+
table = table.clip(0, 255).astype(np.uint8)
|
433 |
+
out = table[img]
|
434 |
+
return out
|
435 |
+
|
436 |
+
|
437 |
+
def color_func(img, factor):
|
438 |
+
'''
|
439 |
+
same output as PIL.ImageEnhance.Color
|
440 |
+
'''
|
441 |
+
# implementation according to PIL definition, quite slow
|
442 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
443 |
+
# out = blend(degenerate, img, factor)
|
444 |
+
# M = (
|
445 |
+
# np.eye(3) * factor
|
446 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
447 |
+
# )[np.newaxis, np.newaxis, :]
|
448 |
+
M = (
|
449 |
+
np.float32([
|
450 |
+
[0.886, -0.114, -0.114],
|
451 |
+
[-0.587, 0.413, -0.587],
|
452 |
+
[-0.299, -0.299, 0.701]]) * factor
|
453 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
454 |
+
)
|
455 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
456 |
+
return out
|
457 |
+
|
458 |
+
|
459 |
+
def contrast_func(img, factor):
|
460 |
+
"""
|
461 |
+
same output as PIL.ImageEnhance.Contrast
|
462 |
+
"""
|
463 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
464 |
+
table = np.array([(
|
465 |
+
el - mean) * factor + mean
|
466 |
+
for el in range(256)
|
467 |
+
]).clip(0, 255).astype(np.uint8)
|
468 |
+
out = table[img]
|
469 |
+
return out
|
470 |
+
|
471 |
+
|
472 |
+
def brightness_func(img, factor):
|
473 |
+
'''
|
474 |
+
same output as PIL.ImageEnhance.Contrast
|
475 |
+
'''
|
476 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
477 |
+
out = table[img]
|
478 |
+
return out
|
479 |
+
|
480 |
+
|
481 |
+
def sharpness_func(img, factor):
|
482 |
+
'''
|
483 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
484 |
+
areas are same
|
485 |
+
'''
|
486 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
487 |
+
kernel[1][1] = 5
|
488 |
+
kernel /= 13
|
489 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
490 |
+
if factor == 0.0:
|
491 |
+
out = degenerate
|
492 |
+
elif factor == 1.0:
|
493 |
+
out = img
|
494 |
+
else:
|
495 |
+
out = img.astype(np.float32)
|
496 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
497 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
498 |
+
out = out.astype(np.uint8)
|
499 |
+
return out
|
500 |
+
|
501 |
+
|
502 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
503 |
+
H, W = img.shape[0], img.shape[1]
|
504 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
505 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
506 |
+
return out
|
507 |
+
|
508 |
+
|
509 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
510 |
+
'''
|
511 |
+
same output as PIL.Image.transform
|
512 |
+
'''
|
513 |
+
H, W = img.shape[0], img.shape[1]
|
514 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
515 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
516 |
+
return out
|
517 |
+
|
518 |
+
|
519 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
520 |
+
'''
|
521 |
+
same output as PIL.Image.transform
|
522 |
+
'''
|
523 |
+
H, W = img.shape[0], img.shape[1]
|
524 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
525 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
526 |
+
return out
|
527 |
+
|
528 |
+
|
529 |
+
def posterize_func(img, bits):
|
530 |
+
'''
|
531 |
+
same output as PIL.ImageOps.posterize
|
532 |
+
'''
|
533 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
534 |
+
return out
|
535 |
+
|
536 |
+
|
537 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
538 |
+
H, W = img.shape[0], img.shape[1]
|
539 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
540 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
541 |
+
return out
|
542 |
+
|
543 |
+
|
544 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
545 |
+
replace = np.array(replace, dtype=np.uint8)
|
546 |
+
H, W = img.shape[0], img.shape[1]
|
547 |
+
rh, rw = np.random.random(2)
|
548 |
+
pad_size = pad_size // 2
|
549 |
+
ch, cw = int(rh * H), int(rw * W)
|
550 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
551 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
552 |
+
out = img.copy()
|
553 |
+
out[x1:x2, y1:y2, :] = replace
|
554 |
+
return out
|
555 |
+
|
556 |
+
|
557 |
+
# level to args
|
558 |
+
def enhance_level_to_args(MAX_LEVEL):
|
559 |
+
def level_to_args(level):
|
560 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
561 |
+
return level_to_args
|
562 |
+
|
563 |
+
|
564 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
565 |
+
def level_to_args(level):
|
566 |
+
level = (level / MAX_LEVEL) * 0.3
|
567 |
+
if np.random.random() > 0.5:
|
568 |
+
level = -level
|
569 |
+
return (level, replace_value)
|
570 |
+
|
571 |
+
return level_to_args
|
572 |
+
|
573 |
+
|
574 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
575 |
+
def level_to_args(level):
|
576 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
577 |
+
if np.random.random() > 0.5:
|
578 |
+
level = -level
|
579 |
+
return (level, replace_value)
|
580 |
+
|
581 |
+
return level_to_args
|
582 |
+
|
583 |
+
|
584 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
585 |
+
def level_to_args(level):
|
586 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
587 |
+
return (level, replace_value)
|
588 |
+
|
589 |
+
return level_to_args
|
590 |
+
|
591 |
+
|
592 |
+
def solarize_level_to_args(MAX_LEVEL):
|
593 |
+
def level_to_args(level):
|
594 |
+
level = int((level / MAX_LEVEL) * 256)
|
595 |
+
return (level, )
|
596 |
+
return level_to_args
|
597 |
+
|
598 |
+
|
599 |
+
def none_level_to_args(level):
|
600 |
+
return ()
|
601 |
+
|
602 |
+
|
603 |
+
def posterize_level_to_args(MAX_LEVEL):
|
604 |
+
def level_to_args(level):
|
605 |
+
level = int((level / MAX_LEVEL) * 4)
|
606 |
+
return (level, )
|
607 |
+
return level_to_args
|
608 |
+
|
609 |
+
|
610 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
611 |
+
def level_to_args(level):
|
612 |
+
level = (level / MAX_LEVEL) * 30
|
613 |
+
if np.random.random() < 0.5:
|
614 |
+
level = -level
|
615 |
+
return (level, replace_value)
|
616 |
+
|
617 |
+
return level_to_args
|
618 |
+
|
619 |
+
|
620 |
+
func_dict = {
|
621 |
+
'Identity': identity_func,
|
622 |
+
'AutoContrast': autocontrast_func,
|
623 |
+
'Equalize': equalize_func,
|
624 |
+
'Rotate': rotate_func,
|
625 |
+
'Solarize': solarize_func,
|
626 |
+
'Color': color_func,
|
627 |
+
'Contrast': contrast_func,
|
628 |
+
'Brightness': brightness_func,
|
629 |
+
'Sharpness': sharpness_func,
|
630 |
+
'ShearX': shear_x_func,
|
631 |
+
'TranslateX': translate_x_func,
|
632 |
+
'TranslateY': translate_y_func,
|
633 |
+
'Posterize': posterize_func,
|
634 |
+
'ShearY': shear_y_func,
|
635 |
+
}
|
636 |
+
|
637 |
+
translate_const = 10
|
638 |
+
MAX_LEVEL = 10
|
639 |
+
replace_value = (128, 128, 128)
|
640 |
+
arg_dict = {
|
641 |
+
'Identity': none_level_to_args,
|
642 |
+
'AutoContrast': none_level_to_args,
|
643 |
+
'Equalize': none_level_to_args,
|
644 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
645 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
646 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
647 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
648 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
649 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
650 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
651 |
+
'TranslateX': translate_level_to_args(
|
652 |
+
translate_const, MAX_LEVEL, replace_value
|
653 |
+
),
|
654 |
+
'TranslateY': translate_level_to_args(
|
655 |
+
translate_const, MAX_LEVEL, replace_value
|
656 |
+
),
|
657 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
658 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
659 |
+
}
|
660 |
+
|
661 |
+
|
662 |
+
class RandomAugment(object):
|
663 |
+
|
664 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
665 |
+
self.N = N
|
666 |
+
self.M = M
|
667 |
+
self.isPIL = isPIL
|
668 |
+
if augs:
|
669 |
+
self.augs = augs
|
670 |
+
else:
|
671 |
+
self.augs = list(arg_dict.keys())
|
672 |
+
|
673 |
+
def get_random_ops(self):
|
674 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
675 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
676 |
+
|
677 |
+
def __call__(self, img):
|
678 |
+
if self.isPIL:
|
679 |
+
img = np.array(img)
|
680 |
+
ops = self.get_random_ops()
|
681 |
+
for name, prob, level in ops:
|
682 |
+
if np.random.random() > prob:
|
683 |
+
continue
|
684 |
+
args = arg_dict[name](level)
|
685 |
+
img = func_dict[name](img, *args)
|
686 |
+
return img
|
687 |
+
|
688 |
+
|
689 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
|
690 |
+
if is_train:
|
691 |
+
t = [
|
692 |
+
RandomResizedCropAndInterpolation(
|
693 |
+
input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
694 |
+
transforms.RandomHorizontalFlip(),
|
695 |
+
]
|
696 |
+
if randaug:
|
697 |
+
t.append(
|
698 |
+
RandomAugment(
|
699 |
+
2, 7, isPIL=True,
|
700 |
+
augs=[
|
701 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
702 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
703 |
+
]))
|
704 |
+
t += [
|
705 |
+
transforms.ToTensor(),
|
706 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
707 |
+
]
|
708 |
+
t = transforms.Compose(t)
|
709 |
+
else:
|
710 |
+
t = transforms.Compose([
|
711 |
+
transforms.Resize((input_size, input_size),
|
712 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
713 |
+
transforms.ToTensor(),
|
714 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
|
715 |
+
])
|
716 |
+
|
717 |
+
return t
|
718 |
+
|
719 |
+
|
720 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
721 |
+
with open(filename, "wb") as fh:
|
722 |
+
with urllib.request.urlopen(
|
723 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
724 |
+
) as response:
|
725 |
+
with tqdm(total=response.length) as pbar:
|
726 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
727 |
+
if not chunk:
|
728 |
+
break
|
729 |
+
pbar.update(chunk_size)
|
730 |
+
fh.write(chunk)
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|