Spaces:
Runtime error
Runtime error
import gradio as gr | |
import segmentation_models_pytorch as smp | |
import torch | |
import PIL as Image | |
#load our pytorch model: | |
model = smp.Unet( | |
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
classes=10, # model output channels (number of classes in your dataset) | |
) | |
model.load_state_dict(torch.load('Floodnet_model_e5.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
#handle input: | |
# output = lbm(sample.unsqueeze(dim=0).float()).detach().type(torch.int64) | |
# show(output.argmax(dim=1).squeeze()) | |
def predict_segmentation(image: Image.Image): | |
image = image.resize((256, 256)) | |
input_data = np.asarray(image) | |
# Assuming the model expects a 4D input array | |
input_data = input_data[np.newaxis, ...] | |
# Get the prediction from the model | |
output_data = model.predict(torch.from_numpy(input_data).float()) | |
# Assuming the output is a 3D array | |
output_mask = output.argmax(dim=1).squeeze() | |
# Convert the output_mask to an Image object | |
output_image = output_mask#Image.fromarray(np.uint8(output_mask.numpy())) | |
return output_image | |
image_input = gr.components.Image(shape=(256, 256), source="upload") | |
image_output = gr.components.Image(type="pil") | |
iface = gr.Interface(predict_segmentation, 'image', 'image') | |
iface.launch() |