import pandas as pd import copy from gluonts.dataset.common import TrainDatasets from gluonts.dataset.repository.datasets import get_dataset SEASONALITY_MAP = { "Y": 1, "Q": 4, "M": 12, "W": 1, "D": 7, "H": 24, } def fix_m3_other_start(ts: dict): new_ts = copy.copy(ts) new_ts["start"] = pd.Period("1750", freq="Y") return new_ts def load_dataset(dataset_name) -> TrainDatasets: data = get_dataset(dataset_name) # m3_other provided by GluonTS has incorrect freq Q that should be replaced by Y if dataset_name == "m3_other": fixed_train = [fix_m3_other_start(ts) for ts in data.train] fixed_test = [fix_m3_other_start(ts) for ts in data.test] data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test) data.metadata.freq = "Y" return data