fffiloni commited on
Commit
aab2a94
1 Parent(s): baba962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -3
app.py CHANGED
@@ -1,5 +1,73 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load(
4
- "spaces/jjeamin/ArcaneStyleTransfer", inputs=[gr.Image(source="webcam", type="filepath", label="Input"), gr.Radio(['True','False'], type="value", default='True', label='face align')], title="Remove your webcam background!"
5
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip freeze")
3
+
4
+ import torch
5
+ import PIL
6
  import gradio as gr
7
+ import torch
8
+ from utils import align_face
9
+ from torchvision import transforms
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+
14
+ image_size = 512
15
+ transform_size = 1024
16
+
17
+ means = [0.5, 0.5, 0.5]
18
+ stds = [0.5, 0.5, 0.5]
19
+
20
+ img_transforms = transforms.Compose([
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(means, stds)])
23
+
24
+ model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")
25
+
26
+ if 'cuda' in device:
27
+ style_transfer = torch.jit.load(model_path).eval().cuda().half()
28
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
29
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
30
+ else:
31
+ style_transfer = torch.jit.load(model_path).eval().cpu()
32
+ t_stds = torch.tensor(stds).cpu()[:,None,None]
33
+ t_means = torch.tensor(means).cpu()[:,None,None]
34
+
35
+ def tensor2im(var):
36
+ return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
37
+
38
+ def proc_pil_img(input_image):
39
+ if 'cuda' in device:
40
+ transformed_image = img_transforms(input_image)[None,...].cuda().half()
41
+ else:
42
+ transformed_image = img_transforms(input_image)[None,...].cpu()
43
+
44
+ with torch.no_grad():
45
+ result_image = style_transfer(transformed_image)[0]
46
+ output_image = tensor2im(result_image)
47
+ output_image = output_image.detach().cpu().numpy().astype('uint8')
48
+ output_image = PIL.Image.fromarray(output_image)
49
+ return output_image
50
 
51
+ def process(im, is_align):
52
+ im = PIL.ImageOps.exif_transpose(im)
53
+
54
+ if is_align == 'True':
55
+ im = align_face(im, output_size=image_size, transform_size=transform_size)
56
+ else:
57
+ pass
58
+
59
+ res = proc_pil_img(im)
60
+
61
+ return res
62
+
63
+ gr.Interface(
64
+ process,
65
+ inputs=[gr.inputs.Image(source="webcam",type="pil", label="Input", shape=(image_size, image_size)), gr.inputs.Radio(['True','False'], type="value", default='True', label='face align')],
66
+ outputs=gr.outputs.Image(type="pil", label="Output"),
67
+ title="Arcane Style Transfer",
68
+ description="Gradio demo for Arcane Style Transfer",
69
+ article = "<p style='text-align: center'><a href='https://github.com/jjeamin/anime_style_transfer_pytorch' target='_blank'>Github Repo by jjeamin</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=jjeamin_arcane_st' alt='visitor badge'></center></p>",
70
+ enable_queue=True,
71
+ allow_flagging=False,
72
+ allow_screenshot=False
73
+ ).launch(enable_queue=True)