File size: 4,163 Bytes
0f1af34
 
 
 
 
 
 
be9674f
13afb1c
0f1af34
 
 
 
 
ff34208
 
 
 
 
05493b2
 
 
0f1af34
a5a784e
4eeda6c
0f1af34
 
0210cac
0f1af34
 
8576ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f1af34
8576ba9
0f1af34
 
 
 
 
 
ff34208
 
 
9f24831
 
ff34208
 
 
8576ba9
78f112e
 
ff34208
78f112e
ff34208
3475ae1
3d25cfa
ff34208
 
 
3d25cfa
eb240c8
0210cac
 
 
 
ff34208
0f1af34
 
cf7ac47
 
 
 
 
 
 
 
 
 
 
0f1af34
cf7ac47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
from astropy.io import fits

# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond

DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
    <h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
    <img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''

# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)

diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()

from torchvision import transforms

# Define a custom transform to clamp data
class ClampTransform(object):
    def __init__(self, min_value=-250, max_value=250):
        self.min_value = min_value
        self.max_value = max_value

    def __call__(self, tensor):
        return torch.clamp(tensor, self.min_value, self.max_value)

transform_hmi_jp2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.RandomVerticalFlip(p=1.0),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

transform_hmi_fits = transforms.Compose([
    transforms.ToTensor(),
    ClampTransform(-250, 250),
    transforms.Resize((256, 256)),
    transforms.RandomVerticalFlip(p=1.0),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

def generate_image(seed_image):
    _, file_ext = os.path.splitext(seed_image)
    
    if file_ext.lower() == '.jp2':
        input_img = Image.open(seed_image)
        input_img_pil = transform_hmi_jp2(input_img).reshape(1, 1, 256, 256).to(device)
    elif file_ext.lower() == '.fits':
        with fits.open(seed_image) as hdul:
            data = hdul[0].data
            input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device)
    else:
        print(f'Format {file_ext.lower()} not supported')
            
    generated_image = diffusion.sample(model, y=input_img_pil, labels=None, n=1)

    inp_img = (input_img_pil.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
    inp_img = (inp_img * 255).type(torch.uint8) # to bring in valid pixel range
    inp_img = np.squeeze(inp_img.cpu().numpy())
    inp = Image.fromarray(inp_img)                # Create a PIL Image from array
    inp = inp.transpose(Image.FLIP_TOP_BOTTOM) 
    
    img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0)              # Permute dimensions to height x width x channels
    img = np.squeeze(img.cpu().numpy())
    v = Image.fromarray(img)                # Create a PIL Image from array
    v = v.transpose(Image.FLIP_TOP_BOTTOM)  
    
    return inp, v

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        input_image = gr.File(label='Input Image')
        output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
        output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
        # Buttons are placed in a nested Row inside the main Row to align them directly under the image
    with gr.Row():
        clear_button = gr.Button('Clear')
        process_button = gr.Button('Generate')
        

    # Binding the process button to the function
    process_button.click(
        fn=generate_image,
        inputs=input_image,
        outputs=[output_image1, output_image2]
    )
    
    # Clear button to reset the input image
    clear_button.click(
        fn=lambda: None,  # Clears the input
        inputs=None,
        outputs=input_image
    )

demo.launch()