EmaadKhwaja
commited on
Commit
•
5d2263b
1
Parent(s):
86d2765
file upload
Browse files- .DS_Store +0 -0
- app.py +1 -5
- celle/__init__.py +4 -0
- celle/attention.py +253 -0
- celle/celle.py +1060 -0
- celle/reversible.py +36 -0
- celle/transformer.py +213 -0
- celle/utils.py +193 -0
- celle/vae.py +112 -0
- celle_main.py +619 -0
- celle_taming_main.py +695 -0
- dataloader.py +321 -0
- prediction.py +267 -0
- requirements.txt +160 -0
- taming/lr_scheduler.py +34 -0
- taming/models/cond_transformer.py +349 -0
- taming/models/dummy_cond_stage.py +22 -0
- taming/models/vqgan.py +649 -0
- taming/modules/autoencoder/lpips/vgg.pth +3 -0
- taming/modules/diffusionmodules/model.py +776 -0
- taming/modules/discriminator/model.py +67 -0
- taming/modules/losses/__init__.py +2 -0
- taming/modules/losses/lpips.py +123 -0
- taming/modules/losses/segmentation.py +22 -0
- taming/modules/losses/vqperceptual.py +182 -0
- taming/modules/misc/coord.py +31 -0
- taming/modules/transformer/mingpt.py +415 -0
- taming/modules/transformer/permuter.py +248 -0
- taming/modules/util.py +130 -0
- taming/modules/vqvae/quantize.py +445 -0
- taming/util.py +157 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,3 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
-
|
6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
demo = gr.load("HuangLab/CELL-E_2_HPA_Finetuned_480", src="models")
|
|
|
|
|
|
|
|
celle/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from celle.celle import CELLE
|
2 |
+
from celle.vae import VQGanVAE
|
3 |
+
|
4 |
+
__version__ = "2.0.0"
|
celle/attention.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from rotary_embedding_torch import apply_rotary_emb
|
7 |
+
from celle.utils import exists, default, max_neg_value
|
8 |
+
|
9 |
+
|
10 |
+
# helpers
|
11 |
+
def stable_softmax(t, dim=-1, alpha=32**2):
|
12 |
+
t = t / alpha
|
13 |
+
t = t - torch.amax(t, dim=dim, keepdim=True).detach()
|
14 |
+
return (t * alpha).softmax(dim=dim)
|
15 |
+
|
16 |
+
|
17 |
+
def apply_pos_emb(pos_emb, qkv):
|
18 |
+
n = qkv[0].shape[-2]
|
19 |
+
pos_emb = pos_emb[..., :n, :]
|
20 |
+
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
|
21 |
+
|
22 |
+
|
23 |
+
# classes
|
24 |
+
class Attention(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
dim,
|
28 |
+
seq_len,
|
29 |
+
causal=False,
|
30 |
+
heads=8,
|
31 |
+
dim_head=64,
|
32 |
+
dropout=0.0,
|
33 |
+
stable=False,
|
34 |
+
static_mask=None,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
inner_dim = dim_head * heads
|
38 |
+
self.heads = heads
|
39 |
+
self.seq_len = seq_len
|
40 |
+
self.scale = dim_head**-0.5
|
41 |
+
self.stable = stable
|
42 |
+
self.causal = causal
|
43 |
+
self.register_buffer("static_mask", static_mask, persistent=False)
|
44 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
45 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
46 |
+
self.save_attn = nn.Identity()
|
47 |
+
|
48 |
+
def forward(self, x, context_mask=None, rotary_pos_emb=None):
|
49 |
+
# x: [batch_size, seq_len, dim]
|
50 |
+
b, n, _, h = *x.shape, self.heads
|
51 |
+
device = x.device
|
52 |
+
|
53 |
+
softmax = torch.softmax if not self.stable else stable_softmax
|
54 |
+
|
55 |
+
# qkv: 3 tensors of shape [batch_size, seq_len, inner_dim]
|
56 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
57 |
+
|
58 |
+
# q,k,v: [batch_size, heads, seq_len, dim_head]
|
59 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
60 |
+
|
61 |
+
if exists(rotary_pos_emb):
|
62 |
+
q, k, v = apply_pos_emb(rotary_pos_emb[..., :, :], (q, k, v))
|
63 |
+
|
64 |
+
q *= self.scale
|
65 |
+
|
66 |
+
# dots: [batch_size, heads, seq_len_i ,seq_len_j]
|
67 |
+
dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
|
68 |
+
mask_value = max_neg_value(dots)
|
69 |
+
|
70 |
+
if exists(context_mask):
|
71 |
+
# context_mask: [batch_size ,1 ,1 ,seq_len_j]
|
72 |
+
context_mask = rearrange(context_mask, "b j -> b 1 1 j")
|
73 |
+
context_mask = F.pad(context_mask, (1, 0), value=True)
|
74 |
+
|
75 |
+
mask_value = -torch.finfo(dots.dtype).max
|
76 |
+
dots = dots.masked_fill(~context_mask, mask_value)
|
77 |
+
|
78 |
+
if self.causal:
|
79 |
+
i, j = dots.shape[-2:]
|
80 |
+
context_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
|
81 |
+
dots.masked_fill_(context_mask, mask_value)
|
82 |
+
|
83 |
+
if exists(self.static_mask):
|
84 |
+
dots.masked_fill_(~self.static_mask[:n, :n], mask_value)
|
85 |
+
|
86 |
+
# attn: [batch_size ,heads ,seq_len_i ,seq_len_j]
|
87 |
+
attn = softmax(dots, dim=-1)
|
88 |
+
attn = self.save_attn(attn)
|
89 |
+
|
90 |
+
# out: [batch_size ,heads ,seq_len_i ,dim_head]
|
91 |
+
out = torch.einsum("b h n j, b h j d -> b h n d", attn, v)
|
92 |
+
|
93 |
+
# out: [batch_size ,seq_len_i ,(heads*dim_head)]
|
94 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
95 |
+
|
96 |
+
# out: [batch_size ,seq_len_i ,dim]
|
97 |
+
out = self.to_out(out)
|
98 |
+
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
|
103 |
+
|
104 |
+
|
105 |
+
class SparseConvCausalAttention(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
dim,
|
109 |
+
seq_len,
|
110 |
+
image_size=32,
|
111 |
+
kernel_size=5,
|
112 |
+
dilation=1,
|
113 |
+
heads=8,
|
114 |
+
dim_head=64,
|
115 |
+
dropout=0.0,
|
116 |
+
stable=False,
|
117 |
+
**kwargs,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
assert kernel_size % 2 == 1, "kernel size must be odd"
|
121 |
+
|
122 |
+
inner_dim = dim_head * heads
|
123 |
+
self.seq_len = seq_len
|
124 |
+
self.heads = heads
|
125 |
+
self.scale = dim_head**-0.5
|
126 |
+
self.image_size = image_size
|
127 |
+
self.kernel_size = kernel_size
|
128 |
+
self.dilation = dilation
|
129 |
+
|
130 |
+
self.stable = stable
|
131 |
+
|
132 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
133 |
+
|
134 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
135 |
+
|
136 |
+
def forward(self, x, mask=None, rotary_pos_emb=None):
|
137 |
+
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = (
|
138 |
+
*x.shape,
|
139 |
+
self.heads,
|
140 |
+
self.image_size,
|
141 |
+
self.kernel_size,
|
142 |
+
self.dilation,
|
143 |
+
self.seq_len,
|
144 |
+
x.device,
|
145 |
+
)
|
146 |
+
softmax = torch.softmax if not self.stable else stable_softmax
|
147 |
+
|
148 |
+
img_seq_len = img_size**2
|
149 |
+
text_len = seq_len + 1 - img_seq_len
|
150 |
+
|
151 |
+
# padding
|
152 |
+
|
153 |
+
padding = seq_len - n + 1
|
154 |
+
mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool())
|
155 |
+
|
156 |
+
x = F.pad(x, (0, 0, 0, padding), value=0)
|
157 |
+
mask = mask[:, :text_len]
|
158 |
+
|
159 |
+
# derive query / keys / values
|
160 |
+
|
161 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
162 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv)
|
163 |
+
|
164 |
+
if exists(rotary_pos_emb):
|
165 |
+
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
|
166 |
+
|
167 |
+
q *= self.scale
|
168 |
+
|
169 |
+
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
|
170 |
+
lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)
|
171 |
+
)
|
172 |
+
|
173 |
+
# text attention
|
174 |
+
|
175 |
+
dots_text = einsum("b i d, b j d -> b i j", q_text, k_text)
|
176 |
+
mask_value = max_neg_value(dots_text)
|
177 |
+
|
178 |
+
i, j = dots_text.shape[-2:]
|
179 |
+
text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
|
180 |
+
dots_text.masked_fill_(text_causal_mask, mask_value)
|
181 |
+
|
182 |
+
attn_text = softmax(dots_text, dim=-1)
|
183 |
+
out_text = einsum("b i j, b j d -> b i d", attn_text, v_text)
|
184 |
+
|
185 |
+
# image attention
|
186 |
+
|
187 |
+
effective_kernel_size = (kernel_size - 1) * dilation + 1
|
188 |
+
padding = effective_kernel_size // 2
|
189 |
+
|
190 |
+
k_img, v_img = map(
|
191 |
+
lambda t: rearrange(t, "b (h w) c -> b c h w", h=img_size), (k_img, v_img)
|
192 |
+
)
|
193 |
+
k_img, v_img = map(
|
194 |
+
lambda t: F.unfold(t, kernel_size, padding=padding, dilation=dilation),
|
195 |
+
(k_img, v_img),
|
196 |
+
)
|
197 |
+
k_img, v_img = map(
|
198 |
+
lambda t: rearrange(t, "b (d j) i -> b i j d", j=kernel_size**2),
|
199 |
+
(k_img, v_img),
|
200 |
+
)
|
201 |
+
|
202 |
+
# let image attend to all of text
|
203 |
+
|
204 |
+
dots_image = einsum("b i d, b i j d -> b i j", q_img, k_img)
|
205 |
+
dots_image_to_text = einsum("b i d, b j d -> b i j", q_img, k_text)
|
206 |
+
|
207 |
+
# calculate causal attention for local convolution
|
208 |
+
|
209 |
+
i, j = dots_image.shape[-2:]
|
210 |
+
img_seq = torch.arange(img_seq_len, device=device)
|
211 |
+
k_img_indices = rearrange(img_seq.float(), "(h w) -> () () h w", h=img_size)
|
212 |
+
k_img_indices = F.pad(
|
213 |
+
k_img_indices, (padding,) * 4, value=img_seq_len
|
214 |
+
) # padding set to be max, so it is never attended to
|
215 |
+
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation)
|
216 |
+
k_img_indices = rearrange(k_img_indices, "b j i -> b i j")
|
217 |
+
|
218 |
+
# mask image attention
|
219 |
+
|
220 |
+
q_img_indices = rearrange(img_seq, "i -> () i ()")
|
221 |
+
causal_mask = q_img_indices < k_img_indices
|
222 |
+
|
223 |
+
# concat text mask with image causal mask
|
224 |
+
|
225 |
+
causal_mask = repeat(causal_mask, "() i j -> b i j", b=b * h)
|
226 |
+
mask = repeat(mask, "b j -> (b h) i j", i=i, h=h)
|
227 |
+
mask = torch.cat((~mask, causal_mask), dim=-1)
|
228 |
+
|
229 |
+
# image can attend to all of text
|
230 |
+
|
231 |
+
dots = torch.cat((dots_image_to_text, dots_image), dim=-1)
|
232 |
+
dots.masked_fill_(mask, mask_value)
|
233 |
+
|
234 |
+
attn = softmax(dots, dim=-1)
|
235 |
+
|
236 |
+
# aggregate
|
237 |
+
|
238 |
+
attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
|
239 |
+
|
240 |
+
out_image_to_image = einsum("b i j, b i j d -> b i d", attn_image, v_img)
|
241 |
+
out_image_to_text = einsum("b i j, b j d -> b i d", attn_image_to_text, v_text)
|
242 |
+
|
243 |
+
out_image = out_image_to_image + out_image_to_text
|
244 |
+
|
245 |
+
# combine attended values for both text and image
|
246 |
+
|
247 |
+
out = torch.cat((out_text, out_image), dim=1)
|
248 |
+
|
249 |
+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
250 |
+
|
251 |
+
out = self.to_out(out)
|
252 |
+
|
253 |
+
return out[:, :n]
|
celle/celle.py
ADDED
@@ -0,0 +1,1060 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary packages and modules
|
2 |
+
from math import floor, ceil
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from axial_positional_embedding import AxialPositionalEmbedding
|
7 |
+
from einops import rearrange
|
8 |
+
from celle.utils import (
|
9 |
+
exists,
|
10 |
+
always,
|
11 |
+
eval_decorator,
|
12 |
+
gumbel_sample,
|
13 |
+
top_k,
|
14 |
+
gamma_func,
|
15 |
+
DivideMax,
|
16 |
+
)
|
17 |
+
|
18 |
+
# Import additional modules from within the codebase
|
19 |
+
from celle.transformer import Transformer
|
20 |
+
|
21 |
+
|
22 |
+
def generate_mask(gamma_func, batch_size, length, device):
|
23 |
+
# Get the number of `True` values in the mask for each batch element
|
24 |
+
num_true_values = floor(gamma_func(torch.rand(1)) * length)
|
25 |
+
|
26 |
+
# Generate a random sample of indices to set to `True` in the mask
|
27 |
+
# The number of indices in the sample is determined by `num_true_values`
|
28 |
+
indices = (
|
29 |
+
torch.rand((batch_size, length), device=device)
|
30 |
+
.topk(num_true_values, dim=1)
|
31 |
+
.indices
|
32 |
+
)
|
33 |
+
|
34 |
+
# Create a binary mask tensor with `True` values at the sampled indices
|
35 |
+
mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device)
|
36 |
+
mask.scatter_(dim=1, index=indices, value=True)
|
37 |
+
|
38 |
+
return mask
|
39 |
+
|
40 |
+
|
41 |
+
def match_batch_size(text, condition, image, batch_size):
|
42 |
+
"""
|
43 |
+
This function ensures all inputs to the sample function have the same batch size.
|
44 |
+
"""
|
45 |
+
if text.shape[0] != batch_size:
|
46 |
+
text = text.repeat(batch_size, 1)
|
47 |
+
|
48 |
+
if condition.shape[0] != batch_size:
|
49 |
+
condition = condition.repeat(batch_size, 1)
|
50 |
+
|
51 |
+
if image.shape[0] != batch_size:
|
52 |
+
image = image.repeat(batch_size, 1)
|
53 |
+
|
54 |
+
return text, condition, image
|
55 |
+
|
56 |
+
|
57 |
+
def calc_unmask_probs(timestep, timesteps, gamma_func):
|
58 |
+
if timestep == 1 or timesteps == 1:
|
59 |
+
unmask_prob = 1
|
60 |
+
else:
|
61 |
+
unmask_prob = 1 - gamma_func(timestep)
|
62 |
+
return unmask_prob
|
63 |
+
|
64 |
+
|
65 |
+
def calculate_logits(
|
66 |
+
input_tokens, input_mask, logits_function, filter_thres, temperature
|
67 |
+
):
|
68 |
+
logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False)
|
69 |
+
filtered_logits = top_k(logits, thres=filter_thres)
|
70 |
+
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
|
71 |
+
|
72 |
+
return logits, sample
|
73 |
+
|
74 |
+
|
75 |
+
def unmask_tokens(
|
76 |
+
input_tokens,
|
77 |
+
input_mask,
|
78 |
+
num_masked_tokens,
|
79 |
+
logits,
|
80 |
+
sample,
|
81 |
+
timestep,
|
82 |
+
timesteps,
|
83 |
+
gamma,
|
84 |
+
filter_func=None,
|
85 |
+
pad_token=None,
|
86 |
+
mask_token=None,
|
87 |
+
force_aas=True,
|
88 |
+
):
|
89 |
+
sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf)
|
90 |
+
if filter_func:
|
91 |
+
sample = filter_func(
|
92 |
+
input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token
|
93 |
+
)
|
94 |
+
selected_token_probs, selected_tokens = torch.max(sample, dim=-1)
|
95 |
+
|
96 |
+
unmask_prob = calc_unmask_probs(timestep, timesteps, gamma)
|
97 |
+
num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens))
|
98 |
+
|
99 |
+
_, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1)
|
100 |
+
|
101 |
+
sample_mask = torch.zeros(
|
102 |
+
input_tokens.shape, dtype=torch.bool, device=input_tokens.device
|
103 |
+
)
|
104 |
+
sample_mask.scatter_(dim=1, index=top_k_indices, value=True)
|
105 |
+
|
106 |
+
unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens)
|
107 |
+
full_logits = torch.where(
|
108 |
+
sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits)
|
109 |
+
)
|
110 |
+
return unmasked_tokens, full_logits
|
111 |
+
|
112 |
+
|
113 |
+
def suppress_invalid_text_tokens(
|
114 |
+
text,
|
115 |
+
logits,
|
116 |
+
start_token=None,
|
117 |
+
end_token=None,
|
118 |
+
pad_token=None,
|
119 |
+
mask_token=None,
|
120 |
+
force_aas=False,
|
121 |
+
):
|
122 |
+
# Find the indices of start_token and end_token in tensor text along axis=1
|
123 |
+
idx_start = (text == start_token).nonzero(as_tuple=True)[1]
|
124 |
+
idx_end = (text == end_token).nonzero(as_tuple=True)[1]
|
125 |
+
|
126 |
+
# For every position other than the index corresponding to the start index, set the values on the start index of dimension=2 to -torch.inf
|
127 |
+
if idx_start.nelement() != start_token:
|
128 |
+
try:
|
129 |
+
mask = idx_start.unsqueeze(1) != torch.arange(
|
130 |
+
logits.size(1), device=text.device
|
131 |
+
)
|
132 |
+
indices = torch.where(mask)
|
133 |
+
logits[indices[0], indices[1], start_token] = -torch.inf
|
134 |
+
except:
|
135 |
+
pass
|
136 |
+
|
137 |
+
# else:
|
138 |
+
# idx_start = torch.zeros(text.size(0), dtype=torch.long)
|
139 |
+
|
140 |
+
# Similarly, for every position other than the index corresponding to the end index, set the values on the end index of dimension=2 to -torch.inf
|
141 |
+
if idx_end.nelement() != 0:
|
142 |
+
try:
|
143 |
+
mask = idx_end.unsqueeze(1) != torch.arange(
|
144 |
+
logits.size(1), device=text.device
|
145 |
+
)
|
146 |
+
indices = torch.where(mask)
|
147 |
+
logits[indices[0], indices[1], end_token] = -torch.inf
|
148 |
+
except:
|
149 |
+
pass
|
150 |
+
|
151 |
+
# else:
|
152 |
+
# idx_end = torch.full((text.size(0),), text.size(1) - 1, dtype=torch.long)
|
153 |
+
|
154 |
+
if pad_token:
|
155 |
+
if idx_start.nelement() != 0 and idx_end.nelement() != 0:
|
156 |
+
try:
|
157 |
+
# For every position between the indices of start_token and end_token, set the values for 1st index of dimension=2 equal to -torch.inf. Any value outside of that range should be set to torch.inf.
|
158 |
+
mask = (
|
159 |
+
torch.arange(logits.size(1), device=text.device)
|
160 |
+
>= idx_start.unsqueeze(1)
|
161 |
+
) & (
|
162 |
+
torch.arange(logits.size(1), device=text.device)
|
163 |
+
<= idx_end.unsqueeze(1)
|
164 |
+
)
|
165 |
+
|
166 |
+
indices = torch.where(mask)
|
167 |
+
logits[indices[0], indices[1], pad_token] = -torch.inf
|
168 |
+
|
169 |
+
indices = torch.where(~mask)
|
170 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
171 |
+
|
172 |
+
except:
|
173 |
+
pass
|
174 |
+
|
175 |
+
elif idx_start.nelement() != 0:
|
176 |
+
try:
|
177 |
+
mask = torch.arange(
|
178 |
+
logits.size(1), device=text.device
|
179 |
+
) < idx_start.unsqueeze(1)
|
180 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
181 |
+
except:
|
182 |
+
pass
|
183 |
+
|
184 |
+
elif idx_end.nelement() != 0:
|
185 |
+
try:
|
186 |
+
mask = torch.arange(
|
187 |
+
logits.size(1), device=text.device
|
188 |
+
) > idx_end.unsqueeze(1)
|
189 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
190 |
+
except:
|
191 |
+
pass
|
192 |
+
|
193 |
+
if force_aas:
|
194 |
+
if pad_token:
|
195 |
+
logits[:, :, pad_token] = -torch.inf
|
196 |
+
logits[:, :, 3] = -torch.inf
|
197 |
+
logits[:, :, 29:] = -torch.inf
|
198 |
+
|
199 |
+
if mask_token:
|
200 |
+
logits[:, :, mask_token] = -torch.inf
|
201 |
+
|
202 |
+
return logits
|
203 |
+
|
204 |
+
|
205 |
+
def detokenize_text(text_embedding, sequence):
|
206 |
+
if text_embedding == "esm1b" or text_embedding == "esm2":
|
207 |
+
from esm import Alphabet
|
208 |
+
|
209 |
+
alphabet = (
|
210 |
+
Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
assert NameError("Detokenization only available for ESM mdodels")
|
214 |
+
|
215 |
+
output_seqs = []
|
216 |
+
|
217 |
+
for batch in sequence:
|
218 |
+
converted_seq = [alphabet[idx] for idx in batch]
|
219 |
+
converted_seq = "".join(converted_seq)
|
220 |
+
output_seqs.append(converted_seq)
|
221 |
+
|
222 |
+
return output_seqs
|
223 |
+
|
224 |
+
class ImageEmbedding(nn.Module):
|
225 |
+
def __init__(self, num_tokens, dim):
|
226 |
+
super(ImageEmbedding, self).__init__()
|
227 |
+
self.image_embedding = nn.Embedding(num_tokens, dim)
|
228 |
+
|
229 |
+
def forward(self, image):
|
230 |
+
return self.image_embedding(image)
|
231 |
+
|
232 |
+
|
233 |
+
class ModelExtender(nn.Module):
|
234 |
+
def __init__(self, vocab, out_features, fixed_embedding=False):
|
235 |
+
super(ModelExtender, self).__init__()
|
236 |
+
|
237 |
+
# Initialize the model according to the given vocabulary
|
238 |
+
self.vocab = vocab
|
239 |
+
|
240 |
+
if vocab == "esm1b":
|
241 |
+
from esm import pretrained
|
242 |
+
|
243 |
+
self.model, _ = pretrained.esm1b_t33_650M_UR50S()
|
244 |
+
self.in_features = 1280
|
245 |
+
elif vocab == "esm2":
|
246 |
+
from esm import pretrained
|
247 |
+
|
248 |
+
if out_features == 320:
|
249 |
+
self.model, _ = pretrained.esm2_t6_8M_UR50D()
|
250 |
+
elif out_features == 480:
|
251 |
+
self.model, _ = pretrained.esm2_t12_35M_UR50D()
|
252 |
+
elif out_features == 640:
|
253 |
+
self.model, _ = pretrained.esm2_t30_150M_UR50D()
|
254 |
+
elif out_features == 1280:
|
255 |
+
self.model, _ = pretrained.esm2_t33_650M_UR50D()
|
256 |
+
elif out_features == 2560:
|
257 |
+
self.model, _ = pretrained.esm2_t36_3B_UR50D()
|
258 |
+
else:
|
259 |
+
self.model, _ = pretrained.esm2_t33_650M_UR50D()
|
260 |
+
self.in_features = self.model.embed_dim
|
261 |
+
|
262 |
+
# Set the number of output features and initialize the scaling layer
|
263 |
+
self.out_features = out_features
|
264 |
+
self.scale_layer = nn.Linear(self.in_features, self.out_features)
|
265 |
+
|
266 |
+
# Determine whether to freeze the model's parameters
|
267 |
+
self.fixed_embedding = fixed_embedding
|
268 |
+
if self.fixed_embedding:
|
269 |
+
self.model = self.model.eval()
|
270 |
+
|
271 |
+
def forward(self, x, **kwargs):
|
272 |
+
# If the model's parameters are fixed, use torch.no_grad()
|
273 |
+
if self.fixed_embedding:
|
274 |
+
with torch.no_grad():
|
275 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
276 |
+
# Reduce sequence length dimension, get top layer representation tensor
|
277 |
+
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
|
278 |
+
"representations"
|
279 |
+
][self.model.num_layers]
|
280 |
+
# Tensor shape: (batch_size, hidden_size)
|
281 |
+
else:
|
282 |
+
# Get top layer representation tensor
|
283 |
+
x = self.model(x, **kwargs)[0]
|
284 |
+
# Tensor shape: (batch_size, sequence_length, hidden_size)
|
285 |
+
else:
|
286 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
287 |
+
# Reduce sequence length dimension, get top layer representation tensor
|
288 |
+
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
|
289 |
+
"representations"
|
290 |
+
][self.model.num_layers]
|
291 |
+
# Tensor shape: (batch_size, hidden_size)
|
292 |
+
else:
|
293 |
+
# Get top layer representation tensor
|
294 |
+
x = self.model(x, **kwargs)[0]
|
295 |
+
# Tensor shape: (batch_size, sequence_length, hidden_size)
|
296 |
+
|
297 |
+
# Scale the representation tensor if necessary
|
298 |
+
if self.out_features != self.in_features:
|
299 |
+
x = self.scale_layer(x)
|
300 |
+
# Tensor shape: (batch_size, out_features)
|
301 |
+
|
302 |
+
return x
|
303 |
+
|
304 |
+
class CELLE(nn.Module):
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
*,
|
308 |
+
dim,
|
309 |
+
vae, # The VAE model used to encode/decode images
|
310 |
+
condition_vae=None, # An optional VAE model used to condition the image generation
|
311 |
+
num_images=2, # Number of images to generate
|
312 |
+
num_text_tokens=30, # Number of tokens in the text vocabulary
|
313 |
+
text_seq_len=1000, # Maximum length of input text sequence
|
314 |
+
depth=16, # Number of layers in the transformer model
|
315 |
+
heads=16, # Number of attention heads
|
316 |
+
dim_head=64, # Dimensionality of each attention head
|
317 |
+
attn_dropout=0.1, # Dropout rate for attention weights
|
318 |
+
ff_dropout=0.1, # Dropout rate for feedforward layers
|
319 |
+
attn_types=None, # Types of attention to use in the transformer
|
320 |
+
causal=False, # Whether to use causal attention
|
321 |
+
loss_cond_weight=1, # Weight of conditioning loss
|
322 |
+
loss_img_weight=1, # Weight of image generation loss
|
323 |
+
stable=False, # Whether to use divide-by-max normalization in the transformer
|
324 |
+
rotary_emb=True, # Whether to use rotary positional embeddings
|
325 |
+
text_embedding="esm2", # Text embedding to use (esm1b, esm2)
|
326 |
+
fixed_embedding=True, # Whether to fix the text embedding or learn it
|
327 |
+
sampling_mode="cosine", # Sampling mode for the VAE
|
328 |
+
linear_project=False, # Whether to project embeddings linearly
|
329 |
+
**kwargs,
|
330 |
+
):
|
331 |
+
super().__init__()
|
332 |
+
|
333 |
+
# Set the stable flag
|
334 |
+
self.stable = stable
|
335 |
+
|
336 |
+
# If the stable flag is set, initialize the DivideMax layer for normalization
|
337 |
+
if stable:
|
338 |
+
self.norm_by_max = DivideMax(dim=-1)
|
339 |
+
|
340 |
+
### Initializing text parameters ###
|
341 |
+
|
342 |
+
# Initialize the text and fixed embeddings
|
343 |
+
self.text_embedding = text_embedding
|
344 |
+
self.fixed_embedding = fixed_embedding
|
345 |
+
|
346 |
+
# Offset logits index and calculate cross entropy loss
|
347 |
+
self.num_text_tokens = num_text_tokens
|
348 |
+
self.linear_project = linear_project
|
349 |
+
|
350 |
+
# Add <BOS> and <EOS> tokens to the beginning and end of text sequences
|
351 |
+
if text_embedding.lower() in ("esm1b", "esm2"):
|
352 |
+
self.text_seq_len = text_seq_len + 2
|
353 |
+
else:
|
354 |
+
self.text_seq_len = text_seq_len
|
355 |
+
|
356 |
+
# Initialize embeddings for <SEP> token
|
357 |
+
self.sep_emb = nn.Embedding(1, dim)
|
358 |
+
|
359 |
+
# Initialize positional embeddings for text sequences and <SEP> token
|
360 |
+
self.text_pos_emb = (
|
361 |
+
nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0)
|
362 |
+
) # +1 for <SEP>
|
363 |
+
|
364 |
+
### ###
|
365 |
+
|
366 |
+
self.num_images = num_images
|
367 |
+
|
368 |
+
### Initializing condition parameters ###
|
369 |
+
|
370 |
+
# Initialize the number of condition tokens, condition sequence length, and condition embedding
|
371 |
+
if exists(condition_vae):
|
372 |
+
condition_size = condition_vae.image_size
|
373 |
+
num_condition_tokens = condition_vae.num_tokens
|
374 |
+
self.num_condition_tokens = num_condition_tokens
|
375 |
+
condition_fmap_size = condition_vae.image_size // (
|
376 |
+
2**condition_vae.num_layers
|
377 |
+
)
|
378 |
+
condition_seq_len = condition_fmap_size**2
|
379 |
+
|
380 |
+
# Initialize ImageEmbedding for condition embedding
|
381 |
+
self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim)
|
382 |
+
|
383 |
+
# Initialize positional embeddings for condition embedding
|
384 |
+
self.condition_pos_emb = (
|
385 |
+
AxialPositionalEmbedding(
|
386 |
+
dim, axial_shape=(condition_fmap_size, condition_fmap_size)
|
387 |
+
)
|
388 |
+
if not rotary_emb
|
389 |
+
else always(0)
|
390 |
+
)
|
391 |
+
|
392 |
+
else:
|
393 |
+
condition_fmap_size = 0
|
394 |
+
condition_seq_len = 0
|
395 |
+
num_condition_tokens = 0
|
396 |
+
|
397 |
+
### ####
|
398 |
+
|
399 |
+
### Initializing image parameters ###
|
400 |
+
|
401 |
+
# Initialize the image size, image token size, and sequence length
|
402 |
+
self.image_size = vae.image_size
|
403 |
+
num_image_tokens = vae.num_tokens
|
404 |
+
image_fmap_size = vae.image_size // (2**vae.num_layers)
|
405 |
+
image_seq_len = image_fmap_size**2
|
406 |
+
self.image_seq_len = image_seq_len
|
407 |
+
self.num_image_tokens = num_image_tokens
|
408 |
+
|
409 |
+
# Initialize ImageEmbedding and positional embeddings for image embedding
|
410 |
+
self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) # +1 for <IM_MASK>
|
411 |
+
|
412 |
+
self.image_pos_emb = (
|
413 |
+
AxialPositionalEmbedding(
|
414 |
+
dim, axial_shape=(image_fmap_size, image_fmap_size)
|
415 |
+
)
|
416 |
+
if not rotary_emb
|
417 |
+
else always(0)
|
418 |
+
)
|
419 |
+
|
420 |
+
# Set total sequence length and total tokens
|
421 |
+
self.num_condition_tokens = num_condition_tokens
|
422 |
+
self.condition_seq_len = condition_seq_len
|
423 |
+
# Text Length + <SEP> + Condition Tokens + Image Tokens
|
424 |
+
seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len
|
425 |
+
total_tokens = (
|
426 |
+
num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1
|
427 |
+
)
|
428 |
+
self.total_tokens = total_tokens
|
429 |
+
self.total_seq_len = seq_len
|
430 |
+
|
431 |
+
# Set the VAE and condition VAE for the model
|
432 |
+
self.vae = vae.eval()
|
433 |
+
self.condition_vae = condition_vae.eval()
|
434 |
+
|
435 |
+
### ###
|
436 |
+
|
437 |
+
### Setting discrete ids ###
|
438 |
+
# Initialize text embedding based on the given text_embedding parameter
|
439 |
+
if text_embedding == "esm1b" or text_embedding == "esm2":
|
440 |
+
self.text_mask_token = 32
|
441 |
+
self.pad_token = 1
|
442 |
+
self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding)
|
443 |
+
else:
|
444 |
+
raise ValueError("Only ESM models are supported.")
|
445 |
+
|
446 |
+
# Set token indices for text, condition, and image sequences
|
447 |
+
self.sep_token = num_text_tokens
|
448 |
+
self.cond_mask_token = num_condition_tokens
|
449 |
+
self.image_mask_token = num_image_tokens
|
450 |
+
|
451 |
+
# Create indices for sequence and logits dimensions
|
452 |
+
self.seq_range = torch.arange(seq_len)
|
453 |
+
self.logits_range = torch.arange(total_tokens)
|
454 |
+
|
455 |
+
# Reshape sequence and logits indices
|
456 |
+
self.seq_range = rearrange(self.seq_range, "n -> () n ()")
|
457 |
+
self.logits_range = rearrange(self.logits_range, "d -> () () d")
|
458 |
+
|
459 |
+
# Create a mask to exclude invalid token positions from the model output
|
460 |
+
# e.g. no image tokens where sequence tokens should be
|
461 |
+
logits_mask = (
|
462 |
+
# Mask text tokens beyond text_seq_len and invalid logits_range
|
463 |
+
(
|
464 |
+
(self.seq_range < self.text_seq_len)
|
465 |
+
& (self.logits_range < num_text_tokens)
|
466 |
+
& (self.logits_range != self.text_mask_token)
|
467 |
+
)
|
468 |
+
|
|
469 |
+
# Mask [SEP] token after text
|
470 |
+
(
|
471 |
+
(self.seq_range == self.text_seq_len)
|
472 |
+
& (self.logits_range == num_text_tokens)
|
473 |
+
)
|
474 |
+
|
|
475 |
+
# Mask condition tokens beyond text_seq_len+1 ([SEP]) and invalid logits_range
|
476 |
+
(
|
477 |
+
(self.seq_range >= self.text_seq_len + 1)
|
478 |
+
& (self.seq_range < self.text_seq_len + 1 + condition_seq_len)
|
479 |
+
& (self.logits_range >= num_text_tokens + 1)
|
480 |
+
& (self.logits_range < num_text_tokens + 1 + num_condition_tokens)
|
481 |
+
)
|
482 |
+
|
|
483 |
+
# Mask image tokens beyond num_text_tokens+num_condition_tokens+1
|
484 |
+
(
|
485 |
+
(self.seq_range >= self.text_seq_len + 1 + condition_seq_len)
|
486 |
+
& (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1)
|
487 |
+
& (
|
488 |
+
self.logits_range
|
489 |
+
< num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens
|
490 |
+
)
|
491 |
+
)
|
492 |
+
)
|
493 |
+
|
494 |
+
# Invert the mask
|
495 |
+
logits_mask = ~logits_mask
|
496 |
+
|
497 |
+
# Register the buffer with the logits_mask
|
498 |
+
self.register_buffer("logits_mask", logits_mask, persistent=False)
|
499 |
+
|
500 |
+
### ###
|
501 |
+
|
502 |
+
# Initialize the Transformer model with given parameters
|
503 |
+
self.transformer = Transformer(
|
504 |
+
dim=dim,
|
505 |
+
causal=causal,
|
506 |
+
seq_len=seq_len,
|
507 |
+
depth=depth,
|
508 |
+
heads=heads,
|
509 |
+
dim_head=dim_head,
|
510 |
+
attn_dropout=attn_dropout,
|
511 |
+
ff_dropout=ff_dropout,
|
512 |
+
image_fmap_size=image_fmap_size + condition_fmap_size,
|
513 |
+
num_images=num_images,
|
514 |
+
stable=stable,
|
515 |
+
rotary_emb=rotary_emb,
|
516 |
+
)
|
517 |
+
|
518 |
+
# Initialize the linear layers for converting transformer output to logits
|
519 |
+
self.to_logits = nn.Sequential(
|
520 |
+
nn.LayerNorm(dim),
|
521 |
+
nn.Linear(dim, self.total_tokens),
|
522 |
+
)
|
523 |
+
|
524 |
+
# Set instance variables for weights and critic
|
525 |
+
self.loss_img_weight = loss_img_weight
|
526 |
+
self.loss_cond_weight = loss_cond_weight
|
527 |
+
self.gamma = gamma_func(sampling_mode)
|
528 |
+
|
529 |
+
def embed_and_transform(self, inputs, masks, return_encoding=False):
|
530 |
+
text, condition, image = inputs
|
531 |
+
device = text.device
|
532 |
+
text_mask, _, image_mask = masks
|
533 |
+
|
534 |
+
text_labels = text.clone()
|
535 |
+
text = torch.where(
|
536 |
+
text_mask, self.text_mask_token * torch.ones_like(text, device=device), text
|
537 |
+
)
|
538 |
+
|
539 |
+
tokens = self.text_emb(text)
|
540 |
+
|
541 |
+
# Add SEP token
|
542 |
+
|
543 |
+
sep_token_emb = self.sep_emb(
|
544 |
+
torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device)
|
545 |
+
)
|
546 |
+
tokens = torch.cat((tokens, sep_token_emb), dim=1)
|
547 |
+
tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device))
|
548 |
+
|
549 |
+
with torch.no_grad():
|
550 |
+
if self.linear_project:
|
551 |
+
b = condition.shape[0]
|
552 |
+
condition, _, [_, _, condition_labels] = self.condition_vae.encode(
|
553 |
+
condition
|
554 |
+
)
|
555 |
+
condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b)
|
556 |
+
|
557 |
+
else:
|
558 |
+
condition_labels = condition
|
559 |
+
if condition.dtype == torch.float:
|
560 |
+
condition_labels = self.condition_vae.get_codebook_indices(
|
561 |
+
condition
|
562 |
+
)
|
563 |
+
condition = condition_labels.clone()
|
564 |
+
|
565 |
+
condition_emb = self.condition_emb(condition)
|
566 |
+
condition_emb += self.condition_pos_emb(condition_emb)
|
567 |
+
tokens = torch.cat((tokens, condition_emb), dim=1)
|
568 |
+
|
569 |
+
with torch.no_grad():
|
570 |
+
if self.linear_project:
|
571 |
+
b = image.shape[0]
|
572 |
+
image, _, [_, _, image_labels] = self.vae.encode(image)
|
573 |
+
image_labels = rearrange(image_labels, "(b n) -> b n", b=b)
|
574 |
+
|
575 |
+
else:
|
576 |
+
image_labels = image
|
577 |
+
if image.dtype == torch.float:
|
578 |
+
image_labels = self.vae.get_codebook_indices(image)
|
579 |
+
image = torch.where(
|
580 |
+
image_mask,
|
581 |
+
self.image_mask_token
|
582 |
+
* torch.ones_like(image_labels, device=device),
|
583 |
+
image_labels,
|
584 |
+
)
|
585 |
+
|
586 |
+
image_emb = self.image_emb(image)
|
587 |
+
|
588 |
+
image_emb += self.image_pos_emb(image_emb)
|
589 |
+
tokens = torch.cat((tokens, image_emb), dim=1)
|
590 |
+
|
591 |
+
if self.stable:
|
592 |
+
alpha = 0.1
|
593 |
+
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
|
594 |
+
|
595 |
+
out = self.transformer(tokens)
|
596 |
+
|
597 |
+
if self.stable:
|
598 |
+
out = self.norm_by_max(out)
|
599 |
+
|
600 |
+
logits = self.to_logits(out)
|
601 |
+
|
602 |
+
max_neg_value = -torch.finfo(logits.dtype).max
|
603 |
+
logits.masked_fill_(self.logits_mask, max_neg_value)
|
604 |
+
|
605 |
+
if return_encoding:
|
606 |
+
return logits, out, [text_labels, condition_labels, image_labels]
|
607 |
+
else:
|
608 |
+
return logits, None, [text_labels, condition_labels, image_labels]
|
609 |
+
|
610 |
+
def forward(
|
611 |
+
self,
|
612 |
+
text,
|
613 |
+
condition=None,
|
614 |
+
image=None,
|
615 |
+
return_loss=False,
|
616 |
+
return_encoding=False,
|
617 |
+
):
|
618 |
+
batch_size, device = text.shape[0], text.device
|
619 |
+
|
620 |
+
# Check that image is supplied when training
|
621 |
+
assert exists(image), "when training, image must be supplied"
|
622 |
+
|
623 |
+
# Check that image dimensions match the expected dimensions
|
624 |
+
assert tuple(image.shape[1:]) == (
|
625 |
+
self.vae.channels,
|
626 |
+
self.image_size,
|
627 |
+
self.image_size,
|
628 |
+
), f"invalid image of dimensions {image.shape} passed in during training"
|
629 |
+
|
630 |
+
# Generate masks for text, condition, and image
|
631 |
+
|
632 |
+
# text_mask = generate_mask(self.gamma, batch_size, self.text_seq_len, device)
|
633 |
+
|
634 |
+
text_mask = generate_mask(
|
635 |
+
gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device
|
636 |
+
)
|
637 |
+
|
638 |
+
image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device)
|
639 |
+
|
640 |
+
# Embed and transform inputs
|
641 |
+
logits, _, labels = self.embed_and_transform(
|
642 |
+
[text, condition, image],
|
643 |
+
[text_mask, None, image_mask],
|
644 |
+
return_encoding,
|
645 |
+
device,
|
646 |
+
)
|
647 |
+
|
648 |
+
# If not returning loss, return the logits
|
649 |
+
if not return_loss:
|
650 |
+
return logits
|
651 |
+
|
652 |
+
# Separate labels
|
653 |
+
text, condition, image = labels
|
654 |
+
|
655 |
+
# Add SEP token to end of text label
|
656 |
+
sep_token = torch.tensor(self.sep_token, device=device).repeat(
|
657 |
+
labels.shape[0], 1
|
658 |
+
)
|
659 |
+
labels = torch.cat([labels, sep_token], dim=1)
|
660 |
+
|
661 |
+
# If condition exists and condition vae is defined, add the condition to the labels
|
662 |
+
if exists(condition) and exists(self.condition_vae):
|
663 |
+
offsetted_condition = condition + self.num_text_tokens + 1
|
664 |
+
labels = torch.cat((labels, offsetted_condition), dim=1)
|
665 |
+
|
666 |
+
# Add image to the labels
|
667 |
+
offsetted_image = (
|
668 |
+
image + self.num_text_tokens + 1 + self.num_condition_tokens + 1
|
669 |
+
)
|
670 |
+
labels = torch.cat((labels, offsetted_image), dim=1)
|
671 |
+
|
672 |
+
# Rearrange logits for cross-entropy loss calculation
|
673 |
+
# Logits size: (batch_size, vocab_size, total_seq_len)
|
674 |
+
# Labels size: (batch_size, total_seq_len)
|
675 |
+
logits = rearrange(logits, "b n c -> b c n")
|
676 |
+
|
677 |
+
# Calculate cross-entropy loss for text and image
|
678 |
+
loss_text = F.cross_entropy(
|
679 |
+
logits[:, :, : self.text_seq_len],
|
680 |
+
labels[:, : self.text_seq_len],
|
681 |
+
reduction="none",
|
682 |
+
)[text_mask].mean()
|
683 |
+
|
684 |
+
loss_img = F.cross_entropy(
|
685 |
+
logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :],
|
686 |
+
labels[:, self.text_seq_len + 1 + self.condition_seq_len :],
|
687 |
+
reduction="none",
|
688 |
+
)[image_mask].mean()
|
689 |
+
|
690 |
+
# Calculate total loss
|
691 |
+
loss = (loss_text + self.loss_img_weight * loss_img) / (
|
692 |
+
self.loss_img_weight + 1
|
693 |
+
)
|
694 |
+
|
695 |
+
loss_dict = {
|
696 |
+
"loss_text": loss_text,
|
697 |
+
# "loss_cond": loss_cond,
|
698 |
+
"loss_img": loss_img,
|
699 |
+
"loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0),
|
700 |
+
}
|
701 |
+
|
702 |
+
return loss, loss_dict, None
|
703 |
+
|
704 |
+
def create_tensors(self, text, condition, image):
|
705 |
+
"""
|
706 |
+
This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function.
|
707 |
+
"""
|
708 |
+
device = next(
|
709 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
710 |
+
None,
|
711 |
+
).device
|
712 |
+
|
713 |
+
if not isinstance(text, torch.Tensor):
|
714 |
+
text = (
|
715 |
+
torch.ones(1, self.text_seq_len, device=device, dtype=torch.long)
|
716 |
+
* self.text_mask_token
|
717 |
+
)
|
718 |
+
|
719 |
+
if not isinstance(condition, torch.Tensor):
|
720 |
+
condition = (
|
721 |
+
torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long)
|
722 |
+
* self.cond_mask_token
|
723 |
+
)
|
724 |
+
else:
|
725 |
+
with torch.no_grad():
|
726 |
+
condition = self.condition_vae.get_codebook_indices(condition)
|
727 |
+
|
728 |
+
if not isinstance(image, torch.Tensor):
|
729 |
+
image = (
|
730 |
+
torch.ones(1, self.image_seq_len, device=device, dtype=torch.long)
|
731 |
+
* self.image_mask_token
|
732 |
+
)
|
733 |
+
else:
|
734 |
+
with torch.no_grad():
|
735 |
+
image = self.vae.get_codebook_indices(image)
|
736 |
+
|
737 |
+
return text, condition, image
|
738 |
+
|
739 |
+
@torch.no_grad()
|
740 |
+
@eval_decorator
|
741 |
+
def sample(
|
742 |
+
self,
|
743 |
+
text=None,
|
744 |
+
condition=None,
|
745 |
+
image=None,
|
746 |
+
temperature=1.0,
|
747 |
+
filter_thres=0.9,
|
748 |
+
progress=False,
|
749 |
+
timesteps=1,
|
750 |
+
force_aas=True,
|
751 |
+
):
|
752 |
+
# ensure timesteps is a positive integer
|
753 |
+
assert int(timesteps) > 0
|
754 |
+
# set model and VAEs to evaluation mode
|
755 |
+
self.eval()
|
756 |
+
vae = self.vae.eval()
|
757 |
+
if progress == True:
|
758 |
+
progress = tqdm
|
759 |
+
else:
|
760 |
+
progress = lambda x: x
|
761 |
+
|
762 |
+
|
763 |
+
# ensure that at least one of text, condition, or image is supplied
|
764 |
+
assert (
|
765 |
+
isinstance(text, torch.Tensor)
|
766 |
+
or isinstance(condition, torch.Tensor)
|
767 |
+
or isinstance(image, torch.Tensor)
|
768 |
+
), "some data must be supplied"
|
769 |
+
|
770 |
+
# convert text, condition, and image to tensors if they aren't already
|
771 |
+
text, condition, image = self.create_tensors(text, condition, image)
|
772 |
+
|
773 |
+
# determine the maximum batch size of the input tensors
|
774 |
+
batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
|
775 |
+
|
776 |
+
# match the batch sizes of text, condition, and image
|
777 |
+
text, condition, image = match_batch_size(text, condition, image, batch_size)
|
778 |
+
|
779 |
+
# determine the device of the tensors
|
780 |
+
device = next(
|
781 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
782 |
+
None,
|
783 |
+
).device
|
784 |
+
|
785 |
+
assert text.shape[0] == condition.shape[0] == image.shape[0]
|
786 |
+
|
787 |
+
# Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
|
788 |
+
|
789 |
+
# full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
|
790 |
+
full_text_logits = torch.zeros(
|
791 |
+
batch_size, self.text_seq_len, self.num_text_tokens
|
792 |
+
).to(device)
|
793 |
+
|
794 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
795 |
+
full_text_logits = full_text_logits.scatter_(
|
796 |
+
dim=-1, index=text.unsqueeze(-1), value=1
|
797 |
+
)
|
798 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
799 |
+
full_image_logits = torch.zeros(
|
800 |
+
batch_size, self.image_seq_len, self.num_image_tokens + 1
|
801 |
+
).to(device)
|
802 |
+
|
803 |
+
# Remove the last token from each image sequence by setting full_image_logits to its first num_image_tokens elements
|
804 |
+
full_image_logits = full_image_logits.scatter_(
|
805 |
+
dim=-1, index=image.unsqueeze(-1), value=1
|
806 |
+
)
|
807 |
+
|
808 |
+
# cut off mask token
|
809 |
+
full_image_logits = full_image_logits[:, :, : self.num_image_tokens]
|
810 |
+
|
811 |
+
count = 0
|
812 |
+
|
813 |
+
for timestep in progress(torch.linspace(0, 1, timesteps)):
|
814 |
+
# Create masks for the text, condition, and image tensors
|
815 |
+
text_mask = text == self.text_mask_token
|
816 |
+
cond_mask = condition == self.cond_mask_token
|
817 |
+
image_mask = image == self.image_mask_token
|
818 |
+
|
819 |
+
# Calculate logits and samples using the calculate_logits function
|
820 |
+
logits, sample = calculate_logits(
|
821 |
+
[text, condition, image],
|
822 |
+
[text_mask, cond_mask, image_mask],
|
823 |
+
self.embed_and_transform,
|
824 |
+
filter_thres,
|
825 |
+
temperature,
|
826 |
+
)
|
827 |
+
|
828 |
+
# Calculate the number of masked tokens in the text and image tensors
|
829 |
+
num_masked_text_tokens = torch.sum(text_mask, dim=1)[0]
|
830 |
+
num_masked_image_tokens = torch.sum(image_mask, dim=1)[0]
|
831 |
+
|
832 |
+
# If there are masked text tokens, unmask them using unmask_tokens and fill the full text logits tensor with -inf for unmasked tokens
|
833 |
+
if num_masked_text_tokens.any() > 0:
|
834 |
+
text, full_text_logits = unmask_tokens(
|
835 |
+
text,
|
836 |
+
text_mask,
|
837 |
+
num_masked_text_tokens,
|
838 |
+
logits[:, : self.text_seq_len, : self.num_text_tokens],
|
839 |
+
sample[:, : self.text_seq_len, : self.num_text_tokens],
|
840 |
+
timestep,
|
841 |
+
timesteps,
|
842 |
+
self.gamma,
|
843 |
+
suppress_invalid_text_tokens,
|
844 |
+
self.pad_token,
|
845 |
+
self.text_mask_token,
|
846 |
+
force_aas=force_aas,
|
847 |
+
)
|
848 |
+
full_text_logits = full_text_logits.masked_fill(
|
849 |
+
~text_mask.unsqueeze(-1), -torch.inf
|
850 |
+
)
|
851 |
+
|
852 |
+
# If there are masked image tokens, unmask them using unmask_tokens and fill the full image logits tensor with -inf for unmasked tokens
|
853 |
+
if num_masked_image_tokens > 0:
|
854 |
+
image, full_image_logits = unmask_tokens(
|
855 |
+
image,
|
856 |
+
image_mask,
|
857 |
+
num_masked_image_tokens,
|
858 |
+
logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
|
859 |
+
sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
|
860 |
+
timestep,
|
861 |
+
timesteps,
|
862 |
+
self.gamma,
|
863 |
+
)
|
864 |
+
full_text_logits = full_text_logits.masked_fill(
|
865 |
+
~text_mask.unsqueeze(-1), -torch.inf
|
866 |
+
)
|
867 |
+
|
868 |
+
# Generate heatmap
|
869 |
+
with torch.no_grad():
|
870 |
+
# Normalize full image logits tensor
|
871 |
+
full_image_logits /= torch.max(
|
872 |
+
torch.abs(full_image_logits), dim=-1, keepdim=True
|
873 |
+
).values
|
874 |
+
|
875 |
+
# Apply quantize embedding to full image logits tensor
|
876 |
+
full_image_logits = torch.matmul(
|
877 |
+
full_image_logits, self.vae.model.quantize.embedding.weight
|
878 |
+
)
|
879 |
+
|
880 |
+
# Rearrange full image logits tensor
|
881 |
+
h = int(self.image_seq_len**0.5)
|
882 |
+
full_image_logits = rearrange(
|
883 |
+
full_image_logits, "b (h w) c -> b c h w", h=h
|
884 |
+
)
|
885 |
+
|
886 |
+
# Decode full image logits tensor
|
887 |
+
full_image_logits = self.vae.model.decode(full_image_logits)
|
888 |
+
|
889 |
+
# Add clipping to full image logits tensor
|
890 |
+
max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0]
|
891 |
+
min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0]
|
892 |
+
full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view(
|
893 |
+
batch_size, 1, 1, 1
|
894 |
+
)
|
895 |
+
full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view(
|
896 |
+
batch_size, 1, 1, 1
|
897 |
+
)
|
898 |
+
|
899 |
+
# Clip full image logits tensor values to the range [0, 1]
|
900 |
+
full_image_logits = torch.clip(full_image_logits, 0, 1)
|
901 |
+
|
902 |
+
# Return text tensor, detokenized text tensor, full text logits tensor,
|
903 |
+
# binary image tensor, and full image logits tensor
|
904 |
+
return (
|
905 |
+
text,
|
906 |
+
detokenize_text(self.text_embedding, text),
|
907 |
+
full_text_logits,
|
908 |
+
1.0 * (vae.decode(image) > 0.5),
|
909 |
+
full_image_logits,
|
910 |
+
)
|
911 |
+
|
912 |
+
@torch.no_grad()
|
913 |
+
@eval_decorator
|
914 |
+
def sample_text(
|
915 |
+
self,
|
916 |
+
text=False,
|
917 |
+
condition=False,
|
918 |
+
image=False,
|
919 |
+
temperature=1.0,
|
920 |
+
filter_thres=0.9,
|
921 |
+
progress=False,
|
922 |
+
n_unmask=1,
|
923 |
+
place_amino=True,
|
924 |
+
force_aas=False,
|
925 |
+
):
|
926 |
+
# set model and VAEs to evaluation mode
|
927 |
+
self.eval()
|
928 |
+
|
929 |
+
# ensure that at least one of text, condition, or image is supplied
|
930 |
+
assert (
|
931 |
+
isinstance(text, torch.Tensor)
|
932 |
+
or isinstance(condition, torch.Tensor)
|
933 |
+
or isinstance(image, torch.Tensor)
|
934 |
+
), "some data must be supplied"
|
935 |
+
|
936 |
+
# convert text, condition, and image to tensors if they aren't already
|
937 |
+
text, condition, image = self.create_tensors(text, condition, image)
|
938 |
+
|
939 |
+
# determine the maximum batch size of the input tensors
|
940 |
+
batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
|
941 |
+
|
942 |
+
# match the batch sizes of text, condition, and image
|
943 |
+
text, condition, image = match_batch_size(text, condition, image, batch_size)
|
944 |
+
|
945 |
+
# determine the device of the tensors
|
946 |
+
device = next(
|
947 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
948 |
+
None,
|
949 |
+
).device
|
950 |
+
|
951 |
+
assert text.shape[0] == condition.shape[0] == image.shape[0]
|
952 |
+
|
953 |
+
# Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
|
954 |
+
|
955 |
+
# full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
|
956 |
+
full_text_logits = torch.zeros(
|
957 |
+
batch_size, self.text_seq_len, self.num_text_tokens
|
958 |
+
).to(device)
|
959 |
+
|
960 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
961 |
+
full_text_logits = full_text_logits.scatter_(
|
962 |
+
dim=-1, index=text.unsqueeze(-1), value=1
|
963 |
+
)
|
964 |
+
|
965 |
+
text_mask = text == self.text_mask_token
|
966 |
+
cond_mask = condition == self.cond_mask_token
|
967 |
+
image_mask = image == self.image_mask_token
|
968 |
+
|
969 |
+
mask_indices = text_mask.nonzero()
|
970 |
+
non_mask_indices = (~text_mask).nonzero()
|
971 |
+
|
972 |
+
# figure out the center of the amino acids to determine generation direction
|
973 |
+
central_protein_index = torch.tensor(
|
974 |
+
[
|
975 |
+
torch.median(
|
976 |
+
non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1]
|
977 |
+
)
|
978 |
+
for idx in range(batch_size)
|
979 |
+
]
|
980 |
+
)
|
981 |
+
|
982 |
+
count = 1
|
983 |
+
|
984 |
+
run_mask = text_mask
|
985 |
+
if progress:
|
986 |
+
pbar = progress(total=torch.sum(run_mask).item())
|
987 |
+
while torch.sum(run_mask) > 0:
|
988 |
+
logits, sample = calculate_logits(
|
989 |
+
[text, condition, image],
|
990 |
+
[text_mask, cond_mask, image_mask],
|
991 |
+
self.embed_and_transform,
|
992 |
+
filter_thres,
|
993 |
+
temperature,
|
994 |
+
)
|
995 |
+
|
996 |
+
# sub_sample: [batch_size ,text_seq_len ,num_text_tokens]
|
997 |
+
sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens]
|
998 |
+
sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf)
|
999 |
+
sub_sample = suppress_invalid_text_tokens(
|
1000 |
+
text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas
|
1001 |
+
)
|
1002 |
+
# calculate % to unmasked
|
1003 |
+
# get most likely token and probability for each position
|
1004 |
+
|
1005 |
+
for idx in range(batch_size):
|
1006 |
+
selected_mask_indices = mask_indices[
|
1007 |
+
torch.where(mask_indices[:, 0] == idx)
|
1008 |
+
][:, -1]
|
1009 |
+
|
1010 |
+
# Generate to the left
|
1011 |
+
if selected_mask_indices[-count] < central_protein_index[idx]:
|
1012 |
+
unmask_index = selected_mask_indices[-count]
|
1013 |
+
left_sample = max(0, (unmask_index + 1) - n_unmask)
|
1014 |
+
right_sample = min(unmask_index + 1, self.text_seq_len - 1)
|
1015 |
+
central_protein_index[idx] = max(
|
1016 |
+
0, central_protein_index[idx] - 0.5 * n_unmask
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
# Generate to the right
|
1020 |
+
elif selected_mask_indices[count - 1] > central_protein_index[idx]:
|
1021 |
+
unmask_index = selected_mask_indices[count - 1]
|
1022 |
+
left_sample = max(0, unmask_index)
|
1023 |
+
right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1)
|
1024 |
+
central_protein_index[idx] = min(
|
1025 |
+
central_protein_index[idx] + 0.5 * n_unmask,
|
1026 |
+
self.text_seq_len - 1,
|
1027 |
+
)
|
1028 |
+
|
1029 |
+
# save logits for relevant position
|
1030 |
+
full_text_logits[
|
1031 |
+
idx, left_sample:right_sample, : self.text_seq_len - 1
|
1032 |
+
] = logits[idx, left_sample:right_sample, : self.num_text_tokens]
|
1033 |
+
|
1034 |
+
run_mask[idx, left_sample:right_sample] = False
|
1035 |
+
|
1036 |
+
# you may want to resample the amion acids or calculate marginal probs
|
1037 |
+
# if so, set place_amino to false
|
1038 |
+
if place_amino:
|
1039 |
+
text[idx, left_sample:right_sample] = torch.where(
|
1040 |
+
text[idx, left_sample:right_sample] == self.text_mask_token,
|
1041 |
+
sub_sample[
|
1042 |
+
idx, left_sample:right_sample, : self.num_text_tokens
|
1043 |
+
].argmax(dim=-1),
|
1044 |
+
text[idx, left_sample:right_sample],
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
text_mask = run_mask
|
1048 |
+
|
1049 |
+
count += n_unmask
|
1050 |
+
|
1051 |
+
if progress:
|
1052 |
+
pbar.update(n_unmask)
|
1053 |
+
if progress:
|
1054 |
+
pbar.close()
|
1055 |
+
|
1056 |
+
return (
|
1057 |
+
text,
|
1058 |
+
detokenize_text(self.text_embedding, text),
|
1059 |
+
full_text_logits,
|
1060 |
+
)
|
celle/reversible.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
# for routing arguments into the functions of the reversible layer
|
4 |
+
def route_args(router, args, depth):
|
5 |
+
routed_args = [(dict(), dict()) for _ in range(depth)]
|
6 |
+
matched_keys = [key for key in args.keys() if key in router]
|
7 |
+
|
8 |
+
for key in matched_keys:
|
9 |
+
val = args[key]
|
10 |
+
for depth, ((f_args, g_args), routes) in enumerate(
|
11 |
+
zip(routed_args, router[key])
|
12 |
+
):
|
13 |
+
new_f_args, new_g_args = map(
|
14 |
+
lambda route: ({key: val} if route else {}), routes
|
15 |
+
)
|
16 |
+
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
17 |
+
return routed_args
|
18 |
+
|
19 |
+
class SequentialSequence(nn.Module):
|
20 |
+
def __init__(self, layers, args_route={}, layer_dropout=0.0):
|
21 |
+
super().__init__()
|
22 |
+
assert all(
|
23 |
+
len(route) == len(layers) for route in args_route.values()
|
24 |
+
), "each argument route map must have the same depth as the number of sequential layers"
|
25 |
+
self.layers = layers
|
26 |
+
self.args_route = args_route
|
27 |
+
self.layer_dropout = layer_dropout
|
28 |
+
|
29 |
+
def forward(self, x, **kwargs):
|
30 |
+
args = route_args(self.args_route, kwargs, len(self.layers))
|
31 |
+
layers_and_args = list(zip(self.layers, args))
|
32 |
+
|
33 |
+
for (f, g), (f_args, g_args) in layers_and_args:
|
34 |
+
x = x + f(x, **f_args)
|
35 |
+
x = x + g(x, **g_args)
|
36 |
+
return x
|
celle/transformer.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from celle.reversible import SequentialSequence
|
9 |
+
from celle.attention import Attention
|
10 |
+
|
11 |
+
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
12 |
+
from celle.utils import exists, default, cast_tuple
|
13 |
+
|
14 |
+
# https://arxiv.org/abs/2103.17239
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(self, dim, depth, fn):
|
17 |
+
super().__init__()
|
18 |
+
if depth <= 18:
|
19 |
+
init_eps = 0.1
|
20 |
+
elif depth > 18 and depth <= 24:
|
21 |
+
init_eps = 1e-5
|
22 |
+
else:
|
23 |
+
init_eps = 1e-6
|
24 |
+
|
25 |
+
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
26 |
+
self.scale = nn.Parameter(scale)
|
27 |
+
self.fn = fn
|
28 |
+
|
29 |
+
def forward(self, x, **kwargs):
|
30 |
+
return self.fn(x, **kwargs) * self.scale
|
31 |
+
|
32 |
+
|
33 |
+
# layer norm
|
34 |
+
class PreNorm(nn.Module):
|
35 |
+
def __init__(self, dim, fn):
|
36 |
+
super().__init__()
|
37 |
+
self.norm = nn.LayerNorm(dim)
|
38 |
+
self.norm_out = nn.Identity()
|
39 |
+
self.fn = fn
|
40 |
+
|
41 |
+
def forward(self, x, **kwargs):
|
42 |
+
x = self.norm(x)
|
43 |
+
x = self.fn(x, **kwargs)
|
44 |
+
return self.norm_out(x)
|
45 |
+
|
46 |
+
|
47 |
+
# feed forward
|
48 |
+
|
49 |
+
|
50 |
+
class GEGLU(nn.Module):
|
51 |
+
def forward(self, x):
|
52 |
+
x, gates = x.chunk(2, dim=-1)
|
53 |
+
return x * F.gelu(gates)
|
54 |
+
|
55 |
+
|
56 |
+
class FeedForward(nn.Module):
|
57 |
+
def __init__(self, dim, dropout=0.0, mult=4.0):
|
58 |
+
super().__init__()
|
59 |
+
self.net = nn.Sequential(
|
60 |
+
nn.Linear(dim, dim * mult * 2),
|
61 |
+
GEGLU(),
|
62 |
+
nn.Dropout(dropout),
|
63 |
+
nn.Linear(dim * mult, dim),
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
return self.net(x)
|
68 |
+
|
69 |
+
|
70 |
+
# main transformer class
|
71 |
+
class Transformer(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
*,
|
75 |
+
dim,
|
76 |
+
depth,
|
77 |
+
seq_len,
|
78 |
+
causal=True,
|
79 |
+
heads=8,
|
80 |
+
dim_head=64,
|
81 |
+
ff_mult=4,
|
82 |
+
attn_dropout=0.0,
|
83 |
+
ff_dropout=0.0,
|
84 |
+
image_fmap_size=None,
|
85 |
+
num_images=None,
|
86 |
+
stable=False,
|
87 |
+
rotary_emb=True,
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
layers = nn.ModuleList([])
|
91 |
+
|
92 |
+
self.seq_len = seq_len
|
93 |
+
self.image_fmap_size = image_fmap_size
|
94 |
+
|
95 |
+
for ind in range(depth):
|
96 |
+
|
97 |
+
attn_class = partial(Attention, stable=stable)
|
98 |
+
|
99 |
+
attn = attn_class(
|
100 |
+
dim,
|
101 |
+
causal=causal,
|
102 |
+
seq_len=seq_len,
|
103 |
+
heads=heads,
|
104 |
+
dim_head=dim_head,
|
105 |
+
dropout=attn_dropout,
|
106 |
+
)
|
107 |
+
|
108 |
+
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
|
109 |
+
|
110 |
+
layers.append(
|
111 |
+
nn.ModuleList(
|
112 |
+
[
|
113 |
+
LayerScale(
|
114 |
+
dim, ind + 1, PreNorm(dim, attn)
|
115 |
+
),
|
116 |
+
LayerScale(
|
117 |
+
dim, ind + 1, PreNorm(dim, ff)
|
118 |
+
),
|
119 |
+
]
|
120 |
+
)
|
121 |
+
)
|
122 |
+
|
123 |
+
# pairs arguments with attention layer
|
124 |
+
route_attn = ((True, False),) * depth
|
125 |
+
attn_route_map = {
|
126 |
+
"mask": route_attn,
|
127 |
+
"rotary_pos_emb": route_attn,
|
128 |
+
}
|
129 |
+
|
130 |
+
self.layers = SequentialSequence(layers, args_route=attn_route_map)
|
131 |
+
|
132 |
+
# generate positional embeddings for rotary
|
133 |
+
|
134 |
+
pos_emb = None
|
135 |
+
if rotary_emb:
|
136 |
+
rot_dim = dim_head // 3
|
137 |
+
img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images
|
138 |
+
|
139 |
+
text_len = seq_len - img_seq_len + 1
|
140 |
+
|
141 |
+
text_pos_emb = RotaryEmbedding(dim=rot_dim)
|
142 |
+
|
143 |
+
img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel")
|
144 |
+
|
145 |
+
text_freqs = text_pos_emb(torch.arange(text_len))
|
146 |
+
|
147 |
+
img_to_text_freqs = text_pos_emb(
|
148 |
+
torch.full((img_seq_len,), 8192)
|
149 |
+
) # image is given a position far away from text
|
150 |
+
|
151 |
+
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0)
|
152 |
+
|
153 |
+
img_freqs_axial = img_axial_pos_emb(
|
154 |
+
torch.linspace(-1, 1, steps=image_fmap_size)
|
155 |
+
)
|
156 |
+
|
157 |
+
if num_images > 1:
|
158 |
+
split_img_freqs_axial = torch.split(
|
159 |
+
img_freqs_axial, image_fmap_size // num_images, dim=0
|
160 |
+
)
|
161 |
+
|
162 |
+
split_img_freqs = [
|
163 |
+
broadcat(
|
164 |
+
(
|
165 |
+
rearrange(img_freqs_axial_per_image, "i d -> i () d"),
|
166 |
+
rearrange(img_freqs_axial_per_image, "j d -> () j d"),
|
167 |
+
),
|
168 |
+
dim=-1,
|
169 |
+
)
|
170 |
+
for img_freqs_axial_per_image in split_img_freqs_axial
|
171 |
+
]
|
172 |
+
|
173 |
+
split_img_freqs = [
|
174 |
+
rearrange(img_freqs_per_image, "h w d -> (h w) d")
|
175 |
+
for img_freqs_per_image in split_img_freqs
|
176 |
+
]
|
177 |
+
|
178 |
+
# concat per image-image_freqs
|
179 |
+
|
180 |
+
img_freqs = torch.cat(split_img_freqs, dim=0)
|
181 |
+
|
182 |
+
elif num_images == 1:
|
183 |
+
img_freqs = broadcat(
|
184 |
+
(
|
185 |
+
rearrange(img_freqs_axial, "i d -> i () d"),
|
186 |
+
rearrange(img_freqs_axial, "j d -> () j d"),
|
187 |
+
),
|
188 |
+
dim=-1,
|
189 |
+
)
|
190 |
+
|
191 |
+
img_freqs = rearrange(img_freqs, "h w d -> (h w) d")
|
192 |
+
|
193 |
+
else:
|
194 |
+
assert False, "num_images must be int greater than 0"
|
195 |
+
self.img_axial_pos_emb = img_axial_pos_emb
|
196 |
+
self.text_pos_emb = text_pos_emb
|
197 |
+
|
198 |
+
text_axial_freqs = img_axial_pos_emb(
|
199 |
+
torch.full((text_len,), -10.0)
|
200 |
+
) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
|
201 |
+
|
202 |
+
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1)
|
203 |
+
|
204 |
+
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0)
|
205 |
+
|
206 |
+
pos_emb = torch.cat((text_freqs, img_freqs), dim=-1)
|
207 |
+
|
208 |
+
pos_emb = rearrange(pos_emb, "n d -> () n d")
|
209 |
+
|
210 |
+
self.register_buffer("pos_emb", pos_emb)
|
211 |
+
|
212 |
+
def forward(self, x, **kwargs):
|
213 |
+
return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs)
|
celle/utils.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from math import pi
|
3 |
+
|
4 |
+
# Define helper functions
|
5 |
+
def exists(val):
|
6 |
+
"""Check if a variable exists"""
|
7 |
+
return val is not None
|
8 |
+
|
9 |
+
|
10 |
+
def uniq(arr):
|
11 |
+
return {el: True for el in arr}.keys()
|
12 |
+
|
13 |
+
|
14 |
+
def default(val, d):
|
15 |
+
"""If a value exists, return it; otherwise, return a default value"""
|
16 |
+
return val if exists(val) else d
|
17 |
+
|
18 |
+
|
19 |
+
def max_neg_value(t):
|
20 |
+
return -torch.finfo(t.dtype).max
|
21 |
+
|
22 |
+
|
23 |
+
def cast_tuple(val, depth=1):
|
24 |
+
if isinstance(val, list):
|
25 |
+
val = tuple(val)
|
26 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
27 |
+
|
28 |
+
|
29 |
+
def is_empty(t):
|
30 |
+
"""Check if a tensor is empty"""
|
31 |
+
# Return True if the number of elements in the tensor is zero, else False
|
32 |
+
return t.nelement() == 0
|
33 |
+
|
34 |
+
|
35 |
+
def masked_mean(t, mask, dim=1):
|
36 |
+
"""
|
37 |
+
Compute the mean of a tensor, masked by a given mask
|
38 |
+
|
39 |
+
Args:
|
40 |
+
t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim)
|
41 |
+
mask (torch.Tensor): mask tensor of shape (batch_size, seq_len)
|
42 |
+
dim (int): dimension along which to compute the mean (default=1)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim)
|
46 |
+
"""
|
47 |
+
t = t.masked_fill(~mask[:, :, None], 0.0)
|
48 |
+
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
49 |
+
|
50 |
+
|
51 |
+
def set_requires_grad(model, value):
|
52 |
+
"""
|
53 |
+
Set whether or not the model's parameters require gradients
|
54 |
+
|
55 |
+
Args:
|
56 |
+
model (torch.nn.Module): the PyTorch model to modify
|
57 |
+
value (bool): whether or not to require gradients
|
58 |
+
"""
|
59 |
+
for param in model.parameters():
|
60 |
+
param.requires_grad = value
|
61 |
+
|
62 |
+
|
63 |
+
def eval_decorator(fn):
|
64 |
+
"""
|
65 |
+
Decorator function to evaluate a given function
|
66 |
+
|
67 |
+
Args:
|
68 |
+
fn (callable): function to evaluate
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
callable: the decorated function
|
72 |
+
"""
|
73 |
+
|
74 |
+
def inner(model, *args, **kwargs):
|
75 |
+
was_training = model.training
|
76 |
+
model.eval()
|
77 |
+
out = fn(model, *args, **kwargs)
|
78 |
+
model.train(was_training)
|
79 |
+
return out
|
80 |
+
|
81 |
+
return inner
|
82 |
+
|
83 |
+
|
84 |
+
def log(t, eps=1e-20):
|
85 |
+
"""
|
86 |
+
Compute the natural logarithm of a tensor
|
87 |
+
|
88 |
+
Args:
|
89 |
+
t (torch.Tensor): input tensor
|
90 |
+
eps (float): small value to add to prevent taking the log of 0 (default=1e-20)
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
torch.Tensor: the natural logarithm of the input tensor
|
94 |
+
"""
|
95 |
+
return torch.log(t + eps)
|
96 |
+
|
97 |
+
|
98 |
+
def gumbel_noise(t):
|
99 |
+
"""
|
100 |
+
Generate Gumbel noise
|
101 |
+
|
102 |
+
Args:
|
103 |
+
t (torch.Tensor): input tensor
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor
|
107 |
+
"""
|
108 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
109 |
+
return -log(-log(noise))
|
110 |
+
|
111 |
+
|
112 |
+
def gumbel_sample(t, temperature=0.9, dim=-1):
|
113 |
+
"""
|
114 |
+
Sample from a Gumbel-softmax distribution
|
115 |
+
|
116 |
+
Args:
|
117 |
+
t (torch.Tensor): input tensor of shape (batch_size, num_classes)
|
118 |
+
temperature (float): temperature for the Gumbel-softmax distribution (default=0.9)
|
119 |
+
dim (int): dimension along which to sample (default=-1)
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor
|
123 |
+
"""
|
124 |
+
return (t / max(temperature, 1e-10)) + gumbel_noise(t)
|
125 |
+
|
126 |
+
|
127 |
+
def top_k(logits, thres=0.5):
|
128 |
+
"""
|
129 |
+
Return a tensor where all but the top k values are set to negative infinity
|
130 |
+
|
131 |
+
Args:
|
132 |
+
logits (torch.Tensor): input tensor of shape (batch_size, num_classes)
|
133 |
+
thres (float): threshold for the top k values (default=0.5)
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity
|
137 |
+
"""
|
138 |
+
num_logits = logits.shape[-1]
|
139 |
+
k = max(int((1 - thres) * num_logits), 1)
|
140 |
+
val, ind = torch.topk(logits, k)
|
141 |
+
probs = torch.full_like(logits, float("-inf"))
|
142 |
+
probs.scatter_(-1, ind, val)
|
143 |
+
return probs
|
144 |
+
|
145 |
+
|
146 |
+
def gamma_func(mode="cosine", scale=0.15):
|
147 |
+
"""Return a function that takes a single input r and returns a value based on the selected mode"""
|
148 |
+
|
149 |
+
# Define a different function based on the selected mode
|
150 |
+
if mode == "linear":
|
151 |
+
return lambda r: 1 - r
|
152 |
+
elif mode == "cosine":
|
153 |
+
return lambda r: torch.cos(r * pi / 2)
|
154 |
+
elif mode == "square":
|
155 |
+
return lambda r: 1 - r**2
|
156 |
+
elif mode == "cubic":
|
157 |
+
return lambda r: 1 - r**3
|
158 |
+
elif mode == "scaled-cosine":
|
159 |
+
return lambda r: scale * (torch.cos(r * pi / 2))
|
160 |
+
else:
|
161 |
+
# Raise an error if the selected mode is not implemented
|
162 |
+
raise NotImplementedError
|
163 |
+
|
164 |
+
|
165 |
+
class always:
|
166 |
+
"""Helper class to always return a given value"""
|
167 |
+
|
168 |
+
def __init__(self, val):
|
169 |
+
self.val = val
|
170 |
+
|
171 |
+
def __call__(self, x, *args, **kwargs):
|
172 |
+
return self.val
|
173 |
+
|
174 |
+
|
175 |
+
class DivideMax(torch.nn.Module):
|
176 |
+
def __init__(self, dim):
|
177 |
+
super().__init__()
|
178 |
+
self.dim = dim
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
maxes = x.amax(dim=self.dim, keepdim=True).detach()
|
182 |
+
return x / maxes
|
183 |
+
|
184 |
+
def process_image(image_path):
|
185 |
+
image = Image.open(image_path)
|
186 |
+
transform = transforms.Compose([
|
187 |
+
transforms.RandomCrop(256),
|
188 |
+
transforms.ToTensor()
|
189 |
+
])
|
190 |
+
image_tensor = transform(image)
|
191 |
+
if image_tensor.shape[0] > 1:
|
192 |
+
image_tensor = torch.mean(image_tensor, dim=0, keepdim=True)
|
193 |
+
return image_tensor.unsqueeze(0)
|
celle/vae.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt, log
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import importlib
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
# helpers methods
|
12 |
+
|
13 |
+
|
14 |
+
def load_model(path):
|
15 |
+
with open(path, "rb") as f:
|
16 |
+
return torch.load(f, map_location=torch.device("cpu"))
|
17 |
+
|
18 |
+
|
19 |
+
def map_pixels(x, eps=0.1):
|
20 |
+
return (1 - 2 * eps) * x + eps
|
21 |
+
|
22 |
+
|
23 |
+
def unmap_pixels(x, eps=0.1):
|
24 |
+
return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
|
25 |
+
|
26 |
+
|
27 |
+
def make_contiguous(module):
|
28 |
+
with torch.no_grad():
|
29 |
+
for param in module.parameters():
|
30 |
+
param.set_(param.contiguous())
|
31 |
+
|
32 |
+
|
33 |
+
# VQGAN from Taming Transformers paper
|
34 |
+
# https://arxiv.org/abs/2012.09841
|
35 |
+
|
36 |
+
|
37 |
+
def get_obj_from_str(string, reload=False):
|
38 |
+
module, cls = string.rsplit(".", 1)
|
39 |
+
if reload:
|
40 |
+
module_imp = importlib.import_module(module)
|
41 |
+
importlib.reload(module_imp)
|
42 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
43 |
+
|
44 |
+
|
45 |
+
def instantiate_from_config(config):
|
46 |
+
if not "target" in config:
|
47 |
+
raise KeyError("Expected key `target` to instantiate.")
|
48 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
49 |
+
|
50 |
+
|
51 |
+
class VQGanVAE(nn.Module):
|
52 |
+
def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
assert vqgan_config_path is not None
|
56 |
+
|
57 |
+
model_path = vqgan_model_path
|
58 |
+
config_path = vqgan_config_path
|
59 |
+
|
60 |
+
config = OmegaConf.load(config_path)
|
61 |
+
|
62 |
+
model = instantiate_from_config(config["model"])
|
63 |
+
|
64 |
+
if vqgan_model_path:
|
65 |
+
|
66 |
+
state = torch.load(model_path, map_location="cpu")["state_dict"]
|
67 |
+
model.load_state_dict(state, strict=True)
|
68 |
+
|
69 |
+
print(f"Loaded VQGAN from {model_path} and {config_path}")
|
70 |
+
|
71 |
+
self.model = model
|
72 |
+
|
73 |
+
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
|
74 |
+
f = (
|
75 |
+
config.model.params.ddconfig.resolution
|
76 |
+
/ config.model.params.ddconfig.attn_resolutions[0]
|
77 |
+
)
|
78 |
+
self.num_layers = int(log(f) / log(2))
|
79 |
+
self.image_size = config.model.params.ddconfig.resolution
|
80 |
+
self.num_tokens = config.model.params.n_embed
|
81 |
+
# self.is_gumbel = isinstance(self.model, GumbelVQ)
|
82 |
+
self.is_gumbel = False
|
83 |
+
self.channels = config.model.params.ddconfig.in_channels
|
84 |
+
|
85 |
+
def encode(self, img):
|
86 |
+
return self.model.encode(img)
|
87 |
+
|
88 |
+
def get_codebook_indices(self, img):
|
89 |
+
b = img.shape[0]
|
90 |
+
# img = (2 * img) - 1
|
91 |
+
_, _, [_, _, indices] = self.encode(img)
|
92 |
+
if self.is_gumbel:
|
93 |
+
return rearrange(indices, "b h w -> b (h w)", b=b)
|
94 |
+
return rearrange(indices, "(b n) -> b n", b=b)
|
95 |
+
|
96 |
+
def decode(self, img_seq):
|
97 |
+
b, n = img_seq.shape
|
98 |
+
one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
|
99 |
+
z = (
|
100 |
+
one_hot_indices @ self.model.quantize.embed.weight
|
101 |
+
if self.is_gumbel
|
102 |
+
else (one_hot_indices @ self.model.quantize.embedding.weight)
|
103 |
+
)
|
104 |
+
|
105 |
+
z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
|
106 |
+
img = self.model.decode(z)
|
107 |
+
|
108 |
+
# img = (img.clamp(-1.0, 1.0) + 1) * 0.5
|
109 |
+
return img
|
110 |
+
|
111 |
+
def forward(self, img, optimizer_idx=1):
|
112 |
+
return self.model.training_step(img, optimizer_idx=optimizer_idx)
|
celle_main.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.random
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning import seed_everything
|
10 |
+
from pytorch_lightning.trainer import Trainer
|
11 |
+
|
12 |
+
from dataloader import CellLoader
|
13 |
+
from celle import VQGanVAE, CELLE
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
import argparse, os, sys, datetime, glob
|
16 |
+
|
17 |
+
from celle.celle import gumbel_sample, top_k
|
18 |
+
|
19 |
+
torch.random.manual_seed(42)
|
20 |
+
np.random.seed(42)
|
21 |
+
|
22 |
+
from celle_taming_main import (
|
23 |
+
instantiate_from_config,
|
24 |
+
nondefault_trainer_args,
|
25 |
+
get_parser,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class CellDataModule(pl.LightningDataModule):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
data_csv,
|
33 |
+
dataset,
|
34 |
+
sequence_mode="standard",
|
35 |
+
vocab="bert",
|
36 |
+
crop_size=256,
|
37 |
+
resize=600,
|
38 |
+
batch_size=1,
|
39 |
+
threshold="median",
|
40 |
+
text_seq_len=1000,
|
41 |
+
num_workers=1,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.data_csv = data_csv
|
47 |
+
self.dataset = dataset
|
48 |
+
self.protein_sequence_length = 0
|
49 |
+
self.image_folders = []
|
50 |
+
self.crop_size = crop_size
|
51 |
+
self.resize = resize
|
52 |
+
self.batch_size = batch_size
|
53 |
+
self.sequence_mode = sequence_mode
|
54 |
+
self.threshold = threshold
|
55 |
+
self.text_seq_len = int(text_seq_len)
|
56 |
+
self.vocab = vocab
|
57 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
58 |
+
|
59 |
+
def setup(self, stage=None):
|
60 |
+
# called on every GPU
|
61 |
+
self.cell_dataset_train = CellLoader(
|
62 |
+
data_csv=self.data_csv,
|
63 |
+
dataset=self.dataset,
|
64 |
+
crop_size=self.crop_size,
|
65 |
+
resize=self.resize,
|
66 |
+
split_key="train",
|
67 |
+
crop_method="random",
|
68 |
+
sequence_mode=self.sequence_mode,
|
69 |
+
vocab=self.vocab,
|
70 |
+
text_seq_len=self.text_seq_len,
|
71 |
+
threshold=self.threshold,
|
72 |
+
)
|
73 |
+
|
74 |
+
self.cell_dataset_val = CellLoader(
|
75 |
+
data_csv=self.data_csv,
|
76 |
+
dataset=self.dataset,
|
77 |
+
crop_size=self.crop_size,
|
78 |
+
resize=self.resize,
|
79 |
+
crop_method="center",
|
80 |
+
split_key="val",
|
81 |
+
sequence_mode=self.sequence_mode,
|
82 |
+
vocab=self.vocab,
|
83 |
+
text_seq_len=self.text_seq_len,
|
84 |
+
threshold=self.threshold,
|
85 |
+
)
|
86 |
+
|
87 |
+
def prepare_data(self):
|
88 |
+
|
89 |
+
pass
|
90 |
+
|
91 |
+
def train_dataloader(self):
|
92 |
+
return DataLoader(
|
93 |
+
self.cell_dataset_train,
|
94 |
+
num_workers=self.num_workers,
|
95 |
+
shuffle=True,
|
96 |
+
batch_size=self.batch_size,
|
97 |
+
)
|
98 |
+
|
99 |
+
def val_dataloader(self):
|
100 |
+
return DataLoader(
|
101 |
+
self.cell_dataset_val,
|
102 |
+
num_workers=self.num_workers,
|
103 |
+
batch_size=self.batch_size,
|
104 |
+
)
|
105 |
+
|
106 |
+
# def test_dataloader(self):
|
107 |
+
# transforms = ...
|
108 |
+
# return DataLoader(self.test, batch_size=64)
|
109 |
+
|
110 |
+
|
111 |
+
class CELLE_trainer(pl.LightningModule):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
vqgan_model_path,
|
115 |
+
vqgan_config_path,
|
116 |
+
ckpt_path=None,
|
117 |
+
image_key="threshold",
|
118 |
+
condition_model_path=None,
|
119 |
+
condition_config_path=None,
|
120 |
+
num_images=2,
|
121 |
+
dim=2,
|
122 |
+
num_text_tokens=30,
|
123 |
+
text_seq_len=1000,
|
124 |
+
depth=16,
|
125 |
+
heads=16,
|
126 |
+
dim_head=64,
|
127 |
+
attn_dropout=0.1,
|
128 |
+
ff_dropout=0.1,
|
129 |
+
attn_types="full",
|
130 |
+
loss_img_weight=7,
|
131 |
+
stable=False,
|
132 |
+
rotary_emb=True,
|
133 |
+
text_embedding="bert",
|
134 |
+
fixed_embedding=True,
|
135 |
+
loss_cond_weight=1,
|
136 |
+
learning_rate=3e-4,
|
137 |
+
monitor="val_loss",
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
vae = VQGanVAE(
|
142 |
+
vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path
|
143 |
+
)
|
144 |
+
|
145 |
+
self.image_key = image_key
|
146 |
+
|
147 |
+
if condition_config_path:
|
148 |
+
condition_vae = VQGanVAE(
|
149 |
+
vqgan_model_path=condition_model_path,
|
150 |
+
vqgan_config_path=condition_config_path,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
condition_vae = None
|
154 |
+
|
155 |
+
self.celle = CELLE(
|
156 |
+
dim=dim,
|
157 |
+
vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
|
158 |
+
condition_vae=condition_vae,
|
159 |
+
num_images=num_images,
|
160 |
+
num_text_tokens=num_text_tokens, # vocab size for text
|
161 |
+
text_seq_len=text_seq_len, # text sequence length
|
162 |
+
depth=depth, # should aim to be 64
|
163 |
+
heads=heads, # attention heads
|
164 |
+
dim_head=dim_head, # attention head dimension
|
165 |
+
attn_dropout=attn_dropout, # attention dropout
|
166 |
+
ff_dropout=ff_dropout, # feedforward dropout
|
167 |
+
loss_img_weight=loss_img_weight,
|
168 |
+
stable=stable,
|
169 |
+
rotary_emb=rotary_emb,
|
170 |
+
text_embedding=text_embedding,
|
171 |
+
fixed_embedding=fixed_embedding,
|
172 |
+
loss_cond_weight=loss_cond_weight,
|
173 |
+
)
|
174 |
+
|
175 |
+
self.learning_rate = learning_rate
|
176 |
+
self.num_text_tokens = num_text_tokens
|
177 |
+
self.num_images = num_images
|
178 |
+
|
179 |
+
if monitor is not None:
|
180 |
+
self.monitor = monitor
|
181 |
+
|
182 |
+
ignore_keys = []
|
183 |
+
|
184 |
+
if condition_model_path:
|
185 |
+
ignore_keys.append("celle.condition_vae")
|
186 |
+
|
187 |
+
if vqgan_model_path:
|
188 |
+
ignore_keys.append("celle.vae")
|
189 |
+
|
190 |
+
if ckpt_path is not None:
|
191 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
192 |
+
|
193 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
194 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
195 |
+
ckpt = sd.copy()
|
196 |
+
for k in sd.keys():
|
197 |
+
for ik in ignore_keys:
|
198 |
+
if k.startswith(ik):
|
199 |
+
# print("Deleting key {} from state_dict.".format(k))
|
200 |
+
del ckpt[k]
|
201 |
+
self.load_state_dict(ckpt, strict=True)
|
202 |
+
print(f"Restored from {path}")
|
203 |
+
|
204 |
+
def forward(self, text, condition, target, return_loss=True):
|
205 |
+
|
206 |
+
return self.celle(
|
207 |
+
text=text, condition=condition, image=target, return_loss=return_loss
|
208 |
+
)
|
209 |
+
|
210 |
+
def get_input(self, batch):
|
211 |
+
text = batch["sequence"].squeeze(1)
|
212 |
+
condition = batch["nucleus"]
|
213 |
+
target = batch[self.image_key]
|
214 |
+
|
215 |
+
return text, condition, target
|
216 |
+
|
217 |
+
def get_image_from_logits(self, logits, temperature=0.9):
|
218 |
+
|
219 |
+
filtered_logits = top_k(logits, thres=0.5)
|
220 |
+
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
|
221 |
+
|
222 |
+
self.celle.vae.eval()
|
223 |
+
out = self.celle.vae.decode(
|
224 |
+
sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :]
|
225 |
+
- (self.celle.num_text_tokens + self.celle.num_condition_tokens)
|
226 |
+
)
|
227 |
+
|
228 |
+
return out
|
229 |
+
|
230 |
+
def get_loss(self, text, condition, target):
|
231 |
+
|
232 |
+
loss_dict = {}
|
233 |
+
|
234 |
+
loss, loss_dict, logits = self(text, condition, target, return_loss=True)
|
235 |
+
|
236 |
+
return loss, loss_dict
|
237 |
+
|
238 |
+
def total_loss(
|
239 |
+
self,
|
240 |
+
loss,
|
241 |
+
loss_dict,
|
242 |
+
mode="train",
|
243 |
+
):
|
244 |
+
|
245 |
+
loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()}
|
246 |
+
|
247 |
+
for key, value in loss_dict.items():
|
248 |
+
self.log(
|
249 |
+
key,
|
250 |
+
value,
|
251 |
+
prog_bar=True,
|
252 |
+
logger=True,
|
253 |
+
on_step=True,
|
254 |
+
on_epoch=True,
|
255 |
+
sync_dist=True,
|
256 |
+
)
|
257 |
+
|
258 |
+
return loss
|
259 |
+
|
260 |
+
def training_step(self, batch, batch_idx):
|
261 |
+
|
262 |
+
text, condition, target = self.get_input(batch)
|
263 |
+
loss, log_dict = self.get_loss(text, condition, target)
|
264 |
+
|
265 |
+
loss = self.total_loss(loss, log_dict, mode="train")
|
266 |
+
|
267 |
+
return loss
|
268 |
+
|
269 |
+
def validation_step(self, batch, batch_idx):
|
270 |
+
|
271 |
+
with torch.no_grad():
|
272 |
+
|
273 |
+
text, condition, target = self.get_input(batch)
|
274 |
+
loss, log_dict = self.get_loss(text, condition, target)
|
275 |
+
|
276 |
+
loss = self.total_loss(loss, log_dict, mode="val")
|
277 |
+
|
278 |
+
return loss
|
279 |
+
|
280 |
+
def configure_optimizers(self):
|
281 |
+
|
282 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
|
283 |
+
|
284 |
+
return optimizer
|
285 |
+
|
286 |
+
def scale_image(self, image):
|
287 |
+
|
288 |
+
for tensor in image:
|
289 |
+
if torch.min(tensor) < 0:
|
290 |
+
tensor += -torch.min(tensor)
|
291 |
+
else:
|
292 |
+
tensor -= torch.min(tensor)
|
293 |
+
|
294 |
+
tensor /= torch.max(tensor)
|
295 |
+
|
296 |
+
return image
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def log_images(self, batch, **kwargs):
|
300 |
+
|
301 |
+
log = []
|
302 |
+
|
303 |
+
text, condition, target = self.get_input(batch)
|
304 |
+
text = text.squeeze(1).to(self.device)
|
305 |
+
condition = condition.to(self.device)
|
306 |
+
|
307 |
+
out = self.celle.generate_images(text=text, condition=condition)
|
308 |
+
|
309 |
+
log["condition"] = self.scale_image(condition)
|
310 |
+
log["output"] = self.scale_image(out)
|
311 |
+
if self.image_key == "threshold":
|
312 |
+
log["threshold"] = self.scale_image(target)
|
313 |
+
log["target"] = self.scale_image(batch["target"])
|
314 |
+
else:
|
315 |
+
log["target"] = self.scale_image(target)
|
316 |
+
|
317 |
+
return log
|
318 |
+
|
319 |
+
|
320 |
+
# from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
# custom parser to specify config files, train, test and debug mode,
|
324 |
+
# postfix, resume.
|
325 |
+
# `--key value` arguments are interpreted as arguments to the trainer.
|
326 |
+
# `nested.key=value` arguments are interpreted as config parameters.
|
327 |
+
# configs are merged from left-to-right followed by command line parameters.
|
328 |
+
|
329 |
+
# model:
|
330 |
+
# learning_rate: float
|
331 |
+
# target: path to lightning module
|
332 |
+
# params:
|
333 |
+
# key: value
|
334 |
+
# data:
|
335 |
+
# target: celle_main.DataModuleFromConfig
|
336 |
+
# params:
|
337 |
+
# batch_size: int
|
338 |
+
# wrap: bool
|
339 |
+
# train:
|
340 |
+
# target: path to train dataset
|
341 |
+
# params:
|
342 |
+
# key: value
|
343 |
+
# validation:
|
344 |
+
# target: path to validation dataset
|
345 |
+
# params:
|
346 |
+
# key: value
|
347 |
+
# test:
|
348 |
+
# target: path to test dataset
|
349 |
+
# params:
|
350 |
+
# key: value
|
351 |
+
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
352 |
+
# trainer:
|
353 |
+
# additional arguments to trainer
|
354 |
+
# logger:
|
355 |
+
# logger to instantiate
|
356 |
+
# modelcheckpoint:
|
357 |
+
# modelcheckpoint to instantiate
|
358 |
+
# callbacks:
|
359 |
+
# callback1:
|
360 |
+
# target: importpath
|
361 |
+
# params:
|
362 |
+
# key: value
|
363 |
+
|
364 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
365 |
+
|
366 |
+
# add cwd for convenience and to make classes in this file available when
|
367 |
+
# running as `python celle_main.py`
|
368 |
+
# (in particular `celle_main.DataModuleFromConfig`)
|
369 |
+
sys.path.append(os.getcwd())
|
370 |
+
|
371 |
+
parser = get_parser()
|
372 |
+
parser = Trainer.add_argparse_args(parser)
|
373 |
+
|
374 |
+
opt, unknown = parser.parse_known_args()
|
375 |
+
if opt.name and opt.resume:
|
376 |
+
raise ValueError(
|
377 |
+
"-n/--name and -r/--resume cannot be specified both."
|
378 |
+
"If you want to resume training in a new log folder, "
|
379 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
380 |
+
)
|
381 |
+
if opt.resume:
|
382 |
+
if not os.path.exists(opt.resume):
|
383 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
384 |
+
if os.path.isfile(opt.resume):
|
385 |
+
paths = opt.resume.split("/")
|
386 |
+
idx = len(paths) - paths[::-1].index("logs") + 1
|
387 |
+
logdir = "/".join(paths[:idx])
|
388 |
+
ckpt = opt.resume
|
389 |
+
else:
|
390 |
+
assert os.path.isdir(opt.resume), opt.resume
|
391 |
+
logdir = opt.resume.rstrip("/")
|
392 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
393 |
+
|
394 |
+
opt.resume_from_checkpoint = ckpt
|
395 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
396 |
+
opt.base = base_configs + opt.base
|
397 |
+
_tmp = logdir.split("/")
|
398 |
+
nowname = _tmp[_tmp.index("logs") + 1]
|
399 |
+
else:
|
400 |
+
if opt.name:
|
401 |
+
name = "_" + opt.name
|
402 |
+
elif opt.base:
|
403 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
404 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
405 |
+
name = "_" + cfg_name
|
406 |
+
else:
|
407 |
+
name = ""
|
408 |
+
nowname = now + name + opt.postfix
|
409 |
+
logdir = os.path.join("logs", nowname)
|
410 |
+
|
411 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
412 |
+
cfgdir = os.path.join(logdir, "configs")
|
413 |
+
seed_everything(opt.seed)
|
414 |
+
|
415 |
+
try:
|
416 |
+
# init and save configs
|
417 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
418 |
+
cli = OmegaConf.from_dotlist(unknown)
|
419 |
+
config = OmegaConf.merge(*configs, cli)
|
420 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
421 |
+
# merge trainer cli with config
|
422 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
423 |
+
# default to ddp
|
424 |
+
# trainer_config["distributed_backend"] = "ddp"
|
425 |
+
for k in nondefault_trainer_args(opt):
|
426 |
+
trainer_config[k] = getattr(opt, k)
|
427 |
+
if not "gpus" in trainer_config:
|
428 |
+
del trainer_config["distributed_backend"]
|
429 |
+
cpu = True
|
430 |
+
else:
|
431 |
+
gpuinfo = trainer_config["gpus"]
|
432 |
+
print(f"Running on GPUs {gpuinfo}")
|
433 |
+
cpu = False
|
434 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
435 |
+
lightning_config.trainer = trainer_config
|
436 |
+
|
437 |
+
# model
|
438 |
+
# model = instantiate_from_config(config.model)
|
439 |
+
model = instantiate_from_config(config.model)
|
440 |
+
# trainer and callbacks
|
441 |
+
trainer_kwargs = dict()
|
442 |
+
|
443 |
+
# default logger configs
|
444 |
+
# NOTE wandb < 0.10.0 interferes with shutdown
|
445 |
+
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
|
446 |
+
# debugging (wrongly sized pudb ui)
|
447 |
+
# thus prefer testtube for now
|
448 |
+
default_logger_cfgs = {
|
449 |
+
"wandb": {
|
450 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
451 |
+
"params": {
|
452 |
+
"name": nowname,
|
453 |
+
"save_dir": logdir,
|
454 |
+
"offline": opt.debug,
|
455 |
+
"id": nowname,
|
456 |
+
},
|
457 |
+
},
|
458 |
+
"testtube": {
|
459 |
+
# "target": "pytorch_lightning.loggers.TestTubeLogger",
|
460 |
+
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
461 |
+
"params": {
|
462 |
+
"name": "testtube",
|
463 |
+
"save_dir": logdir,
|
464 |
+
},
|
465 |
+
},
|
466 |
+
}
|
467 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
468 |
+
# logger_cfg = lightning_config.logger or OmegaConf.create()
|
469 |
+
try:
|
470 |
+
logger_cfg = lightning_config.logger
|
471 |
+
except:
|
472 |
+
logger_cfg = OmegaConf.create()
|
473 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
474 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
475 |
+
|
476 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
477 |
+
# specify which metric is used to determine best models
|
478 |
+
default_modelckpt_cfg = {
|
479 |
+
"checkpoint_callback": {
|
480 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
481 |
+
"params": {
|
482 |
+
"dirpath": ckptdir,
|
483 |
+
"filename": "{epoch:06}",
|
484 |
+
"verbose": True,
|
485 |
+
"save_last": True,
|
486 |
+
},
|
487 |
+
}
|
488 |
+
}
|
489 |
+
if hasattr(model, "monitor"):
|
490 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
491 |
+
default_modelckpt_cfg["checkpoint_callback"]["params"][
|
492 |
+
"monitor"
|
493 |
+
] = model.monitor
|
494 |
+
default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
|
495 |
+
try:
|
496 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
497 |
+
except:
|
498 |
+
modelckpt_cfg = OmegaConf.create()
|
499 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
500 |
+
# trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
501 |
+
|
502 |
+
# add callback which sets up log directory
|
503 |
+
default_callbacks_cfg = {
|
504 |
+
"setup_callback": {
|
505 |
+
"target": "celle_taming_main.SetupCallback",
|
506 |
+
"params": {
|
507 |
+
"resume": opt.resume,
|
508 |
+
"now": now,
|
509 |
+
"logdir": logdir,
|
510 |
+
"ckptdir": ckptdir,
|
511 |
+
"cfgdir": cfgdir,
|
512 |
+
"config": config,
|
513 |
+
"lightning_config": lightning_config,
|
514 |
+
},
|
515 |
+
},
|
516 |
+
# "image_logger": {
|
517 |
+
# "target": "celle_taming_main.ImageLogger",
|
518 |
+
# "params": {
|
519 |
+
# "batch_frequency": 0,
|
520 |
+
# "max_images": 0,
|
521 |
+
# "clamp": False,
|
522 |
+
# "increase_log_steps": False,
|
523 |
+
# },
|
524 |
+
# },
|
525 |
+
# "learning_rate_logger": {
|
526 |
+
# "target": "celle_taming_main.LearningRateMonitor",
|
527 |
+
# "params": {
|
528 |
+
# "logging_interval": "step",
|
529 |
+
# # "log_momentum": True
|
530 |
+
# },
|
531 |
+
# },
|
532 |
+
}
|
533 |
+
try:
|
534 |
+
callbacks_cfg = lightning_config.callbacks
|
535 |
+
except:
|
536 |
+
callbacks_cfg = OmegaConf.create()
|
537 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
538 |
+
callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
|
539 |
+
trainer_kwargs["callbacks"] = [
|
540 |
+
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
541 |
+
]
|
542 |
+
|
543 |
+
trainer = Trainer.from_argparse_args(
|
544 |
+
trainer_opt, **trainer_kwargs, profiler="simple"
|
545 |
+
)
|
546 |
+
|
547 |
+
# data
|
548 |
+
data = instantiate_from_config(config.data)
|
549 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
550 |
+
# calling these ourselves should not be necessary but it is.
|
551 |
+
# lightning still takes care of proper multiprocessing though
|
552 |
+
data.setup()
|
553 |
+
data.prepare_data()
|
554 |
+
|
555 |
+
# configure learning rate
|
556 |
+
bs, lr = config.data.params.batch_size, config.model.learning_rate
|
557 |
+
|
558 |
+
if not cpu:
|
559 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
|
560 |
+
else:
|
561 |
+
ngpu = 1
|
562 |
+
try:
|
563 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
564 |
+
except:
|
565 |
+
accumulate_grad_batches = 1
|
566 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
567 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
568 |
+
model.learning_rate = accumulate_grad_batches * ngpu * bs * lr
|
569 |
+
|
570 |
+
print(
|
571 |
+
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format(
|
572 |
+
model.learning_rate, accumulate_grad_batches, ngpu, bs, lr
|
573 |
+
)
|
574 |
+
)
|
575 |
+
|
576 |
+
# allow checkpointing via USR1
|
577 |
+
def melk(*args, **kwargs):
|
578 |
+
# run all checkpoint hooks
|
579 |
+
if trainer.global_rank == 0:
|
580 |
+
print("Summoning checkpoint.")
|
581 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
582 |
+
trainer.save_checkpoint(ckpt_path)
|
583 |
+
|
584 |
+
def divein(*args, **kwargs):
|
585 |
+
if trainer.global_rank == 0:
|
586 |
+
import pudb
|
587 |
+
|
588 |
+
pudb.set_trace()
|
589 |
+
|
590 |
+
import signal
|
591 |
+
|
592 |
+
signal.signal(signal.SIGUSR1, melk)
|
593 |
+
signal.signal(signal.SIGUSR2, divein)
|
594 |
+
|
595 |
+
# run
|
596 |
+
if opt.train:
|
597 |
+
try:
|
598 |
+
# model = torch.compile(model, mode="reduce_overhead")
|
599 |
+
torch.compile(trainer.fit(model, data), mode="max-autotune")
|
600 |
+
except Exception:
|
601 |
+
melk()
|
602 |
+
raise
|
603 |
+
if not opt.no_test and not trainer.interrupted:
|
604 |
+
trainer.test(model, data)
|
605 |
+
except Exception:
|
606 |
+
if opt.debug and trainer.global_rank == 0:
|
607 |
+
try:
|
608 |
+
import pudb as debugger
|
609 |
+
except ImportError:
|
610 |
+
import pdb as debugger
|
611 |
+
debugger.post_mortem()
|
612 |
+
raise
|
613 |
+
finally:
|
614 |
+
# move newly created debug project to debug_runs
|
615 |
+
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
616 |
+
dst, name = os.path.split(logdir)
|
617 |
+
dst = os.path.join(dst, "debug_runs", name)
|
618 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
619 |
+
os.rename(logdir, dst)
|
celle_taming_main.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, datetime, glob, importlib
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from dataloader import CellLoader
|
9 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from pytorch_lightning import seed_everything
|
12 |
+
from pytorch_lightning.trainer import Trainer
|
13 |
+
from pytorch_lightning.callbacks import Callback
|
14 |
+
from pytorch_lightning.utilities import rank_zero_only
|
15 |
+
|
16 |
+
|
17 |
+
def get_obj_from_str(string, reload=False):
|
18 |
+
module, cls = string.rsplit(".", 1)
|
19 |
+
if reload:
|
20 |
+
module_imp = importlib.import_module(module)
|
21 |
+
importlib.reload(module_imp)
|
22 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
23 |
+
|
24 |
+
|
25 |
+
def get_parser(**parser_kwargs):
|
26 |
+
def str2bool(v):
|
27 |
+
if isinstance(v, bool):
|
28 |
+
return v
|
29 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
30 |
+
return True
|
31 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
32 |
+
return False
|
33 |
+
else:
|
34 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
35 |
+
|
36 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
37 |
+
parser.add_argument(
|
38 |
+
"-n",
|
39 |
+
"--name",
|
40 |
+
type=str,
|
41 |
+
const=True,
|
42 |
+
default="",
|
43 |
+
nargs="?",
|
44 |
+
help="postfix for logdir",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"-r",
|
48 |
+
"--resume",
|
49 |
+
type=str,
|
50 |
+
const=True,
|
51 |
+
default="",
|
52 |
+
nargs="?",
|
53 |
+
help="resume from logdir or checkpoint in logdir",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"-b",
|
57 |
+
"--base",
|
58 |
+
nargs="*",
|
59 |
+
metavar="base_config.yaml",
|
60 |
+
help="paths to base configs. Loaded from left-to-right. "
|
61 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
62 |
+
default=list(),
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"-t",
|
66 |
+
"--train",
|
67 |
+
type=str2bool,
|
68 |
+
const=True,
|
69 |
+
default=False,
|
70 |
+
nargs="?",
|
71 |
+
help="train",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--no-test",
|
75 |
+
type=str2bool,
|
76 |
+
const=True,
|
77 |
+
default=False,
|
78 |
+
nargs="?",
|
79 |
+
help="disable test",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"-p", "--project", help="name of new or path to existing project"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"-d",
|
86 |
+
"--debug",
|
87 |
+
type=str2bool,
|
88 |
+
nargs="?",
|
89 |
+
const=True,
|
90 |
+
default=False,
|
91 |
+
help="enable post-mortem debugging",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"-s",
|
95 |
+
"--seed",
|
96 |
+
type=int,
|
97 |
+
default=42,
|
98 |
+
help="seed for seed_everything",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"-f",
|
102 |
+
"--postfix",
|
103 |
+
type=str,
|
104 |
+
default="",
|
105 |
+
help="post-postfix for default name",
|
106 |
+
)
|
107 |
+
|
108 |
+
return parser
|
109 |
+
|
110 |
+
|
111 |
+
def nondefault_trainer_args(opt):
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
parser = Trainer.add_argparse_args(parser)
|
114 |
+
args = parser.parse_args([])
|
115 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
116 |
+
|
117 |
+
|
118 |
+
def instantiate_from_config(config):
|
119 |
+
if not "target" in config:
|
120 |
+
raise KeyError("Expected key `target` to instantiate.")
|
121 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
122 |
+
|
123 |
+
|
124 |
+
class WrappedDataset(Dataset):
|
125 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
126 |
+
|
127 |
+
def __init__(self, dataset):
|
128 |
+
self.data = dataset
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.data)
|
132 |
+
|
133 |
+
def __getitem__(self, idx):
|
134 |
+
return self.data[idx]
|
135 |
+
|
136 |
+
|
137 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
data_csv,
|
141 |
+
dataset,
|
142 |
+
crop_size=256,
|
143 |
+
resize=600,
|
144 |
+
batch_size=1,
|
145 |
+
sequence_mode="latent",
|
146 |
+
vocab="bert",
|
147 |
+
text_seq_len=0,
|
148 |
+
num_workers=1,
|
149 |
+
threshold=False,
|
150 |
+
train=True,
|
151 |
+
validation=True,
|
152 |
+
test=None,
|
153 |
+
wrap=False,
|
154 |
+
**kwargs,
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.data_csv = data_csv
|
158 |
+
self.dataset = dataset
|
159 |
+
self.image_folders = []
|
160 |
+
self.crop_size = crop_size
|
161 |
+
self.resize = resize
|
162 |
+
self.batch_size = batch_size
|
163 |
+
self.sequence_mode = sequence_mode
|
164 |
+
self.threshold = threshold
|
165 |
+
self.text_seq_len = int(text_seq_len)
|
166 |
+
self.vocab = vocab
|
167 |
+
self.dataset_configs = dict()
|
168 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
169 |
+
if train is not None:
|
170 |
+
self.dataset_configs["train"] = train
|
171 |
+
self.train_dataloader = self._train_dataloader
|
172 |
+
if validation is not None:
|
173 |
+
self.dataset_configs["validation"] = validation
|
174 |
+
self.val_dataloader = self._val_dataloader
|
175 |
+
if test is not None:
|
176 |
+
self.dataset_configs["test"] = test
|
177 |
+
self.test_dataloader = self._test_dataloader
|
178 |
+
self.wrap = wrap
|
179 |
+
|
180 |
+
def prepare_data(self):
|
181 |
+
pass
|
182 |
+
|
183 |
+
def setup(self, stage=None):
|
184 |
+
# called on every GPU
|
185 |
+
self.cell_dataset_train = CellLoader(
|
186 |
+
data_csv=self.data_csv,
|
187 |
+
dataset=self.dataset,
|
188 |
+
crop_size=self.crop_size,
|
189 |
+
split_key="train",
|
190 |
+
crop_method="random",
|
191 |
+
sequence_mode=None,
|
192 |
+
vocab=self.vocab,
|
193 |
+
text_seq_len=self.text_seq_len,
|
194 |
+
threshold=self.threshold,
|
195 |
+
)
|
196 |
+
|
197 |
+
self.cell_dataset_val = CellLoader(
|
198 |
+
data_csv=self.data_csv,
|
199 |
+
dataset=self.dataset,
|
200 |
+
crop_size=self.crop_size,
|
201 |
+
split_key="val",
|
202 |
+
crop_method="center",
|
203 |
+
sequence_mode=None,
|
204 |
+
vocab=self.vocab,
|
205 |
+
text_seq_len=self.text_seq_len,
|
206 |
+
threshold=self.threshold,
|
207 |
+
)
|
208 |
+
|
209 |
+
def _train_dataloader(self):
|
210 |
+
return DataLoader(
|
211 |
+
self.cell_dataset_train,
|
212 |
+
num_workers=self.num_workers,
|
213 |
+
pin_memory=True,
|
214 |
+
shuffle=True,
|
215 |
+
batch_size=self.batch_size,
|
216 |
+
)
|
217 |
+
|
218 |
+
def _val_dataloader(self):
|
219 |
+
return DataLoader(
|
220 |
+
self.cell_dataset_val,
|
221 |
+
num_workers=self.num_workers,
|
222 |
+
pin_memory=True,
|
223 |
+
batch_size=self.batch_size,
|
224 |
+
)
|
225 |
+
|
226 |
+
# def _test_dataloader(self):
|
227 |
+
# return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
228 |
+
# num_workers=self.num_workers)
|
229 |
+
|
230 |
+
|
231 |
+
class SetupCallback(Callback):
|
232 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
233 |
+
super().__init__()
|
234 |
+
self.resume = resume
|
235 |
+
self.now = now
|
236 |
+
self.logdir = logdir
|
237 |
+
self.ckptdir = ckptdir
|
238 |
+
self.cfgdir = cfgdir
|
239 |
+
self.config = config
|
240 |
+
self.lightning_config = lightning_config
|
241 |
+
|
242 |
+
def on_fit_start(self, trainer, pl_module):
|
243 |
+
if trainer.global_rank == 0:
|
244 |
+
# Create logdirs and save configs
|
245 |
+
os.makedirs(self.logdir, exist_ok=True)
|
246 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
247 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
248 |
+
|
249 |
+
print("Project config")
|
250 |
+
print(OmegaConf.to_yaml(self.config))
|
251 |
+
OmegaConf.save(
|
252 |
+
self.config,
|
253 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
|
254 |
+
)
|
255 |
+
|
256 |
+
print("Lightning config")
|
257 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
258 |
+
OmegaConf.save(
|
259 |
+
OmegaConf.create({"lightning": self.lightning_config}),
|
260 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
|
261 |
+
)
|
262 |
+
|
263 |
+
else:
|
264 |
+
# ModelCheckpoint callback created log directory --- remove it
|
265 |
+
if not self.resume and os.path.exists(self.logdir):
|
266 |
+
dst, name = os.path.split(self.logdir)
|
267 |
+
dst = os.path.join(dst, "child_runs", name)
|
268 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
269 |
+
try:
|
270 |
+
os.rename(self.logdir, dst)
|
271 |
+
except FileNotFoundError:
|
272 |
+
pass
|
273 |
+
|
274 |
+
|
275 |
+
class ImageLogger(Callback):
|
276 |
+
def __init__(
|
277 |
+
self, batch_frequency, max_images, clamp=True, increase_log_steps=True
|
278 |
+
):
|
279 |
+
super().__init__()
|
280 |
+
self.batch_freq = batch_frequency
|
281 |
+
self.max_images = max_images
|
282 |
+
self.logger_log_images = {
|
283 |
+
pl.loggers.WandbLogger: self._wandb,
|
284 |
+
# pl.loggers.TestTubeLogger: self._testtube,
|
285 |
+
pl.loggers.TensorBoardLogger: self._testtube,
|
286 |
+
}
|
287 |
+
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
288 |
+
if not increase_log_steps:
|
289 |
+
self.log_steps = [self.batch_freq]
|
290 |
+
self.clamp = clamp
|
291 |
+
|
292 |
+
@rank_zero_only
|
293 |
+
def _wandb(self, pl_module, images, batch_idx, split):
|
294 |
+
raise ValueError("No way wandb")
|
295 |
+
grids = dict()
|
296 |
+
for k in images:
|
297 |
+
grid = torchvision.utils.make_grid(images[k])
|
298 |
+
grids[f"{split}/{k}"] = wandb.Image(grid)
|
299 |
+
pl_module.logger.experiment.log(grids)
|
300 |
+
|
301 |
+
@rank_zero_only
|
302 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
303 |
+
for k in images:
|
304 |
+
images[k] -= torch.min(images[k])
|
305 |
+
images[k] /= torch.max(images[k])
|
306 |
+
grid = torchvision.utils.make_grid(images[k])
|
307 |
+
# grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
308 |
+
|
309 |
+
tag = f"{split}/{k}"
|
310 |
+
pl_module.logger.experiment.add_image(
|
311 |
+
tag, grid, global_step=pl_module.global_step
|
312 |
+
)
|
313 |
+
|
314 |
+
@rank_zero_only
|
315 |
+
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
316 |
+
root = os.path.join(save_dir, "images", split)
|
317 |
+
for k in images:
|
318 |
+
images[k] -= torch.min(images[k])
|
319 |
+
images[k] /= torch.max(images[k])
|
320 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
321 |
+
|
322 |
+
# grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
323 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
324 |
+
grid = grid.numpy()
|
325 |
+
grid = (grid * 255).astype(np.uint8)
|
326 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
327 |
+
k, global_step, current_epoch, batch_idx
|
328 |
+
)
|
329 |
+
path = os.path.join(root, filename)
|
330 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
331 |
+
Image.fromarray(grid).save(path)
|
332 |
+
|
333 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
334 |
+
if (
|
335 |
+
self.check_frequency(batch_idx)
|
336 |
+
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
|
337 |
+
and callable(pl_module.log_images)
|
338 |
+
and self.max_images > 0
|
339 |
+
):
|
340 |
+
logger = type(pl_module.logger)
|
341 |
+
|
342 |
+
is_train = pl_module.training
|
343 |
+
if is_train:
|
344 |
+
pl_module.eval()
|
345 |
+
|
346 |
+
with torch.no_grad():
|
347 |
+
images = pl_module.log_images(batch, split=split)
|
348 |
+
|
349 |
+
for k in images:
|
350 |
+
N = min(images[k].shape[0], self.max_images)
|
351 |
+
images[k] = images[k][:N]
|
352 |
+
if isinstance(images[k], torch.Tensor):
|
353 |
+
images[k] = images[k].detach().cpu()
|
354 |
+
if self.clamp:
|
355 |
+
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
356 |
+
|
357 |
+
self.log_local(
|
358 |
+
pl_module.logger.save_dir,
|
359 |
+
split,
|
360 |
+
images,
|
361 |
+
pl_module.global_step,
|
362 |
+
pl_module.current_epoch,
|
363 |
+
batch_idx,
|
364 |
+
)
|
365 |
+
|
366 |
+
logger_log_images = self.logger_log_images.get(
|
367 |
+
logger, lambda *args, **kwargs: None
|
368 |
+
)
|
369 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
370 |
+
|
371 |
+
if is_train:
|
372 |
+
pl_module.train()
|
373 |
+
|
374 |
+
def check_frequency(self, batch_idx):
|
375 |
+
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
376 |
+
try:
|
377 |
+
self.log_steps.pop(0)
|
378 |
+
except IndexError:
|
379 |
+
pass
|
380 |
+
return True
|
381 |
+
return False
|
382 |
+
|
383 |
+
# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
384 |
+
# def on_train_batch_end(self, *args, **kwargs):
|
385 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
386 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
387 |
+
|
388 |
+
def on_validation_batch_end(
|
389 |
+
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
390 |
+
):
|
391 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
392 |
+
|
393 |
+
|
394 |
+
if __name__ == "__main__":
|
395 |
+
# custom parser to specify config files, train, test and debug mode,
|
396 |
+
# postfix, resume.
|
397 |
+
# `--key value` arguments are interpreted as arguments to the trainer.
|
398 |
+
# `nested.key=value` arguments are interpreted as config parameters.
|
399 |
+
# configs are merged from left-to-right followed by command line parameters.
|
400 |
+
|
401 |
+
# model:
|
402 |
+
# base_learning_rate: float
|
403 |
+
# target: path to lightning module
|
404 |
+
# params:
|
405 |
+
# key: value
|
406 |
+
# data:
|
407 |
+
# target: main.DataModuleFromConfig
|
408 |
+
# params:
|
409 |
+
# batch_size: int
|
410 |
+
# wrap: bool
|
411 |
+
# train:
|
412 |
+
# target: path to train dataset
|
413 |
+
# params:
|
414 |
+
# key: value
|
415 |
+
# validation:
|
416 |
+
# target: path to validation dataset
|
417 |
+
# params:
|
418 |
+
# key: value
|
419 |
+
# test:
|
420 |
+
# target: path to test dataset
|
421 |
+
# params:
|
422 |
+
# key: value
|
423 |
+
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
424 |
+
# trainer:
|
425 |
+
# additional arguments to trainer
|
426 |
+
# logger:
|
427 |
+
# logger to instantiate
|
428 |
+
# modelcheckpoint:
|
429 |
+
# modelcheckpoint to instantiate
|
430 |
+
# callbacks:
|
431 |
+
# callback1:
|
432 |
+
# target: importpath
|
433 |
+
# params:
|
434 |
+
# key: value
|
435 |
+
|
436 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
437 |
+
|
438 |
+
# add cwd for convenience and to make classes in this file available when
|
439 |
+
# running as `python main.py`
|
440 |
+
# (in particular `main.DataModuleFromConfig`)
|
441 |
+
sys.path.append(os.getcwd())
|
442 |
+
|
443 |
+
parser = get_parser()
|
444 |
+
parser = Trainer.add_argparse_args(parser)
|
445 |
+
|
446 |
+
opt, unknown = parser.parse_known_args()
|
447 |
+
if opt.name and opt.resume:
|
448 |
+
raise ValueError(
|
449 |
+
"-n/--name and -r/--resume cannot be specified both."
|
450 |
+
"If you want to resume training in a new log folder, "
|
451 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
452 |
+
)
|
453 |
+
if opt.resume:
|
454 |
+
if not os.path.exists(opt.resume):
|
455 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
456 |
+
if os.path.isfile(opt.resume):
|
457 |
+
paths = opt.resume.split("/")
|
458 |
+
idx = len(paths) - paths[::-1].index("logs") + 1
|
459 |
+
logdir = "/".join(paths[:idx])
|
460 |
+
ckpt = opt.resume
|
461 |
+
else:
|
462 |
+
assert os.path.isdir(opt.resume), opt.resume
|
463 |
+
logdir = opt.resume.rstrip("/")
|
464 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
465 |
+
|
466 |
+
opt.resume_from_checkpoint = ckpt
|
467 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
468 |
+
opt.base = base_configs + opt.base
|
469 |
+
_tmp = logdir.split("/")
|
470 |
+
nowname = _tmp[_tmp.index("logs") + 1]
|
471 |
+
else:
|
472 |
+
if opt.name:
|
473 |
+
name = "_" + opt.name
|
474 |
+
elif opt.base:
|
475 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
476 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
477 |
+
name = "_" + cfg_name
|
478 |
+
else:
|
479 |
+
name = ""
|
480 |
+
nowname = now + name + opt.postfix
|
481 |
+
logdir = os.path.join("logs", nowname)
|
482 |
+
|
483 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
484 |
+
cfgdir = os.path.join(logdir, "configs")
|
485 |
+
seed_everything(opt.seed)
|
486 |
+
|
487 |
+
try:
|
488 |
+
# init and save configs
|
489 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
490 |
+
cli = OmegaConf.from_dotlist(unknown)
|
491 |
+
config = OmegaConf.merge(*configs, cli)
|
492 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
493 |
+
# merge trainer cli with config
|
494 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
495 |
+
# default to ddp
|
496 |
+
trainer_config["distributed_backend"] = "ddp"
|
497 |
+
trainer_config["replace_sampler_ddp"] = False
|
498 |
+
trainer_config["strategy"] = "ddp"
|
499 |
+
trainer_config["persistent_workers"] = True
|
500 |
+
for k in nondefault_trainer_args(opt):
|
501 |
+
trainer_config[k] = getattr(opt, k)
|
502 |
+
if not "gpus" in trainer_config:
|
503 |
+
del trainer_config["distributed_backend"]
|
504 |
+
cpu = True
|
505 |
+
else:
|
506 |
+
gpuinfo = trainer_config["gpus"]
|
507 |
+
print(f"Running on GPUs {gpuinfo}")
|
508 |
+
cpu = False
|
509 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
510 |
+
lightning_config.trainer = trainer_config
|
511 |
+
|
512 |
+
# model
|
513 |
+
model = instantiate_from_config(config.model)
|
514 |
+
# trainer and callbacks
|
515 |
+
trainer_kwargs = dict()
|
516 |
+
|
517 |
+
# default logger configs
|
518 |
+
# NOTE wandb < 0.10.0 interferes with shutdown
|
519 |
+
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
|
520 |
+
# debugging (wrongly sized pudb ui)
|
521 |
+
# thus prefer testtube for now
|
522 |
+
default_logger_cfgs = {
|
523 |
+
"wandb": {
|
524 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
525 |
+
"params": {
|
526 |
+
"name": nowname,
|
527 |
+
"save_dir": logdir,
|
528 |
+
"offline": opt.debug,
|
529 |
+
"id": nowname,
|
530 |
+
},
|
531 |
+
},
|
532 |
+
"testtube": {
|
533 |
+
# "target": "pytorch_lightning.loggers.TestTubeLogger",
|
534 |
+
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
535 |
+
"params": {
|
536 |
+
"name": "testtube",
|
537 |
+
"save_dir": logdir,
|
538 |
+
},
|
539 |
+
},
|
540 |
+
}
|
541 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
542 |
+
try:
|
543 |
+
logger_cfg = lightning_config.logger
|
544 |
+
except:
|
545 |
+
logger_cfg = OmegaConf.create()
|
546 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
547 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
548 |
+
|
549 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
550 |
+
# specify which metric is used to determine best models
|
551 |
+
default_modelckpt_cfg = {
|
552 |
+
"checkpoint_callback": {
|
553 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
554 |
+
"params": {
|
555 |
+
"dirpath": ckptdir,
|
556 |
+
"filename": "{epoch:06}",
|
557 |
+
"verbose": True,
|
558 |
+
"save_last": True,
|
559 |
+
},
|
560 |
+
}
|
561 |
+
}
|
562 |
+
if hasattr(model, "monitor"):
|
563 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
564 |
+
default_modelckpt_cfg["checkpoint_callback"]["params"][
|
565 |
+
"monitor"
|
566 |
+
] = model.monitor
|
567 |
+
default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
|
568 |
+
try:
|
569 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
570 |
+
except:
|
571 |
+
modelckpt_cfg = OmegaConf.create()
|
572 |
+
|
573 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
574 |
+
# trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
575 |
+
|
576 |
+
# loaded_model_callbacks = instantiate_from_config(modelckpt_cfg)
|
577 |
+
|
578 |
+
# add callback which sets up log directory
|
579 |
+
default_callbacks_cfg = {
|
580 |
+
"setup_callback": {
|
581 |
+
"target": "celle_taming_main.SetupCallback",
|
582 |
+
"params": {
|
583 |
+
"resume": opt.resume,
|
584 |
+
"now": now,
|
585 |
+
"logdir": logdir,
|
586 |
+
"ckptdir": ckptdir,
|
587 |
+
"cfgdir": cfgdir,
|
588 |
+
"config": config,
|
589 |
+
"lightning_config": lightning_config,
|
590 |
+
},
|
591 |
+
},
|
592 |
+
"image_logger": {
|
593 |
+
"target": "celle_taming_main.ImageLogger",
|
594 |
+
"params": {
|
595 |
+
"batch_frequency": 2000,
|
596 |
+
"max_images": 10,
|
597 |
+
"clamp": True,
|
598 |
+
"increase_log_steps": False,
|
599 |
+
},
|
600 |
+
},
|
601 |
+
"learning_rate_logger": {
|
602 |
+
"target": "celle_taming_main.LearningRateMonitor",
|
603 |
+
"params": {
|
604 |
+
"logging_interval": "step",
|
605 |
+
# "log_momentum": True
|
606 |
+
},
|
607 |
+
},
|
608 |
+
}
|
609 |
+
try:
|
610 |
+
callbacks_cfg = lightning_config.callbacks
|
611 |
+
except:
|
612 |
+
callbacks_cfg = OmegaConf.create()
|
613 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
614 |
+
callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
|
615 |
+
trainer_kwargs["callbacks"] = [
|
616 |
+
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
617 |
+
]
|
618 |
+
# loaded_callbacks = [
|
619 |
+
# instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
620 |
+
# ]
|
621 |
+
|
622 |
+
# trainer_kwargs["callbacks"] = loaded_callbacks.append(loaded_model_callbacks)
|
623 |
+
|
624 |
+
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
625 |
+
|
626 |
+
# data
|
627 |
+
data = instantiate_from_config(config.data)
|
628 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
629 |
+
# calling these ourselves should not be necessary but it is.
|
630 |
+
# lightning still takes care of proper multiprocessing though
|
631 |
+
data.prepare_data()
|
632 |
+
data.setup()
|
633 |
+
|
634 |
+
# configure learning rate
|
635 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
636 |
+
if not cpu:
|
637 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
|
638 |
+
else:
|
639 |
+
ngpu = 1
|
640 |
+
try:
|
641 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
642 |
+
except:
|
643 |
+
accumulate_grad_batches = 1
|
644 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
645 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
646 |
+
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
647 |
+
print(
|
648 |
+
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
649 |
+
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
|
650 |
+
)
|
651 |
+
)
|
652 |
+
|
653 |
+
# allow checkpointing via USR1
|
654 |
+
def melk(*args, **kwargs):
|
655 |
+
# run all checkpoint hooks
|
656 |
+
if trainer.global_rank == 0:
|
657 |
+
print("Summoning checkpoint.")
|
658 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
659 |
+
trainer.save_checkpoint(ckpt_path)
|
660 |
+
|
661 |
+
def divein(*args, **kwargs):
|
662 |
+
if trainer.global_rank == 0:
|
663 |
+
import pudb
|
664 |
+
|
665 |
+
pudb.set_trace()
|
666 |
+
|
667 |
+
import signal
|
668 |
+
|
669 |
+
signal.signal(signal.SIGUSR1, melk)
|
670 |
+
signal.signal(signal.SIGUSR2, divein)
|
671 |
+
# model = torch.compile(model)
|
672 |
+
# run
|
673 |
+
if opt.train:
|
674 |
+
try:
|
675 |
+
torch.compile(trainer.fit(model, data))
|
676 |
+
except Exception:
|
677 |
+
melk()
|
678 |
+
raise
|
679 |
+
if not opt.no_test and not trainer.interrupted:
|
680 |
+
trainer.test(model, data)
|
681 |
+
except Exception:
|
682 |
+
if opt.debug and trainer.global_rank == 0:
|
683 |
+
try:
|
684 |
+
import pudb as debugger
|
685 |
+
except ImportError:
|
686 |
+
import pdb as debugger
|
687 |
+
debugger.post_mortem()
|
688 |
+
raise
|
689 |
+
finally:
|
690 |
+
# move newly created debug project to debug_runs
|
691 |
+
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
692 |
+
dst, name = os.path.split(logdir)
|
693 |
+
dst = os.path.join(dst, "debug_runs", name)
|
694 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
695 |
+
os.rename(logdir, dst)
|
dataloader.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageSequence
|
4 |
+
import json
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision import transforms
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
|
12 |
+
|
13 |
+
def simple_conversion(seq):
|
14 |
+
"""Create 26-dim embedding"""
|
15 |
+
chars = [
|
16 |
+
"-",
|
17 |
+
"M",
|
18 |
+
"R",
|
19 |
+
"H",
|
20 |
+
"K",
|
21 |
+
"D",
|
22 |
+
"E",
|
23 |
+
"S",
|
24 |
+
"T",
|
25 |
+
"N",
|
26 |
+
"Q",
|
27 |
+
"C",
|
28 |
+
"U",
|
29 |
+
"G",
|
30 |
+
"P",
|
31 |
+
"A",
|
32 |
+
"V",
|
33 |
+
"I",
|
34 |
+
"F",
|
35 |
+
"Y",
|
36 |
+
"W",
|
37 |
+
"L",
|
38 |
+
"O",
|
39 |
+
"X",
|
40 |
+
"Z",
|
41 |
+
"B",
|
42 |
+
"J",
|
43 |
+
]
|
44 |
+
|
45 |
+
nums = range(len(chars))
|
46 |
+
|
47 |
+
seqs_x = np.zeros(len(seq))
|
48 |
+
|
49 |
+
for idx, char in enumerate(seq):
|
50 |
+
|
51 |
+
lui = chars.index(char)
|
52 |
+
|
53 |
+
seqs_x[idx] = nums[lui]
|
54 |
+
|
55 |
+
return torch.tensor([seqs_x]).long()
|
56 |
+
|
57 |
+
|
58 |
+
def replace_outliers(image, percentile=0.0001):
|
59 |
+
|
60 |
+
lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile(
|
61 |
+
image, 1 - percentile
|
62 |
+
)
|
63 |
+
mask = (image <= upper_bound) & (image >= lower_bound)
|
64 |
+
|
65 |
+
valid_pixels = image[mask]
|
66 |
+
|
67 |
+
image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels))
|
68 |
+
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
class CellLoader(Dataset):
|
73 |
+
"""imports mined opencell images with protein sequence"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
data_csv=None,
|
78 |
+
dataset=None,
|
79 |
+
split_key=None,
|
80 |
+
resize=600,
|
81 |
+
crop_size=600,
|
82 |
+
crop_method="random",
|
83 |
+
sequence_mode="simple",
|
84 |
+
vocab="bert",
|
85 |
+
threshold="median",
|
86 |
+
text_seq_len=0,
|
87 |
+
pad_mode="random",
|
88 |
+
):
|
89 |
+
self.data_csv = data_csv
|
90 |
+
self.dataset = dataset
|
91 |
+
self.image_folders = []
|
92 |
+
self.crop_method = crop_method
|
93 |
+
self.resize = resize
|
94 |
+
self.crop_size = crop_size
|
95 |
+
self.sequence_mode = sequence_mode
|
96 |
+
self.threshold = threshold
|
97 |
+
self.text_seq_len = int(text_seq_len)
|
98 |
+
self.vocab = vocab
|
99 |
+
self.pad_mode = pad_mode
|
100 |
+
|
101 |
+
if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
|
102 |
+
|
103 |
+
|
104 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
105 |
+
from esm import Alphabet
|
106 |
+
|
107 |
+
self.tokenizer = Alphabet.from_architecture(
|
108 |
+
"ESM-1b"
|
109 |
+
).get_batch_converter()
|
110 |
+
self.text_seq_len += 2
|
111 |
+
|
112 |
+
if data_csv:
|
113 |
+
|
114 |
+
data = pd.read_csv(data_csv)
|
115 |
+
|
116 |
+
self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
|
117 |
+
|
118 |
+
if split_key == "train":
|
119 |
+
self.data = data[data["split"] == "train"]
|
120 |
+
elif split_key == "val":
|
121 |
+
self.data = data[data["split"] == "val"]
|
122 |
+
else:
|
123 |
+
self.data = data
|
124 |
+
|
125 |
+
self.data = self.data.reset_index(drop=True)
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
def __len__(self):
|
130 |
+
return len(self.data)
|
131 |
+
|
132 |
+
def __getitem__(
|
133 |
+
self,
|
134 |
+
idx,
|
135 |
+
get_sequence=True,
|
136 |
+
get_images=True,
|
137 |
+
):
|
138 |
+
if get_sequence and self.text_seq_len > 0:
|
139 |
+
|
140 |
+
protein_vector = self.get_protein_vector(idx)
|
141 |
+
|
142 |
+
else:
|
143 |
+
protein_vector = torch.zeros((1, 1))
|
144 |
+
|
145 |
+
if get_images:
|
146 |
+
|
147 |
+
nucleus, target, threshold = self.get_images(idx, self.dataset)
|
148 |
+
else:
|
149 |
+
nucleus, target, threshold = torch.zeros((3, 1))
|
150 |
+
|
151 |
+
data_dict = {
|
152 |
+
"nucleus": nucleus.float(),
|
153 |
+
"target": target.float(),
|
154 |
+
"threshold": threshold.float(),
|
155 |
+
"sequence": protein_vector.long(),
|
156 |
+
}
|
157 |
+
|
158 |
+
return data_dict
|
159 |
+
|
160 |
+
def get_protein_vector(self, idx):
|
161 |
+
|
162 |
+
if "protein_sequence" not in self.data.columns:
|
163 |
+
|
164 |
+
metadata = self.retrieve_metadata(idx)
|
165 |
+
protein_sequence = metadata["sequence"]
|
166 |
+
else:
|
167 |
+
protein_sequence = self.data.iloc[idx]["protein_sequence"]
|
168 |
+
|
169 |
+
protein_vector = self.tokenize_sequence(protein_sequence)
|
170 |
+
|
171 |
+
return protein_vector
|
172 |
+
|
173 |
+
def get_images(self, idx, dataset):
|
174 |
+
|
175 |
+
if dataset == "HPA":
|
176 |
+
|
177 |
+
nucleus = Image.open(
|
178 |
+
os.path.join(
|
179 |
+
self.parent_path, self.data.iloc[idx]["nucleus_image_path"]
|
180 |
+
)
|
181 |
+
)
|
182 |
+
|
183 |
+
target = Image.open(
|
184 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["target_image_path"])
|
185 |
+
)
|
186 |
+
|
187 |
+
nucleus = TF.to_tensor(nucleus)[0]
|
188 |
+
target = TF.to_tensor(target)[0]
|
189 |
+
|
190 |
+
image = torch.stack([nucleus, target], axis=0)
|
191 |
+
|
192 |
+
normalize = (0.0655, 0.0650), (0.1732, 0.1208)
|
193 |
+
|
194 |
+
elif dataset == "OpenCell":
|
195 |
+
image = Image.open(
|
196 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["image_path"])
|
197 |
+
)
|
198 |
+
nucleus, target = [page.copy() for page in ImageSequence.Iterator(image)]
|
199 |
+
|
200 |
+
nucleus = replace_outliers(torch.divide(TF.to_tensor(nucleus), 65536))[0]
|
201 |
+
target = replace_outliers(torch.divide(TF.to_tensor(target), 65536))[0]
|
202 |
+
|
203 |
+
image = torch.stack([nucleus, target], axis=0)
|
204 |
+
|
205 |
+
normalize = (
|
206 |
+
(0.0272, 0.0244),
|
207 |
+
(0.0486, 0.0671),
|
208 |
+
)
|
209 |
+
|
210 |
+
# # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
|
211 |
+
|
212 |
+
t_forms = [transforms.Resize(self.resize, antialias=None)]
|
213 |
+
|
214 |
+
if self.crop_method == "random":
|
215 |
+
|
216 |
+
t_forms.append(transforms.RandomCrop(self.crop_size))
|
217 |
+
t_forms.append(transforms.RandomHorizontalFlip(p=0.5))
|
218 |
+
t_forms.append(transforms.RandomVerticalFlip(p=0.5))
|
219 |
+
|
220 |
+
elif self.crop_method == "center":
|
221 |
+
|
222 |
+
t_forms.append(transforms.CenterCrop(self.crop_size))
|
223 |
+
|
224 |
+
t_forms.append(transforms.Normalize(normalize[0], normalize[1]))
|
225 |
+
|
226 |
+
image = transforms.Compose(t_forms)(image)
|
227 |
+
|
228 |
+
nucleus, target = image
|
229 |
+
|
230 |
+
nucleus /= torch.abs(nucleus).max()
|
231 |
+
target -= target.min()
|
232 |
+
target /= target.max()
|
233 |
+
|
234 |
+
nucleus = nucleus.unsqueeze(0)
|
235 |
+
target = target.unsqueeze(0)
|
236 |
+
|
237 |
+
threshold = target
|
238 |
+
|
239 |
+
if self.threshold == "mean":
|
240 |
+
|
241 |
+
threshold = 1.0 * (threshold > (torch.mean(threshold)))
|
242 |
+
|
243 |
+
elif self.threshold == "median":
|
244 |
+
|
245 |
+
threshold = 1.0 * (threshold > (torch.median(threshold)))
|
246 |
+
|
247 |
+
elif self.threshold == "1090_IQR":
|
248 |
+
|
249 |
+
p10 = torch.quantile(threshold, 0.1, None)
|
250 |
+
p90 = torch.quantile(threshold, 0.9, None)
|
251 |
+
threshold = torch.clip(threshold, p10, p90)
|
252 |
+
|
253 |
+
nucleus = torch.nan_to_num(nucleus, 0.0, 1.0, 0.0)
|
254 |
+
target = torch.nan_to_num(target, 0.0, 1.0, 0.0)
|
255 |
+
threshold = torch.nan_to_num(threshold, 0.0, 1.0, 0.0)
|
256 |
+
|
257 |
+
return nucleus, target, threshold
|
258 |
+
|
259 |
+
def retrieve_metadata(self, idx):
|
260 |
+
with open(
|
261 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
|
262 |
+
) as f:
|
263 |
+
metadata = json.load(f)
|
264 |
+
|
265 |
+
return metadata
|
266 |
+
|
267 |
+
def tokenize_sequence(self, protein_sequence):
|
268 |
+
|
269 |
+
pad_token = 0
|
270 |
+
|
271 |
+
if self.sequence_mode == "simple":
|
272 |
+
protein_vector = simple_conversion(protein_sequence)
|
273 |
+
|
274 |
+
elif self.sequence_mode == "center":
|
275 |
+
protein_sequence = protein_sequence.center(self.text_seq_length, "-")
|
276 |
+
protein_vector = simple_conversion(protein_sequence)
|
277 |
+
|
278 |
+
elif self.sequence_mode == "alternating":
|
279 |
+
protein_sequence = protein_sequence.center(self.text_seq_length, "-")
|
280 |
+
protein_sequence = protein_sequence[::18]
|
281 |
+
protein_sequence = protein_sequence.center(
|
282 |
+
int(self.text_seq_length / 18) + 1, "-"
|
283 |
+
)
|
284 |
+
protein_vector = simple_conversion(protein_sequence)
|
285 |
+
|
286 |
+
|
287 |
+
elif self.sequence_mode == "embedding":
|
288 |
+
|
289 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
290 |
+
pad_token = 1
|
291 |
+
protein_vector = self.tokenizer([("", protein_sequence)])[-1]
|
292 |
+
|
293 |
+
if protein_vector.shape[-1] < self.text_seq_len:
|
294 |
+
|
295 |
+
diff = self.text_seq_len - protein_vector.shape[-1]
|
296 |
+
|
297 |
+
if self.pad_mode == "end":
|
298 |
+
protein_vector = torch.nn.functional.pad(
|
299 |
+
protein_vector, (0, diff), "constant", pad_token
|
300 |
+
)
|
301 |
+
elif self.pad_mode == "random":
|
302 |
+
split = diff - np.random.randint(0, diff + 1)
|
303 |
+
|
304 |
+
protein_vector = torch.cat(
|
305 |
+
[torch.ones(1, split) * 0, protein_vector], dim=1
|
306 |
+
)
|
307 |
+
|
308 |
+
protein_vector = torch.nn.functional.pad(
|
309 |
+
protein_vector, (0, diff - split), "constant", pad_token
|
310 |
+
)
|
311 |
+
|
312 |
+
elif protein_vector.shape[-1] > self.text_seq_len:
|
313 |
+
start_int = np.random.randint(
|
314 |
+
0, protein_vector.shape[-1] - self.text_seq_len
|
315 |
+
)
|
316 |
+
|
317 |
+
protein_vector = protein_vector[
|
318 |
+
:, start_int : start_int + self.text_seq_len
|
319 |
+
]
|
320 |
+
|
321 |
+
return protein_vector.long()
|
prediction.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
os.chdir('..')
|
5 |
+
from dataloader import CellLoader
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from celle_main import instantiate_from_config
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from celle.utils import process_image
|
10 |
+
|
11 |
+
def run_model(mode, sequence,
|
12 |
+
nucleus_image_path,
|
13 |
+
protein_image_path,
|
14 |
+
model_ckpt_path,
|
15 |
+
model_config_path,
|
16 |
+
device):
|
17 |
+
if mode == "image":
|
18 |
+
run_image_prediction(
|
19 |
+
sequence,
|
20 |
+
nucleus_image_path,
|
21 |
+
protein_image_path,
|
22 |
+
model_ckpt_path,
|
23 |
+
model_config_path,
|
24 |
+
device
|
25 |
+
)
|
26 |
+
elif mode == "sequence":
|
27 |
+
run_sequence_prediction(
|
28 |
+
sequence,
|
29 |
+
nucleus_image_path,
|
30 |
+
protein_image_path,
|
31 |
+
model_ckpt_path,
|
32 |
+
model_config_path,
|
33 |
+
device
|
34 |
+
)
|
35 |
+
|
36 |
+
def run_sequence_prediction(
|
37 |
+
sequence_input,
|
38 |
+
nucleus_image_path,
|
39 |
+
protein_image_path,
|
40 |
+
model_ckpt_path,
|
41 |
+
model_config_path,
|
42 |
+
device
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Run Celle model with provided inputs and display results.
|
46 |
+
|
47 |
+
:param sequence: Path to sequence file
|
48 |
+
:param nucleus_image_path: Path to nucleus image
|
49 |
+
:param protein_image_path: Path to protein image (optional)
|
50 |
+
:param model_ckpt_path: Path to model checkpoint
|
51 |
+
:param model_config_path: Path to model config
|
52 |
+
"""
|
53 |
+
|
54 |
+
# Instantiate dataset object
|
55 |
+
dataset = CellLoader(
|
56 |
+
sequence_mode="embedding",
|
57 |
+
vocab="esm2",
|
58 |
+
split_key="val",
|
59 |
+
crop_method="center",
|
60 |
+
resize=600,
|
61 |
+
crop_size=256,
|
62 |
+
text_seq_len=1000,
|
63 |
+
pad_mode="end",
|
64 |
+
threshold="median",
|
65 |
+
)
|
66 |
+
|
67 |
+
# Check if sequence is provided and valid
|
68 |
+
if len(sequence_input) == 0:
|
69 |
+
raise ValueError("Sequence must be provided.")
|
70 |
+
|
71 |
+
if "<mask>" not in sequence_input:
|
72 |
+
print("Warning: Sequence does not contain any masked positions to predict.")
|
73 |
+
|
74 |
+
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
75 |
+
sequence = dataset.tokenize_sequence(sequence_input)
|
76 |
+
|
77 |
+
# Check if nucleus image path is provided and valid
|
78 |
+
if not os.path.exists(nucleus_image_path):
|
79 |
+
# Use default nucleus image from dataset and print warning
|
80 |
+
nucleus_image_path = 'images/nucleus.jpg'
|
81 |
+
print(
|
82 |
+
"Warning: No nucleus image provided. Using default nucleus image from dataset."
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
# Load nucleus image from provided path
|
86 |
+
nucleus_image = process_image(nucleus_image_path)
|
87 |
+
|
88 |
+
# Check if protein image path is provided and valid
|
89 |
+
if not os.path.exists(protein_image_path):
|
90 |
+
# Use default nucleus image from dataset and print warning
|
91 |
+
protein_image_path = 'images/protein.jpg'
|
92 |
+
print(
|
93 |
+
"Warning: No nucleus image provided. Using default protein image from dataset."
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
# Load protein image from provided path
|
97 |
+
protein_image = process_image(protein_image_path)
|
98 |
+
protein_image = (protein_image > torch.median(protein_image,dim=0))*1.0
|
99 |
+
|
100 |
+
# Load model config and set ckpt_path if not provided in config
|
101 |
+
config = OmegaConf.load(model_config_path)
|
102 |
+
if config["model"]["params"]["ckpt_path"] is None:
|
103 |
+
config["model"]["params"]["ckpt_path"] = model_ckpt_path
|
104 |
+
|
105 |
+
# Set condition_model_path and vqgan_model_path to None
|
106 |
+
config["model"]["params"]["condition_model_path"] = None
|
107 |
+
config["model"]["params"]["vqgan_model_path"] = None
|
108 |
+
|
109 |
+
# Instantiate model from config and move to device
|
110 |
+
model = instantiate_from_config(config).to(device)
|
111 |
+
|
112 |
+
# Sample from model using provided sequence and nucleus image
|
113 |
+
_, predicted_sequence, _ = model.celle.sample_text(
|
114 |
+
text=sequence,
|
115 |
+
condition=nucleus_image,
|
116 |
+
image=protein_image,
|
117 |
+
force_aas=True,
|
118 |
+
timesteps=1,
|
119 |
+
temperature=1,
|
120 |
+
progress=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
formatted_predicted_sequence = ""
|
124 |
+
|
125 |
+
for i in range(min(len(predicted_sequence), len(sequence))):
|
126 |
+
if predicted_sequence[i] != sequence[i]:
|
127 |
+
formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
|
128 |
+
else:
|
129 |
+
formatted_predicted_sequence += predicted_sequence[i]
|
130 |
+
|
131 |
+
if len(predicted_sequence) > len(sequence):
|
132 |
+
formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
|
133 |
+
|
134 |
+
print("predicted_sequence:", formatted_predicted_sequence)
|
135 |
+
|
136 |
+
|
137 |
+
def run_image_prediction(
|
138 |
+
sequence_input,
|
139 |
+
nucleus_image_path,
|
140 |
+
protein_image_path,
|
141 |
+
model_ckpt_path,
|
142 |
+
model_config_path,
|
143 |
+
device
|
144 |
+
):
|
145 |
+
"""
|
146 |
+
Run Celle model with provided inputs and display results.
|
147 |
+
|
148 |
+
:param sequence: Path to sequence file
|
149 |
+
:param nucleus_image_path: Path to nucleus image
|
150 |
+
:param protein_image_path: Path to protein image (optional)
|
151 |
+
:param model_ckpt_path: Path to model checkpoint
|
152 |
+
:param model_config_path: Path to model config
|
153 |
+
"""
|
154 |
+
# Instantiate dataset object
|
155 |
+
dataset = CellLoader(
|
156 |
+
sequence_mode="embedding",
|
157 |
+
vocab="esm2",
|
158 |
+
split_key="val",
|
159 |
+
crop_method="center",
|
160 |
+
resize=600,
|
161 |
+
crop_size=256,
|
162 |
+
text_seq_len=1000,
|
163 |
+
pad_mode="end",
|
164 |
+
threshold="median",
|
165 |
+
)
|
166 |
+
|
167 |
+
# Check if sequence is provided and valid
|
168 |
+
if len(sequence_input) == 0:
|
169 |
+
sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
|
170 |
+
# Use default sequence for GFP and print warning
|
171 |
+
print("Warning: No sequence provided. Using default sequence for GFP.")
|
172 |
+
|
173 |
+
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
174 |
+
sequence = dataset.tokenize_sequence(sequence_input)
|
175 |
+
|
176 |
+
# Check if nucleus image path is provided and valid
|
177 |
+
if not os.path.exists(nucleus_image_path):
|
178 |
+
# Use default nucleus image from dataset and print warning
|
179 |
+
nucleus_image = dataset[0]["nucleus"]
|
180 |
+
print(
|
181 |
+
"Warning: No nucleus image provided. Using default nucleus image from dataset."
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
# Load nucleus image from provided path
|
185 |
+
nucleus_image = process_image(nucleus_image_path)
|
186 |
+
|
187 |
+
# Load model config and set ckpt_path if not provided in config
|
188 |
+
config = OmegaConf.load(model_config_path)
|
189 |
+
if config["model"]["params"]["ckpt_path"] is None:
|
190 |
+
config["model"]["params"]["ckpt_path"] = model_ckpt_path
|
191 |
+
|
192 |
+
# Set condition_model_path and vqgan_model_path to None
|
193 |
+
config["model"]["params"]["condition_model_path"] = None
|
194 |
+
config["model"]["params"]["vqgan_model_path"] = None
|
195 |
+
|
196 |
+
# Instantiate model from config and move to device
|
197 |
+
model = instantiate_from_config(config).to(device)
|
198 |
+
|
199 |
+
# Sample from model using provided sequence and nucleus image
|
200 |
+
_, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
|
201 |
+
text=sequence,
|
202 |
+
condition=nucleus_image,
|
203 |
+
timesteps=1,
|
204 |
+
temperature=1,
|
205 |
+
progress=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
# Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
|
209 |
+
predicted_threshold = predicted_threshold.cpu()[0, 0]
|
210 |
+
predicted_heatmap = predicted_heatmap.cpu()[0, 0]
|
211 |
+
|
212 |
+
# Create 3 or 4 panel plot depending on whether protein image path is provided
|
213 |
+
fig, axs = plt.subplots(1, 3 if protein_image_path is None else 4)
|
214 |
+
axs[0].imshow(nucleus_image)
|
215 |
+
axs[0].set_title("Nucleus Input")
|
216 |
+
axs[1].imshow(predicted_threshold)
|
217 |
+
axs[1].set_title("Predicted Threshold")
|
218 |
+
if protein_image_path is not None:
|
219 |
+
protein_image = process_image(protein_image_path)
|
220 |
+
axs[2].imshow(protein_image)
|
221 |
+
axs[2].set_title("Protein Image")
|
222 |
+
axs[-1].imshow(predicted_heatmap)
|
223 |
+
axs[-1].set_title("Predicted Heatmap")
|
224 |
+
plt.show()
|
225 |
+
|
226 |
+
|
227 |
+
if __name__ == "__main__":
|
228 |
+
# Parse command line arguments for input parameters
|
229 |
+
parser = argparse.ArgumentParser(
|
230 |
+
description="Run Celle model with provided inputs."
|
231 |
+
)
|
232 |
+
parser.add_argument("--mode", type=str, default="", help="Sequence or Image")
|
233 |
+
parser.add_argument(
|
234 |
+
"--sequence", type=str, default="", help="Path to sequence file"
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--nucleus_image_path",
|
238 |
+
type=str,
|
239 |
+
default="images/nucleus.jpg",
|
240 |
+
help="Path to nucleus image",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--protein_image_path",
|
244 |
+
type=str,
|
245 |
+
default=None,
|
246 |
+
help="Path to protein image (optional)",
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--model_ckpt_path", type=str, required=True, help="Path to model checkpoint"
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--model_config_path", type=str, required=True, help="Path to model config"
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--device", type=str, default="cpu", required=True, help="device"
|
256 |
+
)
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
run_model(
|
260 |
+
args.mode,
|
261 |
+
args.sequence,
|
262 |
+
args.nucleus_image_path,
|
263 |
+
args.protein_image_path,
|
264 |
+
args.model_ckpt_path,
|
265 |
+
args.model_config_path,
|
266 |
+
args.device
|
267 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.4
|
2 |
+
aiosignal==1.3.1
|
3 |
+
antlr4-python3-runtime==4.9.3
|
4 |
+
anyio==3.6.2
|
5 |
+
argon2-cffi==21.3.0
|
6 |
+
argon2-cffi-bindings==21.2.0
|
7 |
+
arrow==1.2.3
|
8 |
+
asttokens==2.2.1
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==23.1.0
|
11 |
+
axial-positional-embedding==0.2.1
|
12 |
+
backcall==0.2.0
|
13 |
+
beautifulsoup4==4.12.2
|
14 |
+
bleach==6.0.0
|
15 |
+
blessed==1.20.0
|
16 |
+
certifi==2023.5.7
|
17 |
+
cffi==1.15.1
|
18 |
+
charset-normalizer==3.1.0
|
19 |
+
click==8.1.3
|
20 |
+
cmake==3.26.3
|
21 |
+
comm==0.1.3
|
22 |
+
contourpy==1.0.7
|
23 |
+
croniter==1.3.14
|
24 |
+
cycler==0.11.0
|
25 |
+
dateutils==0.6.12
|
26 |
+
debugpy==1.6.7
|
27 |
+
decorator==5.1.1
|
28 |
+
deepdiff==6.3.0
|
29 |
+
defusedxml==0.7.1
|
30 |
+
einops==0.6.1
|
31 |
+
executing==1.2.0
|
32 |
+
fair-esm==2.0.0
|
33 |
+
fastapi==0.88.0
|
34 |
+
fastjsonschema==2.16.3
|
35 |
+
filelock==3.12.0
|
36 |
+
fonttools==4.39.4
|
37 |
+
fqdn==1.5.1
|
38 |
+
frozenlist==1.3.3
|
39 |
+
fsspec==2023.5.0
|
40 |
+
h11==0.14.0
|
41 |
+
idna==3.4
|
42 |
+
inquirer==3.1.3
|
43 |
+
ipykernel==6.23.0
|
44 |
+
ipython==8.13.2
|
45 |
+
ipython-genutils==0.2.0
|
46 |
+
ipywidgets==8.0.6
|
47 |
+
isoduration==20.11.0
|
48 |
+
itsdangerous==2.1.2
|
49 |
+
jedi==0.18.2
|
50 |
+
Jinja2==3.1.2
|
51 |
+
jsonpointer==2.3
|
52 |
+
jsonschema==4.17.3
|
53 |
+
jupyter==1.0.0
|
54 |
+
jupyter-console==6.6.3
|
55 |
+
jupyter-events==0.6.3
|
56 |
+
jupyter_client==8.2.0
|
57 |
+
jupyter_core==5.3.0
|
58 |
+
jupyter_server==2.5.0
|
59 |
+
jupyter_server_terminals==0.4.4
|
60 |
+
jupyterlab-pygments==0.2.2
|
61 |
+
jupyterlab-widgets==3.0.7
|
62 |
+
kiwisolver==1.4.4
|
63 |
+
lightning==2.0.2
|
64 |
+
lightning-cloud==0.5.34
|
65 |
+
lightning-utilities==0.8.0
|
66 |
+
lit==16.0.3
|
67 |
+
markdown-it-py==2.2.0
|
68 |
+
MarkupSafe==2.1.2
|
69 |
+
matplotlib==3.7.1
|
70 |
+
matplotlib-inline==0.1.6
|
71 |
+
mdurl==0.1.2
|
72 |
+
mistune==2.0.5
|
73 |
+
mpmath==1.3.0
|
74 |
+
multidict==6.0.4
|
75 |
+
nbclassic==1.0.0
|
76 |
+
nbclient==0.7.4
|
77 |
+
nbconvert==7.4.0
|
78 |
+
nbformat==5.8.0
|
79 |
+
nest-asyncio==1.5.6
|
80 |
+
networkx==3.1
|
81 |
+
notebook==6.5.4
|
82 |
+
notebook_shim==0.2.3
|
83 |
+
numpy==1.24.3
|
84 |
+
nvidia-cublas-cu11==11.10.3.66
|
85 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
86 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
87 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
88 |
+
nvidia-cudnn-cu11==8.5.0.96
|
89 |
+
nvidia-cufft-cu11==10.9.0.58
|
90 |
+
nvidia-curand-cu11==10.2.10.91
|
91 |
+
nvidia-cusolver-cu11==11.4.0.1
|
92 |
+
nvidia-cusparse-cu11==11.7.4.91
|
93 |
+
nvidia-nccl-cu11==2.14.3
|
94 |
+
nvidia-nvtx-cu11==11.7.91
|
95 |
+
omegaconf==2.3.0
|
96 |
+
ordered-set==4.1.0
|
97 |
+
packaging==23.1
|
98 |
+
pandas==2.0.1
|
99 |
+
pandocfilters==1.5.0
|
100 |
+
parso==0.8.3
|
101 |
+
pexpect==4.8.0
|
102 |
+
pickleshare==0.7.5
|
103 |
+
Pillow==9.5.0
|
104 |
+
platformdirs==3.5.1
|
105 |
+
prometheus-client==0.16.0
|
106 |
+
prompt-toolkit==3.0.38
|
107 |
+
psutil==5.9.5
|
108 |
+
ptyprocess==0.7.0
|
109 |
+
pure-eval==0.2.2
|
110 |
+
pycparser==2.21
|
111 |
+
pydantic==1.10.7
|
112 |
+
Pygments==2.15.1
|
113 |
+
PyJWT==2.7.0
|
114 |
+
pyparsing==3.0.9
|
115 |
+
pyrsistent==0.19.3
|
116 |
+
python-dateutil==2.8.2
|
117 |
+
python-editor==1.0.4
|
118 |
+
python-json-logger==2.0.7
|
119 |
+
python-multipart==0.0.6
|
120 |
+
pytorch-lightning==1.9.0
|
121 |
+
pytz==2023.3
|
122 |
+
PyYAML==6.0
|
123 |
+
pyzmq==25.0.2
|
124 |
+
qtconsole==5.4.3
|
125 |
+
QtPy==2.3.1
|
126 |
+
readchar==4.0.5
|
127 |
+
requests==2.30.0
|
128 |
+
rfc3339-validator==0.1.4
|
129 |
+
rfc3986-validator==0.1.1
|
130 |
+
rich==13.3.5
|
131 |
+
rotary-embedding-torch==0.2.3
|
132 |
+
Send2Trash==1.8.2
|
133 |
+
six==1.16.0
|
134 |
+
sniffio==1.3.0
|
135 |
+
soupsieve==2.4.1
|
136 |
+
stack-data==0.6.2
|
137 |
+
starlette==0.22.0
|
138 |
+
starsessions==1.3.0
|
139 |
+
sympy==1.12
|
140 |
+
terminado==0.17.1
|
141 |
+
tinycss2==1.2.1
|
142 |
+
torch==2.0.0
|
143 |
+
torchmetrics==0.11.4
|
144 |
+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.15.1%2Bcpu-cp311-cp311-linux_x86_64.whl
|
145 |
+
tornado==6.3.1
|
146 |
+
tqdm==4.65.0
|
147 |
+
traitlets==5.9.0
|
148 |
+
triton==2.0.0
|
149 |
+
typing_extensions==4.5.0
|
150 |
+
tzdata==2023.3
|
151 |
+
uri-template==1.2.0
|
152 |
+
urllib3==2.0.2
|
153 |
+
uvicorn==0.22.0
|
154 |
+
wcwidth==0.2.6
|
155 |
+
webcolors==1.13
|
156 |
+
webencodings==0.5.1
|
157 |
+
websocket-client==1.5.1
|
158 |
+
websockets==11.0.3
|
159 |
+
widgetsnbextension==4.0.7
|
160 |
+
yarl==1.9.2
|
taming/lr_scheduler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n):
|
33 |
+
return self.schedule(n)
|
34 |
+
|
taming/models/cond_transformer.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from main import instantiate_from_config
|
7 |
+
from taming.modules.util import SOSProvider
|
8 |
+
|
9 |
+
|
10 |
+
def disabled_train(self, mode=True):
|
11 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
12 |
+
does not change anymore."""
|
13 |
+
return self
|
14 |
+
|
15 |
+
|
16 |
+
class Net2NetTransformer(pl.LightningModule):
|
17 |
+
def __init__(self,
|
18 |
+
transformer_config,
|
19 |
+
first_stage_config,
|
20 |
+
cond_stage_config,
|
21 |
+
permuter_config=None,
|
22 |
+
ckpt_path=None,
|
23 |
+
ignore_keys=[],
|
24 |
+
first_stage_key="image",
|
25 |
+
cond_stage_key="depth",
|
26 |
+
downsample_cond_size=-1,
|
27 |
+
pkeep=1.0,
|
28 |
+
sos_token=0,
|
29 |
+
unconditional=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.be_unconditional = unconditional
|
33 |
+
self.sos_token = sos_token
|
34 |
+
self.first_stage_key = first_stage_key
|
35 |
+
self.cond_stage_key = cond_stage_key
|
36 |
+
self.init_first_stage_from_ckpt(first_stage_config)
|
37 |
+
self.init_cond_stage_from_ckpt(cond_stage_config)
|
38 |
+
if permuter_config is None:
|
39 |
+
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
40 |
+
self.permuter = instantiate_from_config(config=permuter_config)
|
41 |
+
self.transformer = instantiate_from_config(config=transformer_config)
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
self.downsample_cond_size = downsample_cond_size
|
46 |
+
self.pkeep = pkeep
|
47 |
+
|
48 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
49 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
50 |
+
for k in sd.keys():
|
51 |
+
for ik in ignore_keys:
|
52 |
+
if k.startswith(ik):
|
53 |
+
self.print("Deleting key {} from state_dict.".format(k))
|
54 |
+
del sd[k]
|
55 |
+
self.load_state_dict(sd, strict=False)
|
56 |
+
print(f"Restored from {path}")
|
57 |
+
|
58 |
+
def init_first_stage_from_ckpt(self, config):
|
59 |
+
model = instantiate_from_config(config)
|
60 |
+
model = model.eval()
|
61 |
+
model.train = disabled_train
|
62 |
+
self.first_stage_model = model
|
63 |
+
|
64 |
+
def init_cond_stage_from_ckpt(self, config):
|
65 |
+
if config == "__is_first_stage__":
|
66 |
+
print("Using first stage also as cond stage.")
|
67 |
+
self.cond_stage_model = self.first_stage_model
|
68 |
+
elif config == "__is_unconditional__" or self.be_unconditional:
|
69 |
+
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
70 |
+
f"Prepending {self.sos_token} as a sos token.")
|
71 |
+
self.be_unconditional = True
|
72 |
+
self.cond_stage_key = self.first_stage_key
|
73 |
+
self.cond_stage_model = SOSProvider(self.sos_token)
|
74 |
+
else:
|
75 |
+
model = instantiate_from_config(config)
|
76 |
+
model = model.eval()
|
77 |
+
model.train = disabled_train
|
78 |
+
self.cond_stage_model = model
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
# one step to produce the logits
|
82 |
+
# x = target
|
83 |
+
# c = nucleus
|
84 |
+
_, z_indices = self.encode_to_z(x)
|
85 |
+
_, c_indices = self.encode_to_c(c)
|
86 |
+
|
87 |
+
if self.training and self.pkeep < 1.0:
|
88 |
+
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
89 |
+
device=z_indices.device))
|
90 |
+
mask = mask.round().to(dtype=torch.int64)
|
91 |
+
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
92 |
+
a_indices = mask*z_indices+(1-mask)*r_indices
|
93 |
+
else:
|
94 |
+
a_indices = z_indices
|
95 |
+
|
96 |
+
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
97 |
+
|
98 |
+
# target includes all sequence elements (no need to handle first one
|
99 |
+
# differently because we are conditioning)
|
100 |
+
target = z_indices
|
101 |
+
# make the prediction
|
102 |
+
logits, _ = self.transformer(cz_indices[:, :-1])
|
103 |
+
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
104 |
+
logits = logits[:, c_indices.shape[1]-1:]
|
105 |
+
|
106 |
+
return logits, target
|
107 |
+
|
108 |
+
def top_k_logits(self, logits, k):
|
109 |
+
v, ix = torch.topk(logits, k)
|
110 |
+
out = logits.clone()
|
111 |
+
out[out < v[..., [-1]]] = -float('Inf')
|
112 |
+
return out
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
116 |
+
callback=lambda k: None):
|
117 |
+
x = torch.cat((c,x),dim=1)
|
118 |
+
block_size = self.transformer.get_block_size()
|
119 |
+
assert not self.transformer.training
|
120 |
+
if self.pkeep <= 0.0:
|
121 |
+
# one pass suffices since input is pure noise anyway
|
122 |
+
assert len(x.shape)==2
|
123 |
+
noise_shape = (x.shape[0], steps-1)
|
124 |
+
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
125 |
+
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
126 |
+
x = torch.cat((x,noise),dim=1)
|
127 |
+
logits, _ = self.transformer(x)
|
128 |
+
# take all logits for now and scale by temp
|
129 |
+
logits = logits / temperature
|
130 |
+
# optionally crop probabilities to only the top k options
|
131 |
+
if top_k is not None:
|
132 |
+
logits = self.top_k_logits(logits, top_k)
|
133 |
+
# apply softmax to convert to probabilities
|
134 |
+
probs = F.softmax(logits, dim=-1)
|
135 |
+
# sample from the distribution or take the most likely
|
136 |
+
if sample:
|
137 |
+
shape = probs.shape
|
138 |
+
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
139 |
+
ix = torch.multinomial(probs, num_samples=1)
|
140 |
+
probs = probs.reshape(shape[0],shape[1],shape[2])
|
141 |
+
ix = ix.reshape(shape[0],shape[1])
|
142 |
+
else:
|
143 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
144 |
+
# cut off conditioning
|
145 |
+
x = ix[:, c.shape[1]-1:]
|
146 |
+
else:
|
147 |
+
for k in range(steps):
|
148 |
+
callback(k)
|
149 |
+
assert x.size(1) <= block_size # make sure model can see conditioning
|
150 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
151 |
+
logits, _ = self.transformer(x_cond)
|
152 |
+
# pluck the logits at the final step and scale by temperature
|
153 |
+
logits = logits[:, -1, :] / temperature
|
154 |
+
# optionally crop probabilities to only the top k options
|
155 |
+
if top_k is not None:
|
156 |
+
logits = self.top_k_logits(logits, top_k)
|
157 |
+
# apply softmax to convert to probabilities
|
158 |
+
probs = F.softmax(logits, dim=-1)
|
159 |
+
# sample from the distribution or take the most likely
|
160 |
+
if sample:
|
161 |
+
ix = torch.multinomial(probs, num_samples=1)
|
162 |
+
else:
|
163 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
164 |
+
# append to the sequence and continue
|
165 |
+
x = torch.cat((x, ix), dim=1)
|
166 |
+
# cut off conditioning
|
167 |
+
x = x[:, c.shape[1]:]
|
168 |
+
return x
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def encode_to_z(self, x):
|
172 |
+
quant_z, _, info = self.first_stage_model.encode(x)
|
173 |
+
indices = info[2].view(quant_z.shape[0], -1)
|
174 |
+
indices = self.permuter(indices)
|
175 |
+
return quant_z, indices
|
176 |
+
|
177 |
+
@torch.no_grad()
|
178 |
+
def encode_to_c(self, c):
|
179 |
+
if self.downsample_cond_size > -1:
|
180 |
+
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
181 |
+
|
182 |
+
#quant_c, _, info = self.cond_stage_model.encode(x)
|
183 |
+
#indices = info[2].view(quant_c.shape[0], -1)
|
184 |
+
#indices = self.permuter(indices)
|
185 |
+
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
186 |
+
if len(indices.shape) != 2:
|
187 |
+
indices = indices.view(c.shape[0], -1)
|
188 |
+
return quant_c, indices
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def decode_to_img(self, index, zshape):
|
192 |
+
index = self.permuter(index, reverse=True)
|
193 |
+
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
194 |
+
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
195 |
+
index.reshape(-1), shape=bhwc)
|
196 |
+
x = self.first_stage_model.decode(quant_z)
|
197 |
+
return x
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
201 |
+
log = dict()
|
202 |
+
|
203 |
+
N = 4
|
204 |
+
if lr_interface:
|
205 |
+
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
206 |
+
else:
|
207 |
+
x, c = self.get_xc(batch, N)
|
208 |
+
x = x.to(device=self.device)
|
209 |
+
c = c.to(device=self.device)
|
210 |
+
|
211 |
+
quant_z, z_indices = self.encode_to_z(x)
|
212 |
+
quant_c, c_indices = self.encode_to_c(c)
|
213 |
+
|
214 |
+
# create a "half"" sample
|
215 |
+
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
216 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
217 |
+
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
218 |
+
temperature=temperature if temperature is not None else 1.0,
|
219 |
+
sample=True,
|
220 |
+
top_k=top_k if top_k is not None else 100,
|
221 |
+
callback=callback if callback is not None else lambda k: None)
|
222 |
+
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
223 |
+
|
224 |
+
# sample
|
225 |
+
z_start_indices = z_indices[:, :0]
|
226 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
227 |
+
steps=z_indices.shape[1],
|
228 |
+
temperature=temperature if temperature is not None else 1.0,
|
229 |
+
sample=True,
|
230 |
+
top_k=top_k if top_k is not None else 100,
|
231 |
+
callback=callback if callback is not None else lambda k: None)
|
232 |
+
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
233 |
+
|
234 |
+
# det sample
|
235 |
+
z_start_indices = z_indices[:, :0]
|
236 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
237 |
+
steps=z_indices.shape[1],
|
238 |
+
sample=False,
|
239 |
+
callback=callback if callback is not None else lambda k: None)
|
240 |
+
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
241 |
+
|
242 |
+
# reconstruction
|
243 |
+
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
244 |
+
|
245 |
+
log["inputs"] = x
|
246 |
+
log["reconstructions"] = x_rec
|
247 |
+
|
248 |
+
if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target":
|
249 |
+
cond_rec = self.cond_stage_model.decode(quant_c)
|
250 |
+
if self.cond_stage_key == "segmentation":
|
251 |
+
# get image from segmentation mask
|
252 |
+
num_classes = cond_rec.shape[1]
|
253 |
+
|
254 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
255 |
+
c = F.one_hot(c, num_classes=num_classes)
|
256 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
257 |
+
c = self.cond_stage_model.to_rgb(c)
|
258 |
+
|
259 |
+
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
260 |
+
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
261 |
+
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
262 |
+
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
263 |
+
log["conditioning_rec"] = cond_rec
|
264 |
+
log["conditioning"] = c
|
265 |
+
|
266 |
+
log["samples_half"] = x_sample
|
267 |
+
log["samples_nopix"] = x_sample_nopix
|
268 |
+
log["samples_det"] = x_sample_det
|
269 |
+
return log
|
270 |
+
|
271 |
+
def get_input(self, key, batch):
|
272 |
+
x = batch[key]
|
273 |
+
if len(x.shape) == 3:
|
274 |
+
x = x[..., None]
|
275 |
+
#if len(x.shape) == 4:
|
276 |
+
# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
277 |
+
if x.dtype == torch.double:
|
278 |
+
x = x.float()
|
279 |
+
return x
|
280 |
+
|
281 |
+
def get_xc(self, batch, N=None):
|
282 |
+
x = self.get_input(self.first_stage_key, batch)
|
283 |
+
c = self.get_input(self.cond_stage_key, batch)
|
284 |
+
if N is not None:
|
285 |
+
x = x[:N]
|
286 |
+
c = c[:N]
|
287 |
+
return x, c
|
288 |
+
|
289 |
+
def shared_step(self, batch):
|
290 |
+
x, c = self.get_xc(batch)
|
291 |
+
logits, target = self(x, c)
|
292 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
293 |
+
return loss
|
294 |
+
|
295 |
+
def training_step(self, batch, batch_idx):
|
296 |
+
loss = self.shared_step(batch)
|
297 |
+
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
298 |
+
return loss
|
299 |
+
|
300 |
+
def validation_step(self, batch, batch_idx):
|
301 |
+
loss = self.shared_step(batch)
|
302 |
+
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
303 |
+
return loss
|
304 |
+
|
305 |
+
def configure_optimizers(self):
|
306 |
+
"""
|
307 |
+
Following minGPT:
|
308 |
+
This long function is unfortunately doing something very simple and is being very defensive:
|
309 |
+
We are separating out all parameters of the model into two buckets: those that will experience
|
310 |
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
311 |
+
We are then returning the PyTorch optimizer object.
|
312 |
+
"""
|
313 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
314 |
+
decay = set()
|
315 |
+
no_decay = set()
|
316 |
+
whitelist_weight_modules = (torch.nn.Linear, )
|
317 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
318 |
+
for mn, m in self.transformer.named_modules():
|
319 |
+
for pn, p in m.named_parameters():
|
320 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
321 |
+
|
322 |
+
if pn.endswith('bias'):
|
323 |
+
# all biases will not be decayed
|
324 |
+
no_decay.add(fpn)
|
325 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
326 |
+
# weights of whitelist modules will be weight decayed
|
327 |
+
decay.add(fpn)
|
328 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
329 |
+
# weights of blacklist modules will NOT be weight decayed
|
330 |
+
no_decay.add(fpn)
|
331 |
+
|
332 |
+
# special case the position embedding parameter in the root GPT module as not decayed
|
333 |
+
no_decay.add('pos_emb')
|
334 |
+
|
335 |
+
# validate that we considered every parameter
|
336 |
+
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
337 |
+
inter_params = decay & no_decay
|
338 |
+
union_params = decay | no_decay
|
339 |
+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
340 |
+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
341 |
+
% (str(param_dict.keys() - union_params), )
|
342 |
+
|
343 |
+
# create the pytorch optimizer object
|
344 |
+
optim_groups = [
|
345 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
346 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
347 |
+
]
|
348 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
349 |
+
return optimizer
|
taming/models/dummy_cond_stage.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
|
4 |
+
class DummyCondStage:
|
5 |
+
def __init__(self, conditional_key):
|
6 |
+
self.conditional_key = conditional_key
|
7 |
+
self.train = None
|
8 |
+
|
9 |
+
def eval(self):
|
10 |
+
return self
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def encode(c: Tensor):
|
14 |
+
return c, None, (None, None, c)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def decode(c: Tensor):
|
18 |
+
return c
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def to_rgb(c: Tensor):
|
22 |
+
return c
|
taming/models/vqgan.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from celle_taming_main import instantiate_from_config
|
6 |
+
|
7 |
+
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
+
from taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
+
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
11 |
+
|
12 |
+
|
13 |
+
class VQModel(pl.LightningModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
ddconfig,
|
17 |
+
lossconfig,
|
18 |
+
n_embed,
|
19 |
+
embed_dim,
|
20 |
+
ckpt_path=None,
|
21 |
+
ignore_keys=[],
|
22 |
+
image_key="image",
|
23 |
+
colorize_nlabels=None,
|
24 |
+
monitor=None,
|
25 |
+
remap=None,
|
26 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.image_key = image_key
|
30 |
+
self.encoder = Encoder(**ddconfig)
|
31 |
+
self.decoder = Decoder(**ddconfig)
|
32 |
+
self.loss = instantiate_from_config(lossconfig)
|
33 |
+
self.quantize = VectorQuantizer(
|
34 |
+
n_embed,
|
35 |
+
embed_dim,
|
36 |
+
beta=0.25,
|
37 |
+
remap=remap,
|
38 |
+
sane_index_shape=sane_index_shape,
|
39 |
+
)
|
40 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
41 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
42 |
+
if ckpt_path is not None:
|
43 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
44 |
+
self.image_key = image_key
|
45 |
+
if colorize_nlabels is not None:
|
46 |
+
assert type(colorize_nlabels) == int
|
47 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
48 |
+
if monitor is not None:
|
49 |
+
self.monitor = monitor
|
50 |
+
|
51 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
52 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
53 |
+
keys = list(sd.keys())
|
54 |
+
for k in keys:
|
55 |
+
for ik in ignore_keys:
|
56 |
+
if k.startswith(ik):
|
57 |
+
print("Deleting key {} from state_dict.".format(k))
|
58 |
+
del sd[k]
|
59 |
+
self.load_state_dict(sd, strict=False)
|
60 |
+
print(f"Restored from {path}")
|
61 |
+
|
62 |
+
def encode(self, x):
|
63 |
+
h = self.encoder(x)
|
64 |
+
h = self.quant_conv(h)
|
65 |
+
quant, emb_loss, info = self.quantize(h)
|
66 |
+
return quant, emb_loss, info
|
67 |
+
|
68 |
+
def decode(self, quant):
|
69 |
+
quant = self.post_quant_conv(quant)
|
70 |
+
dec = self.decoder(quant)
|
71 |
+
return dec
|
72 |
+
|
73 |
+
def decode_code(self, code_b):
|
74 |
+
quant_b = self.quantize.embed_code(code_b)
|
75 |
+
dec = self.decode(quant_b)
|
76 |
+
return dec
|
77 |
+
|
78 |
+
def forward(self, input):
|
79 |
+
quant, diff, _ = self.encode(input)
|
80 |
+
dec = self.decode(quant)
|
81 |
+
return dec, diff
|
82 |
+
|
83 |
+
def get_input(self, batch, k):
|
84 |
+
|
85 |
+
if k == "mixed":
|
86 |
+
keys = ["nucleus", "target"]
|
87 |
+
index = torch.randint(low=0, high=2, size=(1,), dtype=int).item()
|
88 |
+
k = keys[index]
|
89 |
+
|
90 |
+
x = batch[k]
|
91 |
+
if len(x.shape) == 3:
|
92 |
+
x = x[..., None]
|
93 |
+
|
94 |
+
# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
95 |
+
return x
|
96 |
+
|
97 |
+
def training_step(self, batch, batch_idx=None, optimizer_idx=0):
|
98 |
+
|
99 |
+
if type(batch) == dict:
|
100 |
+
|
101 |
+
x = self.get_input(batch, self.image_key)
|
102 |
+
|
103 |
+
else:
|
104 |
+
x = batch
|
105 |
+
|
106 |
+
xrec, qloss = self(
|
107 |
+
x,
|
108 |
+
)
|
109 |
+
|
110 |
+
if optimizer_idx == 0:
|
111 |
+
# autoencode
|
112 |
+
aeloss, log_dict_ae = self.loss(
|
113 |
+
qloss,
|
114 |
+
x,
|
115 |
+
xrec,
|
116 |
+
optimizer_idx,
|
117 |
+
self.global_step,
|
118 |
+
last_layer=self.get_last_layer(),
|
119 |
+
split="train",
|
120 |
+
)
|
121 |
+
|
122 |
+
self.log(
|
123 |
+
"train/aeloss",
|
124 |
+
aeloss,
|
125 |
+
prog_bar=True,
|
126 |
+
logger=True,
|
127 |
+
on_step=True,
|
128 |
+
on_epoch=True,
|
129 |
+
sync_dist=True,
|
130 |
+
)
|
131 |
+
self.log_dict(
|
132 |
+
log_dict_ae,
|
133 |
+
prog_bar=False,
|
134 |
+
logger=True,
|
135 |
+
on_step=True,
|
136 |
+
on_epoch=True,
|
137 |
+
sync_dist=True,
|
138 |
+
)
|
139 |
+
return aeloss
|
140 |
+
|
141 |
+
if optimizer_idx == 1:
|
142 |
+
# discriminator
|
143 |
+
discloss, log_dict_disc = self.loss(
|
144 |
+
qloss,
|
145 |
+
x,
|
146 |
+
xrec,
|
147 |
+
optimizer_idx,
|
148 |
+
self.global_step,
|
149 |
+
last_layer=self.get_last_layer(),
|
150 |
+
split="train",
|
151 |
+
)
|
152 |
+
self.log(
|
153 |
+
"train/discloss",
|
154 |
+
discloss,
|
155 |
+
prog_bar=True,
|
156 |
+
logger=True,
|
157 |
+
on_step=True,
|
158 |
+
on_epoch=True,
|
159 |
+
sync_dist=True,
|
160 |
+
)
|
161 |
+
self.log_dict(
|
162 |
+
log_dict_disc,
|
163 |
+
prog_bar=False,
|
164 |
+
logger=True,
|
165 |
+
on_step=True,
|
166 |
+
on_epoch=True,
|
167 |
+
sync_dist=True,
|
168 |
+
)
|
169 |
+
return discloss
|
170 |
+
|
171 |
+
def validation_step(self, batch, batch_idx):
|
172 |
+
|
173 |
+
if type(batch) == dict:
|
174 |
+
|
175 |
+
x = self.get_input(batch, self.image_key)
|
176 |
+
|
177 |
+
else:
|
178 |
+
x = batch
|
179 |
+
|
180 |
+
xrec, qloss = self(x)
|
181 |
+
aeloss, log_dict_ae = self.loss(
|
182 |
+
qloss,
|
183 |
+
x,
|
184 |
+
xrec,
|
185 |
+
0,
|
186 |
+
self.global_step,
|
187 |
+
last_layer=self.get_last_layer(),
|
188 |
+
split="val",
|
189 |
+
)
|
190 |
+
|
191 |
+
discloss, log_dict_disc = self.loss(
|
192 |
+
qloss,
|
193 |
+
x,
|
194 |
+
xrec,
|
195 |
+
1,
|
196 |
+
self.global_step,
|
197 |
+
last_layer=self.get_last_layer(),
|
198 |
+
split="val",
|
199 |
+
)
|
200 |
+
# rec_loss = log_dict_ae["val/rec_loss"]
|
201 |
+
# self.log(
|
202 |
+
# "val/rec_loss",
|
203 |
+
# rec_loss,
|
204 |
+
# prog_bar=True,
|
205 |
+
# logger=True,
|
206 |
+
# on_step=True,
|
207 |
+
# on_epoch=True,
|
208 |
+
# sync_dist=True,
|
209 |
+
# )
|
210 |
+
# self.log(
|
211 |
+
# "val/aeloss",
|
212 |
+
# aeloss,
|
213 |
+
# prog_bar=True,
|
214 |
+
# logger=True,
|
215 |
+
# on_step=True,
|
216 |
+
# on_epoch=True,
|
217 |
+
# sync_dist=True,
|
218 |
+
# )
|
219 |
+
|
220 |
+
for key, value in log_dict_disc.items():
|
221 |
+
if key in log_dict_ae:
|
222 |
+
log_dict_ae[key].extend(value)
|
223 |
+
else:
|
224 |
+
log_dict_ae[key] = value
|
225 |
+
|
226 |
+
self.log_dict(log_dict_ae, sync_dist=True)
|
227 |
+
return self.log_dict
|
228 |
+
|
229 |
+
def configure_optimizers(self):
|
230 |
+
lr = self.learning_rate
|
231 |
+
opt_ae = torch.optim.Adam(
|
232 |
+
list(self.encoder.parameters())
|
233 |
+
+ list(self.decoder.parameters())
|
234 |
+
+ list(self.quantize.parameters())
|
235 |
+
+ list(self.quant_conv.parameters())
|
236 |
+
+ list(self.post_quant_conv.parameters()),
|
237 |
+
lr=lr,
|
238 |
+
betas=(0.5, 0.9),
|
239 |
+
)
|
240 |
+
opt_disc = torch.optim.Adam(
|
241 |
+
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
242 |
+
)
|
243 |
+
return [opt_ae, opt_disc], []
|
244 |
+
|
245 |
+
def get_last_layer(self):
|
246 |
+
return self.decoder.conv_out.weight
|
247 |
+
|
248 |
+
def log_images(self, batch, **kwargs):
|
249 |
+
log = dict()
|
250 |
+
x = self.get_input(batch, self.image_key)
|
251 |
+
x = x.to(self.device)
|
252 |
+
xrec, _ = self(x)
|
253 |
+
if x.shape[1] > 3:
|
254 |
+
# colorize with random projection
|
255 |
+
assert xrec.shape[1] > 3
|
256 |
+
x = self.to_rgb(x)
|
257 |
+
xrec = self.to_rgb(xrec)
|
258 |
+
log["inputs"] = x
|
259 |
+
log["reconstructions"] = xrec
|
260 |
+
return log
|
261 |
+
|
262 |
+
def to_rgb(self, x):
|
263 |
+
assert self.image_key == "segmentation"
|
264 |
+
if not hasattr(self, "colorize"):
|
265 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
266 |
+
x = F.conv2d(x, weight=self.colorize)
|
267 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
class VQSegmentationModel(VQModel):
|
272 |
+
def __init__(self, n_labels, *args, **kwargs):
|
273 |
+
super().__init__(*args, **kwargs)
|
274 |
+
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
275 |
+
|
276 |
+
def configure_optimizers(self):
|
277 |
+
lr = self.learning_rate
|
278 |
+
opt_ae = torch.optim.Adam(
|
279 |
+
list(self.encoder.parameters())
|
280 |
+
+ list(self.decoder.parameters())
|
281 |
+
+ list(self.quantize.parameters())
|
282 |
+
+ list(self.quant_conv.parameters())
|
283 |
+
+ list(self.post_quant_conv.parameters()),
|
284 |
+
lr=lr,
|
285 |
+
betas=(0.5, 0.9),
|
286 |
+
)
|
287 |
+
return opt_ae
|
288 |
+
|
289 |
+
def training_step(self, batch, batch_idx):
|
290 |
+
x = self.get_input(batch, self.image_key)
|
291 |
+
xrec, qloss = self(x)
|
292 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
293 |
+
self.log_dict(
|
294 |
+
log_dict_ae,
|
295 |
+
prog_bar=False,
|
296 |
+
logger=True,
|
297 |
+
on_step=True,
|
298 |
+
on_epoch=True,
|
299 |
+
sync_dist=True,
|
300 |
+
)
|
301 |
+
return aeloss
|
302 |
+
|
303 |
+
def validation_step(self, batch, batch_idx):
|
304 |
+
x = self.get_input(batch, self.image_key)
|
305 |
+
xrec, qloss = self(x)
|
306 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
307 |
+
self.log_dict(
|
308 |
+
log_dict_ae,
|
309 |
+
prog_bar=False,
|
310 |
+
logger=True,
|
311 |
+
on_step=True,
|
312 |
+
on_epoch=True,
|
313 |
+
sync_dist=True,
|
314 |
+
)
|
315 |
+
total_loss = log_dict_ae["val/total_loss"]
|
316 |
+
self.log(
|
317 |
+
"val/total_loss",
|
318 |
+
total_loss,
|
319 |
+
prog_bar=True,
|
320 |
+
logger=True,
|
321 |
+
on_step=True,
|
322 |
+
on_epoch=True,
|
323 |
+
sync_dist=True,
|
324 |
+
)
|
325 |
+
return aeloss
|
326 |
+
|
327 |
+
@torch.no_grad()
|
328 |
+
def log_images(self, batch, **kwargs):
|
329 |
+
log = dict()
|
330 |
+
x = self.get_input(batch, self.image_key)
|
331 |
+
x = x.to(self.device)
|
332 |
+
xrec, _ = self(x)
|
333 |
+
if x.shape[1] > 3:
|
334 |
+
# colorize with random projection
|
335 |
+
assert xrec.shape[1] > 3
|
336 |
+
# convert logits to indices
|
337 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
338 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
339 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
340 |
+
x = self.to_rgb(x)
|
341 |
+
xrec = self.to_rgb(xrec)
|
342 |
+
log["inputs"] = x
|
343 |
+
log["reconstructions"] = xrec
|
344 |
+
return log
|
345 |
+
|
346 |
+
|
347 |
+
class VQNoDiscModel(VQModel):
|
348 |
+
def __init__(
|
349 |
+
self,
|
350 |
+
ddconfig,
|
351 |
+
lossconfig,
|
352 |
+
n_embed,
|
353 |
+
embed_dim,
|
354 |
+
ckpt_path=None,
|
355 |
+
ignore_keys=[],
|
356 |
+
image_key="image",
|
357 |
+
colorize_nlabels=None,
|
358 |
+
):
|
359 |
+
super().__init__(
|
360 |
+
ddconfig=ddconfig,
|
361 |
+
lossconfig=lossconfig,
|
362 |
+
n_embed=n_embed,
|
363 |
+
embed_dim=embed_dim,
|
364 |
+
ckpt_path=ckpt_path,
|
365 |
+
ignore_keys=ignore_keys,
|
366 |
+
image_key=image_key,
|
367 |
+
colorize_nlabels=colorize_nlabels,
|
368 |
+
)
|
369 |
+
|
370 |
+
def training_step(self, batch, batch_idx):
|
371 |
+
x = self.get_input(batch, self.image_key)
|
372 |
+
xrec, qloss = self(x)
|
373 |
+
# autoencode
|
374 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
375 |
+
output = pl.TrainResult(minimize=aeloss)
|
376 |
+
output.log(
|
377 |
+
"train/aeloss",
|
378 |
+
aeloss,
|
379 |
+
prog_bar=True,
|
380 |
+
logger=True,
|
381 |
+
on_step=True,
|
382 |
+
on_epoch=True,
|
383 |
+
)
|
384 |
+
output.log_dict(
|
385 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
386 |
+
)
|
387 |
+
return output
|
388 |
+
|
389 |
+
def validation_step(self, batch, batch_idx):
|
390 |
+
x = self.get_input(batch, self.image_key)
|
391 |
+
xrec, qloss = self(x)
|
392 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
393 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
394 |
+
output = pl.EvalResult(checkpoint_on=rec_loss)
|
395 |
+
output.log(
|
396 |
+
"val/rec_loss",
|
397 |
+
rec_loss,
|
398 |
+
prog_bar=True,
|
399 |
+
logger=True,
|
400 |
+
on_step=True,
|
401 |
+
on_epoch=True,
|
402 |
+
)
|
403 |
+
output.log(
|
404 |
+
"val/aeloss",
|
405 |
+
aeloss,
|
406 |
+
prog_bar=True,
|
407 |
+
logger=True,
|
408 |
+
on_step=True,
|
409 |
+
on_epoch=True,
|
410 |
+
)
|
411 |
+
output.log_dict(log_dict_ae)
|
412 |
+
|
413 |
+
return output
|
414 |
+
|
415 |
+
def configure_optimizers(self):
|
416 |
+
optimizer = torch.optim.Adam(
|
417 |
+
list(self.encoder.parameters())
|
418 |
+
+ list(self.decoder.parameters())
|
419 |
+
+ list(self.quantize.parameters())
|
420 |
+
+ list(self.quant_conv.parameters())
|
421 |
+
+ list(self.post_quant_conv.parameters()),
|
422 |
+
lr=self.learning_rate,
|
423 |
+
betas=(0.5, 0.9),
|
424 |
+
)
|
425 |
+
return optimizer
|
426 |
+
|
427 |
+
|
428 |
+
class GumbelVQ(VQModel):
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
ddconfig,
|
432 |
+
lossconfig,
|
433 |
+
n_embed,
|
434 |
+
embed_dim,
|
435 |
+
temperature_scheduler_config,
|
436 |
+
ckpt_path=None,
|
437 |
+
ignore_keys=[],
|
438 |
+
image_key="image",
|
439 |
+
colorize_nlabels=None,
|
440 |
+
monitor=None,
|
441 |
+
kl_weight=1e-8,
|
442 |
+
remap=None,
|
443 |
+
):
|
444 |
+
|
445 |
+
z_channels = ddconfig["z_channels"]
|
446 |
+
super().__init__(
|
447 |
+
ddconfig,
|
448 |
+
lossconfig,
|
449 |
+
n_embed,
|
450 |
+
embed_dim,
|
451 |
+
ckpt_path=None,
|
452 |
+
ignore_keys=ignore_keys,
|
453 |
+
image_key=image_key,
|
454 |
+
colorize_nlabels=colorize_nlabels,
|
455 |
+
monitor=monitor,
|
456 |
+
)
|
457 |
+
|
458 |
+
self.loss.n_classes = n_embed
|
459 |
+
self.vocab_size = n_embed
|
460 |
+
|
461 |
+
self.quantize = GumbelQuantize(
|
462 |
+
z_channels,
|
463 |
+
embed_dim,
|
464 |
+
n_embed=n_embed,
|
465 |
+
kl_weight=kl_weight,
|
466 |
+
temp_init=1.0,
|
467 |
+
remap=remap,
|
468 |
+
)
|
469 |
+
|
470 |
+
self.temperature_scheduler = instantiate_from_config(
|
471 |
+
temperature_scheduler_config
|
472 |
+
) # annealing of temp
|
473 |
+
|
474 |
+
if ckpt_path is not None:
|
475 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
476 |
+
|
477 |
+
def temperature_scheduling(self):
|
478 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
479 |
+
|
480 |
+
def encode_to_prequant(self, x):
|
481 |
+
h = self.encoder(x)
|
482 |
+
h = self.quant_conv(h)
|
483 |
+
return h
|
484 |
+
|
485 |
+
def decode_code(self, code_b):
|
486 |
+
raise NotImplementedError
|
487 |
+
|
488 |
+
def training_step(self, batch, batch_idx=None, optimizer_idx=0):
|
489 |
+
self.temperature_scheduling()
|
490 |
+
x = self.get_input(batch, self.image_key)
|
491 |
+
xrec, qloss = self(x)
|
492 |
+
|
493 |
+
if optimizer_idx == 0:
|
494 |
+
# autoencode
|
495 |
+
aeloss, log_dict_ae = self.loss(
|
496 |
+
qloss,
|
497 |
+
x,
|
498 |
+
xrec,
|
499 |
+
optimizer_idx,
|
500 |
+
self.global_step,
|
501 |
+
last_layer=self.get_last_layer(),
|
502 |
+
split="train",
|
503 |
+
)
|
504 |
+
|
505 |
+
self.log_dict(
|
506 |
+
log_dict_ae,
|
507 |
+
prog_bar=False,
|
508 |
+
logger=True,
|
509 |
+
on_step=True,
|
510 |
+
on_epoch=True,
|
511 |
+
sync_dist=True,
|
512 |
+
)
|
513 |
+
self.log(
|
514 |
+
"temperature",
|
515 |
+
self.quantize.temperature,
|
516 |
+
prog_bar=False,
|
517 |
+
logger=True,
|
518 |
+
on_step=True,
|
519 |
+
on_epoch=True,
|
520 |
+
sync_dist=True,
|
521 |
+
)
|
522 |
+
return aeloss
|
523 |
+
|
524 |
+
if optimizer_idx == 1:
|
525 |
+
# discriminator
|
526 |
+
discloss, log_dict_disc = self.loss(
|
527 |
+
qloss,
|
528 |
+
x,
|
529 |
+
xrec,
|
530 |
+
optimizer_idx,
|
531 |
+
self.global_step,
|
532 |
+
last_layer=self.get_last_layer(),
|
533 |
+
split="train",
|
534 |
+
)
|
535 |
+
self.log_dict(
|
536 |
+
log_dict_disc,
|
537 |
+
prog_bar=False,
|
538 |
+
logger=True,
|
539 |
+
on_step=True,
|
540 |
+
on_epoch=True,
|
541 |
+
sync_dist=True,
|
542 |
+
)
|
543 |
+
return discloss
|
544 |
+
|
545 |
+
def validation_step(self, batch, batch_idx):
|
546 |
+
x = self.get_input(batch, self.image_key)
|
547 |
+
xrec, qloss = self(x)
|
548 |
+
aeloss, log_dict_ae = self.loss(
|
549 |
+
qloss,
|
550 |
+
x,
|
551 |
+
xrec,
|
552 |
+
0,
|
553 |
+
self.global_step,
|
554 |
+
last_layer=self.get_last_layer(),
|
555 |
+
split="val",
|
556 |
+
)
|
557 |
+
|
558 |
+
discloss, log_dict_disc = self.loss(
|
559 |
+
qloss,
|
560 |
+
x,
|
561 |
+
xrec,
|
562 |
+
1,
|
563 |
+
self.global_step,
|
564 |
+
last_layer=self.get_last_layer(),
|
565 |
+
split="val",
|
566 |
+
)
|
567 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
568 |
+
self.log(
|
569 |
+
"val/rec_loss",
|
570 |
+
rec_loss,
|
571 |
+
prog_bar=True,
|
572 |
+
logger=True,
|
573 |
+
on_step=False,
|
574 |
+
on_epoch=True,
|
575 |
+
sync_dist=True,
|
576 |
+
)
|
577 |
+
self.log(
|
578 |
+
"val/aeloss",
|
579 |
+
aeloss,
|
580 |
+
prog_bar=True,
|
581 |
+
logger=True,
|
582 |
+
on_step=False,
|
583 |
+
on_epoch=True,
|
584 |
+
sync_dist=True,
|
585 |
+
)
|
586 |
+
self.log_dict(log_dict_ae, sync_dist=True)
|
587 |
+
self.log_dict(log_dict_disc, sync_dist=True)
|
588 |
+
return self.log_dict
|
589 |
+
|
590 |
+
def log_images(self, batch, **kwargs):
|
591 |
+
log = dict()
|
592 |
+
x = self.get_input(batch, self.image_key)
|
593 |
+
x = x.to(self.device)
|
594 |
+
# encode
|
595 |
+
h = self.encoder(x)
|
596 |
+
h = self.quant_conv(h)
|
597 |
+
quant, _, _ = self.quantize(h)
|
598 |
+
# decode
|
599 |
+
x_rec = self.decode(quant)
|
600 |
+
log["inputs"] = x
|
601 |
+
log["reconstructions"] = x_rec
|
602 |
+
return log
|
603 |
+
|
604 |
+
|
605 |
+
class EMAVQ(VQModel):
|
606 |
+
def __init__(
|
607 |
+
self,
|
608 |
+
ddconfig,
|
609 |
+
lossconfig,
|
610 |
+
n_embed,
|
611 |
+
embed_dim,
|
612 |
+
ckpt_path=None,
|
613 |
+
ignore_keys=[],
|
614 |
+
image_key="image",
|
615 |
+
colorize_nlabels=None,
|
616 |
+
monitor=None,
|
617 |
+
remap=None,
|
618 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
619 |
+
):
|
620 |
+
super().__init__(
|
621 |
+
ddconfig,
|
622 |
+
lossconfig,
|
623 |
+
n_embed,
|
624 |
+
embed_dim,
|
625 |
+
ckpt_path=None,
|
626 |
+
ignore_keys=ignore_keys,
|
627 |
+
image_key=image_key,
|
628 |
+
colorize_nlabels=colorize_nlabels,
|
629 |
+
monitor=monitor,
|
630 |
+
)
|
631 |
+
self.quantize = EMAVectorQuantizer(
|
632 |
+
n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap
|
633 |
+
)
|
634 |
+
|
635 |
+
def configure_optimizers(self):
|
636 |
+
lr = self.learning_rate
|
637 |
+
# Remove self.quantize from parameter list since it is updated via EMA
|
638 |
+
opt_ae = torch.optim.Adam(
|
639 |
+
list(self.encoder.parameters())
|
640 |
+
+ list(self.decoder.parameters())
|
641 |
+
+ list(self.quant_conv.parameters())
|
642 |
+
+ list(self.post_quant_conv.parameters()),
|
643 |
+
lr=lr,
|
644 |
+
betas=(0.5, 0.9),
|
645 |
+
)
|
646 |
+
opt_disc = torch.optim.Adam(
|
647 |
+
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
648 |
+
)
|
649 |
+
return [opt_ae, opt_disc], []
|
taming/modules/autoencoder/lpips/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
taming/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
+
"""
|
10 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
+
From Fairseq.
|
12 |
+
Build sinusoidal embeddings.
|
13 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
+
"""
|
16 |
+
assert len(timesteps.shape) == 1
|
17 |
+
|
18 |
+
half_dim = embedding_dim // 2
|
19 |
+
emb = math.log(10000) / (half_dim - 1)
|
20 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
+
emb = emb.to(device=timesteps.device)
|
22 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
+
if embedding_dim % 2 == 1: # zero pad
|
25 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
+
return emb
|
27 |
+
|
28 |
+
|
29 |
+
def nonlinearity(x):
|
30 |
+
# swish
|
31 |
+
return x*torch.sigmoid(x)
|
32 |
+
|
33 |
+
|
34 |
+
def Normalize(in_channels):
|
35 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
+
|
37 |
+
|
38 |
+
class Upsample(nn.Module):
|
39 |
+
def __init__(self, in_channels, with_conv):
|
40 |
+
super().__init__()
|
41 |
+
self.with_conv = with_conv
|
42 |
+
if self.with_conv:
|
43 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
+
in_channels,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
+
if self.with_conv:
|
52 |
+
x = self.conv(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Downsample(nn.Module):
|
57 |
+
def __init__(self, in_channels, with_conv):
|
58 |
+
super().__init__()
|
59 |
+
self.with_conv = with_conv
|
60 |
+
if self.with_conv:
|
61 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
+
in_channels,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=2,
|
66 |
+
padding=0)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.with_conv:
|
70 |
+
pad = (0,1,0,1)
|
71 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
+
x = self.conv(x)
|
73 |
+
else:
|
74 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ResnetBlock(nn.Module):
|
79 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
+
dropout, temb_channels=512):
|
81 |
+
super().__init__()
|
82 |
+
self.in_channels = in_channels
|
83 |
+
out_channels = in_channels if out_channels is None else out_channels
|
84 |
+
self.out_channels = out_channels
|
85 |
+
self.use_conv_shortcut = conv_shortcut
|
86 |
+
|
87 |
+
self.norm1 = Normalize(in_channels)
|
88 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
if temb_channels > 0:
|
94 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
+
out_channels)
|
96 |
+
self.norm2 = Normalize(out_channels)
|
97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1)
|
110 |
+
else:
|
111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
+
out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
def forward(self, x, temb):
|
118 |
+
h = x
|
119 |
+
h = self.norm1(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
|
123 |
+
if temb is not None:
|
124 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
+
|
126 |
+
h = self.norm2(h)
|
127 |
+
h = nonlinearity(h)
|
128 |
+
h = self.dropout(h)
|
129 |
+
h = self.conv2(h)
|
130 |
+
|
131 |
+
if self.in_channels != self.out_channels:
|
132 |
+
if self.use_conv_shortcut:
|
133 |
+
x = self.conv_shortcut(x)
|
134 |
+
else:
|
135 |
+
x = self.nin_shortcut(x)
|
136 |
+
|
137 |
+
return x+h
|
138 |
+
|
139 |
+
|
140 |
+
class AttnBlock(nn.Module):
|
141 |
+
def __init__(self, in_channels):
|
142 |
+
super().__init__()
|
143 |
+
self.in_channels = in_channels
|
144 |
+
|
145 |
+
self.norm = Normalize(in_channels)
|
146 |
+
self.q = torch.nn.Conv2d(in_channels,
|
147 |
+
in_channels,
|
148 |
+
kernel_size=1,
|
149 |
+
stride=1,
|
150 |
+
padding=0)
|
151 |
+
self.k = torch.nn.Conv2d(in_channels,
|
152 |
+
in_channels,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
padding=0)
|
156 |
+
self.v = torch.nn.Conv2d(in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0)
|
161 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
+
in_channels,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
padding=0)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b,c,h,w = q.shape
|
177 |
+
q = q.reshape(b,c,h*w)
|
178 |
+
q = q.permute(0,2,1) # b,hw,c
|
179 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
+
w_ = w_ * (int(c)**(-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = v.reshape(b,c,h*w)
|
186 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
+
h_ = h_.reshape(b,c,h,w)
|
189 |
+
|
190 |
+
h_ = self.proj_out(h_)
|
191 |
+
|
192 |
+
return x+h_
|
193 |
+
|
194 |
+
|
195 |
+
class Model(nn.Module):
|
196 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
+
resolution, use_timestep=True):
|
199 |
+
super().__init__()
|
200 |
+
self.ch = ch
|
201 |
+
self.temb_ch = self.ch*4
|
202 |
+
self.num_resolutions = len(ch_mult)
|
203 |
+
self.num_res_blocks = num_res_blocks
|
204 |
+
self.resolution = resolution
|
205 |
+
self.in_channels = in_channels
|
206 |
+
|
207 |
+
self.use_timestep = use_timestep
|
208 |
+
if self.use_timestep:
|
209 |
+
# timestep embedding
|
210 |
+
self.temb = nn.Module()
|
211 |
+
self.temb.dense = nn.ModuleList([
|
212 |
+
torch.nn.Linear(self.ch,
|
213 |
+
self.temb_ch),
|
214 |
+
torch.nn.Linear(self.temb_ch,
|
215 |
+
self.temb_ch),
|
216 |
+
])
|
217 |
+
|
218 |
+
# downsampling
|
219 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
+
self.ch,
|
221 |
+
kernel_size=3,
|
222 |
+
stride=1,
|
223 |
+
padding=1)
|
224 |
+
|
225 |
+
curr_res = resolution
|
226 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
+
self.down = nn.ModuleList()
|
228 |
+
for i_level in range(self.num_resolutions):
|
229 |
+
block = nn.ModuleList()
|
230 |
+
attn = nn.ModuleList()
|
231 |
+
block_in = ch*in_ch_mult[i_level]
|
232 |
+
block_out = ch*ch_mult[i_level]
|
233 |
+
for i_block in range(self.num_res_blocks):
|
234 |
+
block.append(ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_out,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=dropout))
|
238 |
+
block_in = block_out
|
239 |
+
if curr_res in attn_resolutions:
|
240 |
+
attn.append(AttnBlock(block_in))
|
241 |
+
down = nn.Module()
|
242 |
+
down.block = block
|
243 |
+
down.attn = attn
|
244 |
+
if i_level != self.num_resolutions-1:
|
245 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
+
curr_res = curr_res // 2
|
247 |
+
self.down.append(down)
|
248 |
+
|
249 |
+
# middle
|
250 |
+
self.mid = nn.Module()
|
251 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
+
out_channels=block_in,
|
253 |
+
temb_channels=self.temb_ch,
|
254 |
+
dropout=dropout)
|
255 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
+
out_channels=block_in,
|
258 |
+
temb_channels=self.temb_ch,
|
259 |
+
dropout=dropout)
|
260 |
+
|
261 |
+
# upsampling
|
262 |
+
self.up = nn.ModuleList()
|
263 |
+
for i_level in reversed(range(self.num_resolutions)):
|
264 |
+
block = nn.ModuleList()
|
265 |
+
attn = nn.ModuleList()
|
266 |
+
block_out = ch*ch_mult[i_level]
|
267 |
+
skip_in = ch*ch_mult[i_level]
|
268 |
+
for i_block in range(self.num_res_blocks+1):
|
269 |
+
if i_block == self.num_res_blocks:
|
270 |
+
skip_in = ch*in_ch_mult[i_level]
|
271 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
+
out_channels=block_out,
|
273 |
+
temb_channels=self.temb_ch,
|
274 |
+
dropout=dropout))
|
275 |
+
block_in = block_out
|
276 |
+
if curr_res in attn_resolutions:
|
277 |
+
attn.append(AttnBlock(block_in))
|
278 |
+
up = nn.Module()
|
279 |
+
up.block = block
|
280 |
+
up.attn = attn
|
281 |
+
if i_level != 0:
|
282 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
+
curr_res = curr_res * 2
|
284 |
+
self.up.insert(0, up) # prepend to get consistent order
|
285 |
+
|
286 |
+
# end
|
287 |
+
self.norm_out = Normalize(block_in)
|
288 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
+
out_ch,
|
290 |
+
kernel_size=3,
|
291 |
+
stride=1,
|
292 |
+
padding=1)
|
293 |
+
|
294 |
+
|
295 |
+
def forward(self, x, t=None):
|
296 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
+
|
298 |
+
if self.use_timestep:
|
299 |
+
# timestep embedding
|
300 |
+
assert t is not None
|
301 |
+
temb = get_timestep_embedding(t, self.ch)
|
302 |
+
temb = self.temb.dense[0](temb)
|
303 |
+
temb = nonlinearity(temb)
|
304 |
+
temb = self.temb.dense[1](temb)
|
305 |
+
else:
|
306 |
+
temb = None
|
307 |
+
|
308 |
+
# downsampling
|
309 |
+
hs = [self.conv_in(x)]
|
310 |
+
for i_level in range(self.num_resolutions):
|
311 |
+
for i_block in range(self.num_res_blocks):
|
312 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
+
if len(self.down[i_level].attn) > 0:
|
314 |
+
h = self.down[i_level].attn[i_block](h)
|
315 |
+
hs.append(h)
|
316 |
+
if i_level != self.num_resolutions-1:
|
317 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
+
|
319 |
+
# middle
|
320 |
+
h = hs[-1]
|
321 |
+
h = self.mid.block_1(h, temb)
|
322 |
+
h = self.mid.attn_1(h)
|
323 |
+
h = self.mid.block_2(h, temb)
|
324 |
+
|
325 |
+
# upsampling
|
326 |
+
for i_level in reversed(range(self.num_resolutions)):
|
327 |
+
for i_block in range(self.num_res_blocks+1):
|
328 |
+
h = self.up[i_level].block[i_block](
|
329 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
+
if len(self.up[i_level].attn) > 0:
|
331 |
+
h = self.up[i_level].attn[i_block](h)
|
332 |
+
if i_level != 0:
|
333 |
+
h = self.up[i_level].upsample(h)
|
334 |
+
|
335 |
+
# end
|
336 |
+
h = self.norm_out(h)
|
337 |
+
h = nonlinearity(h)
|
338 |
+
h = self.conv_out(h)
|
339 |
+
return h
|
340 |
+
|
341 |
+
|
342 |
+
class Encoder(nn.Module):
|
343 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
+
super().__init__()
|
347 |
+
self.ch = ch
|
348 |
+
self.temb_ch = 0
|
349 |
+
self.num_resolutions = len(ch_mult)
|
350 |
+
self.num_res_blocks = num_res_blocks
|
351 |
+
self.resolution = resolution
|
352 |
+
self.in_channels = in_channels
|
353 |
+
|
354 |
+
# downsampling
|
355 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
+
self.ch,
|
357 |
+
kernel_size=3,
|
358 |
+
stride=1,
|
359 |
+
padding=1)
|
360 |
+
|
361 |
+
curr_res = resolution
|
362 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
+
self.down = nn.ModuleList()
|
364 |
+
for i_level in range(self.num_resolutions):
|
365 |
+
block = nn.ModuleList()
|
366 |
+
attn = nn.ModuleList()
|
367 |
+
block_in = ch*in_ch_mult[i_level]
|
368 |
+
block_out = ch*ch_mult[i_level]
|
369 |
+
for i_block in range(self.num_res_blocks):
|
370 |
+
block.append(ResnetBlock(in_channels=block_in,
|
371 |
+
out_channels=block_out,
|
372 |
+
temb_channels=self.temb_ch,
|
373 |
+
dropout=dropout))
|
374 |
+
block_in = block_out
|
375 |
+
if curr_res in attn_resolutions:
|
376 |
+
attn.append(AttnBlock(block_in))
|
377 |
+
down = nn.Module()
|
378 |
+
down.block = block
|
379 |
+
down.attn = attn
|
380 |
+
if i_level != self.num_resolutions-1:
|
381 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
+
curr_res = curr_res // 2
|
383 |
+
self.down.append(down)
|
384 |
+
|
385 |
+
# middle
|
386 |
+
self.mid = nn.Module()
|
387 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
+
out_channels=block_in,
|
389 |
+
temb_channels=self.temb_ch,
|
390 |
+
dropout=dropout)
|
391 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
+
out_channels=block_in,
|
394 |
+
temb_channels=self.temb_ch,
|
395 |
+
dropout=dropout)
|
396 |
+
|
397 |
+
# end
|
398 |
+
self.norm_out = Normalize(block_in)
|
399 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
+
2*z_channels if double_z else z_channels,
|
401 |
+
kernel_size=3,
|
402 |
+
stride=1,
|
403 |
+
padding=1)
|
404 |
+
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
+
|
409 |
+
# timestep embedding
|
410 |
+
temb = None
|
411 |
+
|
412 |
+
# downsampling
|
413 |
+
hs = [self.conv_in(x)]
|
414 |
+
for i_level in range(self.num_resolutions):
|
415 |
+
for i_block in range(self.num_res_blocks):
|
416 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
+
if len(self.down[i_level].attn) > 0:
|
418 |
+
h = self.down[i_level].attn[i_block](h)
|
419 |
+
hs.append(h)
|
420 |
+
if i_level != self.num_resolutions-1:
|
421 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
+
|
423 |
+
# middle
|
424 |
+
h = hs[-1]
|
425 |
+
h = self.mid.block_1(h, temb)
|
426 |
+
h = self.mid.attn_1(h)
|
427 |
+
h = self.mid.block_2(h, temb)
|
428 |
+
|
429 |
+
# end
|
430 |
+
h = self.norm_out(h)
|
431 |
+
h = nonlinearity(h)
|
432 |
+
h = self.conv_out(h)
|
433 |
+
return h
|
434 |
+
|
435 |
+
|
436 |
+
class Decoder(nn.Module):
|
437 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
+
super().__init__()
|
441 |
+
self.ch = ch
|
442 |
+
self.temb_ch = 0
|
443 |
+
self.num_resolutions = len(ch_mult)
|
444 |
+
self.num_res_blocks = num_res_blocks
|
445 |
+
self.resolution = resolution
|
446 |
+
self.in_channels = in_channels
|
447 |
+
self.give_pre_end = give_pre_end
|
448 |
+
|
449 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
+
self.z_shape, np.prod(self.z_shape)))
|
456 |
+
|
457 |
+
# z to block_in
|
458 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
+
block_in,
|
460 |
+
kernel_size=3,
|
461 |
+
stride=1,
|
462 |
+
padding=1)
|
463 |
+
|
464 |
+
# middle
|
465 |
+
self.mid = nn.Module()
|
466 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout)
|
470 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
+
out_channels=block_in,
|
473 |
+
temb_channels=self.temb_ch,
|
474 |
+
dropout=dropout)
|
475 |
+
|
476 |
+
# upsampling
|
477 |
+
self.up = nn.ModuleList()
|
478 |
+
for i_level in reversed(range(self.num_resolutions)):
|
479 |
+
block = nn.ModuleList()
|
480 |
+
attn = nn.ModuleList()
|
481 |
+
block_out = ch*ch_mult[i_level]
|
482 |
+
for i_block in range(self.num_res_blocks+1):
|
483 |
+
block.append(ResnetBlock(in_channels=block_in,
|
484 |
+
out_channels=block_out,
|
485 |
+
temb_channels=self.temb_ch,
|
486 |
+
dropout=dropout))
|
487 |
+
block_in = block_out
|
488 |
+
if curr_res in attn_resolutions:
|
489 |
+
attn.append(AttnBlock(block_in))
|
490 |
+
up = nn.Module()
|
491 |
+
up.block = block
|
492 |
+
up.attn = attn
|
493 |
+
if i_level != 0:
|
494 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
+
curr_res = curr_res * 2
|
496 |
+
self.up.insert(0, up) # prepend to get consistent order
|
497 |
+
|
498 |
+
# end
|
499 |
+
self.norm_out = Normalize(block_in)
|
500 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
+
out_ch,
|
502 |
+
kernel_size=3,
|
503 |
+
stride=1,
|
504 |
+
padding=1)
|
505 |
+
|
506 |
+
def forward(self, z):
|
507 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
+
self.last_z_shape = z.shape
|
509 |
+
|
510 |
+
# timestep embedding
|
511 |
+
temb = None
|
512 |
+
|
513 |
+
# z to block_in
|
514 |
+
h = self.conv_in(z)
|
515 |
+
|
516 |
+
# middle
|
517 |
+
h = self.mid.block_1(h, temb)
|
518 |
+
h = self.mid.attn_1(h)
|
519 |
+
h = self.mid.block_2(h, temb)
|
520 |
+
|
521 |
+
# upsampling
|
522 |
+
for i_level in reversed(range(self.num_resolutions)):
|
523 |
+
for i_block in range(self.num_res_blocks+1):
|
524 |
+
h = self.up[i_level].block[i_block](h, temb)
|
525 |
+
if len(self.up[i_level].attn) > 0:
|
526 |
+
h = self.up[i_level].attn[i_block](h)
|
527 |
+
if i_level != 0:
|
528 |
+
h = self.up[i_level].upsample(h)
|
529 |
+
|
530 |
+
# end
|
531 |
+
if self.give_pre_end:
|
532 |
+
return h
|
533 |
+
|
534 |
+
h = self.norm_out(h)
|
535 |
+
h = nonlinearity(h)
|
536 |
+
h = self.conv_out(h)
|
537 |
+
return h
|
538 |
+
|
539 |
+
|
540 |
+
class VUNet(nn.Module):
|
541 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
+
in_channels, c_channels,
|
544 |
+
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
+
super().__init__()
|
546 |
+
self.ch = ch
|
547 |
+
self.temb_ch = self.ch*4
|
548 |
+
self.num_resolutions = len(ch_mult)
|
549 |
+
self.num_res_blocks = num_res_blocks
|
550 |
+
self.resolution = resolution
|
551 |
+
|
552 |
+
self.use_timestep = use_timestep
|
553 |
+
if self.use_timestep:
|
554 |
+
# timestep embedding
|
555 |
+
self.temb = nn.Module()
|
556 |
+
self.temb.dense = nn.ModuleList([
|
557 |
+
torch.nn.Linear(self.ch,
|
558 |
+
self.temb_ch),
|
559 |
+
torch.nn.Linear(self.temb_ch,
|
560 |
+
self.temb_ch),
|
561 |
+
])
|
562 |
+
|
563 |
+
# downsampling
|
564 |
+
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
+
self.ch,
|
566 |
+
kernel_size=3,
|
567 |
+
stride=1,
|
568 |
+
padding=1)
|
569 |
+
|
570 |
+
curr_res = resolution
|
571 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
+
self.down = nn.ModuleList()
|
573 |
+
for i_level in range(self.num_resolutions):
|
574 |
+
block = nn.ModuleList()
|
575 |
+
attn = nn.ModuleList()
|
576 |
+
block_in = ch*in_ch_mult[i_level]
|
577 |
+
block_out = ch*ch_mult[i_level]
|
578 |
+
for i_block in range(self.num_res_blocks):
|
579 |
+
block.append(ResnetBlock(in_channels=block_in,
|
580 |
+
out_channels=block_out,
|
581 |
+
temb_channels=self.temb_ch,
|
582 |
+
dropout=dropout))
|
583 |
+
block_in = block_out
|
584 |
+
if curr_res in attn_resolutions:
|
585 |
+
attn.append(AttnBlock(block_in))
|
586 |
+
down = nn.Module()
|
587 |
+
down.block = block
|
588 |
+
down.attn = attn
|
589 |
+
if i_level != self.num_resolutions-1:
|
590 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
+
curr_res = curr_res // 2
|
592 |
+
self.down.append(down)
|
593 |
+
|
594 |
+
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
+
block_in,
|
596 |
+
kernel_size=1,
|
597 |
+
stride=1,
|
598 |
+
padding=0)
|
599 |
+
# middle
|
600 |
+
self.mid = nn.Module()
|
601 |
+
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
+
out_channels=block_in,
|
603 |
+
temb_channels=self.temb_ch,
|
604 |
+
dropout=dropout)
|
605 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
+
out_channels=block_in,
|
608 |
+
temb_channels=self.temb_ch,
|
609 |
+
dropout=dropout)
|
610 |
+
|
611 |
+
# upsampling
|
612 |
+
self.up = nn.ModuleList()
|
613 |
+
for i_level in reversed(range(self.num_resolutions)):
|
614 |
+
block = nn.ModuleList()
|
615 |
+
attn = nn.ModuleList()
|
616 |
+
block_out = ch*ch_mult[i_level]
|
617 |
+
skip_in = ch*ch_mult[i_level]
|
618 |
+
for i_block in range(self.num_res_blocks+1):
|
619 |
+
if i_block == self.num_res_blocks:
|
620 |
+
skip_in = ch*in_ch_mult[i_level]
|
621 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
+
out_channels=block_out,
|
623 |
+
temb_channels=self.temb_ch,
|
624 |
+
dropout=dropout))
|
625 |
+
block_in = block_out
|
626 |
+
if curr_res in attn_resolutions:
|
627 |
+
attn.append(AttnBlock(block_in))
|
628 |
+
up = nn.Module()
|
629 |
+
up.block = block
|
630 |
+
up.attn = attn
|
631 |
+
if i_level != 0:
|
632 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
+
curr_res = curr_res * 2
|
634 |
+
self.up.insert(0, up) # prepend to get consistent order
|
635 |
+
|
636 |
+
# end
|
637 |
+
self.norm_out = Normalize(block_in)
|
638 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
+
out_ch,
|
640 |
+
kernel_size=3,
|
641 |
+
stride=1,
|
642 |
+
padding=1)
|
643 |
+
|
644 |
+
|
645 |
+
def forward(self, x, z):
|
646 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
+
|
648 |
+
if self.use_timestep:
|
649 |
+
# timestep embedding
|
650 |
+
assert t is not None
|
651 |
+
temb = get_timestep_embedding(t, self.ch)
|
652 |
+
temb = self.temb.dense[0](temb)
|
653 |
+
temb = nonlinearity(temb)
|
654 |
+
temb = self.temb.dense[1](temb)
|
655 |
+
else:
|
656 |
+
temb = None
|
657 |
+
|
658 |
+
# downsampling
|
659 |
+
hs = [self.conv_in(x)]
|
660 |
+
for i_level in range(self.num_resolutions):
|
661 |
+
for i_block in range(self.num_res_blocks):
|
662 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
+
if len(self.down[i_level].attn) > 0:
|
664 |
+
h = self.down[i_level].attn[i_block](h)
|
665 |
+
hs.append(h)
|
666 |
+
if i_level != self.num_resolutions-1:
|
667 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
+
|
669 |
+
# middle
|
670 |
+
h = hs[-1]
|
671 |
+
z = self.z_in(z)
|
672 |
+
h = torch.cat((h,z),dim=1)
|
673 |
+
h = self.mid.block_1(h, temb)
|
674 |
+
h = self.mid.attn_1(h)
|
675 |
+
h = self.mid.block_2(h, temb)
|
676 |
+
|
677 |
+
# upsampling
|
678 |
+
for i_level in reversed(range(self.num_resolutions)):
|
679 |
+
for i_block in range(self.num_res_blocks+1):
|
680 |
+
h = self.up[i_level].block[i_block](
|
681 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
+
if len(self.up[i_level].attn) > 0:
|
683 |
+
h = self.up[i_level].attn[i_block](h)
|
684 |
+
if i_level != 0:
|
685 |
+
h = self.up[i_level].upsample(h)
|
686 |
+
|
687 |
+
# end
|
688 |
+
h = self.norm_out(h)
|
689 |
+
h = nonlinearity(h)
|
690 |
+
h = self.conv_out(h)
|
691 |
+
return h
|
692 |
+
|
693 |
+
|
694 |
+
class SimpleDecoder(nn.Module):
|
695 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
+
super().__init__()
|
697 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
+
ResnetBlock(in_channels=in_channels,
|
699 |
+
out_channels=2 * in_channels,
|
700 |
+
temb_channels=0, dropout=0.0),
|
701 |
+
ResnetBlock(in_channels=2 * in_channels,
|
702 |
+
out_channels=4 * in_channels,
|
703 |
+
temb_channels=0, dropout=0.0),
|
704 |
+
ResnetBlock(in_channels=4 * in_channels,
|
705 |
+
out_channels=2 * in_channels,
|
706 |
+
temb_channels=0, dropout=0.0),
|
707 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
+
Upsample(in_channels, with_conv=True)])
|
709 |
+
# end
|
710 |
+
self.norm_out = Normalize(in_channels)
|
711 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
+
out_channels,
|
713 |
+
kernel_size=3,
|
714 |
+
stride=1,
|
715 |
+
padding=1)
|
716 |
+
|
717 |
+
def forward(self, x):
|
718 |
+
for i, layer in enumerate(self.model):
|
719 |
+
if i in [1,2,3]:
|
720 |
+
x = layer(x, None)
|
721 |
+
else:
|
722 |
+
x = layer(x)
|
723 |
+
|
724 |
+
h = self.norm_out(x)
|
725 |
+
h = nonlinearity(h)
|
726 |
+
x = self.conv_out(h)
|
727 |
+
return x
|
728 |
+
|
729 |
+
|
730 |
+
class UpsampleDecoder(nn.Module):
|
731 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
+
ch_mult=(2,2), dropout=0.0):
|
733 |
+
super().__init__()
|
734 |
+
# upsampling
|
735 |
+
self.temb_ch = 0
|
736 |
+
self.num_resolutions = len(ch_mult)
|
737 |
+
self.num_res_blocks = num_res_blocks
|
738 |
+
block_in = in_channels
|
739 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
+
self.res_blocks = nn.ModuleList()
|
741 |
+
self.upsample_blocks = nn.ModuleList()
|
742 |
+
for i_level in range(self.num_resolutions):
|
743 |
+
res_block = []
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
for i_block in range(self.num_res_blocks + 1):
|
746 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
+
out_channels=block_out,
|
748 |
+
temb_channels=self.temb_ch,
|
749 |
+
dropout=dropout))
|
750 |
+
block_in = block_out
|
751 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
+
if i_level != self.num_resolutions - 1:
|
753 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
+
curr_res = curr_res * 2
|
755 |
+
|
756 |
+
# end
|
757 |
+
self.norm_out = Normalize(block_in)
|
758 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
+
out_channels,
|
760 |
+
kernel_size=3,
|
761 |
+
stride=1,
|
762 |
+
padding=1)
|
763 |
+
|
764 |
+
def forward(self, x):
|
765 |
+
# upsampling
|
766 |
+
h = x
|
767 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
+
for i_block in range(self.num_res_blocks + 1):
|
769 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
+
if i_level != self.num_resolutions - 1:
|
771 |
+
h = self.upsample_blocks[k](h)
|
772 |
+
h = self.norm_out(h)
|
773 |
+
h = nonlinearity(h)
|
774 |
+
h = self.conv_out(h)
|
775 |
+
return h
|
776 |
+
|
taming/modules/discriminator/model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
from taming.modules.util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find('Conv') != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find('BatchNorm') != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(NLayerDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
taming/modules/losses/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from taming.modules.losses.vqperceptual import DummyLoss
|
2 |
+
|
taming/modules/losses/lpips.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
from collections import namedtuple
|
7 |
+
|
8 |
+
from taming.util import get_ckpt_path
|
9 |
+
|
10 |
+
|
11 |
+
class LPIPS(nn.Module):
|
12 |
+
# Learned perceptual metric
|
13 |
+
def __init__(self, use_dropout=True):
|
14 |
+
super().__init__()
|
15 |
+
self.scaling_layer = ScalingLayer()
|
16 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
+
self.load_from_pretrained()
|
24 |
+
for param in self.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
+
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
29 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
+
if name != "vgg_lpips":
|
35 |
+
raise NotImplementedError
|
36 |
+
model = cls()
|
37 |
+
ckpt = get_ckpt_path(name)
|
38 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def forward(self, input, target):
|
42 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
+
feats0, feats1, diffs = {}, {}, {}
|
45 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
+
for kk in range(len(self.chns)):
|
47 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
+
|
50 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
+
val = res[0]
|
52 |
+
for l in range(1, len(self.chns)):
|
53 |
+
val += res[l]
|
54 |
+
return val
|
55 |
+
|
56 |
+
|
57 |
+
class ScalingLayer(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(ScalingLayer, self).__init__()
|
60 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
return (inp - self.shift) / self.scale
|
65 |
+
|
66 |
+
|
67 |
+
class NetLinLayer(nn.Module):
|
68 |
+
""" A single linear layer which does a 1x1 conv """
|
69 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
+
super(NetLinLayer, self).__init__()
|
71 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
+
self.model = nn.Sequential(*layers)
|
74 |
+
|
75 |
+
|
76 |
+
class vgg16(torch.nn.Module):
|
77 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
+
super(vgg16, self).__init__()
|
79 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
+
self.slice1 = torch.nn.Sequential()
|
81 |
+
self.slice2 = torch.nn.Sequential()
|
82 |
+
self.slice3 = torch.nn.Sequential()
|
83 |
+
self.slice4 = torch.nn.Sequential()
|
84 |
+
self.slice5 = torch.nn.Sequential()
|
85 |
+
self.N_slices = 5
|
86 |
+
for x in range(4):
|
87 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
+
for x in range(4, 9):
|
89 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
+
for x in range(9, 16):
|
91 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
+
for x in range(16, 23):
|
93 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
+
for x in range(23, 30):
|
95 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
+
if not requires_grad:
|
97 |
+
for param in self.parameters():
|
98 |
+
param.requires_grad = False
|
99 |
+
|
100 |
+
def forward(self, X):
|
101 |
+
h = self.slice1(X)
|
102 |
+
h_relu1_2 = h
|
103 |
+
h = self.slice2(h)
|
104 |
+
h_relu2_2 = h
|
105 |
+
h = self.slice3(h)
|
106 |
+
h_relu3_3 = h
|
107 |
+
h = self.slice4(h)
|
108 |
+
h_relu4_3 = h
|
109 |
+
h = self.slice5(h)
|
110 |
+
h_relu5_3 = h
|
111 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def normalize_tensor(x,eps=1e-10):
|
117 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
+
return x/(norm_factor+eps)
|
119 |
+
|
120 |
+
|
121 |
+
def spatial_average(x, keepdim=True):
|
122 |
+
return x.mean([2,3],keepdim=keepdim)
|
123 |
+
|
taming/modules/losses/segmentation.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BCELoss(nn.Module):
|
6 |
+
def forward(self, prediction, target):
|
7 |
+
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
8 |
+
return loss, {}
|
9 |
+
|
10 |
+
|
11 |
+
class BCELossWithQuant(nn.Module):
|
12 |
+
def __init__(self, codebook_weight=1.):
|
13 |
+
super().__init__()
|
14 |
+
self.codebook_weight = codebook_weight
|
15 |
+
|
16 |
+
def forward(self, qloss, target, prediction, split):
|
17 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
18 |
+
loss = bce_loss + self.codebook_weight*qloss
|
19 |
+
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
20 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
21 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
22 |
+
}
|
taming/modules/losses/vqperceptual.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from taming.modules.losses.lpips import LPIPS
|
6 |
+
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
|
8 |
+
|
9 |
+
class DummyLoss(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
|
14 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
15 |
+
if global_step < threshold:
|
16 |
+
weight = value
|
17 |
+
return weight
|
18 |
+
|
19 |
+
|
20 |
+
def hinge_d_loss(logits_real, logits_fake):
|
21 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
22 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
23 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
+
return d_loss
|
25 |
+
|
26 |
+
|
27 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
+
d_loss = 0.5 * (
|
29 |
+
torch.mean(torch.nn.functional.softplus(-logits_real))
|
30 |
+
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
31 |
+
)
|
32 |
+
return d_loss
|
33 |
+
|
34 |
+
|
35 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
disc_start,
|
39 |
+
codebook_weight=1.0,
|
40 |
+
pixelloss_weight=1.0,
|
41 |
+
disc_num_layers=3,
|
42 |
+
disc_in_channels=3,
|
43 |
+
disc_factor=1.0,
|
44 |
+
disc_weight=1.0,
|
45 |
+
perceptual_weight=1.0,
|
46 |
+
use_actnorm=False,
|
47 |
+
disc_conditional=False,
|
48 |
+
disc_ndf=64,
|
49 |
+
disc_loss="hinge",
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
assert disc_loss in ["hinge", "vanilla"]
|
53 |
+
self.codebook_weight = codebook_weight
|
54 |
+
self.pixel_weight = pixelloss_weight
|
55 |
+
self.perceptual_loss = LPIPS().eval()
|
56 |
+
self.perceptual_weight = perceptual_weight
|
57 |
+
|
58 |
+
self.discriminator = NLayerDiscriminator(
|
59 |
+
input_nc=disc_in_channels,
|
60 |
+
n_layers=disc_num_layers,
|
61 |
+
use_actnorm=use_actnorm,
|
62 |
+
ndf=disc_ndf,
|
63 |
+
).apply(weights_init)
|
64 |
+
self.discriminator_iter_start = disc_start
|
65 |
+
if disc_loss == "hinge":
|
66 |
+
self.disc_loss = hinge_d_loss
|
67 |
+
elif disc_loss == "vanilla":
|
68 |
+
self.disc_loss = vanilla_d_loss
|
69 |
+
else:
|
70 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
71 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
72 |
+
self.disc_factor = disc_factor
|
73 |
+
self.discriminator_weight = disc_weight
|
74 |
+
self.disc_conditional = disc_conditional
|
75 |
+
|
76 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
77 |
+
if last_layer is not None:
|
78 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
79 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
80 |
+
else:
|
81 |
+
nll_grads = torch.autograd.grad(
|
82 |
+
nll_loss, self.last_layer[0], retain_graph=True
|
83 |
+
)[0]
|
84 |
+
g_grads = torch.autograd.grad(
|
85 |
+
g_loss, self.last_layer[0], retain_graph=True
|
86 |
+
)[0]
|
87 |
+
|
88 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
89 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
90 |
+
d_weight = d_weight * self.discriminator_weight
|
91 |
+
return d_weight
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
codebook_loss,
|
96 |
+
inputs,
|
97 |
+
reconstructions,
|
98 |
+
optimizer_idx,
|
99 |
+
global_step,
|
100 |
+
last_layer=None,
|
101 |
+
cond=None,
|
102 |
+
split="train",
|
103 |
+
):
|
104 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
105 |
+
if self.perceptual_weight > 0:
|
106 |
+
p_loss = self.perceptual_loss(
|
107 |
+
inputs.contiguous(), reconstructions.contiguous()
|
108 |
+
)
|
109 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
110 |
+
else:
|
111 |
+
p_loss = torch.tensor([0.0])
|
112 |
+
|
113 |
+
nll_loss = rec_loss
|
114 |
+
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
115 |
+
nll_loss = torch.mean(nll_loss)
|
116 |
+
|
117 |
+
# now the GAN part
|
118 |
+
if optimizer_idx == 0:
|
119 |
+
# generator update
|
120 |
+
if cond is None:
|
121 |
+
assert not self.disc_conditional
|
122 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
123 |
+
else:
|
124 |
+
assert self.disc_conditional
|
125 |
+
logits_fake = self.discriminator(
|
126 |
+
torch.cat((reconstructions.contiguous(), cond), dim=1)
|
127 |
+
)
|
128 |
+
g_loss = -torch.mean(logits_fake)
|
129 |
+
|
130 |
+
try:
|
131 |
+
d_weight = self.calculate_adaptive_weight(
|
132 |
+
nll_loss, g_loss, last_layer=last_layer
|
133 |
+
)
|
134 |
+
except RuntimeError:
|
135 |
+
assert not self.training
|
136 |
+
d_weight = torch.tensor(0.0)
|
137 |
+
|
138 |
+
disc_factor = adopt_weight(
|
139 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
140 |
+
)
|
141 |
+
loss = (
|
142 |
+
nll_loss
|
143 |
+
+ d_weight * disc_factor * g_loss
|
144 |
+
+ self.codebook_weight * codebook_loss.mean()
|
145 |
+
)
|
146 |
+
|
147 |
+
log = {
|
148 |
+
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
149 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
150 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
151 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
152 |
+
"{}/p_loss".format(split): p_loss.detach().mean(),
|
153 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
154 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
155 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
156 |
+
}
|
157 |
+
return loss, log
|
158 |
+
|
159 |
+
if optimizer_idx == 1:
|
160 |
+
# second pass for discriminator update
|
161 |
+
if cond is None:
|
162 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
163 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
164 |
+
else:
|
165 |
+
logits_real = self.discriminator(
|
166 |
+
torch.cat((inputs.contiguous().detach(), cond), dim=1)
|
167 |
+
)
|
168 |
+
logits_fake = self.discriminator(
|
169 |
+
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
|
170 |
+
)
|
171 |
+
|
172 |
+
disc_factor = adopt_weight(
|
173 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
174 |
+
)
|
175 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
176 |
+
|
177 |
+
log = {
|
178 |
+
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
179 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
180 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
181 |
+
}
|
182 |
+
return d_loss, log
|
taming/modules/misc/coord.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class CoordStage(object):
|
4 |
+
def __init__(self, n_embed, down_factor):
|
5 |
+
self.n_embed = n_embed
|
6 |
+
self.down_factor = down_factor
|
7 |
+
|
8 |
+
def eval(self):
|
9 |
+
return self
|
10 |
+
|
11 |
+
def encode(self, c):
|
12 |
+
"""fake vqmodel interface"""
|
13 |
+
assert 0.0 <= c.min() and c.max() <= 1.0
|
14 |
+
b,ch,h,w = c.shape
|
15 |
+
assert ch == 1
|
16 |
+
|
17 |
+
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
18 |
+
mode="area")
|
19 |
+
c = c.clamp(0.0, 1.0)
|
20 |
+
c = self.n_embed*c
|
21 |
+
c_quant = c.round()
|
22 |
+
c_ind = c_quant.to(dtype=torch.long)
|
23 |
+
|
24 |
+
info = None, None, c_ind
|
25 |
+
return c_quant, None, info
|
26 |
+
|
27 |
+
def decode(self, c):
|
28 |
+
c = c/self.n_embed
|
29 |
+
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
30 |
+
mode="nearest")
|
31 |
+
return c
|
taming/modules/transformer/mingpt.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from: https://github.com/karpathy/minGPT/
|
3 |
+
GPT model:
|
4 |
+
- the initial stem consists of a combination of token encoding and a positional encoding
|
5 |
+
- the meat of it is a uniform sequence of Transformer blocks
|
6 |
+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
7 |
+
- all blocks feed into a central residual pathway similar to resnets
|
8 |
+
- the final decoder is a linear projection into a vanilla Softmax classifier
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from transformers import top_k_top_p_filtering
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class GPTConfig:
|
23 |
+
""" base GPT config, params common to all GPT versions """
|
24 |
+
embd_pdrop = 0.1
|
25 |
+
resid_pdrop = 0.1
|
26 |
+
attn_pdrop = 0.1
|
27 |
+
|
28 |
+
def __init__(self, vocab_size, block_size, **kwargs):
|
29 |
+
self.vocab_size = vocab_size
|
30 |
+
self.block_size = block_size
|
31 |
+
for k,v in kwargs.items():
|
32 |
+
setattr(self, k, v)
|
33 |
+
|
34 |
+
|
35 |
+
class GPT1Config(GPTConfig):
|
36 |
+
""" GPT-1 like network roughly 125M params """
|
37 |
+
n_layer = 12
|
38 |
+
n_head = 12
|
39 |
+
n_embd = 768
|
40 |
+
|
41 |
+
|
42 |
+
class CausalSelfAttention(nn.Module):
|
43 |
+
"""
|
44 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
45 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
46 |
+
explicit implementation here to show that there is nothing too scary here.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
super().__init__()
|
51 |
+
assert config.n_embd % config.n_head == 0
|
52 |
+
# key, query, value projections for all heads
|
53 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
54 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
55 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
56 |
+
# regularization
|
57 |
+
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
58 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
59 |
+
# output projection
|
60 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
61 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
62 |
+
mask = torch.tril(torch.ones(config.block_size,
|
63 |
+
config.block_size))
|
64 |
+
if hasattr(config, "n_unmasked"):
|
65 |
+
mask[:config.n_unmasked, :config.n_unmasked] = 1
|
66 |
+
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
67 |
+
self.n_head = config.n_head
|
68 |
+
|
69 |
+
def forward(self, x, layer_past=None):
|
70 |
+
B, T, C = x.size()
|
71 |
+
|
72 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
73 |
+
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
74 |
+
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
75 |
+
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
76 |
+
|
77 |
+
present = torch.stack((k, v))
|
78 |
+
if layer_past is not None:
|
79 |
+
past_key, past_value = layer_past
|
80 |
+
k = torch.cat((past_key, k), dim=-2)
|
81 |
+
v = torch.cat((past_value, v), dim=-2)
|
82 |
+
|
83 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
84 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
85 |
+
if layer_past is None:
|
86 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
87 |
+
|
88 |
+
att = F.softmax(att, dim=-1)
|
89 |
+
att = self.attn_drop(att)
|
90 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
91 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
92 |
+
|
93 |
+
# output projection
|
94 |
+
y = self.resid_drop(self.proj(y))
|
95 |
+
return y, present # TODO: check that this does not break anything
|
96 |
+
|
97 |
+
|
98 |
+
class Block(nn.Module):
|
99 |
+
""" an unassuming Transformer block """
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
103 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
104 |
+
self.attn = CausalSelfAttention(config)
|
105 |
+
self.mlp = nn.Sequential(
|
106 |
+
nn.Linear(config.n_embd, 4 * config.n_embd),
|
107 |
+
nn.GELU(), # nice
|
108 |
+
nn.Linear(4 * config.n_embd, config.n_embd),
|
109 |
+
nn.Dropout(config.resid_pdrop),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x, layer_past=None, return_present=False):
|
113 |
+
# TODO: check that training still works
|
114 |
+
if return_present: assert not self.training
|
115 |
+
# layer past: tuple of length two with B, nh, T, hs
|
116 |
+
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
|
117 |
+
|
118 |
+
x = x + attn
|
119 |
+
x = x + self.mlp(self.ln2(x))
|
120 |
+
if layer_past is not None or return_present:
|
121 |
+
return x, present
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class GPT(nn.Module):
|
126 |
+
""" the full GPT language model, with a context size of block_size """
|
127 |
+
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
|
128 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
129 |
+
super().__init__()
|
130 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
131 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
132 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
133 |
+
n_unmasked=n_unmasked)
|
134 |
+
# input embedding stem
|
135 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
136 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
137 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
138 |
+
# transformer
|
139 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
140 |
+
# decoder head
|
141 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
142 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
143 |
+
self.block_size = config.block_size
|
144 |
+
self.apply(self._init_weights)
|
145 |
+
self.config = config
|
146 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
147 |
+
|
148 |
+
def get_block_size(self):
|
149 |
+
return self.block_size
|
150 |
+
|
151 |
+
def _init_weights(self, module):
|
152 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
153 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
154 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
155 |
+
module.bias.data.zero_()
|
156 |
+
elif isinstance(module, nn.LayerNorm):
|
157 |
+
module.bias.data.zero_()
|
158 |
+
module.weight.data.fill_(1.0)
|
159 |
+
|
160 |
+
def forward(self, idx, embeddings=None, targets=None):
|
161 |
+
# forward the GPT model
|
162 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
163 |
+
|
164 |
+
if embeddings is not None: # prepend explicit embeddings
|
165 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
166 |
+
|
167 |
+
t = token_embeddings.shape[1]
|
168 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
169 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
170 |
+
x = self.drop(token_embeddings + position_embeddings)
|
171 |
+
x = self.blocks(x)
|
172 |
+
x = self.ln_f(x)
|
173 |
+
logits = self.head(x)
|
174 |
+
|
175 |
+
# if we are given some desired targets also calculate the loss
|
176 |
+
loss = None
|
177 |
+
if targets is not None:
|
178 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
179 |
+
|
180 |
+
return logits, loss
|
181 |
+
|
182 |
+
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
|
183 |
+
# inference only
|
184 |
+
assert not self.training
|
185 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
186 |
+
if embeddings is not None: # prepend explicit embeddings
|
187 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
188 |
+
|
189 |
+
if past is not None:
|
190 |
+
assert past_length is not None
|
191 |
+
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
|
192 |
+
past_shape = list(past.shape)
|
193 |
+
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
|
194 |
+
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
|
195 |
+
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
|
196 |
+
else:
|
197 |
+
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
|
198 |
+
|
199 |
+
x = self.drop(token_embeddings + position_embeddings)
|
200 |
+
presents = [] # accumulate over layers
|
201 |
+
for i, block in enumerate(self.blocks):
|
202 |
+
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
|
203 |
+
presents.append(present)
|
204 |
+
|
205 |
+
x = self.ln_f(x)
|
206 |
+
logits = self.head(x)
|
207 |
+
# if we are given some desired targets also calculate the loss
|
208 |
+
loss = None
|
209 |
+
if targets is not None:
|
210 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
211 |
+
|
212 |
+
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
|
213 |
+
|
214 |
+
|
215 |
+
class DummyGPT(nn.Module):
|
216 |
+
# for debugging
|
217 |
+
def __init__(self, add_value=1):
|
218 |
+
super().__init__()
|
219 |
+
self.add_value = add_value
|
220 |
+
|
221 |
+
def forward(self, idx):
|
222 |
+
return idx + self.add_value, None
|
223 |
+
|
224 |
+
|
225 |
+
class CodeGPT(nn.Module):
|
226 |
+
"""Takes in semi-embeddings"""
|
227 |
+
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
|
228 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
229 |
+
super().__init__()
|
230 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
231 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
232 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
233 |
+
n_unmasked=n_unmasked)
|
234 |
+
# input embedding stem
|
235 |
+
self.tok_emb = nn.Linear(in_channels, config.n_embd)
|
236 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
237 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
238 |
+
# transformer
|
239 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
240 |
+
# decoder head
|
241 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
242 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
243 |
+
self.block_size = config.block_size
|
244 |
+
self.apply(self._init_weights)
|
245 |
+
self.config = config
|
246 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
247 |
+
|
248 |
+
def get_block_size(self):
|
249 |
+
return self.block_size
|
250 |
+
|
251 |
+
def _init_weights(self, module):
|
252 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
253 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
254 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
255 |
+
module.bias.data.zero_()
|
256 |
+
elif isinstance(module, nn.LayerNorm):
|
257 |
+
module.bias.data.zero_()
|
258 |
+
module.weight.data.fill_(1.0)
|
259 |
+
|
260 |
+
def forward(self, idx, embeddings=None, targets=None):
|
261 |
+
# forward the GPT model
|
262 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
263 |
+
|
264 |
+
if embeddings is not None: # prepend explicit embeddings
|
265 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
266 |
+
|
267 |
+
t = token_embeddings.shape[1]
|
268 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
269 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
270 |
+
x = self.drop(token_embeddings + position_embeddings)
|
271 |
+
x = self.blocks(x)
|
272 |
+
x = self.taming_cinln_f(x)
|
273 |
+
logits = self.head(x)
|
274 |
+
|
275 |
+
# if we are given some desired targets also calculate the loss
|
276 |
+
loss = None
|
277 |
+
if targets is not None:
|
278 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
279 |
+
|
280 |
+
return logits, loss
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
#### sampling utils
|
285 |
+
|
286 |
+
def top_k_logits(logits, k):
|
287 |
+
v, ix = torch.topk(logits, k)
|
288 |
+
out = logits.clone()
|
289 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
290 |
+
return out
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
294 |
+
"""
|
295 |
+
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
296 |
+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
297 |
+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
298 |
+
of block_size, unlike an RNN that has an infinite context window.
|
299 |
+
"""
|
300 |
+
block_size = model.get_block_size()
|
301 |
+
model.eval()
|
302 |
+
for k in range(steps):
|
303 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
304 |
+
logits, _ = model(x_cond)
|
305 |
+
# pluck the logits at the final step and scale by temperature
|
306 |
+
logits = logits[:, -1, :] / temperature
|
307 |
+
# optionally crop probabilities to only the top k options
|
308 |
+
if top_k is not None:
|
309 |
+
logits = top_k_logits(logits, top_k)
|
310 |
+
# apply softmax to convert to probabilities
|
311 |
+
probs = F.softmax(logits, dim=-1)
|
312 |
+
# sample from the distribution or take the most likely
|
313 |
+
if sample:
|
314 |
+
ix = torch.multinomial(probs, num_samples=1)
|
315 |
+
else:
|
316 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
317 |
+
# append to the sequence and continue
|
318 |
+
x = torch.cat((x, ix), dim=1)
|
319 |
+
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
|
325 |
+
top_k=None, top_p=None, callback=None):
|
326 |
+
# x is conditioning
|
327 |
+
sample = x
|
328 |
+
cond_len = x.shape[1]
|
329 |
+
past = None
|
330 |
+
for n in range(steps):
|
331 |
+
if callback is not None:
|
332 |
+
callback(n)
|
333 |
+
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
|
334 |
+
if past is None:
|
335 |
+
past = [present]
|
336 |
+
else:
|
337 |
+
past.append(present)
|
338 |
+
logits = logits[:, -1, :] / temperature
|
339 |
+
if top_k is not None:
|
340 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
341 |
+
|
342 |
+
probs = F.softmax(logits, dim=-1)
|
343 |
+
if not sample_logits:
|
344 |
+
_, x = torch.topk(probs, k=1, dim=-1)
|
345 |
+
else:
|
346 |
+
x = torch.multinomial(probs, num_samples=1)
|
347 |
+
# append to the sequence and continue
|
348 |
+
sample = torch.cat((sample, x), dim=1)
|
349 |
+
del past
|
350 |
+
sample = sample[:, cond_len:] # cut conditioning off
|
351 |
+
return sample
|
352 |
+
|
353 |
+
|
354 |
+
#### clustering utils
|
355 |
+
|
356 |
+
class KMeans(nn.Module):
|
357 |
+
def __init__(self, ncluster=512, nc=3, niter=10):
|
358 |
+
super().__init__()
|
359 |
+
self.ncluster = ncluster
|
360 |
+
self.nc = nc
|
361 |
+
self.niter = niter
|
362 |
+
self.shape = (3,32,32)
|
363 |
+
self.register_buffer("C", torch.zeros(self.ncluster,nc))
|
364 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
365 |
+
|
366 |
+
def is_initialized(self):
|
367 |
+
return self.initialized.item() == 1
|
368 |
+
|
369 |
+
@torch.no_grad()
|
370 |
+
def initialize(self, x):
|
371 |
+
N, D = x.shape
|
372 |
+
assert D == self.nc, D
|
373 |
+
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
|
374 |
+
for i in range(self.niter):
|
375 |
+
# assign all pixels to the closest codebook element
|
376 |
+
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
|
377 |
+
# move each codebook element to be the mean of the pixels that assigned to it
|
378 |
+
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
|
379 |
+
# re-assign any poorly positioned codebook elements
|
380 |
+
nanix = torch.any(torch.isnan(c), dim=1)
|
381 |
+
ndead = nanix.sum().item()
|
382 |
+
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
|
383 |
+
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
|
384 |
+
|
385 |
+
self.C.copy_(c)
|
386 |
+
self.initialized.fill_(1)
|
387 |
+
|
388 |
+
|
389 |
+
def forward(self, x, reverse=False, shape=None):
|
390 |
+
if not reverse:
|
391 |
+
# flatten
|
392 |
+
bs,c,h,w = x.shape
|
393 |
+
assert c == self.nc
|
394 |
+
x = x.reshape(bs,c,h*w,1)
|
395 |
+
C = self.C.permute(1,0)
|
396 |
+
C = C.reshape(1,c,1,self.ncluster)
|
397 |
+
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
|
398 |
+
return a
|
399 |
+
else:
|
400 |
+
# flatten
|
401 |
+
bs, HW = x.shape
|
402 |
+
"""
|
403 |
+
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
|
404 |
+
c = c[bs*[0],:,:,:]
|
405 |
+
c = c[:,:,HW*[0],:]
|
406 |
+
x = x.reshape(bs, 1, HW, 1)
|
407 |
+
x = x[:,3*[0],:,:]
|
408 |
+
x = torch.gather(c, dim=3, index=x)
|
409 |
+
"""
|
410 |
+
x = self.C[x]
|
411 |
+
x = x.permute(0,2,1)
|
412 |
+
shape = shape if shape is not None else self.shape
|
413 |
+
x = x.reshape(bs, *shape)
|
414 |
+
|
415 |
+
return x
|
taming/modules/transformer/permuter.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class AbstractPermuter(nn.Module):
|
7 |
+
def __init__(self, *args, **kwargs):
|
8 |
+
super().__init__()
|
9 |
+
def forward(self, x, reverse=False):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
class Identity(AbstractPermuter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def forward(self, x, reverse=False):
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Subsample(AbstractPermuter):
|
22 |
+
def __init__(self, H, W):
|
23 |
+
super().__init__()
|
24 |
+
C = 1
|
25 |
+
indices = np.arange(H*W).reshape(C,H,W)
|
26 |
+
while min(H, W) > 1:
|
27 |
+
indices = indices.reshape(C,H//2,2,W//2,2)
|
28 |
+
indices = indices.transpose(0,2,4,1,3)
|
29 |
+
indices = indices.reshape(C*4,H//2, W//2)
|
30 |
+
H = H//2
|
31 |
+
W = W//2
|
32 |
+
C = C*4
|
33 |
+
assert H == W == 1
|
34 |
+
idx = torch.tensor(indices.ravel())
|
35 |
+
self.register_buffer('forward_shuffle_idx',
|
36 |
+
nn.Parameter(idx, requires_grad=False))
|
37 |
+
self.register_buffer('backward_shuffle_idx',
|
38 |
+
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
39 |
+
|
40 |
+
def forward(self, x, reverse=False):
|
41 |
+
if not reverse:
|
42 |
+
return x[:, self.forward_shuffle_idx]
|
43 |
+
else:
|
44 |
+
return x[:, self.backward_shuffle_idx]
|
45 |
+
|
46 |
+
|
47 |
+
def mortonify(i, j):
|
48 |
+
"""(i,j) index to linear morton code"""
|
49 |
+
i = np.uint64(i)
|
50 |
+
j = np.uint64(j)
|
51 |
+
|
52 |
+
z = np.uint(0)
|
53 |
+
|
54 |
+
for pos in range(32):
|
55 |
+
z = (z |
|
56 |
+
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
57 |
+
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
58 |
+
)
|
59 |
+
return z
|
60 |
+
|
61 |
+
|
62 |
+
class ZCurve(AbstractPermuter):
|
63 |
+
def __init__(self, H, W):
|
64 |
+
super().__init__()
|
65 |
+
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
66 |
+
idx = np.argsort(reverseidx)
|
67 |
+
idx = torch.tensor(idx)
|
68 |
+
reverseidx = torch.tensor(reverseidx)
|
69 |
+
self.register_buffer('forward_shuffle_idx',
|
70 |
+
idx)
|
71 |
+
self.register_buffer('backward_shuffle_idx',
|
72 |
+
reverseidx)
|
73 |
+
|
74 |
+
def forward(self, x, reverse=False):
|
75 |
+
if not reverse:
|
76 |
+
return x[:, self.forward_shuffle_idx]
|
77 |
+
else:
|
78 |
+
return x[:, self.backward_shuffle_idx]
|
79 |
+
|
80 |
+
|
81 |
+
class SpiralOut(AbstractPermuter):
|
82 |
+
def __init__(self, H, W):
|
83 |
+
super().__init__()
|
84 |
+
assert H == W
|
85 |
+
size = W
|
86 |
+
indices = np.arange(size*size).reshape(size,size)
|
87 |
+
|
88 |
+
i0 = size//2
|
89 |
+
j0 = size//2-1
|
90 |
+
|
91 |
+
i = i0
|
92 |
+
j = j0
|
93 |
+
|
94 |
+
idx = [indices[i0, j0]]
|
95 |
+
step_mult = 0
|
96 |
+
for c in range(1, size//2+1):
|
97 |
+
step_mult += 1
|
98 |
+
# steps left
|
99 |
+
for k in range(step_mult):
|
100 |
+
i = i - 1
|
101 |
+
j = j
|
102 |
+
idx.append(indices[i, j])
|
103 |
+
|
104 |
+
# step down
|
105 |
+
for k in range(step_mult):
|
106 |
+
i = i
|
107 |
+
j = j + 1
|
108 |
+
idx.append(indices[i, j])
|
109 |
+
|
110 |
+
step_mult += 1
|
111 |
+
if c < size//2:
|
112 |
+
# step right
|
113 |
+
for k in range(step_mult):
|
114 |
+
i = i + 1
|
115 |
+
j = j
|
116 |
+
idx.append(indices[i, j])
|
117 |
+
|
118 |
+
# step up
|
119 |
+
for k in range(step_mult):
|
120 |
+
i = i
|
121 |
+
j = j - 1
|
122 |
+
idx.append(indices[i, j])
|
123 |
+
else:
|
124 |
+
# end reached
|
125 |
+
for k in range(step_mult-1):
|
126 |
+
i = i + 1
|
127 |
+
idx.append(indices[i, j])
|
128 |
+
|
129 |
+
assert len(idx) == size*size
|
130 |
+
idx = torch.tensor(idx)
|
131 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
132 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
133 |
+
|
134 |
+
def forward(self, x, reverse=False):
|
135 |
+
if not reverse:
|
136 |
+
return x[:, self.forward_shuffle_idx]
|
137 |
+
else:
|
138 |
+
return x[:, self.backward_shuffle_idx]
|
139 |
+
|
140 |
+
|
141 |
+
class SpiralIn(AbstractPermuter):
|
142 |
+
def __init__(self, H, W):
|
143 |
+
super().__init__()
|
144 |
+
assert H == W
|
145 |
+
size = W
|
146 |
+
indices = np.arange(size*size).reshape(size,size)
|
147 |
+
|
148 |
+
i0 = size//2
|
149 |
+
j0 = size//2-1
|
150 |
+
|
151 |
+
i = i0
|
152 |
+
j = j0
|
153 |
+
|
154 |
+
idx = [indices[i0, j0]]
|
155 |
+
step_mult = 0
|
156 |
+
for c in range(1, size//2+1):
|
157 |
+
step_mult += 1
|
158 |
+
# steps left
|
159 |
+
for k in range(step_mult):
|
160 |
+
i = i - 1
|
161 |
+
j = j
|
162 |
+
idx.append(indices[i, j])
|
163 |
+
|
164 |
+
# step down
|
165 |
+
for k in range(step_mult):
|
166 |
+
i = i
|
167 |
+
j = j + 1
|
168 |
+
idx.append(indices[i, j])
|
169 |
+
|
170 |
+
step_mult += 1
|
171 |
+
if c < size//2:
|
172 |
+
# step right
|
173 |
+
for k in range(step_mult):
|
174 |
+
i = i + 1
|
175 |
+
j = j
|
176 |
+
idx.append(indices[i, j])
|
177 |
+
|
178 |
+
# step up
|
179 |
+
for k in range(step_mult):
|
180 |
+
i = i
|
181 |
+
j = j - 1
|
182 |
+
idx.append(indices[i, j])
|
183 |
+
else:
|
184 |
+
# end reached
|
185 |
+
for k in range(step_mult-1):
|
186 |
+
i = i + 1
|
187 |
+
idx.append(indices[i, j])
|
188 |
+
|
189 |
+
assert len(idx) == size*size
|
190 |
+
idx = idx[::-1]
|
191 |
+
idx = torch.tensor(idx)
|
192 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
193 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
194 |
+
|
195 |
+
def forward(self, x, reverse=False):
|
196 |
+
if not reverse:
|
197 |
+
return x[:, self.forward_shuffle_idx]
|
198 |
+
else:
|
199 |
+
return x[:, self.backward_shuffle_idx]
|
200 |
+
|
201 |
+
|
202 |
+
class Random(nn.Module):
|
203 |
+
def __init__(self, H, W):
|
204 |
+
super().__init__()
|
205 |
+
indices = np.random.RandomState(1).permutation(H*W)
|
206 |
+
idx = torch.tensor(indices.ravel())
|
207 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
208 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
209 |
+
|
210 |
+
def forward(self, x, reverse=False):
|
211 |
+
if not reverse:
|
212 |
+
return x[:, self.forward_shuffle_idx]
|
213 |
+
else:
|
214 |
+
return x[:, self.backward_shuffle_idx]
|
215 |
+
|
216 |
+
|
217 |
+
class AlternateParsing(AbstractPermuter):
|
218 |
+
def __init__(self, H, W):
|
219 |
+
super().__init__()
|
220 |
+
indices = np.arange(W*H).reshape(H,W)
|
221 |
+
for i in range(1, H, 2):
|
222 |
+
indices[i, :] = indices[i, ::-1]
|
223 |
+
idx = indices.flatten()
|
224 |
+
assert len(idx) == H*W
|
225 |
+
idx = torch.tensor(idx)
|
226 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
227 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
228 |
+
|
229 |
+
def forward(self, x, reverse=False):
|
230 |
+
if not reverse:
|
231 |
+
return x[:, self.forward_shuffle_idx]
|
232 |
+
else:
|
233 |
+
return x[:, self.backward_shuffle_idx]
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
p0 = AlternateParsing(16, 16)
|
238 |
+
print(p0.forward_shuffle_idx)
|
239 |
+
print(p0.backward_shuffle_idx)
|
240 |
+
|
241 |
+
x = torch.randint(0, 768, size=(11, 256))
|
242 |
+
y = p0(x)
|
243 |
+
xre = p0(y, reverse=True)
|
244 |
+
assert torch.equal(x, xre)
|
245 |
+
|
246 |
+
p1 = SpiralOut(2, 2)
|
247 |
+
print(p1.forward_shuffle_idx)
|
248 |
+
print(p1.backward_shuffle_idx)
|
taming/modules/util.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def count_params(model):
|
6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
7 |
+
return total_params
|
8 |
+
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class AbstractEncoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def encode(self, *args, **kwargs):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
|
103 |
+
class Labelator(AbstractEncoder):
|
104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
106 |
+
super().__init__()
|
107 |
+
self.n_classes = n_classes
|
108 |
+
self.quantize_interface = quantize_interface
|
109 |
+
|
110 |
+
def encode(self, c):
|
111 |
+
c = c[:,None]
|
112 |
+
if self.quantize_interface:
|
113 |
+
return c, None, [None, None, c.long()]
|
114 |
+
return c
|
115 |
+
|
116 |
+
|
117 |
+
class SOSProvider(AbstractEncoder):
|
118 |
+
# for unconditional training
|
119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
120 |
+
super().__init__()
|
121 |
+
self.sos_token = sos_token
|
122 |
+
self.quantize_interface = quantize_interface
|
123 |
+
|
124 |
+
def encode(self, x):
|
125 |
+
# get batch size from data and replicate sos_token
|
126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
+
c = c.long().to(x.device)
|
128 |
+
if self.quantize_interface:
|
129 |
+
return c, None, [None, None, c]
|
130 |
+
return c
|
taming/modules/vqvae/quantize.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from torch import einsum
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
"""
|
11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
+
____________________________________________
|
13 |
+
Discretization bottleneck part of the VQ-VAE.
|
14 |
+
Inputs:
|
15 |
+
- n_e : number of embeddings
|
16 |
+
- e_dim : dimension of embedding
|
17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
+
_____________________________________________
|
19 |
+
"""
|
20 |
+
|
21 |
+
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
+
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
+
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
+
# more efficient.
|
25 |
+
def __init__(self, n_e, e_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.n_e = n_e
|
28 |
+
self.e_dim = e_dim
|
29 |
+
self.beta = beta
|
30 |
+
|
31 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""
|
36 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
+
z (continuous) -> z_q (discrete)
|
39 |
+
z.shape = (batch, channel, height, width)
|
40 |
+
quantization pipeline:
|
41 |
+
1. get encoder input (B,C,H,W)
|
42 |
+
2. flatten input to (B*H*W,C)
|
43 |
+
"""
|
44 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
+
z_flattened = z.view(-1, self.e_dim)
|
47 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
+
|
49 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
51 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
+
|
53 |
+
## could possible replace this here
|
54 |
+
# #\start...
|
55 |
+
# find closest encodings
|
56 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
+
|
58 |
+
min_encodings = torch.zeros(
|
59 |
+
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
+
|
62 |
+
# dtype min encodings: torch.float32
|
63 |
+
# min_encodings shape: torch.Size([2048, 512])
|
64 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
+
|
66 |
+
# get quantized latent vectors
|
67 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
+
#.........\end
|
69 |
+
|
70 |
+
# with:
|
71 |
+
# .........\start
|
72 |
+
#min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
+
#z_q = self.embedding(min_encoding_indices)
|
74 |
+
# ......\end......... (TODO)
|
75 |
+
|
76 |
+
# compute loss for embedding
|
77 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
78 |
+
torch.mean((z_q - z.detach()) ** 2)
|
79 |
+
|
80 |
+
# preserve gradients
|
81 |
+
z_q = z + (z_q - z).detach()
|
82 |
+
|
83 |
+
# perplexity
|
84 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
+
|
87 |
+
# reshape back to match original input shape
|
88 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
+
|
90 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
+
|
92 |
+
def get_codebook_entry(self, indices, shape):
|
93 |
+
# shape specifying (batch, height, width, channel)
|
94 |
+
# TODO: check for more easy handling with nn.Embedding
|
95 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
+
min_encodings.scatter_(1, indices[:,None], 1)
|
97 |
+
|
98 |
+
# get quantized latent vectors
|
99 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
+
|
101 |
+
if shape is not None:
|
102 |
+
z_q = z_q.view(shape)
|
103 |
+
|
104 |
+
# reshape back to match original input shape
|
105 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
+
|
107 |
+
return z_q
|
108 |
+
|
109 |
+
|
110 |
+
class GumbelQuantize(nn.Module):
|
111 |
+
"""
|
112 |
+
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
+
Gumbel Softmax trick quantizer
|
114 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
+
https://arxiv.org/abs/1611.01144
|
116 |
+
"""
|
117 |
+
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
118 |
+
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
119 |
+
remap=None, unknown_index="random"):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.embedding_dim = embedding_dim
|
123 |
+
self.n_embed = n_embed
|
124 |
+
|
125 |
+
self.straight_through = straight_through
|
126 |
+
self.temperature = temp_init
|
127 |
+
self.kl_weight = kl_weight
|
128 |
+
|
129 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
130 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
131 |
+
|
132 |
+
self.use_vqinterface = use_vqinterface
|
133 |
+
|
134 |
+
self.remap = remap
|
135 |
+
if self.remap is not None:
|
136 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
137 |
+
self.re_embed = self.used.shape[0]
|
138 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
139 |
+
if self.unknown_index == "extra":
|
140 |
+
self.unknown_index = self.re_embed
|
141 |
+
self.re_embed = self.re_embed+1
|
142 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
143 |
+
f"Using {self.unknown_index} for unknown indices.")
|
144 |
+
else:
|
145 |
+
self.re_embed = n_embed
|
146 |
+
|
147 |
+
def remap_to_used(self, inds):
|
148 |
+
ishape = inds.shape
|
149 |
+
assert len(ishape)>1
|
150 |
+
inds = inds.reshape(ishape[0],-1)
|
151 |
+
used = self.used.to(inds)
|
152 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
153 |
+
new = match.argmax(-1)
|
154 |
+
unknown = match.sum(2)<1
|
155 |
+
if self.unknown_index == "random":
|
156 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
157 |
+
else:
|
158 |
+
new[unknown] = self.unknown_index
|
159 |
+
return new.reshape(ishape)
|
160 |
+
|
161 |
+
def unmap_to_all(self, inds):
|
162 |
+
ishape = inds.shape
|
163 |
+
assert len(ishape)>1
|
164 |
+
inds = inds.reshape(ishape[0],-1)
|
165 |
+
used = self.used.to(inds)
|
166 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
167 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
168 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
169 |
+
return back.reshape(ishape)
|
170 |
+
|
171 |
+
def forward(self, z, temp=None, return_logits=False):
|
172 |
+
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
173 |
+
hard = self.straight_through if self.training else True
|
174 |
+
temp = self.temperature if temp is None else temp
|
175 |
+
|
176 |
+
logits = self.proj(z)
|
177 |
+
if self.remap is not None:
|
178 |
+
# continue only with used logits
|
179 |
+
full_zeros = torch.zeros_like(logits)
|
180 |
+
logits = logits[:,self.used,...]
|
181 |
+
|
182 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
183 |
+
if self.remap is not None:
|
184 |
+
# go back to all entries but unused set to zero
|
185 |
+
full_zeros[:,self.used,...] = soft_one_hot
|
186 |
+
soft_one_hot = full_zeros
|
187 |
+
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
188 |
+
|
189 |
+
# + kl divergence to the prior loss
|
190 |
+
qy = F.softmax(logits, dim=1)
|
191 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
192 |
+
|
193 |
+
ind = soft_one_hot.argmax(dim=1)
|
194 |
+
if self.remap is not None:
|
195 |
+
ind = self.remap_to_used(ind)
|
196 |
+
if self.use_vqinterface:
|
197 |
+
if return_logits:
|
198 |
+
return z_q, diff, (None, None, ind), logits
|
199 |
+
return z_q, diff, (None, None, ind)
|
200 |
+
return z_q, diff, ind
|
201 |
+
|
202 |
+
def get_codebook_entry(self, indices, shape):
|
203 |
+
b, h, w, c = shape
|
204 |
+
assert b*h*w == indices.shape[0]
|
205 |
+
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
206 |
+
if self.remap is not None:
|
207 |
+
indices = self.unmap_to_all(indices)
|
208 |
+
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
209 |
+
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
210 |
+
return z_q
|
211 |
+
|
212 |
+
|
213 |
+
class VectorQuantizer2(nn.Module):
|
214 |
+
"""
|
215 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
216 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
217 |
+
"""
|
218 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
219 |
+
# backwards compatibility we use the buggy version by default, but you can
|
220 |
+
# specify legacy=False to fix it.
|
221 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
222 |
+
sane_index_shape=False, legacy=True):
|
223 |
+
super().__init__()
|
224 |
+
self.n_e = n_e
|
225 |
+
self.e_dim = e_dim
|
226 |
+
self.beta = beta
|
227 |
+
self.legacy = legacy
|
228 |
+
|
229 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
230 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
231 |
+
|
232 |
+
self.remap = remap
|
233 |
+
if self.remap is not None:
|
234 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
235 |
+
self.re_embed = self.used.shape[0]
|
236 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
237 |
+
if self.unknown_index == "extra":
|
238 |
+
self.unknown_index = self.re_embed
|
239 |
+
self.re_embed = self.re_embed+1
|
240 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
241 |
+
f"Using {self.unknown_index} for unknown indices.")
|
242 |
+
else:
|
243 |
+
self.re_embed = n_e
|
244 |
+
|
245 |
+
self.sane_index_shape = sane_index_shape
|
246 |
+
|
247 |
+
def remap_to_used(self, inds):
|
248 |
+
ishape = inds.shape
|
249 |
+
assert len(ishape)>1
|
250 |
+
inds = inds.reshape(ishape[0],-1)
|
251 |
+
used = self.used.to(inds)
|
252 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
253 |
+
new = match.argmax(-1)
|
254 |
+
unknown = match.sum(2)<1
|
255 |
+
if self.unknown_index == "random":
|
256 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
257 |
+
else:
|
258 |
+
new[unknown] = self.unknown_index
|
259 |
+
return new.reshape(ishape)
|
260 |
+
|
261 |
+
def unmap_to_all(self, inds):
|
262 |
+
ishape = inds.shape
|
263 |
+
assert len(ishape)>1
|
264 |
+
inds = inds.reshape(ishape[0],-1)
|
265 |
+
used = self.used.to(inds)
|
266 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
267 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
268 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
269 |
+
return back.reshape(ishape)
|
270 |
+
|
271 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
272 |
+
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
273 |
+
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
274 |
+
assert return_logits==False, "Only for interface compatible with Gumbel"
|
275 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
276 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
277 |
+
z_flattened = z.view(-1, self.e_dim)
|
278 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
279 |
+
|
280 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
281 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
282 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
283 |
+
|
284 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
285 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
286 |
+
perplexity = None
|
287 |
+
min_encodings = None
|
288 |
+
|
289 |
+
# compute loss for embedding
|
290 |
+
if not self.legacy:
|
291 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
292 |
+
torch.mean((z_q - z.detach()) ** 2)
|
293 |
+
else:
|
294 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
295 |
+
torch.mean((z_q - z.detach()) ** 2)
|
296 |
+
|
297 |
+
# preserve gradients
|
298 |
+
z_q = z + (z_q - z).detach()
|
299 |
+
|
300 |
+
# reshape back to match original input shape
|
301 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
302 |
+
|
303 |
+
if self.remap is not None:
|
304 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
305 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
306 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
307 |
+
|
308 |
+
if self.sane_index_shape:
|
309 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
310 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
311 |
+
|
312 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
313 |
+
|
314 |
+
def get_codebook_entry(self, indices, shape):
|
315 |
+
# shape specifying (batch, height, width, channel)
|
316 |
+
if self.remap is not None:
|
317 |
+
indices = indices.reshape(shape[0],-1) # add batch axis
|
318 |
+
indices = self.unmap_to_all(indices)
|
319 |
+
indices = indices.reshape(-1) # flatten again
|
320 |
+
|
321 |
+
# get quantized latent vectors
|
322 |
+
z_q = self.embedding(indices)
|
323 |
+
|
324 |
+
if shape is not None:
|
325 |
+
z_q = z_q.view(shape)
|
326 |
+
# reshape back to match original input shape
|
327 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
328 |
+
|
329 |
+
return z_q
|
330 |
+
|
331 |
+
class EmbeddingEMA(nn.Module):
|
332 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
333 |
+
super().__init__()
|
334 |
+
self.decay = decay
|
335 |
+
self.eps = eps
|
336 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
337 |
+
self.weight = nn.Parameter(weight, requires_grad = False)
|
338 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
339 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
340 |
+
self.update = True
|
341 |
+
|
342 |
+
def forward(self, embed_id):
|
343 |
+
return F.embedding(embed_id, self.weight)
|
344 |
+
|
345 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
346 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
347 |
+
|
348 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
349 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
350 |
+
|
351 |
+
def weight_update(self, num_tokens):
|
352 |
+
n = self.cluster_size.sum()
|
353 |
+
smoothed_cluster_size = (
|
354 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
355 |
+
)
|
356 |
+
#normalize embedding average with smoothed cluster size
|
357 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
358 |
+
self.weight.data.copy_(embed_normalized)
|
359 |
+
|
360 |
+
|
361 |
+
class EMAVectorQuantizer(nn.Module):
|
362 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
363 |
+
remap=None, unknown_index="random"):
|
364 |
+
super().__init__()
|
365 |
+
self.codebook_dim = codebook_dim
|
366 |
+
self.num_tokens = num_tokens
|
367 |
+
self.beta = beta
|
368 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
369 |
+
|
370 |
+
self.remap = remap
|
371 |
+
if self.remap is not None:
|
372 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
373 |
+
self.re_embed = self.used.shape[0]
|
374 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
375 |
+
if self.unknown_index == "extra":
|
376 |
+
self.unknown_index = self.re_embed
|
377 |
+
self.re_embed = self.re_embed+1
|
378 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
379 |
+
f"Using {self.unknown_index} for unknown indices.")
|
380 |
+
else:
|
381 |
+
self.re_embed = n_embed
|
382 |
+
|
383 |
+
def remap_to_used(self, inds):
|
384 |
+
ishape = inds.shape
|
385 |
+
assert len(ishape)>1
|
386 |
+
inds = inds.reshape(ishape[0],-1)
|
387 |
+
used = self.used.to(inds)
|
388 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
389 |
+
new = match.argmax(-1)
|
390 |
+
unknown = match.sum(2)<1
|
391 |
+
if self.unknown_index == "random":
|
392 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
393 |
+
else:
|
394 |
+
new[unknown] = self.unknown_index
|
395 |
+
return new.reshape(ishape)
|
396 |
+
|
397 |
+
def unmap_to_all(self, inds):
|
398 |
+
ishape = inds.shape
|
399 |
+
assert len(ishape)>1
|
400 |
+
inds = inds.reshape(ishape[0],-1)
|
401 |
+
used = self.used.to(inds)
|
402 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
403 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
404 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
405 |
+
return back.reshape(ishape)
|
406 |
+
|
407 |
+
def forward(self, z):
|
408 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
409 |
+
#z, 'b c h w -> b h w c'
|
410 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
411 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
412 |
+
|
413 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
414 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
415 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
416 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
417 |
+
|
418 |
+
|
419 |
+
encoding_indices = torch.argmin(d, dim=1)
|
420 |
+
|
421 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
422 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
423 |
+
avg_probs = torch.mean(encodings, dim=0)
|
424 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
425 |
+
|
426 |
+
if self.training and self.embedding.update:
|
427 |
+
#EMA cluster size
|
428 |
+
encodings_sum = encodings.sum(0)
|
429 |
+
self.embedding.cluster_size_ema_update(encodings_sum)
|
430 |
+
#EMA embedding average
|
431 |
+
embed_sum = encodings.transpose(0,1) @ z_flattened
|
432 |
+
self.embedding.embed_avg_ema_update(embed_sum)
|
433 |
+
#normalize embed_avg and update weight
|
434 |
+
self.embedding.weight_update(self.num_tokens)
|
435 |
+
|
436 |
+
# compute loss for embedding
|
437 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
438 |
+
|
439 |
+
# preserve gradients
|
440 |
+
z_q = z + (z_q - z).detach()
|
441 |
+
|
442 |
+
# reshape back to match original input shape
|
443 |
+
#z_q, 'b h w c -> b c h w'
|
444 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
445 |
+
return z_q, loss, (perplexity, encodings, encoding_indices)
|
taming/util.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, hashlib
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
URL_MAP = {
|
6 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
7 |
+
}
|
8 |
+
|
9 |
+
CKPT_MAP = {
|
10 |
+
"vgg_lpips": "vgg.pth"
|
11 |
+
}
|
12 |
+
|
13 |
+
MD5_MAP = {
|
14 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def download(url, local_path, chunk_size=1024):
|
19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
20 |
+
with requests.get(url, stream=True) as r:
|
21 |
+
total_size = int(r.headers.get("content-length", 0))
|
22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
23 |
+
with open(local_path, "wb") as f:
|
24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
25 |
+
if data:
|
26 |
+
f.write(data)
|
27 |
+
pbar.update(chunk_size)
|
28 |
+
|
29 |
+
|
30 |
+
def md5_hash(path):
|
31 |
+
with open(path, "rb") as f:
|
32 |
+
content = f.read()
|
33 |
+
return hashlib.md5(content).hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
def get_ckpt_path(name, root, check=False):
|
37 |
+
assert name in URL_MAP
|
38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
41 |
+
download(URL_MAP[name], path)
|
42 |
+
md5 = md5_hash(path)
|
43 |
+
assert md5 == MD5_MAP[name], md5
|
44 |
+
return path
|
45 |
+
|
46 |
+
|
47 |
+
class KeyNotFoundError(Exception):
|
48 |
+
def __init__(self, cause, keys=None, visited=None):
|
49 |
+
self.cause = cause
|
50 |
+
self.keys = keys
|
51 |
+
self.visited = visited
|
52 |
+
messages = list()
|
53 |
+
if keys is not None:
|
54 |
+
messages.append("Key not found: {}".format(keys))
|
55 |
+
if visited is not None:
|
56 |
+
messages.append("Visited: {}".format(visited))
|
57 |
+
messages.append("Cause:\n{}".format(cause))
|
58 |
+
message = "\n".join(messages)
|
59 |
+
super().__init__(message)
|
60 |
+
|
61 |
+
|
62 |
+
def retrieve(
|
63 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
64 |
+
):
|
65 |
+
"""Given a nested list or dict return the desired value at key expanding
|
66 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
67 |
+
is done in-place.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
list_or_dict : list or dict
|
72 |
+
Possibly nested list or dictionary.
|
73 |
+
key : str
|
74 |
+
key/to/value, path like string describing all keys necessary to
|
75 |
+
consider to get to the desired value. List indices can also be
|
76 |
+
passed here.
|
77 |
+
splitval : str
|
78 |
+
String that defines the delimiter between keys of the
|
79 |
+
different depth levels in `key`.
|
80 |
+
default : obj
|
81 |
+
Value returned if :attr:`key` is not found.
|
82 |
+
expand : bool
|
83 |
+
Whether to expand callable nodes on the path or not.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
88 |
+
:attr:`key` is not found returns ``default``.
|
89 |
+
|
90 |
+
Raises
|
91 |
+
------
|
92 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
93 |
+
``None``.
|
94 |
+
"""
|
95 |
+
|
96 |
+
keys = key.split(splitval)
|
97 |
+
|
98 |
+
success = True
|
99 |
+
try:
|
100 |
+
visited = []
|
101 |
+
parent = None
|
102 |
+
last_key = None
|
103 |
+
for key in keys:
|
104 |
+
if callable(list_or_dict):
|
105 |
+
if not expand:
|
106 |
+
raise KeyNotFoundError(
|
107 |
+
ValueError(
|
108 |
+
"Trying to get past callable node with expand=False."
|
109 |
+
),
|
110 |
+
keys=keys,
|
111 |
+
visited=visited,
|
112 |
+
)
|
113 |
+
list_or_dict = list_or_dict()
|
114 |
+
parent[last_key] = list_or_dict
|
115 |
+
|
116 |
+
last_key = key
|
117 |
+
parent = list_or_dict
|
118 |
+
|
119 |
+
try:
|
120 |
+
if isinstance(list_or_dict, dict):
|
121 |
+
list_or_dict = list_or_dict[key]
|
122 |
+
else:
|
123 |
+
list_or_dict = list_or_dict[int(key)]
|
124 |
+
except (KeyError, IndexError, ValueError) as e:
|
125 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
126 |
+
|
127 |
+
visited += [key]
|
128 |
+
# final expansion of retrieved value
|
129 |
+
if expand and callable(list_or_dict):
|
130 |
+
list_or_dict = list_or_dict()
|
131 |
+
parent[last_key] = list_or_dict
|
132 |
+
except KeyNotFoundError as e:
|
133 |
+
if default is None:
|
134 |
+
raise e
|
135 |
+
else:
|
136 |
+
list_or_dict = default
|
137 |
+
success = False
|
138 |
+
|
139 |
+
if not pass_success:
|
140 |
+
return list_or_dict
|
141 |
+
else:
|
142 |
+
return list_or_dict, success
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
config = {"keya": "a",
|
147 |
+
"keyb": "b",
|
148 |
+
"keyc":
|
149 |
+
{"cc1": 1,
|
150 |
+
"cc2": 2,
|
151 |
+
}
|
152 |
+
}
|
153 |
+
from omegaconf import OmegaConf
|
154 |
+
config = OmegaConf.create(config)
|
155 |
+
print(config)
|
156 |
+
retrieve(config, "keya")
|
157 |
+
|