OriLib commited on
Commit
83045c9
1 Parent(s): 33c41df

Update MyPipe.py

Browse files
Files changed (1) hide show
  1. MyPipe.py +5 -3
MyPipe.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  class RMBGPipe(Pipeline):
10
  def __init__(self,**kwargs):
11
  Pipeline.__init__(self,**kwargs)
12
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  self.model.to(self.device)
14
  self.model.eval()
15
 
@@ -39,6 +39,7 @@ class RMBGPipe(Pipeline):
39
  result = self.model(inputs.pop("image"))
40
  inputs["result"] = result
41
  return inputs
 
42
  def postprocess(self,inputs,return_mask:bool=False ):
43
  result = inputs.pop("result")
44
  orig_im_size = inputs.pop("orig_im_size")
@@ -48,7 +49,7 @@ class RMBGPipe(Pipeline):
48
  if return_mask ==True :
49
  return pil_im
50
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
- orig_image = Image.open(im_path)
52
  no_bg_image.paste(orig_image, mask=pil_im)
53
  return no_bg_image
54
 
@@ -59,10 +60,11 @@ class RMBGPipe(Pipeline):
59
  im = im[:, :, np.newaxis]
60
  # orig_im_size=im.shape[0:2]
61
  im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
62
- im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
63
  image = torch.divide(im_tensor,255.0)
64
  image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
65
  return image
 
66
  def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
67
  result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
68
  ma = torch.max(result)
 
9
  class RMBGPipe(Pipeline):
10
  def __init__(self,**kwargs):
11
  Pipeline.__init__(self,**kwargs)
12
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
  self.model.to(self.device)
14
  self.model.eval()
15
 
 
39
  result = self.model(inputs.pop("image"))
40
  inputs["result"] = result
41
  return inputs
42
+
43
  def postprocess(self,inputs,return_mask:bool=False ):
44
  result = inputs.pop("result")
45
  orig_im_size = inputs.pop("orig_im_size")
 
49
  if return_mask ==True :
50
  return pil_im
51
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
52
+ orig_image = Image.fromarray(io.imread(im_path))
53
  no_bg_image.paste(orig_image, mask=pil_im)
54
  return no_bg_image
55
 
 
60
  im = im[:, :, np.newaxis]
61
  # orig_im_size=im.shape[0:2]
62
  im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
63
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
64
  image = torch.divide(im_tensor,255.0)
65
  image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
66
  return image
67
+
68
  def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
69
  result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
70
  ma = torch.max(result)