File size: 4,025 Bytes
a028d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import click
import datetime
import pprint
from typing import Optional


from src import (
    load_dataset,
    fit_predict_with_model,
    score_predictions,
    AVAILABLE_DATASETS,
    AVAILABLE_MODELS,
    SEASONALITY_MAP,
)


def apply_ablation(ablation: str, model_kwargs: dict) -> dict:
    if ablation == "NoEnsemble":
        model_kwargs["enable_ensemble"] = False
    elif ablation == "NoDeepModels":
        model_kwargs["hyperparameters"] = {
            "Naive": {},
            "SeasonalNaive": {},
            "ARIMA": {},
            "ETS": {},
            "AutoETS": {},
            "AutoARIMA": {},
            "Theta": {},
            "AutoGluonTabular": {},
        }
    elif ablation == "NoStatModels":
        model_kwargs["hyperparameters"] = {
            "AutoGluonTabular": {},
            "DeepAR": {},
            "SimpleFeedForward": {},
            "TemporalFusionTransformer": {},
        }
    elif ablation == "NoTreeModels":
        model_kwargs["hyperparameters"] = {
            "Naive": {},
            "SeasonalNaive": {},
            "ARIMA": {},
            "ETS": {},
            "AutoETS": {},
            "AutoARIMA": {},
            "Theta": {},
            "DeepAR": {},
            "SimpleFeedForward": {},
            "TemporalFusionTransformer": {},
        }
    return model_kwargs


@click.command(
    context_settings=dict(
        ignore_unknown_options=True,
        allow_extra_args=True,
    )
)
@click.option(
    "--dataset_name",
    "-d",
    required=True,
    default="m3_other",
    help="The dataset to train the model on",
    type=click.Choice(AVAILABLE_DATASETS),
)
@click.option(
    "--model_name",
    "-m",
    default="autogluon",
    help="Model to train",
    type=click.Choice(AVAILABLE_MODELS),
)
@click.option(
    "--eval_metric",
    "-e",
    default="MASE",
    type=click.Choice(["MASE", "mean_wQuantileLoss"]),
)
@click.option(
    "--seed",
    "-s",
    default=1,
    type=int,
)
@click.option(
    "--time_limit",
    "-t",
    default=4 * 3600,
    type=int,
)
@click.option(
    "--ablation",
    "-a",
    default=None,
    type=click.Choice(["NoEnsemble", "NoDeepModels", "NoStatModels", "NoTreeModels"]),
)
@click.pass_context
def main(
    ctx,
    dataset_name: str,
    model_name: str,
    eval_metric: str,
    seed: int,
    time_limit: int,
    ablation: Optional[str],
):
    print(f"Evaluating {model_name} on {dataset_name}")
    dataset = load_dataset(dataset_name)
    task_kwargs = {
        "prediction_length": dataset.metadata.prediction_length,
        "freq": dataset.metadata.freq,
        "eval_metric": eval_metric,
        "seasonality": SEASONALITY_MAP[dataset.metadata.freq],
    }
    print("Task definition:")
    pprint.pprint(task_kwargs)

    # Additional command line arguments like `--name value` are parsed as {"name": "value"}
    model_kwargs = {ctx.args[i][2:]: ctx.args[i + 1] for i in range(0, len(ctx.args), 2)}
    model_kwargs["seed"] = seed
    model_kwargs["time_limit"] = time_limit

    if ablation is not None:
        assert model_name == "autogluon", f"{model_name} does not support ablations"
        model_kwargs = apply_ablation(ablation, model_kwargs)

    if len(model_kwargs) > 0:
        print("Model kwargs:")
        pprint.pprint(model_kwargs)

    print(f"Starting training {datetime.datetime.now()}")

    predictions, info = fit_predict_with_model(
        model_name, dataset.train, **task_kwargs, **model_kwargs
    )

    metrics = score_predictions(
        dataset=dataset.test,
        predictions=predictions,
        prediction_length=task_kwargs["prediction_length"],
        seasonality=task_kwargs["seasonality"],
    )
    print("================================================")
    print(f"model: {model_name}")
    print(f"dataset: {dataset_name}")
    print(f"total_run_time: {info['run_time']:.2f}")
    print(f"mase: {metrics['MASE']:.4f}")
    print(f"mean_wQuantileLoss: {metrics['mean_wQuantileLoss']:.4f}")


if __name__ == "__main__":
    main()