|
from typing import List |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
|
|
|
|
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: |
|
""" |
|
Plot the training and test datasets using Plotly. |
|
|
|
Args: |
|
df1 (pd.DataFrame): Train dataset |
|
df2 (pd.DataFrame): Test dataset |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=df1.index, |
|
y=df1.iloc[:, 0], |
|
mode="lines", |
|
name="Training Data", |
|
line=dict(color="steelblue"), |
|
marker=dict(color="steelblue"), |
|
) |
|
) |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=df2.index, |
|
y=df2.iloc[:, 0], |
|
mode="lines", |
|
name="Test Data", |
|
line=dict(color="gold"), |
|
marker=dict(color="gold"), |
|
) |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title="Univariate Time Series", |
|
xaxis=dict(title="Date"), |
|
yaxis=dict(title="Value"), |
|
showlegend=True, |
|
template="plotly_white", |
|
) |
|
return fig |
|
|
|
|
|
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]): |
|
""" |
|
Plot the true values and forecasts using Plotly. |
|
|
|
Args: |
|
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. |
|
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. |
|
|
|
Returns: |
|
go.Figure: Plotly figure object. |
|
""" |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=pd.to_datetime(df.index), |
|
y=df.iloc[:, 0], |
|
mode="lines", |
|
name="True values", |
|
line=dict(color="black"), |
|
) |
|
) |
|
|
|
|
|
colors = ["green", "blue", "purple"] |
|
for i, forecast in enumerate(forecasts): |
|
color = colors[i % len(colors)] |
|
for sample in forecast.samples: |
|
fig.add_trace( |
|
go.Scatter( |
|
x=forecast.index.to_timestamp(), |
|
y=sample, |
|
mode="lines", |
|
opacity=0.15, |
|
name=f"Forecast {i + 1}", |
|
showlegend=False, |
|
hoverinfo="none", |
|
line=dict(color=color), |
|
) |
|
) |
|
|
|
mean_forecast = np.mean(forecast.samples, axis=0) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=forecast.index.to_timestamp(), |
|
y=mean_forecast, |
|
mode="lines", |
|
name="Mean Forecast", |
|
line=dict(color="red", dash="dash"), |
|
legendgroup="mean forecast", |
|
showlegend=i == 0, |
|
) |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title=f"{df.columns[0]} Forecast", |
|
yaxis=dict(title=df.columns[0]), |
|
showlegend=True, |
|
legend=dict(x=0, y=1), |
|
hovermode="x", |
|
) |
|
|
|
|
|
return fig |
|
|