File size: 849 Bytes
45e60de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import pandas as pd
import copy
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.repository.datasets import get_dataset
SEASONALITY_MAP = {
"Y": 1,
"Q": 4,
"M": 12,
"W": 1,
"D": 7,
"H": 24,
}
def fix_m3_other_start(ts: dict):
new_ts = copy.copy(ts)
new_ts["start"] = pd.Period("1750", freq="Y")
return new_ts
def load_dataset(dataset_name) -> TrainDatasets:
data = get_dataset(dataset_name)
# m3_other provided by GluonTS has incorrect freq Q that should be replaced by Y
if dataset_name == "m3_other":
fixed_train = [fix_m3_other_start(ts) for ts in data.train]
fixed_test = [fix_m3_other_start(ts) for ts in data.test]
data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test)
data.metadata.freq = "Y"
return data
|