Spaces:
Sleeping
Sleeping
from typing import Tuple, List | |
import os | |
from hashlib import sha1 | |
import numpy as np | |
from PIL import Image | |
from scipy.ndimage import label | |
from utils.constants import Split | |
from utils.paths import PREDS_PATH | |
from atoms_detection.dataset import ImageDataset | |
class Detection: | |
def __init__(self, dataset_csv: str, threshold: float, detections_path: str, inference_cache_path: str): | |
self.image_dataset = ImageDataset(dataset_csv) | |
self.threshold = threshold | |
self.detections_path = detections_path | |
self.inference_cache_path = inference_cache_path | |
self.currently_processing = None | |
if not os.path.exists(self.detections_path): | |
os.makedirs(self.detections_path) | |
if not os.path.exists(self.inference_cache_path): | |
os.makedirs(self.inference_cache_path) | |
def image_to_pred_map(self, img: np.ndarray) -> np.ndarray: | |
raise NotImplementedError | |
def pred_map_to_atoms(self, pred_map: np.ndarray) -> Tuple[List[Tuple[int, int]], List[float]]: | |
pred_mask = pred_map > self.threshold | |
labeled_array, num_features = label(pred_mask) | |
# Convert labelled_array to indexes | |
center_coords_list = [] | |
likelihood_list = [] | |
for label_idx in range(num_features+1): | |
if label_idx == 0: | |
continue | |
label_mask = np.where(labeled_array == label_idx) | |
likelihood = np.max(pred_map[label_mask]) | |
likelihood_list.append(likelihood) | |
# label_size = len(label_mask[0]) | |
# print(f"\t\tAtom {label_idx}: {label_size}") | |
atom_bbox = (label_mask[1].min(), label_mask[1].max(), label_mask[0].min(), label_mask[0].max()) | |
center_coord = self.bbox_to_center_coords(atom_bbox) | |
center_coords_list.append(center_coord) | |
return center_coords_list, likelihood_list | |
def detect_atoms(self, img_filename: str) -> Tuple[List[Tuple[int, int]], List[float]]: | |
img_hash = self.cache_image_identifier(img_filename) | |
prediciton_cache = os.path.join(self.inference_cache_path, f"{img_hash}.npy") | |
if not os.path.exists(prediciton_cache): | |
self.currently_processing = os.path.split(img_filename)[-1] | |
img = self.open_image(img_filename) | |
pred_map = self.image_to_pred_map(img) | |
np.save(prediciton_cache, pred_map) | |
else: | |
pred_map = np.load(prediciton_cache) | |
center_coords_list, likelihood_list = self.pred_map_to_atoms(pred_map) | |
return center_coords_list, likelihood_list | |
def cache_image_identifier(self, img_filename): | |
return sha1(img_filename.encode()).hexdigest() | |
def bbox_to_center_coords(bbox: Tuple[int, int, int, int]) -> Tuple[int, int]: | |
x_center = (bbox[0] + bbox[1]) // 2 | |
y_center = (bbox[2] + bbox[3]) // 2 | |
return x_center, y_center | |
def open_image(img_filename: str): | |
img = Image.open(img_filename) | |
np_img = np.asarray(img).astype(np.float32) | |
return np_img | |
def run_single(self, image_path: str): | |
print(f"Running detection on {os.path.basename(image_path)}") | |
center_coords_list, likelihood_list = self.detect_atoms(image_path) | |
image_filename = os.path.basename(image_path) | |
img_name = os.path.splitext(image_filename)[0] | |
detection_csv = os.path.join(self.detections_path, f"{img_name}.csv") | |
with open(detection_csv, "w") as _csv: | |
_csv.write("Filename,x,y,Likelihood\n") | |
for (x, y), likelihood in zip(center_coords_list, likelihood_list): | |
_csv.write(f"{image_filename},{x},{y},{likelihood}\n") | |
return center_coords_list, likelihood_list | |
def run(self): | |
if not os.path.exists(self.detections_path): | |
os.makedirs(self.detections_path) | |
for image_path in self.image_dataset.iterate_data(Split.TEST): | |
run_single(image_path) | |