Update MyPipe.py
Browse files
MyPipe.py
CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
|
|
9 |
class RMBGPipe(Pipeline):
|
10 |
def __init__(self,**kwargs):
|
11 |
Pipeline.__init__(self,**kwargs)
|
12 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
self.model.to(self.device)
|
14 |
self.model.eval()
|
15 |
|
@@ -39,6 +39,7 @@ class RMBGPipe(Pipeline):
|
|
39 |
result = self.model(inputs.pop("image"))
|
40 |
inputs["result"] = result
|
41 |
return inputs
|
|
|
42 |
def postprocess(self,inputs,return_mask:bool=False ):
|
43 |
result = inputs.pop("result")
|
44 |
orig_im_size = inputs.pop("orig_im_size")
|
@@ -48,7 +49,7 @@ class RMBGPipe(Pipeline):
|
|
48 |
if return_mask ==True :
|
49 |
return pil_im
|
50 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
51 |
-
orig_image = Image.
|
52 |
no_bg_image.paste(orig_image, mask=pil_im)
|
53 |
return no_bg_image
|
54 |
|
@@ -59,10 +60,11 @@ class RMBGPipe(Pipeline):
|
|
59 |
im = im[:, :, np.newaxis]
|
60 |
# orig_im_size=im.shape[0:2]
|
61 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
62 |
-
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
63 |
image = torch.divide(im_tensor,255.0)
|
64 |
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
65 |
return image
|
|
|
66 |
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
67 |
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
68 |
ma = torch.max(result)
|
|
|
9 |
class RMBGPipe(Pipeline):
|
10 |
def __init__(self,**kwargs):
|
11 |
Pipeline.__init__(self,**kwargs)
|
12 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
13 |
self.model.to(self.device)
|
14 |
self.model.eval()
|
15 |
|
|
|
39 |
result = self.model(inputs.pop("image"))
|
40 |
inputs["result"] = result
|
41 |
return inputs
|
42 |
+
|
43 |
def postprocess(self,inputs,return_mask:bool=False ):
|
44 |
result = inputs.pop("result")
|
45 |
orig_im_size = inputs.pop("orig_im_size")
|
|
|
49 |
if return_mask ==True :
|
50 |
return pil_im
|
51 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
52 |
+
orig_image = Image.fromarray(io.imread(im_path))
|
53 |
no_bg_image.paste(orig_image, mask=pil_im)
|
54 |
return no_bg_image
|
55 |
|
|
|
60 |
im = im[:, :, np.newaxis]
|
61 |
# orig_im_size=im.shape[0:2]
|
62 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
63 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
64 |
image = torch.divide(im_tensor,255.0)
|
65 |
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
66 |
return image
|
67 |
+
|
68 |
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
69 |
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
70 |
ma = torch.max(result)
|