Spaces:
Running
on
A10G
Running
on
A10G
''' | |
Copyright (c) Alibaba, Inc. and its affiliates. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from functools import partial | |
from ldm.modules.diffusionmodules.util import conv_nd, linear | |
def get_clip_token_for_string(tokenizer, string): | |
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"] | |
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" | |
return tokens[0, 1] | |
def get_bert_token_for_string(tokenizer, string): | |
token = tokenizer(string) | |
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" | |
token = token[0, 1] | |
return token | |
def get_clip_vision_emb(encoder, processor, img): | |
_img = img.repeat(1, 3, 1, 1)*255 | |
inputs = processor(images=_img, return_tensors="pt") | |
inputs['pixel_values'] = inputs['pixel_values'].to(img.device) | |
outputs = encoder(**inputs) | |
emb = outputs.image_embeds | |
return emb | |
def get_recog_emb(encoder, img_list): | |
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list] | |
encoder.predictor.eval() | |
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) | |
return preds_neck | |
def pad_H(x): | |
_, _, H, W = x.shape | |
p_top = (W - H) // 2 | |
p_bot = W - H - p_top | |
return F.pad(x, (0, 0, p_top, p_bot)) | |
class EncodeNet(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(EncodeNet, self).__init__() | |
chan = 16 | |
n_layer = 4 # downsample | |
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) | |
self.conv_list = nn.ModuleList([]) | |
_c = chan | |
for i in range(n_layer): | |
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2)) | |
_c *= 2 | |
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.act = nn.SiLU() | |
def forward(self, x): | |
x = self.act(self.conv1(x)) | |
for layer in self.conv_list: | |
x = self.act(layer(x)) | |
x = self.act(self.conv2(x)) | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) | |
return x | |
class EmbeddingManager(nn.Module): | |
def __init__( | |
self, | |
embedder, | |
valid=True, | |
glyph_channels=20, | |
position_channels=1, | |
placeholder_string='*', | |
add_pos=False, | |
emb_type='ocr', | |
**kwargs | |
): | |
super().__init__() | |
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder | |
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) | |
token_dim = 768 | |
if hasattr(embedder, 'vit'): | |
assert emb_type == 'vit' | |
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) | |
self.get_recog_emb = None | |
else: # using LDM's BERT encoder | |
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) | |
token_dim = 1280 | |
self.token_dim = token_dim | |
self.emb_type = emb_type | |
self.add_pos = add_pos | |
if add_pos: | |
self.position_encoder = EncodeNet(position_channels, token_dim) | |
if emb_type == 'ocr': | |
self.proj = linear(40*64, token_dim) | |
if emb_type == 'conv': | |
self.glyph_encoder = EncodeNet(glyph_channels, token_dim) | |
self.placeholder_token = get_token_for_string(placeholder_string) | |
def encode_text(self, text_info): | |
if self.get_recog_emb is None and self.emb_type == 'ocr': | |
self.get_recog_emb = partial(get_recog_emb, self.recog) | |
gline_list = [] | |
pos_list = [] | |
for i in range(len(text_info['n_lines'])): # sample index in a batch | |
n_lines = text_info['n_lines'][i] | |
for j in range(n_lines): # line | |
gline_list += [text_info['gly_line'][j][i:i+1]] | |
if self.add_pos: | |
pos_list += [text_info['positions'][j][i:i+1]] | |
if len(gline_list) > 0: | |
if self.emb_type == 'ocr': | |
recog_emb = self.get_recog_emb(gline_list) | |
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) | |
elif self.emb_type == 'vit': | |
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) | |
elif self.emb_type == 'conv': | |
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) | |
if self.add_pos: | |
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) | |
enc_glyph = enc_glyph+enc_pos | |
self.text_embs_all = [] | |
n_idx = 0 | |
for i in range(len(text_info['n_lines'])): # sample index in a batch | |
n_lines = text_info['n_lines'][i] | |
text_embs = [] | |
for j in range(n_lines): # line | |
text_embs += [enc_glyph[n_idx:n_idx+1]] | |
n_idx += 1 | |
self.text_embs_all += [text_embs] | |
def forward( | |
self, | |
tokenized_text, | |
embedded_text, | |
): | |
b, device = tokenized_text.shape[0], tokenized_text.device | |
for i in range(b): | |
idx = tokenized_text[i] == self.placeholder_token.to(device) | |
if sum(idx) > 0: | |
if i >= len(self.text_embs_all): | |
print('truncation for log images...') | |
break | |
text_emb = torch.cat(self.text_embs_all[i], dim=0) | |
if sum(idx) != len(text_emb): | |
print('truncation for long caption...') | |
embedded_text[i][idx] = text_emb[:sum(idx)] | |
return embedded_text | |
def embedding_parameters(self): | |
return self.parameters() | |