|
import torch |
|
import numpy as np |
|
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Lambda, Resize, NormalizeIntensity, GaussianSmooth, ScaleIntensity, AsDiscrete, KeepLargestConnectedComponent, Invert, Rotate90, SaveImage, Transform |
|
from monai.inferers import SlidingWindowInferer |
|
from monai.networks.nets import UNet |
|
|
|
class RgbaToGrayscale(Transform): |
|
def __call__(self, x): |
|
|
|
x = x.squeeze(-1) |
|
|
|
if x.ndim != 3: |
|
raise ValueError(f"Input tensor must be 3D. Shape: {x.shape}") |
|
|
|
|
|
if x.shape[0] == 4: |
|
rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) |
|
|
|
grayscale = torch.einsum('cwh,c->wh', x[:3, :, :], rgb_weights).unsqueeze(0) |
|
elif x.shape[0] == 3: |
|
rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) |
|
grayscale = torch.einsum('cwh,c->wh', x, rgb_weights).unsqueeze(0) |
|
elif x.shape[0] == 1: |
|
grayscale = x |
|
else: |
|
raise ValueError(f"Unsupported channel number: {x.shape[0]}") |
|
return grayscale |
|
|
|
def inverse(self, x): |
|
|
|
return x |
|
|
|
model = UNet( |
|
spatial_dims=2, |
|
in_channels=1, |
|
out_channels=4, |
|
channels=[64, 128, 256, 512], |
|
strides=[2, 2, 2], |
|
num_res_units=3 |
|
) |
|
|
|
checkpoint_path = 'segmentation_model.pt' |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match" |
|
|
|
model.load_state_dict(checkpoint['network']) |
|
model.eval() |
|
|
|
|
|
pre_transforms = Compose([ |
|
LoadImage(image_only=True), |
|
EnsureChannelFirst(), |
|
RgbaToGrayscale(), |
|
Resize(spatial_size=(768, 768)), |
|
Lambda(func=lambda x: x.squeeze(-1)), |
|
NormalizeIntensity(), |
|
GaussianSmooth(sigma=0.1), |
|
ScaleIntensity(minv=-1, maxv=1) |
|
]) |
|
|
|
|
|
|
|
|
|
post_transforms = Compose([ |
|
AsDiscrete(argmax=True, to_onehot=4), |
|
KeepLargestConnectedComponent(), |
|
AsDiscrete(argmax=True), |
|
Invert(pre_transforms), |
|
|
|
]) |
|
|
|
|
|
|
|
def load_and_segment_image(input_image_path, device): |
|
|
|
model = model.to(device) |
|
image_tensor = pre_transforms(input_image_path) |
|
image_tensor = image_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
inferer = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=16, overlap=0.75) |
|
with torch.no_grad(): |
|
outputs = inferer(image_tensor, model) |
|
|
|
|
|
outputs = outputs.squeeze(0) |
|
|
|
processed_outputs = post_transforms(outputs) |
|
|
|
|
|
rotate = Rotate90(spatial_axes=(0, 1), k=3) |
|
processed_outputs = rotate(processed_outputs).to('cpu') |
|
|
|
output_array = processed_outputs.squeeze().detach().numpy().astype(np.uint8) |
|
|
|
|
|
return output_array |