Romain Graux commited on
Commit
60fece7
1 Parent(s): 0bb9cde

New extractor for physical metadata

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. app/dl_inference.py +7 -5
  3. app/tiff_utils.py +75 -16
app.py CHANGED
@@ -85,10 +85,10 @@ def batch_fn(
85
  try:
86
  physical_metadata = extract_physical_metadata(file.name)
87
  if physical_metadata.unit != "nm":
88
- raise ValueError(f"Unit of {file.name} is not nm, cannot process it")
89
  except Exception as e:
90
  error_messages.append(f"Error processing {file.name}: {str(e)}")
91
- continue # Skip to the next file
92
 
93
  original_file_name = os.path.basename(file.name)
94
  sanitized_file_name = original_file_name.replace(" ", "_")
 
85
  try:
86
  physical_metadata = extract_physical_metadata(file.name)
87
  if physical_metadata.unit != "nm":
88
+ raise gr.Error(f"Unit of {file.name} is not nm, cannot process it")
89
  except Exception as e:
90
  error_messages.append(f"Error processing {file.name}: {str(e)}")
91
+ raise gr.Error(f"Error processing {file.name}: {str(e)}")
92
 
93
  original_file_name = os.path.basename(file.name)
94
  sanitized_file_name = original_file_name.replace(" ", "_")
app/dl_inference.py CHANGED
@@ -11,6 +11,7 @@ from functools import lru_cache
11
  import sys
12
 
13
  from .tiff_utils import tiff_to_png
 
14
  if ".." not in sys.path:
15
  sys.path.append("..")
16
 
@@ -78,7 +79,7 @@ def multimers_classification(
78
  for e in range(epochs):
79
  trainer.step(train_loader, scale_factor=scale_factor)
80
  trainer.print_statistics()
81
-
82
  # Extract latent space (only mean) from VAE
83
  z_mean, _ = rvae.encode(torch_crops)
84
 
@@ -150,7 +151,9 @@ def inference_fn(
150
  # if img.max() <= 1:
151
  # raise ValueError("Gradio seems to preprocess badly the tiff images. Did you adapt the preprocessing function as mentionned in the app.py file comments?")
152
  prepro_img, _, pred_map = detection.image_to_pred_map(img, return_intermediate=True)
153
- center_coords_list, likelihood_list = (np.array(x) for x in detection.pred_map_to_atoms(pred_map))
 
 
154
  results = (
155
  multimers_classification(
156
  img=prepro_img,
@@ -161,7 +164,6 @@ def inference_fn(
161
  if n_species > 1
162
  else {
163
  0: {
164
-
165
  "coords": center_coords_list,
166
  "likelihood": likelihood_list,
167
  "confidence": np.ones(len(center_coords_list)),
@@ -173,8 +175,8 @@ def inference_fn(
173
  Evaluation.center_coords_to_bbox(center_coords)
174
  for center_coords in v["coords"]
175
  ]
176
- return tiff_to_png(image), {
177
- "image": tiff_to_png(image),
178
  "pred_map": pred_map,
179
  "species": results,
180
  }
 
11
  import sys
12
 
13
  from .tiff_utils import tiff_to_png
14
+
15
  if ".." not in sys.path:
16
  sys.path.append("..")
17
 
 
79
  for e in range(epochs):
80
  trainer.step(train_loader, scale_factor=scale_factor)
81
  trainer.print_statistics()
82
+
83
  # Extract latent space (only mean) from VAE
84
  z_mean, _ = rvae.encode(torch_crops)
85
 
 
151
  # if img.max() <= 1:
152
  # raise ValueError("Gradio seems to preprocess badly the tiff images. Did you adapt the preprocessing function as mentionned in the app.py file comments?")
153
  prepro_img, _, pred_map = detection.image_to_pred_map(img, return_intermediate=True)
154
+ center_coords_list, likelihood_list = (
155
+ np.array(x) for x in detection.pred_map_to_atoms(pred_map)
156
+ )
157
  results = (
158
  multimers_classification(
159
  img=prepro_img,
 
164
  if n_species > 1
165
  else {
166
  0: {
 
167
  "coords": center_coords_list,
168
  "likelihood": likelihood_list,
169
  "confidence": np.ones(len(center_coords_list)),
 
175
  Evaluation.center_coords_to_bbox(center_coords)
176
  for center_coords in v["coords"]
177
  ]
178
+ return tiff_to_png(Image.fromarray(prepro_img)), {
179
+ "image": tiff_to_png(Image.fromarray(prepro_img)),
180
  "pred_map": pred_map,
181
  "species": results,
182
  }
app/tiff_utils.py CHANGED
@@ -6,34 +6,93 @@
6
  @last modified : 2023 September 19, 11:18:36
7
  """
8
 
 
9
  import re
10
  import imageio
11
- import numpy as np
12
  from collections import namedtuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- physical_metadata = namedtuple("physical_metadata", ["width", "height", "pixel_width", "pixel_height", "unit"])
 
 
 
 
 
 
 
 
 
15
 
16
- def extract_physical_metadata(image_path : str, strict:bool=True) -> physical_metadata:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
- Extracts the physical metadata of an image (only tiff for now)
 
19
  """
20
  with open(image_path, "rb") as f:
21
  data = f.read()
22
- reader = imageio.get_reader(data, format=".tif")
23
  metadata = reader.get_meta_data()
24
 
25
- if strict and not metadata['is_imagej']:
26
- for key, value in metadata.items():
27
- if key.startswith("is_") and value == True: # Force bool to be True, because it can also pass the condition while being an random object
28
- raise ValueError(f"The image is not TIFF image, but it seems to be a {key[3:]} image")
29
- raise ValueError("Impossible to extract metadata from the image (ImageJ)")
30
  h, w = reader.get_next_data().shape
31
- ipw, iph, _ = metadata['resolution']
32
- result = re.search(r"unit=(.+)", metadata['description'])
33
- if strict and not result:
34
- raise ValueError(f"No scale unit found in the image description : {metadata['description']}")
35
- unit = result and result.group(1)
36
- return physical_metadata(w, h, 1. / ipw, 1. / iph, unit)
 
 
 
37
 
38
  def tiff_to_png(image, inplace=True):
39
  img = image if inplace else image.copy()
 
6
  @last modified : 2023 September 19, 11:18:36
7
  """
8
 
9
+ from typing import Callable, Optional
10
  import re
11
  import imageio
 
12
  from collections import namedtuple
13
+ import numpy as np
14
+
15
+ PhysicalMetadata = namedtuple(
16
+ "PhysicalMetadata", ["width", "height", "pixel_width", "pixel_height", "unit"]
17
+ )
18
+
19
+ MetadataExtractor = Callable[[dict, int, int], Optional[PhysicalMetadata]]
20
+
21
+
22
+ def extract_imagej_metadata(
23
+ metadata: dict, width: int, height: int
24
+ ) -> Optional[PhysicalMetadata]:
25
+ try:
26
+ ipw, iph, _ = metadata["resolution"]
27
+ result = re.search(r"unit=(.+)", metadata["description"])
28
+ if not result:
29
+ return None
30
+ unit = result.group(1)
31
+ return PhysicalMetadata(width, height, 1.0 / ipw, 1.0 / iph, unit.lower())
32
+ except (KeyError, AttributeError):
33
+ return None
34
+
35
 
36
+ def extract_resolution_metadata(
37
+ metadata: dict, width: int, height: int
38
+ ) -> Optional[PhysicalMetadata]:
39
+ try:
40
+ ipw, iph, _ = metadata["resolution"]
41
+ # It looks like the resolution unit is not really reliable, so let's just assume nm
42
+ unit = "nm"
43
+ return PhysicalMetadata(width, height, 1.0 / ipw, 1.0 / iph, unit)
44
+ except (KeyError, AttributeError):
45
+ return None
46
 
47
+
48
+ METADATA_EXTRACTORS: list[MetadataExtractor] = [
49
+ extract_imagej_metadata,
50
+ extract_resolution_metadata,
51
+ ]
52
+
53
+
54
+ def normalize_metadata(metadata: PhysicalMetadata) -> PhysicalMetadata:
55
+ conversion_factor = {
56
+ "inch": 2.54e7,
57
+ "m": 1e9,
58
+ "dm": 1e8,
59
+ "cm": 1e7,
60
+ "mm": 1e6,
61
+ "µm": 1e3,
62
+ "nm": 1,
63
+ }
64
+ if metadata.unit not in conversion_factor:
65
+ raise ValueError(f"Unknown unit: {metadata.unit}")
66
+ factor = conversion_factor[metadata.unit]
67
+ return PhysicalMetadata(
68
+ metadata.width,
69
+ metadata.height,
70
+ metadata.pixel_width * factor,
71
+ metadata.pixel_height * factor,
72
+ "nm",
73
+ )
74
+
75
+
76
+ def extract_physical_metadata(image_path: str, strict: bool = True) -> PhysicalMetadata:
77
  """
78
+ Extracts the physical metadata of an image by trying all available extractors.
79
+ Raises ValueError if no extractor succeeds.
80
  """
81
  with open(image_path, "rb") as f:
82
  data = f.read()
83
+ reader = imageio.get_reader(data)
84
  metadata = reader.get_meta_data()
85
 
 
 
 
 
 
86
  h, w = reader.get_next_data().shape
87
+ for extractor in METADATA_EXTRACTORS:
88
+ result = extractor(metadata, w, h)
89
+ if result is not None:
90
+ return normalize_metadata(result)
91
+
92
+ raise ValueError(
93
+ "Failed to extract metadata from the image using any available method."
94
+ )
95
+
96
 
97
  def tiff_to_png(image, inplace=True):
98
  img = image if inplace else image.copy()