add distribution type
Browse files
app.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
|
|
3 |
from gluonts.dataset.pandas import PandasDataset
|
4 |
from gluonts.dataset.split import split
|
5 |
from gluonts.torch.model.deepar import DeepAREstimator
|
|
|
|
|
|
|
|
|
|
|
6 |
from gluonts.evaluation import Evaluator, make_evaluation_predictions
|
7 |
|
8 |
from make_plot import plot_forecast, plot_train_test
|
@@ -32,6 +38,7 @@ def train_and_forecast(
|
|
32 |
prediction_length,
|
33 |
rolling_windows,
|
34 |
epochs,
|
|
|
35 |
progress=gr.Progress(track_tqdm=True),
|
36 |
):
|
37 |
if not input_data:
|
@@ -54,7 +61,14 @@ def train_and_forecast(
|
|
54 |
|
55 |
training_data, test_gen = split(gluon_df, offset=row_offset)
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
estimator = DeepAREstimator(
|
|
|
58 |
prediction_length=prediction_length,
|
59 |
freq=gluon_df.freq,
|
60 |
trainer_kwargs=dict(max_epochs=epochs),
|
@@ -108,6 +122,11 @@ with gr.Blocks() as demo:
|
|
108 |
)
|
109 |
windows = gr.Number(value=3, label="Number of Windows", precision=0)
|
110 |
epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
with gr.Row(label="Dataset"):
|
113 |
upload_btn = gr.UploadButton(label="Upload")
|
@@ -122,7 +141,7 @@ with gr.Blocks() as demo:
|
|
122 |
)
|
123 |
train_btn.click(
|
124 |
fn=train_and_forecast,
|
125 |
-
inputs=[upload_btn, prediction_length, windows, epochs],
|
126 |
outputs=[plot, json],
|
127 |
)
|
128 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
+
|
4 |
from gluonts.dataset.pandas import PandasDataset
|
5 |
from gluonts.dataset.split import split
|
6 |
from gluonts.torch.model.deepar import DeepAREstimator
|
7 |
+
from gluonts.torch.distributions import (
|
8 |
+
NegativeBinomialOutput,
|
9 |
+
StudentTOutput,
|
10 |
+
NormalOutput,
|
11 |
+
)
|
12 |
from gluonts.evaluation import Evaluator, make_evaluation_predictions
|
13 |
|
14 |
from make_plot import plot_forecast, plot_train_test
|
|
|
38 |
prediction_length,
|
39 |
rolling_windows,
|
40 |
epochs,
|
41 |
+
distribution,
|
42 |
progress=gr.Progress(track_tqdm=True),
|
43 |
):
|
44 |
if not input_data:
|
|
|
61 |
|
62 |
training_data, test_gen = split(gluon_df, offset=row_offset)
|
63 |
|
64 |
+
if distribution == "StudentT":
|
65 |
+
distr_output = StudentTOutput()
|
66 |
+
elif distribution == "Normal":
|
67 |
+
distr_output = NormalOutput()
|
68 |
+
else:
|
69 |
+
distr_output = NegativeBinomialOutput()
|
70 |
estimator = DeepAREstimator(
|
71 |
+
distr_output=distr_output,
|
72 |
prediction_length=prediction_length,
|
73 |
freq=gluon_df.freq,
|
74 |
trainer_kwargs=dict(max_epochs=epochs),
|
|
|
122 |
)
|
123 |
windows = gr.Number(value=3, label="Number of Windows", precision=0)
|
124 |
epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
|
125 |
+
distribution = gr.Radio(
|
126 |
+
choices=["StudentT", "Negative Binomial", "Normal"],
|
127 |
+
value="StudentT",
|
128 |
+
label="Distribution",
|
129 |
+
)
|
130 |
|
131 |
with gr.Row(label="Dataset"):
|
132 |
upload_btn = gr.UploadButton(label="Upload")
|
|
|
141 |
)
|
142 |
train_btn.click(
|
143 |
fn=train_and_forecast,
|
144 |
+
inputs=[upload_btn, prediction_length, windows, epochs, distribution],
|
145 |
outputs=[plot, json],
|
146 |
)
|
147 |
|