|
import time |
|
from typing import List |
|
import pandas as pd |
|
|
|
from gluonts.dataset.common import Dataset |
|
from gluonts.model.forecast import Forecast, QuantileForecast |
|
|
|
from .abstract import AbstractPredictor |
|
|
|
|
|
class StatsForecastPredictor(AbstractPredictor): |
|
def __init__(self, prediction_length: int, freq: str, seasonality: int, **kwargs): |
|
super().__init__(prediction_length, freq, seasonality) |
|
|
|
def fit_predict(self, dataset: Dataset) -> List[Forecast]: |
|
from statsforecast import StatsForecast |
|
from statsforecast.models import SeasonalNaive |
|
|
|
df = self._to_statsforecast_df(dataset) |
|
models = self._get_models() |
|
predictor = StatsForecast( |
|
df=df, |
|
freq=self.freq, |
|
models=models, |
|
fallback_model=SeasonalNaive(season_length=self.seasonality), |
|
n_jobs=-1, |
|
) |
|
start_time = time.time() |
|
predictions_df = predictor.forecast( |
|
h=self.prediction_length, level=[0, 20, 40, 60, 80] |
|
) |
|
self.save_runtime(time.time() - start_time) |
|
return self._predictions_df_to_gluonts_forecast( |
|
predictions_df, dataset, model_names=[str(m) for m in models] |
|
) |
|
|
|
def _predictions_df_to_gluonts_forecast( |
|
self, |
|
predictions_df: pd.DataFrame, |
|
dataset: Dataset, |
|
model_names: List[str], |
|
) -> List[Forecast]: |
|
def quantile_to_suffix(q: float) -> str: |
|
if q < 0.5: |
|
prefix = "-lo-" |
|
level = 100 - 200 * q |
|
else: |
|
prefix = "-hi-" |
|
level = 200 * q - 100 |
|
return prefix + str(int(level)) |
|
|
|
|
|
columns = {} |
|
for q in self.quantile_levels: |
|
suffix = quantile_to_suffix(q) |
|
columns[str(q)] = predictions_df[[m + suffix for m in model_names]].median( |
|
axis=1 |
|
) |
|
|
|
|
|
forecast_df = pd.DataFrame(columns) |
|
forecast_list = [] |
|
for ts in dataset: |
|
item_id = ts["item_id"] |
|
f = forecast_df.loc[item_id] |
|
forecast_list.append( |
|
QuantileForecast( |
|
forecast_arrays=f.values.T, |
|
forecast_keys=f.columns, |
|
start_date=pd.Period( |
|
predictions_df["ds"].loc[item_id].iloc[0], freq=self.freq |
|
), |
|
item_id=item_id, |
|
) |
|
) |
|
return forecast_list |
|
|
|
def _to_statsforecast_df(self, dataset: Dataset) -> pd.DataFrame: |
|
"""Convert GluonTS Dataset to StatsForecast compatible DataFrame.""" |
|
dfs = [] |
|
for item in dataset: |
|
target = item["target"] |
|
timestamps = pd.date_range( |
|
start=item["start"].to_timestamp(how="S"), |
|
periods=len(target), |
|
freq=self.freq, |
|
) |
|
df = pd.DataFrame( |
|
{ |
|
"unique_id": [item["item_id"]] * len(target), |
|
"ds": timestamps, |
|
"y": target, |
|
} |
|
) |
|
dfs.append(df) |
|
return pd.concat(dfs) |
|
|
|
|
|
class AutoARIMAPredictor(StatsForecastPredictor): |
|
def _get_models(self): |
|
from statsforecast.models import AutoARIMA |
|
|
|
return [AutoARIMA(season_length=self.seasonality)] |
|
|
|
|
|
class AutoETSPredictor(StatsForecastPredictor): |
|
def _get_models(self): |
|
from statsforecast.models import AutoETS |
|
|
|
return [AutoETS(season_length=self.seasonality)] |
|
|
|
|
|
class AutoThetaPredictor(StatsForecastPredictor): |
|
def _get_models(self): |
|
from statsforecast.models import AutoTheta |
|
|
|
return [AutoTheta(season_length=self.seasonality)] |
|
|
|
|
|
class StatsEnsemblePredictor(StatsForecastPredictor): |
|
def _get_models(self): |
|
from statsforecast.models import ( |
|
AutoETS, |
|
AutoARIMA, |
|
AutoTheta, |
|
) |
|
|
|
return [ |
|
AutoETS(season_length=self.seasonality), |
|
AutoTheta(season_length=self.seasonality), |
|
AutoARIMA(season_length=self.seasonality), |
|
] |
|
|