Spaces:
Sleeping
Sleeping
# run VAE + GMM assignement | |
import argparse | |
from typing import List | |
import numpy as np | |
# import rasterio | |
import torch | |
import warnings | |
import os | |
import re | |
import pandas as pd | |
from PIL import Image | |
from sklearn.mixture import GaussianMixture | |
from atoms_detection.create_crop_dataset import create_crop | |
from atoms_detection.vae_utilities.vae_model import rVAE | |
from atoms_detection.vae_utilities.vae_svi_train import init_dataloader, SVItrainer | |
from atoms_detection.image_preprocessing import dl_prepro_image | |
""" | |
Code sourced from: | |
https://colab.research.google.com/github/ziatdinovmax/notebooks_for_medium/blob/main/pyroVAE_MNIST_medium.ipynb | |
""" | |
numbers = re.compile(r'(\d+)') | |
def numericalSort(value): | |
parts = numbers.split(value) | |
parts[1::2] = map(int, parts[1::2]) | |
return parts | |
warnings.filterwarnings("ignore", module="torchvision.datasets") | |
def get_crops_from_prediction_csvs(pred_crop_file): | |
data = pd.read_csv(pred_crop_file) | |
xx = data['x'].values | |
yy = data['y'].values | |
coords = zip(xx,yy) | |
img_file = data['Filename'][0] | |
likelihood = data['Likelihood'].values | |
img_path = os.path.join('data/tif_data', img_file) | |
img = Image.open(img_path) | |
np_img = np.asarray(img).astype(np.float64) | |
np_img = dl_prepro_image(np_img) | |
img = Image.fromarray(np_img) | |
crops = list() | |
coords_list = [] | |
for x, y in coords: | |
coords_list.append([x,y]) | |
new_crop = create_crop(img, x, y) | |
crops.append(new_crop) | |
print(coords_list[0]) | |
print(np_img[0]) | |
coords_array = np.array(coords_list) | |
return crops, coords_array, likelihood, img_file | |
def classify_crop_species(args): | |
# crop_list = get_crops_from_folder(crops_source_folder='./Ni') | |
crop_list, crop_coords, likelihood, img_filename = get_crops_from_prediction_csvs(args.pred_crop_file) | |
crop_tensor = np.array(crop_list) | |
# Assuming crop_tensor is a list or array of Image objects | |
processed_images = [] | |
for image in crop_tensor: | |
# Convert the Image to a NumPy array | |
image_array = np.array(image) | |
# Append the processed image array to the list | |
processed_images.append(image_array) | |
# Convert the processed images list to a NumPy array | |
processed_images = np.array(processed_images) | |
# Convert the processed_images array to float32 | |
processed_images = processed_images.astype(np.float32) | |
#print(processed_images.shape) | |
rvae = rVAE(in_dim=(21, 21), latent_dim=args.latent_dim, coord=args.coord, seed=args.seed) | |
train_data = torch.from_numpy(processed_images).float() | |
# train_data = torch.from_numpy(crop_tensor).float() | |
train_loader = init_dataloader(train_data, batch_size=args.batchsize) | |
latent_crop_tensor = train_vae(rvae, train_data, train_loader, args) | |
gmm = GaussianMixture(n_components=args.n_species, reg_covar=args.GMMcovar, random_state=args.seed).fit( | |
latent_crop_tensor) | |
preds = gmm.predict(latent_crop_tensor) | |
print(preds) | |
pred_proba = gmm.predict_proba(latent_crop_tensor) | |
pred_proba = [pred_proba[i, pred] for i, pred in enumerate(preds)] | |
# To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int) | |
cluster_median_values = list() | |
for k in range(args.n_species): | |
print(k) | |
relevant_crops = processed_images[preds == k] | |
crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0) | |
img_means = [] | |
for crop, q in zip(relevant_crops, crop_95_percentile): | |
if (crop >= q).any(): | |
print(crop.mean()) | |
img_means.append(crop.mean()) | |
#img_means.append(crop.mean(axis=0, where=crop >= q)) | |
cluster_median_value = np.median(np.array(img_means)) | |
cluster_median_values.append(cluster_median_value) | |
sorted_clusters = sorted([(mval, c_id) for c_id, mval in enumerate(cluster_median_values)]) | |
with open(f"data/detection_data/Multimetallic_{img_filename}.csv", "a") as f: | |
f.write("Filename,x,y,Likelihood,cluster,cluster_confidence\n") | |
for _, c_id in sorted_clusters: | |
c_idd = np.array([c_id]) | |
pred_proba = np.array(pred_proba) | |
relevant_crops_coords = crop_coords[preds == c_idd] | |
relevant_crops_likelihood = likelihood[preds == c_idd] | |
relevant_crops_confidence = pred_proba[preds == c_idd] | |
#print(relevant_crops_confidence) | |
for coords, l, c in zip(relevant_crops_coords, relevant_crops_likelihood, relevant_crops_confidence): | |
x, y = coords | |
f.write(f"{img_filename},{x},{y},{l},{c_id},{c}\n") | |
def train_vae(rvae, train_data, train_loader, args): | |
# Initialize SVI trainer | |
trainer = SVItrainer(rvae) | |
for e in range(args.epochs): | |
trainer.step(train_loader, scale_factor=args.scale_factor) | |
trainer.print_statistics() | |
z_mean, z_sd = rvae.encode(train_data) | |
latent_crop_tensor = z_mean | |
return latent_crop_tensor | |
def get_crops_from_folder(crops_source_folder) -> List[np.ndarray]: | |
ffiles = [] | |
files = [] | |
for dirname, dirnames, filenames in os.walk(crops_source_folder): | |
# print path to all subdirectories first. | |
for subdirname in dirnames: | |
files.append(os.path.join(dirname, subdirname)) | |
# print path to all filenames. | |
for filename in filenames: | |
files.append(os.path.join(dirname, filename)) | |
for filename in sorted((filenames), key=numericalSort): | |
ffiles.append(os.path.join(filename)) | |
crops = ffiles | |
# print(len(crops)) | |
path_crops = './Ni/' | |
all_img = [] | |
for i in range(0, len(crops)): | |
src_path = path_crops + crops[i] | |
img = rasterio.open(src_path) | |
test = np.reshape(img.read([1]), (21, 21)) | |
all_img.append(np.array(test)) | |
return all_img | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'pred_crop_file', | |
type=str, | |
help="Path to the CSV of predicted crop locations (eg in data/detection_data/X/Y.csv)" | |
) | |
parser.add_argument( | |
"-latent_dim", | |
type=int, | |
default=50, | |
help="Experiment extension name" | |
) | |
parser.add_argument( | |
"-seed", | |
type=int, | |
default=444, | |
help="Random seed" | |
) | |
parser.add_argument( | |
"-coord", | |
type=int, | |
default=3, | |
help="Amount of equivariances, 0: None,1: Rotational, 2: Translational, 3:Rotational and Translational" | |
) | |
parser.add_argument( | |
"-batchsize", | |
type=int, | |
default=100, | |
help="Batch size for the VAE model" | |
) | |
parser.add_argument( | |
"-epochs", | |
type=int, | |
default=20, | |
help="Number of training epochs for the VAE" | |
) | |
parser.add_argument( | |
"-scale_factor", | |
type=int, | |
default=3, | |
help="Number of training epochs for the VAE" | |
) | |
parser.add_argument( | |
"-n_species", | |
type=int, | |
default=2, | |
help="Number of chemical species expected in the sample." | |
) | |
parser.add_argument( | |
"-GMMcovar", | |
type=float, | |
default=0.0001, | |
help="Regcovar for the training of the GMM clustering algorithm." | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = get_args() | |
print(args) | |
classify_crop_species(args) | |