fffiloni commited on
Commit
aafe80a
1 Parent(s): a4f8a15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -8,6 +8,14 @@ from PIL import Image
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
 
 
 
 
 
 
 
 
11
  def show_mask(mask, ax, random_color=False, borders = True):
12
  if random_color:
13
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
11
+ # use bfloat16 for the entire notebook
12
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
13
+
14
+ if torch.cuda.get_device_properties(0).major >= 8:
15
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
  def show_mask(mask, ax, random_color=False, borders = True):
20
  if random_color:
21
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)