|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset, IterableDataset |
|
|
|
from ..utils.generic import ModelOutput |
|
|
|
|
|
class PipelineDataset(Dataset): |
|
def __init__(self, dataset, process, params): |
|
self.dataset = dataset |
|
self.process = process |
|
self.params = params |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i): |
|
item = self.dataset[i] |
|
processed = self.process(item, **self.params) |
|
return processed |
|
|
|
|
|
class PipelineIterator(IterableDataset): |
|
def __init__(self, loader, infer, params, loader_batch_size=None): |
|
""" |
|
Roughly equivalent to |
|
|
|
``` |
|
for item in loader: |
|
yield infer(item, **params) |
|
``` |
|
|
|
Arguments: |
|
loader (`torch.utils.data.DataLoader` or any iterator): |
|
The iterator that will be used to apply `infer` on. |
|
infer (any function): |
|
The function to apply of each element of `loader`. |
|
params (`dict`): |
|
The parameters passed to `infer` along with every item |
|
loader_batch_size (`int`, *optional*): |
|
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here |
|
making it roughly behave as |
|
|
|
|
|
``` |
|
for items in loader: |
|
for i in loader_batch_size: |
|
item = items[i] |
|
yield infer(item, **params) |
|
```""" |
|
self.loader = loader |
|
self.infer = infer |
|
self.params = params |
|
if loader_batch_size == 1: |
|
|
|
loader_batch_size = None |
|
self.loader_batch_size = loader_batch_size |
|
|
|
|
|
self._loader_batch_index = None |
|
self._loader_batch_data = None |
|
|
|
def __len__(self): |
|
return len(self.loader) |
|
|
|
def __iter__(self): |
|
self.iterator = iter(self.loader) |
|
return self |
|
|
|
def loader_batch_item(self): |
|
""" |
|
Return item located at `loader_batch_index` within the current `loader_batch_data`. |
|
""" |
|
if isinstance(self._loader_batch_data, torch.Tensor): |
|
|
|
result = self._loader_batch_data[self._loader_batch_index] |
|
else: |
|
|
|
loader_batched = {} |
|
for k, element in self._loader_batch_data.items(): |
|
if isinstance(element, ModelOutput): |
|
|
|
element = element.to_tuple() |
|
if isinstance(element[0], torch.Tensor): |
|
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) |
|
elif isinstance(element[0], np.ndarray): |
|
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) |
|
continue |
|
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): |
|
|
|
if isinstance(element[0], torch.Tensor): |
|
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) |
|
elif isinstance(element[0], np.ndarray): |
|
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) |
|
continue |
|
if element is None: |
|
|
|
loader_batched[k] = None |
|
elif isinstance(element[self._loader_batch_index], torch.Tensor): |
|
|
|
|
|
|
|
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0) |
|
elif isinstance(element[self._loader_batch_index], np.ndarray): |
|
|
|
|
|
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0) |
|
else: |
|
|
|
loader_batched[k] = element[self._loader_batch_index] |
|
|
|
|
|
result = self._loader_batch_data.__class__(loader_batched) |
|
self._loader_batch_index += 1 |
|
return result |
|
|
|
def __next__(self): |
|
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: |
|
|
|
|
|
return self.loader_batch_item() |
|
|
|
|
|
item = next(self.iterator) |
|
processed = self.infer(item, **self.params) |
|
|
|
if self.loader_batch_size is not None: |
|
|
|
if isinstance(processed, torch.Tensor): |
|
first_tensor = processed |
|
else: |
|
key = list(processed.keys())[0] |
|
first_tensor = processed[key] |
|
if isinstance(first_tensor, list): |
|
observed_batch_size = len(first_tensor) |
|
else: |
|
observed_batch_size = first_tensor.shape[0] |
|
if 0 < observed_batch_size < self.loader_batch_size: |
|
|
|
|
|
self.loader_batch_size = observed_batch_size |
|
|
|
self._loader_batch_data = processed |
|
self._loader_batch_index = 0 |
|
return self.loader_batch_item() |
|
else: |
|
|
|
return processed |
|
|
|
|
|
class PipelineChunkIterator(PipelineIterator): |
|
def __init__(self, loader, infer, params, loader_batch_size=None): |
|
""" |
|
Roughly equivalent to |
|
|
|
``` |
|
for iterator in loader: |
|
for item in iterator: |
|
yield infer(item, **params) |
|
``` |
|
|
|
Arguments: |
|
loader (`torch.utils.data.DataLoader` or any iterator): |
|
The iterator that will be used to apply `infer` on. |
|
infer (any function): |
|
The function to apply of each element of `loader`. |
|
params (`dict`): |
|
The parameters passed to `infer` along with every item |
|
""" |
|
super().__init__(loader, infer, params) |
|
|
|
def __iter__(self): |
|
self.iterator = iter(self.loader) |
|
self.subiterator = None |
|
return self |
|
|
|
def __next__(self): |
|
if self.subiterator is None: |
|
"Subiterator None means we haven't started a `preprocess` iterator. so start it" |
|
self.subiterator = self.infer(next(self.iterator), **self.params) |
|
try: |
|
|
|
processed = next(self.subiterator) |
|
except StopIteration: |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.subiterator = self.infer(next(self.iterator), **self.params) |
|
processed = next(self.subiterator) |
|
return processed |
|
|
|
|
|
class PipelinePackIterator(PipelineIterator): |
|
""" |
|
Roughly equivalent to |
|
|
|
``` |
|
packed = [] |
|
for item in loader: |
|
packed.append(item) |
|
if item["is_last"]: |
|
yield packed |
|
packed = [] |
|
``` |
|
|
|
but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In |
|
that case it does |
|
|
|
``` |
|
packed = [] |
|
for batch in loader: |
|
# item is batched |
|
for item in batch: |
|
packed.append(item) |
|
if item["is_last"]: |
|
yield packed |
|
packed = [] |
|
``` |
|
|
|
Arguments: |
|
loader (`torch.utils.data.DataLoader` or any iterator): |
|
The iterator that will be used to apply `infer` on. |
|
infer (any function): |
|
The function to apply of each element of `loader`. |
|
params (`dict`): |
|
The parameters passed to `infer` along with every item |
|
loader_batch_size (`int`, *optional*): |
|
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making |
|
it roughly behave as |
|
|
|
|
|
``` |
|
for items in loader: |
|
for i in loader_batch_size: |
|
item = items[i] |
|
yield infer(item, **params) |
|
```""" |
|
|
|
def __iter__(self): |
|
self.iterator = iter(self.loader) |
|
return self |
|
|
|
def __next__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_last = False |
|
accumulator = [] |
|
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: |
|
while self._loader_batch_index < self.loader_batch_size: |
|
item = self.loader_batch_item() |
|
is_last = item.pop("is_last") |
|
accumulator.append(item) |
|
if is_last: |
|
return accumulator |
|
|
|
while not is_last: |
|
processed = self.infer(next(self.iterator), **self.params) |
|
if self.loader_batch_size is not None: |
|
if isinstance(processed, torch.Tensor): |
|
first_tensor = processed |
|
else: |
|
key = list(processed.keys())[0] |
|
first_tensor = processed[key] |
|
if isinstance(first_tensor, list): |
|
observed_batch_size = len(first_tensor) |
|
else: |
|
observed_batch_size = first_tensor.shape[0] |
|
if 0 < observed_batch_size < self.loader_batch_size: |
|
|
|
|
|
self.loader_batch_size = observed_batch_size |
|
self._loader_batch_data = processed |
|
self._loader_batch_index = 0 |
|
while self._loader_batch_index < self.loader_batch_size: |
|
item = self.loader_batch_item() |
|
is_last = item.pop("is_last") |
|
accumulator.append(item) |
|
if is_last: |
|
return accumulator |
|
else: |
|
item = processed |
|
is_last = item.pop("is_last") |
|
accumulator.append(item) |
|
return accumulator |
|
|
|
|
|
class KeyDataset(Dataset): |
|
def __init__(self, dataset: Dataset, key: str): |
|
self.dataset = dataset |
|
self.key = key |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i): |
|
return self.dataset[i][self.key] |
|
|
|
|
|
class KeyPairDataset(Dataset): |
|
def __init__(self, dataset: Dataset, key1: str, key2: str): |
|
self.dataset = dataset |
|
self.key1 = key1 |
|
self.key2 = key2 |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i): |
|
return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]} |
|
|