File size: 3,355 Bytes
0f1af34 be9674f 0f1af34 ff34208 05493b2 0f1af34 a5a784e 4eeda6c 0f1af34 0210cac 0f1af34 ff34208 0f1af34 ff34208 eb240c8 0210cac ff34208 0f1af34 cf7ac47 0f1af34 cf7ac47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
_, file_ext = os.path.splitext(seed_image)
if file_ext.lower() == '.jp2':
input_img = Image.open(seed_image)
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
elif file_ext.lower() == '.fits':
with fits.open(seed_image) as hdul:
data = hdul[0].data
input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)
inp_img = np.squeeze(inp_img.cpu().numpy())
inp = Image.fromarray(inp_img) # Create a PIL Image from array
inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return inp, v
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_image = gr.File(label='Input Image')
output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
# Buttons are placed in a nested Row inside the main Row to align them directly under the image
with gr.Row():
clear_button = gr.Button('Clear')
process_button = gr.Button('Generate')
# Binding the process button to the function
process_button.click(
fn=generate_image,
inputs=input_image,
outputs=[output_image1, output_image2]
)
# Clear button to reset the input image
clear_button.click(
fn=lambda: None, # Clears the input
inputs=None,
outputs=input_image
)
demo.launch()
|