diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fb1022f945709054fe074e2866505f1ab30bf376 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +# *.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +data +models diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000000000000000000000000000000000..bca14f339e49f92fd16f3b819382171de9aea05c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfcbd61d18be86cd8f2c0a735ff335d6633a2f8e --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +# Atoms Detection + +This contains the code and webapp for the publication *[Quantitative Description of Metal Center Organization and Interactions in Single-Atom Catalysts](https://doi.org/10.1002/adma.202307991)*. + +## Reference + +To cite this work, please use the following: + +``` +K. Rossi, A. Ruiz-Ferrando, D. F. Akl, V. G. Abalos, J. Heras-Domingo, R. Graux, X. Hai, J. Lu, D. Garcia-Gasulla, N. López, J. Pérez-Ramírez, S. Mitchell, Quantitative Description of Metal Center Organization and Interactions in Single-Atom Catalysts. Adv. Mater. 2024, 36, 2307991. https://doi.org/10.1002/adma.202307991 +``` diff --git a/app/app.py b/app/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d751c52b40ebbe04a5c4ada39e09b556fc3b6e90 --- /dev/null +++ b/app/app.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 April 25, 14:39:03 +@last modified : 2023 September 20, 15:35:23 +""" + +# TODO : add the training of the vae +# TODO : add the description of the settings + +import sys + +import numpy as np +from PIL import Image, ImageDraw +import gradio as gr +from tiff_utils import extract_physical_metadata +from dl_inference import inference_fn +from knn import knn, segment_image, bokeh_plot_knn, color_palette + +import tempfile +import shutil +import json +from zipfile import ZipFile +from datetime import datetime + +from collections import namedtuple + +block_state_entry = namedtuple( + "block_state", ["results", "knn_results", "physical_metadata"] +) + +if ".." not in sys.path: + sys.path.append("..") + +from utils.constants import ModelArgs + + +def inf(img, n_species, threshold, architecture): + # Get the coordinates of the atoms + img, results = inference_fn(architecture, img, threshold, n_species=n_species) + draw = ImageDraw.Draw(img) + for (k, v), color in zip(results["species"].items(), color_palette): + color = "#" + "".join([f"{int(255 * x):02x}" for x in color]) + draw.text((5, 5 + 15 * k), f"species {k}", fill=color) + for x, y in v["coords"]: + draw.ellipse( + [x - 5, y - 5, x + 5, y + 5], + outline=color, + width=2, + ) + return img, results + + +def batch_fn(files, n_species, threshold, architecture, block_state): + block_state = {} + if not files: + raise ValueError("No files were uploaded") + + gallery = [] + for file in files: + error_physical_metadata = None + try: + physical_metadata = extract_physical_metadata(file.name) + if physical_metadata.unit != "nm": + raise ValueError(f"Unit of {file.name} is not nm, cannot process it") + except Exception as e: + error_physical_metadata = e + physical_metadata = None + + original_file_name = file.name.split("/")[-1] + img, results = inf(file.name, n_species, threshold, architecture) + mask = segment_image(file.name) + gallery.append((img, original_file_name)) + + if physical_metadata is not None: + factor = 1.0 - np.mean(mask) + scale = physical_metadata.pixel_width + edge = physical_metadata.pixel_width * physical_metadata.width + knn_results = { + k: knn(results["species"][k]["coords"], scale, factor, edge) + for k in results["species"] + } + else: + knn_results = None + + block_state[original_file_name] = block_state_entry( + results, knn_results, physical_metadata + ) + + knn_args = [ + ( + original_file_name, + { + k: block_state[original_file_name].knn_results[k]["distances"] + for k in block_state[original_file_name].knn_results + }, + ) + for original_file_name in block_state + if block_state[original_file_name].knn_results is not None + ] + if len(knn_args) > 0: + bokeh_plot = gr.update( + value=bokeh_plot_knn(knn_args, with_cumulative=True), visible=True + ) + else: + bokeh_plot = gr.update(visible=False) + return ( + gallery, + block_state, + gr.update(visible=True), + bokeh_plot, + gr.HTML.update( + value=f"

{error_physical_metadata}

", + visible=bool(error_physical_metadata), + ), + ) + + +class NumpyEncoder(json.JSONEncoder): + """Special json encoder for numpy types""" + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + +def batch_export_files(gallery, block_state): + # Return images, coords as csv and a zip containing everything + files = [] + tmpdir = tempfile.mkdtemp() + with ZipFile( + f"{tmpdir}/all_results_{datetime.now().isoformat()}.zip", "w" + ) as zipObj: + # Add all metatada + for data_dict, original_file_name in gallery: + file_name = original_file_name.split(".")[0] + + # Save the image + pred_map_path = f"{tmpdir}/pred_map_{file_name}.png" + file_path = data_dict["name"] + shutil.copy(file_path, pred_map_path) + zipObj.write(pred_map_path, arcname=f"{file_name}/pred_map.png") + files.append(pred_map_path) + + # Save the coords + results = block_state[original_file_name].results + coords_path = f"{tmpdir}/coords_{file_name}.csv" + with open(coords_path, "w") as f: + f.write("x,y,likelihood,specie,confidence\n") + for k, v in results["species"].items(): + for (x, y), likelihood, confidence in zip( + v["coords"], v["likelihood"], v["confidence"] + ): + f.write(f"{x},{y},{likelihood},{k},{confidence}\n") + zipObj.write(coords_path, arcname=f"{file_name}/coords.csv") + files.append(coords_path) + + # Save the knn results + if block_state[original_file_name].knn_results is not None: + knn_results = block_state[original_file_name].knn_results + knn_path = f"{tmpdir}/knn_results_{file_name}.json" + with open(knn_path, "w") as f: + json.dump(knn_results, f, cls=NumpyEncoder) + zipObj.write(knn_path, arcname=f"{file_name}/knn_results.json") + files.append(knn_path) + + # Save the physical metadata + if block_state[original_file_name].physical_metadata is not None: + physical_metadata = block_state[original_file_name].physical_metadata + metadata_path = f"{tmpdir}/physical_metadata_{file_name}.json" + with open(metadata_path, "w") as f: + json.dump(physical_metadata._asdict(), f, cls=NumpyEncoder) + zipObj.write( + metadata_path, arcname=f"{file_name}/physical_metadata.json" + ) + files.append(metadata_path) + + files.append(zipObj.filename) + return gr.update(value=files, visible=True) + + +CSS = """ + .header { + display: flex; + justify-content: center; + align-items: center; + padding: var(--block-padding); + border-radius: var(--block-radius); + background: var(--button-secondary-background-hover); + } + + img { + width: 150px; + margin-right: 40px; + } + + .title { + text-align: left; + } + + h1 { + font-size: 36px; + margin-bottom: 10px; + } + + p { + font-size: 18px; + } + + input { + width: 70px; + } + + @media (max-width: 600px) { + h1 { + font-size: 24px; + } + + p { + font-size: 14px; + } + } + +""" + + +with gr.Blocks(css=CSS) as block: + block_state = gr.State({}) + gr.HTML( + """ +
+ + NCCR Catalysis + +
+

Atom Detection

+

Quantitative description of metal center organization in single-atom catalysts

+
+
+ """ + ) + with gr.Row(): + with gr.Column(): + with gr.Row(): + n_species = gr.Number( + label="Number of species", + min=1, + max=10, + value=1, + step=1, + precision=0, + visible=True, + ) + threshold = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.8, + label="Threshold", + visible=True, + ) + architecture = gr.Dropdown( + label="Architecture", + choices=[ + ModelArgs.BASICCNN, + # ModelArgs.RESNET18, + ], + value=ModelArgs.BASICCNN, + visible=False, + ) + files = gr.Files( + label="Images", + file_types=[".tif", ".tiff"], + type="file", + interactive=True, + ) + button = gr.Button(value="Run") + with gr.Column(): + with gr.Tab("Masked prediction") as masked_tab: + masked_prediction_gallery = gr.Gallery( + label="Masked predictions" + ).style(columns=3) + with gr.Tab("Nearest neighbors") as nn_tab: + bokeh_plot = gr.Plot(show_label=False) + error_html = gr.HTML(visible=False) + export_btn = gr.Button(value="Export files", visible=False) + exported_files = gr.File( + label="Exported files", + file_count="multiple", + type="file", + interactive=False, + visible=False, + ) + button.click( + batch_fn, + inputs=[files, n_species, threshold, architecture, block_state], + outputs=[ + masked_prediction_gallery, + block_state, + export_btn, + bokeh_plot, + error_html, + ], + ) + export_btn.click( + batch_export_files, [masked_prediction_gallery, block_state], [exported_files] + ) + with gr.Accordion(label="How to ✨", open=True): + gr.HTML( + """ +
+
    +
  1. Select one or multiple microscopy images as .tiff files 📷🔬
  2. +
  3. Upload individual or multiple .tif images for processing 📤🔢
  4. +
  5. Export the output files. The generated zip archive will contain: +
      +
    • An image with overlayed atomic positions 🌟🔍
    • +
    • A table of atomic positions (in px) along with their probability 📊💎
    • +
    • Physical metadata of the respective images 📄🔍
    • +
    • JSON-formatted plot data 📊📝
    • +
    +
  6. +
+
+ Note + +
+
+ """ + ) + with gr.Accordion(label="Disclaimer and License", open=False): + gr.HTML( + """ +
+

Disclaimer

+

NCCR licenses the Atom Detection Web-App utilisation “as is” with no express or implied warranty of any kind. NCCR specifically disclaims all express or implied warranties to the fullest extent allowed by applicable law, including without limitation all implied warranties of merchantability, title or fitness for any particular purpose or non-infringement. No oral or written information or advice given by the authors shall create or form the basis of any warranty of any kind.

+

License

+

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +
+The software is provided “as is”, without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose and noninfringement. In no event shall the authors or copyright holders be liable for any claim, damages or other liability, whether in an action of contract, tort or otherwise, arising from, out of or in connection with the software or the use or other dealings in the software.

+
+ """ + ) + gr.HTML( + """ +
+

To reference the use of this web app in a publication, please refer to the Atom Detection web app and the development described in this publication: K. Rossi et al. Adv. Mater. 2023, doi:10.1002/adma.202307991.

