from typing import Tuple, List import os from hashlib import sha1 import numpy as np from PIL import Image from scipy.ndimage import label from utils.constants import Split from utils.paths import PREDS_PATH from atoms_detection.dataset import ImageDataset class Detection: def __init__(self, dataset_csv: str, threshold: float, detections_path: str, inference_cache_path: str): self.image_dataset = ImageDataset(dataset_csv) self.threshold = threshold self.detections_path = detections_path self.inference_cache_path = inference_cache_path self.currently_processing = None if not os.path.exists(self.detections_path): os.makedirs(self.detections_path) if not os.path.exists(self.inference_cache_path): os.makedirs(self.inference_cache_path) def image_to_pred_map(self, img: np.ndarray) -> np.ndarray: raise NotImplementedError def pred_map_to_atoms(self, pred_map: np.ndarray) -> Tuple[List[Tuple[int, int]], List[float]]: pred_mask = pred_map > self.threshold labeled_array, num_features = label(pred_mask) # Convert labelled_array to indexes center_coords_list = [] likelihood_list = [] for label_idx in range(num_features+1): if label_idx == 0: continue label_mask = np.where(labeled_array == label_idx) likelihood = np.max(pred_map[label_mask]) likelihood_list.append(likelihood) # label_size = len(label_mask[0]) # print(f"\t\tAtom {label_idx}: {label_size}") atom_bbox = (label_mask[1].min(), label_mask[1].max(), label_mask[0].min(), label_mask[0].max()) center_coord = self.bbox_to_center_coords(atom_bbox) center_coords_list.append(center_coord) return center_coords_list, likelihood_list def detect_atoms(self, img_filename: str) -> Tuple[List[Tuple[int, int]], List[float]]: img_hash = self.cache_image_identifier(img_filename) prediciton_cache = os.path.join(self.inference_cache_path, f"{img_hash}.npy") if not os.path.exists(prediciton_cache): self.currently_processing = os.path.split(img_filename)[-1] img = self.open_image(img_filename) pred_map = self.image_to_pred_map(img) np.save(prediciton_cache, pred_map) else: pred_map = np.load(prediciton_cache) center_coords_list, likelihood_list = self.pred_map_to_atoms(pred_map) return center_coords_list, likelihood_list def cache_image_identifier(self, img_filename): return sha1(img_filename.encode()).hexdigest() @staticmethod def bbox_to_center_coords(bbox: Tuple[int, int, int, int]) -> Tuple[int, int]: x_center = (bbox[0] + bbox[1]) // 2 y_center = (bbox[2] + bbox[3]) // 2 return x_center, y_center @staticmethod def open_image(img_filename: str): img = Image.open(img_filename) np_img = np.asarray(img).astype(np.float32) return np_img def run_single(self, image_path: str): print(f"Running detection on {os.path.basename(image_path)}") center_coords_list, likelihood_list = self.detect_atoms(image_path) image_filename = os.path.basename(image_path) img_name = os.path.splitext(image_filename)[0] detection_csv = os.path.join(self.detections_path, f"{img_name}.csv") with open(detection_csv, "w") as _csv: _csv.write("Filename,x,y,Likelihood\n") for (x, y), likelihood in zip(center_coords_list, likelihood_list): _csv.write(f"{image_filename},{x},{y},{likelihood}\n") return center_coords_list, likelihood_list def run(self): if not os.path.exists(self.detections_path): os.makedirs(self.detections_path) for image_path in self.image_dataset.iterate_data(Split.TEST): run_single(image_path)