mag2mag / app.py
fpramunno's picture
Update app.py
3475ae1 verified
raw
history blame contribute delete
No virus
4.16 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
from astropy.io import fits
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
from torchvision import transforms
# Define a custom transform to clamp data
class ClampTransform(object):
def __init__(self, min_value=-250, max_value=250):
self.min_value = min_value
self.max_value = max_value
def __call__(self, tensor):
return torch.clamp(tensor, self.min_value, self.max_value)
transform_hmi_jp2 = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
transform_hmi_fits = transforms.Compose([
transforms.ToTensor(),
ClampTransform(-250, 250),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
_, file_ext = os.path.splitext(seed_image)
if file_ext.lower() == '.jp2':
input_img = Image.open(seed_image)
input_img_pil = transform_hmi_jp2(input_img).reshape(1, 1, 256, 256).to(device)
elif file_ext.lower() == '.fits':
with fits.open(seed_image) as hdul:
data = hdul[0].data
input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device)
else:
print(f'Format {file_ext.lower()} not supported')
generated_image = diffusion.sample(model, y=input_img_pil, labels=None, n=1)
inp_img = (input_img_pil.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
inp_img = (inp_img * 255).type(torch.uint8) # to bring in valid pixel range
inp_img = np.squeeze(inp_img.cpu().numpy())
inp = Image.fromarray(inp_img) # Create a PIL Image from array
inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return inp, v
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_image = gr.File(label='Input Image')
output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
# Buttons are placed in a nested Row inside the main Row to align them directly under the image
with gr.Row():
clear_button = gr.Button('Clear')
process_button = gr.Button('Generate')
# Binding the process button to the function
process_button.click(
fn=generate_image,
inputs=input_image,
outputs=[output_image1, output_image2]
)
# Clear button to reset the input image
clear_button.click(
fn=lambda: None, # Clears the input
inputs=None,
outputs=input_image
)
demo.launch()