File size: 6,061 Bytes
de7836d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from ldm.modules.diffusionmodules.util import conv_nd, linear


def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    tokens = batch_encoding["input_ids"]
    assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
    return tokens[0, 1]


def get_bert_token_for_string(tokenizer, string):
    token = tokenizer(string)
    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
    token = token[0, 1]
    return token


def get_clip_vision_emb(encoder, processor, img):
    _img = img.repeat(1, 3, 1, 1)*255
    inputs = processor(images=_img, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
    outputs = encoder(**inputs)
    emb = outputs.image_embeds
    return emb


def get_recog_emb(encoder, img_list):
    _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
    encoder.predictor.eval()
    _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
    return preds_neck


def pad_H(x):
    _, _, H, W = x.shape
    p_top = (W - H) // 2
    p_bot = W - H - p_top
    return F.pad(x, (0, 0, p_top, p_bot))


class EncodeNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncodeNet, self).__init__()
        chan = 16
        n_layer = 4  # downsample

        self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
        self.conv_list = nn.ModuleList([])
        _c = chan
        for i in range(n_layer):
            self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
            _c *= 2
        self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.act(self.conv1(x))
        for layer in self.conv_list:
            x = self.act(layer(x))
        x = self.act(self.conv2(x))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x


class EmbeddingManager(nn.Module):
    def __init__(
            self,
            embedder,
            valid=True,
            glyph_channels=20,
            position_channels=1,
            placeholder_string='*',
            add_pos=False,
            emb_type='ocr',
            **kwargs
    ):
        super().__init__()
        if hasattr(embedder, 'tokenizer'):  # using Stable Diffusion's CLIP encoder
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            token_dim = 768
            if hasattr(embedder, 'vit'):
                assert emb_type == 'vit'
                self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
            self.get_recog_emb = None
        else:  # using LDM's BERT encoder
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            token_dim = 1280
        self.token_dim = token_dim
        self.emb_type = emb_type

        self.add_pos = add_pos
        if add_pos:
            self.position_encoder = EncodeNet(position_channels, token_dim)
        if emb_type == 'ocr':
            self.proj = linear(40*64, token_dim)
        if emb_type == 'conv':
            self.glyph_encoder = EncodeNet(glyph_channels, token_dim)

        self.placeholder_token = get_token_for_string(placeholder_string)

    def encode_text(self, text_info):
        if self.get_recog_emb is None and self.emb_type == 'ocr':
            self.get_recog_emb = partial(get_recog_emb, self.recog)

        gline_list = []
        pos_list = []
        for i in range(len(text_info['n_lines'])):  # sample index in a batch
            n_lines = text_info['n_lines'][i]
            for j in range(n_lines):  # line
                gline_list += [text_info['gly_line'][j][i:i+1]]
                if self.add_pos:
                    pos_list += [text_info['positions'][j][i:i+1]]

        if len(gline_list) > 0:
            if self.emb_type == 'ocr':
                recog_emb = self.get_recog_emb(gline_list)
                enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
            elif self.emb_type == 'vit':
                enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
            elif self.emb_type == 'conv':
                enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
            if self.add_pos:
                enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
                enc_glyph = enc_glyph+enc_pos

        self.text_embs_all = []
        n_idx = 0
        for i in range(len(text_info['n_lines'])):  # sample index in a batch
            n_lines = text_info['n_lines'][i]
            text_embs = []
            for j in range(n_lines):  # line
                text_embs += [enc_glyph[n_idx:n_idx+1]]
                n_idx += 1
            self.text_embs_all += [text_embs]

    def forward(
            self,
            tokenized_text,
            embedded_text,
    ):
        b, device = tokenized_text.shape[0], tokenized_text.device
        for i in range(b):
            idx = tokenized_text[i] == self.placeholder_token.to(device)
            if sum(idx) > 0:
                if i >= len(self.text_embs_all):
                    print('truncation for log images...')
                    break
                text_emb = torch.cat(self.text_embs_all[i], dim=0)
                if sum(idx) != len(text_emb):
                    print('truncation for long caption...')
                embedded_text[i][idx] = text_emb[:sum(idx)]
        return embedded_text

    def embedding_parameters(self):
        return self.parameters()