kashif HF staff commited on
Commit
ccf318d
1 Parent(s): f1bfd9d

updated title of plots

Browse files
Files changed (1) hide show
  1. make_plot.py +51 -42
make_plot.py CHANGED
@@ -21,37 +21,41 @@ def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
21
  fig = go.Figure()
22
 
23
  # Add the first scatter plot with steelblue color
24
- fig.add_trace(go.Scatter(
 
25
  x=df1.index,
26
  y=df1.iloc[:, 0],
27
- mode='lines',
28
- name='Training Data',
29
- line=dict(color='steelblue'),
30
- marker=dict(color='steelblue')
31
- ))
 
32
 
33
  # Add the second scatter plot with yellow color
34
- fig.add_trace(go.Scatter(
 
35
  x=df2.index,
36
  y=df2.iloc[:, 0],
37
- mode='lines',
38
- name='Test Data',
39
- line=dict(color='gold'),
40
- marker=dict(color='gold')
41
- ))
 
42
 
43
  # Customize the layout
44
  fig.update_layout(
45
- title='Univariate Time Series',
46
- xaxis=dict(title='Date'),
47
- yaxis=dict(title='Value'),
48
- showlegend=True,
49
- template='plotly_white'
50
- )
51
  return fig
52
 
53
 
54
- def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
55
  """
56
  Plot the true values and forecasts using Plotly.
57
 
@@ -67,48 +71,53 @@ def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
67
  fig = go.Figure()
68
 
69
  # Add the true values trace
70
- fig.add_trace(go.Scatter(
 
71
  x=pd.to_datetime(df.index),
72
  y=df.iloc[:, 0],
73
- mode='lines',
74
- name='True values',
75
- line=dict(color='black')
76
- ))
 
77
 
78
  # Add the forecast traces
79
  colors = ["green", "blue", "purple"]
80
  for i, forecast in enumerate(forecasts):
81
  color = colors[i]
82
  for sample in forecast.samples:
83
- fig.add_trace(go.Scatter(
 
84
  x=forecast.index.to_timestamp(),
85
  y=sample,
86
- mode='lines',
87
  opacity=0.15, # Adjust opacity to control visibility of individual samples
88
- name=f'Forecast {i + 1}',
89
  showlegend=False, # Hide the individual forecast series from the legend
90
- hoverinfo='none', # Disable hover information for the forecast series
91
- line=dict(color=color)
92
- ))
 
93
  # Add the average
94
  mean_forecast = np.mean(forecast.samples, axis=0)
95
- fig.add_trace(go.Scatter(
 
96
  x=forecast.index.to_timestamp(),
97
  y=mean_forecast,
98
- mode='lines',
99
- name=f'Mean Forecast',
100
- line=dict(color='red', dash='dash')
101
- ))
 
102
 
103
  # Customize the layout
104
  fig.update_layout(
105
- title='Passenger Forecast',
106
- xaxis=dict(title='Index'),
107
- yaxis=dict(title='Passenger Count'),
108
- showlegend=True,
109
- legend=dict(x=0, y=1, font=dict(size=16)),
110
- hovermode='x' # Enable x-axis hover for better interactivity
111
- )
112
 
113
  # Return the figure
114
  return fig
 
21
  fig = go.Figure()
22
 
23
  # Add the first scatter plot with steelblue color
24
+ fig.add_trace(
25
+ go.Scatter(
26
  x=df1.index,
27
  y=df1.iloc[:, 0],
28
+ mode="lines",
29
+ name="Training Data",
30
+ line=dict(color="steelblue"),
31
+ marker=dict(color="steelblue"),
32
+ )
33
+ )
34
 
35
  # Add the second scatter plot with yellow color
36
+ fig.add_trace(
37
+ go.Scatter(
38
  x=df2.index,
39
  y=df2.iloc[:, 0],
40
+ mode="lines",
41
+ name="Test Data",
42
+ line=dict(color="gold"),
43
+ marker=dict(color="gold"),
44
+ )
45
+ )
46
 
47
  # Customize the layout
48
  fig.update_layout(
49
+ title="Univariate Time Series",
50
+ xaxis=dict(title="Date"),
51
+ yaxis=dict(title="Value"),
52
+ showlegend=True,
53
+ template="plotly_white",
54
+ )
55
  return fig
56
 
57
 
58
+ def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
59
  """
60
  Plot the true values and forecasts using Plotly.
61
 
 
71
  fig = go.Figure()
72
 
73
  # Add the true values trace
74
+ fig.add_trace(
75
+ go.Scatter(
76
  x=pd.to_datetime(df.index),
77
  y=df.iloc[:, 0],
78
+ mode="lines",
79
+ name="True values",
80
+ line=dict(color="black"),
81
+ )
82
+ )
83
 
84
  # Add the forecast traces
85
  colors = ["green", "blue", "purple"]
86
  for i, forecast in enumerate(forecasts):
87
  color = colors[i]
88
  for sample in forecast.samples:
89
+ fig.add_trace(
90
+ go.Scatter(
91
  x=forecast.index.to_timestamp(),
92
  y=sample,
93
+ mode="lines",
94
  opacity=0.15, # Adjust opacity to control visibility of individual samples
95
+ name=f"Forecast {i + 1}",
96
  showlegend=False, # Hide the individual forecast series from the legend
97
+ hoverinfo="none", # Disable hover information for the forecast series
98
+ line=dict(color=color),
99
+ )
100
+ )
101
  # Add the average
102
  mean_forecast = np.mean(forecast.samples, axis=0)
103
+ fig.add_trace(
104
+ go.Scatter(
105
  x=forecast.index.to_timestamp(),
106
  y=mean_forecast,
107
+ mode="lines",
108
+ name="Mean Forecast",
109
+ line=dict(color="red", dash="dash"),
110
+ )
111
+ )
112
 
113
  # Customize the layout
114
  fig.update_layout(
115
+ title=f"{df.columns[0]} Forecast",
116
+ yaxis=dict(title=df.columns[0]),
117
+ showlegend=True,
118
+ legend=dict(x=0, y=1, font=dict(size=16)),
119
+ hovermode="x", # Enable x-axis hover for better interactivity
120
+ )
 
121
 
122
  # Return the figure
123
  return fig