1024m commited on
Commit
152ee01
1 Parent(s): 13cce59

Upload preprocessing_molmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. preprocessing_molmo.py +189 -0
preprocessing_molmo.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for Molmo.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import PIL
8
+ from PIL import ImageOps
9
+ from PIL.Image import Image
10
+
11
+ try:
12
+ from typing import Unpack
13
+ except ImportError:
14
+ from typing_extensions import Unpack
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ from transformers.image_utils import ImageInput
20
+ from transformers.processing_utils import (
21
+ TextKwargs,
22
+ ProcessingKwargs,
23
+ ProcessorMixin,
24
+ )
25
+
26
+ from transformers.tokenization_utils_base import TextInput
27
+ from transformers.utils import logging
28
+
29
+ from transformers import AutoTokenizer
30
+ from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
37
+ DEFAULT_IM_START_TOKEN = f"<im_start>"
38
+ DEFAULT_IM_END_TOKEN = f"<im_end>"
39
+ DEFAULT_IM_COL_TOKEN = f"<im_col>"
40
+ IMAGE_PROMPT = "<|image|>"
41
+
42
+ EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
43
+
44
+
45
+ def get_special_token_ids(tokenizer):
46
+ ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
47
+ assert len(ids) == len(EXTRA_TOKENS)
48
+ return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
49
+
50
+
51
+ class MolmoTextKwargs(TextKwargs, total=False):
52
+ style: Optional[str]
53
+ system_prompt: Optional[str]
54
+ message_format: Optional[str]
55
+ always_start_with_space: Optional[bool]
56
+ sequence_length: Optional[int]
57
+
58
+
59
+ class MolmoProcessorKwargs(ProcessingKwargs, total=False):
60
+ text_kwargs: MolmoTextKwargs
61
+ images_kwargs: MolmoImagesKwargs
62
+ _defaults = {
63
+ "images_kwargs": {
64
+ "max_crops": 12,
65
+ "overlap_margins": [4, 4],
66
+ "base_image_input_size": [336, 336],
67
+ "image_token_length_w": 12,
68
+ "image_token_length_h": 12,
69
+ "image_patch_size": 14,
70
+ "image_padding_mask": True,
71
+ },
72
+ "text_kwargs": {
73
+ "style": "long_caption",
74
+ "system_prompt": "none",
75
+ "message_format": "role",
76
+ "always_start_with_space": True,
77
+ "sequence_length": 1536,
78
+ "padding": False,
79
+ },
80
+ }
81
+
82
+
83
+ class MolmoProcessor(ProcessorMixin):
84
+ attributes = ["image_processor", "tokenizer"]
85
+ image_processor_class = "AutoImageProcessor"
86
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
87
+
88
+ def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
89
+ # self.image_processor = image_processor
90
+ # self.tokenizer = tokenizer
91
+ super().__init__(image_processor, tokenizer)
92
+ self._special_tokens = None
93
+
94
+ @property
95
+ def special_token_ids(self):
96
+ if self._special_tokens is None:
97
+ self._special_tokens = get_special_token_ids(self.tokenizer)
98
+ return self._special_tokens
99
+
100
+ def get_tokens_input(self, prompt, message_format, always_start_with_space):
101
+ if message_format == "none" or message_format is None:
102
+ pass
103
+ elif message_format == "role":
104
+ prompt = "User: " + prompt + " Assistant:"
105
+ else:
106
+ raise NotImplementedError(f"Message format {message_format} not implemented")
107
+
108
+ if always_start_with_space:
109
+ prompt = " " + prompt
110
+
111
+ tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
112
+
113
+ return tokens
114
+
115
+ def process(
116
+ self,
117
+ text: TextInput = None,
118
+ images: ImageInput = None,
119
+ **kwargs: Unpack[MolmoProcessorKwargs],
120
+ ):
121
+ output_kwargs = self._merge_kwargs(
122
+ MolmoProcessorKwargs,
123
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
124
+ **kwargs,
125
+ )
126
+
127
+ tokens = self.get_tokens_input(
128
+ text,
129
+ output_kwargs["text_kwargs"]["message_format"],
130
+ output_kwargs["text_kwargs"]["always_start_with_space"],
131
+ )
132
+
133
+ image_token_id = self.special_token_ids[IMAGE_PROMPT]
134
+
135
+ if images is not None:
136
+ if not isinstance(images, (list, tuple)):
137
+ images = [images]
138
+ image_arrays = []
139
+ for image in images:
140
+ if isinstance(image, Image):
141
+ image = image.convert("RGB")
142
+ # Handle images with EXIF orientation tags, which PIL will ignore by default
143
+ # https://github.com/python-pillow/Pillow/issues/4703
144
+ img = ImageOps.exif_transpose(image)
145
+ image_arrays.append(np.array(image))
146
+ else:
147
+ assert len(image.shape) == 3 and image.shape[-1] == 3
148
+ image_arrays.append(image.astype(np.uint8))
149
+ images = image_arrays
150
+ # For now only support inserting images at the start
151
+ image_idx = [-1]*len(images)
152
+ else:
153
+ image_idx = None
154
+
155
+ sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
156
+
157
+ image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
158
+ image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
159
+ image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
160
+ image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
161
+ out = self.image_processor.multimodal_preprocess(
162
+ images=images,
163
+ image_idx=image_idx,
164
+ tokens=np.asarray(tokens).astype(np.int32),
165
+ sequence_length=sequence_length,
166
+ image_patch_token_id=image_patch_token_id,
167
+ image_col_token_id=image_col_token_id,
168
+ image_start_token_id=image_start_token_id,
169
+ image_end_token_id=image_end_token_id,
170
+ **output_kwargs["images_kwargs"]
171
+ )
172
+
173
+ # Prepend BOS
174
+ # qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
175
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
176
+ decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
177
+ out["input_ids"] = decoder_input_tokens
178
+ if "image_input_idx" in out:
179
+ # Shift patch mapping up by one since we added BOS
180
+ image_input_idx = out["image_input_idx"]
181
+ out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
182
+
183
+ for k, v in out.items():
184
+ out[k] = torch.from_numpy(v)
185
+
186
+ return out
187
+
188
+
189
+ MolmoProcessor.register_for_auto_class()