atom-detection / atoms_detection /dl_contrastive_pipeline.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
from typing import List
import argparse
import os
from atoms_detection.create_crop_dataset import create_contrastive_crops_dataset
from atoms_detection.dl_detection import DLDetection
from atoms_detection.dl_detection_with_gmm import DLGMMdetection
from atoms_detection.evaluation import Evaluation
from atoms_detection.training_model import train_model
from utils.paths import (
CROPS_PATH,
CROPS_DATASET,
MODELS_PATH,
LOGS_PATH,
DETECTION_PATH,
PREDS_PATH,
PRED_GT_VIS_PATH,
)
from utils.constants import ModelArgs, Split
from matplotlib import pyplot as plt
import pandas as pd
from PIL import Image
import numpy as np
from visualizations.prediction_gt_images import get_gt_coords
from visualizations.utils import plot_gt_pred_on_img
def dl_full_pipeline(
extension_name: str,
architecture: ModelArgs,
coords_csv: str,
thresholds_list: List[float],
force_create_dataset: bool = False,
force_evaluation: bool = False,
show_sampling_image: bool = False,
train: bool = False,
visualise: bool = False,
upsample: bool = False,
upsample_neg_amount: float = 0,
clip_max: float = 1,
negative_dist: float = 1.1,
):
# Create crops data
crops_folder = CROPS_PATH + f"_{extension_name}"
crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
print(os.path.exists(crops_dataset))
if force_create_dataset or not os.path.exists(crops_dataset):
print("Creating crops dataset...")
create_contrastive_crops_dataset(
crops_folder,
coords_csv,
crops_dataset,
show_sampling_result=show_sampling_image,
pos_data_upsampling=upsample,
neg_upsample_multiplier=upsample_neg_amount,
contrastive_distance_multiplier=negative_dist,
) # , clip=clip_max
# training DL model
ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
if train or not os.path.exists(ckpt_filename):
print("Training DL crops model...")
train_model(architecture, crops_dataset, crops_folder, ckpt_filename)
for threshold in thresholds_list:
inference_cache_path = os.path.join(
PREDS_PATH, f"dl_detection_{extension_name}"
)
detections_path = os.path.join(
DETECTION_PATH,
f"dl_detection_{extension_name}",
f"dl_detection_{extension_name}_{threshold}",
)
if force_evaluation or visualise or not os.path.exists(detections_path):
print(f"Detecting atoms on test data with threshold={threshold}...")
if args.run_gmm_for_multimers:
detection_pipeline = DLGMMdetection
else:
detection_pipeline = DLDetection
detection = detection_pipeline(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=coords_csv,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path,
)
detection.run()
logging_filename = os.path.join(
LOGS_PATH,
f"dl_evaluation_{extension_name}",
f"dl_evaluation_{extension_name}_{threshold}.csv",
)
if force_evaluation or visualise or not os.path.exists(logging_filename):
evaluation = Evaluation(
coords_csv=coords_csv,
predictions_path=detections_path,
logging_filename=logging_filename,
)
evaluation.run()
if visualise:
vis_folder = os.path.join(
PRED_GT_VIS_PATH, f"dl_detection_{extension_name}"
)
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
vis_folder = os.path.join(
vis_folder, f"dl_detection_{extension_name}_{threshold}"
)
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
is_evaluation = True
if is_evaluation:
gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)
for image_path in detection.image_dataset.iterate_data(Split.TEST):
img_name = os.path.split(image_path)[-1]
gt_coords = gt_coords_dict[img_name] if is_evaluation else None
pred_df_path = os.path.join(
detections_path, os.path.splitext(img_name)[0] + ".csv"
)
df_predicted = pd.read_csv(pred_df_path)
pred_coords = [
(row["x"], row["y"]) for _, row in df_predicted.iterrows()
]
img = Image.open(image_path)
img_arr = np.array(img).astype(np.float32)
img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
clean_image_name = os.path.splitext(img_name)[0]
vis_path = os.path.join(vis_folder, f"{clean_image_name}.png")
plt.savefig(
vis_path, bbox_inches="tight", pad_inches=0.0, transparent=True
)
plt.close()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("extension_name", type=str, help="Experiment extension name")
parser.add_argument(
"architecture", type=ModelArgs, choices=ModelArgs, help="Architecture name"
)
parser.add_argument(
"coords_csv", type=str, help="Coordinates CSV file to use as input"
)
parser.add_argument(
"-t" "--thresholds", nargs="+", type=float, help="Threshold value"
)
parser.add_argument(
"-c", type=float, default=1, help="Clipping quantile (0..1]. CURRENTLY USELESS!"
)
parser.add_argument(
"-nd", type=float, default=1.1, help="Negative contrastive crop distance"
)
parser.add_argument("--force_create_dataset", action="store_true")
parser.add_argument("--force_evaluation", action="store_true")
parser.add_argument("--show_sampling_result", action="store_true")
parser.add_argument("--train", action="store_true")
parser.add_argument("--visualise", action="store_true")
parser.add_argument("--upsample", action="store_true")
parser.add_argument(
"--run_gmm_for_multimers",
action="store_true",
help="If selected, a postprocessing will be run to split large atoms (possible multimers) with a GMM",
)
parser.add_argument(
"--upsample_neg",
type=float,
default=0,
help="Upsample amount for negative crops during training",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
print(args)
dl_full_pipeline(
args.extension_name,
args.architecture,
args.coords_csv,
args.t__thresholds,
args.force_create_dataset,
args.force_evaluation,
args.show_sampling_result,
args.train,
args.visualise,
args.upsample,
args.upsample_neg,
args.c,
args.nd,
)