File size: 849 Bytes
45e60de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pandas as pd
import copy
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.repository.datasets import get_dataset

SEASONALITY_MAP = {
    "Y": 1,
    "Q": 4,
    "M": 12,
    "W": 1,
    "D": 7,
    "H": 24,
}


def fix_m3_other_start(ts: dict):
    new_ts = copy.copy(ts)
    new_ts["start"] = pd.Period("1750", freq="Y")
    return new_ts


def load_dataset(dataset_name) -> TrainDatasets:
    data = get_dataset(dataset_name)
    # m3_other provided by GluonTS has incorrect freq Q that should be replaced by Y
    if dataset_name == "m3_other":
        fixed_train = [fix_m3_other_start(ts) for ts in data.train]
        fixed_test = [fix_m3_other_start(ts) for ts in data.test]
        data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test)
        data.metadata.freq = "Y"
    return data