+
+ """ + ) + + +block.launch( + share=False, + show_error=True, + server_name="0.0.0.0", + server_port=9003, + enable_queue=True, +) diff --git a/app/assets/ETH_Zurich_Logo_black.svg b/app/assets/ETH_Zurich_Logo_black.svg new file mode 100644 index 0000000000000000000000000000000000000000..637fd85c7ef5514a7d5d3373073aa45e2a686bd4 --- /dev/null +++ b/app/assets/ETH_Zurich_Logo_black.svg @@ -0,0 +1,58 @@ + + + + + + image/svg+xml + + ethz_logo_white + + + + + + + ethz_logo_white + + diff --git a/app/assets/logo-ace.png b/app/assets/logo-ace.png new file mode 100644 index 0000000000000000000000000000000000000000..8826f3c1741ceda9dd53d2b5f22f55b7331cd269 Binary files /dev/null and b/app/assets/logo-ace.png differ diff --git a/app/backup_tiff_utils.py b/app/backup_tiff_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca60fa349fdc8c308dfa856c3f1109bb76443927 --- /dev/null +++ b/app/backup_tiff_utils.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 April 25, 11:59:06 +@last modified : 2023 June 20, 15:04:37 +""" + +import re +import imageio +import numpy as np +from collections import namedtuple +from typing import Protocol + +physical_metadata = namedtuple("physical_metadata", ["width", "height", "pixel_width", "pixel_height", "unit"]) + + +class ImageMetadataExtractor(Protocol): + @classmethod + def __call__(cls, image_path:str, strict:bool=True) -> physical_metadata: + ... + +class TIFFMetadataExtractor(ImageMetadataExtractor): + @classmethod + def __call__(cls, image_path:str, strict:bool=True) -> physical_metadata: + """ + Extracts the physical metadata of an image (only tiff for now) + """ + with open(image_path, "rb") as f: + data = f.read() + reader = imageio.get_reader(data, format=".tif") + metadata = reader.get_meta_data() + + if strict and not metadata['is_imagej']: + for key, value in metadata.items(): + if key.startswith("is_") and value == True: # Force bool to be True, because it can also pass the condition while being an random object + raise ValueError(f"The image is not TIFF image, but it seems to be a {key[3:]} image") + raise ValueError("The image is not in TIFF format") + h, w = reader.get_next_data().shape + ipw, iph, _ = metadata['resolution'] + result = re.search(r"unit=(.+)", metadata['description']) + if strict and not result: + raise ValueError(f"No scale unit found in the image description : {metadata['description']}") + unit = result and result.group(1) + return physical_metadata(w, h, 1. / ipw, 1. / iph, unit) + +def extract_physical_metadata(image_path : str, strict:bool=True) -> physical_metadata: + if image_path.endswith(".tif"): + return TIFFMetadataExtractor(image_path, strict) + +def tiff_to_png(image, inplace=True): + img = image if inplace else image.copy() + if np.array(img.getdata()).max() <= 1: + img = img.point(lambda p: p * 255) + return img.convert("RGB") diff --git a/app/dl_inference.py b/app/dl_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c00ef95cfca480b4aa498a9ed745f59a2c8715f6 --- /dev/null +++ b/app/dl_inference.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 March 17, 10:56:06 +@last modified : 2023 July 18, 10:25:32 +""" + +# Naive import of atomdetection, maybe should make a package out of it +from functools import lru_cache +import sys + +if ".." not in sys.path: + sys.path.append("..") + +import os +import torch +import numpy as np +from PIL import Image +from PIL.Image import Image as PILImage +from typing import Union +from utils.constants import ModelArgs +from utils.paths import MODELS_PATH, DATASET_PATH +from atoms_detection.dl_detection import DLDetection +from atoms_detection.evaluation import Evaluation +from tiff_utils import tiff_to_png + +LOGS_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") +VOID_DS = os.path.join(DATASET_PATH, "void.csv") +DET_PATH = os.path.join(LOGS_PATH, "detections") +INF_PATH = os.path.join(LOGS_PATH, "inference_cache") + +from atoms_detection.create_crop_dataset import create_crop +from atoms_detection.vae_svi_train import SVItrainer, init_dataloader +from atoms_detection.vae_model import rVAE +from sklearn.mixture import GaussianMixture + + +@lru_cache(maxsize=100) +def get_vae_model( + in_dim: tuple = (21, 21), + latent_dim: int = 50, + coord: int = 3, + seed: int = 42, +): + return rVAE(in_dim=in_dim, latent_dim=latent_dim, coord=coord, seed=seed) + + +def multimers_classification( + img, + coords, + likelihood, + n_species, + latent_dim: int = 50, + coord: int = 3, + reg_covar: float = 0.0001, + seed: int = 42, + epochs: int = 20, + scale_factor: float = 3.0, + batch_size: int = 100, +): + def get_crops(img, coords): + """Get crops from image and coords""" + crops = np.array( + [np.array(create_crop(Image.fromarray(img), x, y)) for x, y in coords] + ) # TODO : can be optimized if computationally heavy (multithreading) + return crops + + # Get crops to train VAE on + crops = get_crops(img, coords) + # Initialize VAE + rvae = rVAE(in_dim=(21, 21), latent_dim=latent_dim, coord=coord, seed=seed) + + # Train VAE to reconstruct crops + torch_crops = torch.tensor(crops).float() + train_loader = init_dataloader(torch_crops, batch_size=batch_size) + trainer = SVItrainer(rvae) + for e in range(epochs): + trainer.step(train_loader, scale_factor=scale_factor) + trainer.print_statistics() + + # Extract latent space (only mean) from VAE + z_mean, _ = rvae.encode(torch_crops) + + # Cluster latent space with GMM + gmm = GaussianMixture( + n_components=n_species, reg_covar=reg_covar, random_state=seed + ) + preds = gmm.fit_predict(z_mean) + pred_proba = gmm.predict_proba(z_mean) + pred_proba = np.array([pred_proba[i, pred] for i, pred in enumerate(preds)]) + + # To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int) + cluster_median_values = list() + for k in range(n_species): + relevant_crops = crops[preds == k] + crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0) + img_means = [] + for crop, q in zip(relevant_crops, crop_95_percentile): + if (crop >= q).any(): + img_means.append(crop.mean()) + cluster_median_value = np.median(np.array(img_means)) + cluster_median_values.append(cluster_median_value) + # Sort clusters by median value + sorted_clusters = sorted( + [(mval, c_id) for c_id, mval in enumerate(cluster_median_values)] + ) + + # Return results in a dict with cluster id as key + results = {} + for _, c_id in sorted_clusters: + c_idd = np.array([c_id]) + results[c_id] = { + "coords": coords[preds == c_idd], + "likelihood": likelihood[preds == c_idd], + "confidence": pred_proba[preds == c_idd], + } + return results + + +def inference_fn( + architecture: ModelArgs, + image: Union[str, PILImage], + threshold: float, + n_species: int, +): + if architecture != ModelArgs.BASICCNN: + raise ValueError(f"Architecture {architecture} not supported yet") + ckpt_filename = os.path.join( + MODELS_PATH, + { + ModelArgs.BASICCNN: "model_C_NT_CLIP.ckpt", + # ModelArgs.BASICCNN: "model_replicate20.ckpt", + # ModelArgs.RESNET18 "inference_resnet.ckpt", + }[architecture], + ) + detection = DLDetection( + model_name=architecture, + ckpt_filename=ckpt_filename, + dataset_csv=VOID_DS, + threshold=threshold, + detections_path=DET_PATH, + inference_cache_path=INF_PATH, + batch_size=512, + ) + # Force the image to be in float32 because otherwise it will output wrong results (probably due to the median filter) + if type(image) == str: + image = Image.open(image) + img = np.asarray(image, dtype=np.float32) + # if img.max() <= 1: + # raise ValueError("Gradio seems to preprocess badly the tiff images. Did you adapt the preprocessing function as mentionned in the app.py file comments?") + prepro_img, _, pred_map = detection.image_to_pred_map(img, return_intermediate=True) + center_coords_list, likelihood_list = (np.array(x) for x in detection.pred_map_to_atoms(pred_map)) + results = ( + multimers_classification( + img=prepro_img, + coords=center_coords_list, + likelihood=likelihood_list, + n_species=n_species, + ) + if n_species > 1 + else { + 0: { + + "coords": center_coords_list, + "likelihood": likelihood_list, + "confidence": np.ones(len(center_coords_list)), + } + } + ) + for k, v in results.items(): + results[k]["atoms_bbs"] = [ + Evaluation.center_coords_to_bbox(center_coords) + for center_coords in v["coords"] + ] + return tiff_to_png(image), { + "image": tiff_to_png(image), + "pred_map": pred_map, + "species": results, + } + + +if __name__ == "__main__": + from utils.paths import IMG_PATH + + img_path = os.path.join(IMG_PATH, "091_HAADF_15nm_Sample_PtNC_21Oct20.tif") + _ = inference_fn(ModelArgs.BASICCNN, Image.open(img_path), 0.8) diff --git a/app/knn.py b/app/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..0fda181bbf02c0a49a3814ba4d5b26ba3a59f29e --- /dev/null +++ b/app/knn.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 May 16, 16:18:43 +@last modified : 2023 August 07, 11:54:19 +""" + +from typing import List, Tuple + +from PIL import Image +from collections import defaultdict +from tempfile import mktemp +import matplotlib +import numpy as np +import os +import seaborn as sns +from logger import logger + + +matplotlib.use("agg") +import matplotlib.pyplot as plt +from scipy.stats import rayleigh +from sklearn.neighbors import NearestNeighbors + + +def segment_image(filename, alpha=5): + # Get a random image png file + filename = filename.replace(" ", "\ ") + png_img = mktemp(suffix=".png") + segmented_img = mktemp(suffix=".png") + logger.debug(f"Segmenting image {filename}...") + logger.debug(f"Saving image to {png_img}...") + logger.debug(f"Saving segmented image to {segmented_img}...") + try: + # Segment with image magic in the terminal + ret = os.system(f"convert {filename} {png_img}") + if ret != 0: + raise RuntimeError(f"PNG conversion failed with return code {ret}") + ret = os.system( + f"convert {png_img} -alpha on -fill none -fuzz {alpha}% -draw 'color 0,0 replace' {segmented_img}" + ) + if ret != 0: + raise RuntimeError(f"Segmentation failed with return code {ret}") + # Load the image + img = Image.open(segmented_img) + # Get mask from image + mask = np.array(img) == 0 + finally: + # Delete the temporary files + if os.path.exists(png_img): + os.remove(png_img) + if os.path.exists(segmented_img): + os.remove(segmented_img) + return mask + + +condition = lambda x: x < 0.23 + + +def knn(coords: List[Tuple[int, int]], scale: float, factor: float, edge: float): + coords = np.array(coords) # B, 2 + x, y = coords.T * scale + + print(f"edge: {edge}, scale: {scale}, factor: {factor}, edge*scale: {edge*scale}") + # edge = edge * scale + + data = np.c_[x, y] + + neighbors = NearestNeighbors(n_neighbors=2, algorithm="ball_tree").fit(data) + distances = neighbors.kneighbors(data)[0][:, 1] + + # loc, scale = rayleigh.fit(distances, floc=0) + # r_KNN = scale * np.sqrt(np.pi / 2.0) + + lamda_RNN = len(x) / (edge * edge * factor) + r_RNN = 1 / (2 * np.sqrt(lamda_RNN)) + + n_samples = len(distances) + n_multimers = sum(condition(x) for x in distances) + percentage_multimers = 100.0 * n_multimers / n_samples + density = n_samples / (factor * edge**2) + min_dist = min(distances) + μ_dist = np.mean(distances) + + return { + "n_samples": { + "description": "Number of atoms detected (units = #atoms)", + "value": n_samples, + }, + "number of atoms in multimers": { + "description": "Number of atoms detected to belong to a multimer (units = #atoms)", + "value": n_multimers, + }, + "share of multimers": { + "description": "Percentage of atoms that belong to a multimer (units = %)", + "value": percentage_multimers, + }, + "density": { + "description": "Number of atoms / area in the micrograph (units = #atoms/nm²)", + "value": density, + }, + "min_dist": { + "description": "Lowest first nearest neighbour distance detected (units = nm)", + "value": min_dist, + }, + "": { + "description": "Mean first nearest neighbour distance (units = nm)", + "value": μ_dist, + }, + "r_RNN": { + "description": "First neighbour distance expected from a purely random distribution (units = nm)", + "value": r_RNN, + }, + "distances": distances, + } + + +from bokeh.plotting import figure +from bokeh.models import ColumnDataSource, HoverTool +from bokeh.plotting import figure +from scipy.stats import gaussian_kde +from collections import defaultdict + +color_palette = sns.color_palette("Set3")[2:] + + +def bokeh_plot_knn(distances, with_cumulative=False): + """ + Plot the KNN distances for the different images with the possibility to zoom in and out and toggle the lines + """ + p = figure(title="K=1 NN distances", background_fill_color="#fafafa") + p.xaxis.axis_label = "Distances (nm)" + p.yaxis.axis_label = "Density" + p.x_range.start = 0 + + if with_cumulative: + cum_dists = defaultdict(list) + for _, dists in distances: + for specie, dist in dists.items(): + cum_dists[specie].extend(dist) + cum_dists = {specie: np.array(dist) for specie, dist in cum_dists.items()} + distances.append(("Cumulative", cum_dists)) + + plot_dict = defaultdict(dict) + base_colors = color_palette + for (image_name, species_distances), base_color in zip(distances, base_colors): + palette = ( + sns.light_palette( + base_color, n_colors=len(species_distances) + 1, reverse=True + )[1::-1] + if len(species_distances) > 1 + else [base_color] + ) + colors = [ + f"#{int(255*r):02x}{int(255*g):02x}{int(255*b):02x}" for r, g, b in palette + ] + for (specie, dists), color in zip(species_distances.items(), colors): + kde = gaussian_kde(dists) + # Reduce smoothing + kde.set_bandwidth(bw_method=0.1) + x = np.linspace(-0.5, 1.2 * dists.max(), 100) + source = ColumnDataSource( + dict( + x=x, + y=kde(x), + species=[specie] * len(x), + p_below=[np.mean(dists < 0.22)] * len(x), + mean=[np.mean(dists)] * len(x), + filename=[image_name] * len(x), + ) + ) + plot_dict[image_name][specie] = [ + p.line( + line_width=2, + alpha=0.8, + legend_label=image_name, + line_color=color, + source=source, + ), + p.varea( + y1="y", + y2=0, + alpha=0.3, + legend_label=image_name, + source=source, + fill_color=color, + ), + ] + p.add_tools( + HoverTool( + show_arrow=False, + line_policy="next", + tooltips=[ + ("First NN distances < 0.22nm", "@p_below{0.00%}"), + ("", "@mean{0.00} nm"), + ("species", "@species"), + ("filename", "@filename"), + ], + ) + ) + p.legend.click_policy = "hide" + return p diff --git a/app/logger.py b/app/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0a40df8878f1fd66bd813e4eaf174986a0795efb --- /dev/null +++ b/app/logger.py @@ -0,0 +1,14 @@ +import logging + +name = 'atomdetection-app' +logger = logging.getLogger(name) +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + +# Create a file handler and set its level and formatter +file_handler = logging.FileHandler(f'{name}.log') +file_handler.setLevel(logging.DEBUG) +file_handler.setFormatter(formatter) + +# Add the file handler to the logger +logger.addHandler(file_handler) \ No newline at end of file diff --git a/app/tiff_utils.py b/app/tiff_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a91211f6750745324bd9f9ddb485a9383b5ff69 --- /dev/null +++ b/app/tiff_utils.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 April 25, 11:59:06 +@last modified : 2023 September 19, 11:18:36 +""" + +import re +import imageio +import numpy as np +from collections import namedtuple + +physical_metadata = namedtuple("physical_metadata", ["width", "height", "pixel_width", "pixel_height", "unit"]) + +def extract_physical_metadata(image_path : str, strict:bool=True) -> physical_metadata: + """ + Extracts the physical metadata of an image (only tiff for now) + """ + with open(image_path, "rb") as f: + data = f.read() + reader = imageio.get_reader(data, format=".tif") + metadata = reader.get_meta_data() + + if strict and not metadata['is_imagej']: + for key, value in metadata.items(): + if key.startswith("is_") and value == True: # Force bool to be True, because it can also pass the condition while being an random object + raise ValueError(f"The image is not TIFF image, but it seems to be a {key[3:]} image") + raise ValueError("Impossible to extract metadata from the image (ImageJ)") + h, w = reader.get_next_data().shape + ipw, iph, _ = metadata['resolution'] + result = re.search(r"unit=(.+)", metadata['description']) + if strict and not result: + raise ValueError(f"No scale unit found in the image description : {metadata['description']}") + unit = result and result.group(1) + return physical_metadata(w, h, 1. / ipw, 1. / iph, unit) + +def tiff_to_png(image, inplace=True): + img = image if inplace else image.copy() + if np.array(img.getdata()).max() <= 1: + img = img.point(lambda p: p * 255) + return img.convert("RGB") diff --git a/atoms_detection/README.md b/atoms_detection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..463e8ad3c499264dda4cae6bea64ef69dda18563 --- /dev/null +++ b/atoms_detection/README.md @@ -0,0 +1,12 @@ +--- +title: Sss +emoji: 🚀 +colorFrom: yellow +colorTo: red +sdk: gradio +sdk_version: 3.24.1 +app_file: app.py +pinned: false +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/atoms_detection/__init__.py b/atoms_detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/atoms_detection/create_crop_dataset.py b/atoms_detection/create_crop_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..05f320f4f62b7fd2311add2ae6bef2244bb1e4cd --- /dev/null +++ b/atoms_detection/create_crop_dataset.py @@ -0,0 +1,408 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image + +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.dataset import CoordinatesDataset +from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET +from utils.constants import Split, CropsColumns +import matplotlib.pyplot as plt # I don't know why tf but it doesn't work if not here + +np.random.seed(777) + +window_size = (21, 21) +halfx_window = ((window_size[0] - 1) // 2) +halfy_window = ((window_size[1] - 1) // 2) + + +def get_gaussian_kernel(size=21, mean=0, sigma=0.2): + # Initializing value of x-axis and y-axis + # in the range -1 to 1 + x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size)) + dst = np.sqrt(x * x + y * y) + + # Calculating Gaussian array + kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) + return kernel + + +def generate_support_img(coordinates, window_size): + support_img = np.zeros((512, 512)) + kernel = get_gaussian_kernel(size=window_size[0]) + halfx_window = ((window_size[0] - 1) // 2) + halfy_window = ((window_size[1] - 1) // 2) + for x, y in coordinates: + x_range = (x - halfx_window, x + halfx_window + 1) + y_range = (y - halfy_window, y + halfy_window + 1) + + x_diff = [0, 0] + y_diff = [0, 0] + if x_range[0] < 0: + x_diff[0] = 0 - x_range[0] + if x_range[1] > 512: + x_diff[1] = x_range[1] - 512 + if y_range[0] < 0: + y_diff[0] = 0 - y_range[0] + if y_range[1] > 512: + y_diff[1] = y_range[1] - 512 + + real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]] + real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1]) + real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1]) + + support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel + + support_img = support_img.T + return support_img + + +def open_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_img = dl_prepro_image(np_img) + img = Image.fromarray(np_img) + return img + + +def create_crop(img: Image, x_center: int, y_center: int): + crop_coords = ( + x_center - halfx_window, + y_center - halfy_window, + x_center + halfx_window + 1, + y_center + halfy_window + 1 + ) + crop = img.crop(crop_coords) + return crop + + +def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str): + if not os.path.exists(crops_folder): + os.makedirs(crops_folder) + + crop_name_list = [] + orig_name_list = [] + x_list = [] + y_list = [] + label_list = [] + + n_positives = 0 + label = 1 + dataset = CoordinatesDataset(coords_csv) + print('Creating positive crops...') + for data_filename, label_filename in dataset.iterate_data(Split.TRAIN): + if label_filename is None: + continue + + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + + img = open_image(data_filename) + coordinates = dataset.load_coordinates(label_filename) + + for x_center, y_center in coordinates: + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + n_positives += 1 + + label = 0 + no_train_images = dataset.split_length(Split.TRAIN) + neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)] + print('Creating negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for _ in range(no_neg_crops): + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + if support_map is not None: + while support_map[x_rand, y_rand] != 0: + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + df_data = { + CropsColumns.FILENAME: crop_name_list, + CropsColumns.ORIGINAL: orig_name_list, + CropsColumns.X: x_list, + CropsColumns.Y: y_list, + CropsColumns.LABEL: label_list + } + df = pd.DataFrame(df_data, columns=[ + CropsColumns.FILENAME, + CropsColumns.ORIGINAL, + CropsColumns.X, + CropsColumns.Y, + CropsColumns.LABEL + ]) + + df_pos = df[df.Label == 1] + df_neg = df[df.Label == 0] + + pos_len = len(df_pos) + neg_len = len(df_neg) + + pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)]) + neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)]) + pos_train[CropsColumns.SPLIT] = Split.TRAIN + pos_val[CropsColumns.SPLIT] = Split.VAL + pos_test[CropsColumns.SPLIT] = Split.TEST + neg_train[CropsColumns.SPLIT] = Split.TRAIN + neg_val[CropsColumns.SPLIT] = Split.VAL + neg_test[CropsColumns.SPLIT] = Split.TEST + + df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0) + df_with_splits.to_csv(crops_dataset, header=True, index=False) + + +def create_contrastive_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str, + show_sampling_result: bool = False, contrastive_samples_percent: float = 0.25, + contrastive_distance_multiplier: float = 1.1, pos_data_upsampling: bool = False, + pos_upsample_dist: int = 3, neg_upsample_multiplier: float = 0): + global plt # don't ask why. + if not os.path.exists(crops_folder): + os.makedirs(crops_folder) + + crop_name_list = [] + orig_name_list = [] + x_list = [] + y_list = [] + label_list = [] + + n_positives = 0 + label = 1 + dataset = CoordinatesDataset(coords_csv) + print('Creating positive crops...') + firstx, firsty = True, True + for data_filename, label_filename in dataset.iterate_data(Split.TRAIN): + if label_filename is None: + continue + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + + img = open_image(data_filename) + coordinates = dataset.load_coordinates(label_filename) + + for x_center, y_center in coordinates: + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + if firstx: + firstx = False + crop_save(crop, "pos.png") + print('saved') + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + if pos_data_upsampling: + x_rand, y_rand = None, None + while x_rand is None: + rand_angle = np.random.uniform(0, 2 * np.pi) + x_rand = round(pos_upsample_dist * np.cos(rand_angle)) + x_center + y_rand = round(pos_upsample_dist * np.sin(rand_angle)) + y_center + out_of_bounds = x_rand >= img.size[0] or y_rand >= img.size[1] or \ + x_rand < 0 or y_rand < 0 + if out_of_bounds != 0: + x_rand, y_rand = None, None + + crop = create_crop(img, x_rand, y_rand) + crop_name = "{}_{}_{}.tif".format(img_name, x_rand, y_rand) + crop.save(os.path.join(crops_folder, crop_name)) + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + if firsty: + firsty = False + crop_save(crop, "pos_jit.png") + + n_positives += 1 + + label = 0 + no_train_images = dataset.split_length(Split.TRAIN) + contrastive_sampling_distance = (window_size[0] * contrastive_distance_multiplier) // 2 + neg_crops_per_image = [round((n_positives // no_train_images) * (1+neg_upsample_multiplier)) + (1 if x < n_positives % no_train_images else 0) for x in + range(no_train_images)] + neg_non_constrastive_crops_per_image, neg_contrastive_crops_per_image = \ + list(zip(*[(n_crops - round(contrastive_samples_percent * n_crops), + round(contrastive_samples_percent * n_crops)) + for n_crops in neg_crops_per_image])) + firstx, firsty = True, True + # neg_non_constrastive_crops_per_image, neg_contrastive_crops_per_image = 30*[0], 30*[44] + print(contrastive_sampling_distance) + print('Creating contrastive negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), + neg_contrastive_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for idx in np.random.choice(len(coordinates), no_neg_crops): + atom_rand = coordinates[idx] + x_center, y_center = atom_rand + x_rand, y_rand = None, None + if support_map is not None: + retries=0 + while x_rand is None and retries < 50: # Extremely unlikely: sample impossible + retries += 1 + rand_angle = np.random.uniform(0, 2 * np.pi) + x_rand = round(contrastive_sampling_distance * np.cos(rand_angle)) + x_center + y_rand = round(contrastive_sampling_distance * np.sin(rand_angle)) + y_center + out_of_bounds = x_rand >= img.size[0] or y_rand >= img.size[1] or \ + x_rand<0 or y_rand<0 + if out_of_bounds or support_map[x_rand, y_rand] != 0: + x_rand, y_rand = None, None + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + if firsty: + firsty = False + crop_save(crop, "neg_con.png") + + print('Creating non-contrastive negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), + neg_non_constrastive_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for _ in range(no_neg_crops): + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + if support_map is not None: + while support_map[x_rand, y_rand] != 0: + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + if firstx: + firstx = False + crop_save(crop, "neg_ncon.png") + + if show_sampling_result: + # Only works for single img data. + positives = [(x, y) for x,y,l in zip(x_list, y_list, label_list) if l==1] + negatives = [(x, y) for x,y,l in zip(x_list, y_list, label_list) if l==0] + from matplotlib import pyplot as plt + plt.imshow(img) + plt.scatter(*zip(*positives)) + plt.scatter(*zip(*negatives)) + plt.show() + + + + + df_data = { + CropsColumns.FILENAME: crop_name_list, + CropsColumns.ORIGINAL: orig_name_list, + CropsColumns.X: x_list, + CropsColumns.Y: y_list, + CropsColumns.LABEL: label_list + } + df = pd.DataFrame(df_data, columns=[ + CropsColumns.FILENAME, + CropsColumns.ORIGINAL, + CropsColumns.X, + CropsColumns.Y, + CropsColumns.LABEL + ]) + + df_pos = df[df.Label == 1] + df_neg = df[df.Label == 0] + + pos_len = len(df_pos) + neg_len = len(df_neg) + + pos_train, pos_val = np.split(df_pos.sample(frac=1), [int(0.9 * pos_len)]) + neg_train, neg_val = np.split(df_neg.sample(frac=1), [int(0.9 * neg_len)]) + pos_train[CropsColumns.SPLIT] = Split.TRAIN + pos_val[CropsColumns.SPLIT] = Split.VAL + neg_train[CropsColumns.SPLIT] = Split.TRAIN + neg_val[CropsColumns.SPLIT] = Split.VAL + print("Final size for train(P vs N):", len(pos_train), len(neg_train)) + print("Final size for val (P vs N):", len(pos_val), len(neg_val)) + df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val), axis=0) + df_with_splits.to_csv(crops_dataset, header=True, index=False) + + +def crop_save(crop, im_name): + crop = np.array(crop) + crop = (crop + crop.min()) * 500 + crop = Image.fromarray(crop) + crop = crop.convert("L") + crop.save(im_name, 'png') + + +if __name__ == "__main__": + create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET) diff --git a/atoms_detection/create_crop_dataset_1024.py b/atoms_detection/create_crop_dataset_1024.py new file mode 100644 index 0000000000000000000000000000000000000000..768d87b4e1c611c7adfc1423f8c00e0262b38a6a --- /dev/null +++ b/atoms_detection/create_crop_dataset_1024.py @@ -0,0 +1,197 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image + +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.dataset import CoordinatesDataset +from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET +from utils.constants import Split, CropsColumns + + +np.random.seed(777) + +window_size = (21, 21) +halfx_window = ((window_size[0] - 1) // 2) +halfy_window = ((window_size[1] - 1) // 2) + + +def get_gaussian_kernel(size=33, mean=0, sigma=0.2): + # Initializing value of x-axis and y-axis + # in the range -1 to 1 + x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size)) + dst = np.sqrt(x * x + y * y) + + # Calculating Gaussian array + kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) + return kernel + + +def generate_support_img(coordinates, window_size): + support_img = np.zeros((1024, 1024)) + kernel = get_gaussian_kernel(size=window_size[0]) + halfx_window = ((window_size[0] - 1) // 2) + halfy_window = ((window_size[1] - 1) // 2) + for x, y in coordinates: + x_range = (x - halfx_window, x + halfx_window + 1) + y_range = (y - halfy_window, y + halfy_window + 1) + + x_diff = [0, 0] + y_diff = [0, 0] + if x_range[0] < 0: + x_diff[0] = 0 - x_range[0] + if x_range[1] > 1024: + x_diff[1] = x_range[1] - 1024 + if y_range[0] < 0: + y_diff[0] = 0 - y_range[0] + if y_range[1] > 1024: + y_diff[1] = y_range[1] - 1024 + + x_diff = tuple(int(item) for item in x_diff) + y_diff = tuple(int(item) for item in y_diff) + + real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]] + + real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1]) + real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1]) + + real_x_crop = tuple(int(item) for item in real_x_crop) + real_y_crop = tuple(int(item) for item in real_y_crop) + + support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel + + support_img = support_img.T + return support_img + + +def open_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_img = dl_prepro_image(np_img) + img = Image.fromarray(np_img) + return img + + +def create_crop(img: Image, x_center: int, y_center: int): + crop_coords = ( + x_center - halfx_window, + y_center - halfy_window, + x_center + halfx_window + 1, + y_center + halfy_window + 1 + ) + crop = img.crop(crop_coords) + return crop + + +def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str): + if not os.path.exists(crops_folder): + os.makedirs(crops_folder) + + crop_name_list = [] + orig_name_list = [] + x_list = [] + y_list = [] + label_list = [] + + n_positives = 0 + label = 1 + dataset = CoordinatesDataset(coords_csv) + print('Creating positive crops...') + for data_filename, label_filename in dataset.iterate_data(Split.TRAIN): + if label_filename is None: + continue + + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + + img = open_image(data_filename) + coordinates = dataset.load_coordinates(label_filename) + + for x_center, y_center in coordinates: + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + n_positives += 1 + + label = 0 + no_train_images = dataset.split_length(Split.TRAIN) + neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)] + print('Creating negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for _ in range(no_neg_crops): + x_rand = np.random.randint(0, 1024) + y_rand = np.random.randint(0, 1024) + + if support_map is not None: + while support_map[x_rand, y_rand] != 0: + x_rand = np.random.randint(0, 1024) + y_rand = np.random.randint(0, 1024) + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + df_data = { + CropsColumns.FILENAME: crop_name_list, + CropsColumns.ORIGINAL: orig_name_list, + CropsColumns.X: x_list, + CropsColumns.Y: y_list, + CropsColumns.LABEL: label_list + } + df = pd.DataFrame(df_data, columns=[ + CropsColumns.FILENAME, + CropsColumns.ORIGINAL, + CropsColumns.X, + CropsColumns.Y, + CropsColumns.LABEL + ]) + + df_pos = df[df.Label == 1] + df_neg = df[df.Label == 0] + + pos_len = len(df_pos) + neg_len = len(df_neg) + + pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)]) + neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)]) + pos_train[CropsColumns.SPLIT] = Split.TRAIN + pos_val[CropsColumns.SPLIT] = Split.VAL + pos_test[CropsColumns.SPLIT] = Split.TEST + neg_train[CropsColumns.SPLIT] = Split.TRAIN + neg_val[CropsColumns.SPLIT] = Split.VAL + neg_test[CropsColumns.SPLIT] = Split.TEST + + df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0) + df_with_splits.to_csv(crops_dataset, header=True, index=False) + + +if __name__ == "__main__": + create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET) diff --git a/atoms_detection/create_crop_dataset_2048.py b/atoms_detection/create_crop_dataset_2048.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5741faaeba7eca15bd3e665892c1753ff42b48 --- /dev/null +++ b/atoms_detection/create_crop_dataset_2048.py @@ -0,0 +1,197 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image + +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.dataset import CoordinatesDataset +from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET +from utils.constants import Split, CropsColumns + + +np.random.seed(777) + +window_size = (21, 21) +halfx_window = ((window_size[0] - 1) // 2) +halfy_window = ((window_size[1] - 1) // 2) + + +def get_gaussian_kernel(size=33, mean=0, sigma=0.2): + # Initializing value of x-axis and y-axis + # in the range -1 to 1 + x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size)) + dst = np.sqrt(x * x + y * y) + + # Calculating Gaussian array + kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) + return kernel + + +def generate_support_img(coordinates, window_size): + support_img = np.zeros((2048, 2048)) + kernel = get_gaussian_kernel(size=window_size[0]) + halfx_window = ((window_size[0] - 1) // 2) + halfy_window = ((window_size[1] - 1) // 2) + for x, y in coordinates: + x_range = (x - halfx_window, x + halfx_window + 1) + y_range = (y - halfy_window, y + halfy_window + 1) + + x_diff = [0, 0] + y_diff = [0, 0] + if x_range[0] < 0: + x_diff[0] = 0 - x_range[0] + if x_range[1] > 2048: + x_diff[1] = x_range[1] - 2048 + if y_range[0] < 0: + y_diff[0] = 0 - y_range[0] + if y_range[1] > 2048: + y_diff[1] = y_range[1] - 2048 + + x_diff = tuple(int(item) for item in x_diff) + y_diff = tuple(int(item) for item in y_diff) + + real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]] + + real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1]) + real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1]) + + real_x_crop = tuple(int(item) for item in real_x_crop) + real_y_crop = tuple(int(item) for item in real_y_crop) + + support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel + + support_img = support_img.T + return support_img + + +def open_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_img = dl_prepro_image(np_img) + img = Image.fromarray(np_img) + return img + + +def create_crop(img: Image, x_center: int, y_center: int): + crop_coords = ( + x_center - halfx_window, + y_center - halfy_window, + x_center + halfx_window + 1, + y_center + halfy_window + 1 + ) + crop = img.crop(crop_coords) + return crop + + +def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str): + if not os.path.exists(crops_folder): + os.makedirs(crops_folder) + + crop_name_list = [] + orig_name_list = [] + x_list = [] + y_list = [] + label_list = [] + + n_positives = 0 + label = 1 + dataset = CoordinatesDataset(coords_csv) + print('Creating positive crops...') + for data_filename, label_filename in dataset.iterate_data(Split.TRAIN): + if label_filename is None: + continue + + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + + img = open_image(data_filename) + coordinates = dataset.load_coordinates(label_filename) + + for x_center, y_center in coordinates: + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + n_positives += 1 + + label = 0 + no_train_images = dataset.split_length(Split.TRAIN) + neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)] + print('Creating negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for _ in range(no_neg_crops): + x_rand = np.random.randint(0, 2048) + y_rand = np.random.randint(0, 2048) + + if support_map is not None: + while support_map[x_rand, y_rand] != 0: + x_rand = np.random.randint(0, 2048) + y_rand = np.random.randint(0, 2048) + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + df_data = { + CropsColumns.FILENAME: crop_name_list, + CropsColumns.ORIGINAL: orig_name_list, + CropsColumns.X: x_list, + CropsColumns.Y: y_list, + CropsColumns.LABEL: label_list + } + df = pd.DataFrame(df_data, columns=[ + CropsColumns.FILENAME, + CropsColumns.ORIGINAL, + CropsColumns.X, + CropsColumns.Y, + CropsColumns.LABEL + ]) + + df_pos = df[df.Label == 1] + df_neg = df[df.Label == 0] + + pos_len = len(df_pos) + neg_len = len(df_neg) + + pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)]) + neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)]) + pos_train[CropsColumns.SPLIT] = Split.TRAIN + pos_val[CropsColumns.SPLIT] = Split.VAL + pos_test[CropsColumns.SPLIT] = Split.TEST + neg_train[CropsColumns.SPLIT] = Split.TRAIN + neg_val[CropsColumns.SPLIT] = Split.VAL + neg_test[CropsColumns.SPLIT] = Split.TEST + + df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0) + df_with_splits.to_csv(crops_dataset, header=True, index=False) + + +if __name__ == "__main__": + create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET) diff --git a/atoms_detection/create_crop_dataset_512.py b/atoms_detection/create_crop_dataset_512.py new file mode 100644 index 0000000000000000000000000000000000000000..fbef5d927952a9035a518bde83fa5d10e1efb7f4 --- /dev/null +++ b/atoms_detection/create_crop_dataset_512.py @@ -0,0 +1,190 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image + +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.dataset import CoordinatesDataset +from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET +from utils.constants import Split, CropsColumns + + +np.random.seed(777) + +window_size = (21, 21) +halfx_window = ((window_size[0] - 1) // 2) +halfy_window = ((window_size[1] - 1) // 2) + + +def get_gaussian_kernel(size=21, mean=0, sigma=0.2): + # Initializing value of x-axis and y-axis + # in the range -1 to 1 + x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size)) + dst = np.sqrt(x * x + y * y) + + # Calculating Gaussian array + kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) + return kernel + + +def generate_support_img(coordinates, window_size): + support_img = np.zeros((512, 512)) + kernel = get_gaussian_kernel(size=window_size[0]) + halfx_window = ((window_size[0] - 1) // 2) + halfy_window = ((window_size[1] - 1) // 2) + for x, y in coordinates: + x_range = (x - halfx_window, x + halfx_window + 1) + y_range = (y - halfy_window, y + halfy_window + 1) + + x_diff = [0, 0] + y_diff = [0, 0] + if x_range[0] < 0: + x_diff[0] = 0 - x_range[0] + if x_range[1] > 512: + x_diff[1] = x_range[1] - 512 + if y_range[0] < 0: + y_diff[0] = 0 - y_range[0] + if y_range[1] > 512: + y_diff[1] = y_range[1] - 512 + + real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]] + real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1]) + real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1]) + + support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel + + support_img = support_img.T + return support_img + + +def open_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_img = dl_prepro_image(np_img) + img = Image.fromarray(np_img) + return img + + +def create_crop(img: Image, x_center: int, y_center: int): + crop_coords = ( + x_center - halfx_window, + y_center - halfy_window, + x_center + halfx_window + 1, + y_center + halfy_window + 1 + ) + crop = img.crop(crop_coords) + return crop + + +def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str): + if not os.path.exists(crops_folder): + os.makedirs(crops_folder) + + crop_name_list = [] + orig_name_list = [] + x_list = [] + y_list = [] + label_list = [] + + n_positives = 0 + label = 1 + dataset = CoordinatesDataset(coords_csv) + print('Creating positive crops...') + for data_filename, label_filename in dataset.iterate_data(Split.TRAIN): + if label_filename is None: + continue + + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + + img = open_image(data_filename) + coordinates = dataset.load_coordinates(label_filename) + + for x_center, y_center in coordinates: + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + n_positives += 1 + + label = 0 + no_train_images = dataset.split_length(Split.TRAIN) + neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)] + print('Creating negative crops...') + for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image): + print(data_filename) + orig_img_name = os.path.basename(data_filename) + img_name = os.path.splitext(orig_img_name)[0] + img = open_image(data_filename) + + if label_filename: + coordinates = dataset.load_coordinates(label_filename) + support_map = generate_support_img(coordinates, window_size) + else: + support_map = None + + for _ in range(no_neg_crops): + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + if support_map is not None: + while support_map[x_rand, y_rand] != 0: + x_rand = np.random.randint(0, 512) + y_rand = np.random.randint(0, 512) + + x_center, y_center = x_rand, y_rand + + crop = create_crop(img, x_center, y_center) + crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center) + crop.save(os.path.join(crops_folder, crop_name)) + + crop_name_list.append(crop_name) + orig_name_list.append(orig_img_name) + x_list.append(x_center) + y_list.append(y_center) + label_list.append(label) + + df_data = { + CropsColumns.FILENAME: crop_name_list, + CropsColumns.ORIGINAL: orig_name_list, + CropsColumns.X: x_list, + CropsColumns.Y: y_list, + CropsColumns.LABEL: label_list + } + df = pd.DataFrame(df_data, columns=[ + CropsColumns.FILENAME, + CropsColumns.ORIGINAL, + CropsColumns.X, + CropsColumns.Y, + CropsColumns.LABEL + ]) + + df_pos = df[df.Label == 1] + df_neg = df[df.Label == 0] + + pos_len = len(df_pos) + neg_len = len(df_neg) + + pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)]) + neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)]) + pos_train[CropsColumns.SPLIT] = Split.TRAIN + pos_val[CropsColumns.SPLIT] = Split.VAL + pos_test[CropsColumns.SPLIT] = Split.TEST + neg_train[CropsColumns.SPLIT] = Split.TRAIN + neg_val[CropsColumns.SPLIT] = Split.VAL + neg_test[CropsColumns.SPLIT] = Split.TEST + + df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0) + df_with_splits.to_csv(crops_dataset, header=True, index=False) + + +if __name__ == "__main__": + create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET) diff --git a/atoms_detection/cv_detection.py b/atoms_detection/cv_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7fc276b51cf1a1a09e0dfdb27a041719b80086 --- /dev/null +++ b/atoms_detection/cv_detection.py @@ -0,0 +1,31 @@ +import cv2 +import numpy as np + +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.detection import Detection + + +class CVDetection(Detection): + + @staticmethod + def get_gaussian_kernel(size=21, mean=0, sigma=0.22, offset=0.0): + # Initializing value of x-axis and y-axis + # in the range -1 to 1 + x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size)) + dst = np.sqrt(x * x + y * y) + # Calculating Gaussian array + kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) - offset + return kernel + + def filter_image(self, img_arr: np.ndarray, **kwargs): + gauss_kernel = self.get_gaussian_kernel(**kwargs) + max_kernel_value = gauss_kernel.flatten().sum() + filtered_img = cv2.filter2D(img_arr, -1, gauss_kernel) + filtered_img /= max_kernel_value + return filtered_img + + def image_to_pred_map(self, img: np.ndarray, img_filename=None) -> np.ndarray: + prepro_img = dl_prepro_image(img) + filtered_img = self.filter_image(prepro_img) + filtered_img = filtered_img.transpose() + return filtered_img diff --git a/atoms_detection/cv_fe_detection_evaluation.py b/atoms_detection/cv_fe_detection_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..b2382972c370fb6545e114f3982f12753625e568 --- /dev/null +++ b/atoms_detection/cv_fe_detection_evaluation.py @@ -0,0 +1,37 @@ +import os + +from atoms_detection.cv_detection import CVDetection +from atoms_detection.evaluation import Evaluation +from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, DATASET_PATH +from utils.constants import ModelArgs + + +extension_name = "trial" +threshold = 0.21 +architecture = ModelArgs.BASICCNN +ckpt_filename = os.path.join(MODELS_PATH, "basic_replicate.ckpt") +dataset_csv = os.path.join(DATASET_PATH, "Fe_dataset.csv") + + +inference_cache_path = os.path.join(PREDS_PATH, f"cv_fe_detection_{extension_name}") + +for threshold in [0.1, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25]: + detections_path = os.path.join(DETECTION_PATH, f"cv_fe_detection_{extension_name}", + f"cv_fe_detection_{extension_name}_{threshold}") + print(f"Detecting atoms on test data with threshold={threshold}...") + detection = CVDetection( + dataset_csv=dataset_csv, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path + ) + detection.run() + + logging_filename = os.path.join(LOGS_PATH, f"cv_fe_evaluation_{extension_name}", + f"cv_fe_evaluation_{extension_name}_{threshold}.csv") + evaluation = Evaluation( + coords_csv=dataset_csv, + predictions_path=detections_path, + logging_filename=logging_filename + ) + evaluation.run() diff --git a/atoms_detection/cv_full_pipeline.py b/atoms_detection/cv_full_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cd744d5a60c5436de6cc11a08a41ce892dd1a1 --- /dev/null +++ b/atoms_detection/cv_full_pipeline.py @@ -0,0 +1,71 @@ +from typing import List + +import argparse +import os + +from atoms_detection.cv_detection import CVDetection +from atoms_detection.evaluation import Evaluation +from utils.paths import LOGS_PATH, DETECTION_PATH, PREDS_PATH + + +def cv_full_pipeline( + extension_name: str, + coords_csv: str, + thresholds_list: List[float], + force: bool = False +): + + # DL Detection & Evaluation + for threshold in thresholds_list: + inference_cache_path = os.path.join(PREDS_PATH, f"cv_detection_{extension_name}") + detections_path = os.path.join(DETECTION_PATH, f"cv_detection_{extension_name}_{threshold}") + if force or not os.path.exists(detections_path): + print(f"Detecting atoms on test data with threshold={threshold}...") + detection = CVDetection( + dataset_csv=coords_csv, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path + ) + detection.run() + + logging_filename = os.path.join(LOGS_PATH, f"cv_detection_{extension_name}_{threshold}.csv") + if force or not os.path.exists(logging_filename): + evaluation = Evaluation( + coords_csv=coords_csv, + predictions_path=detections_path, + logging_filename=logging_filename + ) + evaluation.run() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "extension_name", + type=str, + help="Experiment extension name" + ) + parser.add_argument( + "coords_csv", + type=str, + help="Coordinates CSV file to use as input" + ) + parser.add_argument( + "-t" + "--thresholds", + nargs="+", + type=float, + help="Coordinates CSV file to use as input" + ) + parser.add_argument( + "--force", + action="store_true" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print(args) + cv_full_pipeline(args.extension_name, args.coords_csv, args.t__thresholds, args.force) diff --git a/atoms_detection/dataset.py b/atoms_detection/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b3f6e11eebd8bf341952675e4089ac3b7149b4 --- /dev/null +++ b/atoms_detection/dataset.py @@ -0,0 +1,265 @@ +from typing import List, Tuple +import os +import glob + +import numpy as np +import pandas as pd +from PIL import Image +from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter +from torch.utils.data import Dataset +from torchvision import transforms + +from utils.constants import Split, Columns, CropsColumns, ProbsColumns +from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET + + +class ImageClassificationDataset(Dataset): + + def __init__(self, image_paths, image_labels, include_filename=False): + self.image_paths = image_paths + self.image_labels = image_labels + self.include_filename = include_filename + self.transform = transforms.Compose([ + transforms.ToTensor() + # transforms.Normalize(mean=[0.5], std=[0.5]) + ]) + + def get_n_labels(self): + return len(set(self.image_labels)) + + def __len__(self): + return len(self.image_paths) + + @staticmethod + def load_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_bg = median_filter(np_img, size=(40, 40)) + np_clean = np_img - np_bg + np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) + return np_normed + + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = self.load_image(img_path) + image = self.transform(image) + label = self.image_labels[idx] + + if self.include_filename: + return image, label, os.path.basename(img_path) + else: + return image, label + + @staticmethod + def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: + raise NotImplementedError + + @classmethod + def train_dataset(cls, **kwargs): + filenames, labels = cls.get_filenames_labels(Split.TRAIN) + return cls(filenames, labels, **kwargs) + + @classmethod + def val_dataset(cls, **kwargs): + filenames, labels = cls.get_filenames_labels(Split.VAL) + return cls(filenames, labels, **kwargs) + + @classmethod + def test_dataset(cls, **kwargs): + filenames, labels = cls.get_filenames_labels(Split.TEST) + return cls(filenames, labels, **kwargs) + + +class HaadfDataset(ImageClassificationDataset): + @staticmethod + def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: + df = pd.read_csv(HAADF_DATASET) + split_df = df[df[Columns.SPLIT] == split] + filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list() + labels = (split_df[Columns.LABEL]).to_list() + return filenames, labels + + +class ImageDataset: + FILENAME_COL = "Filename" + SPLIT_COL = "Split" + RULER_UNITS = "Ruler Units" + + def __init__(self, dataset_csv: str): + self.df = pd.read_csv(dataset_csv) + + def iterate_data(self, split: Split): + df = self.df[self.df[self.SPLIT_COL] == split] + for idx, row in df.iterrows(): + image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) + yield image_filename + + def get_ruler_units_by_img_name(self, name): + print(name) + return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0] + + + +class CoordinatesDataset: + FILENAME_COL = "Filename" + COORDS_COL = "Coords" + SPLIT_COL = "Split" + + def __init__(self, coord_image_csv: str): + self.df = pd.read_csv(coord_image_csv) + + def iterate_data(self, split: Split): + df = self.df[self.df[self.SPLIT_COL] == split] + for idx, row in df.iterrows(): + image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) + if isinstance(row[self.COORDS_COL], str): + coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL]) + else: + coords_filename = None + yield image_filename, coords_filename + + @staticmethod + def load_coordinates(label_filename: str) -> List[Tuple[int, int]]: + atom_coordinates = pd.read_csv(label_filename) + return list(zip(atom_coordinates['X'], atom_coordinates['Y'])) + + def split_length(self, split: Split): + df = self.df[self.df[self.SPLIT_COL] == split] + return len(df) + + +class HaadfCoordinates(CoordinatesDataset): + def __init__(self): + super().__init__(coord_image_csv=PT_DATASET) + + +class CropsDataset(ImageClassificationDataset): + @staticmethod + def get_filenames_labels(split: Split): + df = pd.read_csv(CROPS_DATASET) + split_df = df[df[CropsColumns.SPLIT] == split] + filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list() + labels = (split_df[CropsColumns.LABEL]).to_list() + return filenames, labels + + +class CropsCustomDataset(ImageClassificationDataset): + + @staticmethod + def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str): + df = pd.read_csv(crops_dataset) + split_df = df[df[CropsColumns.SPLIT] == split] + filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list() + labels = (split_df[CropsColumns.LABEL]).to_list() + return filenames, labels + + +class ProbsDataset(ImageClassificationDataset): + @staticmethod + def get_filenames_labels(split: Split): + df = pd.read_csv(PROBS_DATASET) + split_df = df[df[ProbsColumns.SPLIT] == split] + filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list() + labels = (split_df[ProbsColumns.LABEL]).to_list() + return filenames, labels + + +class SlidingCropDataset(Dataset): + + def __init__(self, tif_filename, include_coords=True): + self.filename = tif_filename + self.include_coords = include_coords + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]) + ]) + + self.n_labels = 2 + self.step_size = 2 + self.window_size = (21, 21) + self.loaded_crops = [] + self.loaded_coords = [] + self.load_crops() + + def sliding_window(self, image): + # slide a window across the image + for x in range(0, image.shape[0] - self.window_size[0], self.step_size): + for y in range(0, image.shape[1] - self.window_size[1], self.step_size): + # yield the current window + center_x = x + ((self.window_size[0] - 1) // 2) + center_y = y + ((self.window_size[1] - 1) // 2) + yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]] + + @staticmethod + def load_image(img_filename): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + np_bg = median_filter(np_img, size=(40, 40)) + np_clean = np_img - np_bg + np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) + return np_normed + + def load_crops(self): + img = self.load_image(self.filename) + for x_center, y_center, img_crop in self.sliding_window(img): + self.loaded_crops.append(img_crop) + self.loaded_coords.append((x_center, y_center)) + + def get_n_labels(self): + return self.n_labels + + def __len__(self): + return len(self.loaded_crops) + + def __getitem__(self, idx): + crop = self.loaded_crops[idx] + x, y = self.loaded_coords[idx] + crop = self.transform(crop) + + return crop, x, y + + +def get_image_path_without_coords(split: str or None = None): + coords_prefix_set = set() + for coords_name in os.listdir(COORDS_PATH): + coord_prefix = coords_name.split('_')[0] + coords_prefix_set.add(coord_prefix) + + all_prefixes_set = set() + for tif_name in os.listdir(IMG_PATH): + coord_prefix = tif_name.split('_')[0] + all_prefixes_set.add(coord_prefix) + + if split == Split.TRAIN: + missing_prefixes = coords_prefix_set + elif split == Split.TEST: + missing_prefixes = all_prefixes_set - coords_prefix_set + elif split is None: + missing_prefixes = all_prefixes_set + else: + raise ValueError + tif_filenames_list = [] + labels_list = [] + for prefix in missing_prefixes: + filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*')) + if len(filename_matches) == 0: + continue + pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename] + neg_filenames = [filename for filename in filename_matches if '_NC' in filename] + + if len(pos_filenames) > 0: + pos_filename = sorted(pos_filenames)[-1] + tif_filenames_list.append(pos_filename) + labels_list.append(1) + if len(neg_filenames) > 0: + neg_filename = sorted(neg_filenames)[-1] + tif_filenames_list.append(neg_filename) + labels_list.append(0) + + return tif_filenames_list, labels_list + + +if __name__ == "__main__": + filenames_list = get_image_path_without_coords() + filename = filenames_list[0] + dataset = SlidingCropDataset(filename) diff --git a/atoms_detection/detection.py b/atoms_detection/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..4e83d2d75a29e7608834225ba25ab7181208f1ec --- /dev/null +++ b/atoms_detection/detection.py @@ -0,0 +1,96 @@ +from typing import Tuple, List + +import os +from hashlib import sha1 + +import numpy as np +from PIL import Image +from scipy.ndimage import label + +from utils.constants import Split +from utils.paths import PREDS_PATH +from atoms_detection.dataset import ImageDataset + + +class Detection: + def __init__(self, dataset_csv: str, threshold: float, detections_path: str, inference_cache_path: str): + self.image_dataset = ImageDataset(dataset_csv) + self.threshold = threshold + self.detections_path = detections_path + self.inference_cache_path = inference_cache_path + self.currently_processing = None + if not os.path.exists(self.detections_path): + os.makedirs(self.detections_path) + if not os.path.exists(self.inference_cache_path): + os.makedirs(self.inference_cache_path) + + def image_to_pred_map(self, img: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def pred_map_to_atoms(self, pred_map: np.ndarray) -> Tuple[List[Tuple[int, int]], List[float]]: + pred_mask = pred_map > self.threshold + labeled_array, num_features = label(pred_mask) + + # Convert labelled_array to indexes + center_coords_list = [] + likelihood_list = [] + for label_idx in range(num_features+1): + if label_idx == 0: + continue + label_mask = np.where(labeled_array == label_idx) + likelihood = np.max(pred_map[label_mask]) + likelihood_list.append(likelihood) + # label_size = len(label_mask[0]) + # print(f"\t\tAtom {label_idx}: {label_size}") + atom_bbox = (label_mask[1].min(), label_mask[1].max(), label_mask[0].min(), label_mask[0].max()) + center_coord = self.bbox_to_center_coords(atom_bbox) + center_coords_list.append(center_coord) + return center_coords_list, likelihood_list + + def detect_atoms(self, img_filename: str) -> Tuple[List[Tuple[int, int]], List[float]]: + img_hash = self.cache_image_identifier(img_filename) + prediciton_cache = os.path.join(self.inference_cache_path, f"{img_hash}.npy") + if not os.path.exists(prediciton_cache): + self.currently_processing = os.path.split(img_filename)[-1] + img = self.open_image(img_filename) + pred_map = self.image_to_pred_map(img) + np.save(prediciton_cache, pred_map) + else: + pred_map = np.load(prediciton_cache) + center_coords_list, likelihood_list = self.pred_map_to_atoms(pred_map) + return center_coords_list, likelihood_list + + def cache_image_identifier(self, img_filename): + return sha1(img_filename.encode()).hexdigest() + + @staticmethod + def bbox_to_center_coords(bbox: Tuple[int, int, int, int]) -> Tuple[int, int]: + x_center = (bbox[0] + bbox[1]) // 2 + y_center = (bbox[2] + bbox[3]) // 2 + return x_center, y_center + + @staticmethod + def open_image(img_filename: str): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + return np_img + + def run_single(self, image_path: str): + print(f"Running detection on {os.path.basename(image_path)}") + center_coords_list, likelihood_list = self.detect_atoms(image_path) + + image_filename = os.path.basename(image_path) + img_name = os.path.splitext(image_filename)[0] + detection_csv = os.path.join(self.detections_path, f"{img_name}.csv") + with open(detection_csv, "w") as _csv: + _csv.write("Filename,x,y,Likelihood\n") + for (x, y), likelihood in zip(center_coords_list, likelihood_list): + _csv.write(f"{image_filename},{x},{y},{likelihood}\n") + return center_coords_list, likelihood_list + + def run(self): + if not os.path.exists(self.detections_path): + os.makedirs(self.detections_path) + + for image_path in self.image_dataset.iterate_data(Split.TEST): + run_single(image_path) diff --git a/atoms_detection/dl_contrastive_pipeline.py b/atoms_detection/dl_contrastive_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2230c48ac6cddcdcd05dd579e613120a5f856b7a --- /dev/null +++ b/atoms_detection/dl_contrastive_pipeline.py @@ -0,0 +1,198 @@ +from typing import List + +import argparse +import os + +from atoms_detection.create_crop_dataset import create_contrastive_crops_dataset +from atoms_detection.dl_detection import DLDetection +from atoms_detection.dl_detection_with_gmm import DLGMMdetection +from atoms_detection.evaluation import Evaluation +from atoms_detection.training_model import train_model +from utils.paths import ( + CROPS_PATH, + CROPS_DATASET, + MODELS_PATH, + LOGS_PATH, + DETECTION_PATH, + PREDS_PATH, + PRED_GT_VIS_PATH, +) +from utils.constants import ModelArgs, Split +from matplotlib import pyplot as plt +import pandas as pd +from PIL import Image +import numpy as np + +from visualizations.prediction_gt_images import get_gt_coords +from visualizations.utils import plot_gt_pred_on_img + + +def dl_full_pipeline( + extension_name: str, + architecture: ModelArgs, + coords_csv: str, + thresholds_list: List[float], + force_create_dataset: bool = False, + force_evaluation: bool = False, + show_sampling_image: bool = False, + train: bool = False, + visualise: bool = False, + upsample: bool = False, + upsample_neg_amount: float = 0, + clip_max: float = 1, + negative_dist: float = 1.1, +): + # Create crops data + crops_folder = CROPS_PATH + f"_{extension_name}" + crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv") + print(os.path.exists(crops_dataset)) + if force_create_dataset or not os.path.exists(crops_dataset): + print("Creating crops dataset...") + create_contrastive_crops_dataset( + crops_folder, + coords_csv, + crops_dataset, + show_sampling_result=show_sampling_image, + pos_data_upsampling=upsample, + neg_upsample_multiplier=upsample_neg_amount, + contrastive_distance_multiplier=negative_dist, + ) # , clip=clip_max + # training DL model + ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt") + if train or not os.path.exists(ckpt_filename): + print("Training DL crops model...") + train_model(architecture, crops_dataset, crops_folder, ckpt_filename) + + for threshold in thresholds_list: + inference_cache_path = os.path.join( + PREDS_PATH, f"dl_detection_{extension_name}" + ) + detections_path = os.path.join( + DETECTION_PATH, + f"dl_detection_{extension_name}", + f"dl_detection_{extension_name}_{threshold}", + ) + if force_evaluation or visualise or not os.path.exists(detections_path): + print(f"Detecting atoms on test data with threshold={threshold}...") + if args.run_gmm_for_multimers: + detection_pipeline = DLGMMdetection + else: + detection_pipeline = DLDetection + + detection = detection_pipeline( + model_name=architecture, + ckpt_filename=ckpt_filename, + dataset_csv=coords_csv, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path, + ) + detection.run() + + logging_filename = os.path.join( + LOGS_PATH, + f"dl_evaluation_{extension_name}", + f"dl_evaluation_{extension_name}_{threshold}.csv", + ) + if force_evaluation or visualise or not os.path.exists(logging_filename): + evaluation = Evaluation( + coords_csv=coords_csv, + predictions_path=detections_path, + logging_filename=logging_filename, + ) + evaluation.run() + if visualise: + vis_folder = os.path.join( + PRED_GT_VIS_PATH, f"dl_detection_{extension_name}" + ) + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + + vis_folder = os.path.join( + vis_folder, f"dl_detection_{extension_name}_{threshold}" + ) + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + is_evaluation = True + if is_evaluation: + gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset) + + for image_path in detection.image_dataset.iterate_data(Split.TEST): + img_name = os.path.split(image_path)[-1] + gt_coords = gt_coords_dict[img_name] if is_evaluation else None + pred_df_path = os.path.join( + detections_path, os.path.splitext(img_name)[0] + ".csv" + ) + df_predicted = pd.read_csv(pred_df_path) + pred_coords = [ + (row["x"], row["y"]) for _, row in df_predicted.iterrows() + ] + img = Image.open(image_path) + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + + plot_gt_pred_on_img(img_normed, gt_coords, pred_coords) + clean_image_name = os.path.splitext(img_name)[0] + vis_path = os.path.join(vis_folder, f"{clean_image_name}.png") + plt.savefig( + vis_path, bbox_inches="tight", pad_inches=0.0, transparent=True + ) + plt.close() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("extension_name", type=str, help="Experiment extension name") + parser.add_argument( + "architecture", type=ModelArgs, choices=ModelArgs, help="Architecture name" + ) + parser.add_argument( + "coords_csv", type=str, help="Coordinates CSV file to use as input" + ) + parser.add_argument( + "-t" "--thresholds", nargs="+", type=float, help="Threshold value" + ) + parser.add_argument( + "-c", type=float, default=1, help="Clipping quantile (0..1]. CURRENTLY USELESS!" + ) + parser.add_argument( + "-nd", type=float, default=1.1, help="Negative contrastive crop distance" + ) + parser.add_argument("--force_create_dataset", action="store_true") + parser.add_argument("--force_evaluation", action="store_true") + parser.add_argument("--show_sampling_result", action="store_true") + parser.add_argument("--train", action="store_true") + parser.add_argument("--visualise", action="store_true") + parser.add_argument("--upsample", action="store_true") + parser.add_argument( + "--run_gmm_for_multimers", + action="store_true", + help="If selected, a postprocessing will be run to split large atoms (possible multimers) with a GMM", + ) + parser.add_argument( + "--upsample_neg", + type=float, + default=0, + help="Upsample amount for negative crops during training", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print(args) + dl_full_pipeline( + args.extension_name, + args.architecture, + args.coords_csv, + args.t__thresholds, + args.force_create_dataset, + args.force_evaluation, + args.show_sampling_result, + args.train, + args.visualise, + args.upsample, + args.upsample_neg, + args.c, + args.nd, + ) diff --git a/atoms_detection/dl_detection.py b/atoms_detection/dl_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..01df7b79766923d1af7a5c17f2666b8c5167a606 --- /dev/null +++ b/atoms_detection/dl_detection.py @@ -0,0 +1,118 @@ +import os +from typing import Tuple, List + +import torch +import numpy as np +import torch.nn +import torch.nn.functional + +from atoms_detection.detection import Detection +from atoms_detection.training_model import model_pipeline +from atoms_detection.image_preprocessing import dl_prepro_image +from utils.constants import ModelArgs +from utils.paths import PREDS_PATH + + +class DLDetection(Detection): + def __init__(self, + model_name: ModelArgs, + ckpt_filename: str, + dataset_csv: str, + threshold: float, + detections_path: str, + inference_cache_path: str, + batch_size: int = 64, + ): + self.model_name = model_name + self.ckpt_filename = ckpt_filename + self.device = self.get_torch_device() + self.batch_size = batch_size + + self.stride = 1 + self.padding = 10 + self.window_size = (21, 21) + + super().__init__(dataset_csv, threshold, detections_path, inference_cache_path) + + @staticmethod + def get_torch_device(): + if torch.backends.mps.is_available(): + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + return device + + def sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[int, int, np.ndarray]: + # slide a window across the image + x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2 + y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2 + + for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride): + for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride): + # yield the current window + center_x = x + x_to_center + center_y = y + y_to_center + yield center_x-padding, center_y-padding, image[y:y + self.window_size[1], x:x + self.window_size[0]] + + def batch_sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[List[int], List[int], List[np.ndarray]]: + x_idx_list = [] + y_idx_list = [] + images_list = [] + count = 0 + for _x, _y, _img in self.sliding_window(image, padding=padding): + x_idx_list.append(_x) + y_idx_list.append(_y) + images_list.append(_img) + count += 1 + if count == self.batch_size: + yield x_idx_list, y_idx_list, images_list + x_idx_list = [] + y_idx_list = [] + images_list = [] + count = 0 + if count != 0: + yield x_idx_list, y_idx_list, images_list + + def padding_image(self, img: np.ndarray) -> np.ndarray: + image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2)) + image_padded[self.padding:-self.padding, self.padding:-self.padding] = img + return image_padded + + def load_model(self) -> torch.nn.Module: + checkpoint = torch.load(self.ckpt_filename, map_location=self.device) + + model = model_pipeline[self.model_name](num_classes=2).to(self.device) + model.load_state_dict(checkpoint['state_dict']) + model.eval() + return model + + def images_to_torch_input(self, images_list: List[np.ndarray]) -> torch.Tensor: + expanded_img = np.expand_dims(images_list, axis=1) + input_tensor = torch.from_numpy(expanded_img).float() + input_tensor = input_tensor.to(self.device) + return input_tensor + + def get_prediction_map(self, padded_image: np.ndarray) -> np.ndarray: + _shape = padded_image.shape + pred_map = np.zeros((_shape[0] - self.padding*2, _shape[1] - self.padding*2)) + model = self.load_model() + for x_idxs, y_idxs, image_crops in self.batch_sliding_window(padded_image, padding=self.padding): + torch_input = self.images_to_torch_input(image_crops) + output = model(torch_input) + pred_prob = torch.nn.functional.softmax(output, 1) + pred_prob = pred_prob.detach().cpu().numpy()[:, 1] + pred_map[np.array(y_idxs), np.array(x_idxs)] = pred_prob + return pred_map + + def image_to_pred_map(self, img: np.ndarray, return_intermediate: bool = False) -> np.ndarray: + preprocessed_img = dl_prepro_image(img) + print(f"preprocessed_img.shape: {preprocessed_img.shape}, μ: {np.mean(preprocessed_img)}, σ: {np.std(preprocessed_img)}") + padded_image = self.padding_image(preprocessed_img) + print(f"padded_image.shape: {padded_image.shape}, μ: {np.mean(padded_image)}, σ: {np.std(padded_image)}") + pred_map = self.get_prediction_map(padded_image) + print(f"pred_map.shape: {pred_map.shape}, μ: {np.mean(pred_map)}, σ: {np.std(pred_map)}") + if return_intermediate: + return preprocessed_img, padded_image, pred_map + return pred_map diff --git a/atoms_detection/dl_detection_evaluation.py b/atoms_detection/dl_detection_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..efc9d41fc62b50254e4cd0e58ce4dfeb9c7197d8 --- /dev/null +++ b/atoms_detection/dl_detection_evaluation.py @@ -0,0 +1,130 @@ +import argparse +import os +import random +import numpy as np +from matplotlib import pyplot as plt + +from PIL import Image +from networkx.tests.test_convert_pandas import pd + +from atoms_detection.dl_detection import DLDetection +from atoms_detection.dl_detection_scaled import DLScaled +from atoms_detection.evaluation import Evaluation +from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH +from utils.constants import ModelArgs, Split +from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords + + +def detection_pipeline(args): + extension_name = args.extension_name + print(f"Storing at {extension_name}") + architecture = ModelArgs.BASICCNN + ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt") + + inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}") + + testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99] + testing_thresholds = [0.8, 0.85, 0.9, 0.95] + for threshold in testing_thresholds: + detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", + f"dl_detection_{extension_name}_{threshold}") + print(f"Detecting atoms on test data with threshold={threshold}...") + if args.experimental_rescale: + print("Using experimental ruler rescaling") + detection = DLScaled( + model_name=architecture, + ckpt_filename=ckpt_filename, + dataset_csv=args.dataset, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path + ) + else: + detection = DLDetection( + model_name=architecture, + ckpt_filename=ckpt_filename, + dataset_csv=args.dataset, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path + ) + detection.run() + if args.eval: + logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}", + f"dl_detection_{extension_name}_{threshold}.csv") + evaluation = Evaluation( + coords_csv=args.dataset, + predictions_path=detections_path, + logging_filename=logging_filename + ) + evaluation.run() + if args.visualise: + + vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}") + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + + vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}") + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + + if args.eval: + gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset) + + for image_path in detection.image_dataset.iterate_data(Split.TEST): + print(image_path) + img_name = os.path.split(image_path)[-1] + gt_coords = gt_coords_dict[img_name] if args.eval else None + pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv') + df_predicted = pd.read_csv(pred_df_path) + pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()] + img = Image.open(image_path) + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + + plot_gt_pred_on_img(img_normed, gt_coords, pred_coords) + clean_image_name = os.path.splitext(img_name)[0] + vis_path = os.path.join(vis_folder, f'{clean_image_name}.png') + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) + plt.close() + + print(f"Experiment {extension_name} completed") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "extension_name", + type=str, + help="Experiment extension name" + ) + parser.add_argument( + "dataset", + type=str, + help="Dataset file upon which to do inference" + ) + parser.add_argument( + "--eval", + action='store_true', + help="Whether to perform evaluation after inference", + default=False + ) + parser.add_argument( + "--visualise", + action='store_true', + help="Whether to store inference results as visual png images", + default=False + ) + parser.add_argument( + "--experimental_rescale", + action='store_true', + help="Whether to rescale inputs based on the ruler of the image as preprocess", + default=False + ) + parser.add_argument('--feature', ) + return parser.parse_args() + + +if __name__=='__main__': + args = get_args() + detection_pipeline(args) diff --git a/atoms_detection/dl_detection_scaled.py b/atoms_detection/dl_detection_scaled.py new file mode 100644 index 0000000000000000000000000000000000000000..15964f87113772112c7e57fd965ae5b69af9b66a --- /dev/null +++ b/atoms_detection/dl_detection_scaled.py @@ -0,0 +1,42 @@ +import os +from _sha1 import sha1 +from typing import Tuple, List + +from PIL import Image + +from atoms_detection.dl_detection import DLDetection +from atoms_detection.image_preprocessing import dl_prepro_image +from utils.constants import ModelArgs +import numpy as np + +class DLScaled(DLDetection): + # Should take into account for the resize: + # Ruler of the image (pixelsxnm) + # Covalent radius + # beam size/voltage/exposure? (can create larger distortions) (Metadata should be in dm3 files, if it can be read) + def __init__(self, + model_name: ModelArgs, + ckpt_filename: str, + dataset_csv: str, + threshold: float, + detections_path: str, + inference_cache_path: str): + super().__init__(model_name, ckpt_filename,dataset_csv,threshold,detections_path, inference_cache_path) + + def image_to_pred_map(self, img: np.ndarray) -> np.ndarray: + ruler_units = self.image_dataset.get_ruler_units_by_img_name(self.currently_processing) + preprocessed_img, scale_factor = dl_prepro_image(img, ruler_units=ruler_units) + padded_image = self.padding_image(preprocessed_img) + pred_map = self.get_prediction_map(padded_image) + + new_dimensions = img.shape[0], img.shape[1] + pred_map = np.array(Image.fromarray(pred_map).resize(new_dimensions)) + return pred_map + + def cache_image_identifier(self, img_filename): + x = sha1((img_filename+'scaled').encode()).hexdigest() + print(x) + return x + + + diff --git a/atoms_detection/dl_detection_with_gmm.py b/atoms_detection/dl_detection_with_gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..6761e67b88fbe2c67ae1a23135133fb94debab4d --- /dev/null +++ b/atoms_detection/dl_detection_with_gmm.py @@ -0,0 +1,151 @@ +from typing import Tuple, List + +from atoms_detection.dl_detection import DLDetection +from utils.constants import ModelArgs +from sklearn.mixture import GaussianMixture +from scipy.ndimage import label +import math +import numpy as np + + +class DLGMMdetection(DLDetection): + MAX_SINGLE_ATOM_AREA = 200 + MAX_ATOMS_PER_AREA = 3 + COVARIANCE_TYPE = "full" + + def __init__( + self, + model_name: ModelArgs, + ckpt_filename: str, + dataset_csv: str, + threshold: float, + detections_path: str, + inference_cache_path: str, + covariance_penalisation: float = 0.03, + n_clusters_penalisation: float = 0.33, + distance_penalisation: float = 0.11, + n_samples_per_gmm: int = 600, + ): + super(DLGMMdetection, self).__init__( + model_name, + ckpt_filename, + dataset_csv, + threshold, + detections_path, + inference_cache_path, + ) + self.covariance_penalisation = covariance_penalisation + self.n_clusters_penalisation = n_clusters_penalisation + self.distance_penalisation = distance_penalisation + self.n_samples_per_gmm = n_samples_per_gmm + + def pred_map_to_atoms( + self, pred_map: np.ndarray + ) -> Tuple[List[Tuple[int, int]], List[float]]: + pred_mask = pred_map > self.threshold + labeled_array, num_features = label(pred_mask) + self.current_pred_map = pred_map + + # Convert labelled_array to indexes + center_coords_list = [] + likelihood_list = [] + for label_idx in range(num_features + 1): + if label_idx == 0: + continue + label_mask = np.where(labeled_array == label_idx) + likelihood = np.max(pred_map[label_mask]) + # label_size = len(label_mask[0]) + # print(f"\t\tAtom {label_idx}: {label_size}") + atom_bbox = ( + label_mask[1].min(), + label_mask[1].max(), + label_mask[0].min(), + label_mask[0].max(), + ) + center_coord = self.bbox_to_center_coords(atom_bbox) + center_coords_list += center_coord + pixel_area = (atom_bbox[1] - atom_bbox[0]) * (atom_bbox[3] - atom_bbox[2]) + if pixel_area < self.MAX_SINGLE_ATOM_AREA: + likelihood_list.append(likelihood) + else: + for i in range(0, len(center_coord)): + likelihood_list.append(likelihood) + self.current_pred_map = None + print(f"number for atoms {len(center_coords_list)}") + return center_coords_list, likelihood_list + + def bbox_to_center_coords( + self, bbox: Tuple[int, int, int, int] + ) -> List[Tuple[int, int]]: + pixel_area = (bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) + if pixel_area < self.MAX_SINGLE_ATOM_AREA: + return super().bbox_to_center_coords(bbox) + else: + pmap = self.get_current_prediction_map_region(bbox) + local_atom_center_list = self.run_gmm_pipeline(pmap) + atom_center_list = [ + (x + bbox[0], y + bbox[2]) for x, y in local_atom_center_list + ] + return atom_center_list + + def sample_img_hist(self, img_region): + x_bin_midpoints = list(range(img_region.shape[1])) + y_bin_midpoints = list(range(img_region.shape[0])) + # noinspection PyUnresolvedReferences + cdf = np.cumsum(img_region.ravel()) + cdf = cdf / cdf[-1] + values = np.random.rand(self.n_samples_per_gmm) + # noinspection PyUnresolvedReferences + value_bins = np.searchsorted(cdf, values) + x_idx, y_idx = np.unravel_index( + value_bins, (len(x_bin_midpoints), len(y_bin_midpoints)) + ) + random_from_cdf = np.column_stack((x_idx, y_idx)) + new_x, new_y = random_from_cdf.T + return new_x, new_y + + def run_gmm_pipeline(self, prediction_map: np.ndarray) -> List[Tuple[int, int]]: + retries = 2 + new_x, new_y = self.sample_img_hist(prediction_map) + best_gmm, best_score = None, np.NINF + obs = np.array((new_x, new_y)).T + for k in range(1, self.MAX_ATOMS_PER_AREA + 1): + for i in range(retries): + gmm = GaussianMixture( + n_components=k, covariance_type=self.COVARIANCE_TYPE + ) + gmm.fit(obs) + logLike = gmm.score(obs) + covar = np.linalg.norm(gmm.covariances_) + if k == 1: + score = ( + logLike + - covar * self.covariance_penalisation + - k * self.n_clusters_penalisation + ) + print(k, score) + else: + distances = [ + math.dist(p1, p2) + for i, p1 in enumerate(gmm.means_[:-1]) + for p2 in gmm.means_[i + 1 :] + ] + dist_penalisation = sum([max(12 - d, 0) ** 2 for d in distances]) + score = ( + logLike + - covar * self.covariance_penalisation + - k * self.n_clusters_penalisation + - dist_penalisation * self.distance_penalisation + ) + print( + k, + score, + logLike, + covar * self.covariance_penalisation, + k * self.n_clusters_penalisation, + dist_penalisation * self.distance_penalisation, + ) + if score > best_score: + best_gmm, best_score = gmm, score + # print(best_gmm.means_) + return [(x, y) for y, x in best_gmm.means_.tolist()] diff --git a/atoms_detection/dl_full_pipeline.py b/atoms_detection/dl_full_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fa99f48e0f0c946f8d061f13aa7dda0fe558e9 --- /dev/null +++ b/atoms_detection/dl_full_pipeline.py @@ -0,0 +1,96 @@ +from typing import List + +import argparse +import os + +from atoms_detection.create_crop_dataset import create_crops_dataset +from atoms_detection.dl_detection import DLDetection +from atoms_detection.evaluation import Evaluation +from atoms_detection.training_model import train_model +from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH +from utils.constants import ModelArgs + + +def dl_full_pipeline( + extension_name: str, + architecture: ModelArgs, + coords_csv: str, + thresholds_list: List[float], + force: bool = False +): + # Create crops data + crops_folder = CROPS_PATH + f"_{extension_name}" + crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv") + if force or not os.path.exists(crops_dataset): + print("Creating crops dataset...") + create_crops_dataset(crops_folder, coords_csv, crops_dataset) + + # training DL model + ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt") + if force or not os.path.exists(ckpt_filename): + print("Training DL crops model...") + train_model(architecture, crops_dataset, crops_folder, ckpt_filename) + + force = True + # DL Detection & Evaluation + for threshold in thresholds_list: + inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}") + detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", f"dl_detection_{extension_name}_{threshold}") + if force or not os.path.exists(detections_path): + print(f"Detecting atoms on test data with threshold={threshold}...") + detection = DLDetection( + model_name=architecture, + ckpt_filename=ckpt_filename, + dataset_csv=coords_csv, + threshold=threshold, + detections_path=detections_path, + inference_cache_path=inference_cache_path + ) + detection.run() + + logging_filename = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{threshold}.csv") + if force or not os.path.exists(logging_filename): + evaluation = Evaluation( + coords_csv=coords_csv, + predictions_path=detections_path, + logging_filename=logging_filename + ) + evaluation.run() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "extension_name", + type=str, + help="Experiment extension name" + ) + parser.add_argument( + "architecture", + type=ModelArgs, + choices=ModelArgs, + help="Architecture name" + ) + parser.add_argument( + "coords_csv", + type=str, + help="Coordinates CSV file to use as input" + ) + parser.add_argument( + "-t" + "--thresholds", + nargs="+", + type=float, + help="Coordinates CSV file to use as input" + ) + parser.add_argument( + "--force", + action="store_true" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print(args) + dl_full_pipeline(args.extension_name, args.architecture, args.coords_csv, args.t__thresholds, args.force) diff --git a/atoms_detection/evaluation.py b/atoms_detection/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..d7732fbe998eb848e059b037ac334ed3352df0f3 --- /dev/null +++ b/atoms_detection/evaluation.py @@ -0,0 +1,254 @@ +from typing import Optional, Tuple, List + +import os + +import numpy as np +import scipy.optimize +from PIL import Image +import pandas as pd +from matplotlib import pyplot as plt +from matplotlib import patches + +from utils.constants import Split +from atoms_detection.dataset import CoordinatesDataset + + +def bbox_iou(bb1, bb2): + assert bb1[0] <= bb1[1] + assert bb1[2] <= bb1[3] + assert bb2[0] <= bb2[1] + assert bb2[2] <= bb2[3] + + # determine the coordinates of the intersection rectangle + x_left = max(bb1[0], bb2[0]) + y_top = max(bb1[2], bb2[2]) + x_right = min(bb1[1], bb2[1]) + y_bottom = min(bb1[3], bb2[3]) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # The intersection of two axis-aligned bounding boxes is always an + # axis-aligned bounding box + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # compute the area of both AABBs + bb1_area = (bb1[1] - bb1[0]) * (bb1[3] - bb1[2]) + bb2_area = (bb2[1] - bb2[0]) * (bb2[3] - bb2[2]) + + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) + assert iou >= 0.0 + assert iou <= 1.0 + return iou + + +def match_bboxes(iou_matrix, IOU_THRESH=0.5): + ''' + Given sets of true and predicted bounding-boxes, + determine the best possible match. + + Returns + ------- + (idxs_true, idxs_pred, ious, labels) + idxs_true, idxs_pred : indices into gt and pred for matches + ious : corresponding IOU value of each match + labels: vector of 0/1 values for the list of detections + ''' + n_true, n_pred = iou_matrix.shape + MIN_IOU = 0.0 + MAX_DIST = 1.0 + + if n_pred > n_true: + # there are more predictions than ground-truth - add dummy rows + diff = n_pred - n_true + iou_matrix = np.concatenate((iou_matrix, + np.full((diff, n_pred), MIN_IOU)), + axis=0) + + if n_true > n_pred: + # more ground-truth than predictions - add dummy columns + diff = n_true - n_pred + iou_matrix = np.concatenate((iou_matrix, + np.full((n_true, diff), MIN_IOU)), + axis=1) + + # call the Hungarian matching + idxs_true, idxs_pred = scipy.optimize.linear_sum_assignment(1 - iou_matrix) + + if (not idxs_true.size) or (not idxs_pred.size): + ious = np.array([]) + else: + ious = iou_matrix[idxs_true, idxs_pred] + + # remove dummy assignments + sel_pred = idxs_pred < n_pred + idx_pred_actual = idxs_pred[sel_pred] + idx_gt_actual = idxs_true[sel_pred] + ious_actual = iou_matrix[idx_gt_actual, idx_pred_actual] + sel_valid = (ious_actual > IOU_THRESH) + label = sel_valid.astype(int) + + return idx_gt_actual[sel_valid], idx_pred_actual[sel_valid], ious_actual[sel_valid], label + + +class Evaluation: + def __init__(self, coords_csv: str, predictions_path: str, logging_filename: str): + self.coordinates_dataset = CoordinatesDataset(coords_csv) + self.predictions_path = predictions_path + self.logging_filename = logging_filename + if not os.path.exists(os.path.dirname(self.logging_filename)): + os.makedirs(os.path.dirname(self.logging_filename)) + self.logs_df = pd.DataFrame(columns=["Filename", "Precision", "Recall", "F1Score"]) + self.threshold = 0.5 + + def get_predictions_dict(self, image_filename: str) -> List[Tuple[int, int]]: + img_name = os.path.splitext(os.path.basename(image_filename))[0] + preds_csv = os.path.join(self.predictions_path, f"{img_name}.csv") + df = pd.read_csv(preds_csv) + pred_coords_list = [] + for idx, row in df.iterrows(): + pred_coords_list.append((row["x"], row["y"])) + return pred_coords_list + + @staticmethod + def center_coords_to_bbox(gt_coord: Tuple[int, int]) -> Tuple[int, int, int, int]: + box_rwidth, box_rheight = 10, 10 + gt_bbox = ( + gt_coord[0] - box_rwidth, + gt_coord[0] + box_rwidth + 1, + gt_coord[1] - box_rheight, + gt_coord[1] + box_rheight + 1 + ) + return gt_bbox + + def eval_matches( + self, + gt_bboxes_dict: List[Tuple[int, int, int, int]], + atoms_bbox_dict: List[Tuple[int, int, int, int]], + iou_threshold: float = 0.5 + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + iou_matrix = np.zeros((len(gt_bboxes_dict), len(atoms_bbox_dict))).astype(np.float32) + + for gt_idx, gt_bbox in enumerate(gt_bboxes_dict): + for atom_idx, atom_bbox in enumerate(atoms_bbox_dict): + iou = bbox_iou(gt_bbox, atom_bbox) + iou_matrix[gt_idx, atom_idx] = iou + idxs_true, idxs_pred, ious, labels = match_bboxes(iou_matrix, IOU_THRESH=iou_threshold) + return idxs_true, idxs_pred, ious, labels + + @staticmethod + def eval_metrics(n_matches: int, n_gt: int, n_pred: int) -> Tuple[float, float]: + precision = n_matches / n_pred if n_pred > 0 else 0.0 + if n_gt == 0: + raise RuntimeError("No ground truth atoms???") + recall = n_matches / n_gt + return precision, recall + + def atom_coords_to_bboxes(self, atoms_coords_dict: List[Tuple[int, int]]) -> List[Tuple[int, int, int, int]]: + atom_bboxes_dict = [] + for atom_center in atoms_coords_dict: + atom_fixed_bbox = self.center_coords_to_bbox(atom_center) + atom_bboxes_dict.append(atom_fixed_bbox) + return atom_bboxes_dict + + def gt_coord_to_bboxes(self, gt_coordinates_dict: List[Tuple[int, int]]) -> List[Tuple[int, int, int, int]]: + gt_bboxes_list = [] + for gt_coord in gt_coordinates_dict: + gt_bbox = self.center_coords_to_bbox(gt_coord) + gt_bboxes_list.append(gt_bbox) + return gt_bboxes_list + + @staticmethod + def open_image(img_filename: str): + img = Image.open(img_filename) + np_img = np.asarray(img).astype(np.float32) + return np_img + + def run(self, plot=False): + for image_path, coordinates_path in self.coordinates_dataset.iterate_data(Split.TEST): + img = self.open_image(image_path) + + center_coords_dict = self.get_predictions_dict(image_path) + atoms_bboxes_dict = self.atom_coords_to_bboxes(center_coords_dict) + + gt_coordinates = self.coordinates_dataset.load_coordinates(coordinates_path) + gt_bboxes_dict = self.gt_coord_to_bboxes(gt_coordinates) + + # VISUALILZE gt & pred bboxes! + if plot: + plt.figure(figsize=(20, 20)) + ax = plt.gca() + ax.imshow(img) + for gt_bbox in gt_bboxes_dict: + xy = (gt_bbox[0], gt_bbox[2]) + width = gt_bbox[1] - gt_bbox[0] + height = gt_bbox[3] - gt_bbox[2] + rect = patches.Rectangle(xy, width, height, linewidth=3, edgecolor='r', facecolor='none') + ax.add_patch(rect) + for atom_bbox in atoms_bboxes_dict: + xy = (atom_bbox[0], atom_bbox[2]) + width = atom_bbox[1] - atom_bbox[0] + height = atom_bbox[3] - atom_bbox[2] + rect = patches.Rectangle(xy, width, height, linewidth=2, edgecolor='g', facecolor='none') + ax.add_patch(rect) + plt.tight_layout() + plt.show() + + idxs_true, idxs_pred, ious, labels = self.eval_matches(gt_bboxes_dict, atoms_bboxes_dict) + precision, recall = self.eval_metrics(n_matches=len(idxs_pred), n_gt=len(gt_coordinates), n_pred=len(atoms_bboxes_dict)) + f1_score = 2*(precision*recall)/(precision+recall) if precision+recall > 0 else 0 + if self.logging_filename: + # self.logs_df = self.logs_df.append({ + # "Filename": os.path.basename(image_path), + # "Precision": precision, + # "Recall": recall, + # "F1Score": f1_score + # }, ignore_index=True) + # Change the old append method to the new concat method to avoid the warning + self.logs_df = pd.concat([self.logs_df, pd.DataFrame({ + "Filename": os.path.basename(image_path), + "Precision": precision, + "Recall": recall, + "F1Score": f1_score + }, index=[0])], ignore_index=True) + format_args = (os.path.basename(image_path), f1_score, precision, recall) + print("{}: F1Score: {}, Precision: {}, Recall: {}".format(*format_args)) + + if self.logging_filename: + mean_precision = self.logs_df["Precision"].mean() + mean_recall = self.logs_df["Recall"].mean() + mean_f1_score = self.logs_df["F1Score"].mean() + std_precision = self.logs_df["Precision"].std() + std_recall = self.logs_df["Recall"].std() + std_f1_score = self.logs_df["F1Score"].std() + print(f"F1Score: {mean_f1_score}, Precision: {mean_precision}, Recall: {mean_recall}") + # self.logs_df = self.logs_df.append({ + # "Filename": "Mean", + # "Precision": mean_precision, + # "Recall": mean_recall, + # "F1Score": mean_f1_score + # }, ignore_index=True) + # Change the old append method to the new concat method to avoid the warning + self.logs_df = pd.concat([self.logs_df, pd.DataFrame({ + "Filename": "Mean", + "Precision": mean_precision, + "Recall": mean_recall, + "F1Score": mean_f1_score + }, index=[0])], ignore_index=True) + # self.logs_df = self.logs_df.append({ + # "Filename": "Std", + # "Precision": std_precision, + # "Recall": std_recall, + # "F1Score": std_f1_score + # }, ignore_index=True) + # Change the old append method to the new concat method to avoid the warning + self.logs_df = pd.concat([self.logs_df, pd.DataFrame({ + "Filename": "Std", + "Precision": std_precision, + "Recall": std_recall, + "F1Score": std_f1_score + }, index=[0])], ignore_index=True) + self.logs_df.to_csv(self.logging_filename, index=False, float_format='%.4f') diff --git a/atoms_detection/fast_filters.cpp b/atoms_detection/fast_filters.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4f078e01e3d9cc674297ed98953457652870a73 --- /dev/null +++ b/atoms_detection/fast_filters.cpp @@ -0,0 +1,83 @@ +/* + @author : Romain Graux + @date : 2023 April 03, 15:30:30 + @last modified : 2023 April 03, 18:34:08 + */ + +#include +#include + +using namespace std; + +extern "C" { + void median_filter(float* data, int width, int height, int window_size, float* out); + void reflect_borders(float *data, int width, int height, int span, float *out); +} + + +void reflect_borders(float *data, int width, int height, int span, float *out){ + int out_width = width + 2*span; + int out_height = height + 2*span; + // First copy the same data but with a border of span pixels + for(int i=0; i np.ndarray: + return np_img[:, :, 0] + + +def dl_prepro_image(np_img: np.ndarray, ruler_units=None, clip=1): + # np_bg = gaussian_filter(np_img, sigma=20) + if len(np_img.shape) == 3: + np_img = preprocess_jpg(np_img) + scale_factor = None + if ruler_units is not None: + try: + ruler_size = get_ruler_size(np_img) + np_img, scale_factor = rescale_img_to_target_pxls_nm( + np_img, ruler_size, ruler_units + ) + except Exception: + pass + + print("WARNING, MANUAL CLIP USAGE") + clip = 0.999 + max_allowed = np.quantile(np_img, q=clip) + np_img = np.clip(np_img, a_min=0, a_max=max_allowed) + try: + np_bg = median_filter_parallel(np_img, 40, splits=4) + except Exception as e: + print(e) + print("Median filter failed, using slower scipy version") + np_bg = median_filter(np_img, 40) + np_clean = np_img - np_bg + np_clean[np_clean < 0] = 0 + np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) + # np_normed = (np_img - np_img.min()) / (np_img.max() - np_img.min()) + from matplotlib import pyplot as plt + + if scale_factor is not None: + return np_normed, scale_factor + return np_normed + + +def cv_prepro_image(img: np.ndarray) -> np.ndarray: + bg_img = gaussian_filter(img, sigma=10) + clean_img = img - bg_img + normed_img = (clean_img - clean_img.min()) / (clean_img.max() - clean_img.min()) + return normed_img + + +def get_ruler_size(img: np.ndarray) -> int: + ruler_start_location_percent = 0.0625 # empirically located here in samples + ruler_start_coords = int( + img.shape[0] * (1 - ruler_start_location_percent) - 1 + ), int(img.shape[1] * ruler_start_location_percent) + if img[ruler_start_coords] != img.max(): + print("Ruler start position verification failed, skipping rescaling") + raise Exception + else: + ruler_iterator = ruler_start_coords + while img[ruler_iterator] == img[ruler_start_coords]: + ruler_iterator = ruler_iterator[0], ruler_iterator[1] + 1 + return ruler_iterator[1] - ruler_start_coords[1] + + +def rescale_img_to_target_pxls_nm( + img: np.ndarray, ruler_pixels: int, ruler_units: int, atom_prior=None +): + target_scale = ( + 512 / 15 + ) # original images were 512x512 and labelled 15nm across, 34 pixels per nano + pixels_per_nanometer = ruler_pixels / ruler_units # current pixels per nano + scaling_factor = target_scale / pixels_per_nanometer + new_dimensions = int(img.shape[0] * scaling_factor), int( + img.shape[1] * scaling_factor + ) + if atom_prior is None: + return np.array(Image.fromarray(img).resize(new_dimensions)), scaling_factor + else: + raise NotImplementedError diff --git a/atoms_detection/model.py b/atoms_detection/model.py new file mode 100644 index 0000000000000000000000000000000000000000..68da591e7bd99cb0917593ae482de5db5009b47b --- /dev/null +++ b/atoms_detection/model.py @@ -0,0 +1,108 @@ +import torch +from torch import nn + + +class BasicCNN(nn.Module): + + def __init__(self, num_classes): + super().__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, stride=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True) + ) + self.adaptive = nn.AdaptiveAvgPool2d((3, 3)) + + self.fc_layers = nn.Sequential( + nn.Linear(3 * 3 * 128, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 128), + nn.ReLU(inplace=True) + ) + self.fc3 = nn.Linear(128, num_classes) + + self._initialize_weights() + + # Defining the forward pass + def forward(self, x): + x = self.features(x) + x = self.adaptive(x) + x = torch.flatten(x, 1) + x = self.fc_layers(x) + x = self.fc3(x) + return x + + def _initialize_weights(self): + for layer in self.features: + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(layer.bias, 0) + elif isinstance(layer, nn.BatchNorm2d): + nn.init.constant_(layer.weight, 1) + nn.init.constant_(layer.bias, 0) + for layer in self.fc_layers: + if isinstance(layer, nn.Linear): + nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(layer.bias, 0) + nn.init.normal_(self.fc3.weight, 0, 0.01) + + +class HeatCNN(nn.Module): + + def __init__(self, num_classes): + super().__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, stride=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True) + ) + self.adaptive = nn.AdaptiveAvgPool2d((3, 3)) + + self.fc_layers = nn.Sequential( + nn.Linear(3 * 3 * 128, 64), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(64, 64), + nn.ReLU(inplace=True), + nn.Dropout(), + ) + self.fc3 = nn.Linear(64, num_classes) + + self._initialize_weights() + + # Defining the forward pass + def forward(self, x): + x = self.features(x) + x = self.adaptive(x) + x = torch.flatten(x, 1) + x = self.fc_layers(x) + x = self.fc3(x) + return x + + def _initialize_weights(self): + for layer in self.features: + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(layer.bias, 0) + elif isinstance(layer, nn.BatchNorm2d): + nn.init.constant_(layer.weight, 1) + nn.init.constant_(layer.bias, 0) + for layer in self.fc_layers: + if isinstance(layer, nn.Linear): + nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(layer.bias, 0) + nn.init.normal_(self.fc3.weight, 0, 0.01) diff --git a/atoms_detection/multimetallic_analysis.py b/atoms_detection/multimetallic_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..557ca277775258b21979062303624589255fb33d --- /dev/null +++ b/atoms_detection/multimetallic_analysis.py @@ -0,0 +1,229 @@ +# run VAE + GMM assignement +import argparse +from typing import List + +import numpy as np +# import rasterio +import torch +import warnings +import os +import re +import pandas as pd +from PIL import Image + +from sklearn.mixture import GaussianMixture + +from atoms_detection.create_crop_dataset import create_crop +from atoms_detection.vae_utilities.vae_model import rVAE +from atoms_detection.vae_utilities.vae_svi_train import init_dataloader, SVItrainer +from atoms_detection.image_preprocessing import dl_prepro_image + +""" +Code sourced from: +https://colab.research.google.com/github/ziatdinovmax/notebooks_for_medium/blob/main/pyroVAE_MNIST_medium.ipynb + +""" + +numbers = re.compile(r'(\d+)') + +def numericalSort(value): + parts = numbers.split(value) + parts[1::2] = map(int, parts[1::2]) + return parts + +warnings.filterwarnings("ignore", module="torchvision.datasets") + + +def get_crops_from_prediction_csvs(pred_crop_file): + data = pd.read_csv(pred_crop_file) + xx = data['x'].values + yy = data['y'].values + coords = zip(xx,yy) + + img_file = data['Filename'][0] + likelihood = data['Likelihood'].values + img_path = os.path.join('data/tif_data', img_file) + + img = Image.open(img_path) + np_img = np.asarray(img).astype(np.float64) + np_img = dl_prepro_image(np_img) + img = Image.fromarray(np_img) + + crops = list() + coords_list = [] + for x, y in coords: + coords_list.append([x,y]) + new_crop = create_crop(img, x, y) + crops.append(new_crop) + + print(coords_list[0]) + print(np_img[0]) + + coords_array = np.array(coords_list) + + return crops, coords_array, likelihood, img_file + + +def classify_crop_species(args): + # crop_list = get_crops_from_folder(crops_source_folder='./Ni') + crop_list, crop_coords, likelihood, img_filename = get_crops_from_prediction_csvs(args.pred_crop_file) + crop_tensor = np.array(crop_list) + + # Assuming crop_tensor is a list or array of Image objects + processed_images = [] + for image in crop_tensor: + # Convert the Image to a NumPy array + image_array = np.array(image) + # Append the processed image array to the list + processed_images.append(image_array) + # Convert the processed images list to a NumPy array + processed_images = np.array(processed_images) + # Convert the processed_images array to float32 + processed_images = processed_images.astype(np.float32) + + #print(processed_images.shape) + + rvae = rVAE(in_dim=(21, 21), latent_dim=args.latent_dim, coord=args.coord, seed=args.seed) + + train_data = torch.from_numpy(processed_images).float() + # train_data = torch.from_numpy(crop_tensor).float() + train_loader = init_dataloader(train_data, batch_size=args.batchsize) + latent_crop_tensor = train_vae(rvae, train_data, train_loader, args) + + gmm = GaussianMixture(n_components=args.n_species, reg_covar=args.GMMcovar, random_state=args.seed).fit( + latent_crop_tensor) + preds = gmm.predict(latent_crop_tensor) + print(preds) + pred_proba = gmm.predict_proba(latent_crop_tensor) + pred_proba = [pred_proba[i, pred] for i, pred in enumerate(preds)] + + # To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int) + cluster_median_values = list() + for k in range(args.n_species): + print(k) + relevant_crops = processed_images[preds == k] + crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0) + img_means = [] + for crop, q in zip(relevant_crops, crop_95_percentile): + if (crop >= q).any(): + print(crop.mean()) + img_means.append(crop.mean()) + #img_means.append(crop.mean(axis=0, where=crop >= q)) + cluster_median_value = np.median(np.array(img_means)) + cluster_median_values.append(cluster_median_value) + sorted_clusters = sorted([(mval, c_id) for c_id, mval in enumerate(cluster_median_values)]) + + with open(f"data/detection_data/Multimetallic_{img_filename}.csv", "a") as f: + f.write("Filename,x,y,Likelihood,cluster,cluster_confidence\n") + for _, c_id in sorted_clusters: + c_idd = np.array([c_id]) + pred_proba = np.array(pred_proba) + relevant_crops_coords = crop_coords[preds == c_idd] + relevant_crops_likelihood = likelihood[preds == c_idd] + relevant_crops_confidence = pred_proba[preds == c_idd] + #print(relevant_crops_confidence) + for coords, l, c in zip(relevant_crops_coords, relevant_crops_likelihood, relevant_crops_confidence): + x, y = coords + f.write(f"{img_filename},{x},{y},{l},{c_id},{c}\n") + + + +def train_vae(rvae, train_data, train_loader, args): + # Initialize SVI trainer + trainer = SVItrainer(rvae) + for e in range(args.epochs): + trainer.step(train_loader, scale_factor=args.scale_factor) + trainer.print_statistics() + z_mean, z_sd = rvae.encode(train_data) + latent_crop_tensor = z_mean + return latent_crop_tensor + + +def get_crops_from_folder(crops_source_folder) -> List[np.ndarray]: + ffiles = [] + files = [] + for dirname, dirnames, filenames in os.walk(crops_source_folder): + # print path to all subdirectories first. + for subdirname in dirnames: + files.append(os.path.join(dirname, subdirname)) + + # print path to all filenames. + for filename in filenames: + files.append(os.path.join(dirname, filename)) + + for filename in sorted((filenames), key=numericalSort): + ffiles.append(os.path.join(filename)) + crops = ffiles + # print(len(crops)) + path_crops = './Ni/' + all_img = [] + for i in range(0, len(crops)): + src_path = path_crops + crops[i] + img = rasterio.open(src_path) + test = np.reshape(img.read([1]), (21, 21)) + all_img.append(np.array(test)) + return all_img + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'pred_crop_file', + type=str, + help="Path to the CSV of predicted crop locations (eg in data/detection_data/X/Y.csv)" + ) + parser.add_argument( + "-latent_dim", + type=int, + default=50, + help="Experiment extension name" + ) + parser.add_argument( + "-seed", + type=int, + default=444, + help="Random seed" + ) + parser.add_argument( + "-coord", + type=int, + default=3, + help="Amount of equivariances, 0: None,1: Rotational, 2: Translational, 3:Rotational and Translational" + ) + parser.add_argument( + "-batchsize", + type=int, + default=100, + help="Batch size for the VAE model" + ) + parser.add_argument( + "-epochs", + type=int, + default=20, + help="Number of training epochs for the VAE" + ) + parser.add_argument( + "-scale_factor", + type=int, + default=3, + help="Number of training epochs for the VAE" + ) + parser.add_argument( + "-n_species", + type=int, + default=2, + help="Number of chemical species expected in the sample." + ) + parser.add_argument( + "-GMMcovar", + type=float, + default=0.0001, + help="Regcovar for the training of the GMM clustering algorithm." + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print(args) + classify_crop_species(args) diff --git a/atoms_detection/testing_model.py b/atoms_detection/testing_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4090ac62470eb9ef09a9bcd01d0261c4a9c23c6f --- /dev/null +++ b/atoms_detection/testing_model.py @@ -0,0 +1,53 @@ +import os + +import torch +from sklearn.metrics import confusion_matrix, f1_score, accuracy_score +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt + +from atoms_detection.training_model import model_pipeline, get_args +from atoms_detection.dataset import CropsDataset +from atoms_detection.training import test_epoch +from utils.cf_matrix import make_confusion_matrix +from utils.paths import MODELS_PATH, CM_VIS_PATH + + +def main(args): + # CUDA for PyTorch + #use_cuda = torch.cuda.is_available() + use_cuda = torch.backends.mps.is_available() + device = torch.device("mps" if use_cuda else "cpu") + + test_dataset = CropsDataset.test_dataset() + test_dataloader = DataLoader(test_dataset, batch_size=64) + + ckpt_filename = os.path.join(MODELS_PATH, f'{args.experiment_name}.ckpt') + checkpoint = torch.load(ckpt_filename, map_location=device) + + model = model_pipeline[args.model](num_classes=test_dataset.get_n_labels()).to(device) + model.load_state_dict(checkpoint['state_dict']) + + if torch.cuda.device_count() > 1: + print("Using {} GPUs!".format(torch.cuda.device_count())) + model = torch.nn.DataParallel(model) + + loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device) + + y_true, y_pred = test_epoch(test_dataloader, model, loss_function, device) + + cm = confusion_matrix(y_true, y_pred) + labels = ["True Neg", "False Pos", "False Neg", "True Pos"] + make_confusion_matrix(cm, group_names=labels, cbar_range=(0, 110)) + if not os.path.exists(CM_VIS_PATH): + os.makedirs(CM_VIS_PATH) + plt.savefig(os.path.join(CM_VIS_PATH, f"cm_{args.experiment_name}.jpg")) + f1 = f1_score(y_true, y_pred) + acc = accuracy_score(y_true, y_pred) + with open(os.path.join(CM_VIS_PATH, f"metrics_{args.experiment_name}.txt"), 'w') as _log: + _log.write(f"F1_score: {f1}\nACCURACY: {acc}\n") + print(f"F1_score: {f1}") + print(f"ACCURACY: {acc}") + + +if __name__ == "__main__": + main(get_args()) diff --git a/atoms_detection/training.py b/atoms_detection/training.py new file mode 100644 index 0000000000000000000000000000000000000000..dcdaf5cc67bd587ca6dbf647ed36fe3b1e91bdbd --- /dev/null +++ b/atoms_detection/training.py @@ -0,0 +1,145 @@ +import time + +import numpy as np +import torch +from torch.nn import functional as F + + +def train_epoch(train_loader, model, loss_function, optimizer, device, epoch): + model.train() + + correct = 0 + total = 0 + losses = 0 + t0 = time.time() + for idx, (batch_images, batch_labels) in enumerate(train_loader): + # Loading tensors in the used device + step_images, step_labels = batch_images.to(device), batch_labels.to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + step_output = model(step_images) + loss = loss_function(step_output, step_labels) + loss.backward() + optimizer.step() + + step_total = step_labels.size(0) + step_loss = loss.item() + losses += step_loss*step_total + total += step_total + + step_preds = torch.max(step_output.data, 1)[1] + step_correct = (step_preds == step_labels).sum().item() + correct += step_correct + + train_loss = losses / total + train_acc = correct / total + format_args = (epoch, train_acc, train_loss, time.time() - t0) + print('EPOCH {} :: train accuracy: {:.4f} - train loss: {:.4f} at {:.1f}s'.format(*format_args)) + + +def val_epoch(val_loader, model, loss_function, device, epoch): + model.eval() + + y_true = [] + y_pred = [] + + correct = 0 + total = 0 + losses = 0 + t0 = time.time() + with torch.no_grad(): + for batch_images, batch_labels in val_loader: + # Loading tensors in the used device + step_images, step_labels = batch_images.to(device), batch_labels.to(device) + + step_output = model(step_images) + loss = loss_function(step_output, step_labels) + + step_total = step_labels.size(0) + step_loss = loss.item() + losses += step_loss*step_total + total += step_total + + step_preds = torch.max(step_output.data, 1)[1] + y_pred.append(step_preds.cpu().detach().numpy()) + y_true.append(step_labels.cpu().detach().numpy()) + step_correct = (step_preds == step_labels).sum().item() + correct += step_correct + + val_loss = losses / total + val_acc = correct / total + format_args = (epoch, val_acc, val_loss, time.time() - t0) + print('EPOCH {} :: val accuracy: {:.4f} - val loss: {:.4f} at {:.1f}s'.format(*format_args)) + + y_pred = np.concatenate(y_pred, axis=0) + y_true = np.concatenate(y_true, axis=0) + return y_true, y_pred + + +def test_epoch(test_loader, model, loss_function, device): + model.eval() + + correct = 0 + total = 0 + losses = 0 + all_true = [] + all_pred = [] + t0 = time.time() + with torch.no_grad(): + for batch_images, batch_labels in test_loader: + # Loading tensors in the used device + step_images, step_labels = batch_images.to(device), batch_labels.to(device) + + step_output = model(step_images) + loss = loss_function(step_output, step_labels) + + step_total = step_labels.size(0) + step_loss = loss.item() + losses += step_loss*step_total + total += step_total + + step_preds = torch.max(step_output.data, 1)[1] + step_correct = (step_preds == step_labels).sum().item() + correct += step_correct + + all_true.append(step_labels.cpu().numpy()) + all_pred.append(step_preds.cpu().numpy()) + + val_loss = losses / total + val_acc = correct / total + format_args = (val_acc, val_loss, time.time() - t0) + print('EPOCH :: test accuracy: {:.4f} - test loss: {:.4f} at {:.1f}s'.format(*format_args)) + + all_pred = np.concatenate(all_pred, axis=0) + all_true = np.concatenate(all_true, axis=0) + return all_true, all_pred + + +def detection_epoch(detection_loader, model, device): + model.eval() + + pred_probs = [] + coords_x = [] + coords_y = [] + t0 = time.time() + with torch.no_grad(): + for batch_images, batch_x, batch_y in detection_loader: + # Loading tensors in the used device + step_images = batch_images.to(device) + step_output = model(step_images) + step_pred_probs = F.softmax(step_output, 1) + + step_pred_probs = step_pred_probs.cpu().numpy() + step_x = batch_x.numpy() + step_y = batch_y.numpy() + + coords_x.append(step_x) + coords_y.append(step_y) + pred_probs.append(step_pred_probs) + + return_pred_probs = np.concatenate(pred_probs, axis=0) + return_coords_x = np.concatenate(coords_x, axis=0) + return_coords_y = np.concatenate(coords_y, axis=0) + return return_pred_probs, return_coords_x, return_coords_y diff --git a/atoms_detection/training_model.py b/atoms_detection/training_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dc821826556276cd72bbb847d0c3cf1ea211ca06 --- /dev/null +++ b/atoms_detection/training_model.py @@ -0,0 +1,111 @@ +import os +import argparse +import random + +import torch +import numpy as np +import pandas as pd +from torch.utils.data import DataLoader +from torchvision.models import resnet18 + +from utils.paths import MODELS_PATH, CROPS_PATH, CROPS_DATASET +from utils.constants import ModelArgs, Split, CropsColumns +from atoms_detection.training import train_epoch, val_epoch +from atoms_detection.dataset import ImageClassificationDataset +from atoms_detection.model import BasicCNN + + +torch.manual_seed(777) +random.seed(777) +np.random.seed(777) + + +def get_basic_cnn(*args, **kwargs): + model = BasicCNN(*args, **kwargs) + return model + + +def get_resnet(*args, **kwargs): + model = resnet18(*args, **kwargs) + model.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False) + return model + + +model_pipeline = { + ModelArgs.BASICCNN: get_basic_cnn, + ModelArgs.RESNET18: get_resnet +} + +epochs_pipeline = { + ModelArgs.BASICCNN: 12, + ModelArgs.RESNET18: 3 +} + + +def train_model(model_arg: ModelArgs, crops_dataset: str, crops_path: str, ckpt_filename: str): + + class CropsDataset(ImageClassificationDataset): + @staticmethod + def get_filenames_labels(split: Split): + df = pd.read_csv(crops_dataset) + split_df = df[df[CropsColumns.SPLIT] == split] + filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list() + labels = (split_df[CropsColumns.LABEL]).to_list() + return filenames, labels + + + # CUDA for PyTorch + #use_cuda = torch.cuda.is_available() + use_cuda = torch.backends.mps.is_available() + device = torch.device("mps" if use_cuda else "cpu") + + train_dataset = CropsDataset.train_dataset() + val_dataset = CropsDataset.val_dataset() + train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) + val_dataloader = DataLoader(val_dataset, batch_size=64) + model = model_pipeline[model_arg](num_classes=train_dataset.get_n_labels()).to(device) + + if torch.cuda.device_count() > 1: + print("Using {} GPUs!".format(torch.cuda.device_count())) + model = torch.nn.DataParallel(model) + + loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True) + + epoch = 0 + for epoch in range(epochs_pipeline[model_arg]): + train_epoch(train_dataloader, model, loss_function, optimizer, device, epoch) + val_epoch(val_dataloader, model, loss_function, device, epoch) + + if not os.path.exists(MODELS_PATH): + os.makedirs(MODELS_PATH) + + state = { + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch + } + torch.save(state, ckpt_filename) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "experiment_name", + type=str, + help="Experiment name" + ) + parser.add_argument( + "model", + type=ModelArgs, + help="model architecture", + choices=list(ModelArgs) + ) + return parser.parse_args() + + +if __name__ == "__main__": + extension_name = "replicate" + ckpt_filename = os.path.join(MODELS_PATH, "basic_replicate2.ckpt") + crops_folder = CROPS_PATH + f"_{extension_name}" + train_model(ModelArgs.BASICCNN, CROPS_DATASET, CROPS_PATH, ckpt_filename) diff --git a/atoms_detection/vae_image_utils.py b/atoms_detection/vae_image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d156eb2a10d3255d730fa2974beff08fc3793445 --- /dev/null +++ b/atoms_detection/vae_image_utils.py @@ -0,0 +1,53 @@ +# @title Load functions for working with image coordinates and labels +# @title Load utility functions for data loading and preprocessing + +from typing import Tuple, Union + +import torch + +import warnings +warnings.filterwarnings("ignore", module="torchvision.datasets") + + +def to_onehot(idx: torch.Tensor, n: int) -> torch.Tensor: + """ + One-hot encoding of a label + """ + if torch.max(idx).item() >= n: + raise AssertionError( + "Labelling must start from 0 and " + "maximum label value must be less than total number of classes") + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if idx.dim() == 1: + idx = idx.unsqueeze(1) + onehot = torch.zeros(idx.size(0), n, device=device) + return onehot.scatter_(1, idx.to(device), 1) + + +def grid2xy(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor: + X = torch.cat((X1[None], X2[None]), 0) + d0, d1 = X.shape[0], X.shape[1] * X.shape[2] + X = X.reshape(d0, d1).T + return X + + +def imcoordgrid(im_dim: Tuple) -> torch.Tensor: + xx = torch.linspace(-1, 1, im_dim[0]) + yy = torch.linspace(1, -1, im_dim[1]) + x0, x1 = torch.meshgrid(xx, yy) + return grid2xy(x0, x1) + + +def transform_coordinates(coord: torch.Tensor, + phi: Union[torch.Tensor, float] = 0, + coord_dx: Union[torch.Tensor, float] = 0, + ) -> torch.Tensor: + + if torch.sum(phi) == 0: + phi = coord.new_zeros(coord.shape[0]) + rotmat_r1 = torch.stack([torch.cos(phi), torch.sin(phi)], 1) + rotmat_r2 = torch.stack([-torch.sin(phi), torch.cos(phi)], 1) + rotmat = torch.stack([rotmat_r1, rotmat_r2], axis=1) + coord = torch.bmm(coord, rotmat) + + return coord + coord_dx diff --git a/atoms_detection/vae_model.py b/atoms_detection/vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf328c33cd399f5627c8520a4835a3481518a46 --- /dev/null +++ b/atoms_detection/vae_model.py @@ -0,0 +1,345 @@ + +import numpy as np + + +import torch +import torch.nn as nn +from torch import tensor as tt + +from typing import Optional, Tuple, Type + +import pyro +import pyro.distributions as dist + +import warnings + +from atoms_detection.vae_image_utils import imcoordgrid, to_onehot, transform_coordinates + +warnings.filterwarnings("ignore", module="torchvision.datasets") + +# VAE model set-up +# @title Load neural networks for VAE { form-width: "25%" } + + +def set_deterministic_mode(seed: int) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def make_fc_layers(in_dim: int, + hidden_dim: int = 128, + num_layers: int = 2, + activation: str = "tanh" + ) -> Type[nn.Module]: + """ + Generates a module with stacked fully-connected (aka dense) layers + """ + activations = {"tanh": nn.Tanh, "lrelu": nn.LeakyReLU, "softplus": nn.Softplus} + fc_layers = [] + for i in range(num_layers): + hidden_dim_ = in_dim if i == 0 else hidden_dim + fc_layers.extend( + [nn.Linear(hidden_dim_, hidden_dim), activations[activation]()]) + fc_layers = nn.Sequential(*fc_layers) + return fc_layers + + +class fcEncoderNet(nn.Module): + """ + Simple fully-connected inference (encoder) network + """ + def __init__(self, + in_dim: Tuple[int,int], + latent_dim: int = 2, + hidden_dim: int = 128, + num_layers: int = 2, + activation: str = 'tanh', + softplus_out: bool = False + ) -> None: + """ + Initializes module parameters + """ + super(fcEncoderNet, self).__init__() + if len(in_dim) not in [1, 2, 3]: + raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") + self.in_dim = torch.prod(tt(in_dim)).item() + + self.fc_layers = make_fc_layers( + self.in_dim, hidden_dim, num_layers, activation) + self.fc11 = nn.Linear(hidden_dim, latent_dim) + self.fc12 = nn.Linear(hidden_dim, latent_dim) + self.activation_out = nn.Softplus() if softplus_out else lambda x: x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Forward pass + """ + x = x.view(-1, self.in_dim) + x = self.fc_layers(x) + mu = self.fc11(x) + log_sigma = self.activation_out(self.fc12(x)) + return mu, log_sigma + + +class fcDecoderNet(nn.Module): + """ + Standard decoder for VAE + """ + def __init__(self, + out_dim: Tuple[int], + latent_dim: int, + hidden_dim: int = 128, + num_layers: int = 2, + activation: str = 'tanh', + sigmoid_out: str = True, + ) -> None: + super(fcDecoderNet, self).__init__() + if len(out_dim) not in [1, 2, 3]: + raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") + self.reshape = out_dim + out_dim = torch.prod(tt(out_dim)).item() + + self.fc_layers = make_fc_layers( + latent_dim, hidden_dim, num_layers, activation) + self.out = nn.Linear(hidden_dim, out_dim) + self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x + + def forward(self, z: torch.Tensor) -> torch.Tensor: + x = self.fc_layers(z) + x = self.activation_out(self.out(x)) + return x.view(-1, *self.reshape) + + +class rDecoderNet(nn.Module): + """ + Spatial generator (decoder) network with fully-connected layers + """ + def __init__(self, + out_dim: Tuple[int], + latent_dim: int, + hidden_dim: int = 128, + num_layers: int = 2, + activation: str = 'tanh', + sigmoid_out: str = True + ) -> None: + """ + Initializes module parameters + """ + super(rDecoderNet, self).__init__() + if len(out_dim) not in [1, 2, 3]: + raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") + self.reshape = out_dim + out_dim = torch.prod(tt(out_dim)).item() + + self.coord_latent = coord_latent(latent_dim, hidden_dim) + self.fc_layers = make_fc_layers( + hidden_dim, hidden_dim, num_layers, activation) + self.out = nn.Linear(hidden_dim, 1) # need to generalize to multi-channel (c > 1) + self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x + + def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """ + Forward pass + """ + x = self.coord_latent(x_coord, z) + x = self.fc_layers(x) + x = self.activation_out(self.out(x)) + return x.view(-1, *self.reshape) + + +class coord_latent(nn.Module): + """ + The "spatial" part of the rVAE's decoder that allows for translational + and rotational invariance (based on https://arxiv.org/abs/1909.11663) + """ + def __init__(self, + latent_dim: int, + out_dim: int, + activation_out: bool = True) -> None: + """ + Iniitalizes modules parameters + """ + super(coord_latent, self).__init__() + self.fc_coord = nn.Linear(2, out_dim) + self.fc_latent = nn.Linear(latent_dim, out_dim, bias=False) + self.activation = nn.Tanh() if activation_out else None + + def forward(self, + x_coord: torch.Tensor, + z: torch.Tensor) -> torch.Tensor: + """ + Forward pass + """ + batch_dim, n = x_coord.size()[:2] + x_coord = x_coord.reshape(batch_dim * n, -1) + h_x = self.fc_coord(x_coord) + h_x = h_x.reshape(batch_dim, n, -1) + h_z = self.fc_latent(z) + h = h_x.add(h_z.unsqueeze(1)) + h = h.reshape(batch_dim * n, -1) + if self.activation is not None: + h = self.activation(h) + return h + + +class rVAE(nn.Module): + """ + Variational autoencoder with rotational and/or transaltional invariance + """ + def __init__(self, + in_dim: Tuple[int, int], + latent_dim: int = 2, + coord: int = 3, + num_classes: int = 0, + hidden_dim_e: int = 128, + hidden_dim_d: int = 128, + num_layers_e: int = 2, + num_layers_d: int = 2, + activation: str = "tanh", + softplus_sd: bool = True, + sigmoid_out: bool = True, + seed: int = 1, + **kwargs + ) -> None: + """ + Initializes rVAE's modules and parameters + """ + super(rVAE, self).__init__() + pyro.clear_param_store() + set_deterministic_mode(seed) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.encoder_net = fcEncoderNet( + in_dim, latent_dim+coord, hidden_dim_e, + num_layers_e, activation, softplus_sd) + if coord not in [0, 1, 2, 3]: + raise ValueError("'coord' argument must be 0, 1, 2 or 3") + dnet = rDecoderNet if coord in [1, 2, 3] else fcDecoderNet + self.decoder_net = dnet( + in_dim, latent_dim+num_classes, hidden_dim_d, + num_layers_d, activation, sigmoid_out) + self.z_dim = latent_dim + coord + self.coord = coord + self.num_classes = num_classes + self.grid = imcoordgrid(in_dim).to(self.device) + self.dx_prior = tt(kwargs.get("dx_prior", 0.1)).to(self.device) + self.to(self.device) + + def model(self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + **kwargs: float) -> torch.Tensor: + """ + Defines the model p(x|z)p(z) + """ + # register PyTorch module `decoder_net` with Pyro + pyro.module("decoder_net", self.decoder_net) + # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) + beta = kwargs.get("scale_factor", 1.) + reshape_ = torch.prod(tt(x.shape[1:])).item() + with pyro.plate("data", x.shape[0]): + # setup hyperparameters for prior p(z) + z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim))) + z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim))) + # sample from prior (value will be sampled by guide when computing the ELBO) + with pyro.poutine.scale(scale=beta): + z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + if self.coord > 0: # rotationally- and/or translationaly-invariant mode + # Split latent variable into parts for rotation + # and/or translation and image content + phi, dx, z = self.split_latent(z) + if torch.sum(dx) != 0: + dx = (dx * self.dx_prior).unsqueeze(1) + # transform coordinate grid + grid = self.grid.expand(x.shape[0], *self.grid.shape) + x_coord_prime = transform_coordinates(grid, phi, dx) + # Add class label (if any) + if y is not None: + y = to_onehot(y, self.num_classes) + z = torch.cat([z, y], dim=-1) + # decode the latent code z together with the transformed coordiantes (if any) + dec_args = (x_coord_prime, z) if self.coord else (z,) + loc_img = self.decoder_net(*dec_args) + # score against actual images ("binary cross-entropy loss") + pyro.sample( + "obs", dist.Bernoulli(loc_img.view(-1, reshape_), validate_args=False).to_event(1), + obs=x.view(-1, reshape_)) + + def guide(self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + **kwargs: float) -> torch.Tensor: + """ + Defines the guide q(z|x) + """ + # register PyTorch module `encoder_net` with Pyro + pyro.module("encoder_net", self.encoder_net) + # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) + beta = kwargs.get("scale_factor", 1.) + with pyro.plate("data", x.shape[0]): + # use the encoder to get the parameters used to define q(z|x) + z_loc, z_scale = self.encoder_net(x) + # sample the latent code z + with pyro.poutine.scale(scale=beta): + pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + + def split_latent(self, z: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Split latent variable into parts for rotation + and/or translation and image content + """ + phi, dx = tt(0), tt(0) + # rotation + translation + if self.coord == 3: + phi = z[:, 0] # encoded angle + dx = z[:, 1:3] # translation + z = z[:, 3:] # image content + # translation only + elif self.coord == 2: + dx = z[:, :2] + z = z[:, 2:] + # rotation only + elif self.coord == 1: + phi = z[:, 0] + z = z[:, 1:] + return phi, dx, z + + def _encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor: + """ + Encodes data using a trained inference (encoder) network + in a batch-by-batch fashion + """ + def inference() -> np.ndarray: + with torch.no_grad(): + encoded = self.encoder_net(x_i) + encoded = torch.cat(encoded, -1).cpu() + return encoded + + x_new = x_new.to(self.device) + num_batches = kwargs.get("num_batches", 10) + batch_size = len(x_new) // num_batches + z_encoded = [] + for i in range(num_batches): + x_i = x_new[i*batch_size:(i+1)*batch_size] + z_encoded_i = inference() + z_encoded.append(z_encoded_i) + x_i = x_new[(i+1)*batch_size:] + if len(x_i) > 0: + z_encoded_i = inference() + z_encoded.append(z_encoded_i) + return torch.cat(z_encoded) + + def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor: + """ + Encodes data using a trained inference (encoder) network + (this is baiscally a wrapper for self._encode) + """ + if isinstance(x_new, torch.utils.data.DataLoader): + x_new = train_loader.dataset.tensors[0] + z = self._encode(x_new) + z_loc = z[:, :self.z_dim] + z_scale = z[:, self.z_dim:] + return z_loc, z_scale diff --git a/atoms_detection/vae_svi_train.py b/atoms_detection/vae_svi_train.py new file mode 100644 index 0000000000000000000000000000000000000000..67142e785158bfd5aafce725eccc566ec2864084 --- /dev/null +++ b/atoms_detection/vae_svi_train.py @@ -0,0 +1,121 @@ + +from typing import Optional, Type + +import torch +import torch.nn as nn + +import pyro +import pyro.infer as infer +import pyro.optim as optim + +import warnings + +#from vae_model import set_deterministic_mode as set_deterministic_mode +from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode + +warnings.filterwarnings("ignore", module="torchvision.datasets") + + +class SVItrainer: + """ + Stochastic variational inference (SVI) trainer for + unsupervised and class-conditioned variational models + """ + def __init__(self, + model: Type[nn.Module], + optimizer: Type[optim.PyroOptim] = None, + loss: Type[infer.ELBO] = None, + seed: int = 1 + ) -> None: + """ + Initializes the trainer's parameters + """ + pyro.clear_param_store() + set_deterministic_mode(seed) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if optimizer is None: + optimizer = optim.Adam({"lr": 1.0e-3}) + if loss is None: + loss = infer.Trace_ELBO() + self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss) + self.loss_history = {"training_loss": [], "test_loss": []} + self.current_epoch = 0 + + def train(self, + train_loader: Type[torch.utils.data.DataLoader], + **kwargs: float) -> float: + """ + Trains a single epoch + """ + # initialize loss accumulator + epoch_loss = 0. + # do a training epoch over each mini-batch returned by the data loader + for data in train_loader: + if len(data) == 1: # VAE mode + x = data[0] + loss = self.svi.step(x.to(self.device), **kwargs) + else: # VED or cVAE mode + x, y = data + loss = self.svi.step( + x.to(self.device), y.to(self.device), **kwargs) + # do ELBO gradient and accumulate loss + epoch_loss += loss + + return epoch_loss / len(train_loader.dataset) + + def evaluate(self, + test_loader: Type[torch.utils.data.DataLoader], + **kwargs: float) -> float: + """ + Evaluates current models state on a single epoch + """ + # initialize loss accumulator + test_loss = 0. + # compute the loss over the entire test set + with torch.no_grad(): + for data in test_loader: + if len(data) == 1: # VAE mode + x = data[0] + loss = self.svi.step(x.to(self.device), **kwargs) + else: # VED or cVAE mode + x, y = data + loss = self.svi.step( + x.to(self.device), y.to(self.device), **kwargs) + test_loss += loss + + return test_loss / len(test_loader.dataset) + + def step(self, + train_loader: Type[torch.utils.data.DataLoader], + test_loader: Optional[Type[torch.utils.data.DataLoader]] = None, + **kwargs: float) -> None: + """ + Single training and (optionally) evaluation step + """ + self.loss_history["training_loss"].append(self.train(train_loader, **kwargs)) + if test_loader is not None: + self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs)) + self.current_epoch += 1 + + def print_statistics(self) -> None: + """ + Prints training and test (if any) losses for current epoch + """ + e = self.current_epoch + if len(self.loss_history["test_loss"]) > 0: + template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}' + print(template.format(e, self.loss_history["training_loss"][-1], + self.loss_history["test_loss"][-1])) + else: + template = 'Epoch: {} Training loss: {:.4f}' + print(template.format(e, self.loss_history["training_loss"][-1])) + + +def init_dataloader(*args: torch.Tensor, **kwargs: int + ) -> Type[torch.utils.data.DataLoader]: + + batch_size = kwargs.get("batch_size", 100) + tensor_set = torch.utils.data.dataset.TensorDataset(*args) + data_loader = torch.utils.data.DataLoader( + dataset=tensor_set, batch_size=batch_size, shuffle=True) + return data_loader diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d335df07bdf7f19a2dc61e211fa9be7965a57cd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +gradio==3.24.1 +matplotlib==3.7.1 +networkx==3.0 +numpy==1.23.5 +opencv_contrib_python==4.7.0.72 +pandas==1.5.3 +Pillow==9.5.0 +scikit_learn==1.2.2 +scipy==1.10.1 +seaborn==0.12.2 +torch==1.13.1 +torchvision==0.14.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd32599ea52386f09299dbb87b93771ef16e4b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@author : Romain Graux +@date : 2023 April 06, 17:33:28 +@last modified : 2023 April 24, 15:59:09 +@last modified : 2023 April 24, 15:59:09 +""" + +import os +import logging +from distutils.core import setup, Extension + +logging.basicConfig(level=logging.INFO) + +os.environ["CC"] = "g++" + +fast_filters_module = Extension( + "fast_filters", + sources=["atoms_detection/fast_filters.cpp"], +) + +setup( + name="atoms_detection", + version="0.0.1a0", + description="", + ext_modules=[fast_filters_module], +) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/cf_matrix.py b/utils/cf_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..3d85cb5e2f53ea00e58645b83c7425cea269e83d --- /dev/null +++ b/utils/cf_matrix.py @@ -0,0 +1,111 @@ +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def make_confusion_matrix(cf, + group_names=None, + categories='auto', + count=True, + percent=True, + cbar=True, + cbar_range=(None, None), + xyticks=True, + xyplotlabels=True, + sum_stats=True, + figsize=None, + cmap='Blues', + title=None): + ''' + This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization. + + Arguments + --------- + cf: confusion matrix to be passed in + + group_names: List of strings that represent the labels row by row to be shown in each square. + + categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto' + + count: If True, show the raw number in the confusion matrix. Default is True. + + normalize: If True, show the proportions for each category. Default is True. + + cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix. + Default is True. + + xyticks: If True, show x and y ticks. Default is True. + + xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True. + + sum_stats: If True, display summary statistics below the figure. Default is True. + + figsize: Tuple representing the figure size. Default will be the matplotlib rcParams value. + + cmap: Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues' + See http://matplotlib.org/examples/color/colormaps_reference.html + + title: Title for the heatmap. Default is None. + + ''' + + # CODE TO GENERATE TEXT INSIDE EACH SQUARE + blanks = ['' for i in range(cf.size)] + + if group_names and len(group_names) == cf.size: + group_labels = ["{}\n".format(value) for value in group_names] + else: + group_labels = blanks + + if count: + group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()] + else: + group_counts = blanks + + if percent: + group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)] + else: + group_percentages = blanks + + box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)] + box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1]) + + # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS + if sum_stats: + # Accuracy is sum of diagonal divided by total observations + accuracy = np.trace(cf) / float(np.sum(cf)) + + # if it is a binary confusion matrix, show some more stats + if len(cf) == 2: + # Metrics for Binary Confusion Matrices + precision = cf[1, 1] / sum(cf[:, 1]) + recall = cf[1, 1] / sum(cf[1, :]) + f1_score = 2 * precision * recall / (precision + recall) + stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format( + accuracy, precision, recall, f1_score) + else: + stats_text = "\n\nAccuracy={:0.3f}".format(accuracy) + else: + stats_text = "" + + # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS + if figsize == None: + # Get default figure size if not set + figsize = plt.rcParams.get('figure.figsize') + + if xyticks == False: + # Do not show categories if xyticks is False + categories = False + + # MAKE THE HEATMAP VISUALIZATION + plt.figure(figsize=figsize) + sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, vmin=cbar_range[0], vmax=cbar_range[1], xticklabels=categories, yticklabels=categories) + + if xyplotlabels: + plt.ylabel('True label') + plt.xlabel('Predicted label' + stats_text) + else: + plt.xlabel(stats_text) + + if title: + plt.title(title) diff --git a/utils/constants.py b/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f8510d1c0e6749edf8f8bfa82d324dfed1c8370a --- /dev/null +++ b/utils/constants.py @@ -0,0 +1,61 @@ +from enum import Enum + + +class Catalyst(Enum): + Pt = 'Pt' + Fe = 'Fe' + + def __str__(self): + return str(self.value) + + +class Method(Enum): + DL = 'DL' + CV = 'CV' + TEM = 'TEMImageNet' + + def __str__(self): + return str(self.value) + + +class Split: + TRAIN = 'train' + VAL = 'val' + TEST = 'test' + + +class Columns: + FILENAME = 'Filename' + LABEL = 'Label' + SPLIT = 'Split' + + +class CropsColumns: + FILENAME = 'Filename' + ORIGINAL = 'Original' + X = 'X' + Y = 'Y' + LABEL = 'Label' + SPLIT = 'Split' + + +class BoxColumns: + FILENAME = 'Filename' + X1 = 'X1' + X2 = 'X2' + Y1 = 'Y1' + Y2 = 'Y2' + LABEL = 'Label' + SPLIT = 'Split' + + +class ProbsColumns: + FILENAME = 'Filename' + ORIGINAL = 'Original' + LABEL = 'Label' + SPLIT = 'Split' + + +class ModelArgs(str, Enum): + BASICCNN = 'basic' + RESNET18 = 'resnet18' diff --git a/utils/crops_visualization.py b/utils/crops_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..39e99e9e10e0fc1fdcbecc3f31b0a0f2153cd9c9 --- /dev/null +++ b/utils/crops_visualization.py @@ -0,0 +1,26 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +from matplotlib import pyplot as plt + +from utils.constants import CropsColumns +from utils.paths import CROPS_DATASET, CROPS_PATH, CROPS_VIS_PATH + + +if not os.path.exists(CROPS_VIS_PATH): + os.makedirs(CROPS_VIS_PATH) + + +dataset_df = pd.read_csv(CROPS_DATASET) +for tif_name in dataset_df[CropsColumns.FILENAME]: + tif_filename = os.path.join(CROPS_PATH, tif_name) + img = Image.open(tif_filename) + img = np.array(img).astype(np.float32) + img = (img - img.min()) / img.max() + plt.tight_layout() + plt.imshow(img) + vis_name = "{}.jpg".format(os.path.splitext(tif_name)[0]) + vis_filename = os.path.join(CROPS_VIS_PATH, vis_name) + plt.savefig(vis_filename) diff --git a/utils/paths.py b/utils/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..08472ecdb3c9b7cd8a32188c638ba112aa4c08dd --- /dev/null +++ b/utils/paths.py @@ -0,0 +1,42 @@ +import os +from glob import glob + +PROJECT_PATH = os.path.abspath(os.path.join(__file__, *(os.path.pardir for _ in range(2)))) + +MINIO_KEYS = os.path.join(PROJECT_PATH, 'minio.json') +LOGS_PATH = os.path.join(PROJECT_PATH, 'logs') +DETECTION_LOGS = os.path.join(LOGS_PATH, 'detection_coords') +PRED_MAP_TABLE_LOGS = os.path.join(LOGS_PATH, 'pred_map_to_table') + +DATA_PATH = os.path.join(PROJECT_PATH, 'data') +IMG_PATH = os.path.join(DATA_PATH, 'tif_data') +COORDS_PATH = os.path.join(DATA_PATH, 'label_coordinates') +CROPS_PATH = os.path.join(DATA_PATH, 'atom_crops_data') +PROBS_PATH = os.path.join(DATA_PATH, 'probs_data') +BOX_PATH = os.path.join(DATA_PATH, 'box_data') +PREDS_PATH = os.path.join(DATA_PATH, 'prediction_cache') +DETECTION_PATH = os.path.join(DATA_PATH, 'detection_data') + +DATASET_PATH = os.path.join(PROJECT_PATH, 'dataset') +CROPS_DATASET = os.path.join(DATASET_PATH, 'atom_crops.csv') +PROBS_DATASET = os.path.join(DATASET_PATH, 'probs_dataset.csv') +BF_DATASET = os.path.join(DATASET_PATH, 'BF_dataset.csv') +HAADF_DATASET = os.path.join(DATASET_PATH, 'HAADF_dataset.csv') +PT_DATASET = os.path.join(DATASET_PATH, 'Pt_dataset.csv') +FE_DATASET = os.path.join(DATASET_PATH, 'Fe_dataset.csv') +BOX_DATASET = os.path.join(DATASET_PATH, 'box_dataset.csv') + +MODELS_PATH = os.path.join(PROJECT_PATH, 'models') + +DATA_VIS_PATH = os.path.join(PROJECT_PATH, 'data_vis') +CROPS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'crops') +CM_VIS_PATH = os.path.join(DATA_VIS_PATH, 'cm_vis') +ORIG_VIS_PATH = os.path.join(DATA_VIS_PATH, 'orig') +PREPRO_VIS_PATH = os.path.join(DATA_VIS_PATH, 'preprocessed') +LABEL_VIS_PATH = os.path.join(DATA_VIS_PATH, 'label') +PRED_VIS_PATH = os.path.join(DATA_VIS_PATH, 'predictions') +PRED_GT_VIS_PATH = os.path.join(DATA_VIS_PATH, 'predictions_gt') +LANDS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'landscapes') +ACTIVATIONS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'activations') + +LIB_PATH = glob(f"{os.path.join(PROJECT_PATH, 'build')}/lib*")[0] diff --git a/visualizations/__init__.py b/visualizations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/visualizations/crop_images.py b/visualizations/crop_images.py new file mode 100644 index 0000000000000000000000000000000000000000..13ceba3ed8b5a5f3952015cf196928543f45b62d --- /dev/null +++ b/visualizations/crop_images.py @@ -0,0 +1,21 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from utils.paths import CROPS_VIS_PATH + +df = pd.read_csv("dataset/atom_crops_replicate.csv") +for crop_name in df['Filename']: + crop_filename = os.path.join("data/atom_crops_data_sac_cnn", crop_name) + crop = Image.open(crop_filename) + crop_arr = np.array(crop).astype(np.float32) + plt.figure() + plt.axis('off') + plt.imshow(crop_arr) + vis_path = os.path.join(CROPS_VIS_PATH, '{}.png'.format(os.path.splitext(crop_name)[0])) + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0) + plt.close() + diff --git a/visualizations/dl_intermediate_layers_visualization.py b/visualizations/dl_intermediate_layers_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..eb62856c98219569714459517e880a805d2098e9 --- /dev/null +++ b/visualizations/dl_intermediate_layers_visualization.py @@ -0,0 +1,188 @@ +import os +from typing import List, Tuple, Optional, Dict + +import argparse + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional +from matplotlib import pyplot as plt + +from atoms_detection.dataset import CoordinatesDataset +from atoms_detection.image_preprocessing import dl_prepro_image +from atoms_detection.model import BasicCNN +from utils.constants import ModelArgs, Split +from utils.paths import ACTIVATIONS_VIS_PATH + + +class ConvLayerVisualizer: + CONV_0 = 'Conv0' + CONV_3 = 'Conv3' + CONV_6 = 'Conv6' + + def __init__(self, model_name: ModelArgs, ckpt_filename: str): + self.model_name = model_name + self.ckpt_filename = ckpt_filename + self.device = self.get_torch_device() + self.batch_size = 64 + + self.stride = 1 + self.padding = 10 + self.window_size = (21, 21) + + @staticmethod + def get_torch_device(): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + return device + + def sliding_window(self, image: np.ndarray) -> Tuple[int, int, np.ndarray]: + # slide a window across the image + x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2 + y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2 + + for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride): + for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride): + # yield the current window + center_x = x + x_to_center + center_y = y + y_to_center + yield center_x, center_y, image[y:y + self.window_size[1], x:x + self.window_size[0]] + + def padding_image(self, img: np.ndarray) -> np.ndarray: + image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2)) + image_padded[self.padding:-self.padding, self.padding:-self.padding] = img + return image_padded + + def images_to_torch_input(self, image: np.ndarray) -> torch.Tensor: + expanded_img = np.expand_dims(image, axis=(0, 1)) + input_tensor = torch.from_numpy(expanded_img).float() + input_tensor = input_tensor.to(self.device) + return input_tensor + + def load_model(self) -> BasicCNN: + checkpoint = torch.load(self.ckpt_filename, map_location=self.device) + model = BasicCNN(num_classes=2).to(self.device) + model.load_state_dict(checkpoint['state_dict']) + model.eval() + return model + + @staticmethod + def center_to_slice(x_center: int, y_center: int, width: int, height: int) -> Tuple[slice, slice]: + x_to_center = width // 2 - 1 if width % 2 == 0 else width // 2 + y_to_center = height // 2 - 1 if height % 2 == 0 else height // 2 + x = x_center - x_to_center + y = y_center - y_to_center + return slice(x, x + width), slice(y, y + height) + + def get_prediction_map(self, padded_image: np.ndarray) -> Dict[str, np.ndarray]: + _shape = padded_image.shape + convs_activations_dict = { + self.CONV_0: (np.zeros(_shape), np.zeros(_shape)), + self.CONV_3: (np.zeros(_shape), np.zeros(_shape)), + self.CONV_6: (np.zeros(_shape), np.zeros(_shape)) + } + model = self.load_model() + for x, y, image_crop in self.sliding_window(padded_image): + torch_input = self.images_to_torch_input(image_crop) + conv_outputs = self.get_conv_activations(torch_input, model) + for conv_layer_key, activations_blob in conv_outputs.items(): + activation_map = self.sum_channels(activations_blob) + h, w = activation_map.shape + x_slice, y_slice = self.center_to_slice(x, y, w, h) + convs_activations_dict[conv_layer_key][0][y_slice, x_slice] += 1 + convs_activations_dict[conv_layer_key][1][y_slice, x_slice] += activation_map + + activations_dict = {} + for conv_layer_key, (counting_map, output_map) in convs_activations_dict.items(): + zero_rows = np.sum(counting_map, axis=1) + zero_cols = np.sum(counting_map, axis=0) + + output_map = np.delete(output_map, np.where(zero_rows == 0), axis=0) + clean_output_map = np.delete(output_map, np.where(zero_cols == 0), axis=1) + counting_map = np.delete(counting_map, np.where(zero_rows == 0), axis=0) + clean_counting_map = np.delete(counting_map, np.where(zero_cols == 0), axis=1) + + activations_dict[conv_layer_key] = clean_output_map / clean_counting_map + + return activations_dict + + def get_conv_activations(self, input_image: torch.Tensor, model: BasicCNN) -> Dict[str, np.ndarray]: + conv_activations = {} + activations = input_image + for i, layer in enumerate(model.features): + activations = layer(activations) + if i == 0: + conv_activations[self.CONV_0] = activations.squeeze(0).detach().cpu().numpy() + elif i == 3: + conv_activations[self.CONV_3] = activations.squeeze(0).detach().cpu().numpy() + elif i == 6: + conv_activations[self.CONV_6] = activations.squeeze(0).detach().cpu().numpy() + + return conv_activations + + @staticmethod + def sum_channels(activations: np.ndarray): + aggregated_activations = np.sum(activations, axis=0) + return aggregated_activations + + def image_to_pred_map(self, img: np.ndarray) -> Dict[str, np.ndarray]: + preprocessed_img = dl_prepro_image(img) + padded_image = self.padding_image(preprocessed_img) + activations_dict = self.get_prediction_map(padded_image) + return activations_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "architecture", + type=ModelArgs, + choices=ModelArgs, + help="Architecture name" + ) + parser.add_argument( + "ckpt_filename", + type=str, + help="Path to model checkpoint" + ) + parser.add_argument( + "coords_csv", + type=str, + help="Coordinates CSV file to use as input" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print(args) + + conv_visualizer = ConvLayerVisualizer( + model_name=args.architecture, + ckpt_filename=args.ckpt_filename + ) + + coordinates_dataset = CoordinatesDataset(args.coords_csv) + for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): + img = Image.open(image_path) + np_img = np.array(img) + activations_dict = conv_visualizer.image_to_pred_map(np_img) + + img_name = os.path.splitext(os.path.basename(image_path))[0] + + output_folder = os.path.join(ACTIVATIONS_VIS_PATH, f"{img_name}") + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + for conv_layer_key, activation_map in activations_dict.items(): + fig = plt.figure() + plt.title(f"{conv_layer_key} -- {img_name}") + plt.imshow(activation_map) + + output_path = os.path.join(output_folder, f"{conv_layer_key}_{img_name}.png") + plt.savefig(output_path, bbox_inches='tight') + plt.close(fig) + + + diff --git a/visualizations/label_images.py b/visualizations/label_images.py new file mode 100644 index 0000000000000000000000000000000000000000..463aaea5e08c8f44bef3aae37d22e0578fb487c2 --- /dev/null +++ b/visualizations/label_images.py @@ -0,0 +1,32 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from atoms_detection.dataset import CoordinatesDataset +from utils.paths import PT_DATASET, IMG_PATH, COORDS_PATH, LABEL_VIS_PATH + +df = pd.read_csv(PT_DATASET) +for idx, row in df.iterrows(): +# row = list(df.iterrows())[0][1] + image_name = row['Image'] + coords_name = row['Coords'] + image_filename = os.path.join(IMG_PATH, image_name) + coords_filename = os.path.join(COORDS_PATH, coords_name) + img = Image.open(image_filename) + atom_coordinates = pd.read_csv(coords_filename) + x, y = atom_coordinates['X'], atom_coordinates['Y'] + # coords = CoordinatesDataset.load_coordinates(coords_filename) + + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + plt.figure(figsize=(3, 3)) + plt.axis('off') + plt.imshow(img_normed) + plt.scatter(x, y, s=80, linewidths=1.5, c='#FFDB1A', marker='+') + vis_path = os.path.join(LABEL_VIS_PATH, '{}.png'.format(os.path.splitext(image_name)[0])) + # plt.show() + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) + plt.close() diff --git a/visualizations/performance_by_threshold.py b/visualizations/performance_by_threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..11495bd03b33a59b8b09cc9ebef6b321614d36f0 --- /dev/null +++ b/visualizations/performance_by_threshold.py @@ -0,0 +1,80 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from utils.paths import LOGS_PATH, DATA_VIS_PATH, DATA_PATH + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "extension_name", + type=str, + help="Experiment extension name" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + extension_name = args.extension_name + + thresholds = np.array([0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]) + f1_mean, f1_std = [], [] + precision_mean, precision_std = [], [] + recall_mean, recall_std = [], [] + + csv_pattern = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{{}}.csv") + for threshold in thresholds: + performance_csv_filename = csv_pattern.format(threshold) + perf_df = pd.read_csv(performance_csv_filename) + + mean_row = perf_df.iloc[-2] + std_row = perf_df.iloc[-1] + + # Precision, Recall, F1Score + f1_mean.append(mean_row['F1Score']) + f1_std.append(std_row['F1Score']) + precision_mean.append(mean_row['Precision']) + precision_std.append(std_row['Precision']) + recall_mean.append(mean_row['Recall']) + recall_std.append(std_row['Recall']) + + f1_mean, f1_std = np.array(f1_mean), np.array(f1_std) + precision_mean, precision_std = np.array(precision_mean), np.array(precision_std) + recall_mean, recall_std = np.array(recall_mean), np.array(recall_std) + + df_to_save = pd.DataFrame({'threshold': thresholds, + 'f1score_mean': f1_mean, 'f1score_std': f1_std, + 'precision_mean': precision_mean, 'precision_std': precision_std, + 'recall_mean': recall_mean, 'recall_std': recall_std}) + csv_filename = os.path.join(DATA_PATH, f"performance_threshold_{extension_name}.csv") + df_to_save.to_csv(csv_filename, index=False) + + plt.figure() + plt.plot(thresholds, f1_mean, color='k', linestyle='-', label='F1Score') + plt.plot(thresholds, precision_mean, color='k', linestyle='--', label='Precision') + plt.plot(thresholds, recall_mean, color='k', linestyle=':', label='Recall') + + f1_high, f1_low = f1_mean+f1_std, f1_mean-f1_std + plt.fill_between(thresholds, f1_high, f1_low, where=f1_high >= f1_low, facecolor='#fccfcf', interpolate=True, alpha=0.5) + + precision_high, precision_low = precision_mean+precision_std, precision_mean-precision_std + plt.fill_between(thresholds, precision_high, precision_low, where=precision_high >= precision_low, facecolor='#cfeffc', interpolate=True, alpha=0.5) + + recall_high, recall_low = recall_mean+recall_std, recall_mean-recall_std + plt.fill_between(thresholds, recall_high, recall_low, where=recall_high >= recall_low, facecolor='#d6ffd1', interpolate=True, alpha=0.5) + + plt.xlabel('Threshold') + plt.xticks(thresholds[1::2]) + plt.yticks(np.arange(0.1, 1, 0.1)) + plt.ylim(0, 1) + + plt.grid(alpha=0.3) + + plt.legend() + plot_filename = os.path.join(DATA_VIS_PATH, f"performance_threshold_{extension_name}.png") + plt.savefig(plot_filename, bbox_inches='tight', pad_inches=0.0) diff --git a/visualizations/performance_by_threshold_kr.py b/visualizations/performance_by_threshold_kr.py new file mode 100644 index 0000000000000000000000000000000000000000..a2eb34e88cdbdfc758e5ff82bb394c382d4595b6 --- /dev/null +++ b/visualizations/performance_by_threshold_kr.py @@ -0,0 +1,59 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from utils.paths import LOGS_PATH, DATA_VIS_PATH, DATA_PATH + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "extension_name", + type=str, + help="Experiment extension name" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + extension_name = args.extension_name + +# thresholds = np.array([0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]) + + thresholds = np.array([0.89]) + f1_mean, f1_std = [], [] + precision_mean, precision_std = [], [] + recall_mean, recall_std = [], [] + + csv_pattern = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{{}}.csv") + for threshold in thresholds: + performance_csv_filename = csv_pattern.format(threshold) + perf_df = pd.read_csv(performance_csv_filename) + + mean_row = perf_df.iloc[-2] + std_row = perf_df.iloc[-1] + + # Precision, Recall, F1Score + f1_mean.append(mean_row['F1Score']) + f1_std.append(std_row['F1Score']) + precision_mean.append(mean_row['Precision']) + precision_std.append(std_row['Precision']) + recall_mean.append(mean_row['Recall']) + recall_std.append(std_row['Recall']) + + f1_mean, f1_std = np.array(f1_mean), np.array(f1_std) + precision_mean, precision_std = np.array(precision_mean), np.array(precision_std) + recall_mean, recall_std = np.array(recall_mean), np.array(recall_std) + + print(f1_mean, precision_mean, recall_mean) + +# df_to_save = pd.DataFrame({'threshold': thresholds, +# 'f1score_mean': f1_mean, 'f1score_std': f1_std, +# 'precision_mean': precision_mean, 'precision_std': precision_std, +# 'recall_mean': recall_mean, 'recall_std': recall_std}) +# csv_filename = os.path.join(DATA_PATH, f"performance_threshold_{extension_name}.csv") +# df_to_save.to_csv(csv_filename, index=False) diff --git a/visualizations/pred_landscape.py b/visualizations/pred_landscape.py new file mode 100644 index 0000000000000000000000000000000000000000..3471b536645310de44d1a23b420e25c970644a58 --- /dev/null +++ b/visualizations/pred_landscape.py @@ -0,0 +1,80 @@ +import os +from hashlib import sha1 + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import cm +from mpl_toolkits.mplot3d.axes3d import Axes3D +from mpl_toolkits.mplot3d import proj3d + +from atoms_detection.dl_detection import DLDetection +from atoms_detection.dataset import CoordinatesDataset +from utils.constants import Split, ModelArgs +from utils.paths import PT_DATASET, PREDS_PATH, DETECTION_PATH,LANDS_VIS_PATH + + +threshold = 0.89 +extension_name = "replicate" +detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}_{threshold}") +inference_cache_path = os.path.join(PREDS_PATH, os.path.basename(detections_path)) + + +def get_pred_map(img_filename: str) -> np.ndarray: + img_hash = sha1(img_filename.encode()).hexdigest() + prediciton_cache = os.path.join(inference_cache_path, f"{img_hash}.npy") + if not os.path.exists(prediciton_cache): + detection = DLDetection( + model_name=ModelArgs.BASICCNN, + ckpt_filename="/home/fpares/PycharmProjects/stem_atoms/models/basic_replicate.ckpt", + dataset_csv="/home/fpares/PycharmProjects/stem_atoms/dataset/Coordinate_image_pairs.csv", + threshold=threshold, + detections_path=detections_path + ) + img = DLDetection.open_image(image_path) + pred_map = detection.image_to_pred_map(img) + np.save(prediciton_cache, pred_map) + else: + pred_map = np.load(prediciton_cache) + return pred_map + + +def short_proj(): + return np.dot(Axes3D.get_proj(ax), scale) + + +if not os.path.exists(LANDS_VIS_PATH): + os.makedirs(LANDS_VIS_PATH) + +coordinates_dataset = CoordinatesDataset(PT_DATASET) +for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): + pred_map = get_pred_map(image_path) + + """ + Scaling is done from here... + """ + x_scale = 1 + y_scale = 1 + z_scale = 0.1 + + scale = np.diag([x_scale, y_scale, z_scale, 1.0]) + scale = scale * (1.0 / scale.max()) + scale[3, 3] = 1.0 + + X = np.arange(0, 512, 1) + Y = np.arange(0, 512, 1) + X, Y = np.meshgrid(X, Y) + + # fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) + fig = plt.figure(figsize=(10, 10)) + ax = fig.gca(projection='3d') + ax.get_proj = short_proj + surf = ax.plot_surface(X, Y, pred_map, cmap=cm.coolwarm, + rstride=2, cstride=2, + linewidth=0.2, antialiased=True) + + ax.set_axis_off() + + img_name = os.path.splitext(os.path.basename(image_path))[0] + landscape_output_path = os.path.join(LANDS_VIS_PATH, f"{img_name}_landscape_{extension_name}_{threshold}.png") + plt.savefig(landscape_output_path, bbox_inches='tight', pad_inches=0.0, transparent=True) + # plt.show() diff --git a/visualizations/pred_map_to_table.py b/visualizations/pred_map_to_table.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdce2e758400f99fbc93c21795d2597a615f529 --- /dev/null +++ b/visualizations/pred_map_to_table.py @@ -0,0 +1,55 @@ +import os +from hashlib import sha1 + +import numpy as np +import pandas as pd + +from atoms_detection.dl_detection import DLDetection +from atoms_detection.dataset import CoordinatesDataset +from utils.constants import Split, ModelArgs +from utils.paths import PT_DATASET, PREDS_PATH, DETECTION_PATH, PRED_MAP_TABLE_LOGS + + +threshold = 0.89 +extension_name = "replicate" +detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}_{threshold}") +inference_cache_path = os.path.join(PREDS_PATH, os.path.basename(detections_path)) + + +def get_pred_map(img_filename: str) -> np.ndarray: + img_hash = sha1(img_filename.encode()).hexdigest() + prediciton_cache = os.path.join(inference_cache_path, f"{img_hash}.npy") + if not os.path.exists(prediciton_cache): + detection = DLDetection( + model_name=ModelArgs.BASICCNN, + ckpt_filename="/home/fpares/PycharmProjects/stem_atoms/models/basic_replicate.ckpt", + dataset_csv="/home/fpares/PycharmProjects/stem_atoms/dataset/Coordinate_image_pairs.csv", + threshold=threshold, + detections_path=detections_path + ) + img = DLDetection.open_image(image_path) + pred_map = detection.image_to_pred_map(img) + np.save(prediciton_cache, pred_map) + else: + pred_map = np.load(prediciton_cache) + return pred_map + + +if not os.path.exists(PRED_MAP_TABLE_LOGS): + os.makedirs(PRED_MAP_TABLE_LOGS) + +coordinates_dataset = CoordinatesDataset(PT_DATASET) +for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): + pred_map = get_pred_map(image_path) + + pred_table = {'X': [], 'Y': [], 'Z': []} + for index, likelihood in np.ndenumerate(pred_map): + pred_table['X'].append(index[0]) + pred_table['Y'].append(index[1]) + pred_table['Z'].append(likelihood) + + pred_df = pd.DataFrame(pred_table) + + img_name = os.path.splitext(os.path.basename(image_path))[0] + pred_table_output_path = os.path.join(PRED_MAP_TABLE_LOGS, f"{img_name}_likelihood_{extension_name}_{threshold}.csv") + pred_df.to_csv(pred_table_output_path, index=False) diff --git a/visualizations/prediction_gt_images.py b/visualizations/prediction_gt_images.py new file mode 100644 index 0000000000000000000000000000000000000000..6803213d8cef21862f4d57e972820da1b7e7b49a --- /dev/null +++ b/visualizations/prediction_gt_images.py @@ -0,0 +1,110 @@ +import argparse +import os +from enum import Enum + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from atoms_detection.dataset import CoordinatesDataset +from utils.constants import Split, Catalyst, Method +from utils.paths import DETECTION_LOGS, IMG_PATH, PRED_GT_VIS_PATH, PT_DATASET, FE_DATASET, DETECTION_PATH +from visualizations.utils import plot_gt_pred_on_img + + +def main(args): + catalyst = args.catalyst + method = args.method + if not os.path.exists(PRED_GT_VIS_PATH): + os.makedirs(PRED_GT_VIS_PATH) + if catalyst == Catalyst.Pt: + coordinates_dataset = CoordinatesDataset(PT_DATASET) + if method == Method.DL: + detection_path = "data/detection_data/dl_detection_sac_cnn/dl_detection_sac_cnn_0.89" + elif method == Method.CV: + detection_path = os.path.join(DETECTION_PATH, "cv_detection_trial_0.18") + elif method == Method.TEM: + detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_pt", + "tem_imagenet_pt_denoise-bg_Gen1GaussianMask") + else: + raise NotImplementedError + + elif catalyst == Catalyst.Fe: + coordinates_dataset = CoordinatesDataset(FE_DATASET) + if method == Method.DL: + detection_path = os.path.join(DETECTION_PATH, f"dl_fe_detection_trial", + f"dl_fe_detection_trial_0.97") + elif method == Method.CV: + detection_path = os.path.join(DETECTION_PATH, "cv_fe_detection_trial", + "cv_fe_detection_trial_0.21") + elif method == Method.TEM: + detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_fe", + "tem_imagenet_fe_denoise-bg_Gen1GaussianMask") + else: + raise NotImplementedError + + else: + raise NotImplementedError + + gt_coords_dict = get_gt_coords(coordinates_dataset) + for name_file in os.listdir(detection_path): + image_name = os.path.splitext(name_file)[0] + ".tif" + print(image_name) + if image_name not in gt_coords_dict: + continue + + filepath = os.path.join(detection_path, name_file) + image_filename = os.path.join(IMG_PATH, image_name) + img = Image.open(image_filename) + + gt_coords = gt_coords_dict[image_name] + df_predicted = pd.read_csv(filepath) + pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()] + + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + + plot_gt_pred_on_img(img_normed, gt_coords, pred_coords) + + vis_folder = os.path.join(PRED_GT_VIS_PATH, f"{catalyst}-Catalyst_{method}-Method") + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + + clean_image_name = os.path.splitext(image_name)[0] + vis_path = os.path.join(vis_folder, f'{clean_image_name}.png') + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) + plt.close() + + + +def get_gt_coords(coordinates_dataset): + gt_coords_dict = {} + for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): + # orig . image_name = os.path.splitext(os.path.basename(image_path))[0] + ".tif" + image_name = os.path.basename(image_path) + gt_coords = coordinates_dataset.load_coordinates(coordinates_path) + gt_coords_dict[image_name] = gt_coords + return gt_coords_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "catalyst", + type=Catalyst, + choices=Catalyst, + help="Select data by catalyst" + ) + parser.add_argument( + "method", + type=Method, + choices=Method, + help="Select method" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/visualizations/prediction_images.py b/visualizations/prediction_images.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b3f9f17ad7d26e1a837f2c49df6bc31d2afac1 --- /dev/null +++ b/visualizations/prediction_images.py @@ -0,0 +1,37 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from atoms_detection.dataset import CoordinatesDataset +from utils.constants import Split +from utils.paths import DETECTION_LOGS, IMG_PATH, PRED_VIS_PATH, PT_DATASET + + +if __name__ == "__main__": + for name_file in os.listdir(DETECTION_LOGS): + print(name_file) + filepath = os.path.join(DETECTION_LOGS, name_file) + + image_name = os.path.splitext(name_file)[0] + ".tif" + image_filename = os.path.join(IMG_PATH, image_name) + img = Image.open(image_filename) + + df = pd.read_csv(filepath) + x, y = [], [] + for idx, row in df.iterrows(): + x.append(row['x']) + y.append(row['y']) + + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + plt.figure(figsize=(10, 10)) + plt.axis('off') + plt.imshow(img_normed) + plt.scatter(x, y, s=300, linewidths=3, c='#FFDB1A', marker='+') + + vis_path = os.path.join(PRED_VIS_PATH, '{}.png'.format(os.path.splitext(image_name)[0])) + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) + plt.close() diff --git a/visualizations/prepro_images.py b/visualizations/prepro_images.py new file mode 100644 index 0000000000000000000000000000000000000000..1f02a09925d9ebec2f71eeacbd6c45e01dc33aea --- /dev/null +++ b/visualizations/prepro_images.py @@ -0,0 +1,32 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from atoms_detection.image_preprocessing import dl_prepro_image +from utils.paths import PT_DATASET, FE_DATASET, IMG_PATH, PREPRO_VIS_PATH + + +def generate_prepro_plots(dataset_path: str, vis_folder: str): + if not os.path.exists(vis_folder): + os.makedirs(vis_folder) + + df = pd.read_csv(dataset_path) + for image_name in df['Filename']: + image_filename = os.path.join(IMG_PATH, image_name) + img = Image.open(image_filename) + np_img = np.array(img).astype(np.float32) + np_prepro_img = dl_prepro_image(np_img) + plt.figure() + plt.axis('off') + plt.imshow(np_prepro_img) + vis_path = os.path.join(vis_folder, '{}.png'.format(os.path.splitext(image_name)[0])) + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0) + plt.close() + + +if __name__ == "__main__": + generate_prepro_plots(PT_DATASET, os.path.join(PREPRO_VIS_PATH, 'Pt-Catalyst')) + generate_prepro_plots(FE_DATASET, os.path.join(PREPRO_VIS_PATH, 'Fe-Catalyst')) diff --git a/visualizations/tif_images.py b/visualizations/tif_images.py new file mode 100644 index 0000000000000000000000000000000000000000..4aded89941ae49d9d5950ca2e1a96a41dff6c810 --- /dev/null +++ b/visualizations/tif_images.py @@ -0,0 +1,23 @@ +import os + +import numpy as np +import pandas as pd +from PIL import Image +import matplotlib.pyplot as plt + +from utils.paths import HAADF_DATASET, IMG_PATH, ORIG_VIS_PATH + + +if __name__=="__main__": + df = pd.read_csv(HAADF_DATASET) + for image_name in df['Filename']: + image_filename = os.path.join(IMG_PATH, image_name) + img = Image.open(image_filename) + img_arr = np.array(img).astype(np.float32) + img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) + plt.figure() + plt.axis('off') + plt.imshow(img_normed) + vis_path = os.path.join(ORIG_VIS_PATH, '{}.png'.format(os.path.splitext(image_name)[0])) + plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0) + plt.close() diff --git a/visualizations/utils.py b/visualizations/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..618fa3c7fb76fe49b1a9c20bb0e1c2d4f9ba0707 --- /dev/null +++ b/visualizations/utils.py @@ -0,0 +1,20 @@ + +from PIL import Image +import numpy as np +from networkx.drawing.tests.test_pylab import plt + +from atoms_detection.image_preprocessing import preprocess_jpg +from utils.constants import Split + + +def plot_gt_pred_on_img(img_normed, gt_coords, pred_coords): + imgsize = img_normed.shape[0]/512 + plt.figure(figsize=(imgsize*10, imgsize*10)) + plt.axis('off') + plt.imshow(img_normed) + if pred_coords is not None: + x, y = zip(*pred_coords) + plt.scatter(x, y, s=300, linewidths=3, c='#FFDB1A', marker='+') + if gt_coords is not None: + gt_x, gt_y = zip(*gt_coords) + plt.scatter(gt_x, gt_y, s=300, linewidths=2, facecolors='none', edgecolors='r')