File size: 6,284 Bytes
18652d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import dataclasses
import hashlib
import sys
import typing
import warnings
import socket
from typing import Optional, Any, Dict
import os
import logging
import absl.flags
from flax.traverse_util import flatten_dict
from ml_collections import ConfigDict, config_flags
from ml_collections.config_dict import placeholder
from mlxu import function_args_to_config
_log_extra_fields: Dict[str, Any] = {}
def is_float_printable(x):
try:
f"{x:0.2f}"
return True
except (ValueError, TypeError):
return False
def compute_hash(string: str) -> str:
"""Computes the hash of a string."""
return hashlib.sha256(string.encode("utf-8")).hexdigest()
def pop_metadata(data):
meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
return data, meta
def setup_logging():
handler: logging.Handler
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
datefmt="%H:%M:%S"
)
handler.setFormatter(formatter)
logging.basicConfig(handlers=[handler], level=logging.INFO)
logging.captureWarnings(True)
logging.getLogger("urllib3").setLevel(logging.ERROR)
def get_maybe_optional_type(field_type):
if type(None) in typing.get_args(field_type):
# Handle optional type
args = [x for x in typing.get_args(field_type) if x != type(None)]
assert len(args) == 1
field_type = args[0]
return field_type
def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
"""Build a `ConfigDict` matching the possibly nested dataclass
dataclass: A dataclass instance or a dataclass type, if an instance defaults
will be set to the values in the class, if a class defaults will be
set to the field defaults, or None if the field is required
defaults_to_none: Make all defaults None
"""
out = {}
fields = dataclasses.fields(dataclass)
for field in fields:
if not field.init:
continue
if defaults_to_none:
default = None
elif hasattr(dataclass, field.name):
default = getattr(dataclass, field.name)
elif field.default is dataclasses.MISSING:
default = None
else:
default = field.default
field_type = get_maybe_optional_type(field.type)
if hasattr(field_type, "__dataclass_fields__"):
if not defaults_to_none and default is None:
pass
else:
out[field.name] = config_from_dataclass(
default or field.type, defaults_to_none=defaults_to_none)
else:
if default is None:
assert not field_type == typing.Any
origin = getattr(field_type, "__origin__", None)
if origin is not None:
field_type = origin
out[field.name] = placeholder(field_type)
else:
out[field.name] = default
return ConfigDict(out)
def dataclass_with_none(cls):
"""Build an instance of possibly nested dataclass `cls` with all attributes None"""
fields = dataclasses.fields(cls)
args = {}
for field in fields:
if not field.init:
pass
elif dataclasses.is_dataclass(field.type):
args[field.name] = dataclass_with_none(field.type)
else:
args[field.name] = None
return cls(**args)
def dataclass_from_config(cls, config: Dict):
"""Build an instance of `cls` with attributes from `config``"""
fields = dataclasses.fields(cls)
args = set(x.name for x in fields)
for k in config.keys():
if k not in args:
raise ValueError(f"Config has unknown arg {k} fr {cls}")
args = {}
for field in fields:
if not field.init:
continue
field_type = get_maybe_optional_type(field.type)
if hasattr(field_type, "__dataclass_fields__"):
if config.get(field.name) is None:
args[field.name] = None
elif hasattr(field_type, "from_dict"):
src = config[field.name]
if isinstance(src, ConfigDict):
src = src.to_dict()
args[field.name] = field_type.from_dict(src)
else:
args[field.name] = dataclass_from_config(field_type, config[field.name])
elif field.name in config:
if isinstance(config[field.name], ConfigDict):
args[field.name] = config[field.name].to_dict()
else:
args[field.name] = config[field.name]
return cls(**args)
def update_dataclass(obj, updates):
"""Sets attributes in `obj` to match non-None fields in `updates`"""
fields = dataclasses.fields(obj)
for field in fields:
if not field.init:
continue
update = updates.get(field.name)
if update is None:
continue
current_value = getattr(obj, field.name)
if dataclasses.is_dataclass(current_value):
update_dataclass(current_value, update)
else:
if isinstance(update, (ConfigDict, dict)):
assert all(x is None for x in flatten_dict(update).values())
else:
setattr(obj, field.name, update)
def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
# Stolen from the OLMo codebase
def format_value(value: float) -> str:
if isinstance(value, str):
return value
if value < 0.0001:
return str(value) # scientific notation
elif value > 1000:
return f"{int(value):,d}"
elif value > 100:
return f"{value:.1f}"
elif value > 10:
return f"{value:.2f}"
elif value > 1:
return f"{value:.3f}"
else:
return f"{value:.4f}"
logging.info(
f"{prefix}\n"
+ "\n".join(
[
f" {name}={format_value(value)}"
for name, value in metrics.items()
if not name.startswith("optim/") # there's too many optimizer metrics
]
)
) |