|
import copy |
|
import multiprocessing as mp |
|
import time |
|
from typing import List |
|
|
|
from gluonts.dataset.common import Dataset |
|
from gluonts.model.forecast import Forecast, SampleForecast |
|
|
|
from .abstract import AbstractPredictor |
|
|
|
|
|
class AutoPyTorchPredictor(AbstractPredictor): |
|
def __init__( |
|
self, |
|
prediction_length: int, |
|
freq: str, |
|
seasonality: int, |
|
time_limit: int = 6 * 60 * 60, |
|
optimize_metric: str = "mean_MASE_forecasting", |
|
seed: int = 1, |
|
**kwargs |
|
): |
|
super().__init__(prediction_length, freq, seasonality) |
|
self.optimize_metric = optimize_metric |
|
self.run_time = time_limit |
|
self.seed = seed |
|
|
|
def fit_predict(self, dataset: Dataset) -> List[Forecast]: |
|
from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask |
|
from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes |
|
|
|
y_train = [item["target"] for item in dataset] |
|
start_times = [item["start"].to_timestamp(how="S") for item in dataset] |
|
|
|
api = TimeSeriesForecastingTask( |
|
seed=self.seed, |
|
ensemble_size=20, |
|
resampling_strategy=HoldoutValTypes.time_series_hold_out_validation, |
|
resampling_strategy_args=None, |
|
) |
|
api.set_pipeline_options(early_stopping=20, torch_num_threads=mp.cpu_count()) |
|
|
|
fit_start_time = time.time() |
|
api.search( |
|
X_train=None, |
|
y_train=copy.deepcopy(y_train), |
|
optimize_metric=self.optimize_metric, |
|
n_prediction_steps=self.prediction_length, |
|
memory_limit=16 * 1024, |
|
freq="1" + self.freq, |
|
start_times=start_times, |
|
normalize_y=False, |
|
total_walltime_limit=self.run_time, |
|
min_num_test_instances=1000, |
|
budget_type="epochs", |
|
max_budget=50, |
|
min_budget=5, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
test_sets = api.dataset.generate_test_seqs() |
|
predictions = api.predict(test_sets) |
|
self.save_runtime(time.time() - fit_start_time) |
|
forecast_list = [] |
|
for ts, pred in zip(dataset, predictions): |
|
forecast_list.append( |
|
SampleForecast( |
|
samples=pred[None], |
|
start_date=ts["start"] + len(ts["target"]), |
|
item_id=ts["item_id"], |
|
) |
|
) |
|
return forecast_list |
|
|