File size: 4,036 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Tuple, List

import os
from hashlib import sha1

import numpy as np
from PIL import Image
from scipy.ndimage import label

from utils.constants import Split
from utils.paths import PREDS_PATH
from atoms_detection.dataset import ImageDataset


class Detection:
    def __init__(self, dataset_csv: str, threshold: float, detections_path: str, inference_cache_path: str):
        self.image_dataset = ImageDataset(dataset_csv)
        self.threshold = threshold
        self.detections_path = detections_path
        self.inference_cache_path = inference_cache_path
        self.currently_processing = None
        if not os.path.exists(self.detections_path):
            os.makedirs(self.detections_path)
        if not os.path.exists(self.inference_cache_path):
            os.makedirs(self.inference_cache_path)

    def image_to_pred_map(self, img: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def pred_map_to_atoms(self, pred_map: np.ndarray) -> Tuple[List[Tuple[int, int]], List[float]]:
        pred_mask = pred_map > self.threshold
        labeled_array, num_features = label(pred_mask)

        # Convert labelled_array to indexes
        center_coords_list = []
        likelihood_list = []
        for label_idx in range(num_features+1):
            if label_idx == 0:
                continue
            label_mask = np.where(labeled_array == label_idx)
            likelihood = np.max(pred_map[label_mask])
            likelihood_list.append(likelihood)
            # label_size = len(label_mask[0])
            # print(f"\t\tAtom {label_idx}: {label_size}")
            atom_bbox = (label_mask[1].min(), label_mask[1].max(), label_mask[0].min(), label_mask[0].max())
            center_coord = self.bbox_to_center_coords(atom_bbox)
            center_coords_list.append(center_coord)
        return center_coords_list, likelihood_list

    def detect_atoms(self, img_filename: str) -> Tuple[List[Tuple[int, int]], List[float]]:
        img_hash = self.cache_image_identifier(img_filename)
        prediciton_cache = os.path.join(self.inference_cache_path, f"{img_hash}.npy")
        if not os.path.exists(prediciton_cache):
            self.currently_processing = os.path.split(img_filename)[-1]
            img = self.open_image(img_filename)
            pred_map = self.image_to_pred_map(img)
            np.save(prediciton_cache, pred_map)
        else:
            pred_map = np.load(prediciton_cache)
        center_coords_list, likelihood_list = self.pred_map_to_atoms(pred_map)
        return center_coords_list, likelihood_list

    def cache_image_identifier(self, img_filename):
        return sha1(img_filename.encode()).hexdigest()

    @staticmethod
    def bbox_to_center_coords(bbox: Tuple[int, int, int, int]) -> Tuple[int, int]:
        x_center = (bbox[0] + bbox[1]) // 2
        y_center = (bbox[2] + bbox[3]) // 2
        return x_center, y_center

    @staticmethod
    def open_image(img_filename: str):
        img = Image.open(img_filename)
        np_img = np.asarray(img).astype(np.float32)
        return np_img

    def run_single(self, image_path: str):
            print(f"Running detection on {os.path.basename(image_path)}")
            center_coords_list, likelihood_list = self.detect_atoms(image_path)

            image_filename = os.path.basename(image_path)
            img_name = os.path.splitext(image_filename)[0]
            detection_csv = os.path.join(self.detections_path, f"{img_name}.csv")
            with open(detection_csv, "w") as _csv:
                _csv.write("Filename,x,y,Likelihood\n")
                for (x, y), likelihood in zip(center_coords_list, likelihood_list):
                    _csv.write(f"{image_filename},{x},{y},{likelihood}\n")
            return center_coords_list, likelihood_list

    def run(self):
        if not os.path.exists(self.detections_path):
            os.makedirs(self.detections_path)

        for image_path in self.image_dataset.iterate_data(Split.TEST):
            run_single(image_path)