kashif HF staff commited on
Commit
74d9e70
1 Parent(s): 49f2f3a

added evaluation metrics

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -19,7 +19,6 @@ def preprocess(
19
  input_data,
20
  prediction_length,
21
  rolling_windows,
22
- item_id,
23
  progress=gr.Progress(track_tqdm=True),
24
  ):
25
  df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
@@ -73,12 +72,11 @@ def train_and_forecast(
73
  forecast_it, ts_it = make_evaluation_predictions(
74
  dataset=test_data.input, predictor=predictor
75
  )
 
76
 
77
  forecasts = list(predictor.predict(test_data.input))
78
 
79
- agg_metrics, _ = evaluator(ts_it, forecast_it)
80
-
81
- return plot_forecast(df, forecasts)
82
 
83
 
84
  with gr.Blocks() as demo:
@@ -112,20 +110,20 @@ with gr.Blocks() as demo:
112
  epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
113
 
114
  with gr.Row(label="Dataset"):
115
- item_id = gr.Textbox(label="Item ID")
116
  upload_btn = gr.UploadButton(label="Upload")
117
  train_btn = gr.Button(label="Train and Forecast")
118
  plot = gr.Plot()
 
119
 
120
  upload_btn.upload(
121
  fn=preprocess,
122
- inputs=[upload_btn, prediction_length, windows, item_id],
123
  outputs=plot,
124
  )
125
  train_btn.click(
126
  fn=train_and_forecast,
127
  inputs=[upload_btn, prediction_length, windows, epochs],
128
- outputs=plot,
129
  )
130
 
131
  if __name__ == "__main__":
 
19
  input_data,
20
  prediction_length,
21
  rolling_windows,
 
22
  progress=gr.Progress(track_tqdm=True),
23
  ):
24
  df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
 
72
  forecast_it, ts_it = make_evaluation_predictions(
73
  dataset=test_data.input, predictor=predictor
74
  )
75
+ agg_metrics, _ = evaluator(ts_it, forecast_it)
76
 
77
  forecasts = list(predictor.predict(test_data.input))
78
 
79
+ return plot_forecast(df, forecasts, agg_metrics)
 
 
80
 
81
 
82
  with gr.Blocks() as demo:
 
110
  epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
111
 
112
  with gr.Row(label="Dataset"):
 
113
  upload_btn = gr.UploadButton(label="Upload")
114
  train_btn = gr.Button(label="Train and Forecast")
115
  plot = gr.Plot()
116
+ json = gr.JSON(label="Evaluation Metrics")
117
 
118
  upload_btn.upload(
119
  fn=preprocess,
120
+ inputs=[upload_btn, prediction_length, windows],
121
  outputs=plot,
122
  )
123
  train_btn.click(
124
  fn=train_and_forecast,
125
  inputs=[upload_btn, prediction_length, windows, epochs],
126
+ outputs=[plot, json],
127
  )
128
 
129
  if __name__ == "__main__":