DexterSptizu commited on
Commit
9cafd02
1 Parent(s): 005b8a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -21
app.py CHANGED
@@ -40,7 +40,7 @@ def create_similarity_based_visualization(texts):
40
  for j in range(n):
41
  similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
42
 
43
- # Convert similarities to distances (1 - similarity)
44
  distance_matrix = 1 - similarity_matrix
45
 
46
  if n == 2:
@@ -50,12 +50,15 @@ def create_similarity_based_visualization(texts):
50
  # Place points based on similarity
51
  similarity = similarity_matrix[0][1]
52
  fig.add_trace(go.Scatter(
53
- x=[0, 1-similarity], # Distance proportional to similarity
54
  y=[0, 0],
55
  mode='markers+text',
56
  text=texts,
57
  textposition='top center',
58
- marker=dict(size=10, color=['blue', 'red'])
 
 
 
59
  ))
60
 
61
  fig.update_layout(
@@ -73,18 +76,14 @@ def create_similarity_based_visualization(texts):
73
  mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
74
  coords = mds.fit_transform(distance_matrix)
75
 
76
- # Create DataFrame for plotting
77
- df = pd.DataFrame(
78
- coords,
79
- columns=['X', 'Y', 'Z']
80
- )
81
- df['text'] = texts
82
-
83
  # Create 3D scatter plot
84
- fig = go.Figure(data=[go.Scatter3d(
85
- x=df['X'],
86
- y=df['Y'],
87
- z=df['Z'],
 
 
 
88
  mode='markers+text',
89
  text=texts,
90
  textposition='top center',
@@ -93,22 +92,28 @@ def create_similarity_based_visualization(texts):
93
  color=list(range(len(texts))),
94
  colorscale='Viridis',
95
  opacity=0.8
96
- )
97
- )])
 
98
 
99
- # Add lines between points to show connections
100
  for i in range(n):
101
  for j in range(i+1, n):
 
 
 
102
  fig.add_trace(go.Scatter3d(
103
  x=[coords[i,0], coords[j,0]],
104
  y=[coords[i,1], coords[j,1]],
105
  z=[coords[i,2], coords[j,2]],
106
  mode='lines',
107
  line=dict(
108
- color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})',
109
  width=2
110
  ),
111
- showlegend=False
 
 
112
  ))
113
 
114
  fig.update_layout(
@@ -116,9 +121,21 @@ def create_similarity_based_visualization(texts):
116
  scene=dict(
117
  xaxis_title="Dimension 1",
118
  yaxis_title="Dimension 2",
119
- zaxis_title="Dimension 3"
 
 
 
 
 
120
  ),
121
- height=700
 
 
 
 
 
 
 
122
  )
123
 
124
  return fig
 
40
  for j in range(n):
41
  similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
42
 
43
+ # Convert similarities to distances
44
  distance_matrix = 1 - similarity_matrix
45
 
46
  if n == 2:
 
50
  # Place points based on similarity
51
  similarity = similarity_matrix[0][1]
52
  fig.add_trace(go.Scatter(
53
+ x=[0, 1-similarity],
54
  y=[0, 0],
55
  mode='markers+text',
56
  text=texts,
57
  textposition='top center',
58
+ marker=dict(
59
+ size=10,
60
+ color=['blue', 'red']
61
+ )
62
  ))
63
 
64
  fig.update_layout(
 
76
  mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
77
  coords = mds.fit_transform(distance_matrix)
78
 
 
 
 
 
 
 
 
79
  # Create 3D scatter plot
80
+ fig = go.Figure()
81
+
82
+ # Add points
83
+ fig.add_trace(go.Scatter3d(
84
+ x=coords[:, 0],
85
+ y=coords[:, 1],
86
+ z=coords[:, 2],
87
  mode='markers+text',
88
  text=texts,
89
  textposition='top center',
 
92
  color=list(range(len(texts))),
93
  colorscale='Viridis',
94
  opacity=0.8
95
+ ),
96
+ name='Texts'
97
+ ))
98
 
99
+ # Add lines between points with valid opacity values
100
  for i in range(n):
101
  for j in range(i+1, n):
102
+ # Calculate opacity based on similarity (ensure it's between 0.1 and 1)
103
+ opacity = max(0.1, min(1.0, similarity_matrix[i,j]))
104
+
105
  fig.add_trace(go.Scatter3d(
106
  x=[coords[i,0], coords[j,0]],
107
  y=[coords[i,1], coords[j,1]],
108
  z=[coords[i,2], coords[j,2]],
109
  mode='lines',
110
  line=dict(
111
+ color='gray',
112
  width=2
113
  ),
114
+ opacity=opacity,
115
+ showlegend=False,
116
+ hoverinfo='skip'
117
  ))
118
 
119
  fig.update_layout(
 
121
  scene=dict(
122
  xaxis_title="Dimension 1",
123
  yaxis_title="Dimension 2",
124
+ zaxis_title="Dimension 3",
125
+ camera=dict(
126
+ up=dict(x=0, y=0, z=1),
127
+ center=dict(x=0, y=0, z=0),
128
+ eye=dict(x=1.5, y=1.5, z=1.5)
129
+ )
130
  ),
131
+ height=700,
132
+ showlegend=True,
133
+ legend=dict(
134
+ yanchor="top",
135
+ y=0.99,
136
+ xanchor="left",
137
+ x=0.01
138
+ )
139
  )
140
 
141
  return fig