from typing import * import time from pathlib import Path from numbers import Number def catch_exception(fn): def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) except Exception as e: import traceback print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})") traceback.print_exc(chain=False) time.sleep(0.1) return None return wrapper class CallbackOnException: def __init__(self, callback: Callable, exception: type): self.exception = exception self.callback = callback def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_val, self.exception): self.callback() return True return False def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: for k, v in d.items(): if isinstance(v, dict): for sub_key in traverse_nested_dict_keys(v): yield (k, ) + sub_key else: yield (k, ) def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): for k in keys: d = d.get(k, default) if d is None: break return d def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): for k in keys[:-1]: d = d.setdefault(k, {}) d[keys[-1]] = value def key_average(list_of_dicts: list) -> Dict[str, Any]: """ Returns a dictionary with the average value of each key in the input list of dictionaries. """ _nested_dict_keys = set() for d in list_of_dicts: _nested_dict_keys.update(traverse_nested_dict_keys(d)) _nested_dict_keys = sorted(_nested_dict_keys) result = {} for k in _nested_dict_keys: values = [ get_nested_dict(d, k) for d in list_of_dicts if get_nested_dict(d, k) is not None ] avg = sum(values) / len(values) if values else float('nan') set_nested_dict(result, k, avg) return result def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: """ Flattens a nested dictionary into a single-level dictionary, with keys as tuples. """ items = [] if parent_key is None: parent_key = () for k, v in d.items(): new_key = parent_key + (k, ) if isinstance(v, MutableMapping): items.extend(flatten_nested_dict(v, new_key).items()) else: items.append((new_key, v)) return dict(items) def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: """ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. """ result = {} for k, v in d.items(): sub_dict = result for k_ in k[:-1]: if k_ not in sub_dict: sub_dict[k_] = {} sub_dict = sub_dict[k_] sub_dict[k[-1]] = v return result def read_jsonl(file): import json with open(file, 'r') as f: data = f.readlines() return [json.loads(line) for line in data] def write_jsonl(data: List[dict], file): import json with open(file, 'w') as f: for item in data: f.write(json.dumps(item) + '\n') def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]): import pandas as pd import json with open(save_path, 'w') as f: json.dump(all_metrics, f, indent=4) def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): import pandas as pd data = [flatten_nested_dict(d) for d in data] df = pd.DataFrame(data) df = df.sort_index(axis=1) df.columns = pd.MultiIndex.from_tuples(df.columns) return df def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): if isinstance(d, str): for old, new in mapping.items(): d = d.replace(old, new) elif isinstance(d, list): for i, item in enumerate(d): d[i] = recursive_replace(item, mapping) elif isinstance(d, dict): for k, v in d.items(): d[k] = recursive_replace(v, mapping) return d class timeit: _history: Dict[str, List['timeit']] = {} def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False): self.name = name self.verbose = verbose self.start = None self.end = None self.multiple = multiple if multiple and name not in timeit._history: timeit._history[name] = [] def __call__(self, func: Callable): import inspect if inspect.iscoroutinefunction(func): async def wrapper(*args, **kwargs): with timeit(self.name or func.__qualname__): ret = await func(*args, **kwargs) return ret return wrapper else: def wrapper(*args, **kwargs): with timeit(self.name or func.__qualname__): ret = func(*args, **kwargs) return ret return wrapper def __enter__(self): self.start = time.time() @property def time(self) -> float: assert self.start is not None, "Time not yet started." assert self.end is not None, "Time not yet ended." return self.end - self.start @property def history(self) -> List['timeit']: return timeit._history.get(self.name, []) def __exit__(self, exc_type, exc_val, exc_tb): self.end = time.time() if self.multiple: timeit._history[self.name].append(self) if self.verbose: if self.multiple: avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) print(f"{self.name or 'It'} took {avg} seconds in average.") else: print(f"{self.name or 'It'} took {self.time} seconds.") def strip_common_prefix_suffix(strings: List[str]) -> List[str]: first = strings[0] for start in range(len(first)): if any(s[start] != strings[0][start] for s in strings): break for end in range(1, min(len(s) for s in strings)): if any(s[-end] != first[-end] for s in strings): break return [s[start:len(s) - end + 1] for s in strings] def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from tqdm import tqdm if pbar is not None: pbar.total = len(inputs) if hasattr(inputs, '__len__') else None else: pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) def decorator(fn: Callable): with ( ThreadPoolExecutor(max_workers=num_workers) as executor, pbar ): pbar.refresh() @catch_exception def _fn(input): ret = fn(input) pbar.update() return ret executor.map(_fn, inputs) executor.shutdown(wait=True) return decorator