kashif's picture
kashif HF staff
Upload 10 files
45e60de
raw
history blame
874 Bytes
import pandas as pd
import copy
from typing import Tuple
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.repository.datasets import get_dataset
SEASONALITY_MAP = {
"Y": 1,
"Q": 4,
"M": 12,
"W": 1,
"D": 7,
"H": 24,
}
def fix_m3_other_start(ts: dict):
new_ts = copy.copy(ts)
new_ts["start"] = pd.Period("1750", freq="Y")
return new_ts
def load_dataset(dataset_name) -> TrainDatasets:
data = get_dataset(dataset_name)
# m3_other provided by GluonTS has incorrect freq Q that should be replaced by Y
if dataset_name == "m3_other":
fixed_train = [fix_m3_other_start(ts) for ts in data.train]
fixed_test = [fix_m3_other_start(ts) for ts in data.test]
data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test)
data.metadata.freq = "Y"
return data