Romain Graux commited on
Commit
b2ffc9b
1 Parent(s): 12414cb

Initial commit with ml code and webapp

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +163 -0
  2. .vscode/settings.json +5 -0
  3. README.md +11 -0
  4. app/app.py +373 -0
  5. app/assets/ETH_Zurich_Logo_black.svg +58 -0
  6. app/assets/logo-ace.png +0 -0
  7. app/backup_tiff_utils.py +55 -0
  8. app/dl_inference.py +187 -0
  9. app/knn.py +203 -0
  10. app/logger.py +14 -0
  11. app/tiff_utils.py +42 -0
  12. atoms_detection/README.md +12 -0
  13. atoms_detection/__init__.py +0 -0
  14. atoms_detection/create_crop_dataset.py +408 -0
  15. atoms_detection/create_crop_dataset_1024.py +197 -0
  16. atoms_detection/create_crop_dataset_2048.py +197 -0
  17. atoms_detection/create_crop_dataset_512.py +190 -0
  18. atoms_detection/cv_detection.py +31 -0
  19. atoms_detection/cv_fe_detection_evaluation.py +37 -0
  20. atoms_detection/cv_full_pipeline.py +71 -0
  21. atoms_detection/dataset.py +265 -0
  22. atoms_detection/detection.py +96 -0
  23. atoms_detection/dl_contrastive_pipeline.py +198 -0
  24. atoms_detection/dl_detection.py +118 -0
  25. atoms_detection/dl_detection_evaluation.py +130 -0
  26. atoms_detection/dl_detection_scaled.py +42 -0
  27. atoms_detection/dl_detection_with_gmm.py +151 -0
  28. atoms_detection/dl_full_pipeline.py +96 -0
  29. atoms_detection/evaluation.py +254 -0
  30. atoms_detection/fast_filters.cpp +83 -0
  31. atoms_detection/fast_filters.py +103 -0
  32. atoms_detection/image_preprocessing.py +83 -0
  33. atoms_detection/model.py +108 -0
  34. atoms_detection/multimetallic_analysis.py +229 -0
  35. atoms_detection/testing_model.py +53 -0
  36. atoms_detection/training.py +145 -0
  37. atoms_detection/training_model.py +111 -0
  38. atoms_detection/vae_image_utils.py +53 -0
  39. atoms_detection/vae_model.py +345 -0
  40. atoms_detection/vae_svi_train.py +121 -0
  41. requirements.txt +12 -0
  42. setup.py +28 -0
  43. utils/__init__.py +0 -0
  44. utils/cf_matrix.py +111 -0
  45. utils/constants.py +61 -0
  46. utils/crops_visualization.py +26 -0
  47. utils/paths.py +42 -0
  48. visualizations/__init__.py +0 -0
  49. visualizations/crop_images.py +21 -0
  50. visualizations/dl_intermediate_layers_visualization.py +188 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ # *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ data
