import os import gradio as gr import numpy as np import torch from typing import Tuple, Optional, Dict, List import glob from collections import defaultdict from transformers import (AutoImageProcessor, ResNetForImageClassification) from labelmap import DR_LABELMAP class App: def __init__(self) -> None: ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/" path = f"release_ckpts/{ckpt_name}/inference/" self.image_processor = AutoImageProcessor.from_pretrained(path) self.model = ResNetForImageClassification.from_pretrained(path) example_lists = self._load_example_lists() device = 'GPU' if torch.cuda.is_available() else 'CPU' css = ".output-image, .input-image, .image-preview {height: 600px !important}" with gr.Blocks(css=css) as ui: with gr.Row(): with gr.Column(scale=1): with gr.Row(): predict_btn = gr.Button("Predict", size="lg") with gr.Row(): gr.Markdown(f"Running on {device}") with gr.Column(scale=4): # output = gr.Textbox(label="Retinopathy level prediction") output = gr.Label(num_top_classes=len(DR_LABELMAP), label="Retinopathy level prediction") with gr.Column(scale=4): gr.Markdown("![](https://media.githubusercontent.com/media/Obs01ete/retinopathy/master/media/logo1.png)") with gr.Row(): with gr.Column(scale=9, min_width=100): image = gr.Image(label="Retina scan") with gr.Column(scale=1, min_width=150): for cls_id in range(len(example_lists)): label = DR_LABELMAP[cls_id] with gr.Tab(f"{cls_id} : {label}"): gr.Examples( example_lists[cls_id], inputs=[image], outputs=[output], fn=self.predict, examples_per_page=10, run_on_click=True) predict_btn.click( fn=self.predict, inputs=image, outputs=output, api_name="predict") self.ui = ui def launch(self) -> None: self.ui.queue().launch(share=True) def predict(self, image: Optional[np.ndarray]): if image is None: return dict() cls_name, prob, probs = self._infer(image) message = f"Predicted class={cls_name}, prob={prob:.3f}" print(message) probs_dict = {f"{i} - {DR_LABELMAP[i]}": float(v) for i, v in enumerate(probs)} return probs_dict def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]: assert isinstance(self.model, ResNetForImageClassification) inputs = self.image_processor(image_chw, return_tensors="pt") with torch.no_grad(): output = self.model(**inputs) logits_batch = output.logits assert len(logits_batch.shape) == 2 assert logits_batch.shape[0] == 1 logits = logits_batch[0] probs = torch.softmax(logits, dim=-1) predicted_label = int(probs.argmax(-1).item()) prob = probs[predicted_label].item() cls_name = self.model.config.id2label[predicted_label] return cls_name, prob, probs.numpy() @staticmethod def _load_example_lists() -> Dict[int, List[str]]: example_flat_list = glob.glob("demo_data/train/**/*.jpeg") example_lists: Dict[int, List[str]] = defaultdict(list) for path in example_flat_list: dir, _ = os.path.split(path) _, subdir = os.path.split(dir) try: cls_id = int(subdir) except ValueError: print(f"Cannot parse path {path}") continue example_lists[cls_id].append(path) return example_lists def main(): app = App() app.launch() if __name__ == "__main__": main()