|
import pandas as pd |
|
import copy |
|
from typing import Tuple |
|
from gluonts.dataset.common import TrainDatasets |
|
from gluonts.dataset.repository.datasets import get_dataset |
|
|
|
SEASONALITY_MAP = { |
|
"Y": 1, |
|
"Q": 4, |
|
"M": 12, |
|
"W": 1, |
|
"D": 7, |
|
"H": 24, |
|
} |
|
|
|
|
|
def fix_m3_other_start(ts: dict): |
|
new_ts = copy.copy(ts) |
|
new_ts["start"] = pd.Period("1750", freq="Y") |
|
return new_ts |
|
|
|
|
|
def load_dataset(dataset_name) -> TrainDatasets: |
|
data = get_dataset(dataset_name) |
|
|
|
if dataset_name == "m3_other": |
|
fixed_train = [fix_m3_other_start(ts) for ts in data.train] |
|
fixed_test = [fix_m3_other_start(ts) for ts in data.test] |
|
data = TrainDatasets(metadata=data.metadata, train=fixed_train, test=fixed_test) |
|
data.metadata.freq = "Y" |
|
return data |
|
|