danielsapit commited on
Commit
10907b9
1 Parent(s): bff44e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -74
app.py CHANGED
@@ -17,7 +17,7 @@ for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
17
  r = requests.get(url, allow_redirects=True)
18
  open(model_path, 'wb').write(r.content)
19
 
20
- def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_shift, state):
21
 
22
  if is_gray:
23
  n_channels = 1 # set 1 for grayscale image, set 3 for color image
@@ -46,59 +46,57 @@ def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_s
46
  # ----------------------------------------
47
  # load model
48
  # ----------------------------------------
49
- if (not enable_zoom) or (state[1] is None):
50
- model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
51
- model.load_state_dict(torch.load(model_path), strict=True)
52
- model.eval()
53
- for k, v in model.named_parameters():
54
- v.requires_grad = False
55
- model = model.to(device)
56
-
57
- test_results = OrderedDict()
58
- test_results['psnr'] = []
59
- test_results['ssim'] = []
60
- test_results['psnrb'] = []
61
-
62
- # ------------------------------------
63
- # (1) img_L
64
- # ------------------------------------
65
-
66
- if n_channels == 1:
67
- open_cv_image = Image.fromarray(input_img)
68
- open_cv_image = ImageOps.grayscale(open_cv_image)
69
- open_cv_image = np.array(open_cv_image) # PIL to open cv image
70
- img = np.expand_dims(open_cv_image, axis=2) # HxWx1
71
- elif n_channels == 3:
72
- open_cv_image = np.array(input_img) # PIL to open cv image
73
- if open_cv_image.ndim == 2:
74
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB) # GGG
75
- else:
76
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) # RGB
77
-
78
- img_L = util.uint2tensor4(open_cv_image)
79
- img_L = img_L.to(device)
80
-
81
- # ------------------------------------
82
- # (2) img_E
83
- # ------------------------------------
84
-
85
- img_E,QF = model(img_L)
86
- QF = 1- QF
87
- img_E = util.tensor2single(img_E)
88
- img_E = util.single2uint(img_E)
89
-
90
- qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
91
- img_E,QF = model(img_L, qf_input)
92
- QF = 1- QF
93
- img_E = util.tensor2single(img_E)
94
- img_E = util.single2uint(img_E)
95
-
96
- if img_E.ndim == 3:
97
- img_E = img_E[:, :, [2, 1, 0]]
98
-
99
- print("--inference finished")
100
- if (state[1] is not None) and enable_zoom:
101
- img_E = state[1]
102
  out_img = Image.fromarray(img_E)
103
  out_img_w, out_img_h = out_img.size # output image size
104
  zoom = zoom/100
@@ -107,46 +105,37 @@ def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_s
107
  zoom_w, zoom_h = out_img_w*zoom, out_img_h*zoom
108
  zoom_left, zoom_right = int((out_img_w - zoom_w)*x_shift), int(zoom_w + (out_img_w - zoom_w)*x_shift)
109
  zoom_top, zoom_bottom = int((out_img_h - zoom_h)*y_shift), int(zoom_h + (out_img_h - zoom_h)*y_shift)
110
- if (state[0] is None) or not enable_zoom:
111
- in_img = Image.fromarray(input_img)
112
- state[0] = input_img
113
- else:
114
- in_img = Image.fromarray(state[0])
115
  in_img = in_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
116
  in_img = in_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
117
  out_img = out_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
118
  out_img = out_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
119
 
120
- return img_E, in_img, out_img, [state[0],img_E]
121
 
122
  gr.Interface(
123
  fn = inference,
124
  inputs = [gr.inputs.Image(label="Input Image"),
125
  gr.inputs.Checkbox(label="Grayscale (Check this if your image is grayscale)"),
126
  gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)"),
127
- gr.inputs.Checkbox(default=False, label="Edit Zoom preview (This is optional. "
128
- "After the image result is loaded, check this to edit zoom parameters "
129
- "so that the input image will not be processed when the submit button is pressed.)"),
130
  gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Zoom Image "
131
  "(Use this to see the image quality up close. "
132
  "100 = original size)"),
133
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview horizontal shift "
134
  "(Increase to shift to the right)"),
135
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview vertical shift "
136
- "(Increase to shift downwards)"),
137
- gr.inputs.State(default=[None,None], label="\t")
138
  ],
139
  outputs = [gr.outputs.Image(label="Result"),
140
  gr.outputs.Image(label="Before:"),
141
- gr.outputs.Image(label="After:"),
142
- "state"],
143
- examples = [["doraemon.jpg",False,60,False,42,50,50],
144
- ["tomandjerry.jpg",False,60,False,40,57,44],
145
- ["somepanda.jpg",True,100,False,30,8,24],
146
- ["cemetry.jpg",False,70,False,20,76,62],
147
- ["michelangelo_david.jpg",True,30,False,12,53,27],
148
- ["elon_musk.jpg",False,45,False,15,33,30],
149
- ["text.jpg",True,70,False,50,11,29]],
150
  title = "JPEG Artifacts Removal [FBCNN]",
