import dataclasses import hashlib import sys import typing import warnings import socket from typing import Optional, Any, Dict import os import logging import absl.flags from flax.traverse_util import flatten_dict from ml_collections import ConfigDict, config_flags from ml_collections.config_dict import placeholder from mlxu import function_args_to_config _log_extra_fields: Dict[str, Any] = {} def is_float_printable(x): try: f"{x:0.2f}" return True except (ValueError, TypeError): return False def compute_hash(string: str) -> str: """Computes the hash of a string.""" return hashlib.sha256(string.encode("utf-8")).hexdigest() def pop_metadata(data): meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")} return data, meta def setup_logging(): handler: logging.Handler handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( "[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s", datefmt="%H:%M:%S" ) handler.setFormatter(formatter) logging.basicConfig(handlers=[handler], level=logging.INFO) logging.captureWarnings(True) logging.getLogger("urllib3").setLevel(logging.ERROR) def get_maybe_optional_type(field_type): if type(None) in typing.get_args(field_type): # Handle optional type args = [x for x in typing.get_args(field_type) if x != type(None)] assert len(args) == 1 field_type = args[0] return field_type def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict: """Build a `ConfigDict` matching the possibly nested dataclass dataclass: A dataclass instance or a dataclass type, if an instance defaults will be set to the values in the class, if a class defaults will be set to the field defaults, or None if the field is required defaults_to_none: Make all defaults None """ out = {} fields = dataclasses.fields(dataclass) for field in fields: if not field.init: continue if defaults_to_none: default = None elif hasattr(dataclass, field.name): default = getattr(dataclass, field.name) elif field.default is dataclasses.MISSING: default = None else: default = field.default field_type = get_maybe_optional_type(field.type) if hasattr(field_type, "__dataclass_fields__"): if not defaults_to_none and default is None: pass else: out[field.name] = config_from_dataclass( default or field.type, defaults_to_none=defaults_to_none) else: if default is None: assert not field_type == typing.Any origin = getattr(field_type, "__origin__", None) if origin is not None: field_type = origin out[field.name] = placeholder(field_type) else: out[field.name] = default return ConfigDict(out) def dataclass_with_none(cls): """Build an instance of possibly nested dataclass `cls` with all attributes None""" fields = dataclasses.fields(cls) args = {} for field in fields: if not field.init: pass elif dataclasses.is_dataclass(field.type): args[field.name] = dataclass_with_none(field.type) else: args[field.name] = None return cls(**args) def dataclass_from_config(cls, config: Dict): """Build an instance of `cls` with attributes from `config``""" fields = dataclasses.fields(cls) args = set(x.name for x in fields) for k in config.keys(): if k not in args: raise ValueError(f"Config has unknown arg {k} fr {cls}") args = {} for field in fields: if not field.init: continue field_type = get_maybe_optional_type(field.type) if hasattr(field_type, "__dataclass_fields__"): if config.get(field.name) is None: args[field.name] = None elif hasattr(field_type, "from_dict"): src = config[field.name] if isinstance(src, ConfigDict): src = src.to_dict() args[field.name] = field_type.from_dict(src) else: args[field.name] = dataclass_from_config(field_type, config[field.name]) elif field.name in config: if isinstance(config[field.name], ConfigDict): args[field.name] = config[field.name].to_dict() else: args[field.name] = config[field.name] return cls(**args) def update_dataclass(obj, updates): """Sets attributes in `obj` to match non-None fields in `updates`""" fields = dataclasses.fields(obj) for field in fields: if not field.init: continue update = updates.get(field.name) if update is None: continue current_value = getattr(obj, field.name) if dataclasses.is_dataclass(current_value): update_dataclass(current_value, update) else: if isinstance(update, (ConfigDict, dict)): assert all(x is None for x in flatten_dict(update).values()) else: setattr(obj, field.name, update) def log_metrics_to_console(prefix: str, metrics: Dict[str, float]): # Stolen from the OLMo codebase def format_value(value: float) -> str: if isinstance(value, str): return value if value < 0.0001: return str(value) # scientific notation elif value > 1000: return f"{int(value):,d}" elif value > 100: return f"{value:.1f}" elif value > 10: return f"{value:.2f}" elif value > 1: return f"{value:.3f}" else: return f"{value:.4f}" logging.info( f"{prefix}\n" + "\n".join( [ f" {name}={format_value(value)}" for name, value in metrics.items() if not name.startswith("optim/") # there's too many optimizer metrics ] ) )