|
|
|
|
|
import torch |
|
import numpy as np |
|
|
|
|
|
def worker_init_fn(_): |
|
worker_info = torch.utils.data.get_worker_info() |
|
worker_id = worker_info.id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
|
|
def collation_fn(samples, combine_tensors=True, combine_scalars=True): |
|
""" |
|
|
|
Args: |
|
samples (list[dict]): |
|
combine_tensors: |
|
combine_scalars: |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
result = {} |
|
|
|
keys = samples[0].keys() |
|
|
|
for key in keys: |
|
result[key] = [] |
|
|
|
for sample in samples: |
|
for key in keys: |
|
val = sample[key] |
|
result[key].append(val) |
|
|
|
for key in keys: |
|
val_list = result[key] |
|
if isinstance(val_list[0], (int, float)): |
|
if combine_scalars: |
|
result[key] = np.array(result[key]) |
|
|
|
elif isinstance(val_list[0], torch.Tensor): |
|
if combine_tensors: |
|
result[key] = torch.stack(val_list) |
|
|
|
elif isinstance(val_list[0], np.ndarray): |
|
if combine_tensors: |
|
result[key] = np.stack(val_list) |
|
|
|
return result |
|
|