# run VAE + GMM assignement import argparse from typing import List import numpy as np # import rasterio import torch import warnings import os import re import pandas as pd from PIL import Image from sklearn.mixture import GaussianMixture from atoms_detection.create_crop_dataset import create_crop from atoms_detection.vae_utilities.vae_model import rVAE from atoms_detection.vae_utilities.vae_svi_train import init_dataloader, SVItrainer from atoms_detection.image_preprocessing import dl_prepro_image """ Code sourced from: https://colab.research.google.com/github/ziatdinovmax/notebooks_for_medium/blob/main/pyroVAE_MNIST_medium.ipynb """ numbers = re.compile(r'(\d+)') def numericalSort(value): parts = numbers.split(value) parts[1::2] = map(int, parts[1::2]) return parts warnings.filterwarnings("ignore", module="torchvision.datasets") def get_crops_from_prediction_csvs(pred_crop_file): data = pd.read_csv(pred_crop_file) xx = data['x'].values yy = data['y'].values coords = zip(xx,yy) img_file = data['Filename'][0] likelihood = data['Likelihood'].values img_path = os.path.join('data/tif_data', img_file) img = Image.open(img_path) np_img = np.asarray(img).astype(np.float64) np_img = dl_prepro_image(np_img) img = Image.fromarray(np_img) crops = list() coords_list = [] for x, y in coords: coords_list.append([x,y]) new_crop = create_crop(img, x, y) crops.append(new_crop) print(coords_list[0]) print(np_img[0]) coords_array = np.array(coords_list) return crops, coords_array, likelihood, img_file def classify_crop_species(args): # crop_list = get_crops_from_folder(crops_source_folder='./Ni') crop_list, crop_coords, likelihood, img_filename = get_crops_from_prediction_csvs(args.pred_crop_file) crop_tensor = np.array(crop_list) # Assuming crop_tensor is a list or array of Image objects processed_images = [] for image in crop_tensor: # Convert the Image to a NumPy array image_array = np.array(image) # Append the processed image array to the list processed_images.append(image_array) # Convert the processed images list to a NumPy array processed_images = np.array(processed_images) # Convert the processed_images array to float32 processed_images = processed_images.astype(np.float32) #print(processed_images.shape) rvae = rVAE(in_dim=(21, 21), latent_dim=args.latent_dim, coord=args.coord, seed=args.seed) train_data = torch.from_numpy(processed_images).float() # train_data = torch.from_numpy(crop_tensor).float() train_loader = init_dataloader(train_data, batch_size=args.batchsize) latent_crop_tensor = train_vae(rvae, train_data, train_loader, args) gmm = GaussianMixture(n_components=args.n_species, reg_covar=args.GMMcovar, random_state=args.seed).fit( latent_crop_tensor) preds = gmm.predict(latent_crop_tensor) print(preds) pred_proba = gmm.predict_proba(latent_crop_tensor) pred_proba = [pred_proba[i, pred] for i, pred in enumerate(preds)] # To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int) cluster_median_values = list() for k in range(args.n_species): print(k) relevant_crops = processed_images[preds == k] crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0) img_means = [] for crop, q in zip(relevant_crops, crop_95_percentile): if (crop >= q).any(): print(crop.mean()) img_means.append(crop.mean()) #img_means.append(crop.mean(axis=0, where=crop >= q)) cluster_median_value = np.median(np.array(img_means)) cluster_median_values.append(cluster_median_value) sorted_clusters = sorted([(mval, c_id) for c_id, mval in enumerate(cluster_median_values)]) with open(f"data/detection_data/Multimetallic_{img_filename}.csv", "a") as f: f.write("Filename,x,y,Likelihood,cluster,cluster_confidence\n") for _, c_id in sorted_clusters: c_idd = np.array([c_id]) pred_proba = np.array(pred_proba) relevant_crops_coords = crop_coords[preds == c_idd] relevant_crops_likelihood = likelihood[preds == c_idd] relevant_crops_confidence = pred_proba[preds == c_idd] #print(relevant_crops_confidence) for coords, l, c in zip(relevant_crops_coords, relevant_crops_likelihood, relevant_crops_confidence): x, y = coords f.write(f"{img_filename},{x},{y},{l},{c_id},{c}\n") def train_vae(rvae, train_data, train_loader, args): # Initialize SVI trainer trainer = SVItrainer(rvae) for e in range(args.epochs): trainer.step(train_loader, scale_factor=args.scale_factor) trainer.print_statistics() z_mean, z_sd = rvae.encode(train_data) latent_crop_tensor = z_mean return latent_crop_tensor def get_crops_from_folder(crops_source_folder) -> List[np.ndarray]: ffiles = [] files = [] for dirname, dirnames, filenames in os.walk(crops_source_folder): # print path to all subdirectories first. for subdirname in dirnames: files.append(os.path.join(dirname, subdirname)) # print path to all filenames. for filename in filenames: files.append(os.path.join(dirname, filename)) for filename in sorted((filenames), key=numericalSort): ffiles.append(os.path.join(filename)) crops = ffiles # print(len(crops)) path_crops = './Ni/' all_img = [] for i in range(0, len(crops)): src_path = path_crops + crops[i] img = rasterio.open(src_path) test = np.reshape(img.read([1]), (21, 21)) all_img.append(np.array(test)) return all_img def get_args(): parser = argparse.ArgumentParser() parser.add_argument( 'pred_crop_file', type=str, help="Path to the CSV of predicted crop locations (eg in data/detection_data/X/Y.csv)" ) parser.add_argument( "-latent_dim", type=int, default=50, help="Experiment extension name" ) parser.add_argument( "-seed", type=int, default=444, help="Random seed" ) parser.add_argument( "-coord", type=int, default=3, help="Amount of equivariances, 0: None,1: Rotational, 2: Translational, 3:Rotational and Translational" ) parser.add_argument( "-batchsize", type=int, default=100, help="Batch size for the VAE model" ) parser.add_argument( "-epochs", type=int, default=20, help="Number of training epochs for the VAE" ) parser.add_argument( "-scale_factor", type=int, default=3, help="Number of training epochs for the VAE" ) parser.add_argument( "-n_species", type=int, default=2, help="Number of chemical species expected in the sample." ) parser.add_argument( "-GMMcovar", type=float, default=0.0001, help="Regcovar for the training of the GMM clustering algorithm." ) return parser.parse_args() if __name__ == "__main__": args = get_args() print(args) classify_crop_species(args)