kashif's picture
kashif HF staff
sort date index
e460697
raw
history blame
3.23 kB
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
from make_plot import plot_forecast, plot_train_test
def offset_calculation(prediction_length, rolling_windows, length):
row_offset = -1 * prediction_length * rolling_windows
if abs(row_offset) > 0.95 * length:
raise gr.Error("Reduce prediction_length * rolling_windows")
return row_offset
def preprocess(
input_data,
prediction_length,
rolling_windows,
item_id,
progress=gr.Progress(track_tqdm=True),
):
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
df.sort_index(inplace=True)
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:])
def train_and_forecast(
input_data,
prediction_length,
rolling_windows,
epochs,
progress=gr.Progress(track_tqdm=True),
):
if not input_data:
raise gr.Error("Upload a file with the Upload button")
try:
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
df.sort_index(inplace=True)
except AttributeError:
raise gr.Error("Upload a file with the Upload button")
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
try:
gluon_df = PandasDataset(df, target=df.columns[0])
except TypeError:
freq = pd.infer_freq(df.index[:3])
gluon_df = PandasDataset(df, target=df.columns[0], freq=freq)
training_data, test_gen = split(gluon_df, offset=row_offset)
model = DeepAREstimator(
prediction_length=prediction_length,
freq=gluon_df.freq,
trainer_kwargs=dict(max_epochs=epochs),
).train(
training_data=training_data,
)
test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows)
forecasts = list(model.predict(test_data.input))
return plot_forecast(df, forecasts)
with gr.Blocks() as demo:
gr.Markdown("""
# How to use
Upload a univariate csv with the first column showing your dates and the second column having your data
# How it works
1. Click **Upload** to upload your data
2. Click **Run**
- This app will visualize your data and then train an estimator and show its predictions
""")
with gr.Accordion(label='Hyperparameters'):
with gr.Row():
prediction_length = gr.Number(value=12, label='Prediction Length', precision=0)
windows = gr.Number(value=3, label='Number of Windows', precision=0)
epochs = gr.Number(value=10, label='Number of Epochs', precision=0)
with gr.Row():
upload_btn = gr.UploadButton(label="Upload")
train_btn = gr.Button(label="Train and Forecast")
plot = gr.Plot()
upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot)
train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, windows, epochs], outputs=plot)
if __name__ == "__main__":
demo.queue().launch()