MolmoE-1B-0924 / utils.py
Muennighoff's picture
Cp over files
18652d8
raw
history blame
6.28 kB
import dataclasses
import hashlib
import sys
import typing
import warnings
import socket
from typing import Optional, Any, Dict
import os
import logging
import absl.flags
from flax.traverse_util import flatten_dict
from ml_collections import ConfigDict, config_flags
from ml_collections.config_dict import placeholder
from mlxu import function_args_to_config
_log_extra_fields: Dict[str, Any] = {}
def is_float_printable(x):
try:
f"{x:0.2f}"
return True
except (ValueError, TypeError):
return False
def compute_hash(string: str) -> str:
"""Computes the hash of a string."""
return hashlib.sha256(string.encode("utf-8")).hexdigest()
def pop_metadata(data):
meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
return data, meta
def setup_logging():
handler: logging.Handler
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
datefmt="%H:%M:%S"
)
handler.setFormatter(formatter)
logging.basicConfig(handlers=[handler], level=logging.INFO)
logging.captureWarnings(True)
logging.getLogger("urllib3").setLevel(logging.ERROR)
def get_maybe_optional_type(field_type):
if type(None) in typing.get_args(field_type):
# Handle optional type
args = [x for x in typing.get_args(field_type) if x != type(None)]
assert len(args) == 1
field_type = args[0]
return field_type
def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
"""Build a `ConfigDict` matching the possibly nested dataclass
dataclass: A dataclass instance or a dataclass type, if an instance defaults
will be set to the values in the class, if a class defaults will be
set to the field defaults, or None if the field is required
defaults_to_none: Make all defaults None
"""
out = {}
fields = dataclasses.fields(dataclass)
for field in fields:
if not field.init:
continue
if defaults_to_none:
default = None
elif hasattr(dataclass, field.name):
default = getattr(dataclass, field.name)
elif field.default is dataclasses.MISSING:
default = None
else:
default = field.default
field_type = get_maybe_optional_type(field.type)
if hasattr(field_type, "__dataclass_fields__"):
if not defaults_to_none and default is None:
pass
else:
out[field.name] = config_from_dataclass(
default or field.type, defaults_to_none=defaults_to_none)
else:
if default is None:
assert not field_type == typing.Any
origin = getattr(field_type, "__origin__", None)
if origin is not None:
field_type = origin
out[field.name] = placeholder(field_type)
else:
out[field.name] = default
return ConfigDict(out)
def dataclass_with_none(cls):
"""Build an instance of possibly nested dataclass `cls` with all attributes None"""
fields = dataclasses.fields(cls)
args = {}
for field in fields:
if not field.init:
pass
elif dataclasses.is_dataclass(field.type):
args[field.name] = dataclass_with_none(field.type)
else:
args[field.name] = None
return cls(**args)
def dataclass_from_config(cls, config: Dict):
"""Build an instance of `cls` with attributes from `config``"""
fields = dataclasses.fields(cls)
args = set(x.name for x in fields)
for k in config.keys():
if k not in args:
raise ValueError(f"Config has unknown arg {k} fr {cls}")
args = {}
for field in fields:
if not field.init:
continue
field_type = get_maybe_optional_type(field.type)
if hasattr(field_type, "__dataclass_fields__"):
if config.get(field.name) is None:
args[field.name] = None
elif hasattr(field_type, "from_dict"):
src = config[field.name]
if isinstance(src, ConfigDict):
src = src.to_dict()
args[field.name] = field_type.from_dict(src)
else:
args[field.name] = dataclass_from_config(field_type, config[field.name])
elif field.name in config:
if isinstance(config[field.name], ConfigDict):
args[field.name] = config[field.name].to_dict()
else:
args[field.name] = config[field.name]
return cls(**args)
def update_dataclass(obj, updates):
"""Sets attributes in `obj` to match non-None fields in `updates`"""
fields = dataclasses.fields(obj)
for field in fields:
if not field.init:
continue
update = updates.get(field.name)
if update is None:
continue
current_value = getattr(obj, field.name)
if dataclasses.is_dataclass(current_value):
update_dataclass(current_value, update)
else:
if isinstance(update, (ConfigDict, dict)):
assert all(x is None for x in flatten_dict(update).values())
else:
setattr(obj, field.name, update)
def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
# Stolen from the OLMo codebase
def format_value(value: float) -> str:
if isinstance(value, str):
return value
if value < 0.0001:
return str(value) # scientific notation
elif value > 1000:
return f"{int(value):,d}"
elif value > 100:
return f"{value:.1f}"
elif value > 10:
return f"{value:.2f}"
elif value > 1:
return f"{value:.3f}"
else:
return f"{value:.4f}"
logging.info(
f"{prefix}\n"
+ "\n".join(
[
f" {name}={format_value(value)}"
for name, value in metrics.items()
if not name.startswith("optim/") # there's too many optimizer metrics
]
)
)