kashif's picture
kashif HF staff
fix typo
eed7913
raw
history blame
No virus
6.16 kB
import os
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.torch.distributions import (
NegativeBinomialOutput,
StudentTOutput,
NormalOutput,
)
from gluonts.evaluation import Evaluator, make_evaluation_predictions
from make_plot import plot_forecast, plot_train_test
def offset_calculation(prediction_length, rolling_windows, length):
row_offset = -1 * prediction_length * rolling_windows
if abs(row_offset) > 0.95 * length:
raise gr.Error("Reduce prediction_length * rolling_windows")
return row_offset
def preprocess(
input_data,
prediction_length,
rolling_windows,
progress=gr.Progress(track_tqdm=True),
):
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
df.sort_index(inplace=True)
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:])
def train_and_forecast(
input_data,
file_data,
prediction_length,
rolling_windows,
epochs,
distribution,
progress=gr.Progress(track_tqdm=True),
):
if not input_data and not file_data:
raise gr.Error("Upload a file with the Upload button")
try:
if input_data:
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
else:
df = pd.read_csv(file_data.name, index_col=0, parse_dates=True)
df.sort_index(inplace=True)
except AttributeError:
raise gr.Error("Upload a file with the Upload button")
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
try:
gluon_df = PandasDataset(df, target=df.columns[0])
except TypeError:
freq = pd.infer_freq(df.index[:3])
date_range = pd.date_range(df.index[0], df.index[-1], freq=freq)
new_df = df.reindex(date_range)
gluon_df = PandasDataset(new_df, target=new_df.columns[0], freq=freq)
training_data, test_gen = split(gluon_df, offset=row_offset)
if distribution == "StudentT":
distr_output = StudentTOutput()
elif distribution == "Normal":
distr_output = NormalOutput()
else:
distr_output = NegativeBinomialOutput()
estimator = DeepAREstimator(
distr_output=distr_output,
prediction_length=prediction_length,
freq=gluon_df.freq,
trainer_kwargs=dict(max_epochs=epochs),
)
predictor = estimator.train(
training_data=training_data,
)
test_data = test_gen.generate_instances(
prediction_length=prediction_length, windows=rolling_windows
)
evaluator = Evaluator(num_workers=0)
forecast_it, ts_it = make_evaluation_predictions(
dataset=test_data.input, predictor=predictor
)
agg_metrics, _ = evaluator(ts_it, forecast_it)
forecasts = list(predictor.predict(test_data.input))
return plot_forecast(df, forecasts), agg_metrics
with gr.Blocks() as demo:
gr.Markdown(
"""
# Probabilistic Time Series Forecasting
## How to use
Upload a *univariate* csv where the first column contains date-times and the second column is your data for example:
| ds | y |
|------------|---------------|
| 2007-12-10 | 9.590761 |
| 2007-12-11 | 8.519590 |
| 2007-12-12 | 8.183677 |
| 2007-12-13 | 8.072467 |
| 2007-12-14 | 7.893572 |
## Steps
1. Click **Upload** to upload your data and visualize it **or** select one of the example CSV files.
2. Click **Run**
- This app will then train an estimator and show its predictions as well as evaluation metrics.
"""
)
with gr.Accordion(label="Hyperparameters"):
with gr.Row():
prediction_length = gr.Number(
value=12, label="Prediction Length", precision=0
)
windows = gr.Number(value=3, label="Number of Windows", precision=0)
epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
distribution = gr.Radio(
choices=["StudentT", "Negative Binomial", "Normal"],
value="StudentT",
label="Distribution",
)
with gr.Row(label="Dataset"):
upload_btn = gr.UploadButton(label="Upload")
train_btn = gr.Button(label="Train and Forecast")
plot = gr.Plot()
json = gr.JSON(label="Evaluation Metrics")
file_output = gr.File()
upload_btn.upload(
fn=preprocess,
inputs=[upload_btn, prediction_length, windows],
outputs=[plot],
)
train_btn.click(
fn=train_and_forecast,
inputs=[
upload_btn,
file_output,
prediction_length,
windows,
epochs,
distribution,
],
outputs=[plot, json],
)
with gr.Row(label="Example Data"):
examples = gr.Examples(
examples=[
[
os.path.join(
os.path.dirname(__file__),
"examples",
"example_air_passengers.csv",
),
12,
3,
],
[
os.path.join(
os.path.dirname(__file__),
"examples",
"example_retail_sales.csv",
),
12,
3,
],
[
os.path.join(
os.path.dirname(__file__),
"examples",
"example_pedestrians_covid.csv",
),
12,
3,
],
],
fn=preprocess,
inputs=[file_output, prediction_length, windows],
outputs=[plot],
run_on_click=True,
)
if __name__ == "__main__":
demo.queue().launch()