from gluonts.dataset.common import Dataset from .models import ( AbstractPredictor, AutoGluonPredictor, AutoPyTorchPredictor, DeepARPredictor, TFTPredictor, AutoARIMAPredictor, AutoETSPredictor, AutoThetaPredictor, StatsEnsemblePredictor, ) MODEL_NAME_TO_CLASS = { "autogluon": AutoGluonPredictor, "autopytorch": AutoPyTorchPredictor, "deepar": DeepARPredictor, "tft": TFTPredictor, "autoarima": AutoARIMAPredictor, "autoets": AutoETSPredictor, "autotheta": AutoThetaPredictor, "statsensemble": StatsEnsemblePredictor, } def fit_predict_with_model( model_name: str, dataset: Dataset, prediction_length: int, freq: str, seasonality: int, **model_kwargs, ): model_class = MODEL_NAME_TO_CLASS[model_name.lower()] model: AbstractPredictor = model_class( prediction_length=prediction_length, freq=freq, seasonality=seasonality, **model_kwargs, ) predictions = model.fit_predict(dataset) info = {"run_time": model.get_runtime()} return predictions, info