probabilistic-forecast / make_plot.py
kashif's picture
kashif HF staff
update plotly and show only one legend for mean
188d8b2
raw
history blame
3.47 kB
from typing import List
import numpy as np
import pandas as pd
import plotly.graph_objects as go
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
"""
Plot the training and test datasets using Plotly.
Args:
df1 (pd.DataFrame): Train dataset
df2 (pd.DataFrame): Test dataset
Returns:
None
"""
# Create a Plotly figure
fig = go.Figure()
# Add the first scatter plot with steelblue color
fig.add_trace(
go.Scatter(
x=df1.index,
y=df1.iloc[:, 0],
mode="lines",
name="Training Data",
line=dict(color="steelblue"),
marker=dict(color="steelblue"),
)
)
# Add the second scatter plot with yellow color
fig.add_trace(
go.Scatter(
x=df2.index,
y=df2.iloc[:, 0],
mode="lines",
name="Test Data",
line=dict(color="gold"),
marker=dict(color="gold"),
)
)
# Customize the layout
fig.update_layout(
title="Univariate Time Series",
xaxis=dict(title="Date"),
yaxis=dict(title="Value"),
showlegend=True,
template="plotly_white",
)
return fig
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
"""
Plot the true values and forecasts using Plotly.
Args:
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
Returns:
go.Figure: Plotly figure object.
"""
# Create a Plotly figure
fig = go.Figure()
# Add the true values trace
fig.add_trace(
go.Scatter(
x=pd.to_datetime(df.index),
y=df.iloc[:, 0],
mode="lines",
name="True values",
line=dict(color="black"),
)
)
# Add the forecast traces
colors = ["green", "blue", "purple"]
for i, forecast in enumerate(forecasts):
color = colors[i % len(colors)]
for sample in forecast.samples:
fig.add_trace(
go.Scatter(
x=forecast.index.to_timestamp(),
y=sample,
mode="lines",
opacity=0.15, # Adjust opacity to control visibility of individual samples
name=f"Forecast {i + 1}",
showlegend=False, # Hide the individual forecast series from the legend
hoverinfo="none", # Disable hover information for the forecast series
line=dict(color=color),
)
)
# Add the average
mean_forecast = np.mean(forecast.samples, axis=0)
fig.add_trace(
go.Scatter(
x=forecast.index.to_timestamp(),
y=mean_forecast,
mode="lines",
name="Mean Forecast",
line=dict(color="red", dash="dash"),
legendgroup="mean forecast",
showlegend=i == 0,
)
)
# Customize the layout
fig.update_layout(
title=f"{df.columns[0]} Forecast",
yaxis=dict(title=df.columns[0]),
showlegend=True,
legend=dict(x=0, y=1),
hovermode="x", # Enable x-axis hover for better interactivity
)
# Return the figure
return fig