atom-detection / atoms_detection /multimetallic_analysis.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame
7.47 kB
# run VAE + GMM assignement
import argparse
from typing import List
import numpy as np
# import rasterio
import torch
import warnings
import os
import re
import pandas as pd
from PIL import Image
from sklearn.mixture import GaussianMixture
from atoms_detection.create_crop_dataset import create_crop
from atoms_detection.vae_utilities.vae_model import rVAE
from atoms_detection.vae_utilities.vae_svi_train import init_dataloader, SVItrainer
from atoms_detection.image_preprocessing import dl_prepro_image
"""
Code sourced from:
https://colab.research.google.com/github/ziatdinovmax/notebooks_for_medium/blob/main/pyroVAE_MNIST_medium.ipynb
"""
numbers = re.compile(r'(\d+)')
def numericalSort(value):
parts = numbers.split(value)
parts[1::2] = map(int, parts[1::2])
return parts
warnings.filterwarnings("ignore", module="torchvision.datasets")
def get_crops_from_prediction_csvs(pred_crop_file):
data = pd.read_csv(pred_crop_file)
xx = data['x'].values
yy = data['y'].values
coords = zip(xx,yy)
img_file = data['Filename'][0]
likelihood = data['Likelihood'].values
img_path = os.path.join('data/tif_data', img_file)
img = Image.open(img_path)
np_img = np.asarray(img).astype(np.float64)
np_img = dl_prepro_image(np_img)
img = Image.fromarray(np_img)
crops = list()
coords_list = []
for x, y in coords:
coords_list.append([x,y])
new_crop = create_crop(img, x, y)
crops.append(new_crop)
print(coords_list[0])
print(np_img[0])
coords_array = np.array(coords_list)
return crops, coords_array, likelihood, img_file
def classify_crop_species(args):
# crop_list = get_crops_from_folder(crops_source_folder='./Ni')
crop_list, crop_coords, likelihood, img_filename = get_crops_from_prediction_csvs(args.pred_crop_file)
crop_tensor = np.array(crop_list)
# Assuming crop_tensor is a list or array of Image objects
processed_images = []
for image in crop_tensor:
# Convert the Image to a NumPy array
image_array = np.array(image)
# Append the processed image array to the list
processed_images.append(image_array)
# Convert the processed images list to a NumPy array
processed_images = np.array(processed_images)
# Convert the processed_images array to float32
processed_images = processed_images.astype(np.float32)
#print(processed_images.shape)
rvae = rVAE(in_dim=(21, 21), latent_dim=args.latent_dim, coord=args.coord, seed=args.seed)
train_data = torch.from_numpy(processed_images).float()
# train_data = torch.from_numpy(crop_tensor).float()
train_loader = init_dataloader(train_data, batch_size=args.batchsize)
latent_crop_tensor = train_vae(rvae, train_data, train_loader, args)
gmm = GaussianMixture(n_components=args.n_species, reg_covar=args.GMMcovar, random_state=args.seed).fit(
latent_crop_tensor)
preds = gmm.predict(latent_crop_tensor)
print(preds)
pred_proba = gmm.predict_proba(latent_crop_tensor)
pred_proba = [pred_proba[i, pred] for i, pred in enumerate(preds)]
# To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int)
cluster_median_values = list()
for k in range(args.n_species):
print(k)
relevant_crops = processed_images[preds == k]
crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0)
img_means = []
for crop, q in zip(relevant_crops, crop_95_percentile):
if (crop >= q).any():
print(crop.mean())
img_means.append(crop.mean())
#img_means.append(crop.mean(axis=0, where=crop >= q))
cluster_median_value = np.median(np.array(img_means))
cluster_median_values.append(cluster_median_value)
sorted_clusters = sorted([(mval, c_id) for c_id, mval in enumerate(cluster_median_values)])
with open(f"data/detection_data/Multimetallic_{img_filename}.csv", "a") as f:
f.write("Filename,x,y,Likelihood,cluster,cluster_confidence\n")
for _, c_id in sorted_clusters:
c_idd = np.array([c_id])
pred_proba = np.array(pred_proba)
relevant_crops_coords = crop_coords[preds == c_idd]
relevant_crops_likelihood = likelihood[preds == c_idd]
relevant_crops_confidence = pred_proba[preds == c_idd]
#print(relevant_crops_confidence)
for coords, l, c in zip(relevant_crops_coords, relevant_crops_likelihood, relevant_crops_confidence):
x, y = coords
f.write(f"{img_filename},{x},{y},{l},{c_id},{c}\n")
def train_vae(rvae, train_data, train_loader, args):
# Initialize SVI trainer
trainer = SVItrainer(rvae)
for e in range(args.epochs):
trainer.step(train_loader, scale_factor=args.scale_factor)
trainer.print_statistics()
z_mean, z_sd = rvae.encode(train_data)
latent_crop_tensor = z_mean
return latent_crop_tensor
def get_crops_from_folder(crops_source_folder) -> List[np.ndarray]:
ffiles = []
files = []
for dirname, dirnames, filenames in os.walk(crops_source_folder):
# print path to all subdirectories first.
for subdirname in dirnames:
files.append(os.path.join(dirname, subdirname))
# print path to all filenames.
for filename in filenames:
files.append(os.path.join(dirname, filename))
for filename in sorted((filenames), key=numericalSort):
ffiles.append(os.path.join(filename))
crops = ffiles
# print(len(crops))
path_crops = './Ni/'
all_img = []
for i in range(0, len(crops)):
src_path = path_crops + crops[i]
img = rasterio.open(src_path)
test = np.reshape(img.read([1]), (21, 21))
all_img.append(np.array(test))
return all_img
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'pred_crop_file',
type=str,
help="Path to the CSV of predicted crop locations (eg in data/detection_data/X/Y.csv)"
)
parser.add_argument(
"-latent_dim",
type=int,
default=50,
help="Experiment extension name"
)
parser.add_argument(
"-seed",
type=int,
default=444,
help="Random seed"
)
parser.add_argument(
"-coord",
type=int,
default=3,
help="Amount of equivariances, 0: None,1: Rotational, 2: Translational, 3:Rotational and Translational"
)
parser.add_argument(
"-batchsize",
type=int,
default=100,
help="Batch size for the VAE model"
)
parser.add_argument(
"-epochs",
type=int,
default=20,
help="Number of training epochs for the VAE"
)
parser.add_argument(
"-scale_factor",
type=int,
default=3,
help="Number of training epochs for the VAE"
)
parser.add_argument(
"-n_species",
type=int,
default=2,
help="Number of chemical species expected in the sample."
)
parser.add_argument(
"-GMMcovar",
type=float,
default=0.0001,
help="Regcovar for the training of the GMM clustering algorithm."
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
print(args)
classify_crop_species(args)