DexterSptizu commited on
Commit
005b8a9
β€’
1 Parent(s): 8afde48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -107
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  from wordllama import WordLlama
4
  import plotly.graph_objects as go
5
  import plotly.express as px
6
- from sklearn.decomposition import PCA
7
  import pandas as pd
8
 
9
  # Page configuration
@@ -16,17 +16,6 @@ st.set_page_config(
16
  # Custom CSS
17
  st.markdown("""
18
  <style>
19
- .main {
20
- background-color: #f8f9fa;
21
- }
22
- .stTabs [data-baseweb="tab-list"] {
23
- gap: 24px;
24
- }
25
- .stTabs [data-baseweb="tab"] {
26
- height: 50px;
27
- padding-left: 20px;
28
- padding-right: 20px;
29
- }
30
  .title-font {
31
  font-size: 28px !important;
32
  font-weight: bold;
@@ -41,104 +30,112 @@ def load_wordllama():
41
 
42
  wl = load_wordllama()
43
 
44
- def create_visualization(texts, embeddings):
45
- """Create appropriate visualization based on number of samples"""
46
- n_samples = len(embeddings)
47
 
48
- # Create DataFrame with original embeddings
49
- df = pd.DataFrame(embeddings)
50
- df['text'] = texts
 
 
 
 
 
51
 
52
- if n_samples == 2:
53
- # For 2 samples, create a 2D visualization
54
  fig = go.Figure()
55
 
56
- # Add points
 
57
  fig.add_trace(go.Scatter(
58
- x=[0, 1],
59
- y=[0, wl.similarity(texts[0], texts[1])],
60
  mode='markers+text',
61
  text=texts,
62
  textposition='top center',
63
- marker=dict(size=10)
64
  ))
65
 
66
  fig.update_layout(
67
- title="Text Similarity Visualization",
68
- xaxis_title="Position",
69
- yaxis_title="Similarity",
70
  height=400,
71
- showlegend=False
 
 
72
  )
73
 
74
  else:
75
- # For 3 or more samples, use PCA for 3D visualization
76
- pca = PCA(n_components=min(3, n_samples))
77
- embeddings_reduced = pca.fit_transform(embeddings)
78
-
79
- # Pad with zeros if needed
80
- if embeddings_reduced.shape[1] < 3:
81
- padding = np.zeros((embeddings_reduced.shape[0], 3 - embeddings_reduced.shape[1]))
82
- embeddings_reduced = np.hstack([embeddings_reduced, padding])
83
 
84
  # Create DataFrame for plotting
85
- df_plot = pd.DataFrame(
86
- embeddings_reduced,
87
  columns=['X', 'Y', 'Z']
88
  )
89
- df_plot['text'] = texts
90
 
91
- fig = px.scatter_3d(
92
- df_plot, x='X', y='Y', z='Z',
93
- text='text',
94
- title='Text Embeddings Visualization'
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- fig.update_traces(
98
- marker=dict(size=8, opacity=0.8),
99
- textposition='top center'
100
- )
101
  fig.update_layout(
 
102
  scene=dict(
103
- xaxis_title='Component 1',
104
- yaxis_title='Component 2',
105
- zaxis_title='Component 3'
106
  ),
107
  height=700
108
  )
109
 
110
  return fig
111
 
112
- def create_similarity_matrix(texts):
113
- n = len(texts)
114
- similarity_matrix = np.zeros((n, n))
115
-
116
- for i in range(n):
117
- for j in range(n):
118
- similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
119
-
120
- fig = go.Figure(data=go.Heatmap(
121
- z=similarity_matrix,
122
- x=texts,
123
- y=texts,
124
- colorscale='Viridis',
125
- text=np.round(similarity_matrix, 3),
126
- texttemplate='%{text}',
127
- textfont={"size": 10},
128
- ))
129
-
130
- fig.update_layout(
131
- title="Similarity Matrix",
132
- height=400
133
- )
134
-
135
- return fig
136
-
137
  def main():
138
- st.title("πŸ¦™ WordLlama Embedding Explorer")
139
- st.markdown("<p class='title-font'>Explore the power of WordLlama embeddings</p>",
140
  unsafe_allow_html=True)
141
 
 
 
 
 
 
 
 
 
142
  tabs = st.tabs(["πŸ’« Text Similarity", "🎯 Multi-Text Analysis"])
143
 
144
  with tabs[0]:
@@ -146,14 +143,21 @@ def main():
146
 
147
  col1, col2 = st.columns(2)
148
  with col1:
149
- text1 = st.text_area("First Text", value="I love programming in Python", height=100)
 
 
 
 
150
  with col2:
151
- text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100)
 
 
 
 
152
 
153
- if st.button("Calculate Similarity", key="sim_button"):
154
  similarity = wl.similarity(text1, text2)
155
 
156
- st.markdown("### Results")
157
  col1, col2 = st.columns(2)
158
 
159
  with col1:
@@ -164,60 +168,94 @@ def main():
164
  )
165
 
166
  interpretation = (
167
- "Very Similar" if similarity > 0.8
168
- else "Moderately Similar" if similarity > 0.5
169
- else "Different"
170
  )
171
  st.info(f"Interpretation: {interpretation}")
172
 
173
  with col2:
174
- embeddings = wl.embed([text1, text2])
175
  st.plotly_chart(
176
- create_visualization([text1, text2], embeddings),
177
  use_container_width=True
178
  )
179
 
180
  with tabs[1]:
181
  st.markdown("### Analyze Multiple Texts")
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  num_texts = st.slider("Number of texts:", 2, 6, 3)
184
  texts = []
185
 
186
  for i in range(num_texts):
 
 
 
187
  text = st.text_area(
188
- f"Text {i+1}",
189
- value=f"Example text {i+1}",
190
  height=100,
191
  key=f"text_{i}"
192
  )
193
  texts.append(text)
194
 
195
  if st.button("Analyze Texts", key="analyze_button"):
196
- embeddings = wl.embed(texts)
197
-
198
- st.markdown("### Visualization")
199
  st.plotly_chart(
200
- create_visualization(texts, embeddings),
201
  use_container_width=True
202
  )
203
 
 
204
  st.markdown("### Similarity Matrix")
205
- st.plotly_chart(
206
- create_similarity_matrix(texts),
207
- use_container_width=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  )
209
 
210
- # Pairwise similarity analysis
211
- st.markdown("### Pairwise Similarities")
212
- for i in range(len(texts)):
213
- for j in range(i+1, len(texts)):
214
- similarity = wl.similarity(texts[i], texts[j])
215
- interpretation = (
216
- "🟒 Very Similar" if similarity > 0.8
217
- else "🟑 Moderately Similar" if similarity > 0.5
218
- else "πŸ”΄ Different"
219
- )
220
- st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}")
221
 
222
  if __name__ == "__main__":
223
  main()
 
3
  from wordllama import WordLlama
4
  import plotly.graph_objects as go
5
  import plotly.express as px
6
+ from sklearn.manifold import MDS
7
  import pandas as pd
8
 
9
  # Page configuration
 
16
  # Custom CSS
17
  st.markdown("""
18
  <style>
 
 
 
 
 
 
 
 
 
 
 
19
  .title-font {
20
  font-size: 28px !important;
21
  font-weight: bold;
 
30
 
31
  wl = load_wordllama()
32
 
33
+ def create_similarity_based_visualization(texts):
34
+ """Create visualization based on similarity distances"""
35
+ n = len(texts)
36
 
37
+ # Create similarity matrix
38
+ similarity_matrix = np.zeros((n, n))
39
+ for i in range(n):
40
+ for j in range(n):
41
+ similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
42
+
43
+ # Convert similarities to distances (1 - similarity)
44
+ distance_matrix = 1 - similarity_matrix
45
 
46
+ if n == 2:
47
+ # For 2 texts, create a simple 2D visualization
48
  fig = go.Figure()
49
 
50
+ # Place points based on similarity
51
+ similarity = similarity_matrix[0][1]
52
  fig.add_trace(go.Scatter(
53
+ x=[0, 1-similarity], # Distance proportional to similarity
54
+ y=[0, 0],
55
  mode='markers+text',
56
  text=texts,
57
  textposition='top center',
58
+ marker=dict(size=10, color=['blue', 'red'])
59
  ))
60
 
61
  fig.update_layout(
62
+ title=f"Text Similarity Visualization (Similarity: {similarity:.3f})",
63
+ xaxis_title="Relative Distance",
64
+ yaxis_title="",
65
  height=400,
66
+ showlegend=False,
67
+ xaxis=dict(range=[-0.1, 1.1]),
68
+ yaxis=dict(range=[-0.5, 0.5])
69
  )
70
 
71
  else:
72
+ # For 3 or more texts, use MDS for 3D visualization
73
+ mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
74
+ coords = mds.fit_transform(distance_matrix)
 
 
 
 
 
75
 
76
  # Create DataFrame for plotting
77
+ df = pd.DataFrame(
78
+ coords,
79
  columns=['X', 'Y', 'Z']
80
  )
81
+ df['text'] = texts
82
 
83
+ # Create 3D scatter plot
84
+ fig = go.Figure(data=[go.Scatter3d(
85
+ x=df['X'],
86
+ y=df['Y'],
87
+ z=df['Z'],
88
+ mode='markers+text',
89
+ text=texts,
90
+ textposition='top center',
91
+ marker=dict(
92
+ size=10,
93
+ color=list(range(len(texts))),
94
+ colorscale='Viridis',
95
+ opacity=0.8
96
+ )
97
+ )])
98
+
99
+ # Add lines between points to show connections
100
+ for i in range(n):
101
+ for j in range(i+1, n):
102
+ fig.add_trace(go.Scatter3d(
103
+ x=[coords[i,0], coords[j,0]],
104
+ y=[coords[i,1], coords[j,1]],
105
+ z=[coords[i,2], coords[j,2]],
106
+ mode='lines',
107
+ line=dict(
108
+ color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})',
109
+ width=2
110
+ ),
111
+ showlegend=False
112
+ ))
113
 
 
 
 
 
114
  fig.update_layout(
115
+ title="3D Similarity Visualization",
116
  scene=dict(
117
+ xaxis_title="Dimension 1",
118
+ yaxis_title="Dimension 2",
119
+ zaxis_title="Dimension 3"
120
  ),
121
  height=700
122
  )
123
 
124
  return fig
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def main():
127
+ st.title("πŸ¦™ WordLlama Similarity Explorer")
128
+ st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>",
129
  unsafe_allow_html=True)