151
  description = "Gradio Demo for JPEG Artifacts Removal. To use it, simply upload your image, "
152
  "or click one of the examples to load them. Check out the paper and the original GitHub repo at the link below. "
 
17
  r = requests.get(url, allow_redirects=True)
18
  open(model_path, 'wb').write(r.content)
19
 
20
+ def inference(input_img, is_gray, input_quality, zoom, x_shift, y_shift):
21
 
22
  if is_gray:
23
  n_channels = 1 # set 1 for grayscale image, set 3 for color image
 
46
  # ----------------------------------------
47
  # load model
48
  # ----------------------------------------
49
+ model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
50
+ model.load_state_dict(torch.load(model_path), strict=True)
51
+ model.eval()
52
+ for k, v in model.named_parameters():
53
+ v.requires_grad = False
54
+ model = model.to(device)
55
+
56
+ test_results = OrderedDict()
57
+ test_results['psnr'] = []
58
+ test_results['ssim'] = []
59
+ test_results['psnrb'] = []
60
+
61
+ # ------------------------------------
62
+ # (1) img_L
63
+ # ------------------------------------
64
+
65
+ if n_channels == 1:
66
+ open_cv_image = Image.fromarray(input_img)
67
+ open_cv_image = ImageOps.grayscale(open_cv_image)
68
+ open_cv_image = np.array(open_cv_image) # PIL to open cv image
69
+ img = np.expand_dims(open_cv_image, axis=2) # HxWx1
70
+ elif n_channels == 3:
71
+ open_cv_image = np.array(input_img) # PIL to open cv image
72
+ if open_cv_image.ndim == 2:
73
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB) # GGG
74
+ else:
75
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) # RGB
76
+
77
+ img_L = util.uint2tensor4(open_cv_image)
78
+ img_L = img_L.to(device)
79
+
80
+ # ------------------------------------
81
+ # (2) img_E
82
+ # ------------------------------------
83
+
84
+ img_E,QF = model(img_L)
85
+ QF = 1- QF
86
+ img_E = util.tensor2single(img_E)
87
+ img_E = util.single2uint(img_E)
88
+
89
+ qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
90
+ img_E,QF = model(img_L, qf_input)
91
+ QF = 1- QF
92
+ img_E = util.tensor2single(img_E)
93
+ img_E = util.single2uint(img_E)
94
+
95
+ if img_E.ndim == 3:
96
+ img_E = img_E[:, :, [2, 1, 0]]
97
+
98
+ print("--inference finished")
99
+
 
 
100
  out_img = Image.fromarray(img_E)
101
  out_img_w, out_img_h = out_img.size # output image size
102
  zoom = zoom/100
 
105
  zoom_w, zoom_h = out_img_w*zoom, out_img_h*zoom
106
  zoom_left, zoom_right = int((out_img_w - zoom_w)*x_shift), int(zoom_w + (out_img_w - zoom_w)*x_shift)
107
  zoom_top, zoom_bottom = int((out_img_h - zoom_h)*y_shift), int(zoom_h + (out_img_h - zoom_h)*y_shift)
108
+ in_img = Image.fromarray(input_img)
 
 
 
 
109
  in_img = in_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
110
  in_img = in_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
111
  out_img = out_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
112
  out_img = out_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
113
 
114
+ return img_E, in_img, out_img
115
 
116
  gr.Interface(
117
  fn = inference,
118
  inputs = [gr.inputs.Image(label="Input Image"),
119
  gr.inputs.Checkbox(label="Grayscale (Check this if your image is grayscale)"),
120
  gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)"),
 
 
 
121
  gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Zoom Image "
122
  "(Use this to see the image quality up close. "
123
  "100 = original size)"),
124
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview horizontal shift "
125
  "(Increase to shift to the right)"),
126
  gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview vertical shift "
127
+ "(Increase to shift downwards)")
 
128
  ],
129
  outputs = [gr.outputs.Image(label="Result"),
130
  gr.outputs.Image(label="Before:"),
131
+ gr.outputs.Image(label="After:")],
132
+ examples = [["doraemon.jpg",False,60,42,50,50],
133
+ ["tomandjerry.jpg",False,60,40,57,44],
134
+ ["somepanda.jpg",True,100,30,8,24],
135
+ ["cemetry.jpg",False,70,20,76,62],
136
+ ["michelangelo_david.jpg",True,30,12,53,27],
137
+ ["elon_musk.jpg",False,45,15,33,30],
138
+ ["text.jpg",True,70,50,11,29]],
 
139
  title = "JPEG Artifacts Removal [FBCNN]",
140
  description = "Gradio Demo for JPEG Artifacts Removal. To use it, simply upload your image, "
141
  "or click one of the examples to load them. Check out the paper and the original GitHub repo at the link below. "