time-series-score / src /models /statsforecast.py
kashif's picture
kashif HF staff
Upload 10 files
45e60de
raw
history blame
4.28 kB
import time
from typing import List
import pandas as pd
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast, QuantileForecast
from .abstract import AbstractPredictor
class StatsForecastPredictor(AbstractPredictor):
def __init__(self, prediction_length: int, freq: str, seasonality: int, **kwargs):
super().__init__(prediction_length, freq, seasonality)
def fit_predict(self, dataset: Dataset) -> List[Forecast]:
from statsforecast import StatsForecast
from statsforecast.models import SeasonalNaive
df = self._to_statsforecast_df(dataset)
models = self._get_models()
predictor = StatsForecast(
df=df,
freq=self.freq,
models=models,
fallback_model=SeasonalNaive(season_length=self.seasonality),
n_jobs=-1,
)
start_time = time.time()
predictions_df = predictor.forecast(
h=self.prediction_length, level=[0, 20, 40, 60, 80]
)
self.save_runtime(time.time() - start_time)
return self._predictions_df_to_gluonts_forecast(
predictions_df, dataset, model_names=[str(m) for m in models]
)
def _predictions_df_to_gluonts_forecast(
self,
predictions_df: pd.DataFrame,
dataset: Dataset,
model_names: List[str],
) -> List[Forecast]:
def quantile_to_suffix(q: float) -> str:
if q < 0.5:
prefix = "-lo-"
level = 100 - 200 * q
else:
prefix = "-hi-"
level = 200 * q - 100
return prefix + str(int(level))
# Convert StatsForecast output -> DataFrame with quantile_levels as outputs
columns = {}
for q in self.quantile_levels:
suffix = quantile_to_suffix(q)
columns[str(q)] = predictions_df[[m + suffix for m in model_names]].median(
axis=1
)
# Convert quantiles DataFrame -> list of QuantileForecasts
forecast_df = pd.DataFrame(columns)
forecast_list = []
for ts in dataset:
item_id = ts["item_id"]
f = forecast_df.loc[item_id]
forecast_list.append(
QuantileForecast(
forecast_arrays=f.values.T,
forecast_keys=f.columns,
start_date=pd.Period(
predictions_df["ds"].loc[item_id].iloc[0], freq=self.freq
),
item_id=item_id,
)
)
return forecast_list
def _to_statsforecast_df(self, dataset: Dataset) -> pd.DataFrame:
"""Convert GluonTS Dataset to StatsForecast compatible DataFrame."""
dfs = []
for item in dataset:
target = item["target"]
timestamps = pd.date_range(
start=item["start"].to_timestamp(how="S"),
periods=len(target),
freq=self.freq,
)
df = pd.DataFrame(
{
"unique_id": [item["item_id"]] * len(target),
"ds": timestamps,
"y": target,
}
)
dfs.append(df)
return pd.concat(dfs)
class AutoARIMAPredictor(StatsForecastPredictor):
def _get_models(self):
from statsforecast.models import AutoARIMA
return [AutoARIMA(season_length=self.seasonality)]
class AutoETSPredictor(StatsForecastPredictor):
def _get_models(self):
from statsforecast.models import AutoETS
return [AutoETS(season_length=self.seasonality)]
class AutoThetaPredictor(StatsForecastPredictor):
def _get_models(self):
from statsforecast.models import AutoTheta
return [AutoTheta(season_length=self.seasonality)]
class StatsEnsemblePredictor(StatsForecastPredictor):
def _get_models(self):
from statsforecast.models import (
AutoETS,
AutoARIMA,
AutoTheta,
)
return [
AutoETS(season_length=self.seasonality),
AutoTheta(season_length=self.seasonality),
AutoARIMA(season_length=self.seasonality),
]