Upload preprocessing_molmo.py with huggingface_hub
Browse files- preprocessing_molmo.py +189 -0
preprocessing_molmo.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Processor class for Molmo.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import PIL
|
8 |
+
from PIL import ImageOps
|
9 |
+
from PIL.Image import Image
|
10 |
+
|
11 |
+
try:
|
12 |
+
from typing import Unpack
|
13 |
+
except ImportError:
|
14 |
+
from typing_extensions import Unpack
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from transformers.image_utils import ImageInput
|
20 |
+
from transformers.processing_utils import (
|
21 |
+
TextKwargs,
|
22 |
+
ProcessingKwargs,
|
23 |
+
ProcessorMixin,
|
24 |
+
)
|
25 |
+
|
26 |
+
from transformers.tokenization_utils_base import TextInput
|
27 |
+
from transformers.utils import logging
|
28 |
+
|
29 |
+
from transformers import AutoTokenizer
|
30 |
+
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
37 |
+
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
38 |
+
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
39 |
+
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
40 |
+
IMAGE_PROMPT = "<|image|>"
|
41 |
+
|
42 |
+
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
43 |
+
|
44 |
+
|
45 |
+
def get_special_token_ids(tokenizer):
|
46 |
+
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
47 |
+
assert len(ids) == len(EXTRA_TOKENS)
|
48 |
+
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
49 |
+
|
50 |
+
|
51 |
+
class MolmoTextKwargs(TextKwargs, total=False):
|
52 |
+
style: Optional[str]
|
53 |
+
system_prompt: Optional[str]
|
54 |
+
message_format: Optional[str]
|
55 |
+
always_start_with_space: Optional[bool]
|
56 |
+
sequence_length: Optional[int]
|
57 |
+
|
58 |
+
|
59 |
+
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
60 |
+
text_kwargs: MolmoTextKwargs
|
61 |
+
images_kwargs: MolmoImagesKwargs
|
62 |
+
_defaults = {
|
63 |
+
"images_kwargs": {
|
64 |
+
"max_crops": 12,
|
65 |
+
"overlap_margins": [4, 4],
|
66 |
+
"base_image_input_size": [336, 336],
|
67 |
+
"image_token_length_w": 12,
|
68 |
+
"image_token_length_h": 12,
|
69 |
+
"image_patch_size": 14,
|
70 |
+
"image_padding_mask": True,
|
71 |
+
},
|
72 |
+
"text_kwargs": {
|
73 |
+
"style": "long_caption",
|
74 |
+
"system_prompt": "none",
|
75 |
+
"message_format": "role",
|
76 |
+
"always_start_with_space": True,
|
77 |
+
"sequence_length": 1536,
|
78 |
+
"padding": False,
|
79 |
+
},
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
class MolmoProcessor(ProcessorMixin):
|
84 |
+
attributes = ["image_processor", "tokenizer"]
|
85 |
+
image_processor_class = "AutoImageProcessor"
|
86 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
87 |
+
|
88 |
+
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
|
89 |
+
# self.image_processor = image_processor
|
90 |
+
# self.tokenizer = tokenizer
|
91 |
+
super().__init__(image_processor, tokenizer)
|
92 |
+
self._special_tokens = None
|
93 |
+
|
94 |
+
@property
|
95 |
+
def special_token_ids(self):
|
96 |
+
if self._special_tokens is None:
|
97 |
+
self._special_tokens = get_special_token_ids(self.tokenizer)
|
98 |
+
return self._special_tokens
|
99 |
+
|
100 |
+
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
101 |
+
if message_format == "none" or message_format is None:
|
102 |
+
pass
|
103 |
+
elif message_format == "role":
|
104 |
+
prompt = "User: " + prompt + " Assistant:"
|
105 |
+
else:
|
106 |
+
raise NotImplementedError(f"Message format {message_format} not implemented")
|
107 |
+
|
108 |
+
if always_start_with_space:
|
109 |
+
prompt = " " + prompt
|
110 |
+
|
111 |
+
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
|
112 |
+
|
113 |
+
return tokens
|
114 |
+
|
115 |
+
def process(
|
116 |
+
self,
|
117 |
+
text: TextInput = None,
|
118 |
+
images: ImageInput = None,
|
119 |
+
**kwargs: Unpack[MolmoProcessorKwargs],
|
120 |
+
):
|
121 |
+
output_kwargs = self._merge_kwargs(
|
122 |
+
MolmoProcessorKwargs,
|
123 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
124 |
+
**kwargs,
|
125 |
+
)
|
126 |
+
|
127 |
+
tokens = self.get_tokens_input(
|
128 |
+
text,
|
129 |
+
output_kwargs["text_kwargs"]["message_format"],
|
130 |
+
output_kwargs["text_kwargs"]["always_start_with_space"],
|
131 |
+
)
|
132 |
+
|
133 |
+
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
134 |
+
|
135 |
+
if images is not None:
|
136 |
+
if not isinstance(images, (list, tuple)):
|
137 |
+
images = [images]
|
138 |
+
image_arrays = []
|
139 |
+
for image in images:
|
140 |
+
if isinstance(image, Image):
|
141 |
+
image = image.convert("RGB")
|
142 |
+
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
143 |
+
# https://github.com/python-pillow/Pillow/issues/4703
|
144 |
+
img = ImageOps.exif_transpose(image)
|
145 |
+
image_arrays.append(np.array(image))
|
146 |
+
else:
|
147 |
+
assert len(image.shape) == 3 and image.shape[-1] == 3
|
148 |
+
image_arrays.append(image.astype(np.uint8))
|
149 |
+
images = image_arrays
|
150 |
+
# For now only support inserting images at the start
|
151 |
+
image_idx = [-1]*len(images)
|
152 |
+
else:
|
153 |
+
image_idx = None
|
154 |
+
|
155 |
+
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
|
156 |
+
|
157 |
+
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
|
158 |
+
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
|
159 |
+
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
|
160 |
+
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
|
161 |
+
out = self.image_processor.multimodal_preprocess(
|
162 |
+
images=images,
|
163 |
+
image_idx=image_idx,
|
164 |
+
tokens=np.asarray(tokens).astype(np.int32),
|
165 |
+
sequence_length=sequence_length,
|
166 |
+
image_patch_token_id=image_patch_token_id,
|
167 |
+
image_col_token_id=image_col_token_id,
|
168 |
+
image_start_token_id=image_start_token_id,
|
169 |
+
image_end_token_id=image_end_token_id,
|
170 |
+
**output_kwargs["images_kwargs"]
|
171 |
+
)
|
172 |
+
|
173 |
+
# Prepend BOS
|
174 |
+
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
|
175 |
+
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
176 |
+
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
|
177 |
+
out["input_ids"] = decoder_input_tokens
|
178 |
+
if "image_input_idx" in out:
|
179 |
+
# Shift patch mapping up by one since we added BOS
|
180 |
+
image_input_idx = out["image_input_idx"]
|
181 |
+
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
182 |
+
|
183 |
+
for k, v in out.items():
|
184 |
+
out[k] = torch.from_numpy(v)
|
185 |
+
|
186 |
+
return out
|
187 |
+
|
188 |
+
|
189 |
+
MolmoProcessor.register_for_auto_class()
|