Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_
|
|
17 |
ckpt = torch.load('ema_ckpt_cond.pt')
|
18 |
model.load_state_dict(ckpt)
|
19 |
|
20 |
-
diffusion = Diffusion_cond(
|
21 |
model.eval()
|
22 |
|
23 |
transform_hmi = transforms.Compose([
|
@@ -30,8 +30,13 @@ transform_hmi = transforms.Compose([
|
|
30 |
def generate_image(seed_image):
|
31 |
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
|
32 |
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
|
33 |
-
generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Create Gradio interface
|
37 |
iface = gr.Interface(
|
|
|
17 |
ckpt = torch.load('ema_ckpt_cond.pt')
|
18 |
model.load_state_dict(ckpt)
|
19 |
|
20 |
+
diffusion = Diffusion_cond(img_size=256, device=device)
|
21 |
model.eval()
|
22 |
|
23 |
transform_hmi = transforms.Compose([
|
|
|
30 |
def generate_image(seed_image):
|
31 |
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
|
32 |
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
|
33 |
+
# generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
|
34 |
+
img = generated_image[0].permute(1, 2, 0) # Permute dimensions to height x width x channels
|
35 |
+
img = np.squeeze(img.cpu().numpy())
|
36 |
+
v = Image.fromarray(img) # Create a PIL Image from array
|
37 |
+
v = v.transpose(Image.FLIP_TOP_BOTTOM)
|
38 |
+
|
39 |
+
return v
|
40 |
|
41 |
# Create Gradio interface
|
42 |
iface = gr.Interface(
|