EmaadKhwaja commited on
Commit
5d2263b
1 Parent(s): 86d2765

file upload

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,7 +1,3 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
  import gradio as gr
2
 
3
+ demo = gr.load("HuangLab/CELL-E_2_HPA_Finetuned_480", src="models")
 
 
 
 
celle/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from celle.celle import CELLE
2
+ from celle.vae import VQGanVAE
3
+
4
+ __version__ = "2.0.0"
celle/attention.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+
6
+ from rotary_embedding_torch import apply_rotary_emb
7
+ from celle.utils import exists, default, max_neg_value
8
+
9
+
10
+ # helpers
11
+ def stable_softmax(t, dim=-1, alpha=32**2):
12
+ t = t / alpha
13
+ t = t - torch.amax(t, dim=dim, keepdim=True).detach()
14
+ return (t * alpha).softmax(dim=dim)
15
+
16
+
17
+ def apply_pos_emb(pos_emb, qkv):
18
+ n = qkv[0].shape[-2]
19
+ pos_emb = pos_emb[..., :n, :]
20
+ return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
21
+
22
+
23
+ # classes
24
+ class Attention(nn.Module):
25
+ def __init__(
26
+ self,
27
+ dim,
28
+ seq_len,
29
+ causal=False,
30
+ heads=8,
31
+ dim_head=64,
32
+ dropout=0.0,
33
+ stable=False,
34
+ static_mask=None,
35
+ ):
36
+ super().__init__()
37
+ inner_dim = dim_head * heads
38
+ self.heads = heads
39
+ self.seq_len = seq_len
40
+ self.scale = dim_head**-0.5
41
+ self.stable = stable
42
+ self.causal = causal
43
+ self.register_buffer("static_mask", static_mask, persistent=False)
44
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
45
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
46
+ self.save_attn = nn.Identity()
47
+
48
+ def forward(self, x, context_mask=None, rotary_pos_emb=None):
49
+ # x: [batch_size, seq_len, dim]
50
+ b, n, _, h = *x.shape, self.heads
51
+ device = x.device
52
+
53
+ softmax = torch.softmax if not self.stable else stable_softmax
54
+
55
+ # qkv: 3 tensors of shape [batch_size, seq_len, inner_dim]
56
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
57
+
58
+ # q,k,v: [batch_size, heads, seq_len, dim_head]
59
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
60
+
61
+ if exists(rotary_pos_emb):
62
+ q, k, v = apply_pos_emb(rotary_pos_emb[..., :, :], (q, k, v))
63
+
64
+ q *= self.scale
65
+
66
+ # dots: [batch_size, heads, seq_len_i ,seq_len_j]
67
+ dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
68
+ mask_value = max_neg_value(dots)
69
+
70
+ if exists(context_mask):
71
+ # context_mask: [batch_size ,1 ,1 ,seq_len_j]
72
+ context_mask = rearrange(context_mask, "b j -> b 1 1 j")
73
+ context_mask = F.pad(context_mask, (1, 0), value=True)
74
+
75
+ mask_value = -torch.finfo(dots.dtype).max
76
+ dots = dots.masked_fill(~context_mask, mask_value)
77
+
78
+ if self.causal:
79
+ i, j = dots.shape[-2:]
80
+ context_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
81
+ dots.masked_fill_(context_mask, mask_value)
82
+
83
+ if exists(self.static_mask):
84
+ dots.masked_fill_(~self.static_mask[:n, :n], mask_value)
85
+
86
+ # attn: [batch_size ,heads ,seq_len_i ,seq_len_j]
87
+ attn = softmax(dots, dim=-1)
88
+ attn = self.save_attn(attn)
89
+
90
+ # out: [batch_size ,heads ,seq_len_i ,dim_head]
91
+ out = torch.einsum("b h n j, b h j d -> b h n d", attn, v)
92
+
93
+ # out: [batch_size ,seq_len_i ,(heads*dim_head)]
94
+ out = rearrange(out, "b h n d -> b n (h d)")
95
+
96
+ # out: [batch_size ,seq_len_i ,dim]
97
+ out = self.to_out(out)
98
+
99
+ return out
100
+
101
+
102
+ # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
103
+
104
+
105
+ class SparseConvCausalAttention(nn.Module):
106
+ def __init__(
107
+ self,
108
+ dim,
109
+ seq_len,
110
+ image_size=32,
111
+ kernel_size=5,
112
+ dilation=1,
113
+ heads=8,
114
+ dim_head=64,
115
+ dropout=0.0,
116
+ stable=False,
117
+ **kwargs,
118
+ ):
119
+ super().__init__()
120
+ assert kernel_size % 2 == 1, "kernel size must be odd"
121
+
122
+ inner_dim = dim_head * heads
123
+ self.seq_len = seq_len
124
+ self.heads = heads
125
+ self.scale = dim_head**-0.5
126
+ self.image_size = image_size
127
+ self.kernel_size = kernel_size
128
+ self.dilation = dilation
129
+
130
+ self.stable = stable
131
+
132
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
133
+
134
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
135
+
136
+ def forward(self, x, mask=None, rotary_pos_emb=None):
137
+ b, n, _, h, img_size, kernel_size, dilation, seq_len, device = (
138
+ *x.shape,
139
+ self.heads,
140
+ self.image_size,
141
+ self.kernel_size,
142
+ self.dilation,
143
+ self.seq_len,
144
+ x.device,
145
+ )
146
+ softmax = torch.softmax if not self.stable else stable_softmax
147
+
148
+ img_seq_len = img_size**2
149
+ text_len = seq_len + 1 - img_seq_len
150
+
151
+ # padding
152
+
153
+ padding = seq_len - n + 1
154
+ mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool())
155
+
156
+ x = F.pad(x, (0, 0, 0, padding), value=0)
157
+ mask = mask[:, :text_len]
158
+
159
+ # derive query / keys / values
160
+
161
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
162
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv)
163
+
164
+ if exists(rotary_pos_emb):
165
+ q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
166
+
167
+ q *= self.scale
168
+
169
+ ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
170
+ lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)
171
+ )
172
+
173
+ # text attention
174
+
175
+ dots_text = einsum("b i d, b j d -> b i j", q_text, k_text)
176
+ mask_value = max_neg_value(dots_text)
177
+
178
+ i, j = dots_text.shape[-2:]
179
+ text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
180
+ dots_text.masked_fill_(text_causal_mask, mask_value)
181
+
182
+ attn_text = softmax(dots_text, dim=-1)
183
+ out_text = einsum("b i j, b j d -> b i d", attn_text, v_text)
184
+
185
+ # image attention
186
+
187
+ effective_kernel_size = (kernel_size - 1) * dilation + 1
188
+ padding = effective_kernel_size // 2
189
+
190
+ k_img, v_img = map(
191
+ lambda t: rearrange(t, "b (h w) c -> b c h w", h=img_size), (k_img, v_img)
192
+ )
193
+ k_img, v_img = map(
194
+ lambda t: F.unfold(t, kernel_size, padding=padding, dilation=dilation),
195
+ (k_img, v_img),
196
+ )
197
+ k_img, v_img = map(
198
+ lambda t: rearrange(t, "b (d j) i -> b i j d", j=kernel_size**2),
199
+ (k_img, v_img),
200
+ )
201
+
202
+ # let image attend to all of text
203
+
204
+ dots_image = einsum("b i d, b i j d -> b i j", q_img, k_img)
205
+ dots_image_to_text = einsum("b i d, b j d -> b i j", q_img, k_text)
206
+
207
+ # calculate causal attention for local convolution
208
+
209
+ i, j = dots_image.shape[-2:]
210
+ img_seq = torch.arange(img_seq_len, device=device)
211
+ k_img_indices = rearrange(img_seq.float(), "(h w) -> () () h w", h=img_size)
212
+ k_img_indices = F.pad(
213
+ k_img_indices, (padding,) * 4, value=img_seq_len
214
+ ) # padding set to be max, so it is never attended to
215
+ k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation)
216
+ k_img_indices = rearrange(k_img_indices, "b j i -> b i j")
217
+
218
+ # mask image attention
219
+
220
+ q_img_indices = rearrange(img_seq, "i -> () i ()")
221
+ causal_mask = q_img_indices < k_img_indices
222
+
223
+ # concat text mask with image causal mask
224
+
225
+ causal_mask = repeat(causal_mask, "() i j -> b i j", b=b * h)
226
+ mask = repeat(mask, "b j -> (b h) i j", i=i, h=h)
227
+ mask = torch.cat((~mask, causal_mask), dim=-1)
228
+
229
+ # image can attend to all of text
230
+
231
+ dots = torch.cat((dots_image_to_text, dots_image), dim=-1)
232
+ dots.masked_fill_(mask, mask_value)
233
+
234
+ attn = softmax(dots, dim=-1)
235
+
236
+ # aggregate
237
+
238
+ attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
239
+
240
+ out_image_to_image = einsum("b i j, b i j d -> b i d", attn_image, v_img)
241
+ out_image_to_text = einsum("b i j, b j d -> b i d", attn_image_to_text, v_text)
242
+
243
+ out_image = out_image_to_image + out_image_to_text
244
+
245
+ # combine attended values for both text and image
246
+
247
+ out = torch.cat((out_text, out_image), dim=1)
248
+
249
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
250
+
251
+ out = self.to_out(out)
252
+
253
+ return out[:, :n]
celle/celle.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary packages and modules
2
+ from math import floor, ceil
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from axial_positional_embedding import AxialPositionalEmbedding
7
+ from einops import rearrange
8
+ from celle.utils import (
9
+ exists,
10
+ always,
11
+ eval_decorator,
12
+ gumbel_sample,
13
+ top_k,
14
+ gamma_func,
15
+ DivideMax,
16
+ )
17
+
18
+ # Import additional modules from within the codebase
19
+ from celle.transformer import Transformer
20
+
21
+
22
+ def generate_mask(gamma_func, batch_size, length, device):
23
+ # Get the number of `True` values in the mask for each batch element
24
+ num_true_values = floor(gamma_func(torch.rand(1)) * length)
25
+
26
+ # Generate a random sample of indices to set to `True` in the mask
27
+ # The number of indices in the sample is determined by `num_true_values`
28
+ indices = (
29
+ torch.rand((batch_size, length), device=device)
30
+ .topk(num_true_values, dim=1)
31
+ .indices
32
+ )
33
+
34
+ # Create a binary mask tensor with `True` values at the sampled indices
35
+ mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device)
36
+ mask.scatter_(dim=1, index=indices, value=True)
37
+
38
+ return mask
39
+
40
+
41
+ def match_batch_size(text, condition, image, batch_size):
42
+ """
43
+ This function ensures all inputs to the sample function have the same batch size.
44
+ """
45
+ if text.shape[0] != batch_size:
46
+ text = text.repeat(batch_size, 1)
47
+
48
+ if condition.shape[0] != batch_size:
49
+ condition = condition.repeat(batch_size, 1)
50
+
51
+ if image.shape[0] != batch_size:
52
+ image = image.repeat(batch_size, 1)
53
+
54
+ return text, condition, image
55
+
56
+
57
+ def calc_unmask_probs(timestep, timesteps, gamma_func):
58
+ if timestep == 1 or timesteps == 1:
59
+ unmask_prob = 1
60
+ else:
61
+ unmask_prob = 1 - gamma_func(timestep)
62
+ return unmask_prob
63
+
64
+
65
+ def calculate_logits(
66
+ input_tokens, input_mask, logits_function, filter_thres, temperature
67
+ ):
68
+ logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False)
69
+ filtered_logits = top_k(logits, thres=filter_thres)
70
+ sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
71
+
72
+ return logits, sample
73
+
74
+
75
+ def unmask_tokens(
76
+ input_tokens,
77
+ input_mask,
78
+ num_masked_tokens,
79
+ logits,
80
+ sample,
81
+ timestep,
82
+ timesteps,
83
+ gamma,
84
+ filter_func=None,
85
+ pad_token=None,
86
+ mask_token=None,
87
+ force_aas=True,
88
+ ):
89
+ sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf)
90
+ if filter_func:
91
+ sample = filter_func(
92
+ input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token
93
+ )
94
+ selected_token_probs, selected_tokens = torch.max(sample, dim=-1)
95
+
96
+ unmask_prob = calc_unmask_probs(timestep, timesteps, gamma)
97
+ num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens))
98
+
99
+ _, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1)
100
+
101
+ sample_mask = torch.zeros(
102
+ input_tokens.shape, dtype=torch.bool, device=input_tokens.device
103
+ )
104
+ sample_mask.scatter_(dim=1, index=top_k_indices, value=True)
105
+
106
+ unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens)
107
+ full_logits = torch.where(
108
+ sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits)
109
+ )
110
+ return unmasked_tokens, full_logits
111
+
112
+
113
+ def suppress_invalid_text_tokens(
114
+ text,
115
+ logits,
116
+ start_token=None,
117
+ end_token=None,
118
+ pad_token=None,
119
+ mask_token=None,
120
+ force_aas=False,
121
+ ):
122
+ # Find the indices of start_token and end_token in tensor text along axis=1
123
+ idx_start = (text == start_token).nonzero(as_tuple=True)[1]
124
+ idx_end = (text == end_token).nonzero(as_tuple=True)[1]
125
+
126
+ # For every position other than the index corresponding to the start index, set the values on the start index of dimension=2 to -torch.inf
127
+ if idx_start.nelement() != start_token:
128
+ try:
129
+ mask = idx_start.unsqueeze(1) != torch.arange(
130
+ logits.size(1), device=text.device
131
+ )
132
+ indices = torch.where(mask)
133
+ logits[indices[0], indices[1], start_token] = -torch.inf
134
+ except:
135
+ pass
136
+
137
+ # else:
138
+ # idx_start = torch.zeros(text.size(0), dtype=torch.long)
139
+
140
+ # Similarly, for every position other than the index corresponding to the end index, set the values on the end index of dimension=2 to -torch.inf
141
+ if idx_end.nelement() != 0:
142
+ try:
143
+ mask = idx_end.unsqueeze(1) != torch.arange(
144
+ logits.size(1), device=text.device
145
+ )
146
+ indices = torch.where(mask)
147
+ logits[indices[0], indices[1], end_token] = -torch.inf
148
+ except:
149
+ pass
150
+
151
+ # else:
152
+ # idx_end = torch.full((text.size(0),), text.size(1) - 1, dtype=torch.long)
153
+
154
+ if pad_token:
155
+ if idx_start.nelement() != 0 and idx_end.nelement() != 0:
156
+ try:
157
+ # For every position between the indices of start_token and end_token, set the values for 1st index of dimension=2 equal to -torch.inf. Any value outside of that range should be set to torch.inf.
158
+ mask = (
159
+ torch.arange(logits.size(1), device=text.device)
160
+ >= idx_start.unsqueeze(1)
161
+ ) & (
162
+ torch.arange(logits.size(1), device=text.device)
163
+ <= idx_end.unsqueeze(1)
164
+ )
165
+
166
+ indices = torch.where(mask)
167
+ logits[indices[0], indices[1], pad_token] = -torch.inf
168
+
169
+ indices = torch.where(~mask)
170
+ logits[indices[0], indices[1], pad_token] = torch.inf
171
+
172
+ except:
173
+ pass
174
+
175
+ elif idx_start.nelement() != 0:
176
+ try:
177
+ mask = torch.arange(
178
+ logits.size(1), device=text.device
179
+ ) < idx_start.unsqueeze(1)
180
+ logits[indices[0], indices[1], pad_token] = torch.inf
181
+ except:
182
+ pass
183
+
184
+ elif idx_end.nelement() != 0:
185
+ try:
186
+ mask = torch.arange(
187
+ logits.size(1), device=text.device
188
+ ) > idx_end.unsqueeze(1)
189
+ logits[indices[0], indices[1], pad_token] = torch.inf
190
+ except:
191
+ pass
192
+
193
+ if force_aas:
194
+ if pad_token:
195
+ logits[:, :, pad_token] = -torch.inf
196
+ logits[:, :, 3] = -torch.inf
197
+ logits[:, :, 29:] = -torch.inf
198
+
199
+ if mask_token:
200
+ logits[:, :, mask_token] = -torch.inf
201
+
202
+ return logits
203
+
204
+
205
+ def detokenize_text(text_embedding, sequence):
206
+ if text_embedding == "esm1b" or text_embedding == "esm2":
207
+ from esm import Alphabet
208
+
209
+ alphabet = (
210
+ Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks
211
+ )
212
+ else:
213
+ assert NameError("Detokenization only available for ESM mdodels")
214
+
215
+ output_seqs = []
216
+
217
+ for batch in sequence:
218
+ converted_seq = [alphabet[idx] for idx in batch]
219
+ converted_seq = "".join(converted_seq)
220
+ output_seqs.append(converted_seq)
221
+
222
+ return output_seqs
223
+
224
+ class ImageEmbedding(nn.Module):
225
+ def __init__(self, num_tokens, dim):
226
+ super(ImageEmbedding, self).__init__()
227
+ self.image_embedding = nn.Embedding(num_tokens, dim)
228
+
229
+ def forward(self, image):
230
+ return self.image_embedding(image)
231
+
232
+
233
+ class ModelExtender(nn.Module):
234
+ def __init__(self, vocab, out_features, fixed_embedding=False):
235
+ super(ModelExtender, self).__init__()
236
+
237
+ # Initialize the model according to the given vocabulary
238
+ self.vocab = vocab
239
+
240
+ if vocab == "esm1b":
241
+ from esm import pretrained
242
+
243
+ self.model, _ = pretrained.esm1b_t33_650M_UR50S()
244
+ self.in_features = 1280
245
+ elif vocab == "esm2":
246
+ from esm import pretrained
247
+
248
+ if out_features == 320:
249
+ self.model, _ = pretrained.esm2_t6_8M_UR50D()
250
+ elif out_features == 480:
251
+ self.model, _ = pretrained.esm2_t12_35M_UR50D()
252
+ elif out_features == 640:
253
+ self.model, _ = pretrained.esm2_t30_150M_UR50D()
254
+ elif out_features == 1280:
255
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
256
+ elif out_features == 2560:
257
+ self.model, _ = pretrained.esm2_t36_3B_UR50D()
258
+ else:
259
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
260
+ self.in_features = self.model.embed_dim
261
+
262
+ # Set the number of output features and initialize the scaling layer
263
+ self.out_features = out_features
264
+ self.scale_layer = nn.Linear(self.in_features, self.out_features)
265
+
266
+ # Determine whether to freeze the model's parameters
267
+ self.fixed_embedding = fixed_embedding
268
+ if self.fixed_embedding:
269
+ self.model = self.model.eval()
270
+
271
+ def forward(self, x, **kwargs):
272
+ # If the model's parameters are fixed, use torch.no_grad()
273
+ if self.fixed_embedding:
274
+ with torch.no_grad():
275
+ if self.vocab == "esm1b" or self.vocab == "esm2":
276
+ # Reduce sequence length dimension, get top layer representation tensor
277
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
278
+ "representations"
279
+ ][self.model.num_layers]
280
+ # Tensor shape: (batch_size, hidden_size)
281
+ else:
282
+ # Get top layer representation tensor
283
+ x = self.model(x, **kwargs)[0]
284
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
285
+ else:
286
+ if self.vocab == "esm1b" or self.vocab == "esm2":
287
+ # Reduce sequence length dimension, get top layer representation tensor
288
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
289
+ "representations"
290
+ ][self.model.num_layers]
291
+ # Tensor shape: (batch_size, hidden_size)
292
+ else:
293
+ # Get top layer representation tensor
294
+ x = self.model(x, **kwargs)[0]
295
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
296
+
297
+ # Scale the representation tensor if necessary
298
+ if self.out_features != self.in_features:
299
+ x = self.scale_layer(x)
300
+ # Tensor shape: (batch_size, out_features)
301
+
302
+ return x
303
+
304
+ class CELLE(nn.Module):
305
+ def __init__(
306
+ self,
307
+ *,
308
+ dim,
309
+ vae, # The VAE model used to encode/decode images
310
+ condition_vae=None, # An optional VAE model used to condition the image generation
311
+ num_images=2, # Number of images to generate
312
+ num_text_tokens=30, # Number of tokens in the text vocabulary
313
+ text_seq_len=1000, # Maximum length of input text sequence
314
+ depth=16, # Number of layers in the transformer model
315
+ heads=16, # Number of attention heads
316
+ dim_head=64, # Dimensionality of each attention head
317
+ attn_dropout=0.1, # Dropout rate for attention weights
318
+ ff_dropout=0.1, # Dropout rate for feedforward layers
319
+ attn_types=None, # Types of attention to use in the transformer
320
+ causal=False, # Whether to use causal attention
321
+ loss_cond_weight=1, # Weight of conditioning loss
322
+ loss_img_weight=1, # Weight of image generation loss
323
+ stable=False, # Whether to use divide-by-max normalization in the transformer
324
+ rotary_emb=True, # Whether to use rotary positional embeddings
325
+ text_embedding="esm2", # Text embedding to use (esm1b, esm2)
326
+ fixed_embedding=True, # Whether to fix the text embedding or learn it
327
+ sampling_mode="cosine", # Sampling mode for the VAE
328
+ linear_project=False, # Whether to project embeddings linearly
329
+ **kwargs,
330
+ ):
331
+ super().__init__()
332
+
333
+ # Set the stable flag
334
+ self.stable = stable
335
+
336
+ # If the stable flag is set, initialize the DivideMax layer for normalization
337
+ if stable:
338
+ self.norm_by_max = DivideMax(dim=-1)
339
+
340
+ ### Initializing text parameters ###
341
+
342
+ # Initialize the text and fixed embeddings
343
+ self.text_embedding = text_embedding
344
+ self.fixed_embedding = fixed_embedding
345
+
346
+ # Offset logits index and calculate cross entropy loss
347
+ self.num_text_tokens = num_text_tokens
348
+ self.linear_project = linear_project
349
+
350
+ # Add <BOS> and <EOS> tokens to the beginning and end of text sequences
351
+ if text_embedding.lower() in ("esm1b", "esm2"):
352
+ self.text_seq_len = text_seq_len + 2
353
+ else:
354
+ self.text_seq_len = text_seq_len
355
+
356
+ # Initialize embeddings for <SEP> token
357
+ self.sep_emb = nn.Embedding(1, dim)
358
+
359
+ # Initialize positional embeddings for text sequences and <SEP> token
360
+ self.text_pos_emb = (
361
+ nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0)
362
+ ) # +1 for <SEP>
363
+
364
+ ### ###
365
+
366
+ self.num_images = num_images
367
+
368
+ ### Initializing condition parameters ###
369
+
370
+ # Initialize the number of condition tokens, condition sequence length, and condition embedding
371
+ if exists(condition_vae):
372
+ condition_size = condition_vae.image_size
373
+ num_condition_tokens = condition_vae.num_tokens
374
+ self.num_condition_tokens = num_condition_tokens
375
+ condition_fmap_size = condition_vae.image_size // (
376
+ 2**condition_vae.num_layers
377
+ )
378
+ condition_seq_len = condition_fmap_size**2
379
+
380
+ # Initialize ImageEmbedding for condition embedding
381
+ self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim)
382
+
383
+ # Initialize positional embeddings for condition embedding
384
+ self.condition_pos_emb = (
385
+ AxialPositionalEmbedding(
386
+ dim, axial_shape=(condition_fmap_size, condition_fmap_size)
387
+ )
388
+ if not rotary_emb
389
+ else always(0)
390
+ )
391
+
392
+ else:
393
+ condition_fmap_size = 0
394
+ condition_seq_len = 0
395
+ num_condition_tokens = 0
396
+
397
+ ### ####
398
+
399
+ ### Initializing image parameters ###
400
+
401
+ # Initialize the image size, image token size, and sequence length
402
+ self.image_size = vae.image_size
403
+ num_image_tokens = vae.num_tokens
404
+ image_fmap_size = vae.image_size // (2**vae.num_layers)
405
+ image_seq_len = image_fmap_size**2
406
+ self.image_seq_len = image_seq_len
407
+ self.num_image_tokens = num_image_tokens
408
+
409
+ # Initialize ImageEmbedding and positional embeddings for image embedding
410
+ self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) # +1 for <IM_MASK>
411
+
412
+ self.image_pos_emb = (
413
+ AxialPositionalEmbedding(
414
+ dim, axial_shape=(image_fmap_size, image_fmap_size)
415
+ )
416
+ if not rotary_emb
417
+ else always(0)
418
+ )
419
+
420
+ # Set total sequence length and total tokens
421
+ self.num_condition_tokens = num_condition_tokens
422
+ self.condition_seq_len = condition_seq_len
423
+ # Text Length + <SEP> + Condition Tokens + Image Tokens
424
+ seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len
425
+ total_tokens = (
426
+ num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1
427
+ )
428
+ self.total_tokens = total_tokens
429
+ self.total_seq_len = seq_len
430
+
431
+ # Set the VAE and condition VAE for the model
432
+ self.vae = vae.eval()
433
+ self.condition_vae = condition_vae.eval()
434
+
435
+ ### ###
436
+
437
+ ### Setting discrete ids ###
438
+ # Initialize text embedding based on the given text_embedding parameter
439
+ if text_embedding == "esm1b" or text_embedding == "esm2":
440
+ self.text_mask_token = 32
441
+ self.pad_token = 1
442
+ self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding)
443
+ else:
444
+ raise ValueError("Only ESM models are supported.")
445
+
446
+ # Set token indices for text, condition, and image sequences
447
+ self.sep_token = num_text_tokens
448
+ self.cond_mask_token = num_condition_tokens
449
+ self.image_mask_token = num_image_tokens
450
+
451
+ # Create indices for sequence and logits dimensions
452
+ self.seq_range = torch.arange(seq_len)
453
+ self.logits_range = torch.arange(total_tokens)
454
+
455
+ # Reshape sequence and logits indices
456
+ self.seq_range = rearrange(self.seq_range, "n -> () n ()")
457
+ self.logits_range = rearrange(self.logits_range, "d -> () () d")
458
+
459
+ # Create a mask to exclude invalid token positions from the model output
460
+ # e.g. no image tokens where sequence tokens should be
461
+ logits_mask = (
462
+ # Mask text tokens beyond text_seq_len and invalid logits_range
463
+ (
464
+ (self.seq_range < self.text_seq_len)
465
+ & (self.logits_range < num_text_tokens)
466
+ & (self.logits_range != self.text_mask_token)
467
+ )
468
+ |
469
+ # Mask [SEP] token after text
470
+ (
471
+ (self.seq_range == self.text_seq_len)
472
+ & (self.logits_range == num_text_tokens)
473
+ )
474
+ |
475
+ # Mask condition tokens beyond text_seq_len+1 ([SEP]) and invalid logits_range
476
+ (
477
+ (self.seq_range >= self.text_seq_len + 1)
478
+ & (self.seq_range < self.text_seq_len + 1 + condition_seq_len)
479
+ & (self.logits_range >= num_text_tokens + 1)
480
+ & (self.logits_range < num_text_tokens + 1 + num_condition_tokens)
481
+ )
482
+ |
483
+ # Mask image tokens beyond num_text_tokens+num_condition_tokens+1
484
+ (
485
+ (self.seq_range >= self.text_seq_len + 1 + condition_seq_len)
486
+ & (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1)
487
+ & (
488
+ self.logits_range
489
+ < num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens
490
+ )
491
+ )
492
+ )
493
+
494
+ # Invert the mask
495
+ logits_mask = ~logits_mask
496
+
497
+ # Register the buffer with the logits_mask
498
+ self.register_buffer("logits_mask", logits_mask, persistent=False)
499
+
500
+ ### ###
501
+
502
+ # Initialize the Transformer model with given parameters
503
+ self.transformer = Transformer(
504
+ dim=dim,
505
+ causal=causal,
506
+ seq_len=seq_len,
507
+ depth=depth,
508
+ heads=heads,
509
+ dim_head=dim_head,
510
+ attn_dropout=attn_dropout,
511
+ ff_dropout=ff_dropout,
512
+ image_fmap_size=image_fmap_size + condition_fmap_size,
513
+ num_images=num_images,
514
+ stable=stable,
515
+ rotary_emb=rotary_emb,
516
+ )
517
+
518
+ # Initialize the linear layers for converting transformer output to logits
519
+ self.to_logits = nn.Sequential(
520
+ nn.LayerNorm(dim),
521
+ nn.Linear(dim, self.total_tokens),
522
+ )
523
+
524
+ # Set instance variables for weights and critic
525
+ self.loss_img_weight = loss_img_weight
526
+ self.loss_cond_weight = loss_cond_weight
527
+ self.gamma = gamma_func(sampling_mode)
528
+
529
+ def embed_and_transform(self, inputs, masks, return_encoding=False):
530
+ text, condition, image = inputs
531
+ device = text.device
532
+ text_mask, _, image_mask = masks
533
+
534
+ text_labels = text.clone()
535
+ text = torch.where(
536
+ text_mask, self.text_mask_token * torch.ones_like(text, device=device), text
537
+ )
538
+
539
+ tokens = self.text_emb(text)
540
+
541
+ # Add SEP token
542
+
543
+ sep_token_emb = self.sep_emb(
544
+ torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device)
545
+ )
546
+ tokens = torch.cat((tokens, sep_token_emb), dim=1)
547
+ tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device))
548
+
549
+ with torch.no_grad():
550
+ if self.linear_project:
551
+ b = condition.shape[0]
552
+ condition, _, [_, _, condition_labels] = self.condition_vae.encode(
553
+ condition
554
+ )
555
+ condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b)
556
+
557
+ else:
558
+ condition_labels = condition
559
+ if condition.dtype == torch.float:
560
+ condition_labels = self.condition_vae.get_codebook_indices(
561
+ condition
562
+ )
563
+ condition = condition_labels.clone()
564
+
565
+ condition_emb = self.condition_emb(condition)
566
+ condition_emb += self.condition_pos_emb(condition_emb)
567
+ tokens = torch.cat((tokens, condition_emb), dim=1)
568
+
569
+ with torch.no_grad():
570
+ if self.linear_project:
571
+ b = image.shape[0]
572
+ image, _, [_, _, image_labels] = self.vae.encode(image)
573
+ image_labels = rearrange(image_labels, "(b n) -> b n", b=b)
574
+
575
+ else:
576
+ image_labels = image
577
+ if image.dtype == torch.float:
578
+ image_labels = self.vae.get_codebook_indices(image)
579
+ image = torch.where(
580
+ image_mask,
581
+ self.image_mask_token
582
+ * torch.ones_like(image_labels, device=device),
583
+ image_labels,
584
+ )
585
+
586
+ image_emb = self.image_emb(image)
587
+
588
+ image_emb += self.image_pos_emb(image_emb)
589
+ tokens = torch.cat((tokens, image_emb), dim=1)
590
+
591
+ if self.stable:
592
+ alpha = 0.1
593
+ tokens = tokens * alpha + tokens.detach() * (1 - alpha)
594
+
595
+ out = self.transformer(tokens)
596
+
597
+ if self.stable:
598
+ out = self.norm_by_max(out)
599
+
600
+ logits = self.to_logits(out)
601
+
602
+ max_neg_value = -torch.finfo(logits.dtype).max
603
+ logits.masked_fill_(self.logits_mask, max_neg_value)
604
+
605
+ if return_encoding:
606
+ return logits, out, [text_labels, condition_labels, image_labels]
607
+ else:
608
+ return logits, None, [text_labels, condition_labels, image_labels]
609
+
610
+ def forward(
611
+ self,
612
+ text,
613
+ condition=None,
614
+ image=None,
615
+ return_loss=False,
616
+ return_encoding=False,
617
+ ):
618
+ batch_size, device = text.shape[0], text.device
619
+
620
+ # Check that image is supplied when training
621
+ assert exists(image), "when training, image must be supplied"
622
+
623
+ # Check that image dimensions match the expected dimensions
624
+ assert tuple(image.shape[1:]) == (
625
+ self.vae.channels,
626
+ self.image_size,
627
+ self.image_size,
628
+ ), f"invalid image of dimensions {image.shape} passed in during training"
629
+
630
+ # Generate masks for text, condition, and image
631
+
632
+ # text_mask = generate_mask(self.gamma, batch_size, self.text_seq_len, device)
633
+
634
+ text_mask = generate_mask(
635
+ gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device
636
+ )
637
+
638
+ image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device)
639
+
640
+ # Embed and transform inputs
641
+ logits, _, labels = self.embed_and_transform(
642
+ [text, condition, image],
643
+ [text_mask, None, image_mask],
644
+ return_encoding,
645
+ device,
646
+ )
647
+
648
+ # If not returning loss, return the logits
649
+ if not return_loss:
650
+ return logits
651
+
652
+ # Separate labels
653
+ text, condition, image = labels
654
+
655
+ # Add SEP token to end of text label
656
+ sep_token = torch.tensor(self.sep_token, device=device).repeat(
657
+ labels.shape[0], 1
658
+ )
659
+ labels = torch.cat([labels, sep_token], dim=1)
660
+
661
+ # If condition exists and condition vae is defined, add the condition to the labels
662
+ if exists(condition) and exists(self.condition_vae):
663
+ offsetted_condition = condition + self.num_text_tokens + 1
664
+ labels = torch.cat((labels, offsetted_condition), dim=1)
665
+
666
+ # Add image to the labels
667
+ offsetted_image = (
668
+ image + self.num_text_tokens + 1 + self.num_condition_tokens + 1
669
+ )
670
+ labels = torch.cat((labels, offsetted_image), dim=1)
671
+
672
+ # Rearrange logits for cross-entropy loss calculation
673
+ # Logits size: (batch_size, vocab_size, total_seq_len)
674
+ # Labels size: (batch_size, total_seq_len)
675
+ logits = rearrange(logits, "b n c -> b c n")
676
+
677
+ # Calculate cross-entropy loss for text and image
678
+ loss_text = F.cross_entropy(
679
+ logits[:, :, : self.text_seq_len],
680
+ labels[:, : self.text_seq_len],
681
+ reduction="none",
682
+ )[text_mask].mean()
683
+
684
+ loss_img = F.cross_entropy(
685
+ logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :],
686
+ labels[:, self.text_seq_len + 1 + self.condition_seq_len :],
687
+ reduction="none",
688
+ )[image_mask].mean()
689
+
690
+ # Calculate total loss
691
+ loss = (loss_text + self.loss_img_weight * loss_img) / (
692
+ self.loss_img_weight + 1
693
+ )
694
+
695
+ loss_dict = {
696
+ "loss_text": loss_text,
697
+ # "loss_cond": loss_cond,
698
+ "loss_img": loss_img,
699
+ "loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0),
700
+ }
701
+
702
+ return loss, loss_dict, None
703
+
704
+ def create_tensors(self, text, condition, image):
705
+ """
706
+ This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function.
707
+ """
708
+ device = next(
709
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
710
+ None,
711
+ ).device
712
+
713
+ if not isinstance(text, torch.Tensor):
714
+ text = (
715
+ torch.ones(1, self.text_seq_len, device=device, dtype=torch.long)
716
+ * self.text_mask_token
717
+ )
718
+
719
+ if not isinstance(condition, torch.Tensor):
720
+ condition = (
721
+ torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long)
722
+ * self.cond_mask_token
723
+ )
724
+ else:
725
+ with torch.no_grad():
726
+ condition = self.condition_vae.get_codebook_indices(condition)
727
+
728
+ if not isinstance(image, torch.Tensor):
729
+ image = (
730
+ torch.ones(1, self.image_seq_len, device=device, dtype=torch.long)
731
+ * self.image_mask_token
732
+ )
733
+ else:
734
+ with torch.no_grad():
735
+ image = self.vae.get_codebook_indices(image)
736
+
737
+ return text, condition, image
738
+
739
+ @torch.no_grad()
740
+ @eval_decorator
741
+ def sample(
742
+ self,
743
+ text=None,
744
+ condition=None,
745
+ image=None,
746
+ temperature=1.0,
747
+ filter_thres=0.9,
748
+ progress=False,
749
+ timesteps=1,
750
+ force_aas=True,
751
+ ):
752
+ # ensure timesteps is a positive integer
753
+ assert int(timesteps) > 0
754
+ # set model and VAEs to evaluation mode
755
+ self.eval()
756
+ vae = self.vae.eval()
757
+ if progress == True:
758
+ progress = tqdm
759
+ else:
760
+ progress = lambda x: x
761
+
762
+
763
+ # ensure that at least one of text, condition, or image is supplied
764
+ assert (
765
+ isinstance(text, torch.Tensor)
766
+ or isinstance(condition, torch.Tensor)
767
+ or isinstance(image, torch.Tensor)
768
+ ), "some data must be supplied"
769
+
770
+ # convert text, condition, and image to tensors if they aren't already
771
+ text, condition, image = self.create_tensors(text, condition, image)
772
+
773
+ # determine the maximum batch size of the input tensors
774
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
775
+
776
+ # match the batch sizes of text, condition, and image
777
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
778
+
779
+ # determine the device of the tensors
780
+ device = next(
781
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
782
+ None,
783
+ ).device
784
+
785
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
786
+
787
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
788
+
789
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
790
+ full_text_logits = torch.zeros(
791
+ batch_size, self.text_seq_len, self.num_text_tokens
792
+ ).to(device)
793
+
794
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
795
+ full_text_logits = full_text_logits.scatter_(
796
+ dim=-1, index=text.unsqueeze(-1), value=1
797
+ )
798
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
799
+ full_image_logits = torch.zeros(
800
+ batch_size, self.image_seq_len, self.num_image_tokens + 1
801
+ ).to(device)
802
+
803
+ # Remove the last token from each image sequence by setting full_image_logits to its first num_image_tokens elements
804
+ full_image_logits = full_image_logits.scatter_(
805
+ dim=-1, index=image.unsqueeze(-1), value=1
806
+ )
807
+
808
+ # cut off mask token
809
+ full_image_logits = full_image_logits[:, :, : self.num_image_tokens]
810
+
811
+ count = 0
812
+
813
+ for timestep in progress(torch.linspace(0, 1, timesteps)):
814
+ # Create masks for the text, condition, and image tensors
815
+ text_mask = text == self.text_mask_token
816
+ cond_mask = condition == self.cond_mask_token
817
+ image_mask = image == self.image_mask_token
818
+
819
+ # Calculate logits and samples using the calculate_logits function
820
+ logits, sample = calculate_logits(
821
+ [text, condition, image],
822
+ [text_mask, cond_mask, image_mask],
823
+ self.embed_and_transform,
824
+ filter_thres,
825
+ temperature,
826
+ )
827
+
828
+ # Calculate the number of masked tokens in the text and image tensors
829
+ num_masked_text_tokens = torch.sum(text_mask, dim=1)[0]
830
+ num_masked_image_tokens = torch.sum(image_mask, dim=1)[0]
831
+
832
+ # If there are masked text tokens, unmask them using unmask_tokens and fill the full text logits tensor with -inf for unmasked tokens
833
+ if num_masked_text_tokens.any() > 0:
834
+ text, full_text_logits = unmask_tokens(
835
+ text,
836
+ text_mask,
837
+ num_masked_text_tokens,
838
+ logits[:, : self.text_seq_len, : self.num_text_tokens],
839
+ sample[:, : self.text_seq_len, : self.num_text_tokens],
840
+ timestep,
841
+ timesteps,
842
+ self.gamma,
843
+ suppress_invalid_text_tokens,
844
+ self.pad_token,
845
+ self.text_mask_token,
846
+ force_aas=force_aas,
847
+ )
848
+ full_text_logits = full_text_logits.masked_fill(
849
+ ~text_mask.unsqueeze(-1), -torch.inf
850
+ )
851
+
852
+ # If there are masked image tokens, unmask them using unmask_tokens and fill the full image logits tensor with -inf for unmasked tokens
853
+ if num_masked_image_tokens > 0:
854
+ image, full_image_logits = unmask_tokens(
855
+ image,
856
+ image_mask,
857
+ num_masked_image_tokens,
858
+ logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
859
+ sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
860
+ timestep,
861
+ timesteps,
862
+ self.gamma,
863
+ )
864
+ full_text_logits = full_text_logits.masked_fill(
865
+ ~text_mask.unsqueeze(-1), -torch.inf
866
+ )
867
+
868
+ # Generate heatmap
869
+ with torch.no_grad():
870
+ # Normalize full image logits tensor
871
+ full_image_logits /= torch.max(
872
+ torch.abs(full_image_logits), dim=-1, keepdim=True
873
+ ).values
874
+
875
+ # Apply quantize embedding to full image logits tensor
876
+ full_image_logits = torch.matmul(
877
+ full_image_logits, self.vae.model.quantize.embedding.weight
878
+ )
879
+
880
+ # Rearrange full image logits tensor
881
+ h = int(self.image_seq_len**0.5)
882
+ full_image_logits = rearrange(
883
+ full_image_logits, "b (h w) c -> b c h w", h=h
884
+ )
885
+
886
+ # Decode full image logits tensor
887
+ full_image_logits = self.vae.model.decode(full_image_logits)
888
+
889
+ # Add clipping to full image logits tensor
890
+ max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0]
891
+ min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0]
892
+ full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view(
893
+ batch_size, 1, 1, 1
894
+ )
895
+ full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view(
896
+ batch_size, 1, 1, 1
897
+ )
898
+
899
+ # Clip full image logits tensor values to the range [0, 1]
900
+ full_image_logits = torch.clip(full_image_logits, 0, 1)
901
+
902
+ # Return text tensor, detokenized text tensor, full text logits tensor,
903
+ # binary image tensor, and full image logits tensor
904
+ return (
905
+ text,
906
+ detokenize_text(self.text_embedding, text),
907
+ full_text_logits,
908
+ 1.0 * (vae.decode(image) > 0.5),
909
+ full_image_logits,
910
+ )
911
+
912
+ @torch.no_grad()
913
+ @eval_decorator
914
+ def sample_text(
915
+ self,
916
+ text=False,
917
+ condition=False,
918
+ image=False,
919
+ temperature=1.0,
920
+ filter_thres=0.9,
921
+ progress=False,
922
+ n_unmask=1,
923
+ place_amino=True,
924
+ force_aas=False,
925
+ ):
926
+ # set model and VAEs to evaluation mode
927
+ self.eval()
928
+
929
+ # ensure that at least one of text, condition, or image is supplied
930
+ assert (
931
+ isinstance(text, torch.Tensor)
932
+ or isinstance(condition, torch.Tensor)
933
+ or isinstance(image, torch.Tensor)
934
+ ), "some data must be supplied"
935
+
936
+ # convert text, condition, and image to tensors if they aren't already
937
+ text, condition, image = self.create_tensors(text, condition, image)
938
+
939
+ # determine the maximum batch size of the input tensors
940
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
941
+
942
+ # match the batch sizes of text, condition, and image
943
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
944
+
945
+ # determine the device of the tensors
946
+ device = next(
947
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
948
+ None,
949
+ ).device
950
+
951
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
952
+
953
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
954
+
955
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
956
+ full_text_logits = torch.zeros(
957
+ batch_size, self.text_seq_len, self.num_text_tokens
958
+ ).to(device)
959
+
960
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
961
+ full_text_logits = full_text_logits.scatter_(
962
+ dim=-1, index=text.unsqueeze(-1), value=1
963
+ )
964
+
965
+ text_mask = text == self.text_mask_token
966
+ cond_mask = condition == self.cond_mask_token
967
+ image_mask = image == self.image_mask_token
968
+
969
+ mask_indices = text_mask.nonzero()
970
+ non_mask_indices = (~text_mask).nonzero()
971
+
972
+ # figure out the center of the amino acids to determine generation direction
973
+ central_protein_index = torch.tensor(
974
+ [
975
+ torch.median(
976
+ non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1]
977
+ )
978
+ for idx in range(batch_size)
979
+ ]
980
+ )
981
+
982
+ count = 1
983
+
984
+ run_mask = text_mask
985
+ if progress:
986
+ pbar = progress(total=torch.sum(run_mask).item())
987
+ while torch.sum(run_mask) > 0:
988
+ logits, sample = calculate_logits(
989
+ [text, condition, image],
990
+ [text_mask, cond_mask, image_mask],
991
+ self.embed_and_transform,
992
+ filter_thres,
993
+ temperature,
994
+ )
995
+
996
+ # sub_sample: [batch_size ,text_seq_len ,num_text_tokens]
997
+ sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens]
998
+ sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf)
999
+ sub_sample = suppress_invalid_text_tokens(
1000
+ text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas
1001
+ )
1002
+ # calculate % to unmasked
1003
+ # get most likely token and probability for each position
1004
+
1005
+ for idx in range(batch_size):
1006
+ selected_mask_indices = mask_indices[
1007
+ torch.where(mask_indices[:, 0] == idx)
1008
+ ][:, -1]
1009
+
1010
+ # Generate to the left
1011
+ if selected_mask_indices[-count] < central_protein_index[idx]:
1012
+ unmask_index = selected_mask_indices[-count]
1013
+ left_sample = max(0, (unmask_index + 1) - n_unmask)
1014
+ right_sample = min(unmask_index + 1, self.text_seq_len - 1)
1015
+ central_protein_index[idx] = max(
1016
+ 0, central_protein_index[idx] - 0.5 * n_unmask
1017
+ )
1018
+
1019
+ # Generate to the right
1020
+ elif selected_mask_indices[count - 1] > central_protein_index[idx]:
1021
+ unmask_index = selected_mask_indices[count - 1]
1022
+ left_sample = max(0, unmask_index)
1023
+ right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1)
1024
+ central_protein_index[idx] = min(
1025
+ central_protein_index[idx] + 0.5 * n_unmask,
1026
+ self.text_seq_len - 1,
1027
+ )
1028
+
1029
+ # save logits for relevant position
1030
+ full_text_logits[
1031
+ idx, left_sample:right_sample, : self.text_seq_len - 1
1032
+ ] = logits[idx, left_sample:right_sample, : self.num_text_tokens]
1033
+
1034
+ run_mask[idx, left_sample:right_sample] = False
1035
+
1036
+ # you may want to resample the amion acids or calculate marginal probs
1037
+ # if so, set place_amino to false
1038
+ if place_amino:
1039
+ text[idx, left_sample:right_sample] = torch.where(
1040
+ text[idx, left_sample:right_sample] == self.text_mask_token,
1041
+ sub_sample[
1042
+ idx, left_sample:right_sample, : self.num_text_tokens
1043
+ ].argmax(dim=-1),
1044
+ text[idx, left_sample:right_sample],
1045
+ )
1046
+
1047
+ text_mask = run_mask
1048
+
1049
+ count += n_unmask
1050
+
1051
+ if progress:
1052
+ pbar.update(n_unmask)
1053
+ if progress:
1054
+ pbar.close()
1055
+
1056
+ return (
1057
+ text,
1058
+ detokenize_text(self.text_embedding, text),
1059
+ full_text_logits,
1060
+ )
celle/reversible.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # for routing arguments into the functions of the reversible layer
4
+ def route_args(router, args, depth):
5
+ routed_args = [(dict(), dict()) for _ in range(depth)]
6
+ matched_keys = [key for key in args.keys() if key in router]
7
+
8
+ for key in matched_keys:
9
+ val = args[key]
10
+ for depth, ((f_args, g_args), routes) in enumerate(
11
+ zip(routed_args, router[key])
12
+ ):
13
+ new_f_args, new_g_args = map(
14
+ lambda route: ({key: val} if route else {}), routes
15
+ )
16
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17
+ return routed_args
18
+
19
+ class SequentialSequence(nn.Module):
20
+ def __init__(self, layers, args_route={}, layer_dropout=0.0):
21
+ super().__init__()
22
+ assert all(
23
+ len(route) == len(layers) for route in args_route.values()
24
+ ), "each argument route map must have the same depth as the number of sequential layers"
25
+ self.layers = layers
26
+ self.args_route = args_route
27
+ self.layer_dropout = layer_dropout
28
+
29
+ def forward(self, x, **kwargs):
30
+ args = route_args(self.args_route, kwargs, len(self.layers))
31
+ layers_and_args = list(zip(self.layers, args))
32
+
33
+ for (f, g), (f_args, g_args) in layers_and_args:
34
+ x = x + f(x, **f_args)
35
+ x = x + g(x, **g_args)
36
+ return x
celle/transformer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ from celle.reversible import SequentialSequence
9
+ from celle.attention import Attention
10
+
11
+ from rotary_embedding_torch import RotaryEmbedding, broadcat
12
+ from celle.utils import exists, default, cast_tuple
13
+
14
+ # https://arxiv.org/abs/2103.17239
15
+ class LayerScale(nn.Module):
16
+ def __init__(self, dim, depth, fn):
17
+ super().__init__()
18
+ if depth <= 18:
19
+ init_eps = 0.1
20
+ elif depth > 18 and depth <= 24:
21
+ init_eps = 1e-5
22
+ else:
23
+ init_eps = 1e-6
24
+
25
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
26
+ self.scale = nn.Parameter(scale)
27
+ self.fn = fn
28
+
29
+ def forward(self, x, **kwargs):
30
+ return self.fn(x, **kwargs) * self.scale
31
+
32
+
33
+ # layer norm
34
+ class PreNorm(nn.Module):
35
+ def __init__(self, dim, fn):
36
+ super().__init__()
37
+ self.norm = nn.LayerNorm(dim)
38
+ self.norm_out = nn.Identity()
39
+ self.fn = fn
40
+
41
+ def forward(self, x, **kwargs):
42
+ x = self.norm(x)
43
+ x = self.fn(x, **kwargs)
44
+ return self.norm_out(x)
45
+
46
+
47
+ # feed forward
48
+
49
+
50
+ class GEGLU(nn.Module):
51
+ def forward(self, x):
52
+ x, gates = x.chunk(2, dim=-1)
53
+ return x * F.gelu(gates)
54
+
55
+
56
+ class FeedForward(nn.Module):
57
+ def __init__(self, dim, dropout=0.0, mult=4.0):
58
+ super().__init__()
59
+ self.net = nn.Sequential(
60
+ nn.Linear(dim, dim * mult * 2),
61
+ GEGLU(),
62
+ nn.Dropout(dropout),
63
+ nn.Linear(dim * mult, dim),
64
+ )
65
+
66
+ def forward(self, x):
67
+ return self.net(x)
68
+
69
+
70
+ # main transformer class
71
+ class Transformer(nn.Module):
72
+ def __init__(
73
+ self,
74
+ *,
75
+ dim,
76
+ depth,
77
+ seq_len,
78
+ causal=True,
79
+ heads=8,
80
+ dim_head=64,
81
+ ff_mult=4,
82
+ attn_dropout=0.0,
83
+ ff_dropout=0.0,
84
+ image_fmap_size=None,
85
+ num_images=None,
86
+ stable=False,
87
+ rotary_emb=True,
88
+ ):
89
+ super().__init__()
90
+ layers = nn.ModuleList([])
91
+
92
+ self.seq_len = seq_len
93
+ self.image_fmap_size = image_fmap_size
94
+
95
+ for ind in range(depth):
96
+
97
+ attn_class = partial(Attention, stable=stable)
98
+
99
+ attn = attn_class(
100
+ dim,
101
+ causal=causal,
102
+ seq_len=seq_len,
103
+ heads=heads,
104
+ dim_head=dim_head,
105
+ dropout=attn_dropout,
106
+ )
107
+
108
+ ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
109
+
110
+ layers.append(
111
+ nn.ModuleList(
112
+ [
113
+ LayerScale(
114
+ dim, ind + 1, PreNorm(dim, attn)
115
+ ),
116
+ LayerScale(
117
+ dim, ind + 1, PreNorm(dim, ff)
118
+ ),
119
+ ]
120
+ )
121
+ )
122
+
123
+ # pairs arguments with attention layer
124
+ route_attn = ((True, False),) * depth
125
+ attn_route_map = {
126
+ "mask": route_attn,
127
+ "rotary_pos_emb": route_attn,
128
+ }
129
+
130
+ self.layers = SequentialSequence(layers, args_route=attn_route_map)
131
+
132
+ # generate positional embeddings for rotary
133
+
134
+ pos_emb = None
135
+ if rotary_emb:
136
+ rot_dim = dim_head // 3
137
+ img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images
138
+
139
+ text_len = seq_len - img_seq_len + 1
140
+
141
+ text_pos_emb = RotaryEmbedding(dim=rot_dim)
142
+
143
+ img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel")
144
+
145
+ text_freqs = text_pos_emb(torch.arange(text_len))
146
+
147
+ img_to_text_freqs = text_pos_emb(
148
+ torch.full((img_seq_len,), 8192)
149
+ ) # image is given a position far away from text
150
+
151
+ text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0)
152
+
153
+ img_freqs_axial = img_axial_pos_emb(
154
+ torch.linspace(-1, 1, steps=image_fmap_size)
155
+ )
156
+
157
+ if num_images > 1:
158
+ split_img_freqs_axial = torch.split(
159
+ img_freqs_axial, image_fmap_size // num_images, dim=0
160
+ )
161
+
162
+ split_img_freqs = [
163
+ broadcat(
164
+ (
165
+ rearrange(img_freqs_axial_per_image, "i d -> i () d"),
166
+ rearrange(img_freqs_axial_per_image, "j d -> () j d"),
167
+ ),
168
+ dim=-1,
169
+ )
170
+ for img_freqs_axial_per_image in split_img_freqs_axial
171
+ ]
172
+
173
+ split_img_freqs = [
174
+ rearrange(img_freqs_per_image, "h w d -> (h w) d")
175
+ for img_freqs_per_image in split_img_freqs
176
+ ]
177
+
178
+ # concat per image-image_freqs
179
+
180
+ img_freqs = torch.cat(split_img_freqs, dim=0)
181
+
182
+ elif num_images == 1:
183
+ img_freqs = broadcat(
184
+ (
185
+ rearrange(img_freqs_axial, "i d -> i () d"),
186
+ rearrange(img_freqs_axial, "j d -> () j d"),
187
+ ),
188
+ dim=-1,
189
+ )
190
+
191
+ img_freqs = rearrange(img_freqs, "h w d -> (h w) d")
192
+
193
+ else:
194
+ assert False, "num_images must be int greater than 0"
195
+ self.img_axial_pos_emb = img_axial_pos_emb
196
+ self.text_pos_emb = text_pos_emb
197
+
198
+ text_axial_freqs = img_axial_pos_emb(
199
+ torch.full((text_len,), -10.0)
200
+ ) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
201
+
202
+ text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1)
203
+
204
+ img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0)
205
+
206
+ pos_emb = torch.cat((text_freqs, img_freqs), dim=-1)
207
+
208
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
209
+
210
+ self.register_buffer("pos_emb", pos_emb)
211
+
212
+ def forward(self, x, **kwargs):
213
+ return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs)
celle/utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from math import pi
3
+
4
+ # Define helper functions
5
+ def exists(val):
6
+ """Check if a variable exists"""
7
+ return val is not None
8
+
9
+
10
+ def uniq(arr):
11
+ return {el: True for el in arr}.keys()
12
+
13
+
14
+ def default(val, d):
15
+ """If a value exists, return it; otherwise, return a default value"""
16
+ return val if exists(val) else d
17
+
18
+
19
+ def max_neg_value(t):
20
+ return -torch.finfo(t.dtype).max
21
+
22
+
23
+ def cast_tuple(val, depth=1):
24
+ if isinstance(val, list):
25
+ val = tuple(val)
26
+ return val if isinstance(val, tuple) else (val,) * depth
27
+
28
+
29
+ def is_empty(t):
30
+ """Check if a tensor is empty"""
31
+ # Return True if the number of elements in the tensor is zero, else False
32
+ return t.nelement() == 0
33
+
34
+
35
+ def masked_mean(t, mask, dim=1):
36
+ """
37
+ Compute the mean of a tensor, masked by a given mask
38
+
39
+ Args:
40
+ t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim)
41
+ mask (torch.Tensor): mask tensor of shape (batch_size, seq_len)
42
+ dim (int): dimension along which to compute the mean (default=1)
43
+
44
+ Returns:
45
+ torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim)
46
+ """
47
+ t = t.masked_fill(~mask[:, :, None], 0.0)
48
+ return t.sum(dim=1) / mask.sum(dim=1)[..., None]
49
+
50
+
51
+ def set_requires_grad(model, value):
52
+ """
53
+ Set whether or not the model's parameters require gradients
54
+
55
+ Args:
56
+ model (torch.nn.Module): the PyTorch model to modify
57
+ value (bool): whether or not to require gradients
58
+ """
59
+ for param in model.parameters():
60
+ param.requires_grad = value
61
+
62
+
63
+ def eval_decorator(fn):
64
+ """
65
+ Decorator function to evaluate a given function
66
+
67
+ Args:
68
+ fn (callable): function to evaluate
69
+
70
+ Returns:
71
+ callable: the decorated function
72
+ """
73
+
74
+ def inner(model, *args, **kwargs):
75
+ was_training = model.training
76
+ model.eval()
77
+ out = fn(model, *args, **kwargs)
78
+ model.train(was_training)
79
+ return out
80
+
81
+ return inner
82
+
83
+
84
+ def log(t, eps=1e-20):
85
+ """
86
+ Compute the natural logarithm of a tensor
87
+
88
+ Args:
89
+ t (torch.Tensor): input tensor
90
+ eps (float): small value to add to prevent taking the log of 0 (default=1e-20)
91
+
92
+ Returns:
93
+ torch.Tensor: the natural logarithm of the input tensor
94
+ """
95
+ return torch.log(t + eps)
96
+
97
+
98
+ def gumbel_noise(t):
99
+ """
100
+ Generate Gumbel noise
101
+
102
+ Args:
103
+ t (torch.Tensor): input tensor
104
+
105
+ Returns:
106
+ torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor
107
+ """
108
+ noise = torch.zeros_like(t).uniform_(0, 1)
109
+ return -log(-log(noise))
110
+
111
+
112
+ def gumbel_sample(t, temperature=0.9, dim=-1):
113
+ """
114
+ Sample from a Gumbel-softmax distribution
115
+
116
+ Args:
117
+ t (torch.Tensor): input tensor of shape (batch_size, num_classes)
118
+ temperature (float): temperature for the Gumbel-softmax distribution (default=0.9)
119
+ dim (int): dimension along which to sample (default=-1)
120
+
121
+ Returns:
122
+ torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor
123
+ """
124
+ return (t / max(temperature, 1e-10)) + gumbel_noise(t)
125
+
126
+
127
+ def top_k(logits, thres=0.5):
128
+ """
129
+ Return a tensor where all but the top k values are set to negative infinity
130
+
131
+ Args:
132
+ logits (torch.Tensor): input tensor of shape (batch_size, num_classes)
133
+ thres (float): threshold for the top k values (default=0.5)
134
+
135
+ Returns:
136
+ torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity
137
+ """
138
+ num_logits = logits.shape[-1]
139
+ k = max(int((1 - thres) * num_logits), 1)
140
+ val, ind = torch.topk(logits, k)
141
+ probs = torch.full_like(logits, float("-inf"))
142
+ probs.scatter_(-1, ind, val)
143
+ return probs
144
+
145
+
146
+ def gamma_func(mode="cosine", scale=0.15):
147
+ """Return a function that takes a single input r and returns a value based on the selected mode"""
148
+
149
+ # Define a different function based on the selected mode
150
+ if mode == "linear":
151
+ return lambda r: 1 - r
152
+ elif mode == "cosine":
153
+ return lambda r: torch.cos(r * pi / 2)
154
+ elif mode == "square":
155
+ return lambda r: 1 - r**2
156
+ elif mode == "cubic":
157
+ return lambda r: 1 - r**3
158
+ elif mode == "scaled-cosine":
159
+ return lambda r: scale * (torch.cos(r * pi / 2))
160
+ else:
161
+ # Raise an error if the selected mode is not implemented
162
+ raise NotImplementedError
163
+
164
+
165
+ class always:
166
+ """Helper class to always return a given value"""
167
+
168
+ def __init__(self, val):
169
+ self.val = val
170
+
171
+ def __call__(self, x, *args, **kwargs):
172
+ return self.val
173
+
174
+
175
+ class DivideMax(torch.nn.Module):
176
+ def __init__(self, dim):
177
+ super().__init__()
178
+ self.dim = dim
179
+
180
+ def forward(self, x):
181
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
182
+ return x / maxes
183
+
184
+ def process_image(image_path):
185
+ image = Image.open(image_path)
186
+ transform = transforms.Compose([
187
+ transforms.RandomCrop(256),
188
+ transforms.ToTensor()
189
+ ])
190
+ image_tensor = transform(image)
191
+ if image_tensor.shape[0] > 1:
192
+ image_tensor = torch.mean(image_tensor, dim=0, keepdim=True)
193
+ return image_tensor.unsqueeze(0)
celle/vae.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt, log
2
+ from omegaconf import OmegaConf
3
+ import importlib
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange
10
+
11
+ # helpers methods
12
+
13
+
14
+ def load_model(path):
15
+ with open(path, "rb") as f:
16
+ return torch.load(f, map_location=torch.device("cpu"))
17
+
18
+
19
+ def map_pixels(x, eps=0.1):
20
+ return (1 - 2 * eps) * x + eps
21
+
22
+
23
+ def unmap_pixels(x, eps=0.1):
24
+ return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
25
+
26
+
27
+ def make_contiguous(module):
28
+ with torch.no_grad():
29
+ for param in module.parameters():
30
+ param.set_(param.contiguous())
31
+
32
+
33
+ # VQGAN from Taming Transformers paper
34
+ # https://arxiv.org/abs/2012.09841
35
+
36
+
37
+ def get_obj_from_str(string, reload=False):
38
+ module, cls = string.rsplit(".", 1)
39
+ if reload:
40
+ module_imp = importlib.import_module(module)
41
+ importlib.reload(module_imp)
42
+ return getattr(importlib.import_module(module, package=None), cls)
43
+
44
+
45
+ def instantiate_from_config(config):
46
+ if not "target" in config:
47
+ raise KeyError("Expected key `target` to instantiate.")
48
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
49
+
50
+
51
+ class VQGanVAE(nn.Module):
52
+ def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1):
53
+ super().__init__()
54
+
55
+ assert vqgan_config_path is not None
56
+
57
+ model_path = vqgan_model_path
58
+ config_path = vqgan_config_path
59
+
60
+ config = OmegaConf.load(config_path)
61
+
62
+ model = instantiate_from_config(config["model"])
63
+
64
+ if vqgan_model_path:
65
+
66
+ state = torch.load(model_path, map_location="cpu")["state_dict"]
67
+ model.load_state_dict(state, strict=True)
68
+
69
+ print(f"Loaded VQGAN from {model_path} and {config_path}")
70
+
71
+ self.model = model
72
+
73
+ # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
74
+ f = (
75
+ config.model.params.ddconfig.resolution
76
+ / config.model.params.ddconfig.attn_resolutions[0]
77
+ )
78
+ self.num_layers = int(log(f) / log(2))
79
+ self.image_size = config.model.params.ddconfig.resolution
80
+ self.num_tokens = config.model.params.n_embed
81
+ # self.is_gumbel = isinstance(self.model, GumbelVQ)
82
+ self.is_gumbel = False
83
+ self.channels = config.model.params.ddconfig.in_channels
84
+
85
+ def encode(self, img):
86
+ return self.model.encode(img)
87
+
88
+ def get_codebook_indices(self, img):
89
+ b = img.shape[0]
90
+ # img = (2 * img) - 1
91
+ _, _, [_, _, indices] = self.encode(img)
92
+ if self.is_gumbel:
93
+ return rearrange(indices, "b h w -> b (h w)", b=b)
94
+ return rearrange(indices, "(b n) -> b n", b=b)
95
+
96
+ def decode(self, img_seq):
97
+ b, n = img_seq.shape
98
+ one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
99
+ z = (
100
+ one_hot_indices @ self.model.quantize.embed.weight
101
+ if self.is_gumbel
102
+ else (one_hot_indices @ self.model.quantize.embedding.weight)
103
+ )
104
+
105
+ z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
106
+ img = self.model.decode(z)
107
+
108
+ # img = (img.clamp(-1.0, 1.0) + 1) * 0.5
109
+ return img
110
+
111
+ def forward(self, img, optimizer_idx=1):
112
+ return self.model.training_step(img, optimizer_idx=optimizer_idx)
celle_main.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.random
6
+ from torch.optim import AdamW
7
+ from torch.utils.data import DataLoader
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import seed_everything
10
+ from pytorch_lightning.trainer import Trainer
11
+
12
+ from dataloader import CellLoader
13
+ from celle import VQGanVAE, CELLE
14
+ from omegaconf import OmegaConf
15
+ import argparse, os, sys, datetime, glob
16
+
17
+ from celle.celle import gumbel_sample, top_k
18
+
19
+ torch.random.manual_seed(42)
20
+ np.random.seed(42)
21
+
22
+ from celle_taming_main import (
23
+ instantiate_from_config,
24
+ nondefault_trainer_args,
25
+ get_parser,
26
+ )
27
+
28
+
29
+ class CellDataModule(pl.LightningDataModule):
30
+ def __init__(
31
+ self,
32
+ data_csv,
33
+ dataset,
34
+ sequence_mode="standard",
35
+ vocab="bert",
36
+ crop_size=256,
37
+ resize=600,
38
+ batch_size=1,
39
+ threshold="median",
40
+ text_seq_len=1000,
41
+ num_workers=1,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.data_csv = data_csv
47
+ self.dataset = dataset
48
+ self.protein_sequence_length = 0
49
+ self.image_folders = []
50
+ self.crop_size = crop_size
51
+ self.resize = resize
52
+ self.batch_size = batch_size
53
+ self.sequence_mode = sequence_mode
54
+ self.threshold = threshold
55
+ self.text_seq_len = int(text_seq_len)
56
+ self.vocab = vocab
57
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
58
+
59
+ def setup(self, stage=None):
60
+ # called on every GPU
61
+ self.cell_dataset_train = CellLoader(
62
+ data_csv=self.data_csv,
63
+ dataset=self.dataset,
64
+ crop_size=self.crop_size,
65
+ resize=self.resize,
66
+ split_key="train",
67
+ crop_method="random",
68
+ sequence_mode=self.sequence_mode,
69
+ vocab=self.vocab,
70
+ text_seq_len=self.text_seq_len,
71
+ threshold=self.threshold,
72
+ )
73
+
74
+ self.cell_dataset_val = CellLoader(
75
+ data_csv=self.data_csv,
76
+ dataset=self.dataset,
77
+ crop_size=self.crop_size,
78
+ resize=self.resize,
79
+ crop_method="center",
80
+ split_key="val",
81
+ sequence_mode=self.sequence_mode,
82
+ vocab=self.vocab,
83
+ text_seq_len=self.text_seq_len,
84
+ threshold=self.threshold,
85
+ )
86
+
87
+ def prepare_data(self):
88
+
89
+ pass
90
+
91
+ def train_dataloader(self):
92
+ return DataLoader(
93
+ self.cell_dataset_train,
94
+ num_workers=self.num_workers,
95
+ shuffle=True,
96
+ batch_size=self.batch_size,
97
+ )
98
+
99
+ def val_dataloader(self):
100
+ return DataLoader(
101
+ self.cell_dataset_val,
102
+ num_workers=self.num_workers,
103
+ batch_size=self.batch_size,
104
+ )
105
+
106
+ # def test_dataloader(self):
107
+ # transforms = ...
108
+ # return DataLoader(self.test, batch_size=64)
109
+
110
+
111
+ class CELLE_trainer(pl.LightningModule):
112
+ def __init__(
113
+ self,
114
+ vqgan_model_path,
115
+ vqgan_config_path,
116
+ ckpt_path=None,
117
+ image_key="threshold",
118
+ condition_model_path=None,
119
+ condition_config_path=None,
120
+ num_images=2,
121
+ dim=2,
122
+ num_text_tokens=30,
123
+ text_seq_len=1000,
124
+ depth=16,
125
+ heads=16,
126
+ dim_head=64,
127
+ attn_dropout=0.1,
128
+ ff_dropout=0.1,
129
+ attn_types="full",
130
+ loss_img_weight=7,
131
+ stable=False,
132
+ rotary_emb=True,
133
+ text_embedding="bert",
134
+ fixed_embedding=True,
135
+ loss_cond_weight=1,
136
+ learning_rate=3e-4,
137
+ monitor="val_loss",
138
+ ):
139
+ super().__init__()
140
+
141
+ vae = VQGanVAE(
142
+ vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path
143
+ )
144
+
145
+ self.image_key = image_key
146
+
147
+ if condition_config_path:
148
+ condition_vae = VQGanVAE(
149
+ vqgan_model_path=condition_model_path,
150
+ vqgan_config_path=condition_config_path,
151
+ )
152
+ else:
153
+ condition_vae = None
154
+
155
+ self.celle = CELLE(
156
+ dim=dim,
157
+ vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
158
+ condition_vae=condition_vae,
159
+ num_images=num_images,
160
+ num_text_tokens=num_text_tokens, # vocab size for text
161
+ text_seq_len=text_seq_len, # text sequence length
162
+ depth=depth, # should aim to be 64
163
+ heads=heads, # attention heads
164
+ dim_head=dim_head, # attention head dimension
165
+ attn_dropout=attn_dropout, # attention dropout
166
+ ff_dropout=ff_dropout, # feedforward dropout
167
+ loss_img_weight=loss_img_weight,
168
+ stable=stable,
169
+ rotary_emb=rotary_emb,
170
+ text_embedding=text_embedding,
171
+ fixed_embedding=fixed_embedding,
172
+ loss_cond_weight=loss_cond_weight,
173
+ )
174
+
175
+ self.learning_rate = learning_rate
176
+ self.num_text_tokens = num_text_tokens
177
+ self.num_images = num_images
178
+
179
+ if monitor is not None:
180
+ self.monitor = monitor
181
+
182
+ ignore_keys = []
183
+
184
+ if condition_model_path:
185
+ ignore_keys.append("celle.condition_vae")
186
+
187
+ if vqgan_model_path:
188
+ ignore_keys.append("celle.vae")
189
+
190
+ if ckpt_path is not None:
191
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
192
+
193
+ def init_from_ckpt(self, path, ignore_keys=list()):
194
+ sd = torch.load(path, map_location="cpu")["state_dict"]
195
+ ckpt = sd.copy()
196
+ for k in sd.keys():
197
+ for ik in ignore_keys:
198
+ if k.startswith(ik):
199
+ # print("Deleting key {} from state_dict.".format(k))
200
+ del ckpt[k]
201
+ self.load_state_dict(ckpt, strict=True)
202
+ print(f"Restored from {path}")
203
+
204
+ def forward(self, text, condition, target, return_loss=True):
205
+
206
+ return self.celle(
207
+ text=text, condition=condition, image=target, return_loss=return_loss
208
+ )
209
+
210
+ def get_input(self, batch):
211
+ text = batch["sequence"].squeeze(1)
212
+ condition = batch["nucleus"]
213
+ target = batch[self.image_key]
214
+
215
+ return text, condition, target
216
+
217
+ def get_image_from_logits(self, logits, temperature=0.9):
218
+
219
+ filtered_logits = top_k(logits, thres=0.5)
220
+ sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
221
+
222
+ self.celle.vae.eval()
223
+ out = self.celle.vae.decode(
224
+ sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :]
225
+ - (self.celle.num_text_tokens + self.celle.num_condition_tokens)
226
+ )
227
+
228
+ return out
229
+
230
+ def get_loss(self, text, condition, target):
231
+
232
+ loss_dict = {}
233
+
234
+ loss, loss_dict, logits = self(text, condition, target, return_loss=True)
235
+
236
+ return loss, loss_dict
237
+
238
+ def total_loss(
239
+ self,
240
+ loss,
241
+ loss_dict,
242
+ mode="train",
243
+ ):
244
+
245
+ loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()}
246
+
247
+ for key, value in loss_dict.items():
248
+ self.log(
249
+ key,
250
+ value,
251
+ prog_bar=True,
252
+ logger=True,
253
+ on_step=True,
254
+ on_epoch=True,
255
+ sync_dist=True,
256
+ )
257
+
258
+ return loss
259
+
260
+ def training_step(self, batch, batch_idx):
261
+
262
+ text, condition, target = self.get_input(batch)
263
+ loss, log_dict = self.get_loss(text, condition, target)
264
+
265
+ loss = self.total_loss(loss, log_dict, mode="train")
266
+
267
+ return loss
268
+
269
+ def validation_step(self, batch, batch_idx):
270
+
271
+ with torch.no_grad():
272
+
273
+ text, condition, target = self.get_input(batch)
274
+ loss, log_dict = self.get_loss(text, condition, target)
275
+
276
+ loss = self.total_loss(loss, log_dict, mode="val")
277
+
278
+ return loss
279
+
280
+ def configure_optimizers(self):
281
+
282
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
283
+
284
+ return optimizer
285
+
286
+ def scale_image(self, image):
287
+
288
+ for tensor in image:
289
+ if torch.min(tensor) < 0:
290
+ tensor += -torch.min(tensor)
291
+ else:
292
+ tensor -= torch.min(tensor)
293
+
294
+ tensor /= torch.max(tensor)
295
+
296
+ return image
297
+
298
+ @torch.no_grad()
299
+ def log_images(self, batch, **kwargs):
300
+
301
+ log = []
302
+
303
+ text, condition, target = self.get_input(batch)
304
+ text = text.squeeze(1).to(self.device)
305
+ condition = condition.to(self.device)
306
+
307
+ out = self.celle.generate_images(text=text, condition=condition)
308
+
309
+ log["condition"] = self.scale_image(condition)
310
+ log["output"] = self.scale_image(out)
311
+ if self.image_key == "threshold":
312
+ log["threshold"] = self.scale_image(target)
313
+ log["target"] = self.scale_image(batch["target"])
314
+ else:
315
+ log["target"] = self.scale_image(target)
316
+
317
+ return log
318
+
319
+
320
+ # from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py
321
+
322
+ if __name__ == "__main__":
323
+ # custom parser to specify config files, train, test and debug mode,
324
+ # postfix, resume.
325
+ # `--key value` arguments are interpreted as arguments to the trainer.
326
+ # `nested.key=value` arguments are interpreted as config parameters.
327
+ # configs are merged from left-to-right followed by command line parameters.
328
+
329
+ # model:
330
+ # learning_rate: float
331
+ # target: path to lightning module
332
+ # params:
333
+ # key: value
334
+ # data:
335
+ # target: celle_main.DataModuleFromConfig
336
+ # params:
337
+ # batch_size: int
338
+ # wrap: bool
339
+ # train:
340
+ # target: path to train dataset
341
+ # params:
342
+ # key: value
343
+ # validation:
344
+ # target: path to validation dataset
345
+ # params:
346
+ # key: value
347
+ # test:
348
+ # target: path to test dataset
349
+ # params:
350
+ # key: value
351
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
352
+ # trainer:
353
+ # additional arguments to trainer
354
+ # logger:
355
+ # logger to instantiate
356
+ # modelcheckpoint:
357
+ # modelcheckpoint to instantiate
358
+ # callbacks:
359
+ # callback1:
360
+ # target: importpath
361
+ # params:
362
+ # key: value
363
+
364
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
365
+
366
+ # add cwd for convenience and to make classes in this file available when
367
+ # running as `python celle_main.py`
368
+ # (in particular `celle_main.DataModuleFromConfig`)
369
+ sys.path.append(os.getcwd())
370
+
371
+ parser = get_parser()
372
+ parser = Trainer.add_argparse_args(parser)
373
+
374
+ opt, unknown = parser.parse_known_args()
375
+ if opt.name and opt.resume:
376
+ raise ValueError(
377
+ "-n/--name and -r/--resume cannot be specified both."
378
+ "If you want to resume training in a new log folder, "
379
+ "use -n/--name in combination with --resume_from_checkpoint"
380
+ )
381
+ if opt.resume:
382
+ if not os.path.exists(opt.resume):
383
+ raise ValueError("Cannot find {}".format(opt.resume))
384
+ if os.path.isfile(opt.resume):
385
+ paths = opt.resume.split("/")
386
+ idx = len(paths) - paths[::-1].index("logs") + 1
387
+ logdir = "/".join(paths[:idx])
388
+ ckpt = opt.resume
389
+ else:
390
+ assert os.path.isdir(opt.resume), opt.resume
391
+ logdir = opt.resume.rstrip("/")
392
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
393
+
394
+ opt.resume_from_checkpoint = ckpt
395
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
396
+ opt.base = base_configs + opt.base
397
+ _tmp = logdir.split("/")
398
+ nowname = _tmp[_tmp.index("logs") + 1]
399
+ else:
400
+ if opt.name:
401
+ name = "_" + opt.name
402
+ elif opt.base:
403
+ cfg_fname = os.path.split(opt.base[0])[-1]
404
+ cfg_name = os.path.splitext(cfg_fname)[0]
405
+ name = "_" + cfg_name
406
+ else:
407
+ name = ""
408
+ nowname = now + name + opt.postfix
409
+ logdir = os.path.join("logs", nowname)
410
+
411
+ ckptdir = os.path.join(logdir, "checkpoints")
412
+ cfgdir = os.path.join(logdir, "configs")
413
+ seed_everything(opt.seed)
414
+
415
+ try:
416
+ # init and save configs
417
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
418
+ cli = OmegaConf.from_dotlist(unknown)
419
+ config = OmegaConf.merge(*configs, cli)
420
+ lightning_config = config.pop("lightning", OmegaConf.create())
421
+ # merge trainer cli with config
422
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
423
+ # default to ddp
424
+ # trainer_config["distributed_backend"] = "ddp"
425
+ for k in nondefault_trainer_args(opt):
426
+ trainer_config[k] = getattr(opt, k)
427
+ if not "gpus" in trainer_config:
428
+ del trainer_config["distributed_backend"]
429
+ cpu = True
430
+ else:
431
+ gpuinfo = trainer_config["gpus"]
432
+ print(f"Running on GPUs {gpuinfo}")
433
+ cpu = False
434
+ trainer_opt = argparse.Namespace(**trainer_config)
435
+ lightning_config.trainer = trainer_config
436
+
437
+ # model
438
+ # model = instantiate_from_config(config.model)
439
+ model = instantiate_from_config(config.model)
440
+ # trainer and callbacks
441
+ trainer_kwargs = dict()
442
+
443
+ # default logger configs
444
+ # NOTE wandb < 0.10.0 interferes with shutdown
445
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
446
+ # debugging (wrongly sized pudb ui)
447
+ # thus prefer testtube for now
448
+ default_logger_cfgs = {
449
+ "wandb": {
450
+ "target": "pytorch_lightning.loggers.WandbLogger",
451
+ "params": {
452
+ "name": nowname,
453
+ "save_dir": logdir,
454
+ "offline": opt.debug,
455
+ "id": nowname,
456
+ },
457
+ },
458
+ "testtube": {
459
+ # "target": "pytorch_lightning.loggers.TestTubeLogger",
460
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
461
+ "params": {
462
+ "name": "testtube",
463
+ "save_dir": logdir,
464
+ },
465
+ },
466
+ }
467
+ default_logger_cfg = default_logger_cfgs["testtube"]
468
+ # logger_cfg = lightning_config.logger or OmegaConf.create()
469
+ try:
470
+ logger_cfg = lightning_config.logger
471
+ except:
472
+ logger_cfg = OmegaConf.create()
473
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
474
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
475
+
476
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
477
+ # specify which metric is used to determine best models
478
+ default_modelckpt_cfg = {
479
+ "checkpoint_callback": {
480
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
481
+ "params": {
482
+ "dirpath": ckptdir,
483
+ "filename": "{epoch:06}",
484
+ "verbose": True,
485
+ "save_last": True,
486
+ },
487
+ }
488
+ }
489
+ if hasattr(model, "monitor"):
490
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
491
+ default_modelckpt_cfg["checkpoint_callback"]["params"][
492
+ "monitor"
493
+ ] = model.monitor
494
+ default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
495
+ try:
496
+ modelckpt_cfg = lightning_config.modelcheckpoint
497
+ except:
498
+ modelckpt_cfg = OmegaConf.create()
499
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
500
+ # trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
501
+
502
+ # add callback which sets up log directory
503
+ default_callbacks_cfg = {
504
+ "setup_callback": {
505
+ "target": "celle_taming_main.SetupCallback",
506
+ "params": {
507
+ "resume": opt.resume,
508
+ "now": now,
509
+ "logdir": logdir,
510
+ "ckptdir": ckptdir,
511
+ "cfgdir": cfgdir,
512
+ "config": config,
513
+ "lightning_config": lightning_config,
514
+ },
515
+ },
516
+ # "image_logger": {
517
+ # "target": "celle_taming_main.ImageLogger",
518
+ # "params": {
519
+ # "batch_frequency": 0,
520
+ # "max_images": 0,
521
+ # "clamp": False,
522
+ # "increase_log_steps": False,
523
+ # },
524
+ # },
525
+ # "learning_rate_logger": {
526
+ # "target": "celle_taming_main.LearningRateMonitor",
527
+ # "params": {
528
+ # "logging_interval": "step",
529
+ # # "log_momentum": True
530
+ # },
531
+ # },
532
+ }
533
+ try:
534
+ callbacks_cfg = lightning_config.callbacks
535
+ except:
536
+ callbacks_cfg = OmegaConf.create()
537
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
538
+ callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
539
+ trainer_kwargs["callbacks"] = [
540
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
541
+ ]
542
+
543
+ trainer = Trainer.from_argparse_args(
544
+ trainer_opt, **trainer_kwargs, profiler="simple"
545
+ )
546
+
547
+ # data
548
+ data = instantiate_from_config(config.data)
549
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
550
+ # calling these ourselves should not be necessary but it is.
551
+ # lightning still takes care of proper multiprocessing though
552
+ data.setup()
553
+ data.prepare_data()
554
+
555
+ # configure learning rate
556
+ bs, lr = config.data.params.batch_size, config.model.learning_rate
557
+
558
+ if not cpu:
559
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
560
+ else:
561
+ ngpu = 1
562
+ try:
563
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
564
+ except:
565
+ accumulate_grad_batches = 1
566
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
567
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
568
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * lr
569
+
570
+ print(
571
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format(
572
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, lr
573
+ )
574
+ )
575
+
576
+ # allow checkpointing via USR1
577
+ def melk(*args, **kwargs):
578
+ # run all checkpoint hooks
579
+ if trainer.global_rank == 0:
580
+ print("Summoning checkpoint.")
581
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
582
+ trainer.save_checkpoint(ckpt_path)
583
+
584
+ def divein(*args, **kwargs):
585
+ if trainer.global_rank == 0:
586
+ import pudb
587
+
588
+ pudb.set_trace()
589
+
590
+ import signal
591
+
592
+ signal.signal(signal.SIGUSR1, melk)
593
+ signal.signal(signal.SIGUSR2, divein)
594
+
595
+ # run
596
+ if opt.train:
597
+ try:
598
+ # model = torch.compile(model, mode="reduce_overhead")
599
+ torch.compile(trainer.fit(model, data), mode="max-autotune")
600
+ except Exception:
601
+ melk()
602
+ raise
603
+ if not opt.no_test and not trainer.interrupted:
604
+ trainer.test(model, data)
605
+ except Exception:
606
+ if opt.debug and trainer.global_rank == 0:
607
+ try:
608
+ import pudb as debugger
609
+ except ImportError:
610
+ import pdb as debugger
611
+ debugger.post_mortem()
612
+ raise
613
+ finally:
614
+ # move newly created debug project to debug_runs
615
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
616
+ dst, name = os.path.split(logdir)
617
+ dst = os.path.join(dst, "debug_runs", name)
618
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
619
+ os.rename(logdir, dst)
celle_taming_main.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib
2
+ from omegaconf import OmegaConf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from dataloader import CellLoader
9
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning import seed_everything
12
+ from pytorch_lightning.trainer import Trainer
13
+ from pytorch_lightning.callbacks import Callback
14
+ from pytorch_lightning.utilities import rank_zero_only
15
+
16
+
17
+ def get_obj_from_str(string, reload=False):
18
+ module, cls = string.rsplit(".", 1)
19
+ if reload:
20
+ module_imp = importlib.import_module(module)
21
+ importlib.reload(module_imp)
22
+ return getattr(importlib.import_module(module, package=None), cls)
23
+
24
+
25
+ def get_parser(**parser_kwargs):
26
+ def str2bool(v):
27
+ if isinstance(v, bool):
28
+ return v
29
+ if v.lower() in ("yes", "true", "t", "y", "1"):
30
+ return True
31
+ elif v.lower() in ("no", "false", "f", "n", "0"):
32
+ return False
33
+ else:
34
+ raise argparse.ArgumentTypeError("Boolean value expected.")
35
+
36
+ parser = argparse.ArgumentParser(**parser_kwargs)
37
+ parser.add_argument(
38
+ "-n",
39
+ "--name",
40
+ type=str,
41
+ const=True,
42
+ default="",
43
+ nargs="?",
44
+ help="postfix for logdir",
45
+ )
46
+ parser.add_argument(
47
+ "-r",
48
+ "--resume",
49
+ type=str,
50
+ const=True,
51
+ default="",
52
+ nargs="?",
53
+ help="resume from logdir or checkpoint in logdir",
54
+ )
55
+ parser.add_argument(
56
+ "-b",
57
+ "--base",
58
+ nargs="*",
59
+ metavar="base_config.yaml",
60
+ help="paths to base configs. Loaded from left-to-right. "
61
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
62
+ default=list(),
63
+ )
64
+ parser.add_argument(
65
+ "-t",
66
+ "--train",
67
+ type=str2bool,
68
+ const=True,
69
+ default=False,
70
+ nargs="?",
71
+ help="train",
72
+ )
73
+ parser.add_argument(
74
+ "--no-test",
75
+ type=str2bool,
76
+ const=True,
77
+ default=False,
78
+ nargs="?",
79
+ help="disable test",
80
+ )
81
+ parser.add_argument(
82
+ "-p", "--project", help="name of new or path to existing project"
83
+ )
84
+ parser.add_argument(
85
+ "-d",
86
+ "--debug",
87
+ type=str2bool,
88
+ nargs="?",
89
+ const=True,
90
+ default=False,
91
+ help="enable post-mortem debugging",
92
+ )
93
+ parser.add_argument(
94
+ "-s",
95
+ "--seed",
96
+ type=int,
97
+ default=42,
98
+ help="seed for seed_everything",
99
+ )
100
+ parser.add_argument(
101
+ "-f",
102
+ "--postfix",
103
+ type=str,
104
+ default="",
105
+ help="post-postfix for default name",
106
+ )
107
+
108
+ return parser
109
+
110
+
111
+ def nondefault_trainer_args(opt):
112
+ parser = argparse.ArgumentParser()
113
+ parser = Trainer.add_argparse_args(parser)
114
+ args = parser.parse_args([])
115
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
116
+
117
+
118
+ def instantiate_from_config(config):
119
+ if not "target" in config:
120
+ raise KeyError("Expected key `target` to instantiate.")
121
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
122
+
123
+
124
+ class WrappedDataset(Dataset):
125
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
126
+
127
+ def __init__(self, dataset):
128
+ self.data = dataset
129
+
130
+ def __len__(self):
131
+ return len(self.data)
132
+
133
+ def __getitem__(self, idx):
134
+ return self.data[idx]
135
+
136
+
137
+ class DataModuleFromConfig(pl.LightningDataModule):
138
+ def __init__(
139
+ self,
140
+ data_csv,
141
+ dataset,
142
+ crop_size=256,
143
+ resize=600,
144
+ batch_size=1,
145
+ sequence_mode="latent",
146
+ vocab="bert",
147
+ text_seq_len=0,
148
+ num_workers=1,
149
+ threshold=False,
150
+ train=True,
151
+ validation=True,
152
+ test=None,
153
+ wrap=False,
154
+ **kwargs,
155
+ ):
156
+ super().__init__()
157
+ self.data_csv = data_csv
158
+ self.dataset = dataset
159
+ self.image_folders = []
160
+ self.crop_size = crop_size
161
+ self.resize = resize
162
+ self.batch_size = batch_size
163
+ self.sequence_mode = sequence_mode
164
+ self.threshold = threshold
165
+ self.text_seq_len = int(text_seq_len)
166
+ self.vocab = vocab
167
+ self.dataset_configs = dict()
168
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
169
+ if train is not None:
170
+ self.dataset_configs["train"] = train
171
+ self.train_dataloader = self._train_dataloader
172
+ if validation is not None:
173
+ self.dataset_configs["validation"] = validation
174
+ self.val_dataloader = self._val_dataloader
175
+ if test is not None:
176
+ self.dataset_configs["test"] = test
177
+ self.test_dataloader = self._test_dataloader
178
+ self.wrap = wrap
179
+
180
+ def prepare_data(self):
181
+ pass
182
+
183
+ def setup(self, stage=None):
184
+ # called on every GPU
185
+ self.cell_dataset_train = CellLoader(
186
+ data_csv=self.data_csv,
187
+ dataset=self.dataset,
188
+ crop_size=self.crop_size,
189
+ split_key="train",
190
+ crop_method="random",
191
+ sequence_mode=None,
192
+ vocab=self.vocab,
193
+ text_seq_len=self.text_seq_len,
194
+ threshold=self.threshold,
195
+ )
196
+
197
+ self.cell_dataset_val = CellLoader(
198
+ data_csv=self.data_csv,
199
+ dataset=self.dataset,
200
+ crop_size=self.crop_size,
201
+ split_key="val",
202
+ crop_method="center",
203
+ sequence_mode=None,
204
+ vocab=self.vocab,
205
+ text_seq_len=self.text_seq_len,
206
+ threshold=self.threshold,
207
+ )
208
+
209
+ def _train_dataloader(self):
210
+ return DataLoader(
211
+ self.cell_dataset_train,
212
+ num_workers=self.num_workers,
213
+ pin_memory=True,
214
+ shuffle=True,
215
+ batch_size=self.batch_size,
216
+ )
217
+
218
+ def _val_dataloader(self):
219
+ return DataLoader(
220
+ self.cell_dataset_val,
221
+ num_workers=self.num_workers,
222
+ pin_memory=True,
223
+ batch_size=self.batch_size,
224
+ )
225
+
226
+ # def _test_dataloader(self):
227
+ # return DataLoader(self.datasets["test"], batch_size=self.batch_size,
228
+ # num_workers=self.num_workers)
229
+
230
+
231
+ class SetupCallback(Callback):
232
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
233
+ super().__init__()
234
+ self.resume = resume
235
+ self.now = now
236
+ self.logdir = logdir
237
+ self.ckptdir = ckptdir
238
+ self.cfgdir = cfgdir
239
+ self.config = config
240
+ self.lightning_config = lightning_config
241
+
242
+ def on_fit_start(self, trainer, pl_module):
243
+ if trainer.global_rank == 0:
244
+ # Create logdirs and save configs
245
+ os.makedirs(self.logdir, exist_ok=True)
246
+ os.makedirs(self.ckptdir, exist_ok=True)
247
+ os.makedirs(self.cfgdir, exist_ok=True)
248
+
249
+ print("Project config")
250
+ print(OmegaConf.to_yaml(self.config))
251
+ OmegaConf.save(
252
+ self.config,
253
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
254
+ )
255
+
256
+ print("Lightning config")
257
+ print(OmegaConf.to_yaml(self.lightning_config))
258
+ OmegaConf.save(
259
+ OmegaConf.create({"lightning": self.lightning_config}),
260
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
261
+ )
262
+
263
+ else:
264
+ # ModelCheckpoint callback created log directory --- remove it
265
+ if not self.resume and os.path.exists(self.logdir):
266
+ dst, name = os.path.split(self.logdir)
267
+ dst = os.path.join(dst, "child_runs", name)
268
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
269
+ try:
270
+ os.rename(self.logdir, dst)
271
+ except FileNotFoundError:
272
+ pass
273
+
274
+
275
+ class ImageLogger(Callback):
276
+ def __init__(
277
+ self, batch_frequency, max_images, clamp=True, increase_log_steps=True
278
+ ):
279
+ super().__init__()
280
+ self.batch_freq = batch_frequency
281
+ self.max_images = max_images
282
+ self.logger_log_images = {
283
+ pl.loggers.WandbLogger: self._wandb,
284
+ # pl.loggers.TestTubeLogger: self._testtube,
285
+ pl.loggers.TensorBoardLogger: self._testtube,
286
+ }
287
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
288
+ if not increase_log_steps:
289
+ self.log_steps = [self.batch_freq]
290
+ self.clamp = clamp
291
+
292
+ @rank_zero_only
293
+ def _wandb(self, pl_module, images, batch_idx, split):
294
+ raise ValueError("No way wandb")
295
+ grids = dict()
296
+ for k in images:
297
+ grid = torchvision.utils.make_grid(images[k])
298
+ grids[f"{split}/{k}"] = wandb.Image(grid)
299
+ pl_module.logger.experiment.log(grids)
300
+
301
+ @rank_zero_only
302
+ def _testtube(self, pl_module, images, batch_idx, split):
303
+ for k in images:
304
+ images[k] -= torch.min(images[k])
305
+ images[k] /= torch.max(images[k])
306
+ grid = torchvision.utils.make_grid(images[k])
307
+ # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
308
+
309
+ tag = f"{split}/{k}"
310
+ pl_module.logger.experiment.add_image(
311
+ tag, grid, global_step=pl_module.global_step
312
+ )
313
+
314
+ @rank_zero_only
315
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
316
+ root = os.path.join(save_dir, "images", split)
317
+ for k in images:
318
+ images[k] -= torch.min(images[k])
319
+ images[k] /= torch.max(images[k])
320
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
321
+
322
+ # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
323
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
324
+ grid = grid.numpy()
325
+ grid = (grid * 255).astype(np.uint8)
326
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
327
+ k, global_step, current_epoch, batch_idx
328
+ )
329
+ path = os.path.join(root, filename)
330
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
331
+ Image.fromarray(grid).save(path)
332
+
333
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
334
+ if (
335
+ self.check_frequency(batch_idx)
336
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
337
+ and callable(pl_module.log_images)
338
+ and self.max_images > 0
339
+ ):
340
+ logger = type(pl_module.logger)
341
+
342
+ is_train = pl_module.training
343
+ if is_train:
344
+ pl_module.eval()
345
+
346
+ with torch.no_grad():
347
+ images = pl_module.log_images(batch, split=split)
348
+
349
+ for k in images:
350
+ N = min(images[k].shape[0], self.max_images)
351
+ images[k] = images[k][:N]
352
+ if isinstance(images[k], torch.Tensor):
353
+ images[k] = images[k].detach().cpu()
354
+ if self.clamp:
355
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
356
+
357
+ self.log_local(
358
+ pl_module.logger.save_dir,
359
+ split,
360
+ images,
361
+ pl_module.global_step,
362
+ pl_module.current_epoch,
363
+ batch_idx,
364
+ )
365
+
366
+ logger_log_images = self.logger_log_images.get(
367
+ logger, lambda *args, **kwargs: None
368
+ )
369
+ logger_log_images(pl_module, images, pl_module.global_step, split)
370
+
371
+ if is_train:
372
+ pl_module.train()
373
+
374
+ def check_frequency(self, batch_idx):
375
+ if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
376
+ try:
377
+ self.log_steps.pop(0)
378
+ except IndexError:
379
+ pass
380
+ return True
381
+ return False
382
+
383
+ # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
384
+ # def on_train_batch_end(self, *args, **kwargs):
385
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
386
+ self.log_img(pl_module, batch, batch_idx, split="train")
387
+
388
+ def on_validation_batch_end(
389
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
390
+ ):
391
+ self.log_img(pl_module, batch, batch_idx, split="val")
392
+
393
+
394
+ if __name__ == "__main__":
395
+ # custom parser to specify config files, train, test and debug mode,
396
+ # postfix, resume.
397
+ # `--key value` arguments are interpreted as arguments to the trainer.
398
+ # `nested.key=value` arguments are interpreted as config parameters.
399
+ # configs are merged from left-to-right followed by command line parameters.
400
+
401
+ # model:
402
+ # base_learning_rate: float
403
+ # target: path to lightning module
404
+ # params:
405
+ # key: value
406
+ # data:
407
+ # target: main.DataModuleFromConfig
408
+ # params:
409
+ # batch_size: int
410
+ # wrap: bool
411
+ # train:
412
+ # target: path to train dataset
413
+ # params:
414
+ # key: value
415
+ # validation:
416
+ # target: path to validation dataset
417
+ # params:
418
+ # key: value
419
+ # test:
420
+ # target: path to test dataset
421
+ # params:
422
+ # key: value
423
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
424
+ # trainer:
425
+ # additional arguments to trainer
426
+ # logger:
427
+ # logger to instantiate
428
+ # modelcheckpoint:
429
+ # modelcheckpoint to instantiate
430
+ # callbacks:
431
+ # callback1:
432
+ # target: importpath
433
+ # params:
434
+ # key: value
435
+
436
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
437
+
438
+ # add cwd for convenience and to make classes in this file available when
439
+ # running as `python main.py`
440
+ # (in particular `main.DataModuleFromConfig`)
441
+ sys.path.append(os.getcwd())
442
+
443
+ parser = get_parser()
444
+ parser = Trainer.add_argparse_args(parser)
445
+
446
+ opt, unknown = parser.parse_known_args()
447
+ if opt.name and opt.resume:
448
+ raise ValueError(
449
+ "-n/--name and -r/--resume cannot be specified both."
450
+ "If you want to resume training in a new log folder, "
451
+ "use -n/--name in combination with --resume_from_checkpoint"
452
+ )
453
+ if opt.resume:
454
+ if not os.path.exists(opt.resume):
455
+ raise ValueError("Cannot find {}".format(opt.resume))
456
+ if os.path.isfile(opt.resume):
457
+ paths = opt.resume.split("/")
458
+ idx = len(paths) - paths[::-1].index("logs") + 1
459
+ logdir = "/".join(paths[:idx])
460
+ ckpt = opt.resume
461
+ else:
462
+ assert os.path.isdir(opt.resume), opt.resume
463
+ logdir = opt.resume.rstrip("/")
464
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
465
+
466
+ opt.resume_from_checkpoint = ckpt
467
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
468
+ opt.base = base_configs + opt.base
469
+ _tmp = logdir.split("/")
470
+ nowname = _tmp[_tmp.index("logs") + 1]
471
+ else:
472
+ if opt.name:
473
+ name = "_" + opt.name
474
+ elif opt.base:
475
+ cfg_fname = os.path.split(opt.base[0])[-1]
476
+ cfg_name = os.path.splitext(cfg_fname)[0]
477
+ name = "_" + cfg_name
478
+ else:
479
+ name = ""
480
+ nowname = now + name + opt.postfix
481
+ logdir = os.path.join("logs", nowname)
482
+
483
+ ckptdir = os.path.join(logdir, "checkpoints")
484
+ cfgdir = os.path.join(logdir, "configs")
485
+ seed_everything(opt.seed)
486
+
487
+ try:
488
+ # init and save configs
489
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
490
+ cli = OmegaConf.from_dotlist(unknown)
491
+ config = OmegaConf.merge(*configs, cli)
492
+ lightning_config = config.pop("lightning", OmegaConf.create())
493
+ # merge trainer cli with config
494
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
495
+ # default to ddp
496
+ trainer_config["distributed_backend"] = "ddp"
497
+ trainer_config["replace_sampler_ddp"] = False
498
+ trainer_config["strategy"] = "ddp"
499
+ trainer_config["persistent_workers"] = True
500
+ for k in nondefault_trainer_args(opt):
501
+ trainer_config[k] = getattr(opt, k)
502
+ if not "gpus" in trainer_config:
503
+ del trainer_config["distributed_backend"]
504
+ cpu = True
505
+ else:
506
+ gpuinfo = trainer_config["gpus"]
507
+ print(f"Running on GPUs {gpuinfo}")
508
+ cpu = False
509
+ trainer_opt = argparse.Namespace(**trainer_config)
510
+ lightning_config.trainer = trainer_config
511
+
512
+ # model
513
+ model = instantiate_from_config(config.model)
514
+ # trainer and callbacks
515
+ trainer_kwargs = dict()
516
+
517
+ # default logger configs
518
+ # NOTE wandb < 0.10.0 interferes with shutdown
519
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
520
+ # debugging (wrongly sized pudb ui)
521
+ # thus prefer testtube for now
522
+ default_logger_cfgs = {
523
+ "wandb": {
524
+ "target": "pytorch_lightning.loggers.WandbLogger",
525
+ "params": {
526
+ "name": nowname,
527
+ "save_dir": logdir,
528
+ "offline": opt.debug,
529
+ "id": nowname,
530
+ },
531
+ },
532
+ "testtube": {
533
+ # "target": "pytorch_lightning.loggers.TestTubeLogger",
534
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
535
+ "params": {
536
+ "name": "testtube",
537
+ "save_dir": logdir,
538
+ },
539
+ },
540
+ }
541
+ default_logger_cfg = default_logger_cfgs["testtube"]
542
+ try:
543
+ logger_cfg = lightning_config.logger
544
+ except:
545
+ logger_cfg = OmegaConf.create()
546
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
547
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
548
+
549
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
550
+ # specify which metric is used to determine best models
551
+ default_modelckpt_cfg = {
552
+ "checkpoint_callback": {
553
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
554
+ "params": {
555
+ "dirpath": ckptdir,
556
+ "filename": "{epoch:06}",
557
+ "verbose": True,
558
+ "save_last": True,
559
+ },
560
+ }
561
+ }
562
+ if hasattr(model, "monitor"):
563
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
564
+ default_modelckpt_cfg["checkpoint_callback"]["params"][
565
+ "monitor"
566
+ ] = model.monitor
567
+ default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
568
+ try:
569
+ modelckpt_cfg = lightning_config.modelcheckpoint
570
+ except:
571
+ modelckpt_cfg = OmegaConf.create()
572
+
573
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
574
+ # trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
575
+
576
+ # loaded_model_callbacks = instantiate_from_config(modelckpt_cfg)
577
+
578
+ # add callback which sets up log directory
579
+ default_callbacks_cfg = {
580
+ "setup_callback": {
581
+ "target": "celle_taming_main.SetupCallback",
582
+ "params": {
583
+ "resume": opt.resume,
584
+ "now": now,
585
+ "logdir": logdir,
586
+ "ckptdir": ckptdir,
587
+ "cfgdir": cfgdir,
588
+ "config": config,
589
+ "lightning_config": lightning_config,
590
+ },
591
+ },
592
+ "image_logger": {
593
+ "target": "celle_taming_main.ImageLogger",
594
+ "params": {
595
+ "batch_frequency": 2000,
596
+ "max_images": 10,
597
+ "clamp": True,
598
+ "increase_log_steps": False,
599
+ },
600
+ },
601
+ "learning_rate_logger": {
602
+ "target": "celle_taming_main.LearningRateMonitor",
603
+ "params": {
604
+ "logging_interval": "step",
605
+ # "log_momentum": True
606
+ },
607
+ },
608
+ }
609
+ try:
610
+ callbacks_cfg = lightning_config.callbacks
611
+ except:
612
+ callbacks_cfg = OmegaConf.create()
613
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
614
+ callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
615
+ trainer_kwargs["callbacks"] = [
616
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
617
+ ]
618
+ # loaded_callbacks = [
619
+ # instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
620
+ # ]
621
+
622
+ # trainer_kwargs["callbacks"] = loaded_callbacks.append(loaded_model_callbacks)
623
+
624
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
625
+
626
+ # data
627
+ data = instantiate_from_config(config.data)
628
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
629
+ # calling these ourselves should not be necessary but it is.
630
+ # lightning still takes care of proper multiprocessing though
631
+ data.prepare_data()
632
+ data.setup()
633
+
634
+ # configure learning rate
635
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
636
+ if not cpu:
637
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
638
+ else:
639
+ ngpu = 1
640
+ try:
641
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
642
+ except:
643
+ accumulate_grad_batches = 1
644
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
645
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
646
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
647
+ print(
648
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
649
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
650
+ )
651
+ )
652
+
653
+ # allow checkpointing via USR1
654
+ def melk(*args, **kwargs):
655
+ # run all checkpoint hooks
656
+ if trainer.global_rank == 0:
657
+ print("Summoning checkpoint.")
658
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
659
+ trainer.save_checkpoint(ckpt_path)
660
+
661
+ def divein(*args, **kwargs):
662
+ if trainer.global_rank == 0:
663
+ import pudb
664
+
665
+ pudb.set_trace()
666
+
667
+ import signal
668
+
669
+ signal.signal(signal.SIGUSR1, melk)
670
+ signal.signal(signal.SIGUSR2, divein)
671
+ # model = torch.compile(model)
672
+ # run
673
+ if opt.train:
674
+ try:
675
+ torch.compile(trainer.fit(model, data))
676
+ except Exception:
677
+ melk()
678
+ raise
679
+ if not opt.no_test and not trainer.interrupted:
680
+ trainer.test(model, data)
681
+ except Exception:
682
+ if opt.debug and trainer.global_rank == 0:
683
+ try:
684
+ import pudb as debugger
685
+ except ImportError:
686
+ import pdb as debugger
687
+ debugger.post_mortem()
688
+ raise
689
+ finally:
690
+ # move newly created debug project to debug_runs
691
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
692
+ dst, name = os.path.split(logdir)
693
+ dst = os.path.join(dst, "debug_runs", name)
694
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
695
+ os.rename(logdir, dst)
dataloader.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image, ImageSequence
4
+ import json
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+ import torchvision.transforms.functional as TF
11
+
12
+
13
+ def simple_conversion(seq):
14
+ """Create 26-dim embedding"""
15
+ chars = [
16
+ "-",
17
+ "M",
18
+ "R",
19
+ "H",
20
+ "K",
21
+ "D",
22
+ "E",
23
+ "S",
24
+ "T",
25
+ "N",
26
+ "Q",
27
+ "C",
28
+ "U",
29
+ "G",
30
+ "P",
31
+ "A",
32
+ "V",
33
+ "I",
34
+ "F",
35
+ "Y",
36
+ "W",
37
+ "L",
38
+ "O",
39
+ "X",
40
+ "Z",
41
+ "B",
42
+ "J",
43
+ ]
44
+
45
+ nums = range(len(chars))
46
+
47
+ seqs_x = np.zeros(len(seq))
48
+
49
+ for idx, char in enumerate(seq):
50
+
51
+ lui = chars.index(char)
52
+
53
+ seqs_x[idx] = nums[lui]
54
+
55
+ return torch.tensor([seqs_x]).long()
56
+
57
+
58
+ def replace_outliers(image, percentile=0.0001):
59
+
60
+ lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile(
61
+ image, 1 - percentile
62
+ )
63
+ mask = (image <= upper_bound) & (image >= lower_bound)
64
+
65
+ valid_pixels = image[mask]
66
+
67
+ image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels))
68
+
69
+ return image
70
+
71
+
72
+ class CellLoader(Dataset):
73
+ """imports mined opencell images with protein sequence"""
74
+
75
+ def __init__(
76
+ self,
77
+ data_csv=None,
78
+ dataset=None,
79
+ split_key=None,
80
+ resize=600,
81
+ crop_size=600,
82
+ crop_method="random",
83
+ sequence_mode="simple",
84
+ vocab="bert",
85
+ threshold="median",
86
+ text_seq_len=0,
87
+ pad_mode="random",
88
+ ):
89
+ self.data_csv = data_csv
90
+ self.dataset = dataset
91
+ self.image_folders = []
92
+ self.crop_method = crop_method
93
+ self.resize = resize
94
+ self.crop_size = crop_size
95
+ self.sequence_mode = sequence_mode
96
+ self.threshold = threshold
97
+ self.text_seq_len = int(text_seq_len)
98
+ self.vocab = vocab
99
+ self.pad_mode = pad_mode
100
+
101
+ if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
102
+
103
+
104
+ if self.vocab == "esm1b" or self.vocab == "esm2":
105
+ from esm import Alphabet
106
+
107
+ self.tokenizer = Alphabet.from_architecture(
108
+ "ESM-1b"
109
+ ).get_batch_converter()
110
+ self.text_seq_len += 2
111
+
112
+ if data_csv:
113
+
114
+ data = pd.read_csv(data_csv)
115
+
116
+ self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
117
+
118
+ if split_key == "train":
119
+ self.data = data[data["split"] == "train"]
120
+ elif split_key == "val":
121
+ self.data = data[data["split"] == "val"]
122
+ else:
123
+ self.data = data
124
+
125
+ self.data = self.data.reset_index(drop=True)
126
+
127
+
128
+
129
+ def __len__(self):
130
+ return len(self.data)
131
+
132
+ def __getitem__(
133
+ self,
134
+ idx,
135
+ get_sequence=True,
136
+ get_images=True,
137
+ ):
138
+ if get_sequence and self.text_seq_len > 0:
139
+
140
+ protein_vector = self.get_protein_vector(idx)
141
+
142
+ else:
143
+ protein_vector = torch.zeros((1, 1))
144
+
145
+ if get_images:
146
+
147
+ nucleus, target, threshold = self.get_images(idx, self.dataset)
148
+ else:
149
+ nucleus, target, threshold = torch.zeros((3, 1))
150
+
151
+ data_dict = {
152
+ "nucleus": nucleus.float(),
153
+ "target": target.float(),
154
+ "threshold": threshold.float(),
155
+ "sequence": protein_vector.long(),
156
+ }
157
+
158
+ return data_dict
159
+
160
+ def get_protein_vector(self, idx):
161
+
162
+ if "protein_sequence" not in self.data.columns:
163
+
164
+ metadata = self.retrieve_metadata(idx)
165
+ protein_sequence = metadata["sequence"]
166
+ else:
167
+ protein_sequence = self.data.iloc[idx]["protein_sequence"]
168
+
169
+ protein_vector = self.tokenize_sequence(protein_sequence)
170
+
171
+ return protein_vector
172
+
173
+ def get_images(self, idx, dataset):
174
+
175
+ if dataset == "HPA":
176
+
177
+ nucleus = Image.open(
178
+ os.path.join(
179
+ self.parent_path, self.data.iloc[idx]["nucleus_image_path"]
180
+ )
181
+ )
182
+
183
+ target = Image.open(
184
+ os.path.join(self.parent_path, self.data.iloc[idx]["target_image_path"])
185
+ )
186
+
187
+ nucleus = TF.to_tensor(nucleus)[0]
188
+ target = TF.to_tensor(target)[0]
189
+
190
+ image = torch.stack([nucleus, target], axis=0)
191
+
192
+ normalize = (0.0655, 0.0650), (0.1732, 0.1208)
193
+
194
+ elif dataset == "OpenCell":
195
+ image = Image.open(
196
+ os.path.join(self.parent_path, self.data.iloc[idx]["image_path"])
197
+ )
198
+ nucleus, target = [page.copy() for page in ImageSequence.Iterator(image)]
199
+
200
+ nucleus = replace_outliers(torch.divide(TF.to_tensor(nucleus), 65536))[0]
201
+ target = replace_outliers(torch.divide(TF.to_tensor(target), 65536))[0]
202
+
203
+ image = torch.stack([nucleus, target], axis=0)
204
+
205
+ normalize = (
206
+ (0.0272, 0.0244),
207
+ (0.0486, 0.0671),
208
+ )
209
+
210
+ # # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
211
+
212
+ t_forms = [transforms.Resize(self.resize, antialias=None)]
213
+
214
+ if self.crop_method == "random":
215
+
216
+ t_forms.append(transforms.RandomCrop(self.crop_size))
217
+ t_forms.append(transforms.RandomHorizontalFlip(p=0.5))
218
+ t_forms.append(transforms.RandomVerticalFlip(p=0.5))
219
+
220
+ elif self.crop_method == "center":
221
+
222
+ t_forms.append(transforms.CenterCrop(self.crop_size))
223
+
224
+ t_forms.append(transforms.Normalize(normalize[0], normalize[1]))
225
+
226
+ image = transforms.Compose(t_forms)(image)
227
+
228
+ nucleus, target = image
229
+
230
+ nucleus /= torch.abs(nucleus).max()
231
+ target -= target.min()
232
+ target /= target.max()
233
+
234
+ nucleus = nucleus.unsqueeze(0)
235
+ target = target.unsqueeze(0)
236
+
237
+ threshold = target
238
+
239
+ if self.threshold == "mean":
240
+
241
+ threshold = 1.0 * (threshold > (torch.mean(threshold)))
242
+
243
+ elif self.threshold == "median":
244
+
245
+ threshold = 1.0 * (threshold > (torch.median(threshold)))
246
+
247
+ elif self.threshold == "1090_IQR":
248
+
249
+ p10 = torch.quantile(threshold, 0.1, None)
250
+ p90 = torch.quantile(threshold, 0.9, None)
251
+ threshold = torch.clip(threshold, p10, p90)
252
+
253
+ nucleus = torch.nan_to_num(nucleus, 0.0, 1.0, 0.0)
254
+ target = torch.nan_to_num(target, 0.0, 1.0, 0.0)
255
+ threshold = torch.nan_to_num(threshold, 0.0, 1.0, 0.0)
256
+
257
+ return nucleus, target, threshold
258
+
259
+ def retrieve_metadata(self, idx):
260
+ with open(
261
+ os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
262
+ ) as f:
263
+ metadata = json.load(f)
264
+
265
+ return metadata
266
+
267
+ def tokenize_sequence(self, protein_sequence):
268
+
269
+ pad_token = 0
270
+
271
+ if self.sequence_mode == "simple":
272
+ protein_vector = simple_conversion(protein_sequence)
273
+
274
+ elif self.sequence_mode == "center":
275
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
276
+ protein_vector = simple_conversion(protein_sequence)
277
+
278
+ elif self.sequence_mode == "alternating":
279
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
280
+ protein_sequence = protein_sequence[::18]
281
+ protein_sequence = protein_sequence.center(
282
+ int(self.text_seq_length / 18) + 1, "-"
283
+ )
284
+ protein_vector = simple_conversion(protein_sequence)
285
+
286
+
287
+ elif self.sequence_mode == "embedding":
288
+
289
+ if self.vocab == "esm1b" or self.vocab == "esm2":
290
+ pad_token = 1
291
+ protein_vector = self.tokenizer([("", protein_sequence)])[-1]
292
+
293
+ if protein_vector.shape[-1] < self.text_seq_len:
294
+
295
+ diff = self.text_seq_len - protein_vector.shape[-1]
296
+
297
+ if self.pad_mode == "end":
298
+ protein_vector = torch.nn.functional.pad(
299
+ protein_vector, (0, diff), "constant", pad_token
300
+ )
301
+ elif self.pad_mode == "random":
302
+ split = diff - np.random.randint(0, diff + 1)
303
+
304
+ protein_vector = torch.cat(
305
+ [torch.ones(1, split) * 0, protein_vector], dim=1
306
+ )
307
+
308
+ protein_vector = torch.nn.functional.pad(
309
+ protein_vector, (0, diff - split), "constant", pad_token
310
+ )
311
+
312
+ elif protein_vector.shape[-1] > self.text_seq_len:
313
+ start_int = np.random.randint(
314
+ 0, protein_vector.shape[-1] - self.text_seq_len
315
+ )
316
+
317
+ protein_vector = protein_vector[
318
+ :, start_int : start_int + self.text_seq_len
319
+ ]
320
+
321
+ return protein_vector.long()
prediction.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ os.chdir('..')
5
+ from dataloader import CellLoader
6
+ from matplotlib import pyplot as plt
7
+ from celle_main import instantiate_from_config
8
+ from omegaconf import OmegaConf
9
+ from celle.utils import process_image
10
+
11
+ def run_model(mode, sequence,
12
+ nucleus_image_path,
13
+ protein_image_path,
14
+ model_ckpt_path,
15
+ model_config_path,
16
+ device):
17
+ if mode == "image":
18
+ run_image_prediction(
19
+ sequence,
20
+ nucleus_image_path,
21
+ protein_image_path,
22
+ model_ckpt_path,
23
+ model_config_path,
24
+ device
25
+ )
26
+ elif mode == "sequence":
27
+ run_sequence_prediction(
28
+ sequence,
29
+ nucleus_image_path,
30
+ protein_image_path,
31
+ model_ckpt_path,
32
+ model_config_path,
33
+ device
34
+ )
35
+
36
+ def run_sequence_prediction(
37
+ sequence_input,
38
+ nucleus_image_path,
39
+ protein_image_path,
40
+ model_ckpt_path,
41
+ model_config_path,
42
+ device
43
+ ):
44
+ """
45
+ Run Celle model with provided inputs and display results.
46
+
47
+ :param sequence: Path to sequence file
48
+ :param nucleus_image_path: Path to nucleus image
49
+ :param protein_image_path: Path to protein image (optional)
50
+ :param model_ckpt_path: Path to model checkpoint
51
+ :param model_config_path: Path to model config
52
+ """
53
+
54
+ # Instantiate dataset object
55
+ dataset = CellLoader(
56
+ sequence_mode="embedding",
57
+ vocab="esm2",
58
+ split_key="val",
59
+ crop_method="center",
60
+ resize=600,
61
+ crop_size=256,
62
+ text_seq_len=1000,
63
+ pad_mode="end",
64
+ threshold="median",
65
+ )
66
+
67
+ # Check if sequence is provided and valid
68
+ if len(sequence_input) == 0:
69
+ raise ValueError("Sequence must be provided.")
70
+
71
+ if "<mask>" not in sequence_input:
72
+ print("Warning: Sequence does not contain any masked positions to predict.")
73
+
74
+ # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
75
+ sequence = dataset.tokenize_sequence(sequence_input)
76
+
77
+ # Check if nucleus image path is provided and valid
78
+ if not os.path.exists(nucleus_image_path):
79
+ # Use default nucleus image from dataset and print warning
80
+ nucleus_image_path = 'images/nucleus.jpg'
81
+ print(
82
+ "Warning: No nucleus image provided. Using default nucleus image from dataset."
83
+ )
84
+ else:
85
+ # Load nucleus image from provided path
86
+ nucleus_image = process_image(nucleus_image_path)
87
+
88
+ # Check if protein image path is provided and valid
89
+ if not os.path.exists(protein_image_path):
90
+ # Use default nucleus image from dataset and print warning
91
+ protein_image_path = 'images/protein.jpg'
92
+ print(
93
+ "Warning: No nucleus image provided. Using default protein image from dataset."
94
+ )
95
+ else:
96
+ # Load protein image from provided path
97
+ protein_image = process_image(protein_image_path)
98
+ protein_image = (protein_image > torch.median(protein_image,dim=0))*1.0
99
+
100
+ # Load model config and set ckpt_path if not provided in config
101
+ config = OmegaConf.load(model_config_path)
102
+ if config["model"]["params"]["ckpt_path"] is None:
103
+ config["model"]["params"]["ckpt_path"] = model_ckpt_path
104
+
105
+ # Set condition_model_path and vqgan_model_path to None
106
+ config["model"]["params"]["condition_model_path"] = None
107
+ config["model"]["params"]["vqgan_model_path"] = None
108
+
109
+ # Instantiate model from config and move to device
110
+ model = instantiate_from_config(config).to(device)
111
+
112
+ # Sample from model using provided sequence and nucleus image
113
+ _, predicted_sequence, _ = model.celle.sample_text(
114
+ text=sequence,
115
+ condition=nucleus_image,
116
+ image=protein_image,
117
+ force_aas=True,
118
+ timesteps=1,
119
+ temperature=1,
120
+ progress=True,
121
+ )
122
+
123
+ formatted_predicted_sequence = ""
124
+
125
+ for i in range(min(len(predicted_sequence), len(sequence))):
126
+ if predicted_sequence[i] != sequence[i]:
127
+ formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
128
+ else:
129
+ formatted_predicted_sequence += predicted_sequence[i]
130
+
131
+ if len(predicted_sequence) > len(sequence):
132
+ formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
133
+
134
+ print("predicted_sequence:", formatted_predicted_sequence)
135
+
136
+
137
+ def run_image_prediction(
138
+ sequence_input,
139
+ nucleus_image_path,
140
+ protein_image_path,
141
+ model_ckpt_path,
142
+ model_config_path,
143
+ device
144
+ ):
145
+ """
146
+ Run Celle model with provided inputs and display results.
147
+
148
+ :param sequence: Path to sequence file
149
+ :param nucleus_image_path: Path to nucleus image
150
+ :param protein_image_path: Path to protein image (optional)
151
+ :param model_ckpt_path: Path to model checkpoint
152
+ :param model_config_path: Path to model config
153
+ """
154
+ # Instantiate dataset object
155
+ dataset = CellLoader(
156
+ sequence_mode="embedding",
157
+ vocab="esm2",
158
+ split_key="val",
159
+ crop_method="center",
160
+ resize=600,
161
+ crop_size=256,
162
+ text_seq_len=1000,
163
+ pad_mode="end",
164
+ threshold="median",
165
+ )
166
+
167
+ # Check if sequence is provided and valid
168
+ if len(sequence_input) == 0:
169
+ sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
170
+ # Use default sequence for GFP and print warning
171
+ print("Warning: No sequence provided. Using default sequence for GFP.")
172
+
173
+ # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
174
+ sequence = dataset.tokenize_sequence(sequence_input)
175
+
176
+ # Check if nucleus image path is provided and valid
177
+ if not os.path.exists(nucleus_image_path):
178
+ # Use default nucleus image from dataset and print warning
179
+ nucleus_image = dataset[0]["nucleus"]
180
+ print(
181
+ "Warning: No nucleus image provided. Using default nucleus image from dataset."
182
+ )
183
+ else:
184
+ # Load nucleus image from provided path
185
+ nucleus_image = process_image(nucleus_image_path)
186
+
187
+ # Load model config and set ckpt_path if not provided in config
188
+ config = OmegaConf.load(model_config_path)
189
+ if config["model"]["params"]["ckpt_path"] is None:
190
+ config["model"]["params"]["ckpt_path"] = model_ckpt_path
191
+
192
+ # Set condition_model_path and vqgan_model_path to None
193
+ config["model"]["params"]["condition_model_path"] = None
194
+ config["model"]["params"]["vqgan_model_path"] = None
195
+
196
+ # Instantiate model from config and move to device
197
+ model = instantiate_from_config(config).to(device)
198
+
199
+ # Sample from model using provided sequence and nucleus image
200
+ _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
201
+ text=sequence,
202
+ condition=nucleus_image,
203
+ timesteps=1,
204
+ temperature=1,
205
+ progress=True,
206
+ )
207
+
208
+ # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
209
+ predicted_threshold = predicted_threshold.cpu()[0, 0]
210
+ predicted_heatmap = predicted_heatmap.cpu()[0, 0]
211
+
212
+ # Create 3 or 4 panel plot depending on whether protein image path is provided
213
+ fig, axs = plt.subplots(1, 3 if protein_image_path is None else 4)
214
+ axs[0].imshow(nucleus_image)
215
+ axs[0].set_title("Nucleus Input")
216
+ axs[1].imshow(predicted_threshold)
217
+ axs[1].set_title("Predicted Threshold")
218
+ if protein_image_path is not None:
219
+ protein_image = process_image(protein_image_path)
220
+ axs[2].imshow(protein_image)
221
+ axs[2].set_title("Protein Image")
222
+ axs[-1].imshow(predicted_heatmap)
223
+ axs[-1].set_title("Predicted Heatmap")
224
+ plt.show()
225
+
226
+
227
+ if __name__ == "__main__":
228
+ # Parse command line arguments for input parameters
229
+ parser = argparse.ArgumentParser(
230
+ description="Run Celle model with provided inputs."
231
+ )
232
+ parser.add_argument("--mode", type=str, default="", help="Sequence or Image")
233
+ parser.add_argument(
234
+ "--sequence", type=str, default="", help="Path to sequence file"
235
+ )
236
+ parser.add_argument(
237
+ "--nucleus_image_path",
238
+ type=str,
239
+ default="images/nucleus.jpg",
240
+ help="Path to nucleus image",
241
+ )
242
+ parser.add_argument(
243
+ "--protein_image_path",
244
+ type=str,
245
+ default=None,
246
+ help="Path to protein image (optional)",
247
+ )
248
+ parser.add_argument(
249
+ "--model_ckpt_path", type=str, required=True, help="Path to model checkpoint"
250
+ )
251
+ parser.add_argument(
252
+ "--model_config_path", type=str, required=True, help="Path to model config"
253
+ )
254
+ parser.add_argument(
255
+ "--device", type=str, default="cpu", required=True, help="device"
256
+ )
257
+ args = parser.parse_args()
258
+
259
+ run_model(
260
+ args.mode,
261
+ args.sequence,
262
+ args.nucleus_image_path,
263
+ args.protein_image_path,
264
+ args.model_ckpt_path,
265
+ args.model_config_path,
266
+ args.device
267
+ )
requirements.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.4
2
+ aiosignal==1.3.1
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==3.6.2
5
+ argon2-cffi==21.3.0
6
+ argon2-cffi-bindings==21.2.0
7
+ arrow==1.2.3
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==23.1.0
11
+ axial-positional-embedding==0.2.1
12
+ backcall==0.2.0
13
+ beautifulsoup4==4.12.2
14
+ bleach==6.0.0
15
+ blessed==1.20.0
16
+ certifi==2023.5.7
17
+ cffi==1.15.1
18
+ charset-normalizer==3.1.0
19
+ click==8.1.3
20
+ cmake==3.26.3
21
+ comm==0.1.3
22
+ contourpy==1.0.7
23
+ croniter==1.3.14
24
+ cycler==0.11.0
25
+ dateutils==0.6.12
26
+ debugpy==1.6.7
27
+ decorator==5.1.1
28
+ deepdiff==6.3.0
29
+ defusedxml==0.7.1
30
+ einops==0.6.1
31
+ executing==1.2.0
32
+ fair-esm==2.0.0
33
+ fastapi==0.88.0
34
+ fastjsonschema==2.16.3
35
+ filelock==3.12.0
36
+ fonttools==4.39.4
37
+ fqdn==1.5.1
38
+ frozenlist==1.3.3
39
+ fsspec==2023.5.0
40
+ h11==0.14.0
41
+ idna==3.4
42
+ inquirer==3.1.3
43
+ ipykernel==6.23.0
44
+ ipython==8.13.2
45
+ ipython-genutils==0.2.0
46
+ ipywidgets==8.0.6
47
+ isoduration==20.11.0
48
+ itsdangerous==2.1.2
49
+ jedi==0.18.2
50
+ Jinja2==3.1.2
51
+ jsonpointer==2.3
52
+ jsonschema==4.17.3
53
+ jupyter==1.0.0
54
+ jupyter-console==6.6.3
55
+ jupyter-events==0.6.3
56
+ jupyter_client==8.2.0
57
+ jupyter_core==5.3.0
58
+ jupyter_server==2.5.0
59
+ jupyter_server_terminals==0.4.4
60
+ jupyterlab-pygments==0.2.2
61
+ jupyterlab-widgets==3.0.7
62
+ kiwisolver==1.4.4
63
+ lightning==2.0.2
64
+ lightning-cloud==0.5.34
65
+ lightning-utilities==0.8.0
66
+ lit==16.0.3
67
+ markdown-it-py==2.2.0
68
+ MarkupSafe==2.1.2
69
+ matplotlib==3.7.1
70
+ matplotlib-inline==0.1.6
71
+ mdurl==0.1.2
72
+ mistune==2.0.5
73
+ mpmath==1.3.0
74
+ multidict==6.0.4
75
+ nbclassic==1.0.0
76
+ nbclient==0.7.4
77
+ nbconvert==7.4.0
78
+ nbformat==5.8.0
79
+ nest-asyncio==1.5.6
80
+ networkx==3.1
81
+ notebook==6.5.4
82
+ notebook_shim==0.2.3
83
+ numpy==1.24.3
84
+ nvidia-cublas-cu11==11.10.3.66
85
+ nvidia-cuda-cupti-cu11==11.7.101
86
+ nvidia-cuda-nvrtc-cu11==11.7.99
87
+ nvidia-cuda-runtime-cu11==11.7.99
88
+ nvidia-cudnn-cu11==8.5.0.96
89
+ nvidia-cufft-cu11==10.9.0.58
90
+ nvidia-curand-cu11==10.2.10.91
91
+ nvidia-cusolver-cu11==11.4.0.1
92
+ nvidia-cusparse-cu11==11.7.4.91
93
+ nvidia-nccl-cu11==2.14.3
94
+ nvidia-nvtx-cu11==11.7.91
95
+ omegaconf==2.3.0
96
+ ordered-set==4.1.0
97
+ packaging==23.1
98
+ pandas==2.0.1
99
+ pandocfilters==1.5.0
100
+ parso==0.8.3
101
+ pexpect==4.8.0
102
+ pickleshare==0.7.5
103
+ Pillow==9.5.0
104
+ platformdirs==3.5.1
105
+ prometheus-client==0.16.0
106
+ prompt-toolkit==3.0.38
107
+ psutil==5.9.5
108
+ ptyprocess==0.7.0
109
+ pure-eval==0.2.2
110
+ pycparser==2.21
111
+ pydantic==1.10.7
112
+ Pygments==2.15.1
113
+ PyJWT==2.7.0
114
+ pyparsing==3.0.9
115
+ pyrsistent==0.19.3
116
+ python-dateutil==2.8.2
117
+ python-editor==1.0.4
118
+ python-json-logger==2.0.7
119
+ python-multipart==0.0.6
120
+ pytorch-lightning==1.9.0
121
+ pytz==2023.3
122
+ PyYAML==6.0
123
+ pyzmq==25.0.2
124
+ qtconsole==5.4.3
125
+ QtPy==2.3.1
126
+ readchar==4.0.5
127
+ requests==2.30.0
128
+ rfc3339-validator==0.1.4
129
+ rfc3986-validator==0.1.1
130
+ rich==13.3.5
131
+ rotary-embedding-torch==0.2.3
132
+ Send2Trash==1.8.2
133
+ six==1.16.0
134
+ sniffio==1.3.0
135
+ soupsieve==2.4.1
136
+ stack-data==0.6.2
137
+ starlette==0.22.0
138
+ starsessions==1.3.0
139
+ sympy==1.12
140
+ terminado==0.17.1
141
+ tinycss2==1.2.1
142
+ torch==2.0.0
143
+ torchmetrics==0.11.4
144
+ torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.15.1%2Bcpu-cp311-cp311-linux_x86_64.whl
145
+ tornado==6.3.1
146
+ tqdm==4.65.0
147
+ traitlets==5.9.0
148
+ triton==2.0.0
149
+ typing_extensions==4.5.0
150
+ tzdata==2023.3
151
+ uri-template==1.2.0
152
+ urllib3==2.0.2
153
+ uvicorn==0.22.0
154
+ wcwidth==0.2.6
155
+ webcolors==1.13
156
+ webencodings==0.5.1
157
+ websocket-client==1.5.1
158
+ websockets==11.0.3
159
+ widgetsnbextension==4.0.7
160
+ yarl==1.9.2
taming/lr_scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n):
33
+ return self.schedule(n)
34
+
taming/models/cond_transformer.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from main import instantiate_from_config
7
+ from taming.modules.util import SOSProvider
8
+
9
+
10
+ def disabled_train(self, mode=True):
11
+ """Overwrite model.train with this function to make sure train/eval mode
12
+ does not change anymore."""
13
+ return self
14
+
15
+
16
+ class Net2NetTransformer(pl.LightningModule):
17
+ def __init__(self,
18
+ transformer_config,
19
+ first_stage_config,
20
+ cond_stage_config,
21
+ permuter_config=None,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ first_stage_key="image",
25
+ cond_stage_key="depth",
26
+ downsample_cond_size=-1,
27
+ pkeep=1.0,
28
+ sos_token=0,
29
+ unconditional=False,
30
+ ):
31
+ super().__init__()
32
+ self.be_unconditional = unconditional
33
+ self.sos_token = sos_token
34
+ self.first_stage_key = first_stage_key
35
+ self.cond_stage_key = cond_stage_key
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if permuter_config is None:
39
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
+ self.permuter = instantiate_from_config(config=permuter_config)
41
+ self.transformer = instantiate_from_config(config=transformer_config)
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+ self.downsample_cond_size = downsample_cond_size
46
+ self.pkeep = pkeep
47
+
48
+ def init_from_ckpt(self, path, ignore_keys=list()):
49
+ sd = torch.load(path, map_location="cpu")["state_dict"]
50
+ for k in sd.keys():
51
+ for ik in ignore_keys:
52
+ if k.startswith(ik):
53
+ self.print("Deleting key {} from state_dict.".format(k))
54
+ del sd[k]
55
+ self.load_state_dict(sd, strict=False)
56
+ print(f"Restored from {path}")
57
+
58
+ def init_first_stage_from_ckpt(self, config):
59
+ model = instantiate_from_config(config)
60
+ model = model.eval()
61
+ model.train = disabled_train
62
+ self.first_stage_model = model
63
+
64
+ def init_cond_stage_from_ckpt(self, config):
65
+ if config == "__is_first_stage__":
66
+ print("Using first stage also as cond stage.")
67
+ self.cond_stage_model = self.first_stage_model
68
+ elif config == "__is_unconditional__" or self.be_unconditional:
69
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
+ f"Prepending {self.sos_token} as a sos token.")
71
+ self.be_unconditional = True
72
+ self.cond_stage_key = self.first_stage_key
73
+ self.cond_stage_model = SOSProvider(self.sos_token)
74
+ else:
75
+ model = instantiate_from_config(config)
76
+ model = model.eval()
77
+ model.train = disabled_train
78
+ self.cond_stage_model = model
79
+
80
+ def forward(self, x, c):
81
+ # one step to produce the logits
82
+ # x = target
83
+ # c = nucleus
84
+ _, z_indices = self.encode_to_z(x)
85
+ _, c_indices = self.encode_to_c(c)
86
+
87
+ if self.training and self.pkeep < 1.0:
88
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
89
+ device=z_indices.device))
90
+ mask = mask.round().to(dtype=torch.int64)
91
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
92
+ a_indices = mask*z_indices+(1-mask)*r_indices
93
+ else:
94
+ a_indices = z_indices
95
+
96
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
97
+
98
+ # target includes all sequence elements (no need to handle first one
99
+ # differently because we are conditioning)
100
+ target = z_indices
101
+ # make the prediction
102
+ logits, _ = self.transformer(cz_indices[:, :-1])
103
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
104
+ logits = logits[:, c_indices.shape[1]-1:]
105
+
106
+ return logits, target
107
+
108
+ def top_k_logits(self, logits, k):
109
+ v, ix = torch.topk(logits, k)
110
+ out = logits.clone()
111
+ out[out < v[..., [-1]]] = -float('Inf')
112
+ return out
113
+
114
+ @torch.no_grad()
115
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
116
+ callback=lambda k: None):
117
+ x = torch.cat((c,x),dim=1)
118
+ block_size = self.transformer.get_block_size()
119
+ assert not self.transformer.training
120
+ if self.pkeep <= 0.0:
121
+ # one pass suffices since input is pure noise anyway
122
+ assert len(x.shape)==2
123
+ noise_shape = (x.shape[0], steps-1)
124
+ #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
125
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
126
+ x = torch.cat((x,noise),dim=1)
127
+ logits, _ = self.transformer(x)
128
+ # take all logits for now and scale by temp
129
+ logits = logits / temperature
130
+ # optionally crop probabilities to only the top k options
131
+ if top_k is not None:
132
+ logits = self.top_k_logits(logits, top_k)
133
+ # apply softmax to convert to probabilities
134
+ probs = F.softmax(logits, dim=-1)
135
+ # sample from the distribution or take the most likely
136
+ if sample:
137
+ shape = probs.shape
138
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
139
+ ix = torch.multinomial(probs, num_samples=1)
140
+ probs = probs.reshape(shape[0],shape[1],shape[2])
141
+ ix = ix.reshape(shape[0],shape[1])
142
+ else:
143
+ _, ix = torch.topk(probs, k=1, dim=-1)
144
+ # cut off conditioning
145
+ x = ix[:, c.shape[1]-1:]
146
+ else:
147
+ for k in range(steps):
148
+ callback(k)
149
+ assert x.size(1) <= block_size # make sure model can see conditioning
150
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
151
+ logits, _ = self.transformer(x_cond)
152
+ # pluck the logits at the final step and scale by temperature
153
+ logits = logits[:, -1, :] / temperature
154
+ # optionally crop probabilities to only the top k options
155
+ if top_k is not None:
156
+ logits = self.top_k_logits(logits, top_k)
157
+ # apply softmax to convert to probabilities
158
+ probs = F.softmax(logits, dim=-1)
159
+ # sample from the distribution or take the most likely
160
+ if sample:
161
+ ix = torch.multinomial(probs, num_samples=1)
162
+ else:
163
+ _, ix = torch.topk(probs, k=1, dim=-1)
164
+ # append to the sequence and continue
165
+ x = torch.cat((x, ix), dim=1)
166
+ # cut off conditioning
167
+ x = x[:, c.shape[1]:]
168
+ return x
169
+
170
+ @torch.no_grad()
171
+ def encode_to_z(self, x):
172
+ quant_z, _, info = self.first_stage_model.encode(x)
173
+ indices = info[2].view(quant_z.shape[0], -1)
174
+ indices = self.permuter(indices)
175
+ return quant_z, indices
176
+
177
+ @torch.no_grad()
178
+ def encode_to_c(self, c):
179
+ if self.downsample_cond_size > -1:
180
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
181
+
182
+ #quant_c, _, info = self.cond_stage_model.encode(x)
183
+ #indices = info[2].view(quant_c.shape[0], -1)
184
+ #indices = self.permuter(indices)
185
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
186
+ if len(indices.shape) != 2:
187
+ indices = indices.view(c.shape[0], -1)
188
+ return quant_c, indices
189
+
190
+ @torch.no_grad()
191
+ def decode_to_img(self, index, zshape):
192
+ index = self.permuter(index, reverse=True)
193
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
194
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
195
+ index.reshape(-1), shape=bhwc)
196
+ x = self.first_stage_model.decode(quant_z)
197
+ return x
198
+
199
+ @torch.no_grad()
200
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
201
+ log = dict()
202
+
203
+ N = 4
204
+ if lr_interface:
205
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
206
+ else:
207
+ x, c = self.get_xc(batch, N)
208
+ x = x.to(device=self.device)
209
+ c = c.to(device=self.device)
210
+
211
+ quant_z, z_indices = self.encode_to_z(x)
212
+ quant_c, c_indices = self.encode_to_c(c)
213
+
214
+ # create a "half"" sample
215
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
216
+ index_sample = self.sample(z_start_indices, c_indices,
217
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
218
+ temperature=temperature if temperature is not None else 1.0,
219
+ sample=True,
220
+ top_k=top_k if top_k is not None else 100,
221
+ callback=callback if callback is not None else lambda k: None)
222
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
223
+
224
+ # sample
225
+ z_start_indices = z_indices[:, :0]
226
+ index_sample = self.sample(z_start_indices, c_indices,
227
+ steps=z_indices.shape[1],
228
+ temperature=temperature if temperature is not None else 1.0,
229
+ sample=True,
230
+ top_k=top_k if top_k is not None else 100,
231
+ callback=callback if callback is not None else lambda k: None)
232
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
233
+
234
+ # det sample
235
+ z_start_indices = z_indices[:, :0]
236
+ index_sample = self.sample(z_start_indices, c_indices,
237
+ steps=z_indices.shape[1],
238
+ sample=False,
239
+ callback=callback if callback is not None else lambda k: None)
240
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
241
+
242
+ # reconstruction
243
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
244
+
245
+ log["inputs"] = x
246
+ log["reconstructions"] = x_rec
247
+
248
+ if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target":
249
+ cond_rec = self.cond_stage_model.decode(quant_c)
250
+ if self.cond_stage_key == "segmentation":
251
+ # get image from segmentation mask
252
+ num_classes = cond_rec.shape[1]
253
+
254
+ c = torch.argmax(c, dim=1, keepdim=True)
255
+ c = F.one_hot(c, num_classes=num_classes)
256
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
257
+ c = self.cond_stage_model.to_rgb(c)
258
+
259
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
260
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
261
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
262
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
263
+ log["conditioning_rec"] = cond_rec
264
+ log["conditioning"] = c
265
+
266
+ log["samples_half"] = x_sample
267
+ log["samples_nopix"] = x_sample_nopix
268
+ log["samples_det"] = x_sample_det
269
+ return log
270
+
271
+ def get_input(self, key, batch):
272
+ x = batch[key]
273
+ if len(x.shape) == 3:
274
+ x = x[..., None]
275
+ #if len(x.shape) == 4:
276
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
277
+ if x.dtype == torch.double:
278
+ x = x.float()
279
+ return x
280
+
281
+ def get_xc(self, batch, N=None):
282
+ x = self.get_input(self.first_stage_key, batch)
283
+ c = self.get_input(self.cond_stage_key, batch)
284
+ if N is not None:
285
+ x = x[:N]
286
+ c = c[:N]
287
+ return x, c
288
+
289
+ def shared_step(self, batch):
290
+ x, c = self.get_xc(batch)
291
+ logits, target = self(x, c)
292
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
293
+ return loss
294
+
295
+ def training_step(self, batch, batch_idx):
296
+ loss = self.shared_step(batch)
297
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
298
+ return loss
299
+
300
+ def validation_step(self, batch, batch_idx):
301
+ loss = self.shared_step(batch)
302
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
303
+ return loss
304
+
305
+ def configure_optimizers(self):
306
+ """
307
+ Following minGPT:
308
+ This long function is unfortunately doing something very simple and is being very defensive:
309
+ We are separating out all parameters of the model into two buckets: those that will experience
310
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
311
+ We are then returning the PyTorch optimizer object.
312
+ """
313
+ # separate out all parameters to those that will and won't experience regularizing weight decay
314
+ decay = set()
315
+ no_decay = set()
316
+ whitelist_weight_modules = (torch.nn.Linear, )
317
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
318
+ for mn, m in self.transformer.named_modules():
319
+ for pn, p in m.named_parameters():
320
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
321
+
322
+ if pn.endswith('bias'):
323
+ # all biases will not be decayed
324
+ no_decay.add(fpn)
325
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
326
+ # weights of whitelist modules will be weight decayed
327
+ decay.add(fpn)
328
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
329
+ # weights of blacklist modules will NOT be weight decayed
330
+ no_decay.add(fpn)
331
+
332
+ # special case the position embedding parameter in the root GPT module as not decayed
333
+ no_decay.add('pos_emb')
334
+
335
+ # validate that we considered every parameter
336
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
337
+ inter_params = decay & no_decay
338
+ union_params = decay | no_decay
339
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
340
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
341
+ % (str(param_dict.keys() - union_params), )
342
+
343
+ # create the pytorch optimizer object
344
+ optim_groups = [
345
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
346
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
347
+ ]
348
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
349
+ return optimizer
taming/models/dummy_cond_stage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+
4
+ class DummyCondStage:
5
+ def __init__(self, conditional_key):
6
+ self.conditional_key = conditional_key
7
+ self.train = None
8
+
9
+ def eval(self):
10
+ return self
11
+
12
+ @staticmethod
13
+ def encode(c: Tensor):
14
+ return c, None, (None, None, c)
15
+
16
+ @staticmethod
17
+ def decode(c: Tensor):
18
+ return c
19
+
20
+ @staticmethod
21
+ def to_rgb(c: Tensor):
22
+ return c
taming/models/vqgan.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+ from celle_taming_main import instantiate_from_config
6
+
7
+ from taming.modules.diffusionmodules.model import Encoder, Decoder
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from taming.modules.vqvae.quantize import GumbelQuantize
10
+ from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
+
12
+
13
+ class VQModel(pl.LightningModule):
14
+ def __init__(
15
+ self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ remap=None,
26
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
27
+ ):
28
+ super().__init__()
29
+ self.image_key = image_key
30
+ self.encoder = Encoder(**ddconfig)
31
+ self.decoder = Decoder(**ddconfig)
32
+ self.loss = instantiate_from_config(lossconfig)
33
+ self.quantize = VectorQuantizer(
34
+ n_embed,
35
+ embed_dim,
36
+ beta=0.25,
37
+ remap=remap,
38
+ sane_index_shape=sane_index_shape,
39
+ )
40
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
41
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
42
+ if ckpt_path is not None:
43
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
44
+ self.image_key = image_key
45
+ if colorize_nlabels is not None:
46
+ assert type(colorize_nlabels) == int
47
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
48
+ if monitor is not None:
49
+ self.monitor = monitor
50
+
51
+ def init_from_ckpt(self, path, ignore_keys=list()):
52
+ sd = torch.load(path, map_location="cpu")["state_dict"]
53
+ keys = list(sd.keys())
54
+ for k in keys:
55
+ for ik in ignore_keys:
56
+ if k.startswith(ik):
57
+ print("Deleting key {} from state_dict.".format(k))
58
+ del sd[k]
59
+ self.load_state_dict(sd, strict=False)
60
+ print(f"Restored from {path}")
61
+
62
+ def encode(self, x):
63
+ h = self.encoder(x)
64
+ h = self.quant_conv(h)
65
+ quant, emb_loss, info = self.quantize(h)
66
+ return quant, emb_loss, info
67
+
68
+ def decode(self, quant):
69
+ quant = self.post_quant_conv(quant)
70
+ dec = self.decoder(quant)
71
+ return dec
72
+
73
+ def decode_code(self, code_b):
74
+ quant_b = self.quantize.embed_code(code_b)
75
+ dec = self.decode(quant_b)
76
+ return dec
77
+
78
+ def forward(self, input):
79
+ quant, diff, _ = self.encode(input)
80
+ dec = self.decode(quant)
81
+ return dec, diff
82
+
83
+ def get_input(self, batch, k):
84
+
85
+ if k == "mixed":
86
+ keys = ["nucleus", "target"]
87
+ index = torch.randint(low=0, high=2, size=(1,), dtype=int).item()
88
+ k = keys[index]
89
+
90
+ x = batch[k]
91
+ if len(x.shape) == 3:
92
+ x = x[..., None]
93
+
94
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
95
+ return x
96
+
97
+ def training_step(self, batch, batch_idx=None, optimizer_idx=0):
98
+
99
+ if type(batch) == dict:
100
+
101
+ x = self.get_input(batch, self.image_key)
102
+
103
+ else:
104
+ x = batch
105
+
106
+ xrec, qloss = self(
107
+ x,
108
+ )
109
+
110
+ if optimizer_idx == 0:
111
+ # autoencode
112
+ aeloss, log_dict_ae = self.loss(
113
+ qloss,
114
+ x,
115
+ xrec,
116
+ optimizer_idx,
117
+ self.global_step,
118
+ last_layer=self.get_last_layer(),
119
+ split="train",
120
+ )
121
+
122
+ self.log(
123
+ "train/aeloss",
124
+ aeloss,
125
+ prog_bar=True,
126
+ logger=True,
127
+ on_step=True,
128
+ on_epoch=True,
129
+ sync_dist=True,
130
+ )
131
+ self.log_dict(
132
+ log_dict_ae,
133
+ prog_bar=False,
134
+ logger=True,
135
+ on_step=True,
136
+ on_epoch=True,
137
+ sync_dist=True,
138
+ )
139
+ return aeloss
140
+
141
+ if optimizer_idx == 1:
142
+ # discriminator
143
+ discloss, log_dict_disc = self.loss(
144
+ qloss,
145
+ x,
146
+ xrec,
147
+ optimizer_idx,
148
+ self.global_step,
149
+ last_layer=self.get_last_layer(),
150
+ split="train",
151
+ )
152
+ self.log(
153
+ "train/discloss",
154
+ discloss,
155
+ prog_bar=True,
156
+ logger=True,
157
+ on_step=True,
158
+ on_epoch=True,
159
+ sync_dist=True,
160
+ )
161
+ self.log_dict(
162
+ log_dict_disc,
163
+ prog_bar=False,
164
+ logger=True,
165
+ on_step=True,
166
+ on_epoch=True,
167
+ sync_dist=True,
168
+ )
169
+ return discloss
170
+
171
+ def validation_step(self, batch, batch_idx):
172
+
173
+ if type(batch) == dict:
174
+
175
+ x = self.get_input(batch, self.image_key)
176
+
177
+ else:
178
+ x = batch
179
+
180
+ xrec, qloss = self(x)
181
+ aeloss, log_dict_ae = self.loss(
182
+ qloss,
183
+ x,
184
+ xrec,
185
+ 0,
186
+ self.global_step,
187
+ last_layer=self.get_last_layer(),
188
+ split="val",
189
+ )
190
+
191
+ discloss, log_dict_disc = self.loss(
192
+ qloss,
193
+ x,
194
+ xrec,
195
+ 1,
196
+ self.global_step,
197
+ last_layer=self.get_last_layer(),
198
+ split="val",
199
+ )
200
+ # rec_loss = log_dict_ae["val/rec_loss"]
201
+ # self.log(
202
+ # "val/rec_loss",
203
+ # rec_loss,
204
+ # prog_bar=True,
205
+ # logger=True,
206
+ # on_step=True,
207
+ # on_epoch=True,
208
+ # sync_dist=True,
209
+ # )
210
+ # self.log(
211
+ # "val/aeloss",
212
+ # aeloss,
213
+ # prog_bar=True,
214
+ # logger=True,
215
+ # on_step=True,
216
+ # on_epoch=True,
217
+ # sync_dist=True,
218
+ # )
219
+
220
+ for key, value in log_dict_disc.items():
221
+ if key in log_dict_ae:
222
+ log_dict_ae[key].extend(value)
223
+ else:
224
+ log_dict_ae[key] = value
225
+
226
+ self.log_dict(log_dict_ae, sync_dist=True)
227
+ return self.log_dict
228
+
229
+ def configure_optimizers(self):
230
+ lr = self.learning_rate
231
+ opt_ae = torch.optim.Adam(
232
+ list(self.encoder.parameters())
233
+ + list(self.decoder.parameters())
234
+ + list(self.quantize.parameters())
235
+ + list(self.quant_conv.parameters())
236
+ + list(self.post_quant_conv.parameters()),
237
+ lr=lr,
238
+ betas=(0.5, 0.9),
239
+ )
240
+ opt_disc = torch.optim.Adam(
241
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
242
+ )
243
+ return [opt_ae, opt_disc], []
244
+
245
+ def get_last_layer(self):
246
+ return self.decoder.conv_out.weight
247
+
248
+ def log_images(self, batch, **kwargs):
249
+ log = dict()
250
+ x = self.get_input(batch, self.image_key)
251
+ x = x.to(self.device)
252
+ xrec, _ = self(x)
253
+ if x.shape[1] > 3:
254
+ # colorize with random projection
255
+ assert xrec.shape[1] > 3
256
+ x = self.to_rgb(x)
257
+ xrec = self.to_rgb(xrec)
258
+ log["inputs"] = x
259
+ log["reconstructions"] = xrec
260
+ return log
261
+
262
+ def to_rgb(self, x):
263
+ assert self.image_key == "segmentation"
264
+ if not hasattr(self, "colorize"):
265
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
266
+ x = F.conv2d(x, weight=self.colorize)
267
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
268
+ return x
269
+
270
+
271
+ class VQSegmentationModel(VQModel):
272
+ def __init__(self, n_labels, *args, **kwargs):
273
+ super().__init__(*args, **kwargs)
274
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
275
+
276
+ def configure_optimizers(self):
277
+ lr = self.learning_rate
278
+ opt_ae = torch.optim.Adam(
279
+ list(self.encoder.parameters())
280
+ + list(self.decoder.parameters())
281
+ + list(self.quantize.parameters())
282
+ + list(self.quant_conv.parameters())
283
+ + list(self.post_quant_conv.parameters()),
284
+ lr=lr,
285
+ betas=(0.5, 0.9),
286
+ )
287
+ return opt_ae
288
+
289
+ def training_step(self, batch, batch_idx):
290
+ x = self.get_input(batch, self.image_key)
291
+ xrec, qloss = self(x)
292
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
293
+ self.log_dict(
294
+ log_dict_ae,
295
+ prog_bar=False,
296
+ logger=True,
297
+ on_step=True,
298
+ on_epoch=True,
299
+ sync_dist=True,
300
+ )
301
+ return aeloss
302
+
303
+ def validation_step(self, batch, batch_idx):
304
+ x = self.get_input(batch, self.image_key)
305
+ xrec, qloss = self(x)
306
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
307
+ self.log_dict(
308
+ log_dict_ae,
309
+ prog_bar=False,
310
+ logger=True,
311
+ on_step=True,
312
+ on_epoch=True,
313
+ sync_dist=True,
314
+ )
315
+ total_loss = log_dict_ae["val/total_loss"]
316
+ self.log(
317
+ "val/total_loss",
318
+ total_loss,
319
+ prog_bar=True,
320
+ logger=True,
321
+ on_step=True,
322
+ on_epoch=True,
323
+ sync_dist=True,
324
+ )
325
+ return aeloss
326
+
327
+ @torch.no_grad()
328
+ def log_images(self, batch, **kwargs):
329
+ log = dict()
330
+ x = self.get_input(batch, self.image_key)
331
+ x = x.to(self.device)
332
+ xrec, _ = self(x)
333
+ if x.shape[1] > 3:
334
+ # colorize with random projection
335
+ assert xrec.shape[1] > 3
336
+ # convert logits to indices
337
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
338
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
339
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
340
+ x = self.to_rgb(x)
341
+ xrec = self.to_rgb(xrec)
342
+ log["inputs"] = x
343
+ log["reconstructions"] = xrec
344
+ return log
345
+
346
+
347
+ class VQNoDiscModel(VQModel):
348
+ def __init__(
349
+ self,
350
+ ddconfig,
351
+ lossconfig,
352
+ n_embed,
353
+ embed_dim,
354
+ ckpt_path=None,
355
+ ignore_keys=[],
356
+ image_key="image",
357
+ colorize_nlabels=None,
358
+ ):
359
+ super().__init__(
360
+ ddconfig=ddconfig,
361
+ lossconfig=lossconfig,
362
+ n_embed=n_embed,
363
+ embed_dim=embed_dim,
364
+ ckpt_path=ckpt_path,
365
+ ignore_keys=ignore_keys,
366
+ image_key=image_key,
367
+ colorize_nlabels=colorize_nlabels,
368
+ )
369
+
370
+ def training_step(self, batch, batch_idx):
371
+ x = self.get_input(batch, self.image_key)
372
+ xrec, qloss = self(x)
373
+ # autoencode
374
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
375
+ output = pl.TrainResult(minimize=aeloss)
376
+ output.log(
377
+ "train/aeloss",
378
+ aeloss,
379
+ prog_bar=True,
380
+ logger=True,
381
+ on_step=True,
382
+ on_epoch=True,
383
+ )
384
+ output.log_dict(
385
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
386
+ )
387
+ return output
388
+
389
+ def validation_step(self, batch, batch_idx):
390
+ x = self.get_input(batch, self.image_key)
391
+ xrec, qloss = self(x)
392
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
393
+ rec_loss = log_dict_ae["val/rec_loss"]
394
+ output = pl.EvalResult(checkpoint_on=rec_loss)
395
+ output.log(
396
+ "val/rec_loss",
397
+ rec_loss,
398
+ prog_bar=True,
399
+ logger=True,
400
+ on_step=True,
401
+ on_epoch=True,
402
+ )
403
+ output.log(
404
+ "val/aeloss",
405
+ aeloss,
406
+ prog_bar=True,
407
+ logger=True,
408
+ on_step=True,
409
+ on_epoch=True,
410
+ )
411
+ output.log_dict(log_dict_ae)
412
+
413
+ return output
414
+
415
+ def configure_optimizers(self):
416
+ optimizer = torch.optim.Adam(
417
+ list(self.encoder.parameters())
418
+ + list(self.decoder.parameters())
419
+ + list(self.quantize.parameters())
420
+ + list(self.quant_conv.parameters())
421
+ + list(self.post_quant_conv.parameters()),
422
+ lr=self.learning_rate,
423
+ betas=(0.5, 0.9),
424
+ )
425
+ return optimizer
426
+
427
+
428
+ class GumbelVQ(VQModel):
429
+ def __init__(
430
+ self,
431
+ ddconfig,
432
+ lossconfig,
433
+ n_embed,
434
+ embed_dim,
435
+ temperature_scheduler_config,
436
+ ckpt_path=None,
437
+ ignore_keys=[],
438
+ image_key="image",
439
+ colorize_nlabels=None,
440
+ monitor=None,
441
+ kl_weight=1e-8,
442
+ remap=None,
443
+ ):
444
+
445
+ z_channels = ddconfig["z_channels"]
446
+ super().__init__(
447
+ ddconfig,
448
+ lossconfig,
449
+ n_embed,
450
+ embed_dim,
451
+ ckpt_path=None,
452
+ ignore_keys=ignore_keys,
453
+ image_key=image_key,
454
+ colorize_nlabels=colorize_nlabels,
455
+ monitor=monitor,
456
+ )
457
+
458
+ self.loss.n_classes = n_embed
459
+ self.vocab_size = n_embed
460
+
461
+ self.quantize = GumbelQuantize(
462
+ z_channels,
463
+ embed_dim,
464
+ n_embed=n_embed,
465
+ kl_weight=kl_weight,
466
+ temp_init=1.0,
467
+ remap=remap,
468
+ )
469
+
470
+ self.temperature_scheduler = instantiate_from_config(
471
+ temperature_scheduler_config
472
+ ) # annealing of temp
473
+
474
+ if ckpt_path is not None:
475
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
476
+
477
+ def temperature_scheduling(self):
478
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
479
+
480
+ def encode_to_prequant(self, x):
481
+ h = self.encoder(x)
482
+ h = self.quant_conv(h)
483
+ return h
484
+
485
+ def decode_code(self, code_b):
486
+ raise NotImplementedError
487
+
488
+ def training_step(self, batch, batch_idx=None, optimizer_idx=0):
489
+ self.temperature_scheduling()
490
+ x = self.get_input(batch, self.image_key)
491
+ xrec, qloss = self(x)
492
+
493
+ if optimizer_idx == 0:
494
+ # autoencode
495
+ aeloss, log_dict_ae = self.loss(
496
+ qloss,
497
+ x,
498
+ xrec,
499
+ optimizer_idx,
500
+ self.global_step,
501
+ last_layer=self.get_last_layer(),
502
+ split="train",
503
+ )
504
+
505
+ self.log_dict(
506
+ log_dict_ae,
507
+ prog_bar=False,
508
+ logger=True,
509
+ on_step=True,
510
+ on_epoch=True,
511
+ sync_dist=True,
512
+ )
513
+ self.log(
514
+ "temperature",
515
+ self.quantize.temperature,
516
+ prog_bar=False,
517
+ logger=True,
518
+ on_step=True,
519
+ on_epoch=True,
520
+ sync_dist=True,
521
+ )
522
+ return aeloss
523
+
524
+ if optimizer_idx == 1:
525
+ # discriminator
526
+ discloss, log_dict_disc = self.loss(
527
+ qloss,
528
+ x,
529
+ xrec,
530
+ optimizer_idx,
531
+ self.global_step,
532
+ last_layer=self.get_last_layer(),
533
+ split="train",
534
+ )
535
+ self.log_dict(
536
+ log_dict_disc,
537
+ prog_bar=False,
538
+ logger=True,
539
+ on_step=True,
540
+ on_epoch=True,
541
+ sync_dist=True,
542
+ )
543
+ return discloss
544
+
545
+ def validation_step(self, batch, batch_idx):
546
+ x = self.get_input(batch, self.image_key)
547
+ xrec, qloss = self(x)
548
+ aeloss, log_dict_ae = self.loss(
549
+ qloss,
550
+ x,
551
+ xrec,
552
+ 0,
553
+ self.global_step,
554
+ last_layer=self.get_last_layer(),
555
+ split="val",
556
+ )
557
+
558
+ discloss, log_dict_disc = self.loss(
559
+ qloss,
560
+ x,
561
+ xrec,
562
+ 1,
563
+ self.global_step,
564
+ last_layer=self.get_last_layer(),
565
+ split="val",
566
+ )
567
+ rec_loss = log_dict_ae["val/rec_loss"]
568
+ self.log(
569
+ "val/rec_loss",
570
+ rec_loss,
571
+ prog_bar=True,
572
+ logger=True,
573
+ on_step=False,
574
+ on_epoch=True,
575
+ sync_dist=True,
576
+ )
577
+ self.log(
578
+ "val/aeloss",
579
+ aeloss,
580
+ prog_bar=True,
581
+ logger=True,
582
+ on_step=False,
583
+ on_epoch=True,
584
+ sync_dist=True,
585
+ )
586
+ self.log_dict(log_dict_ae, sync_dist=True)
587
+ self.log_dict(log_dict_disc, sync_dist=True)
588
+ return self.log_dict
589
+
590
+ def log_images(self, batch, **kwargs):
591
+ log = dict()
592
+ x = self.get_input(batch, self.image_key)
593
+ x = x.to(self.device)
594
+ # encode
595
+ h = self.encoder(x)
596
+ h = self.quant_conv(h)
597
+ quant, _, _ = self.quantize(h)
598
+ # decode
599
+ x_rec = self.decode(quant)
600
+ log["inputs"] = x
601
+ log["reconstructions"] = x_rec
602
+ return log
603
+
604
+
605
+ class EMAVQ(VQModel):
606
+ def __init__(
607
+ self,
608
+ ddconfig,
609
+ lossconfig,
610
+ n_embed,
611
+ embed_dim,
612
+ ckpt_path=None,
613
+ ignore_keys=[],
614
+ image_key="image",
615
+ colorize_nlabels=None,
616
+ monitor=None,
617
+ remap=None,
618
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
619
+ ):
620
+ super().__init__(
621
+ ddconfig,
622
+ lossconfig,
623
+ n_embed,
624
+ embed_dim,
625
+ ckpt_path=None,
626
+ ignore_keys=ignore_keys,
627
+ image_key=image_key,
628
+ colorize_nlabels=colorize_nlabels,
629
+ monitor=monitor,
630
+ )
631
+ self.quantize = EMAVectorQuantizer(
632
+ n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap
633
+ )
634
+
635
+ def configure_optimizers(self):
636
+ lr = self.learning_rate
637
+ # Remove self.quantize from parameter list since it is updated via EMA
638
+ opt_ae = torch.optim.Adam(
639
+ list(self.encoder.parameters())
640
+ + list(self.decoder.parameters())
641
+ + list(self.quant_conv.parameters())
642
+ + list(self.post_quant_conv.parameters()),
643
+ lr=lr,
644
+ betas=(0.5, 0.9),
645
+ )
646
+ opt_disc = torch.optim.Adam(
647
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
648
+ )
649
+ return [opt_ae, opt_disc], []
taming/modules/autoencoder/lpips/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
taming/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(self, in_channels):
142
+ super().__init__()
143
+ self.in_channels = in_channels
144
+
145
+ self.norm = Normalize(in_channels)
146
+ self.q = torch.nn.Conv2d(in_channels,
147
+ in_channels,
148
+ kernel_size=1,
149
+ stride=1,
150
+ padding=0)
151
+ self.k = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.v = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.proj_out = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b,c,h,w = q.shape
177
+ q = q.reshape(b,c,h*w)
178
+ q = q.permute(0,2,1) # b,hw,c
179
+ k = k.reshape(b,c,h*w) # b,c,hw
180
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c)**(-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b,c,h*w)
186
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b,c,h,w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x+h_
193
+
194
+
195
+ class Model(nn.Module):
196
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
+ resolution, use_timestep=True):
199
+ super().__init__()
200
+ self.ch = ch
201
+ self.temb_ch = self.ch*4
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.resolution = resolution
205
+ self.in_channels = in_channels
206
+
207
+ self.use_timestep = use_timestep
208
+ if self.use_timestep:
209
+ # timestep embedding
210
+ self.temb = nn.Module()
211
+ self.temb.dense = nn.ModuleList([
212
+ torch.nn.Linear(self.ch,
213
+ self.temb_ch),
214
+ torch.nn.Linear(self.temb_ch,
215
+ self.temb_ch),
216
+ ])
217
+
218
+ # downsampling
219
+ self.conv_in = torch.nn.Conv2d(in_channels,
220
+ self.ch,
221
+ kernel_size=3,
222
+ stride=1,
223
+ padding=1)
224
+
225
+ curr_res = resolution
226
+ in_ch_mult = (1,)+tuple(ch_mult)
227
+ self.down = nn.ModuleList()
228
+ for i_level in range(self.num_resolutions):
229
+ block = nn.ModuleList()
230
+ attn = nn.ModuleList()
231
+ block_in = ch*in_ch_mult[i_level]
232
+ block_out = ch*ch_mult[i_level]
233
+ for i_block in range(self.num_res_blocks):
234
+ block.append(ResnetBlock(in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout))
238
+ block_in = block_out
239
+ if curr_res in attn_resolutions:
240
+ attn.append(AttnBlock(block_in))
241
+ down = nn.Module()
242
+ down.block = block
243
+ down.attn = attn
244
+ if i_level != self.num_resolutions-1:
245
+ down.downsample = Downsample(block_in, resamp_with_conv)
246
+ curr_res = curr_res // 2
247
+ self.down.append(down)
248
+
249
+ # middle
250
+ self.mid = nn.Module()
251
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
+ out_channels=block_in,
253
+ temb_channels=self.temb_ch,
254
+ dropout=dropout)
255
+ self.mid.attn_1 = AttnBlock(block_in)
256
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+
261
+ # upsampling
262
+ self.up = nn.ModuleList()
263
+ for i_level in reversed(range(self.num_resolutions)):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_out = ch*ch_mult[i_level]
267
+ skip_in = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks+1):
269
+ if i_block == self.num_res_blocks:
270
+ skip_in = ch*in_ch_mult[i_level]
271
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
272
+ out_channels=block_out,
273
+ temb_channels=self.temb_ch,
274
+ dropout=dropout))
275
+ block_in = block_out
276
+ if curr_res in attn_resolutions:
277
+ attn.append(AttnBlock(block_in))
278
+ up = nn.Module()
279
+ up.block = block
280
+ up.attn = attn
281
+ if i_level != 0:
282
+ up.upsample = Upsample(block_in, resamp_with_conv)
283
+ curr_res = curr_res * 2
284
+ self.up.insert(0, up) # prepend to get consistent order
285
+
286
+ # end
287
+ self.norm_out = Normalize(block_in)
288
+ self.conv_out = torch.nn.Conv2d(block_in,
289
+ out_ch,
290
+ kernel_size=3,
291
+ stride=1,
292
+ padding=1)
293
+
294
+
295
+ def forward(self, x, t=None):
296
+ #assert x.shape[2] == x.shape[3] == self.resolution
297
+
298
+ if self.use_timestep:
299
+ # timestep embedding
300
+ assert t is not None
301
+ temb = get_timestep_embedding(t, self.ch)
302
+ temb = self.temb.dense[0](temb)
303
+ temb = nonlinearity(temb)
304
+ temb = self.temb.dense[1](temb)
305
+ else:
306
+ temb = None
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions-1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+
325
+ # upsampling
326
+ for i_level in reversed(range(self.num_resolutions)):
327
+ for i_block in range(self.num_res_blocks+1):
328
+ h = self.up[i_level].block[i_block](
329
+ torch.cat([h, hs.pop()], dim=1), temb)
330
+ if len(self.up[i_level].attn) > 0:
331
+ h = self.up[i_level].attn[i_block](h)
332
+ if i_level != 0:
333
+ h = self.up[i_level].upsample(h)
334
+
335
+ # end
336
+ h = self.norm_out(h)
337
+ h = nonlinearity(h)
338
+ h = self.conv_out(h)
339
+ return h
340
+
341
+
342
+ class Encoder(nn.Module):
343
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
+ resolution, z_channels, double_z=True, **ignore_kwargs):
346
+ super().__init__()
347
+ self.ch = ch
348
+ self.temb_ch = 0
349
+ self.num_resolutions = len(ch_mult)
350
+ self.num_res_blocks = num_res_blocks
351
+ self.resolution = resolution
352
+ self.in_channels = in_channels
353
+
354
+ # downsampling
355
+ self.conv_in = torch.nn.Conv2d(in_channels,
356
+ self.ch,
357
+ kernel_size=3,
358
+ stride=1,
359
+ padding=1)
360
+
361
+ curr_res = resolution
362
+ in_ch_mult = (1,)+tuple(ch_mult)
363
+ self.down = nn.ModuleList()
364
+ for i_level in range(self.num_resolutions):
365
+ block = nn.ModuleList()
366
+ attn = nn.ModuleList()
367
+ block_in = ch*in_ch_mult[i_level]
368
+ block_out = ch*ch_mult[i_level]
369
+ for i_block in range(self.num_res_blocks):
370
+ block.append(ResnetBlock(in_channels=block_in,
371
+ out_channels=block_out,
372
+ temb_channels=self.temb_ch,
373
+ dropout=dropout))
374
+ block_in = block_out
375
+ if curr_res in attn_resolutions:
376
+ attn.append(AttnBlock(block_in))
377
+ down = nn.Module()
378
+ down.block = block
379
+ down.attn = attn
380
+ if i_level != self.num_resolutions-1:
381
+ down.downsample = Downsample(block_in, resamp_with_conv)
382
+ curr_res = curr_res // 2
383
+ self.down.append(down)
384
+
385
+ # middle
386
+ self.mid = nn.Module()
387
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
+ out_channels=block_in,
389
+ temb_channels=self.temb_ch,
390
+ dropout=dropout)
391
+ self.mid.attn_1 = AttnBlock(block_in)
392
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = torch.nn.Conv2d(block_in,
400
+ 2*z_channels if double_z else z_channels,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+
406
+ def forward(self, x):
407
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
+
409
+ # timestep embedding
410
+ temb = None
411
+
412
+ # downsampling
413
+ hs = [self.conv_in(x)]
414
+ for i_level in range(self.num_resolutions):
415
+ for i_block in range(self.num_res_blocks):
416
+ h = self.down[i_level].block[i_block](hs[-1], temb)
417
+ if len(self.down[i_level].attn) > 0:
418
+ h = self.down[i_level].attn[i_block](h)
419
+ hs.append(h)
420
+ if i_level != self.num_resolutions-1:
421
+ hs.append(self.down[i_level].downsample(hs[-1]))
422
+
423
+ # middle
424
+ h = hs[-1]
425
+ h = self.mid.block_1(h, temb)
426
+ h = self.mid.attn_1(h)
427
+ h = self.mid.block_2(h, temb)
428
+
429
+ # end
430
+ h = self.norm_out(h)
431
+ h = nonlinearity(h)
432
+ h = self.conv_out(h)
433
+ return h
434
+
435
+
436
+ class Decoder(nn.Module):
437
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
+ super().__init__()
441
+ self.ch = ch
442
+ self.temb_ch = 0
443
+ self.num_resolutions = len(ch_mult)
444
+ self.num_res_blocks = num_res_blocks
445
+ self.resolution = resolution
446
+ self.in_channels = in_channels
447
+ self.give_pre_end = give_pre_end
448
+
449
+ # compute in_ch_mult, block_in and curr_res at lowest res
450
+ in_ch_mult = (1,)+tuple(ch_mult)
451
+ block_in = ch*ch_mult[self.num_resolutions-1]
452
+ curr_res = resolution // 2**(self.num_resolutions-1)
453
+ self.z_shape = (1,z_channels,curr_res,curr_res)
454
+ print("Working with z of shape {} = {} dimensions.".format(
455
+ self.z_shape, np.prod(self.z_shape)))
456
+
457
+ # z to block_in
458
+ self.conv_in = torch.nn.Conv2d(z_channels,
459
+ block_in,
460
+ kernel_size=3,
461
+ stride=1,
462
+ padding=1)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout)
470
+ self.mid.attn_1 = AttnBlock(block_in)
471
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout)
475
+
476
+ # upsampling
477
+ self.up = nn.ModuleList()
478
+ for i_level in reversed(range(self.num_resolutions)):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks+1):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(AttnBlock(block_in))
490
+ up = nn.Module()
491
+ up.block = block
492
+ up.attn = attn
493
+ if i_level != 0:
494
+ up.upsample = Upsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res * 2
496
+ self.up.insert(0, up) # prepend to get consistent order
497
+
498
+ # end
499
+ self.norm_out = Normalize(block_in)
500
+ self.conv_out = torch.nn.Conv2d(block_in,
501
+ out_ch,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ def forward(self, z):
507
+ #assert z.shape[1:] == self.z_shape[1:]
508
+ self.last_z_shape = z.shape
509
+
510
+ # timestep embedding
511
+ temb = None
512
+
513
+ # z to block_in
514
+ h = self.conv_in(z)
515
+
516
+ # middle
517
+ h = self.mid.block_1(h, temb)
518
+ h = self.mid.attn_1(h)
519
+ h = self.mid.block_2(h, temb)
520
+
521
+ # upsampling
522
+ for i_level in reversed(range(self.num_resolutions)):
523
+ for i_block in range(self.num_res_blocks+1):
524
+ h = self.up[i_level].block[i_block](h, temb)
525
+ if len(self.up[i_level].attn) > 0:
526
+ h = self.up[i_level].attn[i_block](h)
527
+ if i_level != 0:
528
+ h = self.up[i_level].upsample(h)
529
+
530
+ # end
531
+ if self.give_pre_end:
532
+ return h
533
+
534
+ h = self.norm_out(h)
535
+ h = nonlinearity(h)
536
+ h = self.conv_out(h)
537
+ return h
538
+
539
+
540
+ class VUNet(nn.Module):
541
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
+ in_channels, c_channels,
544
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
+ super().__init__()
546
+ self.ch = ch
547
+ self.temb_ch = self.ch*4
548
+ self.num_resolutions = len(ch_mult)
549
+ self.num_res_blocks = num_res_blocks
550
+ self.resolution = resolution
551
+
552
+ self.use_timestep = use_timestep
553
+ if self.use_timestep:
554
+ # timestep embedding
555
+ self.temb = nn.Module()
556
+ self.temb.dense = nn.ModuleList([
557
+ torch.nn.Linear(self.ch,
558
+ self.temb_ch),
559
+ torch.nn.Linear(self.temb_ch,
560
+ self.temb_ch),
561
+ ])
562
+
563
+ # downsampling
564
+ self.conv_in = torch.nn.Conv2d(c_channels,
565
+ self.ch,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+ curr_res = resolution
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ self.down = nn.ModuleList()
573
+ for i_level in range(self.num_resolutions):
574
+ block = nn.ModuleList()
575
+ attn = nn.ModuleList()
576
+ block_in = ch*in_ch_mult[i_level]
577
+ block_out = ch*ch_mult[i_level]
578
+ for i_block in range(self.num_res_blocks):
579
+ block.append(ResnetBlock(in_channels=block_in,
580
+ out_channels=block_out,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout))
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(AttnBlock(block_in))
586
+ down = nn.Module()
587
+ down.block = block
588
+ down.attn = attn
589
+ if i_level != self.num_resolutions-1:
590
+ down.downsample = Downsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res // 2
592
+ self.down.append(down)
593
+
594
+ self.z_in = torch.nn.Conv2d(z_channels,
595
+ block_in,
596
+ kernel_size=1,
597
+ stride=1,
598
+ padding=0)
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ skip_in = ch*ch_mult[i_level]
618
+ for i_block in range(self.num_res_blocks+1):
619
+ if i_block == self.num_res_blocks:
620
+ skip_in = ch*in_ch_mult[i_level]
621
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
622
+ out_channels=block_out,
623
+ temb_channels=self.temb_ch,
624
+ dropout=dropout))
625
+ block_in = block_out
626
+ if curr_res in attn_resolutions:
627
+ attn.append(AttnBlock(block_in))
628
+ up = nn.Module()
629
+ up.block = block
630
+ up.attn = attn
631
+ if i_level != 0:
632
+ up.upsample = Upsample(block_in, resamp_with_conv)
633
+ curr_res = curr_res * 2
634
+ self.up.insert(0, up) # prepend to get consistent order
635
+
636
+ # end
637
+ self.norm_out = Normalize(block_in)
638
+ self.conv_out = torch.nn.Conv2d(block_in,
639
+ out_ch,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1)
643
+
644
+
645
+ def forward(self, x, z):
646
+ #assert x.shape[2] == x.shape[3] == self.resolution
647
+
648
+ if self.use_timestep:
649
+ # timestep embedding
650
+ assert t is not None
651
+ temb = get_timestep_embedding(t, self.ch)
652
+ temb = self.temb.dense[0](temb)
653
+ temb = nonlinearity(temb)
654
+ temb = self.temb.dense[1](temb)
655
+ else:
656
+ temb = None
657
+
658
+ # downsampling
659
+ hs = [self.conv_in(x)]
660
+ for i_level in range(self.num_resolutions):
661
+ for i_block in range(self.num_res_blocks):
662
+ h = self.down[i_level].block[i_block](hs[-1], temb)
663
+ if len(self.down[i_level].attn) > 0:
664
+ h = self.down[i_level].attn[i_block](h)
665
+ hs.append(h)
666
+ if i_level != self.num_resolutions-1:
667
+ hs.append(self.down[i_level].downsample(hs[-1]))
668
+
669
+ # middle
670
+ h = hs[-1]
671
+ z = self.z_in(z)
672
+ h = torch.cat((h,z),dim=1)
673
+ h = self.mid.block_1(h, temb)
674
+ h = self.mid.attn_1(h)
675
+ h = self.mid.block_2(h, temb)
676
+
677
+ # upsampling
678
+ for i_level in reversed(range(self.num_resolutions)):
679
+ for i_block in range(self.num_res_blocks+1):
680
+ h = self.up[i_level].block[i_block](
681
+ torch.cat([h, hs.pop()], dim=1), temb)
682
+ if len(self.up[i_level].attn) > 0:
683
+ h = self.up[i_level].attn[i_block](h)
684
+ if i_level != 0:
685
+ h = self.up[i_level].upsample(h)
686
+
687
+ # end
688
+ h = self.norm_out(h)
689
+ h = nonlinearity(h)
690
+ h = self.conv_out(h)
691
+ return h
692
+
693
+
694
+ class SimpleDecoder(nn.Module):
695
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
696
+ super().__init__()
697
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
+ ResnetBlock(in_channels=in_channels,
699
+ out_channels=2 * in_channels,
700
+ temb_channels=0, dropout=0.0),
701
+ ResnetBlock(in_channels=2 * in_channels,
702
+ out_channels=4 * in_channels,
703
+ temb_channels=0, dropout=0.0),
704
+ ResnetBlock(in_channels=4 * in_channels,
705
+ out_channels=2 * in_channels,
706
+ temb_channels=0, dropout=0.0),
707
+ nn.Conv2d(2*in_channels, in_channels, 1),
708
+ Upsample(in_channels, with_conv=True)])
709
+ # end
710
+ self.norm_out = Normalize(in_channels)
711
+ self.conv_out = torch.nn.Conv2d(in_channels,
712
+ out_channels,
713
+ kernel_size=3,
714
+ stride=1,
715
+ padding=1)
716
+
717
+ def forward(self, x):
718
+ for i, layer in enumerate(self.model):
719
+ if i in [1,2,3]:
720
+ x = layer(x, None)
721
+ else:
722
+ x = layer(x)
723
+
724
+ h = self.norm_out(x)
725
+ h = nonlinearity(h)
726
+ x = self.conv_out(h)
727
+ return x
728
+
729
+
730
+ class UpsampleDecoder(nn.Module):
731
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
+ ch_mult=(2,2), dropout=0.0):
733
+ super().__init__()
734
+ # upsampling
735
+ self.temb_ch = 0
736
+ self.num_resolutions = len(ch_mult)
737
+ self.num_res_blocks = num_res_blocks
738
+ block_in = in_channels
739
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
+ self.res_blocks = nn.ModuleList()
741
+ self.upsample_blocks = nn.ModuleList()
742
+ for i_level in range(self.num_resolutions):
743
+ res_block = []
744
+ block_out = ch * ch_mult[i_level]
745
+ for i_block in range(self.num_res_blocks + 1):
746
+ res_block.append(ResnetBlock(in_channels=block_in,
747
+ out_channels=block_out,
748
+ temb_channels=self.temb_ch,
749
+ dropout=dropout))
750
+ block_in = block_out
751
+ self.res_blocks.append(nn.ModuleList(res_block))
752
+ if i_level != self.num_resolutions - 1:
753
+ self.upsample_blocks.append(Upsample(block_in, True))
754
+ curr_res = curr_res * 2
755
+
756
+ # end
757
+ self.norm_out = Normalize(block_in)
758
+ self.conv_out = torch.nn.Conv2d(block_in,
759
+ out_channels,
760
+ kernel_size=3,
761
+ stride=1,
762
+ padding=1)
763
+
764
+ def forward(self, x):
765
+ # upsampling
766
+ h = x
767
+ for k, i_level in enumerate(range(self.num_resolutions)):
768
+ for i_block in range(self.num_res_blocks + 1):
769
+ h = self.res_blocks[i_level][i_block](h, None)
770
+ if i_level != self.num_resolutions - 1:
771
+ h = self.upsample_blocks[k](h)
772
+ h = self.norm_out(h)
773
+ h = nonlinearity(h)
774
+ h = self.conv_out(h)
775
+ return h
776
+
taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from taming.modules.util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find('Conv') != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
taming/modules/losses/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from taming.modules.losses.vqperceptual import DummyLoss
2
+
taming/modules/losses/lpips.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+ from taming.util import get_ckpt_path
9
+
10
+
11
+ class LPIPS(nn.Module):
12
+ # Learned perceptual metric
13
+ def __init__(self, use_dropout=True):
14
+ super().__init__()
15
+ self.scaling_layer = ScalingLayer()
16
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
17
+ self.net = vgg16(pretrained=True, requires_grad=False)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
+ self.load_from_pretrained()
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def load_from_pretrained(self, name="vgg_lpips"):
28
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
+
32
+ @classmethod
33
+ def from_pretrained(cls, name="vgg_lpips"):
34
+ if name != "vgg_lpips":
35
+ raise NotImplementedError
36
+ model = cls()
37
+ ckpt = get_ckpt_path(name)
38
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
+ return model
40
+
41
+ def forward(self, input, target):
42
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
+ feats0, feats1, diffs = {}, {}, {}
45
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
+ for kk in range(len(self.chns)):
47
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
+
50
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
+ val = res[0]
52
+ for l in range(1, len(self.chns)):
53
+ val += res[l]
54
+ return val
55
+
56
+
57
+ class ScalingLayer(nn.Module):
58
+ def __init__(self):
59
+ super(ScalingLayer, self).__init__()
60
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
+
63
+ def forward(self, inp):
64
+ return (inp - self.shift) / self.scale
65
+
66
+
67
+ class NetLinLayer(nn.Module):
68
+ """ A single linear layer which does a 1x1 conv """
69
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
+ super(NetLinLayer, self).__init__()
71
+ layers = [nn.Dropout(), ] if (use_dropout) else []
72
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
+ self.model = nn.Sequential(*layers)
74
+
75
+
76
+ class vgg16(torch.nn.Module):
77
+ def __init__(self, requires_grad=False, pretrained=True):
78
+ super(vgg16, self).__init__()
79
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
+ self.slice1 = torch.nn.Sequential()
81
+ self.slice2 = torch.nn.Sequential()
82
+ self.slice3 = torch.nn.Sequential()
83
+ self.slice4 = torch.nn.Sequential()
84
+ self.slice5 = torch.nn.Sequential()
85
+ self.N_slices = 5
86
+ for x in range(4):
87
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(4, 9):
89
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(9, 16):
91
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
+ for x in range(16, 23):
93
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
+ for x in range(23, 30):
95
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
+ if not requires_grad:
97
+ for param in self.parameters():
98
+ param.requires_grad = False
99
+
100
+ def forward(self, X):
101
+ h = self.slice1(X)
102
+ h_relu1_2 = h
103
+ h = self.slice2(h)
104
+ h_relu2_2 = h
105
+ h = self.slice3(h)
106
+ h_relu3_3 = h
107
+ h = self.slice4(h)
108
+ h_relu4_3 = h
109
+ h = self.slice5(h)
110
+ h_relu5_3 = h
111
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
+ return out
114
+
115
+
116
+ def normalize_tensor(x,eps=1e-10):
117
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
+ return x/(norm_factor+eps)
119
+
120
+
121
+ def spatial_average(x, keepdim=True):
122
+ return x.mean([2,3],keepdim=keepdim)
123
+
taming/modules/losses/segmentation.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+ def forward(self, prediction, target):
7
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
8
+ return loss, {}
9
+
10
+
11
+ class BCELossWithQuant(nn.Module):
12
+ def __init__(self, codebook_weight=1.):
13
+ super().__init__()
14
+ self.codebook_weight = codebook_weight
15
+
16
+ def forward(self, qloss, target, prediction, split):
17
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
+ loss = bce_loss + self.codebook_weight*qloss
19
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
+ "{}/quant_loss".format(split): qloss.detach().mean()
22
+ }
taming/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from taming.modules.losses.lpips import LPIPS
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+
8
+
9
+ class DummyLoss(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+
14
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
15
+ if global_step < threshold:
16
+ weight = value
17
+ return weight
18
+
19
+
20
+ def hinge_d_loss(logits_real, logits_fake):
21
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
22
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
23
+ d_loss = 0.5 * (loss_real + loss_fake)
24
+ return d_loss
25
+
26
+
27
+ def vanilla_d_loss(logits_real, logits_fake):
28
+ d_loss = 0.5 * (
29
+ torch.mean(torch.nn.functional.softplus(-logits_real))
30
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
31
+ )
32
+ return d_loss
33
+
34
+
35
+ class VQLPIPSWithDiscriminator(nn.Module):
36
+ def __init__(
37
+ self,
38
+ disc_start,
39
+ codebook_weight=1.0,
40
+ pixelloss_weight=1.0,
41
+ disc_num_layers=3,
42
+ disc_in_channels=3,
43
+ disc_factor=1.0,
44
+ disc_weight=1.0,
45
+ perceptual_weight=1.0,
46
+ use_actnorm=False,
47
+ disc_conditional=False,
48
+ disc_ndf=64,
49
+ disc_loss="hinge",
50
+ ):
51
+ super().__init__()
52
+ assert disc_loss in ["hinge", "vanilla"]
53
+ self.codebook_weight = codebook_weight
54
+ self.pixel_weight = pixelloss_weight
55
+ self.perceptual_loss = LPIPS().eval()
56
+ self.perceptual_weight = perceptual_weight
57
+
58
+ self.discriminator = NLayerDiscriminator(
59
+ input_nc=disc_in_channels,
60
+ n_layers=disc_num_layers,
61
+ use_actnorm=use_actnorm,
62
+ ndf=disc_ndf,
63
+ ).apply(weights_init)
64
+ self.discriminator_iter_start = disc_start
65
+ if disc_loss == "hinge":
66
+ self.disc_loss = hinge_d_loss
67
+ elif disc_loss == "vanilla":
68
+ self.disc_loss = vanilla_d_loss
69
+ else:
70
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
71
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
72
+ self.disc_factor = disc_factor
73
+ self.discriminator_weight = disc_weight
74
+ self.disc_conditional = disc_conditional
75
+
76
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
77
+ if last_layer is not None:
78
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
79
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
80
+ else:
81
+ nll_grads = torch.autograd.grad(
82
+ nll_loss, self.last_layer[0], retain_graph=True
83
+ )[0]
84
+ g_grads = torch.autograd.grad(
85
+ g_loss, self.last_layer[0], retain_graph=True
86
+ )[0]
87
+
88
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
89
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
90
+ d_weight = d_weight * self.discriminator_weight
91
+ return d_weight
92
+
93
+ def forward(
94
+ self,
95
+ codebook_loss,
96
+ inputs,
97
+ reconstructions,
98
+ optimizer_idx,
99
+ global_step,
100
+ last_layer=None,
101
+ cond=None,
102
+ split="train",
103
+ ):
104
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
105
+ if self.perceptual_weight > 0:
106
+ p_loss = self.perceptual_loss(
107
+ inputs.contiguous(), reconstructions.contiguous()
108
+ )
109
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
110
+ else:
111
+ p_loss = torch.tensor([0.0])
112
+
113
+ nll_loss = rec_loss
114
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
115
+ nll_loss = torch.mean(nll_loss)
116
+
117
+ # now the GAN part
118
+ if optimizer_idx == 0:
119
+ # generator update
120
+ if cond is None:
121
+ assert not self.disc_conditional
122
+ logits_fake = self.discriminator(reconstructions.contiguous())
123
+ else:
124
+ assert self.disc_conditional
125
+ logits_fake = self.discriminator(
126
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
127
+ )
128
+ g_loss = -torch.mean(logits_fake)
129
+
130
+ try:
131
+ d_weight = self.calculate_adaptive_weight(
132
+ nll_loss, g_loss, last_layer=last_layer
133
+ )
134
+ except RuntimeError:
135
+ assert not self.training
136
+ d_weight = torch.tensor(0.0)
137
+
138
+ disc_factor = adopt_weight(
139
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
140
+ )
141
+ loss = (
142
+ nll_loss
143
+ + d_weight * disc_factor * g_loss
144
+ + self.codebook_weight * codebook_loss.mean()
145
+ )
146
+
147
+ log = {
148
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
149
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
150
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
151
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
152
+ "{}/p_loss".format(split): p_loss.detach().mean(),
153
+ "{}/d_weight".format(split): d_weight.detach(),
154
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
155
+ "{}/g_loss".format(split): g_loss.detach().mean(),
156
+ }
157
+ return loss, log
158
+
159
+ if optimizer_idx == 1:
160
+ # second pass for discriminator update
161
+ if cond is None:
162
+ logits_real = self.discriminator(inputs.contiguous().detach())
163
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
164
+ else:
165
+ logits_real = self.discriminator(
166
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
167
+ )
168
+ logits_fake = self.discriminator(
169
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
170
+ )
171
+
172
+ disc_factor = adopt_weight(
173
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
174
+ )
175
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
176
+
177
+ log = {
178
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
179
+ "{}/logits_real".format(split): logits_real.detach().mean(),
180
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
181
+ }
182
+ return d_loss, log
taming/modules/misc/coord.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CoordStage(object):
4
+ def __init__(self, n_embed, down_factor):
5
+ self.n_embed = n_embed
6
+ self.down_factor = down_factor
7
+
8
+ def eval(self):
9
+ return self
10
+
11
+ def encode(self, c):
12
+ """fake vqmodel interface"""
13
+ assert 0.0 <= c.min() and c.max() <= 1.0
14
+ b,ch,h,w = c.shape
15
+ assert ch == 1
16
+
17
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
+ mode="area")
19
+ c = c.clamp(0.0, 1.0)
20
+ c = self.n_embed*c
21
+ c_quant = c.round()
22
+ c_ind = c_quant.to(dtype=torch.long)
23
+
24
+ info = None, None, c_ind
25
+ return c_quant, None, info
26
+
27
+ def decode(self, c):
28
+ c = c/self.n_embed
29
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
+ mode="nearest")
31
+ return c
taming/modules/transformer/mingpt.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from: https://github.com/karpathy/minGPT/
3
+ GPT model:
4
+ - the initial stem consists of a combination of token encoding and a positional encoding
5
+ - the meat of it is a uniform sequence of Transformer blocks
6
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7
+ - all blocks feed into a central residual pathway similar to resnets
8
+ - the final decoder is a linear projection into a vanilla Softmax classifier
9
+ """
10
+
11
+ import math
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from transformers import top_k_top_p_filtering
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GPTConfig:
23
+ """ base GPT config, params common to all GPT versions """
24
+ embd_pdrop = 0.1
25
+ resid_pdrop = 0.1
26
+ attn_pdrop = 0.1
27
+
28
+ def __init__(self, vocab_size, block_size, **kwargs):
29
+ self.vocab_size = vocab_size
30
+ self.block_size = block_size
31
+ for k,v in kwargs.items():
32
+ setattr(self, k, v)
33
+
34
+
35
+ class GPT1Config(GPTConfig):
36
+ """ GPT-1 like network roughly 125M params """
37
+ n_layer = 12
38
+ n_head = 12
39
+ n_embd = 768
40
+
41
+
42
+ class CausalSelfAttention(nn.Module):
43
+ """
44
+ A vanilla multi-head masked self-attention layer with a projection at the end.
45
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
46
+ explicit implementation here to show that there is nothing too scary here.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ assert config.n_embd % config.n_head == 0
52
+ # key, query, value projections for all heads
53
+ self.key = nn.Linear(config.n_embd, config.n_embd)
54
+ self.query = nn.Linear(config.n_embd, config.n_embd)
55
+ self.value = nn.Linear(config.n_embd, config.n_embd)
56
+ # regularization
57
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
58
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
59
+ # output projection
60
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
61
+ # causal mask to ensure that attention is only applied to the left in the input sequence
62
+ mask = torch.tril(torch.ones(config.block_size,
63
+ config.block_size))
64
+ if hasattr(config, "n_unmasked"):
65
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
66
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67
+ self.n_head = config.n_head
68
+
69
+ def forward(self, x, layer_past=None):
70
+ B, T, C = x.size()
71
+
72
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
+
77
+ present = torch.stack((k, v))
78
+ if layer_past is not None:
79
+ past_key, past_value = layer_past
80
+ k = torch.cat((past_key, k), dim=-2)
81
+ v = torch.cat((past_value, v), dim=-2)
82
+
83
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85
+ if layer_past is None:
86
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87
+
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.attn_drop(att)
90
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92
+
93
+ # output projection
94
+ y = self.resid_drop(self.proj(y))
95
+ return y, present # TODO: check that this does not break anything
96
+
97
+
98
+ class Block(nn.Module):
99
+ """ an unassuming Transformer block """
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.ln1 = nn.LayerNorm(config.n_embd)
103
+ self.ln2 = nn.LayerNorm(config.n_embd)
104
+ self.attn = CausalSelfAttention(config)
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(config.n_embd, 4 * config.n_embd),
107
+ nn.GELU(), # nice
108
+ nn.Linear(4 * config.n_embd, config.n_embd),
109
+ nn.Dropout(config.resid_pdrop),
110
+ )
111
+
112
+ def forward(self, x, layer_past=None, return_present=False):
113
+ # TODO: check that training still works
114
+ if return_present: assert not self.training
115
+ # layer past: tuple of length two with B, nh, T, hs
116
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117
+
118
+ x = x + attn
119
+ x = x + self.mlp(self.ln2(x))
120
+ if layer_past is not None or return_present:
121
+ return x, present
122
+ return x
123
+
124
+
125
+ class GPT(nn.Module):
126
+ """ the full GPT language model, with a context size of block_size """
127
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129
+ super().__init__()
130
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133
+ n_unmasked=n_unmasked)
134
+ # input embedding stem
135
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137
+ self.drop = nn.Dropout(config.embd_pdrop)
138
+ # transformer
139
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140
+ # decoder head
141
+ self.ln_f = nn.LayerNorm(config.n_embd)
142
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
+ self.block_size = config.block_size
144
+ self.apply(self._init_weights)
145
+ self.config = config
146
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147
+
148
+ def get_block_size(self):
149
+ return self.block_size
150
+
151
+ def _init_weights(self, module):
152
+ if isinstance(module, (nn.Linear, nn.Embedding)):
153
+ module.weight.data.normal_(mean=0.0, std=0.02)
154
+ if isinstance(module, nn.Linear) and module.bias is not None:
155
+ module.bias.data.zero_()
156
+ elif isinstance(module, nn.LayerNorm):
157
+ module.bias.data.zero_()
158
+ module.weight.data.fill_(1.0)
159
+
160
+ def forward(self, idx, embeddings=None, targets=None):
161
+ # forward the GPT model
162
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163
+
164
+ if embeddings is not None: # prepend explicit embeddings
165
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166
+
167
+ t = token_embeddings.shape[1]
168
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170
+ x = self.drop(token_embeddings + position_embeddings)
171
+ x = self.blocks(x)
172
+ x = self.ln_f(x)
173
+ logits = self.head(x)
174
+
175
+ # if we are given some desired targets also calculate the loss
176
+ loss = None
177
+ if targets is not None:
178
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179
+
180
+ return logits, loss
181
+
182
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183
+ # inference only
184
+ assert not self.training
185
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
+ if embeddings is not None: # prepend explicit embeddings
187
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188
+
189
+ if past is not None:
190
+ assert past_length is not None
191
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192
+ past_shape = list(past.shape)
193
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196
+ else:
197
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198
+
199
+ x = self.drop(token_embeddings + position_embeddings)
200
+ presents = [] # accumulate over layers
201
+ for i, block in enumerate(self.blocks):
202
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203
+ presents.append(present)
204
+
205
+ x = self.ln_f(x)
206
+ logits = self.head(x)
207
+ # if we are given some desired targets also calculate the loss
208
+ loss = None
209
+ if targets is not None:
210
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211
+
212
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213
+
214
+
215
+ class DummyGPT(nn.Module):
216
+ # for debugging
217
+ def __init__(self, add_value=1):
218
+ super().__init__()
219
+ self.add_value = add_value
220
+
221
+ def forward(self, idx):
222
+ return idx + self.add_value, None
223
+
224
+
225
+ class CodeGPT(nn.Module):
226
+ """Takes in semi-embeddings"""
227
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229
+ super().__init__()
230
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233
+ n_unmasked=n_unmasked)
234
+ # input embedding stem
235
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
236
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237
+ self.drop = nn.Dropout(config.embd_pdrop)
238
+ # transformer
239
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240
+ # decoder head
241
+ self.ln_f = nn.LayerNorm(config.n_embd)
242
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243
+ self.block_size = config.block_size
244
+ self.apply(self._init_weights)
245
+ self.config = config
246
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247
+
248
+ def get_block_size(self):
249
+ return self.block_size
250
+
251
+ def _init_weights(self, module):
252
+ if isinstance(module, (nn.Linear, nn.Embedding)):
253
+ module.weight.data.normal_(mean=0.0, std=0.02)
254
+ if isinstance(module, nn.Linear) and module.bias is not None:
255
+ module.bias.data.zero_()
256
+ elif isinstance(module, nn.LayerNorm):
257
+ module.bias.data.zero_()
258
+ module.weight.data.fill_(1.0)
259
+
260
+ def forward(self, idx, embeddings=None, targets=None):
261
+ # forward the GPT model
262
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263
+
264
+ if embeddings is not None: # prepend explicit embeddings
265
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266
+
267
+ t = token_embeddings.shape[1]
268
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270
+ x = self.drop(token_embeddings + position_embeddings)
271
+ x = self.blocks(x)
272
+ x = self.taming_cinln_f(x)
273
+ logits = self.head(x)
274
+
275
+ # if we are given some desired targets also calculate the loss
276
+ loss = None
277
+ if targets is not None:
278
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279
+
280
+ return logits, loss
281
+
282
+
283
+
284
+ #### sampling utils
285
+
286
+ def top_k_logits(logits, k):
287
+ v, ix = torch.topk(logits, k)
288
+ out = logits.clone()
289
+ out[out < v[:, [-1]]] = -float('Inf')
290
+ return out
291
+
292
+ @torch.no_grad()
293
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294
+ """
295
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
297
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298
+ of block_size, unlike an RNN that has an infinite context window.
299
+ """
300
+ block_size = model.get_block_size()
301
+ model.eval()
302
+ for k in range(steps):
303
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304
+ logits, _ = model(x_cond)
305
+ # pluck the logits at the final step and scale by temperature
306
+ logits = logits[:, -1, :] / temperature
307
+ # optionally crop probabilities to only the top k options
308
+ if top_k is not None:
309
+ logits = top_k_logits(logits, top_k)
310
+ # apply softmax to convert to probabilities
311
+ probs = F.softmax(logits, dim=-1)
312
+ # sample from the distribution or take the most likely
313
+ if sample:
314
+ ix = torch.multinomial(probs, num_samples=1)
315
+ else:
316
+ _, ix = torch.topk(probs, k=1, dim=-1)
317
+ # append to the sequence and continue
318
+ x = torch.cat((x, ix), dim=1)
319
+
320
+ return x
321
+
322
+
323
+ @torch.no_grad()
324
+ def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325
+ top_k=None, top_p=None, callback=None):
326
+ # x is conditioning
327
+ sample = x
328
+ cond_len = x.shape[1]
329
+ past = None
330
+ for n in range(steps):
331
+ if callback is not None:
332
+ callback(n)
333
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334
+ if past is None:
335
+ past = [present]
336
+ else:
337
+ past.append(present)
338
+ logits = logits[:, -1, :] / temperature
339
+ if top_k is not None:
340
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341
+
342
+ probs = F.softmax(logits, dim=-1)
343
+ if not sample_logits:
344
+ _, x = torch.topk(probs, k=1, dim=-1)
345
+ else:
346
+ x = torch.multinomial(probs, num_samples=1)
347
+ # append to the sequence and continue
348
+ sample = torch.cat((sample, x), dim=1)
349
+ del past
350
+ sample = sample[:, cond_len:] # cut conditioning off
351
+ return sample
352
+
353
+
354
+ #### clustering utils
355
+
356
+ class KMeans(nn.Module):
357
+ def __init__(self, ncluster=512, nc=3, niter=10):
358
+ super().__init__()
359
+ self.ncluster = ncluster
360
+ self.nc = nc
361
+ self.niter = niter
362
+ self.shape = (3,32,32)
363
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
364
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365
+
366
+ def is_initialized(self):
367
+ return self.initialized.item() == 1
368
+
369
+ @torch.no_grad()
370
+ def initialize(self, x):
371
+ N, D = x.shape
372
+ assert D == self.nc, D
373
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374
+ for i in range(self.niter):
375
+ # assign all pixels to the closest codebook element
376
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377
+ # move each codebook element to be the mean of the pixels that assigned to it
378
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379
+ # re-assign any poorly positioned codebook elements
380
+ nanix = torch.any(torch.isnan(c), dim=1)
381
+ ndead = nanix.sum().item()
382
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384
+
385
+ self.C.copy_(c)
386
+ self.initialized.fill_(1)
387
+
388
+
389
+ def forward(self, x, reverse=False, shape=None):
390
+ if not reverse:
391
+ # flatten
392
+ bs,c,h,w = x.shape
393
+ assert c == self.nc
394
+ x = x.reshape(bs,c,h*w,1)
395
+ C = self.C.permute(1,0)
396
+ C = C.reshape(1,c,1,self.ncluster)
397
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398
+ return a
399
+ else:
400
+ # flatten
401
+ bs, HW = x.shape
402
+ """
403
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404
+ c = c[bs*[0],:,:,:]
405
+ c = c[:,:,HW*[0],:]
406
+ x = x.reshape(bs, 1, HW, 1)
407
+ x = x[:,3*[0],:,:]
408
+ x = torch.gather(c, dim=3, index=x)
409
+ """
410
+ x = self.C[x]
411
+ x = x.permute(0,2,1)
412
+ shape = shape if shape is not None else self.shape
413
+ x = x.reshape(bs, *shape)
414
+
415
+ return x
taming/modules/transformer/permuter.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class AbstractPermuter(nn.Module):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__()
9
+ def forward(self, x, reverse=False):
10
+ raise NotImplementedError
11
+
12
+
13
+ class Identity(AbstractPermuter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, reverse=False):
18
+ return x
19
+
20
+
21
+ class Subsample(AbstractPermuter):
22
+ def __init__(self, H, W):
23
+ super().__init__()
24
+ C = 1
25
+ indices = np.arange(H*W).reshape(C,H,W)
26
+ while min(H, W) > 1:
27
+ indices = indices.reshape(C,H//2,2,W//2,2)
28
+ indices = indices.transpose(0,2,4,1,3)
29
+ indices = indices.reshape(C*4,H//2, W//2)
30
+ H = H//2
31
+ W = W//2
32
+ C = C*4
33
+ assert H == W == 1
34
+ idx = torch.tensor(indices.ravel())
35
+ self.register_buffer('forward_shuffle_idx',
36
+ nn.Parameter(idx, requires_grad=False))
37
+ self.register_buffer('backward_shuffle_idx',
38
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
39
+
40
+ def forward(self, x, reverse=False):
41
+ if not reverse:
42
+ return x[:, self.forward_shuffle_idx]
43
+ else:
44
+ return x[:, self.backward_shuffle_idx]
45
+
46
+
47
+ def mortonify(i, j):
48
+ """(i,j) index to linear morton code"""
49
+ i = np.uint64(i)
50
+ j = np.uint64(j)
51
+
52
+ z = np.uint(0)
53
+
54
+ for pos in range(32):
55
+ z = (z |
56
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58
+ )
59
+ return z
60
+
61
+
62
+ class ZCurve(AbstractPermuter):
63
+ def __init__(self, H, W):
64
+ super().__init__()
65
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66
+ idx = np.argsort(reverseidx)
67
+ idx = torch.tensor(idx)
68
+ reverseidx = torch.tensor(reverseidx)
69
+ self.register_buffer('forward_shuffle_idx',
70
+ idx)
71
+ self.register_buffer('backward_shuffle_idx',
72
+ reverseidx)
73
+
74
+ def forward(self, x, reverse=False):
75
+ if not reverse:
76
+ return x[:, self.forward_shuffle_idx]
77
+ else:
78
+ return x[:, self.backward_shuffle_idx]
79
+
80
+
81
+ class SpiralOut(AbstractPermuter):
82
+ def __init__(self, H, W):
83
+ super().__init__()
84
+ assert H == W
85
+ size = W
86
+ indices = np.arange(size*size).reshape(size,size)
87
+
88
+ i0 = size//2
89
+ j0 = size//2-1
90
+
91
+ i = i0
92
+ j = j0
93
+
94
+ idx = [indices[i0, j0]]
95
+ step_mult = 0
96
+ for c in range(1, size//2+1):
97
+ step_mult += 1
98
+ # steps left
99
+ for k in range(step_mult):
100
+ i = i - 1
101
+ j = j
102
+ idx.append(indices[i, j])
103
+
104
+ # step down
105
+ for k in range(step_mult):
106
+ i = i
107
+ j = j + 1
108
+ idx.append(indices[i, j])
109
+
110
+ step_mult += 1
111
+ if c < size//2:
112
+ # step right
113
+ for k in range(step_mult):
114
+ i = i + 1
115
+ j = j
116
+ idx.append(indices[i, j])
117
+
118
+ # step up
119
+ for k in range(step_mult):
120
+ i = i
121
+ j = j - 1
122
+ idx.append(indices[i, j])
123
+ else:
124
+ # end reached
125
+ for k in range(step_mult-1):
126
+ i = i + 1
127
+ idx.append(indices[i, j])
128
+
129
+ assert len(idx) == size*size
130
+ idx = torch.tensor(idx)
131
+ self.register_buffer('forward_shuffle_idx', idx)
132
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133
+
134
+ def forward(self, x, reverse=False):
135
+ if not reverse:
136
+ return x[:, self.forward_shuffle_idx]
137
+ else:
138
+ return x[:, self.backward_shuffle_idx]
139
+
140
+
141
+ class SpiralIn(AbstractPermuter):
142
+ def __init__(self, H, W):
143
+ super().__init__()
144
+ assert H == W
145
+ size = W
146
+ indices = np.arange(size*size).reshape(size,size)
147
+
148
+ i0 = size//2
149
+ j0 = size//2-1
150
+
151
+ i = i0
152
+ j = j0
153
+
154
+ idx = [indices[i0, j0]]
155
+ step_mult = 0
156
+ for c in range(1, size//2+1):
157
+ step_mult += 1
158
+ # steps left
159
+ for k in range(step_mult):
160
+ i = i - 1
161
+ j = j
162
+ idx.append(indices[i, j])
163
+
164
+ # step down
165
+ for k in range(step_mult):
166
+ i = i
167
+ j = j + 1
168
+ idx.append(indices[i, j])
169
+
170
+ step_mult += 1
171
+ if c < size//2:
172
+ # step right
173
+ for k in range(step_mult):
174
+ i = i + 1
175
+ j = j
176
+ idx.append(indices[i, j])
177
+
178
+ # step up
179
+ for k in range(step_mult):
180
+ i = i
181
+ j = j - 1
182
+ idx.append(indices[i, j])
183
+ else:
184
+ # end reached
185
+ for k in range(step_mult-1):
186
+ i = i + 1
187
+ idx.append(indices[i, j])
188
+
189
+ assert len(idx) == size*size
190
+ idx = idx[::-1]
191
+ idx = torch.tensor(idx)
192
+ self.register_buffer('forward_shuffle_idx', idx)
193
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194
+
195
+ def forward(self, x, reverse=False):
196
+ if not reverse:
197
+ return x[:, self.forward_shuffle_idx]
198
+ else:
199
+ return x[:, self.backward_shuffle_idx]
200
+
201
+
202
+ class Random(nn.Module):
203
+ def __init__(self, H, W):
204
+ super().__init__()
205
+ indices = np.random.RandomState(1).permutation(H*W)
206
+ idx = torch.tensor(indices.ravel())
207
+ self.register_buffer('forward_shuffle_idx', idx)
208
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209
+
210
+ def forward(self, x, reverse=False):
211
+ if not reverse:
212
+ return x[:, self.forward_shuffle_idx]
213
+ else:
214
+ return x[:, self.backward_shuffle_idx]
215
+
216
+
217
+ class AlternateParsing(AbstractPermuter):
218
+ def __init__(self, H, W):
219
+ super().__init__()
220
+ indices = np.arange(W*H).reshape(H,W)
221
+ for i in range(1, H, 2):
222
+ indices[i, :] = indices[i, ::-1]
223
+ idx = indices.flatten()
224
+ assert len(idx) == H*W
225
+ idx = torch.tensor(idx)
226
+ self.register_buffer('forward_shuffle_idx', idx)
227
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228
+
229
+ def forward(self, x, reverse=False):
230
+ if not reverse:
231
+ return x[:, self.forward_shuffle_idx]
232
+ else:
233
+ return x[:, self.backward_shuffle_idx]
234
+
235
+
236
+ if __name__ == "__main__":
237
+ p0 = AlternateParsing(16, 16)
238
+ print(p0.forward_shuffle_idx)
239
+ print(p0.backward_shuffle_idx)
240
+
241
+ x = torch.randint(0, 768, size=(11, 256))
242
+ y = p0(x)
243
+ xre = p0(y, reverse=True)
244
+ assert torch.equal(x, xre)
245
+
246
+ p1 = SpiralOut(2, 2)
247
+ print(p1.forward_shuffle_idx)
248
+ print(p1.backward_shuffle_idx)
taming/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
taming/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import einsum
6
+ from einops import rearrange
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
+ # used wherever VectorQuantizer has been used before and is additionally
24
+ # more efficient.
25
+ def __init__(self, n_e, e_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+
31
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
+
34
+ def forward(self, z):
35
+ """
36
+ Inputs the output of the encoder network z and maps it to a discrete
37
+ one-hot vector that is the index of the closest embedding vector e_j
38
+ z (continuous) -> z_q (discrete)
39
+ z.shape = (batch, channel, height, width)
40
+ quantization pipeline:
41
+ 1. get encoder input (B,C,H,W)
42
+ 2. flatten input to (B*H*W,C)
43
+ """
44
+ # reshape z -> (batch, height, width, channel) and flatten
45
+ z = z.permute(0, 2, 3, 1).contiguous()
46
+ z_flattened = z.view(-1, self.e_dim)
47
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
+
49
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
+ torch.matmul(z_flattened, self.embedding.weight.t())
52
+
53
+ ## could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
+
58
+ min_encodings = torch.zeros(
59
+ min_encoding_indices.shape[0], self.n_e).to(z)
60
+ min_encodings.scatter_(1, min_encoding_indices, 1)
61
+
62
+ # dtype min encodings: torch.float32
63
+ # min_encodings shape: torch.Size([2048, 512])
64
+ # min_encoding_indices.shape: torch.Size([2048, 1])
65
+
66
+ # get quantized latent vectors
67
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
+ #.........\end
69
+
70
+ # with:
71
+ # .........\start
72
+ #min_encoding_indices = torch.argmin(d, dim=1)
73
+ #z_q = self.embedding(min_encoding_indices)
74
+ # ......\end......... (TODO)
75
+
76
+ # compute loss for embedding
77
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
+ torch.mean((z_q - z.detach()) ** 2)
79
+
80
+ # preserve gradients
81
+ z_q = z + (z_q - z).detach()
82
+
83
+ # perplexity
84
+ e_mean = torch.mean(min_encodings, dim=0)
85
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
+
87
+ # reshape back to match original input shape
88
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
+
90
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
+
92
+ def get_codebook_entry(self, indices, shape):
93
+ # shape specifying (batch, height, width, channel)
94
+ # TODO: check for more easy handling with nn.Embedding
95
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
+ min_encodings.scatter_(1, indices[:,None], 1)
97
+
98
+ # get quantized latent vectors
99
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
+
101
+ if shape is not None:
102
+ z_q = z_q.view(shape)
103
+
104
+ # reshape back to match original input shape
105
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
+
107
+ return z_q
108
+
109
+
110
+ class GumbelQuantize(nn.Module):
111
+ """
112
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
+ Gumbel Softmax trick quantizer
114
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
+ https://arxiv.org/abs/1611.01144
116
+ """
117
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
+ remap=None, unknown_index="random"):
120
+ super().__init__()
121
+
122
+ self.embedding_dim = embedding_dim
123
+ self.n_embed = n_embed
124
+
125
+ self.straight_through = straight_through
126
+ self.temperature = temp_init
127
+ self.kl_weight = kl_weight
128
+
129
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
+ self.embed = nn.Embedding(n_embed, embedding_dim)
131
+
132
+ self.use_vqinterface = use_vqinterface
133
+
134
+ self.remap = remap
135
+ if self.remap is not None:
136
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
+ self.re_embed = self.used.shape[0]
138
+ self.unknown_index = unknown_index # "random" or "extra" or integer
139
+ if self.unknown_index == "extra":
140
+ self.unknown_index = self.re_embed
141
+ self.re_embed = self.re_embed+1
142
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
+ f"Using {self.unknown_index} for unknown indices.")
144
+ else:
145
+ self.re_embed = n_embed
146
+
147
+ def remap_to_used(self, inds):
148
+ ishape = inds.shape
149
+ assert len(ishape)>1
150
+ inds = inds.reshape(ishape[0],-1)
151
+ used = self.used.to(inds)
152
+ match = (inds[:,:,None]==used[None,None,...]).long()
153
+ new = match.argmax(-1)
154
+ unknown = match.sum(2)<1
155
+ if self.unknown_index == "random":
156
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
+ else:
158
+ new[unknown] = self.unknown_index
159
+ return new.reshape(ishape)
160
+
161
+ def unmap_to_all(self, inds):
162
+ ishape = inds.shape
163
+ assert len(ishape)>1
164
+ inds = inds.reshape(ishape[0],-1)
165
+ used = self.used.to(inds)
166
+ if self.re_embed > self.used.shape[0]: # extra token
167
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
+ return back.reshape(ishape)
170
+
171
+ def forward(self, z, temp=None, return_logits=False):
172
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
+ hard = self.straight_through if self.training else True
174
+ temp = self.temperature if temp is None else temp
175
+
176
+ logits = self.proj(z)
177
+ if self.remap is not None:
178
+ # continue only with used logits
179
+ full_zeros = torch.zeros_like(logits)
180
+ logits = logits[:,self.used,...]
181
+
182
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
+ if self.remap is not None:
184
+ # go back to all entries but unused set to zero
185
+ full_zeros[:,self.used,...] = soft_one_hot
186
+ soft_one_hot = full_zeros
187
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
+
189
+ # + kl divergence to the prior loss
190
+ qy = F.softmax(logits, dim=1)
191
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
+
193
+ ind = soft_one_hot.argmax(dim=1)
194
+ if self.remap is not None:
195
+ ind = self.remap_to_used(ind)
196
+ if self.use_vqinterface:
197
+ if return_logits:
198
+ return z_q, diff, (None, None, ind), logits
199
+ return z_q, diff, (None, None, ind)
200
+ return z_q, diff, ind
201
+
202
+ def get_codebook_entry(self, indices, shape):
203
+ b, h, w, c = shape
204
+ assert b*h*w == indices.shape[0]
205
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
+ if self.remap is not None:
207
+ indices = self.unmap_to_all(indices)
208
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
+ return z_q
211
+
212
+
213
+ class VectorQuantizer2(nn.Module):
214
+ """
215
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
+ """
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
+ sane_index_shape=False, legacy=True):
223
+ super().__init__()
224
+ self.n_e = n_e
225
+ self.e_dim = e_dim
226
+ self.beta = beta
227
+ self.legacy = legacy
228
+
229
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
+
232
+ self.remap = remap
233
+ if self.remap is not None:
234
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
+ self.re_embed = self.used.shape[0]
236
+ self.unknown_index = unknown_index # "random" or "extra" or integer
237
+ if self.unknown_index == "extra":
238
+ self.unknown_index = self.re_embed
239
+ self.re_embed = self.re_embed+1
240
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices.")
242
+ else:
243
+ self.re_embed = n_e
244
+
245
+ self.sane_index_shape = sane_index_shape
246
+
247
+ def remap_to_used(self, inds):
248
+ ishape = inds.shape
249
+ assert len(ishape)>1
250
+ inds = inds.reshape(ishape[0],-1)
251
+ used = self.used.to(inds)
252
+ match = (inds[:,:,None]==used[None,None,...]).long()
253
+ new = match.argmax(-1)
254
+ unknown = match.sum(2)<1
255
+ if self.unknown_index == "random":
256
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
+ else:
258
+ new[unknown] = self.unknown_index
259
+ return new.reshape(ishape)
260
+
261
+ def unmap_to_all(self, inds):
262
+ ishape = inds.shape
263
+ assert len(ishape)>1
264
+ inds = inds.reshape(ishape[0],-1)
265
+ used = self.used.to(inds)
266
+ if self.re_embed > self.used.shape[0]: # extra token
267
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
+ return back.reshape(ishape)
270
+
271
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
+ assert return_logits==False, "Only for interface compatible with Gumbel"
275
+ # reshape z -> (batch, height, width, channel) and flatten
276
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
+ z_flattened = z.view(-1, self.e_dim)
278
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
+
280
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
+ torch.mean((z_q - z.detach()) ** 2)
293
+ else:
294
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
+ torch.mean((z_q - z.detach()) ** 2)
296
+
297
+ # preserve gradients
298
+ z_q = z + (z_q - z).detach()
299
+
300
+ # reshape back to match original input shape
301
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
+
303
+ if self.remap is not None:
304
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
+
308
+ if self.sane_index_shape:
309
+ min_encoding_indices = min_encoding_indices.reshape(
310
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
+
312
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
+
314
+ def get_codebook_entry(self, indices, shape):
315
+ # shape specifying (batch, height, width, channel)
316
+ if self.remap is not None:
317
+ indices = indices.reshape(shape[0],-1) # add batch axis
318
+ indices = self.unmap_to_all(indices)
319
+ indices = indices.reshape(-1) # flatten again
320
+
321
+ # get quantized latent vectors
322
+ z_q = self.embedding(indices)
323
+
324
+ if shape is not None:
325
+ z_q = z_q.view(shape)
326
+ # reshape back to match original input shape
327
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
+
329
+ return z_q
330
+
331
+ class EmbeddingEMA(nn.Module):
332
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
+ super().__init__()
334
+ self.decay = decay
335
+ self.eps = eps
336
+ weight = torch.randn(num_tokens, codebook_dim)
337
+ self.weight = nn.Parameter(weight, requires_grad = False)
338
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
+ self.update = True
341
+
342
+ def forward(self, embed_id):
343
+ return F.embedding(embed_id, self.weight)
344
+
345
+ def cluster_size_ema_update(self, new_cluster_size):
346
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
+
348
+ def embed_avg_ema_update(self, new_embed_avg):
349
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
+
351
+ def weight_update(self, num_tokens):
352
+ n = self.cluster_size.sum()
353
+ smoothed_cluster_size = (
354
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
+ self.weight.data.copy_(embed_normalized)
359
+
360
+
361
+ class EMAVectorQuantizer(nn.Module):
362
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
+ remap=None, unknown_index="random"):
364
+ super().__init__()
365
+ self.codebook_dim = codebook_dim
366
+ self.num_tokens = num_tokens
367
+ self.beta = beta
368
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
+
370
+ self.remap = remap
371
+ if self.remap is not None:
372
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
+ self.re_embed = self.used.shape[0]
374
+ self.unknown_index = unknown_index # "random" or "extra" or integer
375
+ if self.unknown_index == "extra":
376
+ self.unknown_index = self.re_embed
377
+ self.re_embed = self.re_embed+1
378
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
+ f"Using {self.unknown_index} for unknown indices.")
380
+ else:
381
+ self.re_embed = n_embed
382
+
383
+ def remap_to_used(self, inds):
384
+ ishape = inds.shape
385
+ assert len(ishape)>1
386
+ inds = inds.reshape(ishape[0],-1)
387
+ used = self.used.to(inds)
388
+ match = (inds[:,:,None]==used[None,None,...]).long()
389
+ new = match.argmax(-1)
390
+ unknown = match.sum(2)<1
391
+ if self.unknown_index == "random":
392
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
+ else:
394
+ new[unknown] = self.unknown_index
395
+ return new.reshape(ishape)
396
+
397
+ def unmap_to_all(self, inds):
398
+ ishape = inds.shape
399
+ assert len(ishape)>1
400
+ inds = inds.reshape(ishape[0],-1)
401
+ used = self.used.to(inds)
402
+ if self.re_embed > self.used.shape[0]: # extra token
403
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
+ return back.reshape(ishape)
406
+
407
+ def forward(self, z):
408
+ # reshape z -> (batch, height, width, channel) and flatten
409
+ #z, 'b c h w -> b h w c'
410
+ z = rearrange(z, 'b c h w -> b h w c')
411
+ z_flattened = z.reshape(-1, self.codebook_dim)
412
+
413
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
+
418
+
419
+ encoding_indices = torch.argmin(d, dim=1)
420
+
421
+ z_q = self.embedding(encoding_indices).view(z.shape)
422
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
+ avg_probs = torch.mean(encodings, dim=0)
424
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
+
426
+ if self.training and self.embedding.update:
427
+ #EMA cluster size
428
+ encodings_sum = encodings.sum(0)
429
+ self.embedding.cluster_size_ema_update(encodings_sum)
430
+ #EMA embedding average
431
+ embed_sum = encodings.transpose(0,1) @ z_flattened
432
+ self.embedding.embed_avg_ema_update(embed_sum)
433
+ #normalize embed_avg and update weight
434
+ self.embedding.weight_update(self.num_tokens)
435
+
436
+ # compute loss for embedding
437
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
438
+
439
+ # preserve gradients
440
+ z_q = z + (z_q - z).detach()
441
+
442
+ # reshape back to match original input shape
443
+ #z_q, 'b h w c -> b c h w'
444
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
445
+ return z_q, loss, (perplexity, encodings, encoding_indices)
taming/util.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
+ }
8
+
9
+ CKPT_MAP = {
10
+ "vgg_lpips": "vgg.pth"
11
+ }
12
+
13
+ MD5_MAP = {
14
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
+ }
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class KeyNotFoundError(Exception):
48
+ def __init__(self, cause, keys=None, visited=None):
49
+ self.cause = cause
50
+ self.keys = keys
51
+ self.visited = visited
52
+ messages = list()
53
+ if keys is not None:
54
+ messages.append("Key not found: {}".format(keys))
55
+ if visited is not None:
56
+ messages.append("Visited: {}".format(visited))
57
+ messages.append("Cause:\n{}".format(cause))
58
+ message = "\n".join(messages)
59
+ super().__init__(message)
60
+
61
+
62
+ def retrieve(
63
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
+ ):
65
+ """Given a nested list or dict return the desired value at key expanding
66
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
+ is done in-place.
68
+
69
+ Parameters
70
+ ----------
71
+ list_or_dict : list or dict
72
+ Possibly nested list or dictionary.
73
+ key : str
74
+ key/to/value, path like string describing all keys necessary to
75
+ consider to get to the desired value. List indices can also be
76
+ passed here.
77
+ splitval : str
78
+ String that defines the delimiter between keys of the
79
+ different depth levels in `key`.
80
+ default : obj
81
+ Value returned if :attr:`key` is not found.
82
+ expand : bool
83
+ Whether to expand callable nodes on the path or not.
84
+
85
+ Returns
86
+ -------
87
+ The desired value or if :attr:`default` is not ``None`` and the
88
+ :attr:`key` is not found returns ``default``.
89
+
90
+ Raises
91
+ ------
92
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
+ ``None``.
94
+ """
95
+
96
+ keys = key.split(splitval)
97
+
98
+ success = True
99
+ try:
100
+ visited = []
101
+ parent = None
102
+ last_key = None
103
+ for key in keys:
104
+ if callable(list_or_dict):
105
+ if not expand:
106
+ raise KeyNotFoundError(
107
+ ValueError(
108
+ "Trying to get past callable node with expand=False."
109
+ ),
110
+ keys=keys,
111
+ visited=visited,
112
+ )
113
+ list_or_dict = list_or_dict()
114
+ parent[last_key] = list_or_dict
115
+
116
+ last_key = key
117
+ parent = list_or_dict
118
+
119
+ try:
120
+ if isinstance(list_or_dict, dict):
121
+ list_or_dict = list_or_dict[key]
122
+ else:
123
+ list_or_dict = list_or_dict[int(key)]
124
+ except (KeyError, IndexError, ValueError) as e:
125
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
126
+
127
+ visited += [key]
128
+ # final expansion of retrieved value
129
+ if expand and callable(list_or_dict):
130
+ list_or_dict = list_or_dict()
131
+ parent[last_key] = list_or_dict
132
+ except KeyNotFoundError as e:
133
+ if default is None:
134
+ raise e
135
+ else:
136
+ list_or_dict = default
137
+ success = False
138
+
139
+ if not pass_success:
140
+ return list_or_dict
141
+ else:
142
+ return list_or_dict, success
143
+
144
+
145
+ if __name__ == "__main__":
146
+ config = {"keya": "a",
147
+ "keyb": "b",
148
+ "keyc":
149
+ {"cc1": 1,
150
+ "cc2": 2,
151
+ }
152
+ }
153
+ from omegaconf import OmegaConf
154
+ config = OmegaConf.create(config)
155
+ print(config)
156
+ retrieve(config, "keya")
157
+