|
import cv2 |
|
import numpy as np |
|
import torch |
|
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
|
from timm.data.transforms import RandomResizedCropAndInterpolation |
|
from torchvision import transforms |
|
import urllib |
|
from tqdm import tqdm |
|
from cpm_live.tokenizers import CPMBeeTokenizer |
|
from torch.utils.data import default_collate |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
from typing_extensions import TypedDict |
|
from numpy.typing import NDArray |
|
import importlib.machinery |
|
import importlib.util |
|
import types |
|
import random |
|
|
|
|
|
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]] |
|
|
|
|
|
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): |
|
items = [] |
|
if isinstance(orig_items[0][key], list): |
|
assert isinstance(orig_items[0][key][0], torch.Tensor) |
|
for it in orig_items: |
|
for tr in it[key]: |
|
items.append({key: tr}) |
|
else: |
|
assert isinstance(orig_items[0][key], torch.Tensor) |
|
items = orig_items |
|
|
|
batch_size = len(items) |
|
shape = items[0][key].shape |
|
dim = len(shape) |
|
assert dim <= 3 |
|
if max_length is None: |
|
max_length = 0 |
|
max_length = max(max_length, max(item[key].shape[-1] for item in items)) |
|
min_length = min(item[key].shape[-1] for item in items) |
|
dtype = items[0][key].dtype |
|
|
|
if dim == 1: |
|
return torch.cat([item[key] for item in items], dim=0) |
|
elif dim == 2: |
|
if max_length == min_length: |
|
return torch.cat([item[key] for item in items], dim=0) |
|
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value |
|
else: |
|
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value |
|
|
|
for i, item in enumerate(items): |
|
if dim == 2: |
|
if padding_side == "left": |
|
tensor[i, -len(item[key][0]):] = item[key][0].clone() |
|
else: |
|
tensor[i, : len(item[key][0])] = item[key][0].clone() |
|
elif dim == 3: |
|
if padding_side == "left": |
|
tensor[i, -len(item[key][0]):, :] = item[key][0].clone() |
|
else: |
|
tensor[i, : len(item[key][0]), :] = item[key][0].clone() |
|
|
|
return tensor |
|
|
|
|
|
class CPMBeeCollater: |
|
""" |
|
针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker |
|
目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning |
|
并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持 |
|
todo: @wangchongyi 重写一下 Dataloader or BatchPacker |
|
""" |
|
|
|
def __init__(self, tokenizer: CPMBeeTokenizer, max_len): |
|
self.tokenizer = tokenizer |
|
self._max_length = max_len |
|
self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset', |
|
'segment_rel', 'sample_ids', 'num_segments'] |
|
|
|
def __call__(self, batch): |
|
batch_size = len(batch) |
|
|
|
tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32) |
|
|
|
span = np.zeros((batch_size, self._max_length), dtype=np.int32) |
|
length = np.zeros((batch_size,), dtype=np.int32) |
|
|
|
batch_ext_table_map: Dict[Tuple[int, int], int] = {} |
|
batch_ext_table_ids: List[int] = [] |
|
batch_ext_table_sub: List[int] = [] |
|
raw_data_list: List[Any] = [] |
|
|
|
for i in range(batch_size): |
|
instance_length = batch[i]['input_ids'][0].shape[0] |
|
length[i] = instance_length |
|
raw_data_list.extend(batch[i]['raw_data']) |
|
|
|
for j in range(instance_length): |
|
idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j] |
|
tgt_idx = idx |
|
if idx_sub > 0: |
|
|
|
if (idx, idx_sub) not in batch_ext_table_map: |
|
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map) |
|
batch_ext_table_ids.append(idx) |
|
batch_ext_table_sub.append(idx_sub) |
|
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size |
|
if j > 1 and batch[i]['context'][0, j - 1] == 0: |
|
if idx != self.tokenizer.bos_id: |
|
tgt[i, j - 1] = tgt_idx |
|
else: |
|
tgt[i, j - 1] = self.tokenizer.eos_id |
|
if batch[i]['context'][0, instance_length - 1] == 0: |
|
tgt[i, instance_length - 1] = self.tokenizer.eos_id |
|
|
|
if len(batch_ext_table_map) == 0: |
|
|
|
batch_ext_table_ids.append(0) |
|
batch_ext_table_sub.append(1) |
|
|
|
|
|
if 'pixel_values' in batch[0]: |
|
data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])} |
|
else: |
|
data = {} |
|
|
|
|
|
if 'image_bound' in batch[0]: |
|
data['image_bound'] = default_collate([i['image_bound'] for i in batch]) |
|
|
|
|
|
for key in self.pad_keys: |
|
data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right') |
|
|
|
data['context'] = data['context'] > 0 |
|
data['length'] = torch.from_numpy(length) |
|
data['span'] = torch.from_numpy(span) |
|
data['target'] = torch.from_numpy(tgt) |
|
data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids)) |
|
data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub)) |
|
data['raw_data'] = raw_data_list |
|
|
|
return data |
|
|
|
|
|
class _DictTree(TypedDict): |
|
value: str |
|
children: List["_DictTree"] |
|
depth: int |
|
segment_id: int |
|
need_predict: bool |
|
is_image: bool |
|
|
|
|
|
class _PrevExtTableStates(TypedDict): |
|
ext_table: Dict[int, str] |
|
token_id_table: Dict[str, Dict[int, int]] |
|
|
|
|
|
class _TransformFuncDict(TypedDict): |
|
loader: importlib.machinery.SourceFileLoader |
|
module: types.ModuleType |
|
last_m: float |
|
|
|
|
|
_TransformFunction = Callable[[CPMBeeInputType, int, random.Random], CPMBeeInputType] |
|
|
|
|
|
class CPMBeeBatch(TypedDict): |
|
inputs: NDArray[np.int32] |
|
inputs_sub: NDArray[np.int32] |
|
length: NDArray[np.int32] |
|
context: NDArray[np.bool_] |
|
sample_ids: NDArray[np.int32] |
|
num_segments: NDArray[np.int32] |
|
segment_ids: NDArray[np.int32] |
|
segment_rel_offset: NDArray[np.int32] |
|
segment_rel: NDArray[np.int32] |
|
spans: NDArray[np.int32] |
|
target: NDArray[np.int32] |
|
ext_ids: NDArray[np.int32] |
|
ext_sub: NDArray[np.int32] |
|
task_ids: NDArray[np.int32] |
|
task_names: List[str] |
|
raw_data: List[Any] |
|
|
|
|
|
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8): |
|
ret = n_up * max_depth + n_down |
|
if ret == 0: |
|
return ret |
|
else: |
|
|
|
return ret + 1 |
|
|
|
|
|
def convert_data_to_id( |
|
tokenizer: CPMBeeTokenizer, |
|
data: Any, |
|
prev_ext_states: Optional[_PrevExtTableStates] = None, |
|
shuffle_answer: bool = True, |
|
max_depth: int = 8 |
|
): |
|
root: _DictTree = { |
|
"value": "<root>", |
|
"children": [], |
|
"depth": 0, |
|
"segment_id": 0, |
|
"need_predict": False, |
|
"is_image": False |
|
} |
|
|
|
segments = [root] |
|
|
|
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]: |
|
if isinstance(data, dict): |
|
ret_list: List[_DictTree] = [] |
|
curr_items = list(data.items()) |
|
if need_predict and shuffle_answer: |
|
access_idx = np.arange(len(curr_items)) |
|
np.random.shuffle(access_idx) |
|
curr_items = [curr_items[idx] for idx in access_idx] |
|
for k, v in curr_items: |
|
child_info: _DictTree = { |
|
"value": k, |
|
"children": [], |
|
"depth": depth, |
|
"segment_id": len(segments), |
|
"need_predict": False, |
|
"is_image": False, |
|
} |
|
segments.append(child_info) |
|
child_info["children"] = _build_dict_tree( |
|
v, depth + 1, |
|
need_predict=need_predict or (depth == 1 and k == "<ans>"), |
|
is_image=is_image or (depth == 1 and k == "image") |
|
) |
|
|
|
ret_list.append(child_info) |
|
return ret_list |
|
else: |
|
assert isinstance(data, str), "Invalid data {}".format(data) |
|
ret: _DictTree = { |
|
"value": data, |
|
"children": [], |
|
"depth": depth, |
|
"segment_id": len(segments), |
|
"need_predict": need_predict, |
|
"is_image": is_image, |
|
} |
|
segments.append(ret) |
|
return [ret] |
|
|
|
root["children"] = _build_dict_tree(data, 1, False, False) |
|
|
|
num_segments = len(segments) |
|
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32) |
|
|
|
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]: |
|
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])] |
|
for child in node["children"]: |
|
sub = _build_segment_rel(child) |
|
for seg_id_1, depth_1 in sub: |
|
for seg_id_2, depth_2 in ret: |
|
n_up = min(depth_1 - node["depth"], max_depth - 1) |
|
n_down = min(depth_2 - node["depth"], max_depth - 1) |
|
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket( |
|
n_up, n_down, max_depth=max_depth |
|
) |
|
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket( |
|
n_down, n_up, max_depth=max_depth |
|
) |
|
ret.extend(sub) |
|
return ret |
|
|
|
_build_segment_rel(root) |
|
|
|
input_ids: List[int] = [] |
|
input_id_subs: List[int] = [] |
|
segment_bound: List[Tuple[int, int]] = [] |
|
image_bound: List[Tuple[int, int]] = [] |
|
|
|
ext_table: Dict[int, str] = {} |
|
token_id_table: Dict[str, Dict[int, int]] = {} |
|
|
|
if prev_ext_states is not None: |
|
ext_table = prev_ext_states["ext_table"] |
|
token_id_table = prev_ext_states["token_id_table"] |
|
|
|
for seg in segments: |
|
tokens, ext_table = tokenizer.encode(seg["value"], ext_table) |
|
|
|
token_id_subs = [] |
|
reid_token_ids = [] |
|
for idx in tokens: |
|
if idx in ext_table: |
|
|
|
token = ext_table[idx] |
|
if token.startswith("<") and token.endswith(">"): |
|
|
|
if "_" in token: |
|
token_name = token[1:-1].split("_", maxsplit=1)[0] |
|
else: |
|
token_name = token[1:-1] |
|
token_name = "<{}>".format(token_name) |
|
else: |
|
token_name = "<unk>" |
|
|
|
if token_name not in token_id_table: |
|
token_id_table[token_name] = {} |
|
if idx not in token_id_table[token_name]: |
|
token_id_table[token_name][idx] = len(token_id_table[token_name]) |
|
if token_name not in tokenizer.encoder: |
|
raise ValueError("Invalid token {}".format(token)) |
|
reid_token_ids.append(tokenizer.encoder[token_name]) |
|
token_id_subs.append(token_id_table[token_name][idx]) |
|
else: |
|
reid_token_ids.append(idx) |
|
token_id_subs.append(0) |
|
tokens = [tokenizer.bos_id] + reid_token_ids |
|
token_id_subs = [0] + token_id_subs |
|
if not seg["need_predict"]: |
|
tokens = tokens + [tokenizer.eos_id] |
|
token_id_subs = token_id_subs + [0] |
|
else: |
|
|
|
pass |
|
begin = len(input_ids) |
|
input_ids.extend(tokens) |
|
input_id_subs.extend(token_id_subs) |
|
end = len(input_ids) |
|
segment_bound.append((begin, end)) |
|
|
|
ids = np.array(input_ids, dtype=np.int32) |
|
id_subs = np.array(input_id_subs, dtype=np.int32) |
|
segs = np.zeros((ids.shape[0],), dtype=np.int32) |
|
context = np.zeros((ids.shape[0],), dtype=np.int8) |
|
for i, (begin, end) in enumerate(segment_bound): |
|
if not segments[i]["need_predict"]: |
|
context[begin:end] = 1 |
|
if segments[i]["is_image"]: |
|
image_bound.append((begin+1, end-1)) |
|
segs[begin:end] = i |
|
|
|
curr_ext_table_states: _PrevExtTableStates = { |
|
"ext_table": ext_table, |
|
"token_id_table": token_id_table, |
|
} |
|
image_bound = np.array(image_bound, dtype=np.int32) |
|
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound |
|
|
|
|
|
|
|
def identity_func(img): |
|
return img |
|
|
|
|
|
def autocontrast_func(img, cutoff=0): |
|
''' |
|
same output as PIL.ImageOps.autocontrast |
|
''' |
|
n_bins = 256 |
|
|
|
def tune_channel(ch): |
|
n = ch.size |
|
cut = cutoff * n // 100 |
|
if cut == 0: |
|
high, low = ch.max(), ch.min() |
|
else: |
|
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) |
|
low = np.argwhere(np.cumsum(hist) > cut) |
|
low = 0 if low.shape[0] == 0 else low[0] |
|
high = np.argwhere(np.cumsum(hist[::-1]) > cut) |
|
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] |
|
if high <= low: |
|
table = np.arange(n_bins) |
|
else: |
|
scale = (n_bins - 1) / (high - low) |
|
table = np.arange(n_bins) * scale - low * scale |
|
table[table < 0] = 0 |
|
table[table > n_bins - 1] = n_bins - 1 |
|
table = table.clip(0, 255).astype(np.uint8) |
|
return table[ch] |
|
|
|
channels = [tune_channel(ch) for ch in cv2.split(img)] |
|
out = cv2.merge(channels) |
|
return out |
|
|
|
|
|
def equalize_func(img): |
|
''' |
|
same output as PIL.ImageOps.equalize |
|
PIL's implementation is different from cv2.equalize |
|
''' |
|
n_bins = 256 |
|
|
|
def tune_channel(ch): |
|
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) |
|
non_zero_hist = hist[hist != 0].reshape(-1) |
|
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) |
|
if step == 0: |
|
return ch |
|
n = np.empty_like(hist) |
|
n[0] = step // 2 |
|
n[1:] = hist[:-1] |
|
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) |
|
return table[ch] |
|
|
|
channels = [tune_channel(ch) for ch in cv2.split(img)] |
|
out = cv2.merge(channels) |
|
return out |
|
|
|
|
|
def rotate_func(img, degree, fill=(0, 0, 0)): |
|
''' |
|
like PIL, rotate by degree, not radians |
|
''' |
|
H, W = img.shape[0], img.shape[1] |
|
center = W / 2, H / 2 |
|
M = cv2.getRotationMatrix2D(center, degree, 1) |
|
out = cv2.warpAffine(img, M, (W, H), borderValue=fill) |
|
return out |
|
|
|
|
|
def solarize_func(img, thresh=128): |
|
''' |
|
same output as PIL.ImageOps.posterize |
|
''' |
|
table = np.array([el if el < thresh else 255 - el for el in range(256)]) |
|
table = table.clip(0, 255).astype(np.uint8) |
|
out = table[img] |
|
return out |
|
|
|
|
|
def color_func(img, factor): |
|
''' |
|
same output as PIL.ImageEnhance.Color |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
M = ( |
|
np.float32([ |
|
[0.886, -0.114, -0.114], |
|
[-0.587, 0.413, -0.587], |
|
[-0.299, -0.299, 0.701]]) * factor |
|
+ np.float32([[0.114], [0.587], [0.299]]) |
|
) |
|
out = np.matmul(img, M).clip(0, 255).astype(np.uint8) |
|
return out |
|
|
|
|
|
def contrast_func(img, factor): |
|
""" |
|
same output as PIL.ImageEnhance.Contrast |
|
""" |
|
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) |
|
table = np.array([( |
|
el - mean) * factor + mean |
|
for el in range(256) |
|
]).clip(0, 255).astype(np.uint8) |
|
out = table[img] |
|
return out |
|
|
|
|
|
def brightness_func(img, factor): |
|
''' |
|
same output as PIL.ImageEnhance.Contrast |
|
''' |
|
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) |
|
out = table[img] |
|
return out |
|
|
|
|
|
def sharpness_func(img, factor): |
|
''' |
|
The differences the this result and PIL are all on the 4 boundaries, the center |
|
areas are same |
|
''' |
|
kernel = np.ones((3, 3), dtype=np.float32) |
|
kernel[1][1] = 5 |
|
kernel /= 13 |
|
degenerate = cv2.filter2D(img, -1, kernel) |
|
if factor == 0.0: |
|
out = degenerate |
|
elif factor == 1.0: |
|
out = img |
|
else: |
|
out = img.astype(np.float32) |
|
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] |
|
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) |
|
out = out.astype(np.uint8) |
|
return out |
|
|
|
|
|
def shear_x_func(img, factor, fill=(0, 0, 0)): |
|
H, W = img.shape[0], img.shape[1] |
|
M = np.float32([[1, factor, 0], [0, 1, 0]]) |
|
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
|
return out |
|
|
|
|
|
def translate_x_func(img, offset, fill=(0, 0, 0)): |
|
''' |
|
same output as PIL.Image.transform |
|
''' |
|
H, W = img.shape[0], img.shape[1] |
|
M = np.float32([[1, 0, -offset], [0, 1, 0]]) |
|
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
|
return out |
|
|
|
|
|
def translate_y_func(img, offset, fill=(0, 0, 0)): |
|
''' |
|
same output as PIL.Image.transform |
|
''' |
|
H, W = img.shape[0], img.shape[1] |
|
M = np.float32([[1, 0, 0], [0, 1, -offset]]) |
|
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
|
return out |
|
|
|
|
|
def posterize_func(img, bits): |
|
''' |
|
same output as PIL.ImageOps.posterize |
|
''' |
|
out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) |
|
return out |
|
|
|
|
|
def shear_y_func(img, factor, fill=(0, 0, 0)): |
|
H, W = img.shape[0], img.shape[1] |
|
M = np.float32([[1, 0, 0], [factor, 1, 0]]) |
|
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
|
return out |
|
|
|
|
|
def cutout_func(img, pad_size, replace=(0, 0, 0)): |
|
replace = np.array(replace, dtype=np.uint8) |
|
H, W = img.shape[0], img.shape[1] |
|
rh, rw = np.random.random(2) |
|
pad_size = pad_size // 2 |
|
ch, cw = int(rh * H), int(rw * W) |
|
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) |
|
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) |
|
out = img.copy() |
|
out[x1:x2, y1:y2, :] = replace |
|
return out |
|
|
|
|
|
|
|
def enhance_level_to_args(MAX_LEVEL): |
|
def level_to_args(level): |
|
return ((level / MAX_LEVEL) * 1.8 + 0.1,) |
|
return level_to_args |
|
|
|
|
|
def shear_level_to_args(MAX_LEVEL, replace_value): |
|
def level_to_args(level): |
|
level = (level / MAX_LEVEL) * 0.3 |
|
if np.random.random() > 0.5: |
|
level = -level |
|
return (level, replace_value) |
|
|
|
return level_to_args |
|
|
|
|
|
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): |
|
def level_to_args(level): |
|
level = (level / MAX_LEVEL) * float(translate_const) |
|
if np.random.random() > 0.5: |
|
level = -level |
|
return (level, replace_value) |
|
|
|
return level_to_args |
|
|
|
|
|
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): |
|
def level_to_args(level): |
|
level = int((level / MAX_LEVEL) * cutout_const) |
|
return (level, replace_value) |
|
|
|
return level_to_args |
|
|
|
|
|
def solarize_level_to_args(MAX_LEVEL): |
|
def level_to_args(level): |
|
level = int((level / MAX_LEVEL) * 256) |
|
return (level, ) |
|
return level_to_args |
|
|
|
|
|
def none_level_to_args(level): |
|
return () |
|
|
|
|
|
def posterize_level_to_args(MAX_LEVEL): |
|
def level_to_args(level): |
|
level = int((level / MAX_LEVEL) * 4) |
|
return (level, ) |
|
return level_to_args |
|
|
|
|
|
def rotate_level_to_args(MAX_LEVEL, replace_value): |
|
def level_to_args(level): |
|
level = (level / MAX_LEVEL) * 30 |
|
if np.random.random() < 0.5: |
|
level = -level |
|
return (level, replace_value) |
|
|
|
return level_to_args |
|
|
|
|
|
func_dict = { |
|
'Identity': identity_func, |
|
'AutoContrast': autocontrast_func, |
|
'Equalize': equalize_func, |
|
'Rotate': rotate_func, |
|
'Solarize': solarize_func, |
|
'Color': color_func, |
|
'Contrast': contrast_func, |
|
'Brightness': brightness_func, |
|
'Sharpness': sharpness_func, |
|
'ShearX': shear_x_func, |
|
'TranslateX': translate_x_func, |
|
'TranslateY': translate_y_func, |
|
'Posterize': posterize_func, |
|
'ShearY': shear_y_func, |
|
} |
|
|
|
translate_const = 10 |
|
MAX_LEVEL = 10 |
|
replace_value = (128, 128, 128) |
|
arg_dict = { |
|
'Identity': none_level_to_args, |
|
'AutoContrast': none_level_to_args, |
|
'Equalize': none_level_to_args, |
|
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), |
|
'Solarize': solarize_level_to_args(MAX_LEVEL), |
|
'Color': enhance_level_to_args(MAX_LEVEL), |
|
'Contrast': enhance_level_to_args(MAX_LEVEL), |
|
'Brightness': enhance_level_to_args(MAX_LEVEL), |
|
'Sharpness': enhance_level_to_args(MAX_LEVEL), |
|
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), |
|
'TranslateX': translate_level_to_args( |
|
translate_const, MAX_LEVEL, replace_value |
|
), |
|
'TranslateY': translate_level_to_args( |
|
translate_const, MAX_LEVEL, replace_value |
|
), |
|
'Posterize': posterize_level_to_args(MAX_LEVEL), |
|
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), |
|
} |
|
|
|
|
|
class RandomAugment(object): |
|
|
|
def __init__(self, N=2, M=10, isPIL=False, augs=[]): |
|
self.N = N |
|
self.M = M |
|
self.isPIL = isPIL |
|
if augs: |
|
self.augs = augs |
|
else: |
|
self.augs = list(arg_dict.keys()) |
|
|
|
def get_random_ops(self): |
|
sampled_ops = np.random.choice(self.augs, self.N) |
|
return [(op, 0.5, self.M) for op in sampled_ops] |
|
|
|
def __call__(self, img): |
|
if self.isPIL: |
|
img = np.array(img) |
|
ops = self.get_random_ops() |
|
for name, prob, level in ops: |
|
if np.random.random() > prob: |
|
continue |
|
args = arg_dict[name](level) |
|
img = func_dict[name](img, *args) |
|
return img |
|
|
|
|
|
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'): |
|
if is_train: |
|
t = [ |
|
RandomResizedCropAndInterpolation( |
|
input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
] |
|
if randaug: |
|
t.append( |
|
RandomAugment( |
|
2, 7, isPIL=True, |
|
augs=[ |
|
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', |
|
])) |
|
t += [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
|
] |
|
t = transforms.Compose(t) |
|
else: |
|
t = transforms.Compose([ |
|
transforms.Resize((input_size, input_size), |
|
interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) |
|
]) |
|
|
|
return t |
|
|
|
|
|
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: |
|
with open(filename, "wb") as fh: |
|
with urllib.request.urlopen( |
|
urllib.request.Request(url, headers={"User-Agent": "vissl"}) |
|
) as response: |
|
with tqdm(total=response.length) as pbar: |
|
for chunk in iter(lambda: response.read(chunk_size), ""): |
|
if not chunk: |
|
break |
|
pbar.update(chunk_size) |
|
fh.write(chunk) |
|
|