File size: 3,348 Bytes
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f722806
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f722806
ab163d2
f722806
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import numpy as np
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Lambda, Resize, NormalizeIntensity, GaussianSmooth, ScaleIntensity, AsDiscrete, KeepLargestConnectedComponent, Invert, Rotate90, SaveImage, Transform
from monai.inferers import SlidingWindowInferer
from monai.networks.nets import UNet

class RgbaToGrayscale(Transform):
    def __call__(self, x):
        # squeeze last dimension, to ensure C, H, W format
        x = x.squeeze(-1)
        # Ensure the tensor is 3D (channels, height, width)
        if x.ndim != 3:
            raise ValueError(f"Input tensor must be 3D. Shape: {x.shape}")
        
        # Check the number of channels
        if x.shape[0] == 4:  # Assuming RGBA
            rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device)
            # Apply weights to RGB channels, output should retain one channel dimension
            grayscale = torch.einsum('cwh,c->wh', x[:3, :, :], rgb_weights).unsqueeze(0)
        elif x.shape[0] == 3:  # Assuming RGB
            rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device)
            grayscale = torch.einsum('cwh,c->wh', x, rgb_weights).unsqueeze(0)
        elif x.shape[0] == 1:  # Already grayscale
            grayscale = x
        else:
            raise ValueError(f"Unsupported channel number: {x.shape[0]}")
        return grayscale

    def inverse(self, x):
        # Simply return the input as the output
        return x
    
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
    channels=[64, 128, 256, 512],
    strides=[2, 2, 2],
    num_res_units=3
)

checkpoint_path = 'segmentation_model.pt'
checkpoint = torch.load(checkpoint_path, map_location='cpu')  
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"

model.load_state_dict(checkpoint['network'])
model.eval()

# Define transforms for preprocessing
pre_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    RgbaToGrayscale(),  # Convert RGBA to grayscale
    Resize(spatial_size=(768, 768)), 
    Lambda(func=lambda x: x.squeeze(-1)),  # Adjust if the input image has an extra unwanted dimension
    NormalizeIntensity(),
    GaussianSmooth(sigma=0.1),
    ScaleIntensity(minv=-1, maxv=1)
])



# Define transforms for postprocessing
post_transforms = Compose([
    AsDiscrete(argmax=True, to_onehot=4),
    KeepLargestConnectedComponent(),
    AsDiscrete(argmax=True),
    Invert(pre_transforms),
    #SaveImage(output_dir='./', output_postfix='seg', output_ext='.nii', resample=False)
])



def load_and_segment_image(input_image_path, device):
        
    model = model.to(device)
    image_tensor = pre_transforms(input_image_path)
    image_tensor = image_tensor.unsqueeze(0).to(device)

    # Inference using SlidingWindowInferer
    inferer = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=16, overlap=0.75)
    with torch.no_grad():
        outputs = inferer(image_tensor, model)


    outputs = outputs.squeeze(0)

    processed_outputs = post_transforms(outputs)

    # rotate 
    rotate = Rotate90(spatial_axes=(0, 1), k=3)
    processed_outputs = rotate(processed_outputs).to('cpu')  

    output_array = processed_outputs.squeeze().detach().numpy().astype(np.uint8)


    return output_array