File size: 7,335 Bytes
ec0c8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from typing import *
import time
from pathlib import Path
from numbers import Number


def catch_exception(fn):
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            import traceback
            print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})")
            traceback.print_exc(chain=False)
            time.sleep(0.1)
            return None
    return wrapper


class CallbackOnException:
    def __init__(self, callback: Callable, exception: type):
        self.exception = exception
        self.callback = callback

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if isinstance(exc_val, self.exception):
            self.callback()
            return True
        return False
    
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
    for k, v in d.items():
        if isinstance(v, dict):
            for sub_key in traverse_nested_dict_keys(v):
                yield (k, ) + sub_key
        else:
            yield (k, )


def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
    for k in keys:
        d = d.get(k, default)
        if d is None:
            break
    return d

def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
    for k in keys[:-1]:
        d = d.setdefault(k, {})
    d[keys[-1]] = value


def key_average(list_of_dicts: list) -> Dict[str, Any]:
    """
    Returns a dictionary with the average value of each key in the input list of dictionaries.
    """
    _nested_dict_keys = set()
    for d in list_of_dicts:
        _nested_dict_keys.update(traverse_nested_dict_keys(d))
    _nested_dict_keys = sorted(_nested_dict_keys)
    result = {}
    for k in _nested_dict_keys:
        values = [
            get_nested_dict(d, k) for d in list_of_dicts
            if get_nested_dict(d, k) is not None
        ]
        avg = sum(values) / len(values) if values else float('nan')
        set_nested_dict(result, k, avg)
    return result


def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
    """
    Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
    """
    items = []
    if parent_key is None:
        parent_key = ()
    for k, v in d.items():
        new_key = parent_key + (k, )
        if isinstance(v, MutableMapping):
            items.extend(flatten_nested_dict(v, new_key).items())
        else:
            items.append((new_key, v))
    return dict(items)


def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
    """
    Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
    """
    result = {}
    for k, v in d.items():
        sub_dict = result
        for k_ in k[:-1]:
            if k_ not in sub_dict:
                sub_dict[k_] = {}
            sub_dict = sub_dict[k_]
        sub_dict[k[-1]] = v
    return result


def read_jsonl(file):
    import json
    with open(file, 'r') as f:
        data = f.readlines()
    return [json.loads(line) for line in data]


def write_jsonl(data: List[dict], file):
    import json
    with open(file, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')


def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]):
    import pandas as pd
    import json
    
    with open(save_path, 'w') as f:
        json.dump(all_metrics, f, indent=4)


def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
    import pandas as pd
    data = [flatten_nested_dict(d) for d in data]
    df = pd.DataFrame(data)
    df = df.sort_index(axis=1)
    df.columns = pd.MultiIndex.from_tuples(df.columns)  
    return df


def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
    if isinstance(d, str):
        for old, new in mapping.items():
            d = d.replace(old, new)
    elif isinstance(d, list):
        for i, item in enumerate(d):
            d[i] = recursive_replace(item, mapping)
    elif isinstance(d, dict):
        for k, v in d.items():
            d[k] = recursive_replace(v, mapping)
    return d


class timeit:
    _history: Dict[str, List['timeit']] = {}

    def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False):
        self.name = name
        self.verbose = verbose
        self.start = None
        self.end = None
        self.multiple = multiple
        if multiple and name not in timeit._history:
            timeit._history[name] = []

    def __call__(self, func: Callable):
        import inspect
        if inspect.iscoroutinefunction(func):
            async def wrapper(*args, **kwargs):
                with timeit(self.name or func.__qualname__):
                    ret = await func(*args, **kwargs)
                return ret
            return wrapper
        else:
            def wrapper(*args, **kwargs):
                with timeit(self.name or func.__qualname__):
                    ret = func(*args, **kwargs)
                return ret
            return wrapper
        
    def __enter__(self):
        self.start = time.time()

    @property
    def time(self) -> float:
        assert self.start is not None, "Time not yet started."
        assert self.end is not None, "Time not yet ended."
        return self.end - self.start

    @property
    def history(self) -> List['timeit']:
        return timeit._history.get(self.name, [])

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.time()
        if self.multiple:
            timeit._history[self.name].append(self)
        if self.verbose:
            if self.multiple:
                avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
                print(f"{self.name or 'It'} took {avg} seconds in average.")
            else:
                print(f"{self.name or 'It'} took {self.time} seconds.")


def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
    first = strings[0]

    for start in range(len(first)):
        if any(s[start] != strings[0][start] for s in strings):
            break

    for end in range(1, min(len(s) for s in strings)):
        if any(s[-end] != first[-end] for s in strings):
            break

    return [s[start:len(s) - end + 1] for s in strings]


def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
    from concurrent.futures import ThreadPoolExecutor
    from contextlib import nullcontext
    from tqdm import tqdm

    if pbar is not None:
        pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
    else:
        pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)

    def decorator(fn: Callable):
        with (
            ThreadPoolExecutor(max_workers=num_workers) as executor,
            pbar
        ):  
            pbar.refresh()
            @catch_exception
            def _fn(input):
                ret = fn(input)
                pbar.update()
                return ret
            executor.map(_fn, inputs)
            executor.shutdown(wait=True)
    
    return decorator