Add log info
Browse files- src/run_clm_flax.py +3 -1
src/run_clm_flax.py
CHANGED
@@ -64,6 +64,8 @@ from data_utils import (
|
|
64 |
normalizer
|
65 |
)
|
66 |
|
|
|
|
|
67 |
logger = logging.getLogger(__name__)
|
68 |
|
69 |
# Cache the result
|
@@ -366,7 +368,7 @@ def main():
|
|
366 |
# dataset = dataset.map(normalizer)
|
367 |
# logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
368 |
dataset = raw_dataset
|
369 |
-
|
370 |
# Load pretrained model and tokenizer
|
371 |
|
372 |
# Distributed training:
|
|
|
64 |
normalizer
|
65 |
)
|
66 |
|
67 |
+
print(jax.devices())
|
68 |
+
|
69 |
logger = logging.getLogger(__name__)
|
70 |
|
71 |
# Cache the result
|
|
|
368 |
# dataset = dataset.map(normalizer)
|
369 |
# logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
370 |
dataset = raw_dataset
|
371 |
+
|
372 |
# Load pretrained model and tokenizer
|
373 |
|
374 |
# Distributed training:
|