Spaces:
Running
on
A10G
Running
on
A10G
import os | |
import numpy as np | |
import cv2 | |
import random | |
import math | |
from PIL import Image, ImageDraw, ImageFont | |
from torch.utils.data import Dataset, DataLoader | |
from dataset_util import load, show_bbox_on_image | |
phrase_list = [ | |
', content and position of the texts are ', | |
', textual material depicted in the image are ', | |
', texts that says ', | |
', captions shown in the snapshot are ', | |
', with the words of ', | |
', that reads ', | |
', the written materials on the picture: ', | |
', these texts are written on it: ', | |
', captions are ', | |
', content of the text in the graphic is ' | |
] | |
def insert_spaces(string, nSpace): | |
if nSpace == 0: | |
return string | |
new_string = "" | |
for char in string: | |
new_string += char + " " * nSpace | |
return new_string[:-nSpace] | |
def draw_glyph(font, text): | |
g_size = 50 | |
W, H = (512, 80) | |
new_font = font.font_variant(size=g_size) | |
img = Image.new(mode='1', size=(W, H), color=0) | |
draw = ImageDraw.Draw(img) | |
left, top, right, bottom = new_font.getbbox(text) | |
text_width = max(right-left, 5) | |
text_height = max(bottom - top, 5) | |
ratio = min(W*0.9/text_width, H*0.9/text_height) | |
new_font = font.font_variant(size=int(g_size*ratio)) | |
text_width, text_height = new_font.getsize(text) | |
offset_x, offset_y = new_font.getoffset(text) | |
x = (img.width - text_width) // 2 | |
y = (img.height - text_height) // 2 - offset_y//2 | |
draw.text((x, y), text, font=new_font, fill='white') | |
img = np.expand_dims(np.array(img), axis=2).astype(np.float64) | |
return img | |
def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): | |
enlarge_polygon = polygon*scale | |
rect = cv2.minAreaRect(enlarge_polygon) | |
box = cv2.boxPoints(rect) | |
box = np.int0(box) | |
w, h = rect[1] | |
angle = rect[2] | |
if angle < -45: | |
angle += 90 | |
angle = -angle | |
if w < h: | |
angle += 90 | |
vert = False | |
if (abs(angle) % 90 < vertAng or abs(90-abs(angle) % 90) % 90 < vertAng): | |
_w = max(box[:, 0]) - min(box[:, 0]) | |
_h = max(box[:, 1]) - min(box[:, 1]) | |
if _h >= _w: | |
vert = True | |
angle = 0 | |
img = np.zeros((height*scale, width*scale, 3), np.uint8) | |
img = Image.fromarray(img) | |
# infer font size | |
image4ratio = Image.new("RGB", img.size, "white") | |
draw = ImageDraw.Draw(image4ratio) | |
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) | |
text_w = min(w, h) * (_tw / _th) | |
if text_w <= max(w, h): | |
# add space | |
if len(text) > 1 and not vert and add_space: | |
for i in range(1, 100): | |
text_space = insert_spaces(text, i) | |
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) | |
if min(w, h) * (_tw2 / _th2) > max(w, h): | |
break | |
text = insert_spaces(text, i-1) | |
font_size = min(w, h)*0.80 | |
else: | |
shrink = 0.75 if vert else 0.85 | |
font_size = min(w, h) / (text_w/max(w, h)) * shrink | |
new_font = font.font_variant(size=int(font_size)) | |
left, top, right, bottom = new_font.getbbox(text) | |
text_width = right-left | |
text_height = bottom - top | |
layer = Image.new('RGBA', img.size, (0, 0, 0, 0)) | |
draw = ImageDraw.Draw(layer) | |
if not vert: | |
draw.text((rect[0][0]-text_width//2, rect[0][1]-text_height//2-top), text, font=new_font, fill=(255, 255, 255, 255)) | |
else: | |
x_s = min(box[:, 0]) + _w//2 - text_height//2 | |
y_s = min(box[:, 1]) | |
for c in text: | |
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) | |
_, _t, _, _b = new_font.getbbox(c) | |
y_s += _b | |
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) | |
x_offset = int((img.width - rotated_layer.width) / 2) | |
y_offset = int((img.height - rotated_layer.height) / 2) | |
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) | |
img = np.expand_dims(np.array(img.convert('1')), axis=2).astype(np.float64) | |
return img | |
def get_caption_pos(ori_caption, pos_idxs, prob=1.0, place_holder='*'): | |
idx2pos = { | |
0: " top left", | |
1: " top", | |
2: " top right", | |
3: " left", | |
4: random.choice([" middle", " center"]), | |
5: " right", | |
6: " bottom left", | |
7: " bottom", | |
8: " bottom right" | |
} | |
new_caption = ori_caption + random.choice(phrase_list) | |
pos = '' | |
for i in range(len(pos_idxs)): | |
if random.random() < prob and pos_idxs[i] > 0: | |
pos += place_holder + random.choice([' located', ' placed', ' positioned', '']) + random.choice([' at', ' in', ' on']) + idx2pos[pos_idxs[i]] + ', ' | |
else: | |
pos += place_holder + ' , ' | |
pos = pos[:-2] + '.' | |
new_caption += pos | |
return new_caption | |
def generate_random_rectangles(w, h, box_num): | |
rectangles = [] | |
for i in range(box_num): | |
x = random.randint(0, w) | |
y = random.randint(0, h) | |
w = random.randint(16, 256) | |
h = random.randint(16, 96) | |
angle = random.randint(-45, 45) | |
p1 = (x, y) | |
p2 = (x + w, y) | |
p3 = (x + w, y + h) | |
p4 = (x, y + h) | |
center = ((x + x + w) / 2, (y + y + h) / 2) | |
p1 = rotate_point(p1, center, angle) | |
p2 = rotate_point(p2, center, angle) | |
p3 = rotate_point(p3, center, angle) | |
p4 = rotate_point(p4, center, angle) | |
rectangles.append((p1, p2, p3, p4)) | |
return rectangles | |
def rotate_point(point, center, angle): | |
# rotation | |
angle = math.radians(angle) | |
x = point[0] - center[0] | |
y = point[1] - center[1] | |
x1 = x * math.cos(angle) - y * math.sin(angle) | |
y1 = x * math.sin(angle) + y * math.cos(angle) | |
x1 += center[0] | |
y1 += center[1] | |
return int(x1), int(y1) | |
class T3DataSet(Dataset): | |
def __init__( | |
self, | |
json_path, | |
max_lines=5, | |
max_chars=20, | |
place_holder='*', | |
font_path='./font/Arial_Unicode.ttf', | |
caption_pos_prob=1.0, | |
mask_pos_prob=1.0, | |
mask_img_prob=0.5, | |
for_show=False, | |
using_dlc=False, | |
glyph_scale=1, | |
percent=1.0, | |
debug=False, | |
wm_thresh=1.0, | |
): | |
assert isinstance(json_path, (str, list)) | |
if isinstance(json_path, str): | |
json_path = [json_path] | |
data_list = [] | |
self.using_dlc = using_dlc | |
self.max_lines = max_lines | |
self.max_chars = max_chars | |
self.place_holder = place_holder | |
self.font = ImageFont.truetype(font_path, size=60) | |
self.caption_pos_porb = caption_pos_prob | |
self.mask_pos_prob = mask_pos_prob | |
self.mask_img_prob = mask_img_prob | |
self.for_show = for_show | |
self.glyph_scale = glyph_scale | |
self.wm_thresh = wm_thresh | |
for jp in json_path: | |
data_list += self.load_data(jp, percent) | |
self.data_list = data_list | |
print(f'All dataset loaded, imgs={len(self.data_list)}') | |
self.debug = debug | |
if self.debug: | |
self.tmp_items = [i for i in range(100)] | |
def load_data(self, json_path, percent): | |
content = load(json_path) | |
d = [] | |
count = 0 | |
wm_skip = 0 | |
max_img = len(content['data_list']) * percent | |
for gt in content['data_list']: | |
if len(d) > max_img: | |
break | |
if 'wm_score' in gt and gt['wm_score'] > self.wm_thresh: # wm_score > thresh will be skiped as an img with watermark | |
wm_skip += 1 | |
continue | |
data_root = content['data_root'] | |
if self.using_dlc: | |
data_root = data_root.replace('/data/vdb', '/mnt/data', 1) | |
img_path = os.path.join(data_root, gt['img_name']) | |
info = {} | |
info['img_path'] = img_path | |
info['caption'] = gt['caption'] if 'caption' in gt else '' | |
if self.place_holder in info['caption']: | |
count += 1 | |
info['caption'] = info['caption'].replace(self.place_holder, " ") | |
if 'annotations' in gt: | |
polygons = [] | |
invalid_polygons = [] | |
texts = [] | |
languages = [] | |
pos = [] | |
for annotation in gt['annotations']: | |
if len(annotation['polygon']) == 0: | |
continue | |
if 'valid' in annotation and annotation['valid'] is False: | |
invalid_polygons.append(annotation['polygon']) | |
continue | |
polygons.append(annotation['polygon']) | |
texts.append(annotation['text']) | |
languages.append(annotation['language']) | |
if 'pos' in annotation: | |
pos.append(annotation['pos']) | |
info['polygons'] = [np.array(i) for i in polygons] | |
info['invalid_polygons'] = [np.array(i) for i in invalid_polygons] | |
info['texts'] = texts | |
info['language'] = languages | |
info['pos'] = pos | |
d.append(info) | |
print(f'{json_path} loaded, imgs={len(d)}, wm_skip={wm_skip}') | |
if count > 0: | |
print(f"Found {count} image's caption contain placeholder: {self.place_holder}, change to ' '...") | |
return d | |
def __getitem__(self, item): | |
item_dict = {} | |
if self.debug: # sample fixed items | |
item = self.tmp_items.pop() | |
print(f'item = {item}') | |
cur_item = self.data_list[item] | |
# img | |
target = np.array(Image.open(cur_item['img_path']).convert('RGB')) | |
if target.shape[0] != 512 or target.shape[1] != 512: | |
target = cv2.resize(target, (512, 512)) | |
target = (target.astype(np.float32) / 127.5) - 1.0 | |
item_dict['img'] = target | |
# caption | |
item_dict['caption'] = cur_item['caption'] | |
item_dict['glyphs'] = [] | |
item_dict['gly_line'] = [] | |
item_dict['positions'] = [] | |
item_dict['texts'] = [] | |
item_dict['language'] = [] | |
item_dict['inv_mask'] = [] | |
texts = cur_item.get('texts', []) | |
if len(texts) > 0: | |
idxs = [i for i in range(len(texts))] | |
if len(texts) > self.max_lines: | |
sel_idxs = random.sample(idxs, self.max_lines) | |
unsel_idxs = [i for i in idxs if i not in sel_idxs] | |
else: | |
sel_idxs = idxs | |
unsel_idxs = [] | |
if len(cur_item['pos']) > 0: | |
pos_idxs = [cur_item['pos'][i] for i in sel_idxs] | |
else: | |
pos_idxs = [-1 for i in sel_idxs] | |
item_dict['caption'] = get_caption_pos(item_dict['caption'], pos_idxs, self.caption_pos_porb, self.place_holder) | |
item_dict['polygons'] = [cur_item['polygons'][i] for i in sel_idxs] | |
item_dict['texts'] = [cur_item['texts'][i][:self.max_chars] for i in sel_idxs] | |
item_dict['language'] = [cur_item['language'][i] for i in sel_idxs] | |
# glyphs | |
for idx, text in enumerate(item_dict['texts']): | |
gly_line = draw_glyph(self.font, text) | |
glyphs = draw_glyph2(self.font, text, item_dict['polygons'][idx], scale=self.glyph_scale) | |
item_dict['glyphs'] += [glyphs] | |
item_dict['gly_line'] += [gly_line] | |
# mask_pos | |
for polygon in item_dict['polygons']: | |
item_dict['positions'] += [self.draw_pos(polygon, self.mask_pos_prob)] | |
# inv_mask | |
invalid_polygons = cur_item['invalid_polygons'] if 'invalid_polygons' in cur_item else [] | |
if len(texts) > 0: | |
invalid_polygons += [cur_item['polygons'][i] for i in unsel_idxs] | |
item_dict['inv_mask'] = self.draw_inv_mask(invalid_polygons) | |
item_dict['hint'] = self.get_hint(item_dict['positions']) | |
if random.random() < self.mask_img_prob: | |
# randomly generate 0~3 masks | |
box_num = random.randint(0, 3) | |
boxes = generate_random_rectangles(512, 512, box_num) | |
boxes = np.array(boxes) | |
pos_list = item_dict['positions'].copy() | |
for i in range(box_num): | |
pos_list += [self.draw_pos(boxes[i], self.mask_pos_prob)] | |
mask = self.get_hint(pos_list) | |
masked_img = target*(1-mask) | |
else: | |
masked_img = np.zeros_like(target) | |
item_dict['masked_img'] = masked_img | |
if self.for_show: | |
item_dict['img_name'] = os.path.split(cur_item['img_path'])[-1] | |
return item_dict | |
if len(texts) > 0: | |
del item_dict['polygons'] | |
# padding | |
n_lines = min(len(texts), self.max_lines) | |
item_dict['n_lines'] = n_lines | |
n_pad = self.max_lines - n_lines | |
if n_pad > 0: | |
item_dict['glyphs'] += [np.zeros((512*self.glyph_scale, 512*self.glyph_scale, 1))] * n_pad | |
item_dict['gly_line'] += [np.zeros((80, 512, 1))] * n_pad | |
item_dict['positions'] += [np.zeros((512, 512, 1))] * n_pad | |
item_dict['texts'] += [' '] * n_pad | |
item_dict['language'] += [' '] * n_pad | |
return item_dict | |
def __len__(self): | |
return len(self.data_list) | |
def draw_inv_mask(self, polygons): | |
img = np.zeros((512, 512)) | |
for p in polygons: | |
pts = p.reshape((-1, 1, 2)) | |
cv2.fillPoly(img, [pts], color=255) | |
img = img[..., None] | |
return img/255. | |
def draw_pos(self, ploygon, prob=1.0): | |
img = np.zeros((512, 512)) | |
rect = cv2.minAreaRect(ploygon) | |
w, h = rect[1] | |
small = False | |
if w < 20 or h < 20: | |
small = True | |
if random.random() < prob: | |
pts = ploygon.reshape((-1, 1, 2)) | |
cv2.fillPoly(img, [pts], color=255) | |
# 10% dilate / 10% erode / 5% dilatex2 5% erodex2 | |
random_value = random.random() | |
kernel = np.ones((3, 3), dtype=np.uint8) | |
if random_value < 0.7: | |
pass | |
elif random_value < 0.8: | |
img = cv2.dilate(img.astype(np.uint8), kernel, iterations=1) | |
elif random_value < 0.9 and not small: | |
img = cv2.erode(img.astype(np.uint8), kernel, iterations=1) | |
elif random_value < 0.95: | |
img = cv2.dilate(img.astype(np.uint8), kernel, iterations=2) | |
elif random_value < 1.0 and not small: | |
img = cv2.erode(img.astype(np.uint8), kernel, iterations=2) | |
img = img[..., None] | |
return img/255. | |
def get_hint(self, positions): | |
if len(positions) == 0: | |
return np.zeros((512, 512, 1)) | |
return np.sum(positions, axis=0).clip(0, 1) | |
if __name__ == '__main__': | |
''' | |
Run this script to show details of your dataset, such as ocr annotations, glyphs, prompts, etc. | |
''' | |
from tqdm import tqdm | |
from matplotlib import pyplot as plt | |
import shutil | |
show_imgs_dir = 'show_results' | |
show_count = 50 | |
if os.path.exists(show_imgs_dir): | |
shutil.rmtree(show_imgs_dir) | |
os.makedirs(show_imgs_dir) | |
plt.rcParams['axes.unicode_minus'] = False | |
json_paths = [ | |
'/path/of/your/dataset/data1.json', | |
'/path/of/your/dataset/data2.json', | |
# ... | |
] | |
dataset = T3DataSet(json_paths, for_show=True, max_lines=20, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0) | |
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) | |
pbar = tqdm(total=show_count) | |
for i, data in enumerate(train_loader): | |
if i == show_count: | |
break | |
img = ((data['img'][0].numpy() + 1.0) / 2.0 * 255).astype(np.uint8) | |
masked_img = ((data['masked_img'][0].numpy() + 1.0) / 2.0 * 255)[..., ::-1].astype(np.uint8) | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_masked.jpg'), masked_img) | |
if 'texts' in data and len(data['texts']) > 0: | |
texts = [x[0] for x in data['texts']] | |
img = show_bbox_on_image(Image.fromarray(img), data['polygons'], texts) | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}.jpg'), np.array(img)[..., ::-1]) | |
with open(os.path.join(show_imgs_dir, f'plots_{i}.txt'), 'w') as fin: | |
fin.writelines([data['caption'][0]]) | |
all_glyphs = [] | |
for k, glyphs in enumerate(data['glyphs']): | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_glyph_{k}.jpg'), glyphs[0].numpy().astype(np.int32)*255) | |
all_glyphs += [glyphs[0].numpy().astype(np.int32)*255] | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_allglyphs.jpg'), np.sum(all_glyphs, axis=0)) | |
for k, gly_line in enumerate(data['gly_line']): | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_gly_line_{k}.jpg'), gly_line[0].numpy().astype(np.int32)*255) | |
for k, position in enumerate(data['positions']): | |
if position is not None: | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_pos_{k}.jpg'), position[0].numpy().astype(np.int32)*255) | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_hint.jpg'), data['hint'][0].numpy().astype(np.int32)*255) | |
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_inv_mask.jpg'), np.array(img)[..., ::-1]*(1-data['inv_mask'][0].numpy().astype(np.int32))) | |
pbar.update(1) | |
pbar.close() | |