130
 
131
+ with st.expander("ℹ️ How to interpret the visualization", expanded=True):
132
+ st.markdown("""
133
+ - **Distance between points** represents dissimilarity (farther = less similar)
134
+ - **Line opacity** indicates similarity strength (darker = more similar)
135
+ - **Colors** help distinguish different texts
136
+ - **Hover** over points to see full text content
137
+ """)
138
+
139
  tabs = st.tabs(["πŸ’« Text Similarity", "🎯 Multi-Text Analysis"])
140
 
141
  with tabs[0]:
 
143
 
144
  col1, col2 = st.columns(2)
145
  with col1:
146
+ text1 = st.text_area(
147
+ "First Text",
148
+ value="I love programming in Python",
149
+ height=100
150
+ )
151
  with col2:
152
+ text2 = st.text_area(
153
+ "Second Text",
154
+ value="Coding with Python is amazing",
155
+ height=100
156
+ )
157
 
158
+ if st.button("Analyze Similarity", key="sim_button"):
159
  similarity = wl.similarity(text1, text2)
160
 
 
161
  col1, col2 = st.columns(2)
162
 
163
  with col1:
 
168
  )
169
 
170
  interpretation = (
171
+ "🟒 Very Similar" if similarity > 0.8
172
+ else "🟑 Moderately Similar" if similarity > 0.5
173
+ else "πŸ”΄ Different"
174
  )
