gpt2-medium-persian / src /prep_dataset.py
m3hrdadfi's picture
Add training script with checkpoint and preprocessing + merge scripts
7cfca48
raw
history blame
No virus
1.08 kB
from datasets import load_dataset, DatasetDict
from hazm import sent_tokenize
from normalizer import normalize
class Prep_dataset:
def __init__(self, subsample=False, *args, **kwargs):
raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_fa")
if subsample:
sample_dataset = raw_dataset.copy()
sample_dataset["sample"] = sample_dataset["train"].select(range(100))
sample_dataset.pop("train")
sample_dataset["train"] = sample_dataset.pop("sample")
final = DatasetDict(sample_dataset)
self.raw_dataset = final
else:
self.raw_dataset = raw_dataset
def _normalize(self, example):
example["text"] = normalize(example["text"])
return example
def preprare_dataset(self):
big_dataset = self.raw_dataset.filter(lambda x: len(x["text"]) > 500)
richSent_dataset = big_dataset.filter(lambda x: len(sent_tokenize(x["text"])) > 2)
normalized_dataset = richSent_dataset.map(self._normalize)
return normalized_dataset