from typing import List import numpy as np import pandas as pd import plotly.graph_objects as go def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: """ Plot the training and test datasets using Plotly. Args: df1 (pd.DataFrame): Train dataset df2 (pd.DataFrame): Test dataset Returns: None """ # Create a Plotly figure fig = go.Figure() # Add the first scatter plot with steelblue color fig.add_trace( go.Scatter( x=df1.index, y=df1.iloc[:, 0], mode="lines", name="Training Data", line=dict(color="steelblue"), marker=dict(color="steelblue"), ) ) # Add the second scatter plot with yellow color fig.add_trace( go.Scatter( x=df2.index, y=df2.iloc[:, 0], mode="lines", name="Test Data", line=dict(color="gold"), marker=dict(color="gold"), ) ) # Customize the layout fig.update_layout( title="Univariate Time Series", xaxis=dict(title="Date"), yaxis=dict(title="Value"), showlegend=True, template="plotly_white", ) return fig def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]): """ Plot the true values and forecasts using Plotly. Args: df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. Returns: go.Figure: Plotly figure object. """ # Create a Plotly figure fig = go.Figure() # Add the true values trace fig.add_trace( go.Scatter( x=pd.to_datetime(df.index), y=df.iloc[:, 0], mode="lines", name="True values", line=dict(color="black"), ) ) # Add the forecast traces colors = ["green", "blue", "purple"] for i, forecast in enumerate(forecasts): color = colors[i] for sample in forecast.samples: fig.add_trace( go.Scatter( x=forecast.index.to_timestamp(), y=sample, mode="lines", opacity=0.15, # Adjust opacity to control visibility of individual samples name=f"Forecast {i + 1}", showlegend=False, # Hide the individual forecast series from the legend hoverinfo="none", # Disable hover information for the forecast series line=dict(color=color), ) ) # Add the average mean_forecast = np.mean(forecast.samples, axis=0) fig.add_trace( go.Scatter( x=forecast.index.to_timestamp(), y=mean_forecast, mode="lines", name="Mean Forecast", line=dict(color="red", dash="dash"), ) ) # Customize the layout fig.update_layout( title=f"{df.columns[0]} Forecast", yaxis=dict(title=df.columns[0]), showlegend=True, legend=dict(x=0, y=1, font=dict(size=16)), hovermode="x", # Enable x-axis hover for better interactivity ) # Return the figure return fig