time-series-score / src /models /autopytorch.py
kashif's picture
kashif HF staff
Upload 10 files
45e60de
raw
history blame
2.65 kB
import copy
import multiprocessing as mp
import time
from typing import List
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast, SampleForecast
from .abstract import AbstractPredictor
class AutoPyTorchPredictor(AbstractPredictor):
def __init__(
self,
prediction_length: int,
freq: str,
seasonality: int,
time_limit: int = 6 * 60 * 60,
optimize_metric: str = "mean_MASE_forecasting",
seed: int = 1,
**kwargs
):
super().__init__(prediction_length, freq, seasonality)
self.optimize_metric = optimize_metric
self.run_time = time_limit
self.seed = seed
def fit_predict(self, dataset: Dataset) -> List[Forecast]:
from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask
from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes
y_train = [item["target"] for item in dataset]
start_times = [item["start"].to_timestamp(how="S") for item in dataset]
api = TimeSeriesForecastingTask(
seed=self.seed,
ensemble_size=20,
resampling_strategy=HoldoutValTypes.time_series_hold_out_validation,
resampling_strategy_args=None,
)
api.set_pipeline_options(early_stopping=20, torch_num_threads=mp.cpu_count())
fit_start_time = time.time()
api.search(
X_train=None,
y_train=copy.deepcopy(y_train),
optimize_metric=self.optimize_metric,
n_prediction_steps=self.prediction_length,
memory_limit=16 * 1024,
freq="1" + self.freq,
start_times=start_times,
normalize_y=False,
total_walltime_limit=self.run_time,
min_num_test_instances=1000,
budget_type="epochs",
max_budget=50,
min_budget=5,
)
# # Skip refitting as this raises exceptions for all models as of v0.2.1
# refit_dataset = api.dataset.create_refit_set()
# api.refit(refit_dataset, 0)
# Predict for the test set
test_sets = api.dataset.generate_test_seqs()
predictions = api.predict(test_sets)
self.save_runtime(time.time() - fit_start_time)
forecast_list = []
for ts, pred in zip(dataset, predictions):
forecast_list.append(
SampleForecast(
samples=pred[None],
start_date=ts["start"] + len(ts["target"]),
item_id=ts["item_id"],
)
)
return forecast_list