justin-zk commited on
Commit
17b7b8e
1 Parent(s): c2dade9
Files changed (3) hide show
  1. app.py +553 -0
  2. requirements.txt +7 -0
  3. show.py +28 -0
app.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # PersonalizeSAM -- Personalize Segment Anything Model with One Shot
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # --------------------------------------------------------
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import gradio as gr
9
+ import numpy as np
10
+ from torch.nn import functional as F
11
+
12
+ from show import *
13
+ from per_segment_anything import sam_model_registry, SamPredictor
14
+
15
+
16
+ class ImageMask(gr.components.Image):
17
+ """
18
+ Sets: source="canvas", tool="sketch"
19
+ """
20
+
21
+ is_template = True
22
+
23
+ def __init__(self, **kwargs):
24
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
25
+
26
+ def preprocess(self, x):
27
+ return super().preprocess(x)
28
+
29
+
30
+ class Mask_Weights(nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
34
+
35
+
36
+ def point_selection(mask_sim, topk=1):
37
+ # Top-1 point selection
38
+ w, h = mask_sim.shape
39
+ topk_xy = mask_sim.flatten(0).topk(topk)[1]
40
+ topk_x = (topk_xy // h).unsqueeze(0)
41
+ topk_y = (topk_xy - topk_x * h)
42
+ topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
43
+ topk_label = np.array([1] * topk)
44
+ topk_xy = topk_xy.cpu().numpy()
45
+
46
+ # Top-last point selection
47
+ last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1]
48
+ last_x = (last_xy // h).unsqueeze(0)
49
+ last_y = (last_xy - last_x * h)
50
+ last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0)
51
+ last_label = np.array([0] * topk)
52
+ last_xy = last_xy.cpu().numpy()
53
+
54
+ return topk_xy, topk_label, last_xy, last_label
55
+
56
+
57
+ def calculate_dice_loss(inputs, targets, num_masks = 1):
58
+ """
59
+ Compute the DICE loss, similar to generalized IOU for masks
60
+ Args:
61
+ inputs: A float tensor of arbitrary shape.
62
+ The predictions for each example.
63
+ targets: A float tensor with the same shape as inputs. Stores the binary
64
+ classification label for each element in inputs
65
+ (0 for the negative class and 1 for the positive class).
66
+ """
67
+ inputs = inputs.sigmoid()
68
+ inputs = inputs.flatten(1)
69
+ numerator = 2 * (inputs * targets).sum(-1)
70
+ denominator = inputs.sum(-1) + targets.sum(-1)
71
+ loss = 1 - (numerator + 1) / (denominator + 1)
72
+ return loss.sum() / num_masks
73
+
74
+
75
+ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
76
+ """
77
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
78
+ Args:
79
+ inputs: A float tensor of arbitrary shape.
80
+ The predictions for each example.
81
+ targets: A float tensor with the same shape as inputs. Stores the binary
82
+ classification label for each element in inputs
83
+ (0 for the negative class and 1 for the positive class).
84
+ alpha: (optional) Weighting factor in range (0,1) to balance
85
+ positive vs negative examples. Default = -1 (no weighting).
86
+ gamma: Exponent of the modulating factor (1 - p_t) to
87
+ balance easy vs hard examples.
88
+ Returns:
89
+ Loss tensor
90
+ """
91
+ prob = inputs.sigmoid()
92
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
93
+ p_t = prob * targets + (1 - prob) * (1 - targets)
94
+ loss = ce_loss * ((1 - p_t) ** gamma)
95
+
96
+ if alpha >= 0:
97
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
98
+ loss = alpha_t * loss
99
+
100
+ return loss.mean(1).sum() / num_masks
101
+
102
+
103
+ def inference(ic_image, ic_mask, image1, image2):
104
+ # in context image and mask
105
+ ic_image = np.array(ic_image.convert("RGB"))
106
+ ic_mask = np.array(ic_mask.convert("RGB"))
107
+
108
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
109
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
110
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
111
+ predictor = SamPredictor(sam)
112
+
113
+ # Image features encoding
114
+ ref_mask = predictor.set_image(ic_image, ic_mask)
115
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
116
+
117
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
118
+ ref_mask = ref_mask.squeeze()[0]
119
+
120
+ # Target feature extraction
121
+ print("======> Obtain Location Prior" )
122
+ target_feat = ref_feat[ref_mask > 0]
123
+ target_embedding = target_feat.mean(0).unsqueeze(0)
124
+ target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
125
+ target_embedding = target_embedding.unsqueeze(0)
126
+
127
+ output_image = []
128
+
129
+ for test_image in [image1, image2]:
130
+ print("======> Testing Image" )
131
+ test_image = np.array(test_image.convert("RGB"))
132
+
133
+ # Image feature encoding
134
+ predictor.set_image(test_image)
135
+ test_feat = predictor.features.squeeze()
136
+
137
+ # Cosine similarity
138
+ C, h, w = test_feat.shape
139
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
140
+ test_feat = test_feat.reshape(C, h * w)
141
+ sim = target_feat @ test_feat
142
+
143
+ sim = sim.reshape(1, 1, h, w)
144
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
145
+ sim = predictor.model.postprocess_masks(
146
+ sim,
147
+ input_size=predictor.input_size,
148
+ original_size=predictor.original_size).squeeze()
149
+
150
+ # Positive-negative location prior
151
+ topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
152
+ topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
153
+ topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
154
+
155
+ # Obtain the target guidance for cross-attention layers
156
+ sim = (sim - sim.mean()) / torch.std(sim)
157
+ sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
158
+ attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
159
+
160
+ # First-step prediction
161
+ masks, scores, logits, _ = predictor.predict(
162
+ point_coords=topk_xy,
163
+ point_labels=topk_label,
164
+ multimask_output=False,
165
+ attn_sim=attn_sim, # Target-guided Attention
166
+ target_embedding=target_embedding # Target-semantic Prompting
167
+ )
168
+ best_idx = 0
169
+
170
+ # Cascaded Post-refinement-1
171
+ masks, scores, logits, _ = predictor.predict(
172
+ point_coords=topk_xy,
173
+ point_labels=topk_label,
174
+ mask_input=logits[best_idx: best_idx + 1, :, :],
175
+ multimask_output=True)
176
+ best_idx = np.argmax(scores)
177
+
178
+ # Cascaded Post-refinement-2
179
+ y, x = np.nonzero(masks[best_idx])
180
+ x_min = x.min()
181
+ x_max = x.max()
182
+ y_min = y.min()
183
+ y_max = y.max()
184
+ input_box = np.array([x_min, y_min, x_max, y_max])
185
+ masks, scores, logits, _ = predictor.predict(
186
+ point_coords=topk_xy,
187
+ point_labels=topk_label,
188
+ box=input_box[None, :],
189
+ mask_input=logits[best_idx: best_idx + 1, :, :],
190
+ multimask_output=True)
191
+ best_idx = np.argmax(scores)
192
+
193
+ final_mask = masks[best_idx]
194
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
195
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
196
+ output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
197
+
198
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
199
+
200
+
201
+ def inference_scribble(image, image1, image2):
202
+ # in context image and mask
203
+ ic_image = image["image"]
204
+ ic_mask = image["mask"]
205
+ ic_image = np.array(ic_image.convert("RGB"))
206
+ ic_mask = np.array(ic_mask.convert("RGB"))
207
+
208
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
209
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
210
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
211
+ predictor = SamPredictor(sam)
212
+
213
+ # Image features encoding
214
+ ref_mask = predictor.set_image(ic_image, ic_mask)
215
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
216
+
217
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
218
+ ref_mask = ref_mask.squeeze()[0]
219
+
220
+ # Target feature extraction
221
+ print("======> Obtain Location Prior" )
222
+ target_feat = ref_feat[ref_mask > 0]
223
+ target_embedding = target_feat.mean(0).unsqueeze(0)
224
+ target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
225
+ target_embedding = target_embedding.unsqueeze(0)
226
+
227
+ output_image = []
228
+
229
+ for test_image in [image1, image2]:
230
+ print("======> Testing Image" )
231
+ test_image = np.array(test_image.convert("RGB"))
232
+
233
+ # Image feature encoding
234
+ predictor.set_image(test_image)
235
+ test_feat = predictor.features.squeeze()
236
+
237
+ # Cosine similarity
238
+ C, h, w = test_feat.shape
239
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
240
+ test_feat = test_feat.reshape(C, h * w)
241
+ sim = target_feat @ test_feat
242
+
243
+ sim = sim.reshape(1, 1, h, w)
244
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
245
+ sim = predictor.model.postprocess_masks(
246
+ sim,
247
+ input_size=predictor.input_size,
248
+ original_size=predictor.original_size).squeeze()
249
+
250
+ # Positive-negative location prior
251
+ topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
252
+ topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
253
+ topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
254
+
255
+ # Obtain the target guidance for cross-attention layers
256
+ sim = (sim - sim.mean()) / torch.std(sim)
257
+ sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
258
+ attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
259
+
260
+ # First-step prediction
261
+ masks, scores, logits, _ = predictor.predict(
262
+ point_coords=topk_xy,
263
+ point_labels=topk_label,
264
+ multimask_output=False,
265
+ attn_sim=attn_sim, # Target-guided Attention
266
+ target_embedding=target_embedding # Target-semantic Prompting
267
+ )
268
+ best_idx = 0
269
+
270
+ # Cascaded Post-refinement-1
271
+ masks, scores, logits, _ = predictor.predict(
272
+ point_coords=topk_xy,
273
+ point_labels=topk_label,
274
+ mask_input=logits[best_idx: best_idx + 1, :, :],
275
+ multimask_output=True)
276
+ best_idx = np.argmax(scores)
277
+
278
+ # Cascaded Post-refinement-2
279
+ y, x = np.nonzero(masks[best_idx])
280
+ x_min = x.min()
281
+ x_max = x.max()
282
+ y_min = y.min()
283
+ y_max = y.max()
284
+ input_box = np.array([x_min, y_min, x_max, y_max])
285
+ masks, scores, logits, _ = predictor.predict(
286
+ point_coords=topk_xy,
287
+ point_labels=topk_label,
288
+ box=input_box[None, :],
289
+ mask_input=logits[best_idx: best_idx + 1, :, :],
290
+ multimask_output=True)
291
+ best_idx = np.argmax(scores)
292
+
293
+ final_mask = masks[best_idx]
294
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
295
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
296
+ output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
297
+
298
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
299
+
300
+
301
+ def inference_finetune(ic_image, ic_mask, image1, image2):
302
+ # in context image and mask
303
+ ic_image = np.array(ic_image.convert("RGB"))
304
+ ic_mask = np.array(ic_mask.convert("RGB"))
305
+
306
+ gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
307
+ gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
308
+ # gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
309
+
310
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
311
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
312
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
313
+ for name, param in sam.named_parameters():
314
+ param.requires_grad = False
315
+ predictor = SamPredictor(sam)
316
+
317
+ print("======> Obtain Self Location Prior" )
318
+ # Image features encoding
319
+ ref_mask = predictor.set_image(ic_image, ic_mask)
320
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
321
+
322
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
323
+ ref_mask = ref_mask.squeeze()[0]
324
+
325
+ # Target feature extraction
326
+ target_feat = ref_feat[ref_mask > 0]
327
+ target_feat_mean = target_feat.mean(0)
328
+ target_feat_max = torch.max(target_feat, dim=0)[0]
329
+ target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
330
+
331
+ # Cosine similarity
332
+ h, w, C = ref_feat.shape
333
+ target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
334
+ ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
335
+ ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
336
+ sim = target_feat @ ref_feat
337
+
338
+ sim = sim.reshape(1, 1, h, w)
339
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
340
+ sim = predictor.model.postprocess_masks(
341
+ sim,
342
+ input_size=predictor.input_size,
343
+ original_size=predictor.original_size).squeeze()
344
+
345
+ # Positive location prior
346
+ topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
347
+
348
+ print('======> Start Training')
349
+ # Learnable mask weights
350
+ mask_weights = Mask_Weights().cuda()
351
+ # mask_weights = Mask_Weights()
352
+ mask_weights.train()
353
+ train_epoch = 1000
354
+ optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
355
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epoch)
356
+
357
+ for train_idx in range(train_epoch):
358
+ # Run the decoder
359
+ masks, scores, logits, logits_high = predictor.predict(
360
+ point_coords=topk_xy,
361
+ point_labels=topk_label,
362
+ multimask_output=True)
363
+ logits_high = logits_high.flatten(1)
364
+
365
+ # Weighted sum three-scale masks
366
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
367
+ logits_high = logits_high * weights
368
+ logits_high = logits_high.sum(0).unsqueeze(0)
369
+
370
+ dice_loss = calculate_dice_loss(logits_high, gt_mask)
371
+ focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
372
+ loss = dice_loss + focal_loss
373
+
374
+ optimizer.zero_grad()
375
+ loss.backward()
376
+ optimizer.step()
377
+ scheduler.step()
378
+
379
+ if train_idx % 10 == 0:
380
+ print('Train Epoch: {:} / {:}'.format(train_idx, train_epoch))
381
+ current_lr = scheduler.get_last_lr()[0]
382
+ print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
383
+
384
+
385
+ mask_weights.eval()
386
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
387
+ weights_np = weights.detach().cpu().numpy()
388
+ print('======> Mask weights:\n', weights_np)
389
+
390
+ print('======> Start Testing')
391
+ output_image = []
392
+
393
+ for test_image in [image1, image2]:
394
+ test_image = np.array(test_image.convert("RGB"))
395
+
396
+ # Image feature encoding
397
+ predictor.set_image(test_image)
398
+ test_feat = predictor.features.squeeze()
399
+ # Image feature encoding
400
+ predictor.set_image(test_image)
401
+ test_feat = predictor.features.squeeze()
402
+
403
+ # Cosine similarity
404
+ C, h, w = test_feat.shape
405
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
406
+ test_feat = test_feat.reshape(C, h * w)
407
+ sim = target_feat @ test_feat
408
+
409
+ sim = sim.reshape(1, 1, h, w)
410
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
411
+ sim = predictor.model.postprocess_masks(
412
+ sim,
413
+ input_size=predictor.input_size,
414
+ original_size=predictor.original_size).squeeze()
415
+
416
+ # Positive location prior
417
+ topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
418
+
419
+ # First-step prediction
420
+ masks, scores, logits, logits_high = predictor.predict(
421
+ point_coords=topk_xy,
422
+ point_labels=topk_label,
423
+ multimask_output=True)
424
+
425
+ # Weighted sum three-scale masks
426
+ logits_high = logits_high * weights.unsqueeze(-1)
427
+ logit_high = logits_high.sum(0)
428
+ mask = (logit_high > 0).detach().cpu().numpy()
429
+
430
+ logits = logits * weights_np[..., None]
431
+ logit = logits.sum(0)
432
+
433
+ # Cascaded Post-refinement-1
434
+ y, x = np.nonzero(mask)
435
+ x_min = x.min()
436
+ x_max = x.max()
437
+ y_min = y.min()
438
+ y_max = y.max()
439
+ input_box = np.array([x_min, y_min, x_max, y_max])
440
+ masks, scores, logits, _ = predictor.predict(
441
+ point_coords=topk_xy,
442
+ point_labels=topk_label,
443
+ box=input_box[None, :],
444
+ mask_input=logit[None, :, :],
445
+ multimask_output=True)
446
+ best_idx = np.argmax(scores)
447
+
448
+ # Cascaded Post-refinement-2
449
+ y, x = np.nonzero(masks[best_idx])
450
+ x_min = x.min()
451
+ x_max = x.max()
452
+ y_min = y.min()
453
+ y_max = y.max()
454
+ input_box = np.array([x_min, y_min, x_max, y_max])
455
+ masks, scores, logits, _ = predictor.predict(
456
+ point_coords=topk_xy,
457
+ point_labels=topk_label,
458
+ box=input_box[None, :],
459
+ mask_input=logits[best_idx: best_idx + 1, :, :],
460
+ multimask_output=True)
461
+ best_idx = np.argmax(scores)
462
+
463
+ final_mask = masks[best_idx]
464
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
465
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
466
+ output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
467
+
468
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
469
+
470
+
471
+ description = """
472
+ <div style="text-align: center; font-weight: bold;">
473
+ <span style="font-size: 18px" id="paper-info">
474
+ [<a href="https://github.com/ZrrSkywalker/Personalize-SAM" target="_blank">GitHub</a>]
475
+ [<a href="https://arxiv.org/pdf/2305.03048.pdf" target="_blank">Paper</a>]
476
+ </span>
477
+ </div>
478
+ """
479
+
480
+ main = gr.Interface(
481
+ fn=inference,
482
+ inputs=[
483
+ gr.Image(type="pil", label="in context image",),
484
+ gr.Image(type="pil", label="in context mask"),
485
+ gr.Image(type="pil", label="test image1"),
486
+ gr.Image(type="pil", label="test image2"),
487
+ ],
488
+ outputs=[
489
+ gr.outputs.Image(type="pil", label="output image1"),
490
+ gr.outputs.Image(type="pil", label="output image2"),
491
+ ],
492
+ allow_flagging="never",
493
+ title="Personalize Segment Anything Model with 1 Shot",
494
+ description=description,
495
+ examples=[
496
+ ["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
497
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
498
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
499
+ ]
500
+ )
501
+
502
+ main_scribble = gr.Interface(
503
+ fn=inference_scribble,
504
+ inputs=[
505
+ gr.ImageMask(label="[Stroke] Draw on Image", type="pil"),
506
+ gr.Image(type="pil", label="test image1"),
507
+ gr.Image(type="pil", label="test image2"),
508
+ ],
509
+ outputs=[
510
+ gr.outputs.Image(type="pil", label="output image1"),
511
+ gr.outputs.Image(type="pil", label="output image2"),
512
+ ],
513
+ allow_flagging="never",
514
+ title="Personalize Segment Anything Model with 1 Shot",
515
+ description=description,
516
+ examples=[
517
+ ["./examples/cat_00.jpg", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
518
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
519
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
520
+ ]
521
+ )
522
+
523
+ main_finetune = gr.Interface(
524
+ fn=inference_finetune,
525
+ inputs=[
526
+ gr.Image(type="pil", label="in context image"),
527
+ gr.Image(type="pil", label="in context mask"),
528
+ gr.Image(type="pil", label="test image1"),
529
+ gr.Image(type="pil", label="test image2"),
530
+ ],
531
+ outputs=[
532
+ gr.components.Image(type="pil", label="output image1"),
533
+ gr.components.Image(type="pil", label="output image2"),
534
+ ],
535
+ allow_flagging="never",
536
+ title="Personalize Segment Anything Model with 1 Shot",
537
+ description=description,
538
+ examples=[
539
+ ["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
540
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
541
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
542
+ ]
543
+ )
544
+
545
+
546
+ demo = gr.Blocks()
547
+ with demo:
548
+ gr.TabbedInterface(
549
+ [main, main_scribble, main_finetune],
550
+ ["Personalize-SAM", "Personalize-SAM-Scribble", "Personalize-SAM-F"],
551
+ )
552
+
553
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ tqdm
3
+ os
4
+ numpy
5
+ warnings
6
+ argparse
7
+ opencv-python
show.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+
6
+
7
+
8
+ def show_mask(mask, ax, random_color=False):
9
+ if random_color:
10
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
11
+ else:
12
+ color = np.array([30/255, 144/255, 255/255, 0.4])
13
+ h, w = mask.shape[-2:]
14
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
15
+ ax.imshow(mask_image)
16
+
17
+
18
+ def show_points(coords, labels, ax, marker_size=375):
19
+ pos_points = coords[labels==1]
20
+ neg_points = coords[labels==0]
21
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
22
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
23
+
24
+
25
+ def show_box(box, ax):
26
+ x0, y0 = box[0], box[1]
27
+ w, h = box[2] - box[0], box[3] - box[1]
28
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))