|
import click |
|
import datetime |
|
import pprint |
|
from typing import Optional |
|
|
|
|
|
from src import ( |
|
load_dataset, |
|
fit_predict_with_model, |
|
score_predictions, |
|
AVAILABLE_DATASETS, |
|
AVAILABLE_MODELS, |
|
SEASONALITY_MAP, |
|
) |
|
|
|
|
|
def apply_ablation(ablation: str, model_kwargs: dict) -> dict: |
|
if ablation == "NoEnsemble": |
|
model_kwargs["enable_ensemble"] = False |
|
elif ablation == "NoDeepModels": |
|
model_kwargs["hyperparameters"] = { |
|
"Naive": {}, |
|
"SeasonalNaive": {}, |
|
"ARIMA": {}, |
|
"ETS": {}, |
|
"AutoETS": {}, |
|
"AutoARIMA": {}, |
|
"Theta": {}, |
|
"AutoGluonTabular": {}, |
|
} |
|
elif ablation == "NoStatModels": |
|
model_kwargs["hyperparameters"] = { |
|
"AutoGluonTabular": {}, |
|
"DeepAR": {}, |
|
"SimpleFeedForward": {}, |
|
"TemporalFusionTransformer": {}, |
|
} |
|
elif ablation == "NoTreeModels": |
|
model_kwargs["hyperparameters"] = { |
|
"Naive": {}, |
|
"SeasonalNaive": {}, |
|
"ARIMA": {}, |
|
"ETS": {}, |
|
"AutoETS": {}, |
|
"AutoARIMA": {}, |
|
"Theta": {}, |
|
"DeepAR": {}, |
|
"SimpleFeedForward": {}, |
|
"TemporalFusionTransformer": {}, |
|
} |
|
return model_kwargs |
|
|
|
|
|
@click.command( |
|
context_settings=dict( |
|
ignore_unknown_options=True, |
|
allow_extra_args=True, |
|
) |
|
) |
|
@click.option( |
|
"--dataset_name", |
|
"-d", |
|
required=True, |
|
default="m3_other", |
|
help="The dataset to train the model on", |
|
type=click.Choice(AVAILABLE_DATASETS), |
|
) |
|
@click.option( |
|
"--model_name", |
|
"-m", |
|
default="autogluon", |
|
help="Model to train", |
|
type=click.Choice(AVAILABLE_MODELS), |
|
) |
|
@click.option( |
|
"--eval_metric", |
|
"-e", |
|
default="MASE", |
|
type=click.Choice(["MASE", "mean_wQuantileLoss"]), |
|
) |
|
@click.option( |
|
"--seed", |
|
"-s", |
|
default=1, |
|
type=int, |
|
) |
|
@click.option( |
|
"--time_limit", |
|
"-t", |
|
default=4 * 3600, |
|
type=int, |
|
) |
|
@click.option( |
|
"--ablation", |
|
"-a", |
|
default=None, |
|
type=click.Choice(["NoEnsemble", "NoDeepModels", "NoStatModels", "NoTreeModels"]), |
|
) |
|
@click.pass_context |
|
def main( |
|
ctx, |
|
dataset_name: str, |
|
model_name: str, |
|
eval_metric: str, |
|
seed: int, |
|
time_limit: int, |
|
ablation: Optional[str], |
|
): |
|
print(f"Evaluating {model_name} on {dataset_name}") |
|
dataset = load_dataset(dataset_name) |
|
task_kwargs = { |
|
"prediction_length": dataset.metadata.prediction_length, |
|
"freq": dataset.metadata.freq, |
|
"eval_metric": eval_metric, |
|
"seasonality": SEASONALITY_MAP[dataset.metadata.freq], |
|
} |
|
print("Task definition:") |
|
pprint.pprint(task_kwargs) |
|
|
|
|
|
model_kwargs = {ctx.args[i][2:]: ctx.args[i + 1] for i in range(0, len(ctx.args), 2)} |
|
model_kwargs["seed"] = seed |
|
model_kwargs["time_limit"] = time_limit |
|
|
|
if ablation is not None: |
|
assert model_name == "autogluon", f"{model_name} does not support ablations" |
|
model_kwargs = apply_ablation(ablation, model_kwargs) |
|
|
|
if len(model_kwargs) > 0: |
|
print("Model kwargs:") |
|
pprint.pprint(model_kwargs) |
|
|
|
print(f"Starting training {datetime.datetime.now()}") |
|
|
|
predictions, info = fit_predict_with_model( |
|
model_name, dataset.train, **task_kwargs, **model_kwargs |
|
) |
|
|
|
metrics = score_predictions( |
|
dataset=dataset.test, |
|
predictions=predictions, |
|
prediction_length=task_kwargs["prediction_length"], |
|
seasonality=task_kwargs["seasonality"], |
|
) |
|
print("================================================") |
|
print(f"model: {model_name}") |
|
print(f"dataset: {dataset_name}") |
|
print(f"total_run_time: {info['run_time']:.2f}") |
|
print(f"mase: {metrics['MASE']:.4f}") |
|
print(f"mean_wQuantileLoss: {metrics['mean_wQuantileLoss']:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|