Update app.py
Browse files
app.py
CHANGED
@@ -28,8 +28,27 @@ model.load_state_dict(ckpt)
|
|
28 |
diffusion = Diffusion_cond(img_size=256, device=device)
|
29 |
model.eval()
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
transforms.ToTensor(),
|
|
|
33 |
transforms.Resize((256, 256)),
|
34 |
transforms.RandomVerticalFlip(p=1.0),
|
35 |
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
@@ -39,12 +58,12 @@ def generate_image(seed_image):
|
|
39 |
_, file_ext = os.path.splitext(seed_image)
|
40 |
|
41 |
if file_ext.lower() == '.jp2':
|
42 |
-
input_img = Image.
|
43 |
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
|
44 |
elif file_ext.lower() == '.fits':
|
45 |
with fits.open(seed_image) as hdul:
|
46 |
data = hdul[0].data
|
47 |
-
input_img_pil =
|
48 |
else:
|
49 |
print(f'Format {file_ext.lower()} not supported')
|
50 |
|
|
|
28 |
diffusion = Diffusion_cond(img_size=256, device=device)
|
29 |
model.eval()
|
30 |
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
# Define a custom transform to clamp data
|
34 |
+
class ClampTransform(object):
|
35 |
+
def __init__(self, min_value=-250, max_value=250):
|
36 |
+
self.min_value = min_value
|
37 |
+
self.max_value = max_value
|
38 |
+
|
39 |
+
def __call__(self, tensor):
|
40 |
+
return torch.clamp(tensor, self.min_value, self.max_value)
|
41 |
+
|
42 |
+
transform_hmi_jp2 = transforms.Compose([
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Resize((256, 256)),
|
45 |
+
transforms.RandomVerticalFlip(p=1.0),
|
46 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
47 |
+
])
|
48 |
+
|
49 |
+
transform_hmi_fits = transforms.Compose([
|
50 |
transforms.ToTensor(),
|
51 |
+
ClampTransform(-250, 250),
|
52 |
transforms.Resize((256, 256)),
|
53 |
transforms.RandomVerticalFlip(p=1.0),
|
54 |
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
|
|
58 |
_, file_ext = os.path.splitext(seed_image)
|
59 |
|
60 |
if file_ext.lower() == '.jp2':
|
61 |
+
input_img = Image.transform_hmi_jp2(seed_image)
|
62 |
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
|
63 |
elif file_ext.lower() == '.fits':
|
64 |
with fits.open(seed_image) as hdul:
|
65 |
data = hdul[0].data
|
66 |
+
input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device)
|
67 |
else:
|
68 |
print(f'Format {file_ext.lower()} not supported')
|
69 |
|