import argparse import os from enum import Enum import numpy as np import pandas as pd from PIL import Image import matplotlib.pyplot as plt from atoms_detection.dataset import CoordinatesDataset from utils.constants import Split, Catalyst, Method from utils.paths import DETECTION_LOGS, IMG_PATH, PRED_GT_VIS_PATH, PT_DATASET, FE_DATASET, DETECTION_PATH from visualizations.utils import plot_gt_pred_on_img def main(args): catalyst = args.catalyst method = args.method if not os.path.exists(PRED_GT_VIS_PATH): os.makedirs(PRED_GT_VIS_PATH) if catalyst == Catalyst.Pt: coordinates_dataset = CoordinatesDataset(PT_DATASET) if method == Method.DL: detection_path = "data/detection_data/dl_detection_sac_cnn/dl_detection_sac_cnn_0.89" elif method == Method.CV: detection_path = os.path.join(DETECTION_PATH, "cv_detection_trial_0.18") elif method == Method.TEM: detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_pt", "tem_imagenet_pt_denoise-bg_Gen1GaussianMask") else: raise NotImplementedError elif catalyst == Catalyst.Fe: coordinates_dataset = CoordinatesDataset(FE_DATASET) if method == Method.DL: detection_path = os.path.join(DETECTION_PATH, f"dl_fe_detection_trial", f"dl_fe_detection_trial_0.97") elif method == Method.CV: detection_path = os.path.join(DETECTION_PATH, "cv_fe_detection_trial", "cv_fe_detection_trial_0.21") elif method == Method.TEM: detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_fe", "tem_imagenet_fe_denoise-bg_Gen1GaussianMask") else: raise NotImplementedError else: raise NotImplementedError gt_coords_dict = get_gt_coords(coordinates_dataset) for name_file in os.listdir(detection_path): image_name = os.path.splitext(name_file)[0] + ".tif" print(image_name) if image_name not in gt_coords_dict: continue filepath = os.path.join(detection_path, name_file) image_filename = os.path.join(IMG_PATH, image_name) img = Image.open(image_filename) gt_coords = gt_coords_dict[image_name] df_predicted = pd.read_csv(filepath) pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()] img_arr = np.array(img).astype(np.float32) img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) plot_gt_pred_on_img(img_normed, gt_coords, pred_coords) vis_folder = os.path.join(PRED_GT_VIS_PATH, f"{catalyst}-Catalyst_{method}-Method") if not os.path.exists(vis_folder): os.makedirs(vis_folder) clean_image_name = os.path.splitext(image_name)[0] vis_path = os.path.join(vis_folder, f'{clean_image_name}.png') plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) plt.close() def get_gt_coords(coordinates_dataset): gt_coords_dict = {} for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): # orig . image_name = os.path.splitext(os.path.basename(image_path))[0] + ".tif" image_name = os.path.basename(image_path) gt_coords = coordinates_dataset.load_coordinates(coordinates_path) gt_coords_dict[image_name] = gt_coords return gt_coords_dict def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "catalyst", type=Catalyst, choices=Catalyst, help="Select data by catalyst" ) parser.add_argument( "method", type=Method, choices=Method, help="Select method" ) return parser.parse_args() if __name__ == "__main__": args = get_args() main(args)