fpramunno commited on
Commit
0f1af34
1 Parent(s): 36b7a16

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +75 -0
  2. diffusion.py +139 -0
  3. modules.py +243 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import requests
8
+
9
+ # Function to download the model from Google Drive
10
+ def download_file_from_google_drive(id, destination):
11
+ URL = "https://drive.google.com/uc?export=download"
12
+ session = requests.Session()
13
+ response = session.get(URL, params={'id': id}, stream=True)
14
+ token = get_confirm_token(response)
15
+
16
+ if token:
17
+ params = {'id': id, 'confirm': token}
18
+ response = session.get(URL, params=params, stream=True)
19
+
20
+ save_response_content(response, destination)
21
+
22
+ def get_confirm_token(response):
23
+ for key, value in response.cookies.items():
24
+ if key.startswith('download_warning'):
25
+ return value
26
+ return None
27
+
28
+ def save_response_content(response, destination):
29
+ CHUNK_SIZE = 32768
30
+ with open(destination, "wb") as f:
31
+ for chunk in response.iter_content(CHUNK_SIZE):
32
+ if chunk: # filter out keep-alive new chunks
33
+ f.write(chunk)
34
+
35
+ # Replace 'YOUR_FILE_ID' with your actual file ID from Google Drive
36
+ file_id = '1WJ33nys02XpPDsMO5uIZFiLqTuAT_iuV'
37
+ destination = 'ema_ckpt_cond.pt'
38
+ download_file_from_google_drive(file_id, destination)
39
+
40
+ # Preprocessing
41
+ from modules import PaletteModelV2
42
+ from diffusion import Diffusion_cond
43
+
44
+ device = 'cuda'
45
+
46
+ model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
47
+ ckpt = torch.load(destination, map_location=device)
48
+ model.load_state_dict(ckpt)
49
+
50
+ diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
51
+ model.eval()
52
+
53
+ transform_hmi = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Resize((256, 256)),
56
+ transforms.RandomVerticalFlip(p=1.0),
57
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
58
+ ])
59
+
60
+ def generate_image(seed_image):
61
+ seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
62
+ generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
63
+ generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
64
+ return generated_image_pil
65
+
66
+ # Create Gradio interface
67
+ iface = gr.Interface(
68
+ fn=generate_image,
69
+ inputs="file",
70
+ outputs="image",
71
+ title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
72
+ description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
73
+ )
74
+
75
+ iface.launch()
diffusion.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Apr 25 14:45:59 2023
4
+
5
+ @author: pio-r
6
+ """
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ import torch.nn as nn
11
+ import logging
12
+ import numpy as np
13
+
14
+ logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
15
+
16
+
17
+ class Diffusion_cond:
18
+ def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"):
19
+ self.noise_steps = noise_steps # timestesps
20
+ self.beta_start = beta_start
21
+ self.beta_end = beta_end
22
+ self.img_channel = img_channel
23
+ self.img_size = img_size
24
+ self.device = device
25
+
26
+ self.beta = self.prepare_noise_schedule().to(device)
27
+ self.alpha = 1. - self.beta
28
+ self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0)
29
+ self.alpha_hat = torch.cumprod(self.alpha, dim=0)
30
+ self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0)
31
+ # self.alphas_cumprod_prev = torch.from_numpy(np.append(1, self.alpha_hat[:-1].cpu().numpy())).to(device)
32
+ def prepare_noise_schedule(self):
33
+ return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # linear variance schedule as proposed by Ho et al 2020
34
+
35
+ def noise_images(self, x, t):
36
+ sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
37
+ sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
38
+ Ɛ = torch.randn_like(x)
39
+ return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ # equation in the paper from Ho et al that describes the noise processs
40
+
41
+ def sample_timesteps(self, n):
42
+ return torch.randint(low=1, high=self.noise_steps, size=(n,))
43
+
44
+ def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'):
45
+ logging.info(f"Sampling {n} new images....")
46
+ model.eval() # evaluation mode
47
+ with torch.no_grad(): # algorithm 2 from DDPM
48
+ x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device)
49
+ for i in tqdm(reversed(range(1, self.noise_steps)), position=0): # reverse loop from T to 1
50
+ t = (torch.ones(n) * i).long().to(self.device) # create timesteps tensor of length n
51
+ predicted_noise = model(x, y, labels, t)
52
+ if cfg_scale > 0:
53
+ uncond_predicted_noise = model(x, y, None, t)
54
+ predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
55
+
56
+
57
+ alpha = self.alpha[t][:, None, None, None]
58
+ alpha_hat = self.alpha_hat[t][:, None, None, None] # this is noise, created in one
59
+ alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None]
60
+ beta = self.beta[t][:, None, None, None]
61
+ # SAMPLING adjusted from Stable diffusion
62
+ sigma = (
63
+ eta
64
+ * torch.sqrt((1 - alpha_prev) / (1 - alpha_hat)
65
+ * (1 - alpha_hat / alpha_prev))
66
+ )
67
+ if i > 1:
68
+ noise = torch.randn_like(x)
69
+ else:
70
+ noise = torch.zeros_like(x)
71
+ # pred_x0 = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
72
+ pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
73
+ if sampling_mode == 'ddpm':
74
+ x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
75
+ elif sampling_mode == 'ddim':
76
+ noise = torch.randn_like(x)
77
+ nonzero_mask = (
78
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
79
+ )
80
+ x = (
81
+ torch.sqrt(alpha_prev) * pred_x0 +
82
+ torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise +
83
+ nonzero_mask * sigma * noise
84
+ )
85
+ else:
86
+ print('The sampler {} is not implemented'.format(sampling_mode))
87
+ break
88
+ model.train() # it goes back to training mode
89
+ # x = (x.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
90
+ # x = (x * 255).type(torch.uint8) # to bring in valid pixel range
91
+ return x
92
+
93
+ mse = nn.MSELoss()
94
+
95
+ def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
96
+ r"""Create a function that calculates the PSNR between 2 images.
97
+
98
+ PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
99
+ Given an m x n image, the PSNR is:
100
+
101
+ .. math::
102
+
103
+ \text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)
104
+
105
+ where
106
+
107
+ .. math::
108
+
109
+ \text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2
110
+
111
+ and :math:`\text{MAX}_I` is the maximum possible input value
112
+ (e.g for floating point images :math:`\text{MAX}_I=1`).
113
+
114
+ Args:
115
+ input: the input image with arbitrary shape :math:`(*)`.
116
+ labels: the labels image with arbitrary shape :math:`(*)`.
117
+ max_val: The maximum value in the input tensor.
118
+
119
+ Return:
120
+ the computed loss as a scalar.
121
+
122
+ Examples:
123
+ >>> ones = torch.ones(1)
124
+ >>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
125
+ tensor(20.0000)
126
+
127
+ Reference:
128
+ https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
129
+ """
130
+ if not isinstance(input, torch.Tensor):
131
+ raise TypeError(f"Expected torch.Tensor but got {type(target)}.")
132
+
133
+ if not isinstance(target, torch.Tensor):
134
+ raise TypeError(f"Expected torch.Tensor but got {type(input)}.")
135
+
136
+ if input.shape != target.shape:
137
+ raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")
138
+
139
+ return 10.0 * torch.log10(max_val**2 / mse(input, target))
modules.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Apr 25 14:28:21 2023
4
+
5
+ @author: pio-r
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ class EMA:
14
+ def __init__(self, beta):
15
+ super().__init__()
16
+ self.beta = beta
17
+ self.step = 0
18
+
19
+ def update_model_average(self, ma_model, current_model):
20
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
21
+ old_weight, up_weight = ma_params.data, current_params.data
22
+ ma_params.data = self.update_average(old_weight, up_weight)
23
+
24
+ def update_average(self, old, new):
25
+ if old is None:
26
+ return new
27
+ return old * self.beta + (1 - self.beta) * new
28
+
29
+ def step_ema(self, ema_model, model, step_start_ema=2000):
30
+ if self.step < step_start_ema:
31
+ self.reset_parameters(ema_model, model)
32
+ self.step += 1
33
+ return
34
+ self.update_model_average(ema_model, model)
35
+ self.step += 1
36
+
37
+ def reset_parameters(self, ema_model, model):
38
+ ema_model.load_state_dict(model.state_dict())
39
+
40
+ class SelfAttention(nn.Module):
41
+ """
42
+ Pre Layer norm -> multi-headed tension -> skip connections -> pass it to
43
+ the feed forward layer (layer-norm -> 2 multiheadattention)
44
+ """
45
+ def __init__(self, channels, size):
46
+ super(SelfAttention, self).__init__()
47
+ self.channels = channels
48
+ self.size = size
49
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
50
+ self.ln = nn.LayerNorm([channels])
51
+ self.ff_self = nn.Sequential(
52
+ nn.LayerNorm([channels]),
53
+ nn.Linear(channels, channels),
54
+ nn.GELU(),
55
+ nn.Linear(channels, channels),
56
+ )
57
+
58
+ def forward(self, x):
59
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
60
+ x_ln = self.ln(x)
61
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
62
+ attention_value = attention_value + x
63
+ attention_value = self.ff_self(attention_value) + attention_value
64
+ return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
65
+
66
+
67
+ class DoubleConv(nn.Module):
68
+ """
69
+ Normal convolution block, with 2d convolution -> Group Norm -> GeLU -> convolution -> Group Norm
70
+ Possibility to add residual connection providing residual=True
71
+ """
72
+ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
73
+ super().__init__()
74
+ self.residual = residual
75
+ if not mid_channels:
76
+ mid_channels = out_channels
77
+ self.double_conv = nn.Sequential(
78
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
79
+ nn.GroupNorm(1, mid_channels),
80
+ nn.GELU(),
81
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
82
+ nn.GroupNorm(1, out_channels),
83
+ )
84
+
85
+ def forward(self, x):
86
+ if self.residual:
87
+ return F.gelu(x + self.double_conv(x))
88
+ else:
89
+ return self.double_conv(x)
90
+
91
+
92
+ class Down(nn.Module):
93
+ """
94
+ maxpool reduce size by half -> 2*DoubleConv -> Embedding layer
95
+
96
+ """
97
+ def __init__(self, in_channels, out_channels, emb_dim=256):
98
+ super().__init__()
99
+ self.maxpool_conv = nn.Sequential(
100
+ nn.MaxPool2d(2),
101
+ DoubleConv(in_channels, in_channels, residual=True),
102
+ DoubleConv(in_channels, out_channels),
103
+ )
104
+
105
+ self.emb_layer = nn.Sequential(
106
+ nn.SiLU(),
107
+ nn.Linear( # linear projection to bring the time embedding to the proper dimension
108
+ emb_dim,
109
+ out_channels
110
+ ),
111
+ )
112
+
113
+ def forward(self, x, t):
114
+ x = self.maxpool_conv(x)
115
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) # projection
116
+ return x + emb
117
+
118
+
119
+ class Up(nn.Module):
120
+ """
121
+ We take the skip connection which comes from the encoder
122
+ """
123
+ def __init__(self, in_channels, out_channels, emb_dim=256):
124
+ super().__init__()
125
+
126
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
127
+ self.conv = nn.Sequential(
128
+ DoubleConv(in_channels, in_channels, residual=True),
129
+ DoubleConv(in_channels, out_channels, in_channels // 2),
130
+ )
131
+
132
+ self.emb_layer = nn.Sequential(
133
+ nn.SiLU(),
134
+ nn.Linear(
135
+ emb_dim,
136
+ out_channels
137
+ ),
138
+ )
139
+
140
+ def forward(self, x, skip_x, t):
141
+ x = self.up(x)
142
+ x = torch.cat([skip_x, x], dim=1)
143
+ x = self.conv(x)
144
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
145
+ return x + emb
146
+
147
+ class PaletteModelV2(nn.Module):
148
+ def __init__(self, c_in=1, c_out=1, image_size=64, time_dim=256, device='cuda', latent=False, true_img_size=64, num_classes=None):
149
+ super(PaletteModelV2, self).__init__()
150
+
151
+ # Encoder
152
+ self.true_img_size = true_img_size
153
+ self.image_size = image_size
154
+ self.time_dim = time_dim
155
+ self.device = device
156
+ self.inc = DoubleConv(c_in, self.image_size) # Wrap-up for 2 Conv Layers
157
+ self.down1 = Down(self.image_size, self.image_size*2) # input and output channels
158
+ # self.sa1 = SelfAttention(self.image_size*2,int( self.true_img_size/2)) # 1st is channel dim, 2nd current image resolution
159
+ self.down2 = Down(self.image_size*2, self.image_size*4)
160
+ # self.sa2 = SelfAttention(self.image_size*4, int(self.true_img_size/4))
161
+ self.down3 = Down(self.image_size*4, self.image_size*4)
162
+ # self.sa3 = SelfAttention(self.image_size*4, int(self.true_img_size/8))
163
+
164
+ # Bootleneck
165
+ self.bot1 = DoubleConv(self.image_size*4, self.image_size*8)
166
+ self.bot2 = DoubleConv(self.image_size*8, self.image_size*8)
167
+ self.bot3 = DoubleConv(self.image_size*8, self.image_size*4)
168
+
169
+ # Decoder: reverse of encoder
170
+ self.up1 = Up(self.image_size*8, self.image_size*2)
171
+ # self.sa4 = SelfAttention(self.image_size*2, int(self.true_img_size/4))
172
+ self.up2 = Up(self.image_size*4, self.image_size)
173
+ # self.sa5 = SelfAttention(self.image_size, int(self.true_img_size/2))
174
+ self.up3 = Up(self.image_size*2, self.image_size)
175
+ # self.sa6 = SelfAttention(self.image_size, self.true_img_size)
176
+ self.outc = nn.Conv2d(self.image_size, c_out, kernel_size=1) # projecting back to the output channel dimensions
177
+
178
+ if num_classes is not None:
179
+ self.label_emb = nn.Embedding(num_classes, time_dim)
180
+
181
+ if latent == True:
182
+ self.latent = nn.Sequential(
183
+ nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
184
+ nn.LeakyReLU(0.2),
185
+ nn.MaxPool2d(kernel_size=2, stride=2),
186
+ nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
187
+ nn.LeakyReLU(0.2),
188
+ nn.MaxPool2d(kernel_size=2, stride=2),
189
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
190
+ nn.LeakyReLU(0.2),
191
+ nn.MaxPool2d(kernel_size=2, stride=2),
192
+ nn.Flatten(),
193
+ nn.Linear(64 * 8 * 8, 256)).to(device)
194
+
195
+ def pos_encoding(self, t, channels):
196
+ """
197
+ Input noised images and the timesteps. The timesteps will only be
198
+ a tensor with the integer timesteps values in it
199
+ """
200
+ inv_freq = 1.0 / (
201
+ 10000
202
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
203
+ )
204
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
205
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
206
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
207
+ return pos_enc
208
+
209
+ def forward(self, x, y, lab, t):
210
+ # Pass the source image through the encoder network
211
+ t = t.unsqueeze(-1).type(torch.float)
212
+ t = self.pos_encoding(t, self.time_dim) # Encoding timesteps is HERE, we provide the dimension we want to encode
213
+
214
+
215
+ if lab is not None:
216
+ t += self.label_emb(lab)
217
+
218
+ # t += self.latent(y)
219
+
220
+ # Concatenate the source image and reference image
221
+ x = torch.cat([x, y], dim=1)
222
+
223
+ x1 = self.inc(x)
224
+ x2 = self.down1(x1, t)
225
+ # x2 = self.sa1(x2)
226
+ x3 = self.down2(x2, t)
227
+ # x3 = self.sa2(x3)
228
+ x4 = self.down3(x3, t)
229
+ # x4 = self.sa3(x4)
230
+
231
+ x4 = self.bot1(x4)
232
+ x4 = self.bot2(x4)
233
+ x4 = self.bot3(x4)
234
+
235
+ x = self.up1(x4, x3, t) # We note that upsampling box that in the skip connections from encoder
236
+ # x = self.sa4(x)
237
+ x = self.up2(x, x2, t)
238
+ # x = self.sa5(x)
239
+ x = self.up3(x, x1, t)
240
+ # x = self.sa6(x)
241
+ output = self.outc(x)
242
+
243
+ return output