kashif's picture
kashif HF staff
fux unused import
4d941d1
raw
history blame contribute delete
No virus
849 Bytes
import pandas as pd
import copy
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.repository.datasets import get_dataset
SEASONALITY_MAP = {
"Y": 1,
"Q": 4,
"M": 12,
"W": 1,
"D": 7,
"H": 24,
}
def fix_m3_other_start(ts: dict):
new_ts = copy.copy(ts)
new_ts["start"] = pd.Period("1750", freq="Y")
return new_ts
def load_dataset(dataset_name) -> TrainDatasets:
data = get_dataset(dataset_name)
# m3_other provided by GluonTS has incorrect freq Q that should be replaced by Y
if dataset_name == "m3_other":
fixed_train = [fix_m3_other_start(ts) for ts in data.train]
fixed_test = [fix_m3_other_start(ts) for ts in data.test]
data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test)
data.metadata.freq = "Y"
return data