atom-detection / atoms_detection /dl_detection_with_gmm.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
from typing import Tuple, List
from atoms_detection.dl_detection import DLDetection
from utils.constants import ModelArgs
from sklearn.mixture import GaussianMixture
from scipy.ndimage import label
import math
import numpy as np
class DLGMMdetection(DLDetection):
MAX_SINGLE_ATOM_AREA = 200
MAX_ATOMS_PER_AREA = 3
COVARIANCE_TYPE = "full"
def __init__(
self,
model_name: ModelArgs,
ckpt_filename: str,
dataset_csv: str,
threshold: float,
detections_path: str,
inference_cache_path: str,
covariance_penalisation: float = 0.03,
n_clusters_penalisation: float = 0.33,
distance_penalisation: float = 0.11,
n_samples_per_gmm: int = 600,
):
super(DLGMMdetection, self).__init__(
model_name,
ckpt_filename,
dataset_csv,
threshold,
detections_path,
inference_cache_path,
)
self.covariance_penalisation = covariance_penalisation
self.n_clusters_penalisation = n_clusters_penalisation
self.distance_penalisation = distance_penalisation
self.n_samples_per_gmm = n_samples_per_gmm
def pred_map_to_atoms(
self, pred_map: np.ndarray
) -> Tuple[List[Tuple[int, int]], List[float]]:
pred_mask = pred_map > self.threshold
labeled_array, num_features = label(pred_mask)
self.current_pred_map = pred_map
# Convert labelled_array to indexes
center_coords_list = []
likelihood_list = []
for label_idx in range(num_features + 1):
if label_idx == 0:
continue
label_mask = np.where(labeled_array == label_idx)
likelihood = np.max(pred_map[label_mask])
# label_size = len(label_mask[0])
# print(f"\t\tAtom {label_idx}: {label_size}")
atom_bbox = (
label_mask[1].min(),
label_mask[1].max(),
label_mask[0].min(),
label_mask[0].max(),
)
center_coord = self.bbox_to_center_coords(atom_bbox)
center_coords_list += center_coord
pixel_area = (atom_bbox[1] - atom_bbox[0]) * (atom_bbox[3] - atom_bbox[2])
if pixel_area < self.MAX_SINGLE_ATOM_AREA:
likelihood_list.append(likelihood)
else:
for i in range(0, len(center_coord)):
likelihood_list.append(likelihood)
self.current_pred_map = None
print(f"number for atoms {len(center_coords_list)}")
return center_coords_list, likelihood_list
def bbox_to_center_coords(
self, bbox: Tuple[int, int, int, int]
) -> List[Tuple[int, int]]:
pixel_area = (bbox[1] - bbox[0]) * (bbox[3] - bbox[2])
if pixel_area < self.MAX_SINGLE_ATOM_AREA:
return super().bbox_to_center_coords(bbox)
else:
pmap = self.get_current_prediction_map_region(bbox)
local_atom_center_list = self.run_gmm_pipeline(pmap)
atom_center_list = [
(x + bbox[0], y + bbox[2]) for x, y in local_atom_center_list
]
return atom_center_list
def sample_img_hist(self, img_region):
x_bin_midpoints = list(range(img_region.shape[1]))
y_bin_midpoints = list(range(img_region.shape[0]))
# noinspection PyUnresolvedReferences
cdf = np.cumsum(img_region.ravel())
cdf = cdf / cdf[-1]
values = np.random.rand(self.n_samples_per_gmm)
# noinspection PyUnresolvedReferences
value_bins = np.searchsorted(cdf, values)
x_idx, y_idx = np.unravel_index(
value_bins, (len(x_bin_midpoints), len(y_bin_midpoints))
)
random_from_cdf = np.column_stack((x_idx, y_idx))
new_x, new_y = random_from_cdf.T
return new_x, new_y
def run_gmm_pipeline(self, prediction_map: np.ndarray) -> List[Tuple[int, int]]:
retries = 2
new_x, new_y = self.sample_img_hist(prediction_map)
best_gmm, best_score = None, np.NINF
obs = np.array((new_x, new_y)).T
for k in range(1, self.MAX_ATOMS_PER_AREA + 1):
for i in range(retries):
gmm = GaussianMixture(
n_components=k, covariance_type=self.COVARIANCE_TYPE
)
gmm.fit(obs)
logLike = gmm.score(obs)
covar = np.linalg.norm(gmm.covariances_)
if k == 1:
score = (
logLike
- covar * self.covariance_penalisation
- k * self.n_clusters_penalisation
)
print(k, score)
else:
distances = [
math.dist(p1, p2)
for i, p1 in enumerate(gmm.means_[:-1])
for p2 in gmm.means_[i + 1 :]
]
dist_penalisation = sum([max(12 - d, 0) ** 2 for d in distances])
score = (
logLike
- covar * self.covariance_penalisation
- k * self.n_clusters_penalisation
- dist_penalisation * self.distance_penalisation
)
print(
k,
score,
logLike,
covar * self.covariance_penalisation,
k * self.n_clusters_penalisation,
dist_penalisation * self.distance_penalisation,
)
if score > best_score:
best_gmm, best_score = gmm, score
# print(best_gmm.means_)
return [(x, y) for y, x in best_gmm.means_.tolist()]