163
+ models
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter"
4
+ }
5
+ }
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Atoms Detection
2
+
3
+ This contains the code and webapp for the publication *[Quantitative Description of Metal Center Organization and Interactions in Single-Atom Catalysts](https://doi.org/10.1002/adma.202307991)*.
4
+
5
+ ## Reference
6
+
7
+ To cite this work, please use the following:
8
+
9
+ ```
10
+ K. Rossi, A. Ruiz-Ferrando, D. F. Akl, V. G. Abalos, J. Heras-Domingo, R. Graux, X. Hai, J. Lu, D. Garcia-Gasulla, N. López, J. Pérez-Ramírez, S. Mitchell, Quantitative Description of Metal Center Organization and Interactions in Single-Atom Catalysts. Adv. Mater. 2024, 36, 2307991. https://doi.org/10.1002/adma.202307991
11
+ ```
app/app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 April 25, 14:39:03
6
+ @last modified : 2023 September 20, 15:35:23
7
+ """
8
+
9
+ # TODO : add the training of the vae
10
+ # TODO : add the description of the settings
11
+
12
+ import sys
13
+
14
+ import numpy as np
15
+ from PIL import Image, ImageDraw
16
+ import gradio as gr
17
+ from tiff_utils import extract_physical_metadata
18
+ from dl_inference import inference_fn
19
+ from knn import knn, segment_image, bokeh_plot_knn, color_palette
20
+
21
+ import tempfile
22
+ import shutil
23
+ import json
24
+ from zipfile import ZipFile
25
+ from datetime import datetime
26
+
27
+ from collections import namedtuple
28
+
29
+ block_state_entry = namedtuple(
30
+ "block_state", ["results", "knn_results", "physical_metadata"]
31
+ )
32
+
33
+ if ".." not in sys.path:
34
+ sys.path.append("..")
35
+
36
+ from utils.constants import ModelArgs
37
+
38
+
39
+ def inf(img, n_species, threshold, architecture):
40
+ # Get the coordinates of the atoms
41
+ img, results = inference_fn(architecture, img, threshold, n_species=n_species)
42
+ draw = ImageDraw.Draw(img)
43
+ for (k, v), color in zip(results["species"].items(), color_palette):
44
+ color = "#" + "".join([f"{int(255 * x):02x}" for x in color])
45
+ draw.text((5, 5 + 15 * k), f"species {k}", fill=color)
46
+ for x, y in v["coords"]:
47
+ draw.ellipse(
48
+ [x - 5, y - 5, x + 5, y + 5],
49
+ outline=color,
50
+ width=2,
51
+ )
52
+ return img, results
53
+
54
+
55
+ def batch_fn(files, n_species, threshold, architecture, block_state):
56
+ block_state = {}
57
+ if not files:
58
+ raise ValueError("No files were uploaded")
59
+
60
+ gallery = []
61
+ for file in files:
62
+ error_physical_metadata = None
63
+ try:
64
+ physical_metadata = extract_physical_metadata(file.name)
65
+ if physical_metadata.unit != "nm":
66
+ raise ValueError(f"Unit of {file.name} is not nm, cannot process it")
67
+ except Exception as e:
68
+ error_physical_metadata = e
69
+ physical_metadata = None
70
+
71
+ original_file_name = file.name.split("/")[-1]
72
+ img, results = inf(file.name, n_species, threshold, architecture)
73
+ mask = segment_image(file.name)
74
+ gallery.append((img, original_file_name))
75
+
76
+ if physical_metadata is not None:
77
+ factor = 1.0 - np.mean(mask)
78
+ scale = physical_metadata.pixel_width
79
+ edge = physical_metadata.pixel_width * physical_metadata.width
80
+ knn_results = {
81
+ k: knn(results["species"][k]["coords"], scale, factor, edge)
82
+ for k in results["species"]
83
+ }
84
+ else:
85
+ knn_results = None
86
+
87
+ block_state[original_file_name] = block_state_entry(
88
+ results, knn_results, physical_metadata
89
+ )
90
+
91
+ knn_args = [
92
+ (
93
+ original_file_name,
94
+ {
95
+ k: block_state[original_file_name].knn_results[k]["distances"]
96
+ for k in block_state[original_file_name].knn_results
97
+ },
98
+ )
99
+ for original_file_name in block_state
100
+ if block_state[original_file_name].knn_results is not None
101
+ ]
102
+ if len(knn_args) > 0:
103
+ bokeh_plot = gr.update(
104
+ value=bokeh_plot_knn(knn_args, with_cumulative=True), visible=True
105
+ )
106
+ else:
107
+ bokeh_plot = gr.update(visible=False)
108
+ return (
109
+ gallery,
110
+ block_state,
111
+ gr.update(visible=True),
112
+ bokeh_plot,
113
+ gr.HTML.update(
114
+ value=f"<p style='width:fit-content; background-color:rgba(255, 0, 0, 0.75); border-radius:5px; padding:5px; color:white;'>{error_physical_metadata}</p>",
115
+ visible=bool(error_physical_metadata),
116
+ ),
117
+ )
118
+
119
+
120
+ class NumpyEncoder(json.JSONEncoder):
121
+ """Special json encoder for numpy types"""
122
+
123
+ def default(self, obj):
124
+ if isinstance(obj, np.integer):
125
+ return int(obj)
126
+ elif isinstance(obj, np.floating):
127
+ return float(obj)
128
+ elif isinstance(obj, np.ndarray):
129
+ return obj.tolist()
130
+ return json.JSONEncoder.default(self, obj)
131
+
132
+
133
+ def batch_export_files(gallery, block_state):
134
+ # Return images, coords as csv and a zip containing everything
135
+ files = []
136
+ tmpdir = tempfile.mkdtemp()
137
+ with ZipFile(
138
+ f"{tmpdir}/all_results_{datetime.now().isoformat()}.zip", "w"
139
+ ) as zipObj:
140
+ # Add all metatada
141
+ for data_dict, original_file_name in gallery:
142
+ file_name = original_file_name.split(".")[0]
143
+
144
+ # Save the image
145
+ pred_map_path = f"{tmpdir}/pred_map_{file_name}.png"
146
+ file_path = data_dict["name"]
147
+ shutil.copy(file_path, pred_map_path)
148
+ zipObj.write(pred_map_path, arcname=f"{file_name}/pred_map.png")
149
+ files.append(pred_map_path)
150
+
151
+ # Save the coords
152
+ results = block_state[original_file_name].results
153
+ coords_path = f"{tmpdir}/coords_{file_name}.csv"
154
+ with open(coords_path, "w") as f:
155
+ f.write("x,y,likelihood,specie,confidence\n")
156
+ for k, v in results["species"].items():
157
+ for (x, y), likelihood, confidence in zip(
158
+ v["coords"], v["likelihood"], v["confidence"]
159
+ ):
160
+ f.write(f"{x},{y},{likelihood},{k},{confidence}\n")
161
+ zipObj.write(coords_path, arcname=f"{file_name}/coords.csv")
162
+ files.append(coords_path)
163
+
164
+ # Save the knn results
165
+ if block_state[original_file_name].knn_results is not None:
166
+ knn_results = block_state[original_file_name].knn_results
167
+ knn_path = f"{tmpdir}/knn_results_{file_name}.json"
168
+ with open(knn_path, "w") as f:
169
+ json.dump(knn_results, f, cls=NumpyEncoder)
170
+ zipObj.write(knn_path, arcname=f"{file_name}/knn_results.json")
171
+ files.append(knn_path)
172
+
173
+ # Save the physical metadata
174
+ if block_state[original_file_name].physical_metadata is not None:
175
+ physical_metadata = block_state[original_file_name].physical_metadata
176
+ metadata_path = f"{tmpdir}/physical_metadata_{file_name}.json"
177
+ with open(metadata_path, "w") as f:
178
+ json.dump(physical_metadata._asdict(), f, cls=NumpyEncoder)
179
+ zipObj.write(
180
+ metadata_path, arcname=f"{file_name}/physical_metadata.json"
181
+ )
182
+ files.append(metadata_path)
183
+
184
+ files.append(zipObj.filename)
185
+ return gr.update(value=files, visible=True)
186
+
187
+
188
+ CSS = """
189
+ .header {
190
+ display: flex;
191
+ justify-content: center;
192
+ align-items: center;
193
+ padding: var(--block-padding);
194
+ border-radius: var(--block-radius);
195
+ background: var(--button-secondary-background-hover);
196
+ }
197
+
198
+ img {
199
+ width: 150px;
200
+ margin-right: 40px;
201
+ }
202
+
203
+ .title {
204
+ text-align: left;
205
+ }
206
+
207
+ h1 {
208
+ font-size: 36px;
209
+ margin-bottom: 10px;
210
+ }
211
+
212
+ p {
213
+ font-size: 18px;
214
+ }
215
+
216
+ input {
217
+ width: 70px;
218
+ }
219
+
220
+ @media (max-width: 600px) {
221
+ h1 {
222
+ font-size: 24px;
223
+ }
224
+
225
+ p {
226
+ font-size: 14px;
227
+ }
228
+ }
229
+
230
+ """
231
+
232
+
233
+ with gr.Blocks(css=CSS) as block:
234
+ block_state = gr.State({})
235
+ gr.HTML(
236
+ """
237
+ <div class="header">
238
+ <a href="https://www.nccr-catalysis.ch/" target="_blank">
239
+ <img src="https://www.nccr-catalysis.ch/site/assets/files/1/nccr_catalysis_logo.svg" alt="NCCR Catalysis">
240
+ </a>
241
+ <div class="title">
242
+ <h1>Atom Detection</h1>
243
+ <p>Quantitative description of metal center organization in single-atom catalysts</p>
244
+ </div>
245
+ </div>
246
+ """
247
+ )
248
+ with gr.Row():
249
+ with gr.Column():
250
+ with gr.Row():
251
+ n_species = gr.Number(
252
+ label="Number of species",
253
+ min=1,
254
+ max=10,
255
+ value=1,
256
+ step=1,
257
+ precision=0,
258
+ visible=True,
259
+ )
260
+ threshold = gr.Slider(
261
+ minimum=0.0,
262
+ maximum=1.0,
263
+ value=0.8,
264
+ label="Threshold",
265
+ visible=True,
266
+ )
267
+ architecture = gr.Dropdown(
268
+ label="Architecture",
269
+ choices=[
270
+ ModelArgs.BASICCNN,
271
+ # ModelArgs.RESNET18,
272
+ ],
273
+ value=ModelArgs.BASICCNN,
274
+ visible=False,
275
+ )
276
+ files = gr.Files(
277
+ label="Images",
278
+ file_types=[".tif", ".tiff"],
279
+ type="file",
280
+ interactive=True,
281
+ )
282
+ button = gr.Button(value="Run")
283
+ with gr.Column():
284
+ with gr.Tab("Masked prediction") as masked_tab:
285
+ masked_prediction_gallery = gr.Gallery(
286
+ label="Masked predictions"
287
+ ).style(columns=3)
288
+ with gr.Tab("Nearest neighbors") as nn_tab:
289
+ bokeh_plot = gr.Plot(show_label=False)
290
+ error_html = gr.HTML(visible=False)
291
+ export_btn = gr.Button(value="Export files", visible=False)
292
+ exported_files = gr.File(
293
+ label="Exported files",
294
+ file_count="multiple",
295
+ type="file",
296
+ interactive=False,
297
+ visible=False,
298
+ )
299
+ button.click(
300
+ batch_fn,
301
+ inputs=[files, n_species, threshold, architecture, block_state],
302
+ outputs=[
303
+ masked_prediction_gallery,
304
+ block_state,
305
+ export_btn,
306
+ bokeh_plot,
307
+ error_html,
308
+ ],
309
+ )
310
+ export_btn.click(
311
+ batch_export_files, [masked_prediction_gallery, block_state], [exported_files]
312
+ )
313
+ with gr.Accordion(label="How to ✨", open=True):
314
+ gr.HTML(
315
+ """
316
+ <div style="font-size: 14px;">
317
+ <ol>
318
+ <li>Select one or multiple microscopy images as <b>.tiff files</b> 📷🔬</li>
319
+ <li>Upload individual or multiple .tif images for processing 📤🔢</li>
320
+ <li>Export the output files. The generated zip archive will contain:
321
+ <ul>
322
+ <li>An image with overlayed atomic positions 🌟🔍</li>
323
+ <li>A table of atomic positions (in px) along with their probability 📊💎</li>
324
+ <li>Physical metadata of the respective images 📄🔍</li>
325
+ <li>JSON-formatted plot data 📊📝</li>
326
+ </ul>
327
+ </li>
328
+ </ol>
329
+ <details style="padding: 5px; border-radius: 5px; background: var(--button-secondary-background-hover); font-size: 14px;">
330
+ <summary>Note</summary>
331
+ <ul style="padding-left: 10px;">
332
+ <li>
333
+ Structural descriptors beyond pixel-wise atom detections are available as outputs only if images present an embedded real-space calibration (e.g., in <a href="https://imagej.nih.gov/ij/docs/guide/146-30.html#sub:Set-Scale...">nm px-1</a>) 📷🔬
334
+ </li>
335
+ <li>
336
+ 32-bit images will be processed correctly, but appear as mostly white in the image preview window
337
+ </li>
338
+ </ul>
339
+ </details>
340
+ </div>
341
+ """
342
+ )
343
+ with gr.Accordion(label="Disclaimer and License", open=False):
344
+ gr.HTML(
345
+ """
346
+ <div class="acknowledgments">
347
+ <h3>Disclaimer</h3>
348
+ <p>NCCR licenses the Atom Detection Web-App utilisation “as is” with no express or implied warranty of any kind. NCCR specifically disclaims all express or implied warranties to the fullest extent allowed by applicable law, including without limitation all implied warranties of merchantability, title or fitness for any particular purpose or non-infringement. No oral or written information or advice given by the authors shall create or form the basis of any warranty of any kind.</p>
349
+ <h3>License</h3>
350
+ <p>Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
351
+ <br>
352
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
353
+ <br>
354
+ The software is provided “as is”, without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose and noninfringement. In no event shall the authors or copyright holders be liable for any claim, damages or other liability, whether in an action of contract, tort or otherwise, arising from, out of or in connection with the software or the use or other dealings in the software.</p>
355
+ </div>
356
+ """
357
+ )
358
+ gr.HTML(
359
+ """
360
+ <div style="background-color: var(--secondary-100); border-radius: 5px; padding: 10px;">
361
+ <p style='font-size: 14px; color: black'>To reference the use of this web app in a publication, please refer to the Atom Detection web app and the development described in this publication: K. Rossi et al. Adv. Mater. 2023, <a href="https://doi.org/10.1002/adma.202307991">doi:10.1002/adma.202307991</a>.</p>
362
+ </div>
363
+ """
364
+ )
365
+
366
+
367
+ block.launch(
368
+ share=False,
369
+ show_error=True,
370
+ server_name="0.0.0.0",
371
+ server_port=9003,
372
+ enable_queue=True,
373
+ )
app/assets/ETH_Zurich_Logo_black.svg ADDED
app/assets/logo-ace.png ADDED
app/backup_tiff_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 April 25, 11:59:06
6
+ @last modified : 2023 June 20, 15:04:37
7
+ """
8
+
9
+ import re
10
+ import imageio
11
+ import numpy as np
12
+ from collections import namedtuple
13
+ from typing import Protocol
14
+
15
+ physical_metadata = namedtuple("physical_metadata", ["width", "height", "pixel_width", "pixel_height", "unit"])
16
+
17
+
18
+ class ImageMetadataExtractor(Protocol):
19
+ @classmethod
20
+ def __call__(cls, image_path:str, strict:bool=True) -> physical_metadata:
21
+ ...
22
+
23
+ class TIFFMetadataExtractor(ImageMetadataExtractor):
24
+ @classmethod
25
+ def __call__(cls, image_path:str, strict:bool=True) -> physical_metadata:
26
+ """
27
+ Extracts the physical metadata of an image (only tiff for now)
28
+ """
29
+ with open(image_path, "rb") as f:
30
+ data = f.read()
31
+ reader = imageio.get_reader(data, format=".tif")
32
+ metadata = reader.get_meta_data()
33
+
34
+ if strict and not metadata['is_imagej']:
35
+ for key, value in metadata.items():
36
+ if key.startswith("is_") and value == True: # Force bool to be True, because it can also pass the condition while being an random object
37
+ raise ValueError(f"The image is not TIFF image, but it seems to be a {key[3:]} image")
38
+ raise ValueError("The image is not in TIFF format")
39
+ h, w = reader.get_next_data().shape
40
+ ipw, iph, _ = metadata['resolution']
41
+ result = re.search(r"unit=(.+)", metadata['description'])
42
+ if strict and not result:
43
+ raise ValueError(f"No scale unit found in the image description : {metadata['description']}")
44
+ unit = result and result.group(1)
45
+ return physical_metadata(w, h, 1. / ipw, 1. / iph, unit)
46
+
47
+ def extract_physical_metadata(image_path : str, strict:bool=True) -> physical_metadata:
48
+ if image_path.endswith(".tif"):
49
+ return TIFFMetadataExtractor(image_path, strict)
50
+
51
+ def tiff_to_png(image, inplace=True):
52
+ img = image if inplace else image.copy()
53
+ if np.array(img.getdata()).max() <= 1:
54
+ img = img.point(lambda p: p * 255)
55
+ return img.convert("RGB")
app/dl_inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 March 17, 10:56:06
6
+ @last modified : 2023 July 18, 10:25:32
7
+ """
8
+
9
+ # Naive import of atomdetection, maybe should make a package out of it
10
+ from functools import lru_cache
11
+ import sys
12
+
13
+ if ".." not in sys.path:
14
+ sys.path.append("..")
15
+
16
+ import os
17
+ import torch
18
+ import numpy as np
19
+ from PIL import Image
20
+ from PIL.Image import Image as PILImage
21
+ from typing import Union
22
+ from utils.constants import ModelArgs
23
+ from utils.paths import MODELS_PATH, DATASET_PATH
24
+ from atoms_detection.dl_detection import DLDetection
25
+ from atoms_detection.evaluation import Evaluation
26
+ from tiff_utils import tiff_to_png
27
+
28
+ LOGS_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
29
+ VOID_DS = os.path.join(DATASET_PATH, "void.csv")
30
+ DET_PATH = os.path.join(LOGS_PATH, "detections")
31
+ INF_PATH = os.path.join(LOGS_PATH, "inference_cache")
32
+
33
+ from atoms_detection.create_crop_dataset import create_crop
34
+ from atoms_detection.vae_svi_train import SVItrainer, init_dataloader
35
+ from atoms_detection.vae_model import rVAE
36
+ from sklearn.mixture import GaussianMixture
37
+
38
+
39
+ @lru_cache(maxsize=100)
40
+ def get_vae_model(
41
+ in_dim: tuple = (21, 21),
42
+ latent_dim: int = 50,
43
+ coord: int = 3,
44
+ seed: int = 42,
45
+ ):
46
+ return rVAE(in_dim=in_dim, latent_dim=latent_dim, coord=coord, seed=seed)
47
+
48
+
49
+ def multimers_classification(
50
+ img,
51
+ coords,
52
+ likelihood,
53
+ n_species,
54
+ latent_dim: int = 50,
55
+ coord: int = 3,
56
+ reg_covar: float = 0.0001,
57
+ seed: int = 42,
58
+ epochs: int = 20,
59
+ scale_factor: float = 3.0,
60
+ batch_size: int = 100,
61
+ ):
62
+ def get_crops(img, coords):
63
+ """Get crops from image and coords"""
64
+ crops = np.array(
65
+ [np.array(create_crop(Image.fromarray(img), x, y)) for x, y in coords]
66
+ ) # TODO : can be optimized if computationally heavy (multithreading)
67
+ return crops
68
+
69
+ # Get crops to train VAE on
70
+ crops = get_crops(img, coords)
71
+ # Initialize VAE
72
+ rvae = rVAE(in_dim=(21, 21), latent_dim=latent_dim, coord=coord, seed=seed)
73
+
74
+ # Train VAE to reconstruct crops
75
+ torch_crops = torch.tensor(crops).float()
76
+ train_loader = init_dataloader(torch_crops, batch_size=batch_size)
77
+ trainer = SVItrainer(rvae)
78
+ for e in range(epochs):
79
+ trainer.step(train_loader, scale_factor=scale_factor)
80
+ trainer.print_statistics()
81
+
82
+ # Extract latent space (only mean) from VAE
83
+ z_mean, _ = rvae.encode(torch_crops)
84
+
85
+ # Cluster latent space with GMM
86
+ gmm = GaussianMixture(
87
+ n_components=n_species, reg_covar=reg_covar, random_state=seed
88
+ )
89
+ preds = gmm.fit_predict(z_mean)
90
+ pred_proba = gmm.predict_proba(z_mean)
91
+ pred_proba = np.array([pred_proba[i, pred] for i, pred in enumerate(preds)])
92
+
93
+ # To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int)
94
+ cluster_median_values = list()
95
+ for k in range(n_species):
96
+ relevant_crops = crops[preds == k]
97
+ crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0)
98
+ img_means = []
99
+ for crop, q in zip(relevant_crops, crop_95_percentile):
100
+ if (crop >= q).any():
101
+ img_means.append(crop.mean())
102
+ cluster_median_value = np.median(np.array(img_means))
103
+ cluster_median_values.append(cluster_median_value)
104
+ # Sort clusters by median value
105
+ sorted_clusters = sorted(
106
+ [(mval, c_id) for c_id, mval in enumerate(cluster_median_values)]
107
+ )
108
+
109
+ # Return results in a dict with cluster id as key
110
+ results = {}
111
+ for _, c_id in sorted_clusters:
112
+ c_idd = np.array([c_id])
113
+ results[c_id] = {
114
+ "coords": coords[preds == c_idd],
115
+ "likelihood": likelihood[preds == c_idd],
116
+ "confidence": pred_proba[preds == c_idd],
117
+ }
118
+ return results
119
+
120
+
121
+ def inference_fn(
122
+ architecture: ModelArgs,
123
+ image: Union[str, PILImage],
124
+ threshold: float,
125
+ n_species: int,
126
+ ):
127
+ if architecture != ModelArgs.BASICCNN:
128
+ raise ValueError(f"Architecture {architecture} not supported yet")
129
+ ckpt_filename = os.path.join(
130
+ MODELS_PATH,
131
+ {
132
+ ModelArgs.BASICCNN: "model_C_NT_CLIP.ckpt",
133
+ # ModelArgs.BASICCNN: "model_replicate20.ckpt",
134
+ # ModelArgs.RESNET18 "inference_resnet.ckpt",
135
+ }[architecture],
136
+ )
137
+ detection = DLDetection(
138
+ model_name=architecture,
139
+ ckpt_filename=ckpt_filename,
140
+ dataset_csv=VOID_DS,
141
+ threshold=threshold,
142
+ detections_path=DET_PATH,
143
+ inference_cache_path=INF_PATH,
144
+ batch_size=512,
145
+ )
146
+ # Force the image to be in float32 because otherwise it will output wrong results (probably due to the median filter)
147
+ if type(image) == str:
148
+ image = Image.open(image)
149
+ img = np.asarray(image, dtype=np.float32)
150
+ # if img.max() <= 1:
151
+ # raise ValueError("Gradio seems to preprocess badly the tiff images. Did you adapt the preprocessing function as mentionned in the app.py file comments?")
152
+ prepro_img, _, pred_map = detection.image_to_pred_map(img, return_intermediate=True)
153
+ center_coords_list, likelihood_list = (np.array(x) for x in detection.pred_map_to_atoms(pred_map))
154
+ results = (
155
+ multimers_classification(
156
+ img=prepro_img,
157
+ coords=center_coords_list,
158
+ likelihood=likelihood_list,
159
+ n_species=n_species,
160
+ )
161
+ if n_species > 1
162
+ else {
163
+ 0: {
164
+
165
+ "coords": center_coords_list,
166
+ "likelihood": likelihood_list,
167
+ "confidence": np.ones(len(center_coords_list)),
168
+ }
169
+ }
170
+ )
171
+ for k, v in results.items():
172
+ results[k]["atoms_bbs"] = [
173
+ Evaluation.center_coords_to_bbox(center_coords)
174
+ for center_coords in v["coords"]
175
+ ]
176
+ return tiff_to_png(image), {
177
+ "image": tiff_to_png(image),
178
+ "pred_map": pred_map,
179
+ "species": results,
180
+ }
181
+
182
+
183
+ if __name__ == "__main__":
184
+ from utils.paths import IMG_PATH
185
+
186
+ img_path = os.path.join(IMG_PATH, "091_HAADF_15nm_Sample_PtNC_21Oct20.tif")
187
+ _ = inference_fn(ModelArgs.BASICCNN, Image.open(img_path), 0.8)
app/knn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 May 16, 16:18:43
6
+ @last modified : 2023 August 07, 11:54:19
7
+ """
8
+
9
+ from typing import List, Tuple
10
+
11
+ from PIL import Image
12
+ from collections import defaultdict
13
+ from tempfile import mktemp
14
+ import matplotlib
15
+ import numpy as np
16
+ import os
17
+ import seaborn as sns
18
+ from logger import logger
19
+
20
+
21
+ matplotlib.use("agg")
22
+ import matplotlib.pyplot as plt
23
+ from scipy.stats import rayleigh
24
+ from sklearn.neighbors import NearestNeighbors
25
+
26
+
27
+ def segment_image(filename, alpha=5):
28
+ # Get a random image png file
29
+ filename = filename.replace(" ", "\ ")
30
+ png_img = mktemp(suffix=".png")
31
+ segmented_img = mktemp(suffix=".png")
32
+ logger.debug(f"Segmenting image {filename}...")
33
+ logger.debug(f"Saving image to {png_img}...")
34
+ logger.debug(f"Saving segmented image to {segmented_img}...")
35
+ try:
36
+ # Segment with image magic in the terminal
37
+ ret = os.system(f"convert {filename} {png_img}")
38
+ if ret != 0:
39
+ raise RuntimeError(f"PNG conversion failed with return code {ret}")
40
+ ret = os.system(
41
+ f"convert {png_img} -alpha on -fill none -fuzz {alpha}% -draw 'color 0,0 replace' {segmented_img}"
42
+ )
43
+ if ret != 0:
44
+ raise RuntimeError(f"Segmentation failed with return code {ret}")
45
+ # Load the image
46
+ img = Image.open(segmented_img)
47
+ # Get mask from image
48
+ mask = np.array(img) == 0
49
+ finally:
50
+ # Delete the temporary files
51
+ if os.path.exists(png_img):
52
+ os.remove(png_img)
53
+ if os.path.exists(segmented_img):
54
+ os.remove(segmented_img)
55
+ return mask
56
+
57
+
58
+ condition = lambda x: x < 0.23
59
+
60
+
61
+ def knn(coords: List[Tuple[int, int]], scale: float, factor: float, edge: float):
62
+ coords = np.array(coords) # B, 2
63
+ x, y = coords.T * scale
64
+
65
+ print(f"edge: {edge}, scale: {scale}, factor: {factor}, edge*scale: {edge*scale}")
66
+ # edge = edge * scale
67
+
68
+ data = np.c_[x, y]
69
+
70
+ neighbors = NearestNeighbors(n_neighbors=2, algorithm="ball_tree").fit(data)
71
+ distances = neighbors.kneighbors(data)[0][:, 1]
72
+
73
+ # loc, scale = rayleigh.fit(distances, floc=0)
74
+ # r_KNN = scale * np.sqrt(np.pi / 2.0)
75
+
76
+ lamda_RNN = len(x) / (edge * edge * factor)
77
+ r_RNN = 1 / (2 * np.sqrt(lamda_RNN))
78
+
79
+ n_samples = len(distances)
80
+ n_multimers = sum(condition(x) for x in distances)
81
+ percentage_multimers = 100.0 * n_multimers / n_samples
82
+ density = n_samples / (factor * edge**2)
83
+ min_dist = min(distances)
84
+ μ_dist = np.mean(distances)
85
+
86
+ return {
87
+ "n_samples": {
88
+ "description": "Number of atoms detected (units = #atoms)",
89
+ "value": n_samples,
90
+ },
91
+ "number of atoms in multimers": {
92
+ "description": "Number of atoms detected to belong to a multimer (units = #atoms)",
93
+ "value": n_multimers,
94
+ },
95
+ "share of multimers": {
96
+ "description": "Percentage of atoms that belong to a multimer (units = %)",
97
+ "value": percentage_multimers,
98
+ },
99
+ "density": {
100
+ "description": "Number of atoms / area in the micrograph (units = #atoms/nm²)",
101
+ "value": density,
102
+ },
103
+ "min_dist": {
104
+ "description": "Lowest first nearest neighbour distance detected (units = nm)",
105
+ "value": min_dist,
106
+ },
107
+ "<NND>": {
108
+ "description": "Mean first nearest neighbour distance (units = nm)",
109
+ "value": μ_dist,
110
+ },
111
+ "r_RNN": {
112
+ "description": "First neighbour distance expected from a purely random distribution (units = nm)",
113
+ "value": r_RNN,
114
+ },
115
+ "distances": distances,
116
+ }
117
+
118
+
119
+ from bokeh.plotting import figure
120
+ from bokeh.models import ColumnDataSource, HoverTool
121
+ from bokeh.plotting import figure
122
+ from scipy.stats import gaussian_kde
123
+ from collections import defaultdict
124
+
125
+ color_palette = sns.color_palette("Set3")[2:]
126
+
127
+
128
+ def bokeh_plot_knn(distances, with_cumulative=False):
129
+ """
130
+ Plot the KNN distances for the different images with the possibility to zoom in and out and toggle the lines
131
+ """
132
+ p = figure(title="K=1 NN distances", background_fill_color="#fafafa")
133
+ p.xaxis.axis_label = "Distances (nm)"
134
+ p.yaxis.axis_label = "Density"
135
+ p.x_range.start = 0
136
+
137
+ if with_cumulative:
138
+ cum_dists = defaultdict(list)
139
+ for _, dists in distances:
140
+ for specie, dist in dists.items():
141
+ cum_dists[specie].extend(dist)
142
+ cum_dists = {specie: np.array(dist) for specie, dist in cum_dists.items()}
143
+ distances.append(("Cumulative", cum_dists))
144
+
145
+ plot_dict = defaultdict(dict)
146
+ base_colors = color_palette
147
+ for (image_name, species_distances), base_color in zip(distances, base_colors):
148
+ palette = (
149
+ sns.light_palette(
150
+ base_color, n_colors=len(species_distances) + 1, reverse=True
151
+ )[1::-1]
152
+ if len(species_distances) > 1
153
+ else [base_color]
154
+ )
155
+ colors = [
156
+ f"#{int(255*r):02x}{int(255*g):02x}{int(255*b):02x}" for r, g, b in palette
157
+ ]
158
+ for (specie, dists), color in zip(species_distances.items(), colors):
159
+ kde = gaussian_kde(dists)
160
+ # Reduce smoothing
161
+ kde.set_bandwidth(bw_method=0.1)
162
+ x = np.linspace(-0.5, 1.2 * dists.max(), 100)
163
+ source = ColumnDataSource(
164
+ dict(
165
+ x=x,
166
+ y=kde(x),
167
+ species=[specie] * len(x),
168
+ p_below=[np.mean(dists < 0.22)] * len(x),
169
+ mean=[np.mean(dists)] * len(x),
170
+ filename=[image_name] * len(x),
171
+ )
172
+ )
173
+ plot_dict[image_name][specie] = [
174
+ p.line(
175
+ line_width=2,
176
+ alpha=0.8,
177
+ legend_label=image_name,
178
+ line_color=color,
179
+ source=source,
180
+ ),
181
+ p.varea(
182
+ y1="y",
183
+ y2=0,
184
+ alpha=0.3,
185
+ legend_label=image_name,
186
+ source=source,
187
+ fill_color=color,
188
+ ),
189
+ ]
190
+ p.add_tools(
191
+ HoverTool(
192
+ show_arrow=False,
193
+ line_policy="next",
194
+ tooltips=[
195
+ ("First NN distances < 0.22nm", "@p_below{0.00%}"),
196
+ ("<NND>", "@mean{0.00} nm"),
197
+ ("species", "@species"),
198
+ ("filename", "@filename"),
199
+ ],
200
+ )
201
+ )
202
+ p.legend.click_policy = "hide"
203
+ return p
app/logger.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ name = 'atomdetection-app'
4
+ logger = logging.getLogger(name)
5
+ logger.setLevel(logging.DEBUG)
6
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
7
+
8
+ # Create a file handler and set its level and formatter
9
+ file_handler = logging.FileHandler(f'{name}.log')
10
+ file_handler.setLevel(logging.DEBUG)
11
+ file_handler.setFormatter(formatter)
12
+
13
+ # Add the file handler to the logger
14
+ logger.addHandler(file_handler)
app/tiff_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 April 25, 11:59:06
6
+ @last modified : 2023 September 19, 11:18:36
7
+ """
8
+
9
+ import re
10
+ import imageio
11
+ import numpy as np
12
+ from collections import namedtuple
13
+
14
+ physical_metadata = namedtuple("physical_metadata", ["width", "height", "pixel_width", "pixel_height", "unit"])
15
+
16
+ def extract_physical_metadata(image_path : str, strict:bool=True) -> physical_metadata:
17
+ """
18
+ Extracts the physical metadata of an image (only tiff for now)
19
+ """
20
+ with open(image_path, "rb") as f:
21
+ data = f.read()
22
+ reader = imageio.get_reader(data, format=".tif")
23
+ metadata = reader.get_meta_data()
24
+
25
+ if strict and not metadata['is_imagej']:
26
+ for key, value in metadata.items():
27
+ if key.startswith("is_") and value == True: # Force bool to be True, because it can also pass the condition while being an random object
28
+ raise ValueError(f"The image is not TIFF image, but it seems to be a {key[3:]} image")
29
+ raise ValueError("Impossible to extract metadata from the image (ImageJ)")
30
+ h, w = reader.get_next_data().shape
31
+ ipw, iph, _ = metadata['resolution']
32
+ result = re.search(r"unit=(.+)", metadata['description'])
33
+ if strict and not result:
34
+ raise ValueError(f"No scale unit found in the image description : {metadata['description']}")
35
+ unit = result and result.group(1)
36
+ return physical_metadata(w, h, 1. / ipw, 1. / iph, unit)
37
+
38
+ def tiff_to_png(image, inplace=True):
39
+ img = image if inplace else image.copy()
40
+ if np.array(img.getdata()).max() <= 1:
41
+ img = img.point(lambda p: p * 255)
42
+ return img.convert("RGB")
atoms_detection/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sss
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
atoms_detection/__init__.py ADDED
File without changes
atoms_detection/create_crop_dataset.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+
7
+ from atoms_detection.image_preprocessing import dl_prepro_image
8
+ from atoms_detection.dataset import CoordinatesDataset
9
+ from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET
10
+ from utils.constants import Split, CropsColumns
11
+ import matplotlib.pyplot as plt # I don't know why tf but it doesn't work if not here
12
+
13
+ np.random.seed(777)
14
+
15
+ window_size = (21, 21)
16
+ halfx_window = ((window_size[0] - 1) // 2)
17
+ halfy_window = ((window_size[1] - 1) // 2)
18
+
19
+
20
+ def get_gaussian_kernel(size=21, mean=0, sigma=0.2):
21
+ # Initializing value of x-axis and y-axis
22
+ # in the range -1 to 1
23
+ x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
24
+ dst = np.sqrt(x * x + y * y)
25
+
26
+ # Calculating Gaussian array
27
+ kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2)))
28
+ return kernel
29
+
30
+
31
+ def generate_support_img(coordinates, window_size):
32
+ support_img = np.zeros((512, 512))
33
+ kernel = get_gaussian_kernel(size=window_size[0])
34
+ halfx_window = ((window_size[0] - 1) // 2)
35
+ halfy_window = ((window_size[1] - 1) // 2)
36
+ for x, y in coordinates:
37
+ x_range = (x - halfx_window, x + halfx_window + 1)
38
+ y_range = (y - halfy_window, y + halfy_window + 1)
39
+
40
+ x_diff = [0, 0]
41
+ y_diff = [0, 0]
42
+ if x_range[0] < 0:
43
+ x_diff[0] = 0 - x_range[0]
44
+ if x_range[1] > 512:
45
+ x_diff[1] = x_range[1] - 512
46
+ if y_range[0] < 0:
47
+ y_diff[0] = 0 - y_range[0]
48
+ if y_range[1] > 512:
49
+ y_diff[1] = y_range[1] - 512
50
+
51
+ real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]]
52
+ real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1])
53
+ real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1])
54
+
55
+ support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel
56
+
57
+ support_img = support_img.T
58
+ return support_img
59
+
60
+
61
+ def open_image(img_filename):
62
+ img = Image.open(img_filename)
63
+ np_img = np.asarray(img).astype(np.float32)
64
+ np_img = dl_prepro_image(np_img)
65
+ img = Image.fromarray(np_img)
66
+ return img
67
+
68
+
69
+ def create_crop(img: Image, x_center: int, y_center: int):
70
+ crop_coords = (
71
+ x_center - halfx_window,
72
+ y_center - halfy_window,
73
+ x_center + halfx_window + 1,
74
+ y_center + halfy_window + 1
75
+ )
76
+ crop = img.crop(crop_coords)
77
+ return crop
78
+
79
+
80
+ def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str):
81
+ if not os.path.exists(crops_folder):
82
+ os.makedirs(crops_folder)
83
+
84
+ crop_name_list = []
85
+ orig_name_list = []
86
+ x_list = []
87
+ y_list = []
88
+ label_list = []
89
+
90
+ n_positives = 0
91
+ label = 1
92
+ dataset = CoordinatesDataset(coords_csv)
93
+ print('Creating positive crops...')
94
+ for data_filename, label_filename in dataset.iterate_data(Split.TRAIN):
95
+ if label_filename is None:
96
+ continue
97
+
98
+ print(data_filename)
99
+ orig_img_name = os.path.basename(data_filename)
100
+ img_name = os.path.splitext(orig_img_name)[0]
101
+
102
+ img = open_image(data_filename)
103
+ coordinates = dataset.load_coordinates(label_filename)
104
+
105
+ for x_center, y_center in coordinates:
106
+ crop = create_crop(img, x_center, y_center)
107
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
108
+ crop.save(os.path.join(crops_folder, crop_name))
109
+
110
+ crop_name_list.append(crop_name)
111
+ orig_name_list.append(orig_img_name)
112
+ x_list.append(x_center)
113
+ y_list.append(y_center)
114
+ label_list.append(label)
115
+
116
+ n_positives += 1
117
+
118
+ label = 0
119
+ no_train_images = dataset.split_length(Split.TRAIN)
120
+ neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)]
121
+ print('Creating negative crops...')
122
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image):
123
+ print(data_filename)
124
+ orig_img_name = os.path.basename(data_filename)
125
+ img_name = os.path.splitext(orig_img_name)[0]
126
+ img = open_image(data_filename)
127
+
128
+ if label_filename:
129
+ coordinates = dataset.load_coordinates(label_filename)
130
+ support_map = generate_support_img(coordinates, window_size)
131
+ else:
132
+ support_map = None
133
+
134
+ for _ in range(no_neg_crops):
135
+ x_rand = np.random.randint(0, 512)
136
+ y_rand = np.random.randint(0, 512)
137
+
138
+ if support_map is not None:
139
+ while support_map[x_rand, y_rand] != 0:
140
+ x_rand = np.random.randint(0, 512)
141
+ y_rand = np.random.randint(0, 512)
142
+
143
+ x_center, y_center = x_rand, y_rand
144
+
145
+ crop = create_crop(img, x_center, y_center)
146
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
147
+ crop.save(os.path.join(crops_folder, crop_name))
148
+
149
+ crop_name_list.append(crop_name)
150
+ orig_name_list.append(orig_img_name)
151
+ x_list.append(x_center)
152
+ y_list.append(y_center)
153
+ label_list.append(label)
154
+
155
+ df_data = {
156
+ CropsColumns.FILENAME: crop_name_list,
157
+ CropsColumns.ORIGINAL: orig_name_list,
158
+ CropsColumns.X: x_list,
159
+ CropsColumns.Y: y_list,
160
+ CropsColumns.LABEL: label_list
161
+ }
162
+ df = pd.DataFrame(df_data, columns=[
163
+ CropsColumns.FILENAME,
164
+ CropsColumns.ORIGINAL,
165
+ CropsColumns.X,
166
+ CropsColumns.Y,
167
+ CropsColumns.LABEL
168
+ ])
169
+
170
+ df_pos = df[df.Label == 1]
171
+ df_neg = df[df.Label == 0]
172
+
173
+ pos_len = len(df_pos)
174
+ neg_len = len(df_neg)
175
+
176
+ pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)])
177
+ neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)])
178
+ pos_train[CropsColumns.SPLIT] = Split.TRAIN
179
+ pos_val[CropsColumns.SPLIT] = Split.VAL
180
+ pos_test[CropsColumns.SPLIT] = Split.TEST
181
+ neg_train[CropsColumns.SPLIT] = Split.TRAIN
182
+ neg_val[CropsColumns.SPLIT] = Split.VAL
183
+ neg_test[CropsColumns.SPLIT] = Split.TEST
184
+
185
+ df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0)
186
+ df_with_splits.to_csv(crops_dataset, header=True, index=False)
187
+
188
+
189
+ def create_contrastive_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str,
190
+ show_sampling_result: bool = False, contrastive_samples_percent: float = 0.25,
191
+ contrastive_distance_multiplier: float = 1.1, pos_data_upsampling: bool = False,
192
+ pos_upsample_dist: int = 3, neg_upsample_multiplier: float = 0):
193
+ global plt # don't ask why.
194
+ if not os.path.exists(crops_folder):
195
+ os.makedirs(crops_folder)
196
+
197
+ crop_name_list = []
198
+ orig_name_list = []
199
+ x_list = []
200
+ y_list = []
201
+ label_list = []
202
+
203
+ n_positives = 0
204
+ label = 1
205
+ dataset = CoordinatesDataset(coords_csv)
206
+ print('Creating positive crops...')
207
+ firstx, firsty = True, True
208
+ for data_filename, label_filename in dataset.iterate_data(Split.TRAIN):
209
+ if label_filename is None:
210
+ continue
211
+ print(data_filename)
212
+ orig_img_name = os.path.basename(data_filename)
213
+ img_name = os.path.splitext(orig_img_name)[0]
214
+
215
+ img = open_image(data_filename)
216
+ coordinates = dataset.load_coordinates(label_filename)
217
+
218
+ for x_center, y_center in coordinates:
219
+ crop = create_crop(img, x_center, y_center)
220
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
221
+ crop.save(os.path.join(crops_folder, crop_name))
222
+ if firstx:
223
+ firstx = False
224
+ crop_save(crop, "pos.png")
225
+ print('saved')
226
+
227
+ crop_name_list.append(crop_name)
228
+ orig_name_list.append(orig_img_name)
229
+ x_list.append(x_center)
230
+ y_list.append(y_center)
231
+ label_list.append(label)
232
+ if pos_data_upsampling:
233
+ x_rand, y_rand = None, None
234
+ while x_rand is None:
235
+ rand_angle = np.random.uniform(0, 2 * np.pi)
236
+ x_rand = round(pos_upsample_dist * np.cos(rand_angle)) + x_center
237
+ y_rand = round(pos_upsample_dist * np.sin(rand_angle)) + y_center
238
+ out_of_bounds = x_rand >= img.size[0] or y_rand >= img.size[1] or \
239
+ x_rand < 0 or y_rand < 0
240
+ if out_of_bounds != 0:
241
+ x_rand, y_rand = None, None
242
+
243
+ crop = create_crop(img, x_rand, y_rand)
244
+ crop_name = "{}_{}_{}.tif".format(img_name, x_rand, y_rand)
245
+ crop.save(os.path.join(crops_folder, crop_name))
246
+ crop_name_list.append(crop_name)
247
+ orig_name_list.append(orig_img_name)
248
+ x_list.append(x_center)
249
+ y_list.append(y_center)
250
+ label_list.append(label)
251
+
252
+ if firsty:
253
+ firsty = False
254
+ crop_save(crop, "pos_jit.png")
255
+
256
+ n_positives += 1
257
+
258
+ label = 0
259
+ no_train_images = dataset.split_length(Split.TRAIN)
260
+ contrastive_sampling_distance = (window_size[0] * contrastive_distance_multiplier) // 2
261
+ neg_crops_per_image = [round((n_positives // no_train_images) * (1+neg_upsample_multiplier)) + (1 if x < n_positives % no_train_images else 0) for x in
262
+ range(no_train_images)]
263
+ neg_non_constrastive_crops_per_image, neg_contrastive_crops_per_image = \
264
+ list(zip(*[(n_crops - round(contrastive_samples_percent * n_crops),
265
+ round(contrastive_samples_percent * n_crops))
266
+ for n_crops in neg_crops_per_image]))
267
+ firstx, firsty = True, True
268
+ # neg_non_constrastive_crops_per_image, neg_contrastive_crops_per_image = 30*[0], 30*[44]
269
+ print(contrastive_sampling_distance)
270
+ print('Creating contrastive negative crops...')
271
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN),
272
+ neg_contrastive_crops_per_image):
273
+ print(data_filename)
274
+ orig_img_name = os.path.basename(data_filename)
275
+ img_name = os.path.splitext(orig_img_name)[0]
276
+ img = open_image(data_filename)
277
+
278
+ if label_filename:
279
+ coordinates = dataset.load_coordinates(label_filename)
280
+ support_map = generate_support_img(coordinates, window_size)
281
+ else:
282
+ support_map = None
283
+
284
+ for idx in np.random.choice(len(coordinates), no_neg_crops):
285
+ atom_rand = coordinates[idx]
286
+ x_center, y_center = atom_rand
287
+ x_rand, y_rand = None, None
288
+ if support_map is not None:
289
+ retries=0
290
+ while x_rand is None and retries < 50: # Extremely unlikely: sample impossible
291
+ retries += 1
292
+ rand_angle = np.random.uniform(0, 2 * np.pi)
293
+ x_rand = round(contrastive_sampling_distance * np.cos(rand_angle)) + x_center
294
+ y_rand = round(contrastive_sampling_distance * np.sin(rand_angle)) + y_center
295
+ out_of_bounds = x_rand >= img.size[0] or y_rand >= img.size[1] or \
296
+ x_rand<0 or y_rand<0
297
+ if out_of_bounds or support_map[x_rand, y_rand] != 0:
298
+ x_rand, y_rand = None, None
299
+
300
+ x_center, y_center = x_rand, y_rand
301
+
302
+ crop = create_crop(img, x_center, y_center)
303
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
304
+ crop.save(os.path.join(crops_folder, crop_name))
305
+
306
+ crop_name_list.append(crop_name)
307
+ orig_name_list.append(orig_img_name)
308
+ x_list.append(x_center)
309
+ y_list.append(y_center)
310
+ label_list.append(label)
311
+ if firsty:
312
+ firsty = False
313
+ crop_save(crop, "neg_con.png")
314
+
315
+ print('Creating non-contrastive negative crops...')
316
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN),
317
+ neg_non_constrastive_crops_per_image):
318
+ print(data_filename)
319
+ orig_img_name = os.path.basename(data_filename)
320
+ img_name = os.path.splitext(orig_img_name)[0]
321
+ img = open_image(data_filename)
322
+
323
+ if label_filename:
324
+ coordinates = dataset.load_coordinates(label_filename)
325
+ support_map = generate_support_img(coordinates, window_size)
326
+ else:
327
+ support_map = None
328
+
329
+ for _ in range(no_neg_crops):
330
+ x_rand = np.random.randint(0, 512)
331
+ y_rand = np.random.randint(0, 512)
332
+
333
+ if support_map is not None:
334
+ while support_map[x_rand, y_rand] != 0:
335
+ x_rand = np.random.randint(0, 512)
336
+ y_rand = np.random.randint(0, 512)
337
+
338
+ x_center, y_center = x_rand, y_rand
339
+
340
+ crop = create_crop(img, x_center, y_center)
341
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
342
+ crop.save(os.path.join(crops_folder, crop_name))
343
+
344
+ crop_name_list.append(crop_name)
345
+ orig_name_list.append(orig_img_name)
346
+ x_list.append(x_center)
347
+ y_list.append(y_center)
348
+ label_list.append(label)
349
+ if firstx:
350
+ firstx = False
351
+ crop_save(crop, "neg_ncon.png")
352
+
353
+ if show_sampling_result:
354
+ # Only works for single img data.
355
+ positives = [(x, y) for x,y,l in zip(x_list, y_list, label_list) if l==1]
356
+ negatives = [(x, y) for x,y,l in zip(x_list, y_list, label_list) if l==0]
357
+ from matplotlib import pyplot as plt
358
+ plt.imshow(img)
359
+ plt.scatter(*zip(*positives))
360
+ plt.scatter(*zip(*negatives))
361
+ plt.show()
362
+
363
+
364
+
365
+
366
+ df_data = {
367
+ CropsColumns.FILENAME: crop_name_list,
368
+ CropsColumns.ORIGINAL: orig_name_list,
369
+ CropsColumns.X: x_list,
370
+ CropsColumns.Y: y_list,
371
+ CropsColumns.LABEL: label_list
372
+ }
373
+ df = pd.DataFrame(df_data, columns=[
374
+ CropsColumns.FILENAME,
375
+ CropsColumns.ORIGINAL,
376
+ CropsColumns.X,
377
+ CropsColumns.Y,
378
+ CropsColumns.LABEL
379
+ ])
380
+
381
+ df_pos = df[df.Label == 1]
382
+ df_neg = df[df.Label == 0]
383
+
384
+ pos_len = len(df_pos)
385
+ neg_len = len(df_neg)
386
+
387
+ pos_train, pos_val = np.split(df_pos.sample(frac=1), [int(0.9 * pos_len)])
388
+ neg_train, neg_val = np.split(df_neg.sample(frac=1), [int(0.9 * neg_len)])
389
+ pos_train[CropsColumns.SPLIT] = Split.TRAIN
390
+ pos_val[CropsColumns.SPLIT] = Split.VAL
391
+ neg_train[CropsColumns.SPLIT] = Split.TRAIN
392
+ neg_val[CropsColumns.SPLIT] = Split.VAL
393
+ print("Final size for train(P vs N):", len(pos_train), len(neg_train))
394
+ print("Final size for val (P vs N):", len(pos_val), len(neg_val))
395
+ df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val), axis=0)
396
+ df_with_splits.to_csv(crops_dataset, header=True, index=False)
397
+
398
+
399
+ def crop_save(crop, im_name):
400
+ crop = np.array(crop)
401
+ crop = (crop + crop.min()) * 500
402
+ crop = Image.fromarray(crop)
403
+ crop = crop.convert("L")
404
+ crop.save(im_name, 'png')
405
+
406
+
407
+ if __name__ == "__main__":
408
+ create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET)
atoms_detection/create_crop_dataset_1024.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+
7
+ from atoms_detection.image_preprocessing import dl_prepro_image
8
+ from atoms_detection.dataset import CoordinatesDataset
9
+ from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET
10
+ from utils.constants import Split, CropsColumns
11
+
12
+
13
+ np.random.seed(777)
14
+
15
+ window_size = (21, 21)
16
+ halfx_window = ((window_size[0] - 1) // 2)
17
+ halfy_window = ((window_size[1] - 1) // 2)
18
+
19
+
20
+ def get_gaussian_kernel(size=33, mean=0, sigma=0.2):
21
+ # Initializing value of x-axis and y-axis
22
+ # in the range -1 to 1
23
+ x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
24
+ dst = np.sqrt(x * x + y * y)
25
+
26
+ # Calculating Gaussian array
27
+ kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2)))
28
+ return kernel
29
+
30
+
31
+ def generate_support_img(coordinates, window_size):
32
+ support_img = np.zeros((1024, 1024))
33
+ kernel = get_gaussian_kernel(size=window_size[0])
34
+ halfx_window = ((window_size[0] - 1) // 2)
35
+ halfy_window = ((window_size[1] - 1) // 2)
36
+ for x, y in coordinates:
37
+ x_range = (x - halfx_window, x + halfx_window + 1)
38
+ y_range = (y - halfy_window, y + halfy_window + 1)
39
+
40
+ x_diff = [0, 0]
41
+ y_diff = [0, 0]
42
+ if x_range[0] < 0:
43
+ x_diff[0] = 0 - x_range[0]
44
+ if x_range[1] > 1024:
45
+ x_diff[1] = x_range[1] - 1024
46
+ if y_range[0] < 0:
47
+ y_diff[0] = 0 - y_range[0]
48
+ if y_range[1] > 1024:
49
+ y_diff[1] = y_range[1] - 1024
50
+
51
+ x_diff = tuple(int(item) for item in x_diff)
52
+ y_diff = tuple(int(item) for item in y_diff)
53
+
54
+ real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]]
55
+
56
+ real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1])
57
+ real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1])
58
+
59
+ real_x_crop = tuple(int(item) for item in real_x_crop)
60
+ real_y_crop = tuple(int(item) for item in real_y_crop)
61
+
62
+ support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel
63
+
64
+ support_img = support_img.T
65
+ return support_img
66
+
67
+
68
+ def open_image(img_filename):
69
+ img = Image.open(img_filename)
70
+ np_img = np.asarray(img).astype(np.float32)
71
+ np_img = dl_prepro_image(np_img)
72
+ img = Image.fromarray(np_img)
73
+ return img
74
+
75
+
76
+ def create_crop(img: Image, x_center: int, y_center: int):
77
+ crop_coords = (
78
+ x_center - halfx_window,
79
+ y_center - halfy_window,
80
+ x_center + halfx_window + 1,
81
+ y_center + halfy_window + 1
82
+ )
83
+ crop = img.crop(crop_coords)
84
+ return crop
85
+
86
+
87
+ def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str):
88
+ if not os.path.exists(crops_folder):
89
+ os.makedirs(crops_folder)
90
+
91
+ crop_name_list = []
92
+ orig_name_list = []
93
+ x_list = []
94
+ y_list = []
95
+ label_list = []
96
+
97
+ n_positives = 0
98
+ label = 1
99
+ dataset = CoordinatesDataset(coords_csv)
100
+ print('Creating positive crops...')
101
+ for data_filename, label_filename in dataset.iterate_data(Split.TRAIN):
102
+ if label_filename is None:
103
+ continue
104
+
105
+ print(data_filename)
106
+ orig_img_name = os.path.basename(data_filename)
107
+ img_name = os.path.splitext(orig_img_name)[0]
108
+
109
+ img = open_image(data_filename)
110
+ coordinates = dataset.load_coordinates(label_filename)
111
+
112
+ for x_center, y_center in coordinates:
113
+ crop = create_crop(img, x_center, y_center)
114
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
115
+ crop.save(os.path.join(crops_folder, crop_name))
116
+
117
+ crop_name_list.append(crop_name)
118
+ orig_name_list.append(orig_img_name)
119
+ x_list.append(x_center)
120
+ y_list.append(y_center)
121
+ label_list.append(label)
122
+
123
+ n_positives += 1
124
+
125
+ label = 0
126
+ no_train_images = dataset.split_length(Split.TRAIN)
127
+ neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)]
128
+ print('Creating negative crops...')
129
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image):
130
+ print(data_filename)
131
+ orig_img_name = os.path.basename(data_filename)
132
+ img_name = os.path.splitext(orig_img_name)[0]
133
+ img = open_image(data_filename)
134
+
135
+ if label_filename:
136
+ coordinates = dataset.load_coordinates(label_filename)
137
+ support_map = generate_support_img(coordinates, window_size)
138
+ else:
139
+ support_map = None
140
+
141
+ for _ in range(no_neg_crops):
142
+ x_rand = np.random.randint(0, 1024)
143
+ y_rand = np.random.randint(0, 1024)
144
+
145
+ if support_map is not None:
146
+ while support_map[x_rand, y_rand] != 0:
147
+ x_rand = np.random.randint(0, 1024)
148
+ y_rand = np.random.randint(0, 1024)
149
+
150
+ x_center, y_center = x_rand, y_rand
151
+
152
+ crop = create_crop(img, x_center, y_center)
153
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
154
+ crop.save(os.path.join(crops_folder, crop_name))
155
+
156
+ crop_name_list.append(crop_name)
157
+ orig_name_list.append(orig_img_name)
158
+ x_list.append(x_center)
159
+ y_list.append(y_center)
160
+ label_list.append(label)
161
+
162
+ df_data = {
163
+ CropsColumns.FILENAME: crop_name_list,
164
+ CropsColumns.ORIGINAL: orig_name_list,
165
+ CropsColumns.X: x_list,
166
+ CropsColumns.Y: y_list,
167
+ CropsColumns.LABEL: label_list
168
+ }
169
+ df = pd.DataFrame(df_data, columns=[
170
+ CropsColumns.FILENAME,
171
+ CropsColumns.ORIGINAL,
172
+ CropsColumns.X,
173
+ CropsColumns.Y,
174
+ CropsColumns.LABEL
175
+ ])
176
+
177
+ df_pos = df[df.Label == 1]
178
+ df_neg = df[df.Label == 0]
179
+
180
+ pos_len = len(df_pos)
181
+ neg_len = len(df_neg)
182
+
183
+ pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)])
184
+ neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)])
185
+ pos_train[CropsColumns.SPLIT] = Split.TRAIN
186
+ pos_val[CropsColumns.SPLIT] = Split.VAL
187
+ pos_test[CropsColumns.SPLIT] = Split.TEST
188
+ neg_train[CropsColumns.SPLIT] = Split.TRAIN
189
+ neg_val[CropsColumns.SPLIT] = Split.VAL
190
+ neg_test[CropsColumns.SPLIT] = Split.TEST
191
+
192
+ df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0)
193
+ df_with_splits.to_csv(crops_dataset, header=True, index=False)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET)
atoms_detection/create_crop_dataset_2048.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+
7
+ from atoms_detection.image_preprocessing import dl_prepro_image
8
+ from atoms_detection.dataset import CoordinatesDataset
9
+ from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET
10
+ from utils.constants import Split, CropsColumns
11
+
12
+
13
+ np.random.seed(777)
14
+
15
+ window_size = (21, 21)
16
+ halfx_window = ((window_size[0] - 1) // 2)
17
+ halfy_window = ((window_size[1] - 1) // 2)
18
+
19
+
20
+ def get_gaussian_kernel(size=33, mean=0, sigma=0.2):
21
+ # Initializing value of x-axis and y-axis
22
+ # in the range -1 to 1
23
+ x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
24
+ dst = np.sqrt(x * x + y * y)
25
+
26
+ # Calculating Gaussian array
27
+ kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2)))
28
+ return kernel
29
+
30
+
31
+ def generate_support_img(coordinates, window_size):
32
+ support_img = np.zeros((2048, 2048))
33
+ kernel = get_gaussian_kernel(size=window_size[0])
34
+ halfx_window = ((window_size[0] - 1) // 2)
35
+ halfy_window = ((window_size[1] - 1) // 2)
36
+ for x, y in coordinates:
37
+ x_range = (x - halfx_window, x + halfx_window + 1)
38
+ y_range = (y - halfy_window, y + halfy_window + 1)
39
+
40
+ x_diff = [0, 0]
41
+ y_diff = [0, 0]
42
+ if x_range[0] < 0:
43
+ x_diff[0] = 0 - x_range[0]
44
+ if x_range[1] > 2048:
45
+ x_diff[1] = x_range[1] - 2048
46
+ if y_range[0] < 0:
47
+ y_diff[0] = 0 - y_range[0]
48
+ if y_range[1] > 2048:
49
+ y_diff[1] = y_range[1] - 2048
50
+
51
+ x_diff = tuple(int(item) for item in x_diff)
52
+ y_diff = tuple(int(item) for item in y_diff)
53
+
54
+ real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]]
55
+
56
+ real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1])
57
+ real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1])
58
+
59
+ real_x_crop = tuple(int(item) for item in real_x_crop)
60
+ real_y_crop = tuple(int(item) for item in real_y_crop)
61
+
62
+ support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel
63
+
64
+ support_img = support_img.T
65
+ return support_img
66
+
67
+
68
+ def open_image(img_filename):
69
+ img = Image.open(img_filename)
70
+ np_img = np.asarray(img).astype(np.float32)
71
+ np_img = dl_prepro_image(np_img)
72
+ img = Image.fromarray(np_img)
73
+ return img
74
+
75
+
76
+ def create_crop(img: Image, x_center: int, y_center: int):
77
+ crop_coords = (
78
+ x_center - halfx_window,
79
+ y_center - halfy_window,
80
+ x_center + halfx_window + 1,
81
+ y_center + halfy_window + 1
82
+ )
83
+ crop = img.crop(crop_coords)
84
+ return crop
85
+
86
+
87
+ def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str):
88
+ if not os.path.exists(crops_folder):
89
+ os.makedirs(crops_folder)
90
+
91
+ crop_name_list = []
92
+ orig_name_list = []
93
+ x_list = []
94
+ y_list = []
95
+ label_list = []
96
+
97
+ n_positives = 0
98
+ label = 1
99
+ dataset = CoordinatesDataset(coords_csv)
100
+ print('Creating positive crops...')
101
+ for data_filename, label_filename in dataset.iterate_data(Split.TRAIN):
102
+ if label_filename is None:
103
+ continue
104
+
105
+ print(data_filename)
106
+ orig_img_name = os.path.basename(data_filename)
107
+ img_name = os.path.splitext(orig_img_name)[0]
108
+
109
+ img = open_image(data_filename)
110
+ coordinates = dataset.load_coordinates(label_filename)
111
+
112
+ for x_center, y_center in coordinates:
113
+ crop = create_crop(img, x_center, y_center)
114
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
115
+ crop.save(os.path.join(crops_folder, crop_name))
116
+
117
+ crop_name_list.append(crop_name)
118
+ orig_name_list.append(orig_img_name)
119
+ x_list.append(x_center)
120
+ y_list.append(y_center)
121
+ label_list.append(label)
122
+
123
+ n_positives += 1
124
+
125
+ label = 0
126
+ no_train_images = dataset.split_length(Split.TRAIN)
127
+ neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)]
128
+ print('Creating negative crops...')
129
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image):
130
+ print(data_filename)
131
+ orig_img_name = os.path.basename(data_filename)
132
+ img_name = os.path.splitext(orig_img_name)[0]
133
+ img = open_image(data_filename)
134
+
135
+ if label_filename:
136
+ coordinates = dataset.load_coordinates(label_filename)
137
+ support_map = generate_support_img(coordinates, window_size)
138
+ else:
139
+ support_map = None
140
+
141
+ for _ in range(no_neg_crops):
142
+ x_rand = np.random.randint(0, 2048)
143
+ y_rand = np.random.randint(0, 2048)
144
+
145
+ if support_map is not None:
146
+ while support_map[x_rand, y_rand] != 0:
147
+ x_rand = np.random.randint(0, 2048)
148
+ y_rand = np.random.randint(0, 2048)
149
+
150
+ x_center, y_center = x_rand, y_rand
151
+
152
+ crop = create_crop(img, x_center, y_center)
153
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
154
+ crop.save(os.path.join(crops_folder, crop_name))
155
+
156
+ crop_name_list.append(crop_name)
157
+ orig_name_list.append(orig_img_name)
158
+ x_list.append(x_center)
159
+ y_list.append(y_center)
160
+ label_list.append(label)
161
+
162
+ df_data = {
163
+ CropsColumns.FILENAME: crop_name_list,
164
+ CropsColumns.ORIGINAL: orig_name_list,
165
+ CropsColumns.X: x_list,
166
+ CropsColumns.Y: y_list,
167
+ CropsColumns.LABEL: label_list
168
+ }
169
+ df = pd.DataFrame(df_data, columns=[
170
+ CropsColumns.FILENAME,
171
+ CropsColumns.ORIGINAL,
172
+ CropsColumns.X,
173
+ CropsColumns.Y,
174
+ CropsColumns.LABEL
175
+ ])
176
+
177
+ df_pos = df[df.Label == 1]
178
+ df_neg = df[df.Label == 0]
179
+
180
+ pos_len = len(df_pos)
181
+ neg_len = len(df_neg)
182
+
183
+ pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)])
184
+ neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)])
185
+ pos_train[CropsColumns.SPLIT] = Split.TRAIN
186
+ pos_val[CropsColumns.SPLIT] = Split.VAL
187
+ pos_test[CropsColumns.SPLIT] = Split.TEST
188
+ neg_train[CropsColumns.SPLIT] = Split.TRAIN
189
+ neg_val[CropsColumns.SPLIT] = Split.VAL
190
+ neg_test[CropsColumns.SPLIT] = Split.TEST
191
+
192
+ df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0)
193
+ df_with_splits.to_csv(crops_dataset, header=True, index=False)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET)
atoms_detection/create_crop_dataset_512.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+
7
+ from atoms_detection.image_preprocessing import dl_prepro_image
8
+ from atoms_detection.dataset import CoordinatesDataset
9
+ from utils.paths import CROPS_PATH, CROPS_DATASET, PT_DATASET
10
+ from utils.constants import Split, CropsColumns
11
+
12
+
13
+ np.random.seed(777)
14
+
15
+ window_size = (21, 21)
16
+ halfx_window = ((window_size[0] - 1) // 2)
17
+ halfy_window = ((window_size[1] - 1) // 2)
18
+
19
+
20
+ def get_gaussian_kernel(size=21, mean=0, sigma=0.2):
21
+ # Initializing value of x-axis and y-axis
22
+ # in the range -1 to 1
23
+ x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
24
+ dst = np.sqrt(x * x + y * y)
25
+
26
+ # Calculating Gaussian array
27
+ kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2)))
28
+ return kernel
29
+
30
+
31
+ def generate_support_img(coordinates, window_size):
32
+ support_img = np.zeros((512, 512))
33
+ kernel = get_gaussian_kernel(size=window_size[0])
34
+ halfx_window = ((window_size[0] - 1) // 2)
35
+ halfy_window = ((window_size[1] - 1) // 2)
36
+ for x, y in coordinates:
37
+ x_range = (x - halfx_window, x + halfx_window + 1)
38
+ y_range = (y - halfy_window, y + halfy_window + 1)
39
+
40
+ x_diff = [0, 0]
41
+ y_diff = [0, 0]
42
+ if x_range[0] < 0:
43
+ x_diff[0] = 0 - x_range[0]
44
+ if x_range[1] > 512:
45
+ x_diff[1] = x_range[1] - 512
46
+ if y_range[0] < 0:
47
+ y_diff[0] = 0 - y_range[0]
48
+ if y_range[1] > 512:
49
+ y_diff[1] = y_range[1] - 512
50
+
51
+ real_kernel = kernel[x_diff[0]:window_size[0] - x_diff[1], y_diff[0]:window_size[1] - y_diff[1]]
52
+ real_x_crop = (x_range[0] + x_diff[0], x_range[1] - x_diff[1])
53
+ real_y_crop = (y_range[0] + y_diff[0], y_range[1] - y_diff[1])
54
+
55
+ support_img[real_x_crop[0]:real_x_crop[1], real_y_crop[0]:real_y_crop[1]] += real_kernel
56
+
57
+ support_img = support_img.T
58
+ return support_img
59
+
60
+
61
+ def open_image(img_filename):
62
+ img = Image.open(img_filename)
63
+ np_img = np.asarray(img).astype(np.float32)
64
+ np_img = dl_prepro_image(np_img)
65
+ img = Image.fromarray(np_img)
66
+ return img
67
+
68
+
69
+ def create_crop(img: Image, x_center: int, y_center: int):
70
+ crop_coords = (
71
+ x_center - halfx_window,
72
+ y_center - halfy_window,
73
+ x_center + halfx_window + 1,
74
+ y_center + halfy_window + 1
75
+ )
76
+ crop = img.crop(crop_coords)
77
+ return crop
78
+
79
+
80
+ def create_crops_dataset(crops_folder: str, coords_csv: str, crops_dataset: str):
81
+ if not os.path.exists(crops_folder):
82
+ os.makedirs(crops_folder)
83
+
84
+ crop_name_list = []
85
+ orig_name_list = []
86
+ x_list = []
87
+ y_list = []
88
+ label_list = []
89
+
90
+ n_positives = 0
91
+ label = 1
92
+ dataset = CoordinatesDataset(coords_csv)
93
+ print('Creating positive crops...')
94
+ for data_filename, label_filename in dataset.iterate_data(Split.TRAIN):
95
+ if label_filename is None:
96
+ continue
97
+
98
+ print(data_filename)
99
+ orig_img_name = os.path.basename(data_filename)
100
+ img_name = os.path.splitext(orig_img_name)[0]
101
+
102
+ img = open_image(data_filename)
103
+ coordinates = dataset.load_coordinates(label_filename)
104
+
105
+ for x_center, y_center in coordinates:
106
+ crop = create_crop(img, x_center, y_center)
107
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
108
+ crop.save(os.path.join(crops_folder, crop_name))
109
+
110
+ crop_name_list.append(crop_name)
111
+ orig_name_list.append(orig_img_name)
112
+ x_list.append(x_center)
113
+ y_list.append(y_center)
114
+ label_list.append(label)
115
+
116
+ n_positives += 1
117
+
118
+ label = 0
119
+ no_train_images = dataset.split_length(Split.TRAIN)
120
+ neg_crops_per_image = [n_positives // no_train_images + (1 if x < n_positives % no_train_images else 0) for x in range(no_train_images)]
121
+ print('Creating negative crops...')
122
+ for (data_filename, label_filename), no_neg_crops in zip(dataset.iterate_data(Split.TRAIN), neg_crops_per_image):
123
+ print(data_filename)
124
+ orig_img_name = os.path.basename(data_filename)
125
+ img_name = os.path.splitext(orig_img_name)[0]
126
+ img = open_image(data_filename)
127
+
128
+ if label_filename:
129
+ coordinates = dataset.load_coordinates(label_filename)
130
+ support_map = generate_support_img(coordinates, window_size)
131
+ else:
132
+ support_map = None
133
+
134
+ for _ in range(no_neg_crops):
135
+ x_rand = np.random.randint(0, 512)
136
+ y_rand = np.random.randint(0, 512)
137
+
138
+ if support_map is not None:
139
+ while support_map[x_rand, y_rand] != 0:
140
+ x_rand = np.random.randint(0, 512)
141
+ y_rand = np.random.randint(0, 512)
142
+
143
+ x_center, y_center = x_rand, y_rand
144
+
145
+ crop = create_crop(img, x_center, y_center)
146
+ crop_name = "{}_{}_{}.tif".format(img_name, x_center, y_center)
147
+ crop.save(os.path.join(crops_folder, crop_name))
148
+
149
+ crop_name_list.append(crop_name)
150
+ orig_name_list.append(orig_img_name)
151
+ x_list.append(x_center)
152
+ y_list.append(y_center)
153
+ label_list.append(label)
154
+
155
+ df_data = {
156
+ CropsColumns.FILENAME: crop_name_list,
157
+ CropsColumns.ORIGINAL: orig_name_list,
158
+ CropsColumns.X: x_list,
159
+ CropsColumns.Y: y_list,
160
+ CropsColumns.LABEL: label_list
161
+ }
162
+ df = pd.DataFrame(df_data, columns=[
163
+ CropsColumns.FILENAME,
164
+ CropsColumns.ORIGINAL,
165
+ CropsColumns.X,
166
+ CropsColumns.Y,
167
+ CropsColumns.LABEL
168
+ ])
169
+
170
+ df_pos = df[df.Label == 1]
171
+ df_neg = df[df.Label == 0]
172
+
173
+ pos_len = len(df_pos)
174
+ neg_len = len(df_neg)
175
+
176
+ pos_train, pos_val, pos_test = np.split(df_pos.sample(frac=1), [int(0.8*pos_len), int(0.9*pos_len)])
177
+ neg_train, neg_val, neg_test = np.split(df_neg.sample(frac=1), [int(0.8*neg_len), int(0.9*neg_len)])
178
+ pos_train[CropsColumns.SPLIT] = Split.TRAIN
179
+ pos_val[CropsColumns.SPLIT] = Split.VAL
180
+ pos_test[CropsColumns.SPLIT] = Split.TEST
181
+ neg_train[CropsColumns.SPLIT] = Split.TRAIN
182
+ neg_val[CropsColumns.SPLIT] = Split.VAL
183
+ neg_test[CropsColumns.SPLIT] = Split.TEST
184
+
185
+ df_with_splits = pd.concat((pos_train, neg_train, pos_val, neg_val, pos_test, neg_test), axis=0)
186
+ df_with_splits.to_csv(crops_dataset, header=True, index=False)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ create_crops_dataset(CROPS_PATH, PT_DATASET, CROPS_DATASET)
atoms_detection/cv_detection.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from atoms_detection.image_preprocessing import dl_prepro_image
5
+ from atoms_detection.detection import Detection
6
+
7
+
8
+ class CVDetection(Detection):
9
+
10
+ @staticmethod
11
+ def get_gaussian_kernel(size=21, mean=0, sigma=0.22, offset=0.0):
12
+ # Initializing value of x-axis and y-axis
13
+ # in the range -1 to 1
14
+ x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
15
+ dst = np.sqrt(x * x + y * y)
16
+ # Calculating Gaussian array
17
+ kernel = np.exp(-((dst - mean) ** 2 / (2.0 * sigma ** 2))) - offset
18
+ return kernel
19
+
20
+ def filter_image(self, img_arr: np.ndarray, **kwargs):
21
+ gauss_kernel = self.get_gaussian_kernel(**kwargs)
22
+ max_kernel_value = gauss_kernel.flatten().sum()
23
+ filtered_img = cv2.filter2D(img_arr, -1, gauss_kernel)
24
+ filtered_img /= max_kernel_value
25
+ return filtered_img
26
+
27
+ def image_to_pred_map(self, img: np.ndarray, img_filename=None) -> np.ndarray:
28
+ prepro_img = dl_prepro_image(img)
29
+ filtered_img = self.filter_image(prepro_img)
30
+ filtered_img = filtered_img.transpose()
31
+ return filtered_img
atoms_detection/cv_fe_detection_evaluation.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from atoms_detection.cv_detection import CVDetection
4
+ from atoms_detection.evaluation import Evaluation
5
+ from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, DATASET_PATH
6
+ from utils.constants import ModelArgs
7
+
8
+
9
+ extension_name = "trial"
10
+ threshold = 0.21
11
+ architecture = ModelArgs.BASICCNN
12
+ ckpt_filename = os.path.join(MODELS_PATH, "basic_replicate.ckpt")
13
+ dataset_csv = os.path.join(DATASET_PATH, "Fe_dataset.csv")
14
+
15
+
16
+ inference_cache_path = os.path.join(PREDS_PATH, f"cv_fe_detection_{extension_name}")
17
+
18
+ for threshold in [0.1, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25]:
19
+ detections_path = os.path.join(DETECTION_PATH, f"cv_fe_detection_{extension_name}",
20
+ f"cv_fe_detection_{extension_name}_{threshold}")
21
+ print(f"Detecting atoms on test data with threshold={threshold}...")
22
+ detection = CVDetection(
23
+ dataset_csv=dataset_csv,
24
+ threshold=threshold,
25
+ detections_path=detections_path,
26
+ inference_cache_path=inference_cache_path
27
+ )
28
+ detection.run()
29
+
30
+ logging_filename = os.path.join(LOGS_PATH, f"cv_fe_evaluation_{extension_name}",
31
+ f"cv_fe_evaluation_{extension_name}_{threshold}.csv")
32
+ evaluation = Evaluation(
33
+ coords_csv=dataset_csv,
34
+ predictions_path=detections_path,
35
+ logging_filename=logging_filename
36
+ )
37
+ evaluation.run()
atoms_detection/cv_full_pipeline.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from atoms_detection.cv_detection import CVDetection
7
+ from atoms_detection.evaluation import Evaluation
8
+ from utils.paths import LOGS_PATH, DETECTION_PATH, PREDS_PATH
9
+
10
+
11
+ def cv_full_pipeline(
12
+ extension_name: str,
13
+ coords_csv: str,
14
+ thresholds_list: List[float],
15
+ force: bool = False
16
+ ):
17
+
18
+ # DL Detection & Evaluation
19
+ for threshold in thresholds_list:
20
+ inference_cache_path = os.path.join(PREDS_PATH, f"cv_detection_{extension_name}")
21
+ detections_path = os.path.join(DETECTION_PATH, f"cv_detection_{extension_name}_{threshold}")
22
+ if force or not os.path.exists(detections_path):
23
+ print(f"Detecting atoms on test data with threshold={threshold}...")
24
+ detection = CVDetection(
25
+ dataset_csv=coords_csv,
26
+ threshold=threshold,
27
+ detections_path=detections_path,
28
+ inference_cache_path=inference_cache_path
29
+ )
30
+ detection.run()
31
+
32
+ logging_filename = os.path.join(LOGS_PATH, f"cv_detection_{extension_name}_{threshold}.csv")
33
+ if force or not os.path.exists(logging_filename):
34
+ evaluation = Evaluation(
35
+ coords_csv=coords_csv,
36
+ predictions_path=detections_path,
37
+ logging_filename=logging_filename
38
+ )
39
+ evaluation.run()
40
+
41
+
42
+ def get_args():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument(
45
+ "extension_name",
46
+ type=str,
47
+ help="Experiment extension name"
48
+ )
49
+ parser.add_argument(
50
+ "coords_csv",
51
+ type=str,
52
+ help="Coordinates CSV file to use as input"
53
+ )
54
+ parser.add_argument(
55
+ "-t"
56
+ "--thresholds",
57
+ nargs="+",
58
+ type=float,
59
+ help="Coordinates CSV file to use as input"
60
+ )
61
+ parser.add_argument(
62
+ "--force",
63
+ action="store_true"
64
+ )
65
+ return parser.parse_args()
66
+
67
+
68
+ if __name__ == "__main__":
69
+ args = get_args()
70
+ print(args)
71
+ cv_full_pipeline(args.extension_name, args.coords_csv, args.t__thresholds, args.force)
atoms_detection/dataset.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import os
3
+ import glob
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+ from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+
12
+ from utils.constants import Split, Columns, CropsColumns, ProbsColumns
13
+ from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET
14
+
15
+
16
+ class ImageClassificationDataset(Dataset):
17
+
18
+ def __init__(self, image_paths, image_labels, include_filename=False):
19
+ self.image_paths = image_paths
20
+ self.image_labels = image_labels
21
+ self.include_filename = include_filename
22
+ self.transform = transforms.Compose([
23
+ transforms.ToTensor()
24
+ # transforms.Normalize(mean=[0.5], std=[0.5])
25
+ ])
26
+
27
+ def get_n_labels(self):
28
+ return len(set(self.image_labels))
29
+
30
+ def __len__(self):
31
+ return len(self.image_paths)
32
+
33
+ @staticmethod
34
+ def load_image(img_filename):
35
+ img = Image.open(img_filename)
36
+ np_img = np.asarray(img).astype(np.float32)
37
+ np_bg = median_filter(np_img, size=(40, 40))
38
+ np_clean = np_img - np_bg
39
+ np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
40
+ return np_normed
41
+
42
+ def __getitem__(self, idx):
43
+ img_path = self.image_paths[idx]
44
+ image = self.load_image(img_path)
45
+ image = self.transform(image)
46
+ label = self.image_labels[idx]
47
+
48
+ if self.include_filename:
49
+ return image, label, os.path.basename(img_path)
50
+ else:
51
+ return image, label
52
+
53
+ @staticmethod
54
+ def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
55
+ raise NotImplementedError
56
+
57
+ @classmethod
58
+ def train_dataset(cls, **kwargs):
59
+ filenames, labels = cls.get_filenames_labels(Split.TRAIN)
60
+ return cls(filenames, labels, **kwargs)
61
+
62
+ @classmethod
63
+ def val_dataset(cls, **kwargs):
64
+ filenames, labels = cls.get_filenames_labels(Split.VAL)
65
+ return cls(filenames, labels, **kwargs)
66
+
67
+ @classmethod
68
+ def test_dataset(cls, **kwargs):
69
+ filenames, labels = cls.get_filenames_labels(Split.TEST)
70
+ return cls(filenames, labels, **kwargs)
71
+
72
+
73
+ class HaadfDataset(ImageClassificationDataset):
74
+ @staticmethod
75
+ def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
76
+ df = pd.read_csv(HAADF_DATASET)
77
+ split_df = df[df[Columns.SPLIT] == split]
78
+ filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list()
79
+ labels = (split_df[Columns.LABEL]).to_list()
80
+ return filenames, labels
81
+
82
+
83
+ class ImageDataset:
84
+ FILENAME_COL = "Filename"
85
+ SPLIT_COL = "Split"
86
+ RULER_UNITS = "Ruler Units"
87
+
88
+ def __init__(self, dataset_csv: str):
89
+ self.df = pd.read_csv(dataset_csv)
90
+
91
+ def iterate_data(self, split: Split):
92
+ df = self.df[self.df[self.SPLIT_COL] == split]
93
+ for idx, row in df.iterrows():
94
+ image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
95
+ yield image_filename
96
+
97
+ def get_ruler_units_by_img_name(self, name):
98
+ print(name)
99
+ return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0]
100
+
101
+
102
+
103
+ class CoordinatesDataset:
104
+ FILENAME_COL = "Filename"
105
+ COORDS_COL = "Coords"
106
+ SPLIT_COL = "Split"
107
+
108
+ def __init__(self, coord_image_csv: str):
109
+ self.df = pd.read_csv(coord_image_csv)
110
+
111
+ def iterate_data(self, split: Split):
112
+ df = self.df[self.df[self.SPLIT_COL] == split]
113
+ for idx, row in df.iterrows():
114
+ image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
115
+ if isinstance(row[self.COORDS_COL], str):
116
+ coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL])
117
+ else:
118
+ coords_filename = None
119
+ yield image_filename, coords_filename
120
+
121
+ @staticmethod
122
+ def load_coordinates(label_filename: str) -> List[Tuple[int, int]]:
123
+ atom_coordinates = pd.read_csv(label_filename)
124
+ return list(zip(atom_coordinates['X'], atom_coordinates['Y']))
125
+
126
+ def split_length(self, split: Split):
127
+ df = self.df[self.df[self.SPLIT_COL] == split]
128
+ return len(df)
129
+
130
+
131
+ class HaadfCoordinates(CoordinatesDataset):
132
+ def __init__(self):
133
+ super().__init__(coord_image_csv=PT_DATASET)
134
+
135
+
136
+ class CropsDataset(ImageClassificationDataset):
137
+ @staticmethod
138
+ def get_filenames_labels(split: Split):
139
+ df = pd.read_csv(CROPS_DATASET)
140
+ split_df = df[df[CropsColumns.SPLIT] == split]
141
+ filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list()
142
+ labels = (split_df[CropsColumns.LABEL]).to_list()
143
+ return filenames, labels
144
+
145
+
146
+ class CropsCustomDataset(ImageClassificationDataset):
147
+
148
+ @staticmethod
149
+ def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str):
150
+ df = pd.read_csv(crops_dataset)
151
+ split_df = df[df[CropsColumns.SPLIT] == split]
152
+ filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list()
153
+ labels = (split_df[CropsColumns.LABEL]).to_list()
154
+ return filenames, labels
155
+
156
+
157
+ class ProbsDataset(ImageClassificationDataset):
158
+ @staticmethod
159
+ def get_filenames_labels(split: Split):
160
+ df = pd.read_csv(PROBS_DATASET)
161
+ split_df = df[df[ProbsColumns.SPLIT] == split]
162
+ filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list()
163
+ labels = (split_df[ProbsColumns.LABEL]).to_list()
164
+ return filenames, labels
165
+
166
+
167
+ class SlidingCropDataset(Dataset):
168
+
169
+ def __init__(self, tif_filename, include_coords=True):
170
+ self.filename = tif_filename
171
+ self.include_coords = include_coords
172
+ self.transform = transforms.Compose([
173
+ transforms.ToTensor(),
174
+ transforms.Normalize(mean=[0.5], std=[0.5])
175
+ ])
176
+
177
+ self.n_labels = 2
178
+ self.step_size = 2
179
+ self.window_size = (21, 21)
180
+ self.loaded_crops = []
181
+ self.loaded_coords = []
182
+ self.load_crops()
183
+
184
+ def sliding_window(self, image):
185
+ # slide a window across the image
186
+ for x in range(0, image.shape[0] - self.window_size[0], self.step_size):
187
+ for y in range(0, image.shape[1] - self.window_size[1], self.step_size):
188
+ # yield the current window
189
+ center_x = x + ((self.window_size[0] - 1) // 2)
190
+ center_y = y + ((self.window_size[1] - 1) // 2)
191
+ yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]]
192
+
193
+ @staticmethod
194
+ def load_image(img_filename):
195
+ img = Image.open(img_filename)
196
+ np_img = np.asarray(img).astype(np.float32)
197
+ np_bg = median_filter(np_img, size=(40, 40))
198
+ np_clean = np_img - np_bg
199
+ np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
200
+ return np_normed
201
+
202
+ def load_crops(self):
203
+ img = self.load_image(self.filename)
204
+ for x_center, y_center, img_crop in self.sliding_window(img):
205
+ self.loaded_crops.append(img_crop)
206
+ self.loaded_coords.append((x_center, y_center))
207
+
208
+ def get_n_labels(self):
209
+ return self.n_labels
210
+
211
+ def __len__(self):
212
+ return len(self.loaded_crops)
213
+
214
+ def __getitem__(self, idx):
215
+ crop = self.loaded_crops[idx]
216
+ x, y = self.loaded_coords[idx]
217
+ crop = self.transform(crop)
218
+
219
+ return crop, x, y
220
+
221
+
222
+ def get_image_path_without_coords(split: str or None = None):
223
+ coords_prefix_set = set()
224
+ for coords_name in os.listdir(COORDS_PATH):
225
+ coord_prefix = coords_name.split('_')[0]
226
+ coords_prefix_set.add(coord_prefix)
227
+
228
+ all_prefixes_set = set()
229
+ for tif_name in os.listdir(IMG_PATH):
230
+ coord_prefix = tif_name.split('_')[0]
231
+ all_prefixes_set.add(coord_prefix)
232
+
233
+ if split == Split.TRAIN:
234
+ missing_prefixes = coords_prefix_set
235
+ elif split == Split.TEST:
236
+ missing_prefixes = all_prefixes_set - coords_prefix_set
237
+ elif split is None:
238
+ missing_prefixes = all_prefixes_set
239
+ else:
240
+ raise ValueError
241
+ tif_filenames_list = []
242
+ labels_list = []
243
+ for prefix in missing_prefixes:
244
+ filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*'))
245
+ if len(filename_matches) == 0:
246
+ continue
247
+ pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename]
248
+ neg_filenames = [filename for filename in filename_matches if '_NC' in filename]
249
+
250
+ if len(pos_filenames) > 0:
251
+ pos_filename = sorted(pos_filenames)[-1]
252
+ tif_filenames_list.append(pos_filename)
253
+ labels_list.append(1)
254
+ if len(neg_filenames) > 0:
255
+ neg_filename = sorted(neg_filenames)[-1]
256
+ tif_filenames_list.append(neg_filename)
257
+ labels_list.append(0)
258
+
259
+ return tif_filenames_list, labels_list
260
+
261
+
262
+ if __name__ == "__main__":
263
+ filenames_list = get_image_path_without_coords()
264
+ filename = filenames_list[0]
265
+ dataset = SlidingCropDataset(filename)
atoms_detection/detection.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import os
4
+ from hashlib import sha1
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from scipy.ndimage import label
9
+
10
+ from utils.constants import Split
11
+ from utils.paths import PREDS_PATH
12
+ from atoms_detection.dataset import ImageDataset
13
+
14
+
15
+ class Detection:
16
+ def __init__(self, dataset_csv: str, threshold: float, detections_path: str, inference_cache_path: str):
17
+ self.image_dataset = ImageDataset(dataset_csv)
18
+ self.threshold = threshold
19
+ self.detections_path = detections_path
20
+ self.inference_cache_path = inference_cache_path
21
+ self.currently_processing = None
22
+ if not os.path.exists(self.detections_path):
23
+ os.makedirs(self.detections_path)
24
+ if not os.path.exists(self.inference_cache_path):
25
+ os.makedirs(self.inference_cache_path)
26
+
27
+ def image_to_pred_map(self, img: np.ndarray) -> np.ndarray:
28
+ raise NotImplementedError
29
+
30
+ def pred_map_to_atoms(self, pred_map: np.ndarray) -> Tuple[List[Tuple[int, int]], List[float]]:
31
+ pred_mask = pred_map > self.threshold
32
+ labeled_array, num_features = label(pred_mask)
33
+
34
+ # Convert labelled_array to indexes
35
+ center_coords_list = []
36
+ likelihood_list = []
37
+ for label_idx in range(num_features+1):
38
+ if label_idx == 0:
39
+ continue
40
+ label_mask = np.where(labeled_array == label_idx)
41
+ likelihood = np.max(pred_map[label_mask])
42
+ likelihood_list.append(likelihood)
43
+ # label_size = len(label_mask[0])
44
+ # print(f"\t\tAtom {label_idx}: {label_size}")
45
+ atom_bbox = (label_mask[1].min(), label_mask[1].max(), label_mask[0].min(), label_mask[0].max())
46
+ center_coord = self.bbox_to_center_coords(atom_bbox)
47
+ center_coords_list.append(center_coord)
48
+ return center_coords_list, likelihood_list
49
+
50
+ def detect_atoms(self, img_filename: str) -> Tuple[List[Tuple[int, int]], List[float]]:
51
+ img_hash = self.cache_image_identifier(img_filename)
52
+ prediciton_cache = os.path.join(self.inference_cache_path, f"{img_hash}.npy")
53
+ if not os.path.exists(prediciton_cache):
54
+ self.currently_processing = os.path.split(img_filename)[-1]
55
+ img = self.open_image(img_filename)
56
+ pred_map = self.image_to_pred_map(img)
57
+ np.save(prediciton_cache, pred_map)
58
+ else:
59
+ pred_map = np.load(prediciton_cache)
60
+ center_coords_list, likelihood_list = self.pred_map_to_atoms(pred_map)
61
+ return center_coords_list, likelihood_list
62
+
63
+ def cache_image_identifier(self, img_filename):
64
+ return sha1(img_filename.encode()).hexdigest()
65
+
66
+ @staticmethod
67
+ def bbox_to_center_coords(bbox: Tuple[int, int, int, int]) -> Tuple[int, int]:
68
+ x_center = (bbox[0] + bbox[1]) // 2
69
+ y_center = (bbox[2] + bbox[3]) // 2
70
+ return x_center, y_center
71
+
72
+ @staticmethod
73
+ def open_image(img_filename: str):
74
+ img = Image.open(img_filename)
75
+ np_img = np.asarray(img).astype(np.float32)
76
+ return np_img
77
+
78
+ def run_single(self, image_path: str):
79
+ print(f"Running detection on {os.path.basename(image_path)}")
80
+ center_coords_list, likelihood_list = self.detect_atoms(image_path)
81
+
82
+ image_filename = os.path.basename(image_path)
83
+ img_name = os.path.splitext(image_filename)[0]
84
+ detection_csv = os.path.join(self.detections_path, f"{img_name}.csv")
85
+ with open(detection_csv, "w") as _csv:
86
+ _csv.write("Filename,x,y,Likelihood\n")
87
+ for (x, y), likelihood in zip(center_coords_list, likelihood_list):
88
+ _csv.write(f"{image_filename},{x},{y},{likelihood}\n")
89
+ return center_coords_list, likelihood_list
90
+
91
+ def run(self):
92
+ if not os.path.exists(self.detections_path):
93
+ os.makedirs(self.detections_path)
94
+
95
+ for image_path in self.image_dataset.iterate_data(Split.TEST):
96
+ run_single(image_path)
atoms_detection/dl_contrastive_pipeline.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from atoms_detection.create_crop_dataset import create_contrastive_crops_dataset
7
+ from atoms_detection.dl_detection import DLDetection
8
+ from atoms_detection.dl_detection_with_gmm import DLGMMdetection
9
+ from atoms_detection.evaluation import Evaluation
10
+ from atoms_detection.training_model import train_model
11
+ from utils.paths import (
12
+ CROPS_PATH,
13
+ CROPS_DATASET,
14
+ MODELS_PATH,
15
+ LOGS_PATH,
16
+ DETECTION_PATH,
17
+ PREDS_PATH,
18
+ PRED_GT_VIS_PATH,
19
+ )
20
+ from utils.constants import ModelArgs, Split
21
+ from matplotlib import pyplot as plt
22
+ import pandas as pd
23
+ from PIL import Image
24
+ import numpy as np
25
+
26
+ from visualizations.prediction_gt_images import get_gt_coords
27
+ from visualizations.utils import plot_gt_pred_on_img
28
+
29
+
30
+ def dl_full_pipeline(
31
+ extension_name: str,
32
+ architecture: ModelArgs,
33
+ coords_csv: str,
34
+ thresholds_list: List[float],
35
+ force_create_dataset: bool = False,
36
+ force_evaluation: bool = False,
37
+ show_sampling_image: bool = False,
38
+ train: bool = False,
39
+ visualise: bool = False,
40
+ upsample: bool = False,
41
+ upsample_neg_amount: float = 0,
42
+ clip_max: float = 1,
43
+ negative_dist: float = 1.1,
44
+ ):
45
+ # Create crops data
46
+ crops_folder = CROPS_PATH + f"_{extension_name}"
47
+ crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
48
+ print(os.path.exists(crops_dataset))
49
+ if force_create_dataset or not os.path.exists(crops_dataset):
50
+ print("Creating crops dataset...")
51
+ create_contrastive_crops_dataset(
52
+ crops_folder,
53
+ coords_csv,
54
+ crops_dataset,
55
+ show_sampling_result=show_sampling_image,
56
+ pos_data_upsampling=upsample,
57
+ neg_upsample_multiplier=upsample_neg_amount,
58
+ contrastive_distance_multiplier=negative_dist,
59
+ ) # , clip=clip_max
60
+ # training DL model
61
+ ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
62
+ if train or not os.path.exists(ckpt_filename):
63
+ print("Training DL crops model...")
64
+ train_model(architecture, crops_dataset, crops_folder, ckpt_filename)
65
+
66
+ for threshold in thresholds_list:
67
+ inference_cache_path = os.path.join(
68
+ PREDS_PATH, f"dl_detection_{extension_name}"
69
+ )
70
+ detections_path = os.path.join(
71
+ DETECTION_PATH,
72
+ f"dl_detection_{extension_name}",
73
+ f"dl_detection_{extension_name}_{threshold}",
74
+ )
75
+ if force_evaluation or visualise or not os.path.exists(detections_path):
76
+ print(f"Detecting atoms on test data with threshold={threshold}...")
77
+ if args.run_gmm_for_multimers:
78
+ detection_pipeline = DLGMMdetection
79
+ else:
80
+ detection_pipeline = DLDetection
81
+
82
+ detection = detection_pipeline(
83
+ model_name=architecture,
84
+ ckpt_filename=ckpt_filename,
85
+ dataset_csv=coords_csv,
86
+ threshold=threshold,
87
+ detections_path=detections_path,
88
+ inference_cache_path=inference_cache_path,
89
+ )
90
+ detection.run()
91
+
92
+ logging_filename = os.path.join(
93
+ LOGS_PATH,
94
+ f"dl_evaluation_{extension_name}",
95
+ f"dl_evaluation_{extension_name}_{threshold}.csv",
96
+ )
97
+ if force_evaluation or visualise or not os.path.exists(logging_filename):
98
+ evaluation = Evaluation(
99
+ coords_csv=coords_csv,
100
+ predictions_path=detections_path,
101
+ logging_filename=logging_filename,
102
+ )
103
+ evaluation.run()
104
+ if visualise:
105
+ vis_folder = os.path.join(
106
+ PRED_GT_VIS_PATH, f"dl_detection_{extension_name}"
107
+ )
108
+ if not os.path.exists(vis_folder):
109
+ os.makedirs(vis_folder)
110
+
111
+ vis_folder = os.path.join(
112
+ vis_folder, f"dl_detection_{extension_name}_{threshold}"
113
+ )
114
+ if not os.path.exists(vis_folder):
115
+ os.makedirs(vis_folder)
116
+ is_evaluation = True
117
+ if is_evaluation:
118
+ gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)
119
+
120
+ for image_path in detection.image_dataset.iterate_data(Split.TEST):
121
+ img_name = os.path.split(image_path)[-1]
122
+ gt_coords = gt_coords_dict[img_name] if is_evaluation else None
123
+ pred_df_path = os.path.join(
124
+ detections_path, os.path.splitext(img_name)[0] + ".csv"
125
+ )
126
+ df_predicted = pd.read_csv(pred_df_path)
127
+ pred_coords = [
128
+ (row["x"], row["y"]) for _, row in df_predicted.iterrows()
129
+ ]
130
+ img = Image.open(image_path)
131
+ img_arr = np.array(img).astype(np.float32)
132
+ img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
133
+
134
+ plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
135
+ clean_image_name = os.path.splitext(img_name)[0]
136
+ vis_path = os.path.join(vis_folder, f"{clean_image_name}.png")
137
+ plt.savefig(
138
+ vis_path, bbox_inches="tight", pad_inches=0.0, transparent=True
139
+ )
140
+ plt.close()
141
+
142
+
143
+ def get_args():
144
+ parser = argparse.ArgumentParser()
145
+ parser.add_argument("extension_name", type=str, help="Experiment extension name")
146
+ parser.add_argument(
147
+ "architecture", type=ModelArgs, choices=ModelArgs, help="Architecture name"
148
+ )
149
+ parser.add_argument(
150
+ "coords_csv", type=str, help="Coordinates CSV file to use as input"
151
+ )
152
+ parser.add_argument(
153
+ "-t" "--thresholds", nargs="+", type=float, help="Threshold value"
154
+ )
155
+ parser.add_argument(
156
+ "-c", type=float, default=1, help="Clipping quantile (0..1]. CURRENTLY USELESS!"
157
+ )
158
+ parser.add_argument(
159
+ "-nd", type=float, default=1.1, help="Negative contrastive crop distance"
160
+ )
161
+ parser.add_argument("--force_create_dataset", action="store_true")
162
+ parser.add_argument("--force_evaluation", action="store_true")
163
+ parser.add_argument("--show_sampling_result", action="store_true")
164
+ parser.add_argument("--train", action="store_true")
165
+ parser.add_argument("--visualise", action="store_true")
166
+ parser.add_argument("--upsample", action="store_true")
167
+ parser.add_argument(
168
+ "--run_gmm_for_multimers",
169
+ action="store_true",
170
+ help="If selected, a postprocessing will be run to split large atoms (possible multimers) with a GMM",
171
+ )
172
+ parser.add_argument(
173
+ "--upsample_neg",
174
+ type=float,
175
+ default=0,
176
+ help="Upsample amount for negative crops during training",
177
+ )
178
+ return parser.parse_args()
179
+
180
+
181
+ if __name__ == "__main__":
182
+ args = get_args()
183
+ print(args)
184
+ dl_full_pipeline(
185
+ args.extension_name,
186
+ args.architecture,
187
+ args.coords_csv,
188
+ args.t__thresholds,
189
+ args.force_create_dataset,
190
+ args.force_evaluation,
191
+ args.show_sampling_result,
192
+ args.train,
193
+ args.visualise,
194
+ args.upsample,
195
+ args.upsample_neg,
196
+ args.c,
197
+ args.nd,
198
+ )
atoms_detection/dl_detection.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, List
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn
7
+ import torch.nn.functional
8
+
9
+ from atoms_detection.detection import Detection
10
+ from atoms_detection.training_model import model_pipeline
11
+ from atoms_detection.image_preprocessing import dl_prepro_image
12
+ from utils.constants import ModelArgs
13
+ from utils.paths import PREDS_PATH
14
+
15
+
16
+ class DLDetection(Detection):
17
+ def __init__(self,
18
+ model_name: ModelArgs,
19
+ ckpt_filename: str,
20
+ dataset_csv: str,
21
+ threshold: float,
22
+ detections_path: str,
23
+ inference_cache_path: str,
24
+ batch_size: int = 64,
25
+ ):
26
+ self.model_name = model_name
27
+ self.ckpt_filename = ckpt_filename
28
+ self.device = self.get_torch_device()
29
+ self.batch_size = batch_size
30
+
31
+ self.stride = 1
32
+ self.padding = 10
33
+ self.window_size = (21, 21)
34
+
35
+ super().__init__(dataset_csv, threshold, detections_path, inference_cache_path)
36
+
37
+ @staticmethod
38
+ def get_torch_device():
39
+ if torch.backends.mps.is_available():
40
+ device = torch.device("mps")
41
+ elif torch.cuda.is_available():
42
+ device = torch.device("cuda")
43
+ else:
44
+ device = torch.device("cpu")
45
+ return device
46
+
47
+ def sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[int, int, np.ndarray]:
48
+ # slide a window across the image
49
+ x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2
50
+ y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2
51
+
52
+ for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride):
53
+ for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride):
54
+ # yield the current window
55
+ center_x = x + x_to_center
56
+ center_y = y + y_to_center
57
+ yield center_x-padding, center_y-padding, image[y:y + self.window_size[1], x:x + self.window_size[0]]
58
+
59
+ def batch_sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[List[int], List[int], List[np.ndarray]]:
60
+ x_idx_list = []
61
+ y_idx_list = []
62
+ images_list = []
63
+ count = 0
64
+ for _x, _y, _img in self.sliding_window(image, padding=padding):
65
+ x_idx_list.append(_x)
66
+ y_idx_list.append(_y)
67
+ images_list.append(_img)
68
+ count += 1
69
+ if count == self.batch_size:
70
+ yield x_idx_list, y_idx_list, images_list
71
+ x_idx_list = []
72
+ y_idx_list = []
73
+ images_list = []
74
+ count = 0
75
+ if count != 0:
76
+ yield x_idx_list, y_idx_list, images_list
77
+
78
+ def padding_image(self, img: np.ndarray) -> np.ndarray:
79
+ image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2))
80
+ image_padded[self.padding:-self.padding, self.padding:-self.padding] = img
81
+ return image_padded
82
+
83
+ def load_model(self) -> torch.nn.Module:
84
+ checkpoint = torch.load(self.ckpt_filename, map_location=self.device)
85
+
86
+ model = model_pipeline[self.model_name](num_classes=2).to(self.device)
87
+ model.load_state_dict(checkpoint['state_dict'])
88
+ model.eval()
89
+ return model
90
+
91
+ def images_to_torch_input(self, images_list: List[np.ndarray]) -> torch.Tensor:
92
+ expanded_img = np.expand_dims(images_list, axis=1)
93
+ input_tensor = torch.from_numpy(expanded_img).float()
94
+ input_tensor = input_tensor.to(self.device)
95
+ return input_tensor
96
+
97
+ def get_prediction_map(self, padded_image: np.ndarray) -> np.ndarray:
98
+ _shape = padded_image.shape
99
+ pred_map = np.zeros((_shape[0] - self.padding*2, _shape[1] - self.padding*2))
100
+ model = self.load_model()
101
+ for x_idxs, y_idxs, image_crops in self.batch_sliding_window(padded_image, padding=self.padding):
102
+ torch_input = self.images_to_torch_input(image_crops)
103
+ output = model(torch_input)
104
+ pred_prob = torch.nn.functional.softmax(output, 1)
105
+ pred_prob = pred_prob.detach().cpu().numpy()[:, 1]
106
+ pred_map[np.array(y_idxs), np.array(x_idxs)] = pred_prob
107
+ return pred_map
108
+
109
+ def image_to_pred_map(self, img: np.ndarray, return_intermediate: bool = False) -> np.ndarray:
110
+ preprocessed_img = dl_prepro_image(img)
111
+ print(f"preprocessed_img.shape: {preprocessed_img.shape}, μ: {np.mean(preprocessed_img)}, σ: {np.std(preprocessed_img)}")
112
+ padded_image = self.padding_image(preprocessed_img)
113
+ print(f"padded_image.shape: {padded_image.shape}, μ: {np.mean(padded_image)}, σ: {np.std(padded_image)}")
114
+ pred_map = self.get_prediction_map(padded_image)
115
+ print(f"pred_map.shape: {pred_map.shape}, μ: {np.mean(pred_map)}, σ: {np.std(pred_map)}")
116
+ if return_intermediate:
117
+ return preprocessed_img, padded_image, pred_map
118
+ return pred_map
atoms_detection/dl_detection_evaluation.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+
7
+ from PIL import Image
8
+ from networkx.tests.test_convert_pandas import pd
9
+
10
+ from atoms_detection.dl_detection import DLDetection
11
+ from atoms_detection.dl_detection_scaled import DLScaled
12
+ from atoms_detection.evaluation import Evaluation
13
+ from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH
14
+ from utils.constants import ModelArgs, Split
15
+ from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords
16
+
17
+
18
+ def detection_pipeline(args):
19
+ extension_name = args.extension_name
20
+ print(f"Storing at {extension_name}")
21
+ architecture = ModelArgs.BASICCNN
22
+ ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt")
23
+
24
+ inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
25
+
26
+ testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]
27
+ testing_thresholds = [0.8, 0.85, 0.9, 0.95]
28
+ for threshold in testing_thresholds:
29
+ detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}",
30
+ f"dl_detection_{extension_name}_{threshold}")
31
+ print(f"Detecting atoms on test data with threshold={threshold}...")
32
+ if args.experimental_rescale:
33
+ print("Using experimental ruler rescaling")
34
+ detection = DLScaled(
35
+ model_name=architecture,
36
+ ckpt_filename=ckpt_filename,
37
+ dataset_csv=args.dataset,
38
+ threshold=threshold,
39
+ detections_path=detections_path,
40
+ inference_cache_path=inference_cache_path
41
+ )
42
+ else:
43
+ detection = DLDetection(
44
+ model_name=architecture,
45
+ ckpt_filename=ckpt_filename,
46
+ dataset_csv=args.dataset,
47
+ threshold=threshold,
48
+ detections_path=detections_path,
49
+ inference_cache_path=inference_cache_path
50
+ )
51
+ detection.run()
52
+ if args.eval:
53
+ logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}",
54
+ f"dl_detection_{extension_name}_{threshold}.csv")
55
+ evaluation = Evaluation(
56
+ coords_csv=args.dataset,
57
+ predictions_path=detections_path,
58
+ logging_filename=logging_filename
59
+ )
60
+ evaluation.run()
61
+ if args.visualise:
62
+
63
+ vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}")
64
+ if not os.path.exists(vis_folder):
65
+ os.makedirs(vis_folder)
66
+
67
+ vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}")
68
+ if not os.path.exists(vis_folder):
69
+ os.makedirs(vis_folder)
70
+
71
+ if args.eval:
72
+ gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)
73
+
74
+ for image_path in detection.image_dataset.iterate_data(Split.TEST):
75
+ print(image_path)
76
+ img_name = os.path.split(image_path)[-1]
77
+ gt_coords = gt_coords_dict[img_name] if args.eval else None
78
+ pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv')
79
+ df_predicted = pd.read_csv(pred_df_path)
80
+ pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]
81
+ img = Image.open(image_path)
82
+ img_arr = np.array(img).astype(np.float32)
83
+ img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
84
+
85
+ plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
86
+ clean_image_name = os.path.splitext(img_name)[0]
87
+ vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
88
+ plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
89
+ plt.close()
90
+
91
+ print(f"Experiment {extension_name} completed")
92
+
93
+
94
+ def get_args():
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument(
97
+ "extension_name",
98
+ type=str,
99
+ help="Experiment extension name"
100
+ )
101
+ parser.add_argument(
102
+ "dataset",
103
+ type=str,
104
+ help="Dataset file upon which to do inference"
105
+ )
106
+ parser.add_argument(
107
+ "--eval",
108
+ action='store_true',
109
+ help="Whether to perform evaluation after inference",
110
+ default=False
111
+ )
112
+ parser.add_argument(
113
+ "--visualise",
114
+ action='store_true',
115
+ help="Whether to store inference results as visual png images",
116
+ default=False
117
+ )
118
+ parser.add_argument(
119
+ "--experimental_rescale",
120
+ action='store_true',
121
+ help="Whether to rescale inputs based on the ruler of the image as preprocess",
122
+ default=False
123
+ )
124
+ parser.add_argument('--feature', )
125
+ return parser.parse_args()
126
+
127
+
128
+ if __name__=='__main__':
129
+ args = get_args()
130
+ detection_pipeline(args)
atoms_detection/dl_detection_scaled.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from _sha1 import sha1
3
+ from typing import Tuple, List
4
+
5
+ from PIL import Image
6
+
7
+ from atoms_detection.dl_detection import DLDetection
8
+ from atoms_detection.image_preprocessing import dl_prepro_image
9
+ from utils.constants import ModelArgs
10
+ import numpy as np
11
+
12
+ class DLScaled(DLDetection):
13
+ # Should take into account for the resize:
14
+ # Ruler of the image (pixelsxnm)
15
+ # Covalent radius
16
+ # beam size/voltage/exposure? (can create larger distortions) (Metadata should be in dm3 files, if it can be read)
17
+ def __init__(self,
18
+ model_name: ModelArgs,
19
+ ckpt_filename: str,
20
+ dataset_csv: str,
21
+ threshold: float,
22
+ detections_path: str,
23
+ inference_cache_path: str):
24
+ super().__init__(model_name, ckpt_filename,dataset_csv,threshold,detections_path, inference_cache_path)
25
+
26
+ def image_to_pred_map(self, img: np.ndarray) -> np.ndarray:
27
+ ruler_units = self.image_dataset.get_ruler_units_by_img_name(self.currently_processing)
28
+ preprocessed_img, scale_factor = dl_prepro_image(img, ruler_units=ruler_units)
29
+ padded_image = self.padding_image(preprocessed_img)
30
+ pred_map = self.get_prediction_map(padded_image)
31
+
32
+ new_dimensions = img.shape[0], img.shape[1]
33
+ pred_map = np.array(Image.fromarray(pred_map).resize(new_dimensions))
34
+ return pred_map
35
+
36
+ def cache_image_identifier(self, img_filename):
37
+ x = sha1((img_filename+'scaled').encode()).hexdigest()
38
+ print(x)
39
+ return x
40
+
41
+
42
+
atoms_detection/dl_detection_with_gmm.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ from atoms_detection.dl_detection import DLDetection
4
+ from utils.constants import ModelArgs
5
+ from sklearn.mixture import GaussianMixture
6
+ from scipy.ndimage import label
7
+ import math
8
+ import numpy as np
9
+
10
+
11
+ class DLGMMdetection(DLDetection):
12
+ MAX_SINGLE_ATOM_AREA = 200
13
+ MAX_ATOMS_PER_AREA = 3
14
+ COVARIANCE_TYPE = "full"
15
+
16
+ def __init__(
17
+ self,
18
+ model_name: ModelArgs,
19
+ ckpt_filename: str,
20
+ dataset_csv: str,
21
+ threshold: float,
22
+ detections_path: str,
23
+ inference_cache_path: str,
24
+ covariance_penalisation: float = 0.03,
25
+ n_clusters_penalisation: float = 0.33,
26
+ distance_penalisation: float = 0.11,
27
+ n_samples_per_gmm: int = 600,
28
+ ):
29
+ super(DLGMMdetection, self).__init__(
30
+ model_name,
31
+ ckpt_filename,
32
+ dataset_csv,
33
+ threshold,
34
+ detections_path,
35
+ inference_cache_path,
36
+ )
37
+ self.covariance_penalisation = covariance_penalisation
38
+ self.n_clusters_penalisation = n_clusters_penalisation
39
+ self.distance_penalisation = distance_penalisation
40
+ self.n_samples_per_gmm = n_samples_per_gmm
41
+
42
+ def pred_map_to_atoms(
43
+ self, pred_map: np.ndarray
44
+ ) -> Tuple[List[Tuple[int, int]], List[float]]:
45
+ pred_mask = pred_map > self.threshold
46
+ labeled_array, num_features = label(pred_mask)
47
+ self.current_pred_map = pred_map
48
+
49
+ # Convert labelled_array to indexes
50
+ center_coords_list = []
51
+ likelihood_list = []
52
+ for label_idx in range(num_features + 1):
53
+ if label_idx == 0:
54
+ continue
55
+ label_mask = np.where(labeled_array == label_idx)
56
+ likelihood = np.max(pred_map[label_mask])
57
+ # label_size = len(label_mask[0])
58
+ # print(f"\t\tAtom {label_idx}: {label_size}")
59
+ atom_bbox = (
60
+ label_mask[1].min(),
61
+ label_mask[1].max(),
62
+ label_mask[0].min(),
63
+ label_mask[0].max(),
64
+ )
65
+ center_coord = self.bbox_to_center_coords(atom_bbox)
66
+ center_coords_list += center_coord
67
+ pixel_area = (atom_bbox[1] - atom_bbox[0]) * (atom_bbox[3] - atom_bbox[2])
68
+ if pixel_area < self.MAX_SINGLE_ATOM_AREA:
69
+ likelihood_list.append(likelihood)
70
+ else:
71
+ for i in range(0, len(center_coord)):
72
+ likelihood_list.append(likelihood)
73
+ self.current_pred_map = None
74
+ print(f"number for atoms {len(center_coords_list)}")
75
+ return center_coords_list, likelihood_list
76
+
77
+ def bbox_to_center_coords(
78
+ self, bbox: Tuple[int, int, int, int]
79
+ ) -> List[Tuple[int, int]]:
80
+ pixel_area = (bbox[1] - bbox[0]) * (bbox[3] - bbox[2])
81
+ if pixel_area < self.MAX_SINGLE_ATOM_AREA:
82
+ return super().bbox_to_center_coords(bbox)
83
+ else:
84
+ pmap = self.get_current_prediction_map_region(bbox)
85
+ local_atom_center_list = self.run_gmm_pipeline(pmap)
86
+ atom_center_list = [
87
+ (x + bbox[0], y + bbox[2]) for x, y in local_atom_center_list
88
+ ]
89
+ return atom_center_list
90
+
91
+ def sample_img_hist(self, img_region):
92
+ x_bin_midpoints = list(range(img_region.shape[1]))
93
+ y_bin_midpoints = list(range(img_region.shape[0]))
94
+ # noinspection PyUnresolvedReferences
95
+ cdf = np.cumsum(img_region.ravel())
96
+ cdf = cdf / cdf[-1]
97
+ values = np.random.rand(self.n_samples_per_gmm)
98
+ # noinspection PyUnresolvedReferences
99
+ value_bins = np.searchsorted(cdf, values)
100
+ x_idx, y_idx = np.unravel_index(
101
+ value_bins, (len(x_bin_midpoints), len(y_bin_midpoints))
102
+ )
103
+ random_from_cdf = np.column_stack((x_idx, y_idx))
104
+ new_x, new_y = random_from_cdf.T
105
+ return new_x, new_y
106
+
107
+ def run_gmm_pipeline(self, prediction_map: np.ndarray) -> List[Tuple[int, int]]:
108
+ retries = 2
109
+ new_x, new_y = self.sample_img_hist(prediction_map)
110
+ best_gmm, best_score = None, np.NINF
111
+ obs = np.array((new_x, new_y)).T
112
+ for k in range(1, self.MAX_ATOMS_PER_AREA + 1):
113
+ for i in range(retries):
114
+ gmm = GaussianMixture(
115
+ n_components=k, covariance_type=self.COVARIANCE_TYPE
116
+ )
117
+ gmm.fit(obs)
118
+ logLike = gmm.score(obs)
119
+ covar = np.linalg.norm(gmm.covariances_)
120
+ if k == 1:
121
+ score = (
122
+ logLike
123
+ - covar * self.covariance_penalisation
124
+ - k * self.n_clusters_penalisation
125
+ )
126
+ print(k, score)
127
+ else:
128
+ distances = [
129
+ math.dist(p1, p2)
130
+ for i, p1 in enumerate(gmm.means_[:-1])
131
+ for p2 in gmm.means_[i + 1 :]
132
+ ]
133
+ dist_penalisation = sum([max(12 - d, 0) ** 2 for d in distances])
134
+ score = (
135
+ logLike
136
+ - covar * self.covariance_penalisation
137
+ - k * self.n_clusters_penalisation
138
+ - dist_penalisation * self.distance_penalisation
139
+ )
140
+ print(
141
+ k,
142
+ score,
143
+ logLike,
144
+ covar * self.covariance_penalisation,
145
+ k * self.n_clusters_penalisation,
146
+ dist_penalisation * self.distance_penalisation,
147
+ )
148
+ if score > best_score:
149
+ best_gmm, best_score = gmm, score
150
+ # print(best_gmm.means_)
151
+ return [(x, y) for y, x in best_gmm.means_.tolist()]
atoms_detection/dl_full_pipeline.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from atoms_detection.create_crop_dataset import create_crops_dataset
7
+ from atoms_detection.dl_detection import DLDetection
8
+ from atoms_detection.evaluation import Evaluation
9
+ from atoms_detection.training_model import train_model
10
+ from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH
11
+ from utils.constants import ModelArgs
12
+
13
+
14
+ def dl_full_pipeline(
15
+ extension_name: str,
16
+ architecture: ModelArgs,
17
+ coords_csv: str,
18
+ thresholds_list: List[float],
19
+ force: bool = False
20
+ ):
21
+ # Create crops data
22
+ crops_folder = CROPS_PATH + f"_{extension_name}"
23
+ crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
24
+ if force or not os.path.exists(crops_dataset):
25
+ print("Creating crops dataset...")
26
+ create_crops_dataset(crops_folder, coords_csv, crops_dataset)
27
+
28
+ # training DL model
29
+ ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
30
+ if force or not os.path.exists(ckpt_filename):
31
+ print("Training DL crops model...")
32
+ train_model(architecture, crops_dataset, crops_folder, ckpt_filename)
33
+
34
+ force = True
35
+ # DL Detection & Evaluation
36
+ for threshold in thresholds_list:
37
+ inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
38
+ detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", f"dl_detection_{extension_name}_{threshold}")
39
+ if force or not os.path.exists(detections_path):
40
+ print(f"Detecting atoms on test data with threshold={threshold}...")
41
+ detection = DLDetection(
42
+ model_name=architecture,
43
+ ckpt_filename=ckpt_filename,
44
+ dataset_csv=coords_csv,
45
+ threshold=threshold,
46
+ detections_path=detections_path,
47
+ inference_cache_path=inference_cache_path
48
+ )
49
+ detection.run()
50
+
51
+ logging_filename = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{threshold}.csv")
52
+ if force or not os.path.exists(logging_filename):
53
+ evaluation = Evaluation(
54
+ coords_csv=coords_csv,
55
+ predictions_path=detections_path,
56
+ logging_filename=logging_filename
57
+ )
58
+ evaluation.run()
59
+
60
+
61
+ def get_args():
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument(
64
+ "extension_name",
65
+ type=str,
66
+ help="Experiment extension name"
67
+ )
68
+ parser.add_argument(
69
+ "architecture",
70
+ type=ModelArgs,
71
+ choices=ModelArgs,
72
+ help="Architecture name"
73
+ )
74
+ parser.add_argument(
75
+ "coords_csv",
76
+ type=str,
77
+ help="Coordinates CSV file to use as input"
78
+ )
79
+ parser.add_argument(
80
+ "-t"
81
+ "--thresholds",
82
+ nargs="+",
83
+ type=float,
84
+ help="Coordinates CSV file to use as input"
85
+ )
86
+ parser.add_argument(
87
+ "--force",
88
+ action="store_true"
89
+ )
90
+ return parser.parse_args()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ args = get_args()
95
+ print(args)
96
+ dl_full_pipeline(args.extension_name, args.architecture, args.coords_csv, args.t__thresholds, args.force)
atoms_detection/evaluation.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ from PIL import Image
8
+ import pandas as pd
9
+ from matplotlib import pyplot as plt
10
+ from matplotlib import patches
11
+
12
+ from utils.constants import Split
13
+ from atoms_detection.dataset import CoordinatesDataset
14
+
15
+
16
+ def bbox_iou(bb1, bb2):
17
+ assert bb1[0] <= bb1[1]
18
+ assert bb1[2] <= bb1[3]
19
+ assert bb2[0] <= bb2[1]
20
+ assert bb2[2] <= bb2[3]
21
+
22
+ # determine the coordinates of the intersection rectangle
23
+ x_left = max(bb1[0], bb2[0])
24
+ y_top = max(bb1[2], bb2[2])
25
+ x_right = min(bb1[1], bb2[1])
26
+ y_bottom = min(bb1[3], bb2[3])
27
+
28
+ if x_right < x_left or y_bottom < y_top:
29
+ return 0.0
30
+
31
+ # The intersection of two axis-aligned bounding boxes is always an
32
+ # axis-aligned bounding box
33
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
34
+
35
+ # compute the area of both AABBs
36
+ bb1_area = (bb1[1] - bb1[0]) * (bb1[3] - bb1[2])
37
+ bb2_area = (bb2[1] - bb2[0]) * (bb2[3] - bb2[2])
38
+
39
+ # compute the intersection over union by taking the intersection
40
+ # area and dividing it by the sum of prediction + ground-truth
41
+ # areas - the interesection area
42
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
43
+ assert iou >= 0.0
44
+ assert iou <= 1.0
45
+ return iou
46
+
47
+
48
+ def match_bboxes(iou_matrix, IOU_THRESH=0.5):
49
+ '''
50
+ Given sets of true and predicted bounding-boxes,
51
+ determine the best possible match.
52
+
53
+ Returns
54
+ -------
55
+ (idxs_true, idxs_pred, ious, labels)
56
+ idxs_true, idxs_pred : indices into gt and pred for matches
57
+ ious : corresponding IOU value of each match
58
+ labels: vector of 0/1 values for the list of detections
59
+ '''
60
+ n_true, n_pred = iou_matrix.shape
61
+ MIN_IOU = 0.0
62
+ MAX_DIST = 1.0
63
+
64
+ if n_pred > n_true:
65
+ # there are more predictions than ground-truth - add dummy rows
66
+ diff = n_pred - n_true
67
+ iou_matrix = np.concatenate((iou_matrix,
68
+ np.full((diff, n_pred), MIN_IOU)),
69
+ axis=0)
70
+
71
+ if n_true > n_pred:
72
+ # more ground-truth than predictions - add dummy columns
73
+ diff = n_true - n_pred
74
+ iou_matrix = np.concatenate((iou_matrix,
75
+ np.full((n_true, diff), MIN_IOU)),
76
+ axis=1)
77
+
78
+ # call the Hungarian matching
79
+ idxs_true, idxs_pred = scipy.optimize.linear_sum_assignment(1 - iou_matrix)
80
+
81
+ if (not idxs_true.size) or (not idxs_pred.size):
82
+ ious = np.array([])
83
+ else:
84
+ ious = iou_matrix[idxs_true, idxs_pred]
85
+
86
+ # remove dummy assignments
87
+ sel_pred = idxs_pred < n_pred
88
+ idx_pred_actual = idxs_pred[sel_pred]
89
+ idx_gt_actual = idxs_true[sel_pred]
90
+ ious_actual = iou_matrix[idx_gt_actual, idx_pred_actual]
91
+ sel_valid = (ious_actual > IOU_THRESH)
92
+ label = sel_valid.astype(int)
93
+
94
+ return idx_gt_actual[sel_valid], idx_pred_actual[sel_valid], ious_actual[sel_valid], label
95
+
96
+
97
+ class Evaluation:
98
+ def __init__(self, coords_csv: str, predictions_path: str, logging_filename: str):
99
+ self.coordinates_dataset = CoordinatesDataset(coords_csv)
100
+ self.predictions_path = predictions_path
101
+ self.logging_filename = logging_filename
102
+ if not os.path.exists(os.path.dirname(self.logging_filename)):
103
+ os.makedirs(os.path.dirname(self.logging_filename))
104
+ self.logs_df = pd.DataFrame(columns=["Filename", "Precision", "Recall", "F1Score"])
105
+ self.threshold = 0.5
106
+
107
+ def get_predictions_dict(self, image_filename: str) -> List[Tuple[int, int]]:
108
+ img_name = os.path.splitext(os.path.basename(image_filename))[0]
109
+ preds_csv = os.path.join(self.predictions_path, f"{img_name}.csv")
110
+ df = pd.read_csv(preds_csv)
111
+ pred_coords_list = []
112
+ for idx, row in df.iterrows():
113
+ pred_coords_list.append((row["x"], row["y"]))
114
+ return pred_coords_list
115
+
116
+ @staticmethod
117
+ def center_coords_to_bbox(gt_coord: Tuple[int, int]) -> Tuple[int, int, int, int]:
118
+ box_rwidth, box_rheight = 10, 10
119
+ gt_bbox = (
120
+ gt_coord[0] - box_rwidth,
121
+ gt_coord[0] + box_rwidth + 1,
122
+ gt_coord[1] - box_rheight,
123
+ gt_coord[1] + box_rheight + 1
124
+ )
125
+ return gt_bbox
126
+
127
+ def eval_matches(
128
+ self,
129
+ gt_bboxes_dict: List[Tuple[int, int, int, int]],
130
+ atoms_bbox_dict: List[Tuple[int, int, int, int]],
131
+ iou_threshold: float = 0.5
132
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
133
+ iou_matrix = np.zeros((len(gt_bboxes_dict), len(atoms_bbox_dict))).astype(np.float32)
134
+
135
+ for gt_idx, gt_bbox in enumerate(gt_bboxes_dict):
136
+ for atom_idx, atom_bbox in enumerate(atoms_bbox_dict):
137
+ iou = bbox_iou(gt_bbox, atom_bbox)
138
+ iou_matrix[gt_idx, atom_idx] = iou
139
+ idxs_true, idxs_pred, ious, labels = match_bboxes(iou_matrix, IOU_THRESH=iou_threshold)
140
+ return idxs_true, idxs_pred, ious, labels
141
+
142
+ @staticmethod
143
+ def eval_metrics(n_matches: int, n_gt: int, n_pred: int) -> Tuple[float, float]:
144
+ precision = n_matches / n_pred if n_pred > 0 else 0.0
145
+ if n_gt == 0:
146
+ raise RuntimeError("No ground truth atoms???")
147
+ recall = n_matches / n_gt
148
+ return precision, recall
149
+
150
+ def atom_coords_to_bboxes(self, atoms_coords_dict: List[Tuple[int, int]]) -> List[Tuple[int, int, int, int]]:
151
+ atom_bboxes_dict = []
152
+ for atom_center in atoms_coords_dict:
153
+ atom_fixed_bbox = self.center_coords_to_bbox(atom_center)
154
+ atom_bboxes_dict.append(atom_fixed_bbox)
155
+ return atom_bboxes_dict
156
+
157
+ def gt_coord_to_bboxes(self, gt_coordinates_dict: List[Tuple[int, int]]) -> List[Tuple[int, int, int, int]]:
158
+ gt_bboxes_list = []
159
+ for gt_coord in gt_coordinates_dict:
160
+ gt_bbox = self.center_coords_to_bbox(gt_coord)
161
+ gt_bboxes_list.append(gt_bbox)
162
+ return gt_bboxes_list
163
+
164
+ @staticmethod
165
+ def open_image(img_filename: str):
166
+ img = Image.open(img_filename)
167
+ np_img = np.asarray(img).astype(np.float32)
168
+ return np_img
169
+
170
+ def run(self, plot=False):
171
+ for image_path, coordinates_path in self.coordinates_dataset.iterate_data(Split.TEST):
172
+ img = self.open_image(image_path)
173
+
174
+ center_coords_dict = self.get_predictions_dict(image_path)
175
+ atoms_bboxes_dict = self.atom_coords_to_bboxes(center_coords_dict)
176
+
177
+ gt_coordinates = self.coordinates_dataset.load_coordinates(coordinates_path)
178
+ gt_bboxes_dict = self.gt_coord_to_bboxes(gt_coordinates)
179
+
180
+ # VISUALILZE gt & pred bboxes!
181
+ if plot:
182
+ plt.figure(figsize=(20, 20))
183
+ ax = plt.gca()
184
+ ax.imshow(img)
185
+ for gt_bbox in gt_bboxes_dict:
186
+ xy = (gt_bbox[0], gt_bbox[2])
187
+ width = gt_bbox[1] - gt_bbox[0]
188
+ height = gt_bbox[3] - gt_bbox[2]
189
+ rect = patches.Rectangle(xy, width, height, linewidth=3, edgecolor='r', facecolor='none')
190
+ ax.add_patch(rect)
191
+ for atom_bbox in atoms_bboxes_dict:
192
+ xy = (atom_bbox[0], atom_bbox[2])
193
+ width = atom_bbox[1] - atom_bbox[0]
194
+ height = atom_bbox[3] - atom_bbox[2]
195
+ rect = patches.Rectangle(xy, width, height, linewidth=2, edgecolor='g', facecolor='none')
196
+ ax.add_patch(rect)
197
+ plt.tight_layout()
198
+ plt.show()
199
+
200
+ idxs_true, idxs_pred, ious, labels = self.eval_matches(gt_bboxes_dict, atoms_bboxes_dict)
201
+ precision, recall = self.eval_metrics(n_matches=len(idxs_pred), n_gt=len(gt_coordinates), n_pred=len(atoms_bboxes_dict))
202
+ f1_score = 2*(precision*recall)/(precision+recall) if precision+recall > 0 else 0
203
+ if self.logging_filename:
204
+ # self.logs_df = self.logs_df.append({
205
+ # "Filename": os.path.basename(image_path),
206
+ # "Precision": precision,
207
+ # "Recall": recall,
208
+ # "F1Score": f1_score
209
+ # }, ignore_index=True)
210
+ # Change the old append method to the new concat method to avoid the warning
211
+ self.logs_df = pd.concat([self.logs_df, pd.DataFrame({
212
+ "Filename": os.path.basename(image_path),
213
+ "Precision": precision,
214
+ "Recall": recall,
215
+ "F1Score": f1_score
216
+ }, index=[0])], ignore_index=True)
217
+ format_args = (os.path.basename(image_path), f1_score, precision, recall)
218
+ print("{}: F1Score: {}, Precision: {}, Recall: {}".format(*format_args))
219
+
220
+ if self.logging_filename:
221
+ mean_precision = self.logs_df["Precision"].mean()
222
+ mean_recall = self.logs_df["Recall"].mean()
223
+ mean_f1_score = self.logs_df["F1Score"].mean()
224
+ std_precision = self.logs_df["Precision"].std()
225
+ std_recall = self.logs_df["Recall"].std()
226
+ std_f1_score = self.logs_df["F1Score"].std()
227
+ print(f"F1Score: {mean_f1_score}, Precision: {mean_precision}, Recall: {mean_recall}")
228
+ # self.logs_df = self.logs_df.append({
229
+ # "Filename": "Mean",
230
+ # "Precision": mean_precision,
231
+ # "Recall": mean_recall,
232
+ # "F1Score": mean_f1_score
233
+ # }, ignore_index=True)
234
+ # Change the old append method to the new concat method to avoid the warning
235
+ self.logs_df = pd.concat([self.logs_df, pd.DataFrame({
236
+ "Filename": "Mean",
237
+ "Precision": mean_precision,
238
+ "Recall": mean_recall,
239
+ "F1Score": mean_f1_score
240
+ }, index=[0])], ignore_index=True)
241
+ # self.logs_df = self.logs_df.append({
242
+ # "Filename": "Std",
243
+ # "Precision": std_precision,
244
+ # "Recall": std_recall,
245
+ # "F1Score": std_f1_score
246
+ # }, ignore_index=True)
247
+ # Change the old append method to the new concat method to avoid the warning
248
+ self.logs_df = pd.concat([self.logs_df, pd.DataFrame({
249
+ "Filename": "Std",
250
+ "Precision": std_precision,
251
+ "Recall": std_recall,
252
+ "F1Score": std_f1_score
253
+ }, index=[0])], ignore_index=True)
254
+ self.logs_df.to_csv(self.logging_filename, index=False, float_format='%.4f')
atoms_detection/fast_filters.cpp ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ @author : Romain Graux
3
+ @date : 2023 April 03, 15:30:30
4
+ @last modified : 2023 April 03, 18:34:08
5
+ */
6
+
7
+ #include <iostream>
8
+ #include <algorithm>
9
+
10
+ using namespace std;
11
+
12
+ extern "C" {
13
+ void median_filter(float* data, int width, int height, int window_size, float* out);
14
+ void reflect_borders(float *data, int width, int height, int span, float *out);
15
+ }
16
+
17
+
18
+ void reflect_borders(float *data, int width, int height, int span, float *out){
19
+ int out_width = width + 2*span;
20
+ int out_height = height + 2*span;
21
+ // First copy the same data but with a border of span pixels
22
+ for(int i=0; i<height; i++){
23
+ for(int j=0; j<width; j++){
24
+ out[(i+span)*(out_width) + j + span] = data[i*width + j];
25
+ }
26
+ }
27
+
28
+ // Then reflect the top and bottom borders
29
+ for (int j=0; j<width; j++){
30
+ for (int h=0; h<span; h++){
31
+ out[(span-h-1)*out_width + j + span] = out[(span+h)*out_width + j + span];
32
+ out[(out_height-span+h)*out_width + j + span] = out[(out_height-span-h-1)*out_width + j + span];
33
+ }
34
+ }
35
+
36
+ // Then reflect the left and right borders
37
+ for(int i=0; i<out_height; i++){
38
+ for(int w=0; w<span; w++){
39
+ out[i*out_width + span - w - 1] = out[i*out_width + span + w];
40
+ out[i*out_width + width + span + w] = out[i*out_width + width + span - w - 1];
41
+ }
42
+ }
43
+ }
44
+
45
+
46
+ void median_filter(float* data, int width, int height, int windowSize, float* out){
47
+ int span = windowSize/2;
48
+ int padded_width = width + 2*span;
49
+
50
+ float* window = new float[windowSize*windowSize];
51
+ for (int y = span; y < height + span; y++){
52
+ for (int x = span; x < width + span; x++){
53
+ for (int i = 0; i < windowSize; i++){
54
+ for (int j = 0; j < windowSize; j++){
55
+ window[i*windowSize + j] = data[(y-span+i)*padded_width + (x-span+j)];
56
+ }
57
+ }
58
+ std::nth_element(window, window + windowSize*windowSize/2, window + windowSize*windowSize);
59
+ out[(y-span)*width + (x-span)] = window[windowSize*windowSize/2];
60
+ }
61
+ }
62
+ delete[] window;
63
+ }
64
+
65
+
66
+ int main(){
67
+ int width = 4;
68
+ int height = 4;
69
+ int window_size = 2;
70
+ float* data = new float[width*height];
71
+ float* out = new float[(width+2*window_size/2)*(height+2*window_size/2)];
72
+ for(int i=0; i<width*height; i++){
73
+ data[i] = i;
74
+ }
75
+ reflect_borders(data, width, height, window_size/2, out);
76
+ for (int i=0; i<height+2*window_size/2; i++){
77
+ for (int j=0; j<width+2*window_size/2; j++){
78
+ cout << out[i*(width+2*window_size/2) + j] << " ";
79
+ }
80
+ cout << endl;
81
+ }
82
+ return 0;
83
+ }
atoms_detection/fast_filters.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import ctypes
4
+ import numpy as np
5
+ from glob import glob
6
+ from functools import partial
7
+ from multiprocessing import Pool
8
+ from utils.paths import LIB_PATH
9
+
10
+ # Load the shared library
11
+ try:
12
+ lib_path = glob(os.path.join(LIB_PATH, "fast_filters*.so"))[0]
13
+ lib = ctypes.cdll.LoadLibrary(lib_path)
14
+ except OSError as e:
15
+ raise OSError(
16
+ "Did you compile the shared library? Please run `python setup.py build_ext`"
17
+ ) from e
18
+
19
+ # Define the functions arguments and return type
20
+ lib.reflect_borders.argtypes = [
21
+ ctypes.POINTER(ctypes.c_float),
22
+ ctypes.c_int,
23
+ ctypes.c_int,
24
+ ctypes.c_int,
25
+ ctypes.POINTER(ctypes.c_float),
26
+ ]
27
+ lib.reflect_borders.restype = None
28
+ lib.median_filter.argtypes = [
29
+ ctypes.POINTER(ctypes.c_float),
30
+ ctypes.c_int,
31
+ ctypes.c_int,
32
+ ctypes.c_int,
33
+ ctypes.POINTER(ctypes.c_float),
34
+ ]
35
+ lib.median_filter.restype = None
36
+
37
+
38
+ # Create the median filter function in Python
39
+ def median_filter(data, window_size, width, height):
40
+ out = np.zeros((height, width), dtype=np.float32)
41
+ lib.median_filter(
42
+ data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
43
+ width,
44
+ height,
45
+ window_size,
46
+ out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
47
+ )
48
+ return out
49
+
50
+
51
+ # Create the reflecting borders function in Python
52
+ def reflecting_borders(data, span):
53
+ out = np.zeros(
54
+ (data.shape[0] + 2 * span, data.shape[1] + 2 * span), dtype=np.float32
55
+ )
56
+ lib.reflect_borders(
57
+ data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
58
+ data.shape[1],
59
+ data.shape[0],
60
+ span,
61
+ out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
62
+ )
63
+ return out
64
+
65
+ def relecting_borders_py(data, span):
66
+ data = np.pad(data, span, mode="reflect")
67
+ return data
68
+
69
+
70
+ def median_filter_parallel(data, window_size, splits=None, inplace=False):
71
+ if splits is None:
72
+ splits = (
73
+ int(math.sqrt(os.cpu_count())) + 1
74
+ ) # at least 1 split per core (n_splits = splits**2 because 2 dimensions)
75
+ out = data if inplace else data.copy()
76
+ height, width = data.shape
77
+ span = window_size // 2
78
+ padded = reflecting_borders(data, span)
79
+ # TODO: split the data evenly and take into account the remaining subarray shape if not all height/width not divisble by the number of splits
80
+ subarrays = []
81
+ for i in range(splits):
82
+ for j in range(splits):
83
+ istart = i * height // splits
84
+ jstart = j * width // splits
85
+ iend = istart + height // splits + 2 * span
86
+ jend = jstart + width // splits + 2 * span
87
+ subarrays.append(padded[istart:iend, jstart:jend])
88
+ f = partial(
89
+ median_filter,
90
+ window_size=window_size,
91
+ width=width // splits,
92
+ height=height // splits,
93
+ )
94
+ with Pool(processes=splits * splits) as pool:
95
+ filtered_subs = list(pool.map(f, subarrays))
96
+ # TODO: merge subarrays faster (maybe with cython)
97
+ for i in range(splits):
98
+ for j in range(splits):
99
+ out[
100
+ i * height // splits : (i + 1) * height // splits,
101
+ j * width // splits : (j + 1) * width // splits,
102
+ ] = filtered_subs[i * splits + j]
103
+ return out
atoms_detection/image_preprocessing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from scipy.ndimage.filters import gaussian_filter, median_filter
4
+ from atoms_detection.fast_filters import median_filter_parallel
5
+ from PIL import Image
6
+
7
+
8
+ def preprocess_jpg(np_img: np.ndarray) -> np.ndarray:
9
+ return np_img[:, :, 0]
10
+
11
+
12
+ def dl_prepro_image(np_img: np.ndarray, ruler_units=None, clip=1):
13
+ # np_bg = gaussian_filter(np_img, sigma=20)
14
+ if len(np_img.shape) == 3:
15
+ np_img = preprocess_jpg(np_img)
16
+ scale_factor = None
17
+ if ruler_units is not None:
18
+ try:
19
+ ruler_size = get_ruler_size(np_img)
20
+ np_img, scale_factor = rescale_img_to_target_pxls_nm(
21
+ np_img, ruler_size, ruler_units
22
+ )
23
+ except Exception:
24
+ pass
25
+
26
+ print("WARNING, MANUAL CLIP USAGE")
27
+ clip = 0.999
28
+ max_allowed = np.quantile(np_img, q=clip)
29
+ np_img = np.clip(np_img, a_min=0, a_max=max_allowed)
30
+ try:
31
+ np_bg = median_filter_parallel(np_img, 40, splits=4)
32
+ except Exception as e:
33
+ print(e)
34
+ print("Median filter failed, using slower scipy version")
35
+ np_bg = median_filter(np_img, 40)
36
+ np_clean = np_img - np_bg
37
+ np_clean[np_clean < 0] = 0
38
+ np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
39
+ # np_normed = (np_img - np_img.min()) / (np_img.max() - np_img.min())
40
+ from matplotlib import pyplot as plt
41
+
42
+ if scale_factor is not None:
43
+ return np_normed, scale_factor
44
+ return np_normed
45
+
46
+
47
+ def cv_prepro_image(img: np.ndarray) -> np.ndarray:
48
+ bg_img = gaussian_filter(img, sigma=10)
49
+ clean_img = img - bg_img
50
+ normed_img = (clean_img - clean_img.min()) / (clean_img.max() - clean_img.min())
51
+ return normed_img
52
+
53
+
54
+ def get_ruler_size(img: np.ndarray) -> int:
55
+ ruler_start_location_percent = 0.0625 # empirically located here in samples
56
+ ruler_start_coords = int(
57
+ img.shape[0] * (1 - ruler_start_location_percent) - 1
58
+ ), int(img.shape[1] * ruler_start_location_percent)
59
+ if img[ruler_start_coords] != img.max():
60
+ print("Ruler start position verification failed, skipping rescaling")
61
+ raise Exception
62
+ else:
63
+ ruler_iterator = ruler_start_coords
64
+ while img[ruler_iterator] == img[ruler_start_coords]:
65
+ ruler_iterator = ruler_iterator[0], ruler_iterator[1] + 1
66
+ return ruler_iterator[1] - ruler_start_coords[1]
67
+
68
+
69
+ def rescale_img_to_target_pxls_nm(
70
+ img: np.ndarray, ruler_pixels: int, ruler_units: int, atom_prior=None
71
+ ):
72
+ target_scale = (
73
+ 512 / 15
74
+ ) # original images were 512x512 and labelled 15nm across, 34 pixels per nano
75
+ pixels_per_nanometer = ruler_pixels / ruler_units # current pixels per nano
76
+ scaling_factor = target_scale / pixels_per_nanometer
77
+ new_dimensions = int(img.shape[0] * scaling_factor), int(
78
+ img.shape[1] * scaling_factor
79
+ )
80
+ if atom_prior is None:
81
+ return np.array(Image.fromarray(img).resize(new_dimensions)), scaling_factor
82
+ else:
83
+ raise NotImplementedError
atoms_detection/model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class BasicCNN(nn.Module):
6
+
7
+ def __init__(self, num_classes):
8
+ super().__init__()
9
+
10
+ self.features = nn.Sequential(
11
+ nn.Conv2d(1, 32, kernel_size=3, stride=1),
12
+ nn.BatchNorm2d(32),
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(32, 64, kernel_size=3, stride=1),
15
+ nn.BatchNorm2d(64),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(64, 128, kernel_size=3, stride=1),
18
+ nn.BatchNorm2d(128),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+ self.adaptive = nn.AdaptiveAvgPool2d((3, 3))
22
+
23
+ self.fc_layers = nn.Sequential(
24
+ nn.Linear(3 * 3 * 128, 128),
25
+ nn.ReLU(inplace=True),
26
+ nn.Linear(128, 128),
27
+ nn.ReLU(inplace=True)
28
+ )
29
+ self.fc3 = nn.Linear(128, num_classes)
30
+
31
+ self._initialize_weights()
32
+
33
+ # Defining the forward pass
34
+ def forward(self, x):
35
+ x = self.features(x)
36
+ x = self.adaptive(x)
37
+ x = torch.flatten(x, 1)
38
+ x = self.fc_layers(x)
39
+ x = self.fc3(x)
40
+ return x
41
+
42
+ def _initialize_weights(self):
43
+ for layer in self.features:
44
+ if isinstance(layer, nn.Conv2d):
45
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
46
+ nn.init.constant_(layer.bias, 0)
47
+ elif isinstance(layer, nn.BatchNorm2d):
48
+ nn.init.constant_(layer.weight, 1)
49
+ nn.init.constant_(layer.bias, 0)
50
+ for layer in self.fc_layers:
51
+ if isinstance(layer, nn.Linear):
52
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
53
+ nn.init.constant_(layer.bias, 0)
54
+ nn.init.normal_(self.fc3.weight, 0, 0.01)
55
+
56
+
57
+ class HeatCNN(nn.Module):
58
+
59
+ def __init__(self, num_classes):
60
+ super().__init__()
61
+
62
+ self.features = nn.Sequential(
63
+ nn.Conv2d(1, 32, kernel_size=3, stride=1),
64
+ nn.BatchNorm2d(32),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv2d(32, 64, kernel_size=3, stride=1),
67
+ nn.BatchNorm2d(64),
68
+ nn.ReLU(inplace=True),
69
+ nn.Conv2d(64, 128, kernel_size=3, stride=1),
70
+ nn.BatchNorm2d(128),
71
+ nn.ReLU(inplace=True)
72
+ )
73
+ self.adaptive = nn.AdaptiveAvgPool2d((3, 3))
74
+
75
+ self.fc_layers = nn.Sequential(
76
+ nn.Linear(3 * 3 * 128, 64),
77
+ nn.ReLU(inplace=True),
78
+ nn.Dropout(),
79
+ nn.Linear(64, 64),
80
+ nn.ReLU(inplace=True),
81
+ nn.Dropout(),
82
+ )
83
+ self.fc3 = nn.Linear(64, num_classes)
84
+
85
+ self._initialize_weights()
86
+
87
+ # Defining the forward pass
88
+ def forward(self, x):
89
+ x = self.features(x)
90
+ x = self.adaptive(x)
91
+ x = torch.flatten(x, 1)
92
+ x = self.fc_layers(x)
93
+ x = self.fc3(x)
94
+ return x
95
+
96
+ def _initialize_weights(self):
97
+ for layer in self.features:
98
+ if isinstance(layer, nn.Conv2d):
99
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
100
+ nn.init.constant_(layer.bias, 0)
101
+ elif isinstance(layer, nn.BatchNorm2d):
102
+ nn.init.constant_(layer.weight, 1)
103
+ nn.init.constant_(layer.bias, 0)
104
+ for layer in self.fc_layers:
105
+ if isinstance(layer, nn.Linear):
106
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
107
+ nn.init.constant_(layer.bias, 0)
108
+ nn.init.normal_(self.fc3.weight, 0, 0.01)
atoms_detection/multimetallic_analysis.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run VAE + GMM assignement
2
+ import argparse
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ # import rasterio
7
+ import torch
8
+ import warnings
9
+ import os
10
+ import re
11
+ import pandas as pd
12
+ from PIL import Image
13
+
14
+ from sklearn.mixture import GaussianMixture
15
+
16
+ from atoms_detection.create_crop_dataset import create_crop
17
+ from atoms_detection.vae_utilities.vae_model import rVAE
18
+ from atoms_detection.vae_utilities.vae_svi_train import init_dataloader, SVItrainer
19
+ from atoms_detection.image_preprocessing import dl_prepro_image
20
+
21
+ """
22
+ Code sourced from:
23
+ https://colab.research.google.com/github/ziatdinovmax/notebooks_for_medium/blob/main/pyroVAE_MNIST_medium.ipynb
24
+
25
+ """
26
+
27
+ numbers = re.compile(r'(\d+)')
28
+
29
+ def numericalSort(value):
30
+ parts = numbers.split(value)
31
+ parts[1::2] = map(int, parts[1::2])
32
+ return parts
33
+
34
+ warnings.filterwarnings("ignore", module="torchvision.datasets")
35
+
36
+
37
+ def get_crops_from_prediction_csvs(pred_crop_file):
38
+ data = pd.read_csv(pred_crop_file)
39
+ xx = data['x'].values
40
+ yy = data['y'].values
41
+ coords = zip(xx,yy)
42
+
43
+ img_file = data['Filename'][0]
44
+ likelihood = data['Likelihood'].values
45
+ img_path = os.path.join('data/tif_data', img_file)
46
+
47
+ img = Image.open(img_path)
48
+ np_img = np.asarray(img).astype(np.float64)
49
+ np_img = dl_prepro_image(np_img)
50
+ img = Image.fromarray(np_img)
51
+
52
+ crops = list()
53
+ coords_list = []
54
+ for x, y in coords:
55
+ coords_list.append([x,y])
56
+ new_crop = create_crop(img, x, y)
57
+ crops.append(new_crop)
58
+
59
+ print(coords_list[0])
60
+ print(np_img[0])
61
+
62
+ coords_array = np.array(coords_list)
63
+
64
+ return crops, coords_array, likelihood, img_file
65
+
66
+
67
+ def classify_crop_species(args):
68
+ # crop_list = get_crops_from_folder(crops_source_folder='./Ni')
69
+ crop_list, crop_coords, likelihood, img_filename = get_crops_from_prediction_csvs(args.pred_crop_file)
70
+ crop_tensor = np.array(crop_list)
71
+
72
+ # Assuming crop_tensor is a list or array of Image objects
73
+ processed_images = []
74
+ for image in crop_tensor:
75
+ # Convert the Image to a NumPy array
76
+ image_array = np.array(image)
77
+ # Append the processed image array to the list
78
+ processed_images.append(image_array)
79
+ # Convert the processed images list to a NumPy array
80
+ processed_images = np.array(processed_images)
81
+ # Convert the processed_images array to float32
82
+ processed_images = processed_images.astype(np.float32)
83
+
84
+ #print(processed_images.shape)
85
+
86
+ rvae = rVAE(in_dim=(21, 21), latent_dim=args.latent_dim, coord=args.coord, seed=args.seed)
87
+
88
+ train_data = torch.from_numpy(processed_images).float()
89
+ # train_data = torch.from_numpy(crop_tensor).float()
90
+ train_loader = init_dataloader(train_data, batch_size=args.batchsize)
91
+ latent_crop_tensor = train_vae(rvae, train_data, train_loader, args)
92
+
93
+ gmm = GaussianMixture(n_components=args.n_species, reg_covar=args.GMMcovar, random_state=args.seed).fit(
94
+ latent_crop_tensor)
95
+ preds = gmm.predict(latent_crop_tensor)
96
+ print(preds)
97
+ pred_proba = gmm.predict_proba(latent_crop_tensor)
98
+ pred_proba = [pred_proba[i, pred] for i, pred in enumerate(preds)]
99
+
100
+ # To order clusters, signal-to-noise ratio OR median (across crops) of some intensity quality (eg mean top-5% int)
101
+ cluster_median_values = list()
102
+ for k in range(args.n_species):
103
+ print(k)
104
+ relevant_crops = processed_images[preds == k]
105
+ crop_95_percentile = np.percentile(relevant_crops, q=95, axis=0)
106
+ img_means = []
107
+ for crop, q in zip(relevant_crops, crop_95_percentile):
108
+ if (crop >= q).any():
109
+ print(crop.mean())
110
+ img_means.append(crop.mean())
111
+ #img_means.append(crop.mean(axis=0, where=crop >= q))
112
+ cluster_median_value = np.median(np.array(img_means))
113
+ cluster_median_values.append(cluster_median_value)
114
+ sorted_clusters = sorted([(mval, c_id) for c_id, mval in enumerate(cluster_median_values)])
115
+
116
+ with open(f"data/detection_data/Multimetallic_{img_filename}.csv", "a") as f:
117
+ f.write("Filename,x,y,Likelihood,cluster,cluster_confidence\n")
118
+ for _, c_id in sorted_clusters:
119
+ c_idd = np.array([c_id])
120
+ pred_proba = np.array(pred_proba)
121
+ relevant_crops_coords = crop_coords[preds == c_idd]
122
+ relevant_crops_likelihood = likelihood[preds == c_idd]
123
+ relevant_crops_confidence = pred_proba[preds == c_idd]
124
+ #print(relevant_crops_confidence)
125
+ for coords, l, c in zip(relevant_crops_coords, relevant_crops_likelihood, relevant_crops_confidence):
126
+ x, y = coords
127
+ f.write(f"{img_filename},{x},{y},{l},{c_id},{c}\n")
128
+
129
+
130
+
131
+ def train_vae(rvae, train_data, train_loader, args):
132
+ # Initialize SVI trainer
133
+ trainer = SVItrainer(rvae)
134
+ for e in range(args.epochs):
135
+ trainer.step(train_loader, scale_factor=args.scale_factor)
136
+ trainer.print_statistics()
137
+ z_mean, z_sd = rvae.encode(train_data)
138
+ latent_crop_tensor = z_mean
139
+ return latent_crop_tensor
140
+
141
+
142
+ def get_crops_from_folder(crops_source_folder) -> List[np.ndarray]:
143
+ ffiles = []
144
+ files = []
145
+ for dirname, dirnames, filenames in os.walk(crops_source_folder):
146
+ # print path to all subdirectories first.
147
+ for subdirname in dirnames:
148
+ files.append(os.path.join(dirname, subdirname))
149
+
150
+ # print path to all filenames.
151
+ for filename in filenames:
152
+ files.append(os.path.join(dirname, filename))
153
+
154
+ for filename in sorted((filenames), key=numericalSort):
155
+ ffiles.append(os.path.join(filename))
156
+ crops = ffiles
157
+ # print(len(crops))
158
+ path_crops = './Ni/'
159
+ all_img = []
160
+ for i in range(0, len(crops)):
161
+ src_path = path_crops + crops[i]
162
+ img = rasterio.open(src_path)
163
+ test = np.reshape(img.read([1]), (21, 21))
164
+ all_img.append(np.array(test))
165
+ return all_img
166
+
167
+
168
+ def get_args():
169
+ parser = argparse.ArgumentParser()
170
+ parser.add_argument(
171
+ 'pred_crop_file',
172
+ type=str,
173
+ help="Path to the CSV of predicted crop locations (eg in data/detection_data/X/Y.csv)"
174
+ )
175
+ parser.add_argument(
176
+ "-latent_dim",
177
+ type=int,
178
+ default=50,
179
+ help="Experiment extension name"
180
+ )
181
+ parser.add_argument(
182
+ "-seed",
183
+ type=int,
184
+ default=444,
185
+ help="Random seed"
186
+ )
187
+ parser.add_argument(
188
+ "-coord",
189
+ type=int,
190
+ default=3,
191
+ help="Amount of equivariances, 0: None,1: Rotational, 2: Translational, 3:Rotational and Translational"
192
+ )
193
+ parser.add_argument(
194
+ "-batchsize",
195
+ type=int,
196
+ default=100,
197
+ help="Batch size for the VAE model"
198
+ )
199
+ parser.add_argument(
200
+ "-epochs",
201
+ type=int,
202
+ default=20,
203
+ help="Number of training epochs for the VAE"
204
+ )
205
+ parser.add_argument(
206
+ "-scale_factor",
207
+ type=int,
208
+ default=3,
209
+ help="Number of training epochs for the VAE"
210
+ )
211
+ parser.add_argument(
212
+ "-n_species",
213
+ type=int,
214
+ default=2,
215
+ help="Number of chemical species expected in the sample."
216
+ )
217
+ parser.add_argument(
218
+ "-GMMcovar",
219
+ type=float,
220
+ default=0.0001,
221
+ help="Regcovar for the training of the GMM clustering algorithm."
222
+ )
223
+ return parser.parse_args()
224
+
225
+
226
+ if __name__ == "__main__":
227
+ args = get_args()
228
+ print(args)
229
+ classify_crop_species(args)
atoms_detection/testing_model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
5
+ from torch.utils.data import DataLoader
6
+ import matplotlib.pyplot as plt
7
+
8
+ from atoms_detection.training_model import model_pipeline, get_args
9
+ from atoms_detection.dataset import CropsDataset
10
+ from atoms_detection.training import test_epoch
11
+ from utils.cf_matrix import make_confusion_matrix
12
+ from utils.paths import MODELS_PATH, CM_VIS_PATH
13
+
14
+
15
+ def main(args):
16
+ # CUDA for PyTorch
17
+ #use_cuda = torch.cuda.is_available()
18
+ use_cuda = torch.backends.mps.is_available()
19
+ device = torch.device("mps" if use_cuda else "cpu")
20
+
21
+ test_dataset = CropsDataset.test_dataset()
22
+ test_dataloader = DataLoader(test_dataset, batch_size=64)
23
+
24
+ ckpt_filename = os.path.join(MODELS_PATH, f'{args.experiment_name}.ckpt')
25
+ checkpoint = torch.load(ckpt_filename, map_location=device)
26
+
27
+ model = model_pipeline[args.model](num_classes=test_dataset.get_n_labels()).to(device)
28
+ model.load_state_dict(checkpoint['state_dict'])
29
+
30
+ if torch.cuda.device_count() > 1:
31
+ print("Using {} GPUs!".format(torch.cuda.device_count()))
32
+ model = torch.nn.DataParallel(model)
33
+
34
+ loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
35
+
36
+ y_true, y_pred = test_epoch(test_dataloader, model, loss_function, device)
37
+
38
+ cm = confusion_matrix(y_true, y_pred)
39
+ labels = ["True Neg", "False Pos", "False Neg", "True Pos"]
40
+ make_confusion_matrix(cm, group_names=labels, cbar_range=(0, 110))
41
+ if not os.path.exists(CM_VIS_PATH):
42
+ os.makedirs(CM_VIS_PATH)
43
+ plt.savefig(os.path.join(CM_VIS_PATH, f"cm_{args.experiment_name}.jpg"))
44
+ f1 = f1_score(y_true, y_pred)
45
+ acc = accuracy_score(y_true, y_pred)
46
+ with open(os.path.join(CM_VIS_PATH, f"metrics_{args.experiment_name}.txt"), 'w') as _log:
47
+ _log.write(f"F1_score: {f1}\nACCURACY: {acc}\n")
48
+ print(f"F1_score: {f1}")
49
+ print(f"ACCURACY: {acc}")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main(get_args())
atoms_detection/training.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def train_epoch(train_loader, model, loss_function, optimizer, device, epoch):
9
+ model.train()
10
+
11
+ correct = 0
12
+ total = 0
13
+ losses = 0
14
+ t0 = time.time()
15
+ for idx, (batch_images, batch_labels) in enumerate(train_loader):
16
+ # Loading tensors in the used device
17
+ step_images, step_labels = batch_images.to(device), batch_labels.to(device)
18
+
19
+ # zero the parameter gradients
20
+ optimizer.zero_grad()
21
+
22
+ step_output = model(step_images)
23
+ loss = loss_function(step_output, step_labels)
24
+ loss.backward()
25
+ optimizer.step()
26
+
27
+ step_total = step_labels.size(0)
28
+ step_loss = loss.item()
29
+ losses += step_loss*step_total
30
+ total += step_total
31
+
32
+ step_preds = torch.max(step_output.data, 1)[1]
33
+ step_correct = (step_preds == step_labels).sum().item()
34
+ correct += step_correct
35
+
36
+ train_loss = losses / total
37
+ train_acc = correct / total
38
+ format_args = (epoch, train_acc, train_loss, time.time() - t0)
39
+ print('EPOCH {} :: train accuracy: {:.4f} - train loss: {:.4f} at {:.1f}s'.format(*format_args))
40
+
41
+
42
+ def val_epoch(val_loader, model, loss_function, device, epoch):
43
+ model.eval()
44
+
45
+ y_true = []
46
+ y_pred = []
47
+
48
+ correct = 0
49
+ total = 0
50
+ losses = 0
51
+ t0 = time.time()
52
+ with torch.no_grad():
53
+ for batch_images, batch_labels in val_loader:
54
+ # Loading tensors in the used device
55
+ step_images, step_labels = batch_images.to(device), batch_labels.to(device)
56
+
57
+ step_output = model(step_images)
58
+ loss = loss_function(step_output, step_labels)
59
+
60
+ step_total = step_labels.size(0)
61
+ step_loss = loss.item()
62
+ losses += step_loss*step_total
63
+ total += step_total
64
+
65
+ step_preds = torch.max(step_output.data, 1)[1]
66
+ y_pred.append(step_preds.cpu().detach().numpy())
67
+ y_true.append(step_labels.cpu().detach().numpy())
68
+ step_correct = (step_preds == step_labels).sum().item()
69
+ correct += step_correct
70
+
71
+ val_loss = losses / total
72
+ val_acc = correct / total
73
+ format_args = (epoch, val_acc, val_loss, time.time() - t0)
74
+ print('EPOCH {} :: val accuracy: {:.4f} - val loss: {:.4f} at {:.1f}s'.format(*format_args))
75
+
76
+ y_pred = np.concatenate(y_pred, axis=0)
77
+ y_true = np.concatenate(y_true, axis=0)
78
+ return y_true, y_pred
79
+
80
+
81
+ def test_epoch(test_loader, model, loss_function, device):
82
+ model.eval()
83
+
84
+ correct = 0
85
+ total = 0
86
+ losses = 0
87
+ all_true = []
88
+ all_pred = []
89
+ t0 = time.time()
90
+ with torch.no_grad():
91
+ for batch_images, batch_labels in test_loader:
92
+ # Loading tensors in the used device
93
+ step_images, step_labels = batch_images.to(device), batch_labels.to(device)
94
+
95
+ step_output = model(step_images)
96
+ loss = loss_function(step_output, step_labels)
97
+
98
+ step_total = step_labels.size(0)
99
+ step_loss = loss.item()
100
+ losses += step_loss*step_total
101
+ total += step_total
102
+
103
+ step_preds = torch.max(step_output.data, 1)[1]
104
+ step_correct = (step_preds == step_labels).sum().item()
105
+ correct += step_correct
106
+
107
+ all_true.append(step_labels.cpu().numpy())
108
+ all_pred.append(step_preds.cpu().numpy())
109
+
110
+ val_loss = losses / total
111
+ val_acc = correct / total
112
+ format_args = (val_acc, val_loss, time.time() - t0)
113
+ print('EPOCH :: test accuracy: {:.4f} - test loss: {:.4f} at {:.1f}s'.format(*format_args))
114
+
115
+ all_pred = np.concatenate(all_pred, axis=0)
116
+ all_true = np.concatenate(all_true, axis=0)
117
+ return all_true, all_pred
118
+
119
+
120
+ def detection_epoch(detection_loader, model, device):
121
+ model.eval()
122
+
123
+ pred_probs = []
124
+ coords_x = []
125
+ coords_y = []
126
+ t0 = time.time()
127
+ with torch.no_grad():
128
+ for batch_images, batch_x, batch_y in detection_loader:
129
+ # Loading tensors in the used device
130
+ step_images = batch_images.to(device)
131
+ step_output = model(step_images)
132
+ step_pred_probs = F.softmax(step_output, 1)
133
+
134
+ step_pred_probs = step_pred_probs.cpu().numpy()
135
+ step_x = batch_x.numpy()
136
+ step_y = batch_y.numpy()
137
+
138
+ coords_x.append(step_x)
139
+ coords_y.append(step_y)
140
+ pred_probs.append(step_pred_probs)
141
+
142
+ return_pred_probs = np.concatenate(pred_probs, axis=0)
143
+ return_coords_x = np.concatenate(coords_x, axis=0)
144
+ return_coords_y = np.concatenate(coords_y, axis=0)
145
+ return return_pred_probs, return_coords_x, return_coords_y
atoms_detection/training_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from torch.utils.data import DataLoader
9
+ from torchvision.models import resnet18
10
+
11
+ from utils.paths import MODELS_PATH, CROPS_PATH, CROPS_DATASET
12
+ from utils.constants import ModelArgs, Split, CropsColumns
13
+ from atoms_detection.training import train_epoch, val_epoch
14
+ from atoms_detection.dataset import ImageClassificationDataset
15
+ from atoms_detection.model import BasicCNN
16
+
17
+
18
+ torch.manual_seed(777)
19
+ random.seed(777)
20
+ np.random.seed(777)
21
+
22
+
23
+ def get_basic_cnn(*args, **kwargs):
24
+ model = BasicCNN(*args, **kwargs)
25
+ return model
26
+
27
+
28
+ def get_resnet(*args, **kwargs):
29
+ model = resnet18(*args, **kwargs)
30
+ model.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
31
+ return model
32
+
33
+
34
+ model_pipeline = {
35
+ ModelArgs.BASICCNN: get_basic_cnn,
36
+ ModelArgs.RESNET18: get_resnet
37
+ }
38
+
39
+ epochs_pipeline = {
40
+ ModelArgs.BASICCNN: 12,
41
+ ModelArgs.RESNET18: 3
42
+ }
43
+
44
+
45
+ def train_model(model_arg: ModelArgs, crops_dataset: str, crops_path: str, ckpt_filename: str):
46
+
47
+ class CropsDataset(ImageClassificationDataset):
48
+ @staticmethod
49
+ def get_filenames_labels(split: Split):
50
+ df = pd.read_csv(crops_dataset)
51
+ split_df = df[df[CropsColumns.SPLIT] == split]
52
+ filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list()
53
+ labels = (split_df[CropsColumns.LABEL]).to_list()
54
+ return filenames, labels
55
+
56
+
57
+ # CUDA for PyTorch
58
+ #use_cuda = torch.cuda.is_available()
59
+ use_cuda = torch.backends.mps.is_available()
60
+ device = torch.device("mps" if use_cuda else "cpu")
61
+
62
+ train_dataset = CropsDataset.train_dataset()
63
+ val_dataset = CropsDataset.val_dataset()
64
+ train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
65
+ val_dataloader = DataLoader(val_dataset, batch_size=64)
66
+ model = model_pipeline[model_arg](num_classes=train_dataset.get_n_labels()).to(device)
67
+
68
+ if torch.cuda.device_count() > 1:
69
+ print("Using {} GPUs!".format(torch.cuda.device_count()))
70
+ model = torch.nn.DataParallel(model)
71
+
72
+ loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
73
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)
74
+
75
+ epoch = 0
76
+ for epoch in range(epochs_pipeline[model_arg]):
77
+ train_epoch(train_dataloader, model, loss_function, optimizer, device, epoch)
78
+ val_epoch(val_dataloader, model, loss_function, device, epoch)
79
+
80
+ if not os.path.exists(MODELS_PATH):
81
+ os.makedirs(MODELS_PATH)
82
+
83
+ state = {
84
+ 'state_dict': model.state_dict(),
85
+ 'optimizer': optimizer.state_dict(),
86
+ 'epoch': epoch
87
+ }
88
+ torch.save(state, ckpt_filename)
89
+
90
+
91
+ def get_args():
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument(
94
+ "experiment_name",
95
+ type=str,
96
+ help="Experiment name"
97
+ )
98
+ parser.add_argument(
99
+ "model",
100
+ type=ModelArgs,
101
+ help="model architecture",
102
+ choices=list(ModelArgs)
103
+ )
104
+ return parser.parse_args()
105
+
106
+
107
+ if __name__ == "__main__":
108
+ extension_name = "replicate"
109
+ ckpt_filename = os.path.join(MODELS_PATH, "basic_replicate2.ckpt")
110
+ crops_folder = CROPS_PATH + f"_{extension_name}"
111
+ train_model(ModelArgs.BASICCNN, CROPS_DATASET, CROPS_PATH, ckpt_filename)
atoms_detection/vae_image_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Load functions for working with image coordinates and labels
2
+ # @title Load utility functions for data loading and preprocessing
3
+
4
+ from typing import Tuple, Union
5
+
6
+ import torch
7
+
8
+ import warnings
9
+ warnings.filterwarnings("ignore", module="torchvision.datasets")
10
+
11
+
12
+ def to_onehot(idx: torch.Tensor, n: int) -> torch.Tensor:
13
+ """
14
+ One-hot encoding of a label
15
+ """
16
+ if torch.max(idx).item() >= n:
17
+ raise AssertionError(
18
+ "Labelling must start from 0 and "
19
+ "maximum label value must be less than total number of classes")
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ if idx.dim() == 1:
22
+ idx = idx.unsqueeze(1)
23
+ onehot = torch.zeros(idx.size(0), n, device=device)
24
+ return onehot.scatter_(1, idx.to(device), 1)
25
+
26
+
27
+ def grid2xy(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
28
+ X = torch.cat((X1[None], X2[None]), 0)
29
+ d0, d1 = X.shape[0], X.shape[1] * X.shape[2]
30
+ X = X.reshape(d0, d1).T
31
+ return X
32
+
33
+
34
+ def imcoordgrid(im_dim: Tuple) -> torch.Tensor:
35
+ xx = torch.linspace(-1, 1, im_dim[0])
36
+ yy = torch.linspace(1, -1, im_dim[1])
37
+ x0, x1 = torch.meshgrid(xx, yy)
38
+ return grid2xy(x0, x1)
39
+
40
+
41
+ def transform_coordinates(coord: torch.Tensor,
42
+ phi: Union[torch.Tensor, float] = 0,
43
+ coord_dx: Union[torch.Tensor, float] = 0,
44
+ ) -> torch.Tensor:
45
+
46
+ if torch.sum(phi) == 0:
47
+ phi = coord.new_zeros(coord.shape[0])
48
+ rotmat_r1 = torch.stack([torch.cos(phi), torch.sin(phi)], 1)
49
+ rotmat_r2 = torch.stack([-torch.sin(phi), torch.cos(phi)], 1)
50
+ rotmat = torch.stack([rotmat_r1, rotmat_r2], axis=1)
51
+ coord = torch.bmm(coord, rotmat)
52
+
53
+ return coord + coord_dx
atoms_detection/vae_model.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import tensor as tt
8
+
9
+ from typing import Optional, Tuple, Type
10
+
11
+ import pyro
12
+ import pyro.distributions as dist
13
+
14
+ import warnings
15
+
16
+ from atoms_detection.vae_image_utils import imcoordgrid, to_onehot, transform_coordinates
17
+
18
+ warnings.filterwarnings("ignore", module="torchvision.datasets")
19
+
20
+ # VAE model set-up
21
+ # @title Load neural networks for VAE { form-width: "25%" }
22
+
23
+
24
+ def set_deterministic_mode(seed: int) -> None:
25
+ torch.manual_seed(seed)
26
+ if torch.cuda.is_available():
27
+ torch.cuda.empty_cache()
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ def make_fc_layers(in_dim: int,
34
+ hidden_dim: int = 128,
35
+ num_layers: int = 2,
36
+ activation: str = "tanh"
37
+ ) -> Type[nn.Module]:
38
+ """
39
+ Generates a module with stacked fully-connected (aka dense) layers
40
+ """
41
+ activations = {"tanh": nn.Tanh, "lrelu": nn.LeakyReLU, "softplus": nn.Softplus}
42
+ fc_layers = []
43
+ for i in range(num_layers):
44
+ hidden_dim_ = in_dim if i == 0 else hidden_dim
45
+ fc_layers.extend(
46
+ [nn.Linear(hidden_dim_, hidden_dim), activations[activation]()])
47
+ fc_layers = nn.Sequential(*fc_layers)
48
+ return fc_layers
49
+
50
+
51
+ class fcEncoderNet(nn.Module):
52
+ """
53
+ Simple fully-connected inference (encoder) network
54
+ """
55
+ def __init__(self,
56
+ in_dim: Tuple[int,int],
57
+ latent_dim: int = 2,
58
+ hidden_dim: int = 128,
59
+ num_layers: int = 2,
60
+ activation: str = 'tanh',
61
+ softplus_out: bool = False
62
+ ) -> None:
63
+ """
64
+ Initializes module parameters
65
+ """
66
+ super(fcEncoderNet, self).__init__()
67
+ if len(in_dim) not in [1, 2, 3]:
68
+ raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)")
69
+ self.in_dim = torch.prod(tt(in_dim)).item()
70
+
71
+ self.fc_layers = make_fc_layers(
72
+ self.in_dim, hidden_dim, num_layers, activation)
73
+ self.fc11 = nn.Linear(hidden_dim, latent_dim)
74
+ self.fc12 = nn.Linear(hidden_dim, latent_dim)
75
+ self.activation_out = nn.Softplus() if softplus_out else lambda x: x
76
+
77
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
78
+ """
79
+ Forward pass
80
+ """
81
+ x = x.view(-1, self.in_dim)
82
+ x = self.fc_layers(x)
83
+ mu = self.fc11(x)
84
+ log_sigma = self.activation_out(self.fc12(x))
85
+ return mu, log_sigma
86
+
87
+
88
+ class fcDecoderNet(nn.Module):
89
+ """
90
+ Standard decoder for VAE
91
+ """
92
+ def __init__(self,
93
+ out_dim: Tuple[int],
94
+ latent_dim: int,
95
+ hidden_dim: int = 128,
96
+ num_layers: int = 2,
97
+ activation: str = 'tanh',
98
+ sigmoid_out: str = True,
99
+ ) -> None:
100
+ super(fcDecoderNet, self).__init__()
101
+ if len(out_dim) not in [1, 2, 3]:
102
+ raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)")
103
+ self.reshape = out_dim
104
+ out_dim = torch.prod(tt(out_dim)).item()
105
+
106
+ self.fc_layers = make_fc_layers(
107
+ latent_dim, hidden_dim, num_layers, activation)
108
+ self.out = nn.Linear(hidden_dim, out_dim)
109
+ self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x
110
+
111
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
112
+ x = self.fc_layers(z)
113
+ x = self.activation_out(self.out(x))
114
+ return x.view(-1, *self.reshape)
115
+
116
+
117
+ class rDecoderNet(nn.Module):
118
+ """
119
+ Spatial generator (decoder) network with fully-connected layers
120
+ """
121
+ def __init__(self,
122
+ out_dim: Tuple[int],
123
+ latent_dim: int,
124
+ hidden_dim: int = 128,
125
+ num_layers: int = 2,
126
+ activation: str = 'tanh',
127
+ sigmoid_out: str = True
128
+ ) -> None:
129
+ """
130
+ Initializes module parameters
131
+ """
132
+ super(rDecoderNet, self).__init__()
133
+ if len(out_dim) not in [1, 2, 3]:
134
+ raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)")
135
+ self.reshape = out_dim
136
+ out_dim = torch.prod(tt(out_dim)).item()
137
+
138
+ self.coord_latent = coord_latent(latent_dim, hidden_dim)
139
+ self.fc_layers = make_fc_layers(
140
+ hidden_dim, hidden_dim, num_layers, activation)
141
+ self.out = nn.Linear(hidden_dim, 1) # need to generalize to multi-channel (c > 1)
142
+ self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x
143
+
144
+ def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Forward pass
147
+ """
148
+ x = self.coord_latent(x_coord, z)
149
+ x = self.fc_layers(x)
150
+ x = self.activation_out(self.out(x))
151
+ return x.view(-1, *self.reshape)
152
+
153
+
154
+ class coord_latent(nn.Module):
155
+ """
156
+ The "spatial" part of the rVAE's decoder that allows for translational
157
+ and rotational invariance (based on https://arxiv.org/abs/1909.11663)
158
+ """
159
+ def __init__(self,
160
+ latent_dim: int,
161
+ out_dim: int,
162
+ activation_out: bool = True) -> None:
163
+ """
164
+ Iniitalizes modules parameters
165
+ """
166
+ super(coord_latent, self).__init__()
167
+ self.fc_coord = nn.Linear(2, out_dim)
168
+ self.fc_latent = nn.Linear(latent_dim, out_dim, bias=False)
169
+ self.activation = nn.Tanh() if activation_out else None
170
+
171
+ def forward(self,
172
+ x_coord: torch.Tensor,
173
+ z: torch.Tensor) -> torch.Tensor:
174
+ """
175
+ Forward pass
176
+ """
177
+ batch_dim, n = x_coord.size()[:2]
178
+ x_coord = x_coord.reshape(batch_dim * n, -1)
179
+ h_x = self.fc_coord(x_coord)
180
+ h_x = h_x.reshape(batch_dim, n, -1)
181
+ h_z = self.fc_latent(z)
182
+ h = h_x.add(h_z.unsqueeze(1))
183
+ h = h.reshape(batch_dim * n, -1)
184
+ if self.activation is not None:
185
+ h = self.activation(h)
186
+ return h
187
+
188
+
189
+ class rVAE(nn.Module):
190
+ """
191
+ Variational autoencoder with rotational and/or transaltional invariance
192
+ """
193
+ def __init__(self,
194
+ in_dim: Tuple[int, int],
195
+ latent_dim: int = 2,
196
+ coord: int = 3,
197
+ num_classes: int = 0,
198
+ hidden_dim_e: int = 128,
199
+ hidden_dim_d: int = 128,
200
+ num_layers_e: int = 2,
201
+ num_layers_d: int = 2,
202
+ activation: str = "tanh",
203
+ softplus_sd: bool = True,
204
+ sigmoid_out: bool = True,
205
+ seed: int = 1,
206
+ **kwargs
207
+ ) -> None:
208
+ """
209
+ Initializes rVAE's modules and parameters
210
+ """
211
+ super(rVAE, self).__init__()
212
+ pyro.clear_param_store()
213
+ set_deterministic_mode(seed)
214
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
215
+ self.encoder_net = fcEncoderNet(
216
+ in_dim, latent_dim+coord, hidden_dim_e,
217
+ num_layers_e, activation, softplus_sd)
218
+ if coord not in [0, 1, 2, 3]:
219
+ raise ValueError("'coord' argument must be 0, 1, 2 or 3")
220
+ dnet = rDecoderNet if coord in [1, 2, 3] else fcDecoderNet
221
+ self.decoder_net = dnet(
222
+ in_dim, latent_dim+num_classes, hidden_dim_d,
223
+ num_layers_d, activation, sigmoid_out)
224
+ self.z_dim = latent_dim + coord
225
+ self.coord = coord
226
+ self.num_classes = num_classes
227
+ self.grid = imcoordgrid(in_dim).to(self.device)
228
+ self.dx_prior = tt(kwargs.get("dx_prior", 0.1)).to(self.device)
229
+ self.to(self.device)
230
+
231
+ def model(self,
232
+ x: torch.Tensor,
233
+ y: Optional[torch.Tensor] = None,
234
+ **kwargs: float) -> torch.Tensor:
235
+ """
236
+ Defines the model p(x|z)p(z)
237
+ """
238
+ # register PyTorch module `decoder_net` with Pyro
239
+ pyro.module("decoder_net", self.decoder_net)
240
+ # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
241
+ beta = kwargs.get("scale_factor", 1.)
242
+ reshape_ = torch.prod(tt(x.shape[1:])).item()
243
+ with pyro.plate("data", x.shape[0]):
244
+ # setup hyperparameters for prior p(z)
245
+ z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
246
+ z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
247
+ # sample from prior (value will be sampled by guide when computing the ELBO)
248
+ with pyro.poutine.scale(scale=beta):
249
+ z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
250
+ if self.coord > 0: # rotationally- and/or translationaly-invariant mode
251
+ # Split latent variable into parts for rotation
252
+ # and/or translation and image content
253
+ phi, dx, z = self.split_latent(z)
254
+ if torch.sum(dx) != 0:
255
+ dx = (dx * self.dx_prior).unsqueeze(1)
256
+ # transform coordinate grid
257
+ grid = self.grid.expand(x.shape[0], *self.grid.shape)
258
+ x_coord_prime = transform_coordinates(grid, phi, dx)
259
+ # Add class label (if any)
260
+ if y is not None:
261
+ y = to_onehot(y, self.num_classes)
262
+ z = torch.cat([z, y], dim=-1)
263
+ # decode the latent code z together with the transformed coordiantes (if any)
264
+ dec_args = (x_coord_prime, z) if self.coord else (z,)
265
+ loc_img = self.decoder_net(*dec_args)
266
+ # score against actual images ("binary cross-entropy loss")
267
+ pyro.sample(
268
+ "obs", dist.Bernoulli(loc_img.view(-1, reshape_), validate_args=False).to_event(1),
269
+ obs=x.view(-1, reshape_))
270
+
271
+ def guide(self,
272
+ x: torch.Tensor,
273
+ y: Optional[torch.Tensor] = None,
274
+ **kwargs: float) -> torch.Tensor:
275
+ """
276
+ Defines the guide q(z|x)
277
+ """
278
+ # register PyTorch module `encoder_net` with Pyro
279
+ pyro.module("encoder_net", self.encoder_net)
280
+ # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
281
+ beta = kwargs.get("scale_factor", 1.)
282
+ with pyro.plate("data", x.shape[0]):
283
+ # use the encoder to get the parameters used to define q(z|x)
284
+ z_loc, z_scale = self.encoder_net(x)
285
+ # sample the latent code z
286
+ with pyro.poutine.scale(scale=beta):
287
+ pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
288
+
289
+ def split_latent(self, z: torch.Tensor) -> Tuple[torch.Tensor]:
290
+ """
291
+ Split latent variable into parts for rotation
292
+ and/or translation and image content
293
+ """
294
+ phi, dx = tt(0), tt(0)
295
+ # rotation + translation
296
+ if self.coord == 3:
297
+ phi = z[:, 0] # encoded angle
298
+ dx = z[:, 1:3] # translation
299
+ z = z[:, 3:] # image content
300
+ # translation only
301
+ elif self.coord == 2:
302
+ dx = z[:, :2]
303
+ z = z[:, 2:]
304
+ # rotation only
305
+ elif self.coord == 1:
306
+ phi = z[:, 0]
307
+ z = z[:, 1:]
308
+ return phi, dx, z
309
+
310
+ def _encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
311
+ """
312
+ Encodes data using a trained inference (encoder) network
313
+ in a batch-by-batch fashion
314
+ """
315
+ def inference() -> np.ndarray:
316
+ with torch.no_grad():
317
+ encoded = self.encoder_net(x_i)
318
+ encoded = torch.cat(encoded, -1).cpu()
319
+ return encoded
320
+
321
+ x_new = x_new.to(self.device)
322
+ num_batches = kwargs.get("num_batches", 10)
323
+ batch_size = len(x_new) // num_batches
324
+ z_encoded = []
325
+ for i in range(num_batches):
326
+ x_i = x_new[i*batch_size:(i+1)*batch_size]
327
+ z_encoded_i = inference()
328
+ z_encoded.append(z_encoded_i)
329
+ x_i = x_new[(i+1)*batch_size:]
330
+ if len(x_i) > 0:
331
+ z_encoded_i = inference()
332
+ z_encoded.append(z_encoded_i)
333
+ return torch.cat(z_encoded)
334
+
335
+ def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
336
+ """
337
+ Encodes data using a trained inference (encoder) network
338
+ (this is baiscally a wrapper for self._encode)
339
+ """
340
+ if isinstance(x_new, torch.utils.data.DataLoader):
341
+ x_new = train_loader.dataset.tensors[0]
342
+ z = self._encode(x_new)
343
+ z_loc = z[:, :self.z_dim]
344
+ z_scale = z[:, self.z_dim:]
345
+ return z_loc, z_scale
atoms_detection/vae_svi_train.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import pyro
8
+ import pyro.infer as infer
9
+ import pyro.optim as optim
10
+
11
+ import warnings
12
+
13
+ #from vae_model import set_deterministic_mode as set_deterministic_mode
14
+ from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode
15
+
16
+ warnings.filterwarnings("ignore", module="torchvision.datasets")
17
+
18
+
19
+ class SVItrainer:
20
+ """
21
+ Stochastic variational inference (SVI) trainer for
22
+ unsupervised and class-conditioned variational models
23
+ """
24
+ def __init__(self,
25
+ model: Type[nn.Module],
26
+ optimizer: Type[optim.PyroOptim] = None,
27
+ loss: Type[infer.ELBO] = None,
28
+ seed: int = 1
29
+ ) -> None:
30
+ """
31
+ Initializes the trainer's parameters
32
+ """
33
+ pyro.clear_param_store()
34
+ set_deterministic_mode(seed)
35
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
+ if optimizer is None:
37
+ optimizer = optim.Adam({"lr": 1.0e-3})
38
+ if loss is None:
39
+ loss = infer.Trace_ELBO()
40
+ self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss)
41
+ self.loss_history = {"training_loss": [], "test_loss": []}
42
+ self.current_epoch = 0
43
+
44
+ def train(self,
45
+ train_loader: Type[torch.utils.data.DataLoader],
46
+ **kwargs: float) -> float:
47
+ """
48
+ Trains a single epoch
49
+ """
50
+ # initialize loss accumulator
51
+ epoch_loss = 0.
52
+ # do a training epoch over each mini-batch returned by the data loader
53
+ for data in train_loader:
54
+ if len(data) == 1: # VAE mode
55
+ x = data[0]
56
+ loss = self.svi.step(x.to(self.device), **kwargs)
57
+ else: # VED or cVAE mode
58
+ x, y = data
59
+ loss = self.svi.step(
60
+ x.to(self.device), y.to(self.device), **kwargs)
61
+ # do ELBO gradient and accumulate loss
62
+ epoch_loss += loss
63
+
64
+ return epoch_loss / len(train_loader.dataset)
65
+
66
+ def evaluate(self,
67
+ test_loader: Type[torch.utils.data.DataLoader],
68
+ **kwargs: float) -> float:
69
+ """
70
+ Evaluates current models state on a single epoch
71
+ """
72
+ # initialize loss accumulator
73
+ test_loss = 0.
74
+ # compute the loss over the entire test set
75
+ with torch.no_grad():
76
+ for data in test_loader:
77
+ if len(data) == 1: # VAE mode
78
+ x = data[0]
79
+ loss = self.svi.step(x.to(self.device), **kwargs)
80
+ else: # VED or cVAE mode
81
+ x, y = data
82
+ loss = self.svi.step(
83
+ x.to(self.device), y.to(self.device), **kwargs)
84
+ test_loss += loss
85
+
86
+ return test_loss / len(test_loader.dataset)
87
+
88
+ def step(self,
89
+ train_loader: Type[torch.utils.data.DataLoader],
90
+ test_loader: Optional[Type[torch.utils.data.DataLoader]] = None,
91
+ **kwargs: float) -> None:
92
+ """
93
+ Single training and (optionally) evaluation step
94
+ """
95
+ self.loss_history["training_loss"].append(self.train(train_loader, **kwargs))
96
+ if test_loader is not None:
97
+ self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs))
98
+ self.current_epoch += 1
99
+
100
+ def print_statistics(self) -> None:
101
+ """
102
+ Prints training and test (if any) losses for current epoch
103
+ """
104
+ e = self.current_epoch
105
+ if len(self.loss_history["test_loss"]) > 0:
106
+ template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}'
107
+ print(template.format(e, self.loss_history["training_loss"][-1],
108
+ self.loss_history["test_loss"][-1]))
109
+ else:
110
+ template = 'Epoch: {} Training loss: {:.4f}'
111
+ print(template.format(e, self.loss_history["training_loss"][-1]))
112
+
113
+
114
+ def init_dataloader(*args: torch.Tensor, **kwargs: int
115
+ ) -> Type[torch.utils.data.DataLoader]:
116
+
117
+ batch_size = kwargs.get("batch_size", 100)
118
+ tensor_set = torch.utils.data.dataset.TensorDataset(*args)
119
+ data_loader = torch.utils.data.DataLoader(
120
+ dataset=tensor_set, batch_size=batch_size, shuffle=True)
121
+ return data_loader
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.24.1
2
+ matplotlib==3.7.1
3
+ networkx==3.0
4
+ numpy==1.23.5
5
+ opencv_contrib_python==4.7.0.72
6
+ pandas==1.5.3
7
+ Pillow==9.5.0
8
+ scikit_learn==1.2.2
9
+ scipy==1.10.1
10
+ seaborn==0.12.2
11
+ torch==1.13.1
12
+ torchvision==0.14.1
setup.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author : Romain Graux
5
+ @date : 2023 April 06, 17:33:28
6
+ @last modified : 2023 April 24, 15:59:09
7
+ @last modified : 2023 April 24, 15:59:09
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ from distutils.core import setup, Extension
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ os.environ["CC"] = "g++"
17
+
18
+ fast_filters_module = Extension(
19
+ "fast_filters",
20
+ sources=["atoms_detection/fast_filters.cpp"],
21
+ )
22
+
23
+ setup(
24
+ name="atoms_detection",
25
+ version="0.0.1a0",
26
+ description="",
27
+ ext_modules=[fast_filters_module],
28
+ )
utils/__init__.py ADDED
File without changes
utils/cf_matrix.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+
5
+
6
+ def make_confusion_matrix(cf,
7
+ group_names=None,
8
+ categories='auto',
9
+ count=True,
10
+ percent=True,
11
+ cbar=True,
12
+ cbar_range=(None, None),
13
+ xyticks=True,
14
+ xyplotlabels=True,
15
+ sum_stats=True,
16
+ figsize=None,
17
+ cmap='Blues',
18
+ title=None):
19
+ '''
20
+ This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
21
+
22
+ Arguments
23
+ ---------
24
+ cf: confusion matrix to be passed in
25
+
26
+ group_names: List of strings that represent the labels row by row to be shown in each square.
27
+
28
+ categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
29
+
30
+ count: If True, show the raw number in the confusion matrix. Default is True.
31
+
32
+ normalize: If True, show the proportions for each category. Default is True.
33
+
34
+ cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix.
35
+ Default is True.
36
+
37
+ xyticks: If True, show x and y ticks. Default is True.
38
+
39
+ xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
40
+
41
+ sum_stats: If True, display summary statistics below the figure. Default is True.
42
+
43
+ figsize: Tuple representing the figure size. Default will be the matplotlib rcParams value.
44
+
45
+ cmap: Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
46
+ See http://matplotlib.org/examples/color/colormaps_reference.html
47
+
48
+ title: Title for the heatmap. Default is None.
49
+
50
+ '''
51
+
52
+ # CODE TO GENERATE TEXT INSIDE EACH SQUARE
53
+ blanks = ['' for i in range(cf.size)]
54
+
55
+ if group_names and len(group_names) == cf.size:
56
+ group_labels = ["{}\n".format(value) for value in group_names]
57
+ else:
58
+ group_labels = blanks
59
+
60
+ if count:
61
+ group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
62
+ else:
63
+ group_counts = blanks
64
+
65
+ if percent:
66
+ group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)]
67
+ else:
68
+ group_percentages = blanks
69
+
70
+ box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)]
71
+ box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1])
72
+
73
+ # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
74
+ if sum_stats:
75
+ # Accuracy is sum of diagonal divided by total observations
76
+ accuracy = np.trace(cf) / float(np.sum(cf))
77
+
78
+ # if it is a binary confusion matrix, show some more stats
79
+ if len(cf) == 2:
80
+ # Metrics for Binary Confusion Matrices
81
+ precision = cf[1, 1] / sum(cf[:, 1])
82
+ recall = cf[1, 1] / sum(cf[1, :])
83
+ f1_score = 2 * precision * recall / (precision + recall)
84
+ stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
85
+ accuracy, precision, recall, f1_score)
86
+ else:
87
+ stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
88
+ else:
89
+ stats_text = ""
90
+
91
+ # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
92
+ if figsize == None:
93
+ # Get default figure size if not set
94
+ figsize = plt.rcParams.get('figure.figsize')
95
+
96
+ if xyticks == False:
97
+ # Do not show categories if xyticks is False
98
+ categories = False
99
+
100
+ # MAKE THE HEATMAP VISUALIZATION
101
+ plt.figure(figsize=figsize)
102
+ sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, vmin=cbar_range[0], vmax=cbar_range[1], xticklabels=categories, yticklabels=categories)
103
+
104
+ if xyplotlabels:
105
+ plt.ylabel('True label')
106
+ plt.xlabel('Predicted label' + stats_text)
107
+ else:
108
+ plt.xlabel(stats_text)
109
+
110
+ if title:
111
+ plt.title(title)
utils/constants.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class Catalyst(Enum):
5
+ Pt = 'Pt'
6
+ Fe = 'Fe'
7
+
8
+ def __str__(self):
9
+ return str(self.value)
10
+
11
+
12
+ class Method(Enum):
13
+ DL = 'DL'
14
+ CV = 'CV'
15
+ TEM = 'TEMImageNet'
16
+
17
+ def __str__(self):
18
+ return str(self.value)
19
+
20
+
21
+ class Split:
22
+ TRAIN = 'train'
23
+ VAL = 'val'
24
+ TEST = 'test'
25
+
26
+
27
+ class Columns:
28
+ FILENAME = 'Filename'
29
+ LABEL = 'Label'
30
+ SPLIT = 'Split'
31
+
32
+
33
+ class CropsColumns:
34
+ FILENAME = 'Filename'
35
+ ORIGINAL = 'Original'
36
+ X = 'X'
37
+ Y = 'Y'
38
+ LABEL = 'Label'
39
+ SPLIT = 'Split'
40
+
41
+
42
+ class BoxColumns:
43
+ FILENAME = 'Filename'
44
+ X1 = 'X1'
45
+ X2 = 'X2'
46
+ Y1 = 'Y1'
47
+ Y2 = 'Y2'
48
+ LABEL = 'Label'
49
+ SPLIT = 'Split'
50
+
51
+
52
+ class ProbsColumns:
53
+ FILENAME = 'Filename'
54
+ ORIGINAL = 'Original'
55
+ LABEL = 'Label'
56
+ SPLIT = 'Split'
57
+
58
+
59
+ class ModelArgs(str, Enum):
60
+ BASICCNN = 'basic'
61
+ RESNET18 = 'resnet18'
utils/crops_visualization.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from matplotlib import pyplot as plt
7
+
8
+ from utils.constants import CropsColumns
9
+ from utils.paths import CROPS_DATASET, CROPS_PATH, CROPS_VIS_PATH
10
+
11
+
12
+ if not os.path.exists(CROPS_VIS_PATH):
13
+ os.makedirs(CROPS_VIS_PATH)
14
+
15
+
16
+ dataset_df = pd.read_csv(CROPS_DATASET)
17
+ for tif_name in dataset_df[CropsColumns.FILENAME]:
18
+ tif_filename = os.path.join(CROPS_PATH, tif_name)
19
+ img = Image.open(tif_filename)
20
+ img = np.array(img).astype(np.float32)
21
+ img = (img - img.min()) / img.max()
22
+ plt.tight_layout()
23
+ plt.imshow(img)
24
+ vis_name = "{}.jpg".format(os.path.splitext(tif_name)[0])
25
+ vis_filename = os.path.join(CROPS_VIS_PATH, vis_name)
26
+ plt.savefig(vis_filename)
utils/paths.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ PROJECT_PATH = os.path.abspath(os.path.join(__file__, *(os.path.pardir for _ in range(2))))
5
+
6
+ MINIO_KEYS = os.path.join(PROJECT_PATH, 'minio.json')
7
+ LOGS_PATH = os.path.join(PROJECT_PATH, 'logs')
8
+ DETECTION_LOGS = os.path.join(LOGS_PATH, 'detection_coords')
9
+ PRED_MAP_TABLE_LOGS = os.path.join(LOGS_PATH, 'pred_map_to_table')
10
+
11
+ DATA_PATH = os.path.join(PROJECT_PATH, 'data')
12
+ IMG_PATH = os.path.join(DATA_PATH, 'tif_data')
13
+ COORDS_PATH = os.path.join(DATA_PATH, 'label_coordinates')
14
+ CROPS_PATH = os.path.join(DATA_PATH, 'atom_crops_data')
15
+ PROBS_PATH = os.path.join(DATA_PATH, 'probs_data')
16
+ BOX_PATH = os.path.join(DATA_PATH, 'box_data')
17
+ PREDS_PATH = os.path.join(DATA_PATH, 'prediction_cache')
18
+ DETECTION_PATH = os.path.join(DATA_PATH, 'detection_data')
19
+
20
+ DATASET_PATH = os.path.join(PROJECT_PATH, 'dataset')
21
+ CROPS_DATASET = os.path.join(DATASET_PATH, 'atom_crops.csv')
22
+ PROBS_DATASET = os.path.join(DATASET_PATH, 'probs_dataset.csv')
23
+ BF_DATASET = os.path.join(DATASET_PATH, 'BF_dataset.csv')
24
+ HAADF_DATASET = os.path.join(DATASET_PATH, 'HAADF_dataset.csv')
25
+ PT_DATASET = os.path.join(DATASET_PATH, 'Pt_dataset.csv')
26
+ FE_DATASET = os.path.join(DATASET_PATH, 'Fe_dataset.csv')
27
+ BOX_DATASET = os.path.join(DATASET_PATH, 'box_dataset.csv')
28
+
29
+ MODELS_PATH = os.path.join(PROJECT_PATH, 'models')
30
+
31
+ DATA_VIS_PATH = os.path.join(PROJECT_PATH, 'data_vis')
32
+ CROPS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'crops')
33
+ CM_VIS_PATH = os.path.join(DATA_VIS_PATH, 'cm_vis')
34
+ ORIG_VIS_PATH = os.path.join(DATA_VIS_PATH, 'orig')
35
+ PREPRO_VIS_PATH = os.path.join(DATA_VIS_PATH, 'preprocessed')
36
+ LABEL_VIS_PATH = os.path.join(DATA_VIS_PATH, 'label')
37
+ PRED_VIS_PATH = os.path.join(DATA_VIS_PATH, 'predictions')
38
+ PRED_GT_VIS_PATH = os.path.join(DATA_VIS_PATH, 'predictions_gt')
39
+ LANDS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'landscapes')
40
+ ACTIVATIONS_VIS_PATH = os.path.join(DATA_VIS_PATH, 'activations')
41
+
42
+ LIB_PATH = glob(f"{os.path.join(PROJECT_PATH, 'build')}/lib*")[0]
visualizations/__init__.py ADDED
File without changes
visualizations/crop_images.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+
8
+ from utils.paths import CROPS_VIS_PATH
9
+
10
+ df = pd.read_csv("dataset/atom_crops_replicate.csv")
11
+ for crop_name in df['Filename']:
12
+ crop_filename = os.path.join("data/atom_crops_data_sac_cnn", crop_name)
13
+ crop = Image.open(crop_filename)
14
+ crop_arr = np.array(crop).astype(np.float32)
15
+ plt.figure()
16
+ plt.axis('off')
17
+ plt.imshow(crop_arr)
18
+ vis_path = os.path.join(CROPS_VIS_PATH, '{}.png'.format(os.path.splitext(crop_name)[0]))
19
+ plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0)
20
+ plt.close()
21
+
visualizations/dl_intermediate_layers_visualization.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple, Optional, Dict
3
+
4
+ import argparse
5
+
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional
10
+ from matplotlib import pyplot as plt
11
+
12
+ from atoms_detection.dataset import CoordinatesDataset
13
+ from atoms_detection.image_preprocessing import dl_prepro_image
14
+ from atoms_detection.model import BasicCNN
15
+ from utils.constants import ModelArgs, Split
16
+ from utils.paths import ACTIVATIONS_VIS_PATH
17
+
18
+
19
+ class ConvLayerVisualizer:
20
+ CONV_0 = 'Conv0'
21
+ CONV_3 = 'Conv3'
22
+ CONV_6 = 'Conv6'
23
+
24
+ def __init__(self, model_name: ModelArgs, ckpt_filename: str):
25
+ self.model_name = model_name
26
+ self.ckpt_filename = ckpt_filename
27
+ self.device = self.get_torch_device()
28
+ self.batch_size = 64
29
+
30
+ self.stride = 1
31
+ self.padding = 10
32
+ self.window_size = (21, 21)
33
+
34
+ @staticmethod
35
+ def get_torch_device():
36
+ use_cuda = torch.cuda.is_available()
37
+ device = torch.device("cuda" if use_cuda else "cpu")
38
+ return device
39
+
40
+ def sliding_window(self, image: np.ndarray) -> Tuple[int, int, np.ndarray]:
41
+ # slide a window across the image
42
+ x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2
43
+ y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2
44
+
45
+ for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride):
46
+ for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride):
47
+ # yield the current window
48
+ center_x = x + x_to_center
49
+ center_y = y + y_to_center
50
+ yield center_x, center_y, image[y:y + self.window_size[1], x:x + self.window_size[0]]
51
+
52
+ def padding_image(self, img: np.ndarray) -> np.ndarray:
53
+ image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2))
54
+ image_padded[self.padding:-self.padding, self.padding:-self.padding] = img
55
+ return image_padded
56
+
57
+ def images_to_torch_input(self, image: np.ndarray) -> torch.Tensor:
58
+ expanded_img = np.expand_dims(image, axis=(0, 1))
59
+ input_tensor = torch.from_numpy(expanded_img).float()
60
+ input_tensor = input_tensor.to(self.device)
61
+ return input_tensor
62
+
63
+ def load_model(self) -> BasicCNN:
64
+ checkpoint = torch.load(self.ckpt_filename, map_location=self.device)
65
+ model = BasicCNN(num_classes=2).to(self.device)
66
+ model.load_state_dict(checkpoint['state_dict'])
67
+ model.eval()
68
+ return model
69
+
70
+ @staticmethod
71
+ def center_to_slice(x_center: int, y_center: int, width: int, height: int) -> Tuple[slice, slice]:
72
+ x_to_center = width // 2 - 1 if width % 2 == 0 else width // 2
73
+ y_to_center = height // 2 - 1 if height % 2 == 0 else height // 2
74
+ x = x_center - x_to_center
75
+ y = y_center - y_to_center
76
+ return slice(x, x + width), slice(y, y + height)
77
+
78
+ def get_prediction_map(self, padded_image: np.ndarray) -> Dict[str, np.ndarray]:
79
+ _shape = padded_image.shape
80
+ convs_activations_dict = {
81
+ self.CONV_0: (np.zeros(_shape), np.zeros(_shape)),
82
+ self.CONV_3: (np.zeros(_shape), np.zeros(_shape)),
83
+ self.CONV_6: (np.zeros(_shape), np.zeros(_shape))
84
+ }
85
+ model = self.load_model()
86
+ for x, y, image_crop in self.sliding_window(padded_image):
87
+ torch_input = self.images_to_torch_input(image_crop)
88
+ conv_outputs = self.get_conv_activations(torch_input, model)
89
+ for conv_layer_key, activations_blob in conv_outputs.items():
90
+ activation_map = self.sum_channels(activations_blob)
91
+ h, w = activation_map.shape
92
+ x_slice, y_slice = self.center_to_slice(x, y, w, h)
93
+ convs_activations_dict[conv_layer_key][0][y_slice, x_slice] += 1
94
+ convs_activations_dict[conv_layer_key][1][y_slice, x_slice] += activation_map
95
+
96
+ activations_dict = {}
97
+ for conv_layer_key, (counting_map, output_map) in convs_activations_dict.items():
98
+ zero_rows = np.sum(counting_map, axis=1)
99
+ zero_cols = np.sum(counting_map, axis=0)
100
+
101
+ output_map = np.delete(output_map, np.where(zero_rows == 0), axis=0)
102
+ clean_output_map = np.delete(output_map, np.where(zero_cols == 0), axis=1)
103
+ counting_map = np.delete(counting_map, np.where(zero_rows == 0), axis=0)
104
+ clean_counting_map = np.delete(counting_map, np.where(zero_cols == 0), axis=1)
105
+
106
+ activations_dict[conv_layer_key] = clean_output_map / clean_counting_map
107
+
108
+ return activations_dict
109
+
110
+ def get_conv_activations(self, input_image: torch.Tensor, model: BasicCNN) -> Dict[str, np.ndarray]:
111
+ conv_activations = {}
112
+ activations = input_image
113
+ for i, layer in enumerate(model.features):
114
+ activations = layer(activations)
115
+ if i == 0:
116
+ conv_activations[self.CONV_0] = activations.squeeze(0).detach().cpu().numpy()
117
+ elif i == 3:
118
+ conv_activations[self.CONV_3] = activations.squeeze(0).detach().cpu().numpy()
119
+ elif i == 6:
120
+ conv_activations[self.CONV_6] = activations.squeeze(0).detach().cpu().numpy()
121
+
122
+ return conv_activations
123
+
124
+ @staticmethod
125
+ def sum_channels(activations: np.ndarray):
126
+ aggregated_activations = np.sum(activations, axis=0)
127
+ return aggregated_activations
128
+
129
+ def image_to_pred_map(self, img: np.ndarray) -> Dict[str, np.ndarray]:
130
+ preprocessed_img = dl_prepro_image(img)
131
+ padded_image = self.padding_image(preprocessed_img)
132
+ activations_dict = self.get_prediction_map(padded_image)
133
+ return activations_dict
134
+
135
+
136
+ def get_args():
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument(
139
+ "architecture",
140
+ type=ModelArgs,
141
+ choices=ModelArgs,
142
+ help="Architecture name"
143
+ )
144
+ parser.add_argument(
145
+ "ckpt_filename",
146
+ type=str,
147
+ help="Path to model checkpoint"
148
+ )
149
+ parser.add_argument(
150
+ "coords_csv",
151
+ type=str,
152
+ help="Coordinates CSV file to use as input"
153
+ )
154
+ return parser.parse_args()
155
+
156
+
157
+ if __name__ == "__main__":
158
+ args = get_args()
159
+ print(args)
160
+
161
+ conv_visualizer = ConvLayerVisualizer(
162
+ model_name=args.architecture,
163
+ ckpt_filename=args.ckpt_filename
164
+ )
165
+
166
+ coordinates_dataset = CoordinatesDataset(args.coords_csv)
167
+ for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST):
168
+ img = Image.open(image_path)
169
+ np_img = np.array(img)
170
+ activations_dict = conv_visualizer.image_to_pred_map(np_img)
171
+
172
+ img_name = os.path.splitext(os.path.basename(image_path))[0]
173
+
174
+ output_folder = os.path.join(ACTIVATIONS_VIS_PATH, f"{img_name}")
175
+ if not os.path.exists(output_folder):
176
+ os.makedirs(output_folder)
177
+
178
+ for conv_layer_key, activation_map in activations_dict.items():
179
+ fig = plt.figure()
180
+ plt.title(f"{conv_layer_key} -- {img_name}")
181
+ plt.imshow(activation_map)
182
+
183
+ output_path = os.path.join(output_folder, f"{conv_layer_key}_{img_name}.png")
184
+ plt.savefig(output_path, bbox_inches='tight')
185
+ plt.close(fig)
186
+
187
+
188
+