175
  st.info(f"Interpretation: {interpretation}")
176
 
177
  with col2:
 
178
  st.plotly_chart(
179
+ create_similarity_based_visualization([text1, text2]),
180
  use_container_width=True
181
  )
182
 
183
  with tabs[1]:
184
  st.markdown("### Analyze Multiple Texts")
185
 
186
+ # Example templates
187
+ examples = {
188
+ "Similar Texts": [
189
+ "I love programming in Python",
190
+ "Python programming is my passion",
191
+ "I enjoy coding with Python"
192
+ ],
193
+ "Mixed Similarity": [
194
+ "The cat sleeps on the mat",
195
+ "A cat is sleeping on the rug",
196
+ "Python is a programming language"
197
+ ],
198
+ "Different Topics": [
199
+ "The weather is sunny today",
200
+ "Python is a programming language",
201
+ "Cats are wonderful pets"
202
+ ]
203
+ }
204
+
205
+ col1, col2 = st.columns([3, 1])
206
+ with col1:
207
+ selected_example = st.selectbox(
208
+ "Choose an example set:",
209
+ list(examples.keys())
210
+ )
211
+ with col2:
212
+ if st.button("Load Example"):
213
+ st.session_state.texts = examples[selected_example]
214
+
215
  num_texts = st.slider("Number of texts:", 2, 6, 3)
216
  texts = []
217
 
218
  for i in range(num_texts):
219
+ default_text = (examples[selected_example][i]
220
+ if selected_example in examples and i < len(examples[selected_example])
221
+ else f"Example text {i+1}")
222
  text = st.text_area(
223
+ f"Text {i+1}",
224
+ value=default_text,
225
  height=100,
226
  key=f"text_{i}"
227
  )
228
  texts.append(text)
229
 
230
  if st.button("Analyze Texts", key="analyze_button"):
 
 
 
231
  st.plotly_chart(
232
+ create_similarity_based_visualization(texts),
233
  use_container_width=True
234
  )
235
 
236
+ # Show similarity matrix
237
  st.markdown("### Similarity Matrix")
238
+ similarity_matrix = np.zeros((len(texts), len(texts)))
239
+ for i in range(len(texts)):
240
+ for j in range(len(texts)):
241
+ similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
242
+
243
+ fig = go.Figure(data=go.Heatmap(
244
+ z=similarity_matrix,
245
+ x=[f"Text {i+1}" for i in range(len(texts))],
246
+ y=[f"Text {i+1}" for i in range(len(texts))],
247
+ colorscale='Viridis',
248
+ text=np.round(similarity_matrix, 3),
249
+ texttemplate='%{text}',
250
+ textfont={"size": 12},
251
+ ))
252
+
253
+ fig.update_layout(
254
+ title="Similarity Matrix",
255
+ height=400
256
  )
257
 
258
+ st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
259
 
260
  if __name__ == "__main__":
261
  main()