from typing import List import numpy as np import pandas as pd import plotly.graph_objects as go def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: """ Plot the training and test datasets using Plotly. Args: df1 (pd.DataFrame): Train dataset df2 (pd.DataFrame): Test dataset Returns: None """ # Create a Plotly figure fig = go.Figure() # Add the first scatter plot with steelblue color fig.add_trace(go.Scatter( x=df1.index, y=df1.iloc[:, 0], mode='lines', name='Training Data', line=dict(color='steelblue'), marker=dict(color='steelblue') )) # Add the second scatter plot with yellow color fig.add_trace(go.Scatter( x=df2.index, y=df2.iloc[:, 0], mode='lines', name='Test Data', line=dict(color='gold'), marker=dict(color='gold') )) # Customize the layout fig.update_layout( title='Univariate Time Series', xaxis=dict(title='Date'), yaxis=dict(title='Value'), showlegend=True, template='plotly_white' ) return fig def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure: """ Plot the true values and forecasts using Plotly. Args: df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. Returns: go.Figure: Plotly figure object. """ # Create a Plotly figure fig = go.Figure() # Add the true values trace fig.add_trace(go.Scatter( x=pd.to_datetime(df.index), y=df.iloc[:, 0], mode='lines', name='True values', line=dict(color='black') )) # Add the forecast traces colors = ["green", "blue", "purple"] for i, forecast in enumerate(forecasts): color = colors[i] for sample in forecast.samples: fig.add_trace(go.Scatter( x=forecast.index.to_timestamp(), y=sample, mode='lines', opacity=0.15, # Adjust opacity to control visibility of individual samples name=f'Forecast {i + 1}', showlegend=False, # Hide the individual forecast series from the legend hoverinfo='none', # Disable hover information for the forecast series line=dict(color=color) )) # Add the average mean_forecast = np.mean(forecast.samples, axis=0) fig.add_trace(go.Scatter( x=forecast.index.to_timestamp(), y=mean_forecast, mode='lines', name=f'Mean Forecast', line=dict(color='red', dash='dash') )) # Customize the layout fig.update_layout( title='Passenger Forecast', xaxis=dict(title='Index'), yaxis=dict(title='Passenger Count'), showlegend=True, legend=dict(x=0, y=1, font=dict(size=16)), hovermode='x' # Enable x-axis hover for better interactivity ) # Return the figure return fig