probabilistic-forecast / make_plot.py
derek-thomas's picture
derek-thomas HF staff
Duplicate from derek-thomas/probabilistic-forecast
97ab62b
raw
history blame
3.42 kB
from typing import List
import numpy as np
import pandas as pd
import plotly.graph_objects as go
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
"""
Plot the training and test datasets using Plotly.
Args:
df1 (pd.DataFrame): Train dataset
df2 (pd.DataFrame): Test dataset
Returns:
None
"""
# Create a Plotly figure
fig = go.Figure()
# Add the first scatter plot with steelblue color
fig.add_trace(go.Scatter(
x=df1.index,
y=df1.iloc[:, 0],
mode='lines',
name='Training Data',
line=dict(color='steelblue'),
marker=dict(color='steelblue')
))
# Add the second scatter plot with yellow color
fig.add_trace(go.Scatter(
x=df2.index,
y=df2.iloc[:, 0],
mode='lines',
name='Test Data',
line=dict(color='gold'),
marker=dict(color='gold')
))
# Customize the layout
fig.update_layout(
title='Univariate Time Series',
xaxis=dict(title='Date'),
yaxis=dict(title='Value'),
showlegend=True,
template='plotly_white'
)
return fig
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
"""
Plot the true values and forecasts using Plotly.
Args:
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
Returns:
go.Figure: Plotly figure object.
"""
# Create a Plotly figure
fig = go.Figure()
# Add the true values trace
fig.add_trace(go.Scatter(
x=pd.to_datetime(df.index),
y=df.iloc[:, 0],
mode='lines',
name='True values',
line=dict(color='black')
))
# Add the forecast traces
colors = ["green", "blue", "purple"]
for i, forecast in enumerate(forecasts):
color = colors[i]
for sample in forecast.samples:
fig.add_trace(go.Scatter(
x=forecast.index.to_timestamp(),
y=sample,
mode='lines',
opacity=0.15, # Adjust opacity to control visibility of individual samples
name=f'Forecast {i + 1}',
showlegend=False, # Hide the individual forecast series from the legend
hoverinfo='none', # Disable hover information for the forecast series
line=dict(color=color)
))
# Add the average
mean_forecast = np.mean(forecast.samples, axis=0)
fig.add_trace(go.Scatter(
x=forecast.index.to_timestamp(),
y=mean_forecast,
mode='lines',
name=f'Mean Forecast',
line=dict(color='red', dash='dash')
))
# Customize the layout
fig.update_layout(
title='Passenger Forecast',
xaxis=dict(title='Index'),
yaxis=dict(title='Passenger Count'),
showlegend=True,
legend=dict(x=0, y=1, font=dict(size=16)),
hovermode='x' # Enable x-axis hover for better interactivity
)
# Return the figure
return fig