"""Datasets for converting to MDS Shards.""" import os import warnings from typing import Dict, Iterable, Union import datasets as hf_datasets import numpy as np from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase class NoConcatDataset(IterableDataset): """An IterableDataset that returns text samples for MDSWriter. Returns dicts of {'text': bytes} """ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset]): self.hf_dataset = hf_dataset def __iter__(self) -> Iterable[Dict[str, bytes]]: for sample in self.hf_dataset: yield {'text': sample['text'].encode('utf-8')} class ConcatTokensDataset(IterableDataset): """An IterableDataset that returns token samples for MDSWriter. Returns dicts of {'tokens': bytes} To use data created by this class and written to MDS format: ```python import torch from streaming.base import StreamingDataset from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained('your/tokenizer') ds = StreamingDataset(local='mds-data-folder', split='val') # note, you need to copy the numpy array because the original is non-writeable # and torch does not support non-writeable tensors, so you get a scary warning and # if you do try to write to the tensor you get undefined behavior tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy()) print(tokenizer.decode(tokens)) ``` """ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], tokenizer: PreTrainedTokenizerBase, max_length: int, bos_text: str, eos_text: str, no_wrap: bool): self.hf_dataset = hf_dataset self.tokenizer = tokenizer os.environ['TOKENIZERS_PARALLELISM'] = 'false' self.max_length = max_length self.bos_text = bos_text self.eos_text = eos_text self.should_wrap = not no_wrap self.bos_tokens = self.tokenizer(self.bos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids'] if len(self.bos_tokens) > 1: warnings.warn(f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token , instead we got {self.bos_tokens}. Quit if this was in error.') self.eos_tokens = self.tokenizer(self.eos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids'] if len(self.eos_tokens) > 1: warnings.warn(f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token , instead we got {self.eos_tokens}. Quit if this was in error.') eos_text_provided = self.eos_text != '' bos_text_provided = self.bos_text != '' test_text = self.tokenizer('') if len(test_text['input_ids']) > 0 and (eos_text_provided or bos_text_provided): message = 'both eos and bos' if eos_text_provided and bos_text_provided else 'eos_text' if eos_text_provided else 'bos_text' warnings.warn(f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + 'in duplicated special tokens. Please be sure this is what you intend.') def __iter__(self) -> Iterable[Dict[str, bytes]]: buffer = [] for sample in self.hf_dataset: encoded = self.tokenizer(sample['text'], truncation=False, padding=False) iids = encoded['input_ids'] buffer = buffer + self.bos_tokens + iids + self.eos_tokens while len(buffer) >= self.max_length: concat_sample = buffer[:self.max_length] buffer = buffer[self.max_length:] if self.should_wrap else [] yield {'tokens': np.asarray(concat_sample).tobytes()}