File size: 2,598 Bytes
0998a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py

import functools

import seqio
import tensorflow as tf
import t5.data
from datasets import load_dataset, load_from_disk
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils

TaskRegistry = seqio.TaskRegistry

vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}


def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


# Final pretraining task used in Raffel et al., 2019 adaptated to our use
dataset_name = "/researchdisk/lm_training_dataset_full"
dataset_params = {"from_disk_path": dataset_name}

if "from_disk_path" in dataset_params:
    dataset = load_from_disk(dataset_params.get("from_disk_path"))
else:
    dataset = load_dataset(**dataset_params)

dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
TaskRegistry.add(
    "pretrain_finnish",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset=dataset),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=dataset_shapes,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption, 
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
    metric_fns=[metrics.accuracy]
)