deepflash2 / app.py
matjesg's picture
Update app.py
807aa27
raw
history blame
No virus
1.23 kB
import numpy as np
import gradio as gr
import onnxruntime as ort
from matplotlib import pyplot as plt
from huggingface_hub import hf_hub_download
model = hf_hub_download(repo_id="matjesg/cFOS_in_HC", filename="ensemble.onnx")
def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(str(model_path), options, providers=[provider])
session.disable_fallback()
return session
ort_session = create_model_for_provider(model)
def inference(img):
img = img[...,:1]/255
ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]*255
title="deepflash2"
description="deepflash2 is a deep-learning pipeline for segmentation of ambiguous microscopic images."
examples=[['1599.tif']]
gr.Interface(inference,
gr.inputs.Image(type="numpy"),
gr.outputs.Image(),
title=title,
description=description,
examples=examples
).launch(share=True)