import gradio as gr from pathlib import Path import os from PIL import Image import torch import torchvision.transforms as transforms import requests import numpy as np from astropy.io import fits # Preprocessing from modules import PaletteModelV2 from diffusion import Diffusion_cond DESCRIPTION = '''

MAG2MAG

teaser
''' # Check for GPU availability, else use CPU device = 'cuda' if torch.cuda.is_available() else 'cpu' model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device) ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device)) model.load_state_dict(ckpt) diffusion = Diffusion_cond(img_size=256, device=device) model.eval() from torchvision import transforms # Define a custom transform to clamp data class ClampTransform(object): def __init__(self, min_value=-250, max_value=250): self.min_value = min_value self.max_value = max_value def __call__(self, tensor): return torch.clamp(tensor, self.min_value, self.max_value) transform_hmi_jp2 = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256, 256)), transforms.RandomVerticalFlip(p=1.0), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) transform_hmi_fits = transforms.Compose([ transforms.ToTensor(), ClampTransform(-250, 250), transforms.Resize((256, 256)), transforms.RandomVerticalFlip(p=1.0), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) def generate_image(seed_image): _, file_ext = os.path.splitext(seed_image) if file_ext.lower() == '.jp2': input_img = Image.open(seed_image) input_img_pil = transform_hmi_jp2(input_img).reshape(1, 1, 256, 256).to(device) elif file_ext.lower() == '.fits': with fits.open(seed_image) as hdul: data = hdul[0].data input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device) else: print(f'Format {file_ext.lower()} not supported') generated_image = diffusion.sample(model, y=input_img_pil, labels=None, n=1) inp_img = (input_img_pil.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1] inp_img = (inp_img * 255).type(torch.uint8) # to bring in valid pixel range inp_img = np.squeeze(inp_img.cpu().numpy()) inp = Image.fromarray(inp_img) # Create a PIL Image from array inp = inp.transpose(Image.FLIP_TOP_BOTTOM) img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels img = np.squeeze(img.cpu().numpy()) v = Image.fromarray(img) # Create a PIL Image from array v = v.transpose(Image.FLIP_TOP_BOTTOM) return inp, v # Create Gradio interface with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): input_image = gr.File(label='Input Image') output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False) output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False) # Buttons are placed in a nested Row inside the main Row to align them directly under the image with gr.Row(): clear_button = gr.Button('Clear') process_button = gr.Button('Generate') # Binding the process button to the function process_button.click( fn=generate_image, inputs=input_image, outputs=[output_image1, output_image2] ) # Clear button to reset the input image clear_button.click( fn=lambda: None, # Clears the input inputs=None, outputs=input_image ) demo.launch()