import functools import seqio from t5.evaluation import metrics from t5.data import preprocessors vocabulary = seqio.SentencePieceVocabulary('spiece.model') output_features = { 'inputs': seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False), 'targets': seqio.Feature(vocabulary=vocabulary, add_eos=True) } seqio.TaskRegistry.add( 'pretrain_finnish', source=seqio.TextLineDataSource({ "train": "/researchdisk/lm_training_dataset_full_sentences/train.txt", "validation": "/researchdisk/lm_training_dataset_full_sentences/validation.txt" }), preprocessors=[ functools.partial( preprocessors.parse_tsv, field_names=["text"], field_delim="\n"), functools.partial( preprocessors.rekey, key_map={ "inputs": None, "targets": "text" }), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), preprocessors.span_corruption, seqio.preprocessors.append_eos_after_trim, ], metric_fns=[metrics.accuracy], output_features=output_features) # dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset( # sequence_length={"inputs": 512, "targets": 114}, # split="train", # shuffle=True, # num_epochs=1, # #shard_info=seqio.ShardInfo(index=0, num_shards=10), # use_cached=False, # seed=42 # ) # # Print the first 5 examples. # for _, ex in zip(range(5), dataset.as_numpy_iterator()): # print(ex)