kashif HF staff commited on
Commit
49f2f3a
1 Parent(s): b359e4d

added initial evaluation

Browse files
Files changed (1) hide show
  1. app.py +44 -17
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  from gluonts.dataset.pandas import PandasDataset
4
  from gluonts.dataset.split import split
5
  from gluonts.torch.model.deepar import DeepAREstimator
 
6
 
7
  from make_plot import plot_forecast, plot_train_test
8
 
@@ -54,16 +55,29 @@ def train_and_forecast(
54
 
55
  training_data, test_gen = split(gluon_df, offset=row_offset)
56
 
57
- model = DeepAREstimator(
58
- prediction_length=prediction_length,
59
- freq=gluon_df.freq,
60
- trainer_kwargs=dict(max_epochs=epochs),
61
- ).train(
62
- training_data=training_data,
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows)
66
- forecasts = list(model.predict(test_data.input))
67
  return plot_forecast(df, forecasts)
68
 
69
 
@@ -87,19 +101,32 @@ with gr.Blocks() as demo:
87
  1. Click **Upload** to upload your data
88
  2. Click **Run**
89
  - This app will visualize your data and then train an estimator and show its predictions
90
- """)
91
- with gr.Accordion(label='Hyperparameters'):
 
92
  with gr.Row():
93
- prediction_length = gr.Number(value=12, label='Prediction Length', precision=0)
94
- windows = gr.Number(value=3, label='Number of Windows', precision=0)
95
- epochs = gr.Number(value=10, label='Number of Epochs', precision=0)
96
- with gr.Row():
 
 
 
 
97
  upload_btn = gr.UploadButton(label="Upload")
98
  train_btn = gr.Button(label="Train and Forecast")
99
  plot = gr.Plot()
100
 
101
- upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot)
102
- train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, windows, epochs], outputs=plot)
 
 
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
  demo.queue().launch()
 
3
  from gluonts.dataset.pandas import PandasDataset
4
  from gluonts.dataset.split import split
5
  from gluonts.torch.model.deepar import DeepAREstimator
6
+ from gluonts.evaluation import Evaluator, make_evaluation_predictions
7
 
8
  from make_plot import plot_forecast, plot_train_test
9
 
 
55
 
56
  training_data, test_gen = split(gluon_df, offset=row_offset)
57
 
58
+ estimator = DeepAREstimator(
59
+ prediction_length=prediction_length,
60
+ freq=gluon_df.freq,
61
+ trainer_kwargs=dict(max_epochs=epochs),
62
+ )
63
+
64
+ predictor = estimator.train(
65
+ training_data=training_data,
66
+ )
67
+
68
+ test_data = test_gen.generate_instances(
69
+ prediction_length=prediction_length, windows=rolling_windows
70
+ )
71
+
72
+ evaluator = Evaluator(num_workers=0)
73
+ forecast_it, ts_it = make_evaluation_predictions(
74
+ dataset=test_data.input, predictor=predictor
75
+ )
76
+
77
+ forecasts = list(predictor.predict(test_data.input))
78
+
79
+ agg_metrics, _ = evaluator(ts_it, forecast_it)
80
 
 
 
81
  return plot_forecast(df, forecasts)
82
 
83
 
 
101
  1. Click **Upload** to upload your data
102
  2. Click **Run**
103
  - This app will visualize your data and then train an estimator and show its predictions
104
+ """
105
+ )
106
+ with gr.Accordion(label="Hyperparameters"):
107
  with gr.Row():
108
+ prediction_length = gr.Number(
109
+ value=12, label="Prediction Length", precision=0
110
+ )
111
+ windows = gr.Number(value=3, label="Number of Windows", precision=0)
112
+ epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
113
+
114
+ with gr.Row(label="Dataset"):
115
+ item_id = gr.Textbox(label="Item ID")
116
  upload_btn = gr.UploadButton(label="Upload")
117
  train_btn = gr.Button(label="Train and Forecast")
118
  plot = gr.Plot()
119
 
120
+ upload_btn.upload(
121
+ fn=preprocess,
122
+ inputs=[upload_btn, prediction_length, windows, item_id],
123
+ outputs=plot,
124
+ )
125
+ train_btn.click(
126
+ fn=train_and_forecast,
127
+ inputs=[upload_btn, prediction_length, windows, epochs],
128
+ outputs=plot,
129
+ )
130
 
131
  if __name__ == "__main__":
132
  demo.queue().launch()