|
from typing import List |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
|
|
|
|
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: |
|
""" |
|
Plot the training and test datasets using Plotly. |
|
|
|
Args: |
|
df1 (pd.DataFrame): Train dataset |
|
df2 (pd.DataFrame): Test dataset |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=df1.index, |
|
y=df1.iloc[:, 0], |
|
mode='lines', |
|
name='Training Data', |
|
line=dict(color='steelblue'), |
|
marker=dict(color='steelblue') |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=df2.index, |
|
y=df2.iloc[:, 0], |
|
mode='lines', |
|
name='Test Data', |
|
line=dict(color='gold'), |
|
marker=dict(color='gold') |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title='Univariate Time Series', |
|
xaxis=dict(title='Date'), |
|
yaxis=dict(title='Value'), |
|
showlegend=True, |
|
template='plotly_white' |
|
) |
|
return fig |
|
|
|
|
|
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure: |
|
""" |
|
Plot the true values and forecasts using Plotly. |
|
|
|
Args: |
|
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. |
|
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. |
|
|
|
Returns: |
|
go.Figure: Plotly figure object. |
|
""" |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=pd.to_datetime(df.index), |
|
y=df.iloc[:, 0], |
|
mode='lines', |
|
name='True values', |
|
line=dict(color='black') |
|
)) |
|
|
|
|
|
colors = ["green", "blue", "purple"] |
|
for i, forecast in enumerate(forecasts): |
|
color = colors[i] |
|
for sample in forecast.samples: |
|
fig.add_trace(go.Scatter( |
|
x=forecast.index.to_timestamp(), |
|
y=sample, |
|
mode='lines', |
|
opacity=0.15, |
|
name=f'Forecast {i + 1}', |
|
showlegend=False, |
|
hoverinfo='none', |
|
line=dict(color=color) |
|
)) |
|
|
|
mean_forecast = np.mean(forecast.samples, axis=0) |
|
fig.add_trace(go.Scatter( |
|
x=forecast.index.to_timestamp(), |
|
y=mean_forecast, |
|
mode='lines', |
|
name=f'Mean Forecast', |
|
line=dict(color='red', dash='dash') |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title='Passenger Forecast', |
|
xaxis=dict(title='Index'), |
|
yaxis=dict(title='Passenger Count'), |
|
showlegend=True, |
|
legend=dict(x=0, y=1, font=dict(size=16)), |
|
hovermode='x' |
|
) |
|
|
|
|
|
return fig |
|
|