File size: 2,650 Bytes
45e60de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import copy
import multiprocessing as mp
import time
from typing import List

from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast, SampleForecast

from .abstract import AbstractPredictor


class AutoPyTorchPredictor(AbstractPredictor):
    def __init__(
        self,
        prediction_length: int,
        freq: str,
        seasonality: int,
        time_limit: int = 6 * 60 * 60,
        optimize_metric: str = "mean_MASE_forecasting",
        seed: int = 1,
        **kwargs
    ):
        super().__init__(prediction_length, freq, seasonality)
        self.optimize_metric = optimize_metric
        self.run_time = time_limit
        self.seed = seed

    def fit_predict(self, dataset: Dataset) -> List[Forecast]:
        from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask
        from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes

        y_train = [item["target"] for item in dataset]
        start_times = [item["start"].to_timestamp(how="S") for item in dataset]

        api = TimeSeriesForecastingTask(
            seed=self.seed,
            ensemble_size=20,
            resampling_strategy=HoldoutValTypes.time_series_hold_out_validation,
            resampling_strategy_args=None,
        )
        api.set_pipeline_options(early_stopping=20, torch_num_threads=mp.cpu_count())

        fit_start_time = time.time()
        api.search(
            X_train=None,
            y_train=copy.deepcopy(y_train),
            optimize_metric=self.optimize_metric,
            n_prediction_steps=self.prediction_length,
            memory_limit=16 * 1024,
            freq="1" + self.freq,
            start_times=start_times,
            normalize_y=False,
            total_walltime_limit=self.run_time,
            min_num_test_instances=1000,
            budget_type="epochs",
            max_budget=50,
            min_budget=5,
        )
        # # Skip refitting as this raises exceptions for all models as of v0.2.1
        # refit_dataset = api.dataset.create_refit_set()
        # api.refit(refit_dataset, 0)

        # Predict for the test set
        test_sets = api.dataset.generate_test_seqs()
        predictions = api.predict(test_sets)
        self.save_runtime(time.time() - fit_start_time)
        forecast_list = []
        for ts, pred in zip(dataset, predictions):
            forecast_list.append(
                SampleForecast(
                    samples=pred[None],
                    start_date=ts["start"] + len(ts["target"]),
                    item_id=ts["item_id"],
                )
            )
        return forecast_list