|
import gradio as gr |
|
import pandas as pd |
|
from gluonts.dataset.pandas import PandasDataset |
|
from gluonts.dataset.split import split |
|
from gluonts.torch.model.deepar import DeepAREstimator |
|
|
|
from make_plot import plot_forecast, plot_train_test |
|
|
|
|
|
def offset_calculation(prediction_length, rolling_windows, length): |
|
row_offset = -1 * prediction_length * rolling_windows |
|
if abs(row_offset) > 0.95 * length: |
|
raise gr.Error("Reduce prediction_length * rolling_windows") |
|
return row_offset |
|
|
|
|
|
def preprocess( |
|
input_data, |
|
prediction_length, |
|
rolling_windows, |
|
item_id, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True) |
|
df.sort_index(inplace=True) |
|
row_offset = offset_calculation(prediction_length, rolling_windows, len(df)) |
|
return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:]) |
|
|
|
|
|
def train_and_forecast( |
|
input_data, |
|
prediction_length, |
|
rolling_windows, |
|
epochs, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
if not input_data: |
|
raise gr.Error("Upload a file with the Upload button") |
|
try: |
|
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True) |
|
df.sort_index(inplace=True) |
|
except AttributeError: |
|
raise gr.Error("Upload a file with the Upload button") |
|
|
|
row_offset = offset_calculation(prediction_length, rolling_windows, len(df)) |
|
|
|
try: |
|
gluon_df = PandasDataset(df, target=df.columns[0]) |
|
except TypeError: |
|
freq = pd.infer_freq(df.index[:3]) |
|
date_range = pd.date_range(df.index[0], df.index[-1], freq=freq) |
|
new_df = df.reindex(date_range) |
|
gluon_df = PandasDataset(new_df, target=new_df.columns[0], freq=freq) |
|
|
|
training_data, test_gen = split(gluon_df, offset=row_offset) |
|
|
|
model = DeepAREstimator( |
|
prediction_length=prediction_length, |
|
freq=gluon_df.freq, |
|
trainer_kwargs=dict(max_epochs=epochs), |
|
).train( |
|
training_data=training_data, |
|
) |
|
|
|
test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows) |
|
forecasts = list(model.predict(test_data.input)) |
|
return plot_forecast(df, forecasts) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# How to use |
|
Upload a univariate csv with the first column showing your dates and the second column having your data |
|
|
|
# How it works |
|
1. Click **Upload** to upload your data |
|
2. Click **Run** |
|
- This app will visualize your data and then train an estimator and show its predictions |
|
""") |
|
with gr.Accordion(label='Hyperparameters'): |
|
with gr.Row(): |
|
prediction_length = gr.Number(value=12, label='Prediction Length', precision=0) |
|
windows = gr.Number(value=3, label='Number of Windows', precision=0) |
|
epochs = gr.Number(value=10, label='Number of Epochs', precision=0) |
|
with gr.Row(): |
|
upload_btn = gr.UploadButton(label="Upload") |
|
train_btn = gr.Button(label="Train and Forecast") |
|
plot = gr.Plot() |
|
|
|
upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot) |
|
train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, windows, epochs], outputs=plot) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |
|
|