FantasticGNU commited on
Commit
9db944c
1 Parent(s): 7b352a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -28,10 +28,7 @@ delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'
28
  model.load_state_dict(delta_ckpt, strict=False)
29
  delta_ckpt = torch.load(args['anomalygpt_ckpt_path'], map_location=torch.device('cpu'))
30
  model.load_state_dict(delta_ckpt, strict=False)
31
- model = model.eval().to(torch.bfloat16)
32
-
33
-
34
- output = None
35
 
36
  """Override Chatbot.postprocess"""
37
  def postprocess(self, y):
@@ -155,28 +152,19 @@ def predict(
155
  eroded_image = cv2.erode(image, kernel, iterations=1)
156
  cv2.imwrite('output.png', eroded_image)
157
 
158
- global output
159
  output = PILImage.open('output.png').convert('L')
160
 
161
 
162
- return chatbot, history, modality_cache
163
 
164
 
165
- def get_image():
166
- global output
167
- return output if output else "ffffff.png"
168
-
169
 
170
  def reset_user_input():
171
  return gr.update(value='')
172
 
173
- def reset_dialog():
174
- return [], []
175
 
176
  def reset_state():
177
- global output
178
- output = None
179
- return None, None, [], [], []
180
 
181
 
182
 
@@ -188,7 +176,7 @@ with gr.Blocks() as demo:
188
  with gr.Row(scale=3):
189
  image_path = gr.Image(type="filepath", label="Query Image", value=None)
190
  with gr.Row(scale=3):
191
- normal_img_path = gr.Image(type="filepath", label="Normal Image", value=None)
192
  with gr.Row():
193
  max_length = gr.Slider(0, 512, value=512, step=1.0, label="Maximum length", interactive=True)
194
  with gr.Row():
@@ -200,10 +188,10 @@ with gr.Blocks() as demo:
200
  with gr.Column(scale=3):
201
  with gr.Row():
202
  with gr.Column(scale=6):
203
- chatbot = gr.Chatbot().style(height=415)
204
  with gr.Column(scale=4):
205
  # gr.Image(output)
206
- image_output = gr.Image(value=get_image, label="Localization Output", every=1.0, shape=[224,224])
207
  with gr.Row():
208
  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
209
  with gr.Row():
@@ -229,18 +217,21 @@ with gr.Blocks() as demo:
229
  ], [
230
  chatbot,
231
  history,
232
- modality_cache
 
233
  ],
234
  show_progress=True
235
  )
236
 
237
  submitBtn.click(reset_user_input, [], [user_input])
238
  emptyBtn.click(reset_state, outputs=[
 
239
  image_path,
240
  normal_img_path,
241
  chatbot,
242
  history,
243
- modality_cache
 
244
  ], show_progress=True)
245
 
246
 
 
28
  model.load_state_dict(delta_ckpt, strict=False)
29
  delta_ckpt = torch.load(args['anomalygpt_ckpt_path'], map_location=torch.device('cpu'))
30
  model.load_state_dict(delta_ckpt, strict=False)
31
+ model = model.eval().to(torch.bfloat16)#.half()#.cuda()
 
 
 
32
 
33
  """Override Chatbot.postprocess"""
34
  def postprocess(self, y):
 
152
  eroded_image = cv2.erode(image, kernel, iterations=1)
153
  cv2.imwrite('output.png', eroded_image)
154
 
 
155
  output = PILImage.open('output.png').convert('L')
156
 
157
 
158
+ return chatbot, history, modality_cache, output
159
 
160
 
 
 
 
 
161
 
162
  def reset_user_input():
163
  return gr.update(value='')
164
 
 
 
165
 
166
  def reset_state():
167
+ return gr.update(value=''), None, None, [], [], [], PILImage.open('ffffff.png')
 
 
168
 
169
 
170
 
 
176
  with gr.Row(scale=3):
177
  image_path = gr.Image(type="filepath", label="Query Image", value=None)
178
  with gr.Row(scale=3):
179
+ normal_img_path = gr.Image(type="filepath", label="Normal Image (optional)", value=None)
180
  with gr.Row():
181
  max_length = gr.Slider(0, 512, value=512, step=1.0, label="Maximum length", interactive=True)
182
  with gr.Row():
 
188
  with gr.Column(scale=3):
189
  with gr.Row():
190
  with gr.Column(scale=6):
191
+ chatbot = gr.Chatbot().style(height=430)
192
  with gr.Column(scale=4):
193
  # gr.Image(output)
194
+ image_output = gr.Image(interactive=False, label="Localization Output", every=1.0, shape=[224,224], type='pil',value=PILImage.open('ffffff.png'))
195
  with gr.Row():
196
  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
197
  with gr.Row():
 
217
  ], [
218
  chatbot,
219
  history,
220
+ modality_cache,
221
+ image_output
222
  ],
223
  show_progress=True
224
  )
225
 
226
  submitBtn.click(reset_user_input, [], [user_input])
227
  emptyBtn.click(reset_state, outputs=[
228
+ user_input,
229
  image_path,
230
  normal_img_path,
231
  chatbot,
232
  history,
233
+ modality_cache,
234
+ image_output
235
  ], show_progress=True)
236
 
237