DexterSptizu
commited on
Commit
β’
005b8a9
1
Parent(s):
8afde48
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3 |
from wordllama import WordLlama
|
4 |
import plotly.graph_objects as go
|
5 |
import plotly.express as px
|
6 |
-
from sklearn.
|
7 |
import pandas as pd
|
8 |
|
9 |
# Page configuration
|
@@ -16,17 +16,6 @@ st.set_page_config(
|
|
16 |
# Custom CSS
|
17 |
st.markdown("""
|
18 |
<style>
|
19 |
-
.main {
|
20 |
-
background-color: #f8f9fa;
|
21 |
-
}
|
22 |
-
.stTabs [data-baseweb="tab-list"] {
|
23 |
-
gap: 24px;
|
24 |
-
}
|
25 |
-
.stTabs [data-baseweb="tab"] {
|
26 |
-
height: 50px;
|
27 |
-
padding-left: 20px;
|
28 |
-
padding-right: 20px;
|
29 |
-
}
|
30 |
.title-font {
|
31 |
font-size: 28px !important;
|
32 |
font-weight: bold;
|
@@ -41,104 +30,112 @@ def load_wordllama():
|
|
41 |
|
42 |
wl = load_wordllama()
|
43 |
|
44 |
-
def
|
45 |
-
"""Create
|
46 |
-
|
47 |
|
48 |
-
# Create
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
if
|
53 |
-
# For 2
|
54 |
fig = go.Figure()
|
55 |
|
56 |
-
#
|
|
|
57 |
fig.add_trace(go.Scatter(
|
58 |
-
x=[0, 1],
|
59 |
-
y=[0,
|
60 |
mode='markers+text',
|
61 |
text=texts,
|
62 |
textposition='top center',
|
63 |
-
marker=dict(size=10)
|
64 |
))
|
65 |
|
66 |
fig.update_layout(
|
67 |
-
title="Text Similarity Visualization",
|
68 |
-
xaxis_title="
|
69 |
-
yaxis_title="
|
70 |
height=400,
|
71 |
-
showlegend=False
|
|
|
|
|
72 |
)
|
73 |
|
74 |
else:
|
75 |
-
# For 3 or more
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
# Pad with zeros if needed
|
80 |
-
if embeddings_reduced.shape[1] < 3:
|
81 |
-
padding = np.zeros((embeddings_reduced.shape[0], 3 - embeddings_reduced.shape[1]))
|
82 |
-
embeddings_reduced = np.hstack([embeddings_reduced, padding])
|
83 |
|
84 |
# Create DataFrame for plotting
|
85 |
-
|
86 |
-
|
87 |
columns=['X', 'Y', 'Z']
|
88 |
)
|
89 |
-
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
fig.update_traces(
|
98 |
-
marker=dict(size=8, opacity=0.8),
|
99 |
-
textposition='top center'
|
100 |
-
)
|
101 |
fig.update_layout(
|
|
|
102 |
scene=dict(
|
103 |
-
xaxis_title=
|
104 |
-
yaxis_title=
|
105 |
-
zaxis_title=
|
106 |
),
|
107 |
height=700
|
108 |
)
|
109 |
|
110 |
return fig
|
111 |
|
112 |
-
def create_similarity_matrix(texts):
|
113 |
-
n = len(texts)
|
114 |
-
similarity_matrix = np.zeros((n, n))
|
115 |
-
|
116 |
-
for i in range(n):
|
117 |
-
for j in range(n):
|
118 |
-
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
|
119 |
-
|
120 |
-
fig = go.Figure(data=go.Heatmap(
|
121 |
-
z=similarity_matrix,
|
122 |
-
x=texts,
|
123 |
-
y=texts,
|
124 |
-
colorscale='Viridis',
|
125 |
-
text=np.round(similarity_matrix, 3),
|
126 |
-
texttemplate='%{text}',
|
127 |
-
textfont={"size": 10},
|
128 |
-
))
|
129 |
-
|
130 |
-
fig.update_layout(
|
131 |
-
title="Similarity Matrix",
|
132 |
-
height=400
|
133 |
-
)
|
134 |
-
|
135 |
-
return fig
|
136 |
-
|
137 |
def main():
|
138 |
-
st.title("π¦ WordLlama
|
139 |
-
st.markdown("<p class='title-font'>
|
140 |
unsafe_allow_html=True)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
tabs = st.tabs(["π« Text Similarity", "π― Multi-Text Analysis"])
|
143 |
|
144 |
with tabs[0]:
|
@@ -146,14 +143,21 @@ def main():
|
|
146 |
|
147 |
col1, col2 = st.columns(2)
|
148 |
with col1:
|
149 |
-
text1 = st.text_area(
|
|
|
|
|
|
|
|
|
150 |
with col2:
|
151 |
-
text2 = st.text_area(
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
if st.button("
|
154 |
similarity = wl.similarity(text1, text2)
|
155 |
|
156 |
-
st.markdown("### Results")
|
157 |
col1, col2 = st.columns(2)
|
158 |
|
159 |
with col1:
|
@@ -164,60 +168,94 @@ def main():
|
|
164 |
)
|
165 |
|
166 |
interpretation = (
|
167 |
-
"Very Similar" if similarity > 0.8
|
168 |
-
else "Moderately Similar" if similarity > 0.5
|
169 |
-
else "Different"
|
170 |
)
|
171 |
st.info(f"Interpretation: {interpretation}")
|
172 |
|
173 |
with col2:
|
174 |
-
embeddings = wl.embed([text1, text2])
|
175 |
st.plotly_chart(
|
176 |
-
|
177 |
use_container_width=True
|
178 |
)
|
179 |
|
180 |
with tabs[1]:
|
181 |
st.markdown("### Analyze Multiple Texts")
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
num_texts = st.slider("Number of texts:", 2, 6, 3)
|
184 |
texts = []
|
185 |
|
186 |
for i in range(num_texts):
|
|
|
|
|
|
|
187 |
text = st.text_area(
|
188 |
-
f"Text {i+1}",
|
189 |
-
value=
|
190 |
height=100,
|
191 |
key=f"text_{i}"
|
192 |
)
|
193 |
texts.append(text)
|
194 |
|
195 |
if st.button("Analyze Texts", key="analyze_button"):
|
196 |
-
embeddings = wl.embed(texts)
|
197 |
-
|
198 |
-
st.markdown("### Visualization")
|
199 |
st.plotly_chart(
|
200 |
-
|
201 |
use_container_width=True
|
202 |
)
|
203 |
|
|
|
204 |
st.markdown("### Similarity Matrix")
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
)
|
209 |
|
210 |
-
|
211 |
-
st.markdown("### Pairwise Similarities")
|
212 |
-
for i in range(len(texts)):
|
213 |
-
for j in range(i+1, len(texts)):
|
214 |
-
similarity = wl.similarity(texts[i], texts[j])
|
215 |
-
interpretation = (
|
216 |
-
"π’ Very Similar" if similarity > 0.8
|
217 |
-
else "π‘ Moderately Similar" if similarity > 0.5
|
218 |
-
else "π΄ Different"
|
219 |
-
)
|
220 |
-
st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}")
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
main()
|
|
|
3 |
from wordllama import WordLlama
|
4 |
import plotly.graph_objects as go
|
5 |
import plotly.express as px
|
6 |
+
from sklearn.manifold import MDS
|
7 |
import pandas as pd
|
8 |
|
9 |
# Page configuration
|
|
|
16 |
# Custom CSS
|
17 |
st.markdown("""
|
18 |
<style>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
.title-font {
|
20 |
font-size: 28px !important;
|
21 |
font-weight: bold;
|
|
|
30 |
|
31 |
wl = load_wordllama()
|
32 |
|
33 |
+
def create_similarity_based_visualization(texts):
|
34 |
+
"""Create visualization based on similarity distances"""
|
35 |
+
n = len(texts)
|
36 |
|
37 |
+
# Create similarity matrix
|
38 |
+
similarity_matrix = np.zeros((n, n))
|
39 |
+
for i in range(n):
|
40 |
+
for j in range(n):
|
41 |
+
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
|
42 |
+
|
43 |
+
# Convert similarities to distances (1 - similarity)
|
44 |
+
distance_matrix = 1 - similarity_matrix
|
45 |
|
46 |
+
if n == 2:
|
47 |
+
# For 2 texts, create a simple 2D visualization
|
48 |
fig = go.Figure()
|
49 |
|
50 |
+
# Place points based on similarity
|
51 |
+
similarity = similarity_matrix[0][1]
|
52 |
fig.add_trace(go.Scatter(
|
53 |
+
x=[0, 1-similarity], # Distance proportional to similarity
|
54 |
+
y=[0, 0],
|
55 |
mode='markers+text',
|
56 |
text=texts,
|
57 |
textposition='top center',
|
58 |
+
marker=dict(size=10, color=['blue', 'red'])
|
59 |
))
|
60 |
|
61 |
fig.update_layout(
|
62 |
+
title=f"Text Similarity Visualization (Similarity: {similarity:.3f})",
|
63 |
+
xaxis_title="Relative Distance",
|
64 |
+
yaxis_title="",
|
65 |
height=400,
|
66 |
+
showlegend=False,
|
67 |
+
xaxis=dict(range=[-0.1, 1.1]),
|
68 |
+
yaxis=dict(range=[-0.5, 0.5])
|
69 |
)
|
70 |
|
71 |
else:
|
72 |
+
# For 3 or more texts, use MDS for 3D visualization
|
73 |
+
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
|
74 |
+
coords = mds.fit_transform(distance_matrix)
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Create DataFrame for plotting
|
77 |
+
df = pd.DataFrame(
|
78 |
+
coords,
|
79 |
columns=['X', 'Y', 'Z']
|
80 |
)
|
81 |
+
df['text'] = texts
|
82 |
|
83 |
+
# Create 3D scatter plot
|
84 |
+
fig = go.Figure(data=[go.Scatter3d(
|
85 |
+
x=df['X'],
|
86 |
+
y=df['Y'],
|
87 |
+
z=df['Z'],
|
88 |
+
mode='markers+text',
|
89 |
+
text=texts,
|
90 |
+
textposition='top center',
|
91 |
+
marker=dict(
|
92 |
+
size=10,
|
93 |
+
color=list(range(len(texts))),
|
94 |
+
colorscale='Viridis',
|
95 |
+
opacity=0.8
|
96 |
+
)
|
97 |
+
)])
|
98 |
+
|
99 |
+
# Add lines between points to show connections
|
100 |
+
for i in range(n):
|
101 |
+
for j in range(i+1, n):
|
102 |
+
fig.add_trace(go.Scatter3d(
|
103 |
+
x=[coords[i,0], coords[j,0]],
|
104 |
+
y=[coords[i,1], coords[j,1]],
|
105 |
+
z=[coords[i,2], coords[j,2]],
|
106 |
+
mode='lines',
|
107 |
+
line=dict(
|
108 |
+
color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})',
|
109 |
+
width=2
|
110 |
+
),
|
111 |
+
showlegend=False
|
112 |
+
))
|
113 |
|
|
|
|
|
|
|
|
|
114 |
fig.update_layout(
|
115 |
+
title="3D Similarity Visualization",
|
116 |
scene=dict(
|
117 |
+
xaxis_title="Dimension 1",
|
118 |
+
yaxis_title="Dimension 2",
|
119 |
+
zaxis_title="Dimension 3"
|
120 |
),
|
121 |
height=700
|
122 |
)
|
123 |
|
124 |
return fig
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def main():
|
127 |
+
st.title("π¦ WordLlama Similarity Explorer")
|
128 |
+
st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>",
|
129 |
unsafe_allow_html=True)
|
130 |
|
131 |
+
with st.expander("βΉοΈ How to interpret the visualization", expanded=True):
|
132 |
+
st.markdown("""
|
133 |
+
- **Distance between points** represents dissimilarity (farther = less similar)
|
134 |
+
- **Line opacity** indicates similarity strength (darker = more similar)
|
135 |
+
- **Colors** help distinguish different texts
|
136 |
+
- **Hover** over points to see full text content
|
137 |
+
""")
|
138 |
+
|
139 |
tabs = st.tabs(["π« Text Similarity", "π― Multi-Text Analysis"])
|
140 |
|
141 |
with tabs[0]:
|
|
|
143 |
|
144 |
col1, col2 = st.columns(2)
|
145 |
with col1:
|
146 |
+
text1 = st.text_area(
|
147 |
+
"First Text",
|
148 |
+
value="I love programming in Python",
|
149 |
+
height=100
|
150 |
+
)
|
151 |
with col2:
|
152 |
+
text2 = st.text_area(
|
153 |
+
"Second Text",
|
154 |
+
value="Coding with Python is amazing",
|
155 |
+
height=100
|
156 |
+
)
|
157 |
|
158 |
+
if st.button("Analyze Similarity", key="sim_button"):
|
159 |
similarity = wl.similarity(text1, text2)
|
160 |
|
|
|
161 |
col1, col2 = st.columns(2)
|
162 |
|
163 |
with col1:
|
|
|
168 |
)
|
169 |
|
170 |
interpretation = (
|
171 |
+
"π’ Very Similar" if similarity > 0.8
|
172 |
+
else "π‘ Moderately Similar" if similarity > 0.5
|
173 |
+
else "π΄ Different"
|
174 |
)
|
175 |
st.info(f"Interpretation: {interpretation}")
|
176 |
|
177 |
with col2:
|
|
|
178 |
st.plotly_chart(
|
179 |
+
create_similarity_based_visualization([text1, text2]),
|
180 |
use_container_width=True
|
181 |
)
|
182 |
|
183 |
with tabs[1]:
|
184 |
st.markdown("### Analyze Multiple Texts")
|
185 |
|
186 |
+
# Example templates
|
187 |
+
examples = {
|
188 |
+
"Similar Texts": [
|
189 |
+
"I love programming in Python",
|
190 |
+
"Python programming is my passion",
|
191 |
+
"I enjoy coding with Python"
|
192 |
+
],
|
193 |
+
"Mixed Similarity": [
|
194 |
+
"The cat sleeps on the mat",
|
195 |
+
"A cat is sleeping on the rug",
|
196 |
+
"Python is a programming language"
|
197 |
+
],
|
198 |
+
"Different Topics": [
|
199 |
+
"The weather is sunny today",
|
200 |
+
"Python is a programming language",
|
201 |
+
"Cats are wonderful pets"
|
202 |
+
]
|
203 |
+
}
|
204 |
+
|
205 |
+
col1, col2 = st.columns([3, 1])
|
206 |
+
with col1:
|
207 |
+
selected_example = st.selectbox(
|
208 |
+
"Choose an example set:",
|
209 |
+
list(examples.keys())
|
210 |
+
)
|
211 |
+
with col2:
|
212 |
+
if st.button("Load Example"):
|
213 |
+
st.session_state.texts = examples[selected_example]
|
214 |
+
|
215 |
num_texts = st.slider("Number of texts:", 2, 6, 3)
|
216 |
texts = []
|
217 |
|
218 |
for i in range(num_texts):
|
219 |
+
default_text = (examples[selected_example][i]
|
220 |
+
if selected_example in examples and i < len(examples[selected_example])
|
221 |
+
else f"Example text {i+1}")
|
222 |
text = st.text_area(
|
223 |
+
f"Text {i+1}",
|
224 |
+
value=default_text,
|
225 |
height=100,
|
226 |
key=f"text_{i}"
|
227 |
)
|
228 |
texts.append(text)
|
229 |
|
230 |
if st.button("Analyze Texts", key="analyze_button"):
|
|
|
|
|
|
|
231 |
st.plotly_chart(
|
232 |
+
create_similarity_based_visualization(texts),
|
233 |
use_container_width=True
|
234 |
)
|
235 |
|
236 |
+
# Show similarity matrix
|
237 |
st.markdown("### Similarity Matrix")
|
238 |
+
similarity_matrix = np.zeros((len(texts), len(texts)))
|
239 |
+
for i in range(len(texts)):
|
240 |
+
for j in range(len(texts)):
|
241 |
+
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
|
242 |
+
|
243 |
+
fig = go.Figure(data=go.Heatmap(
|
244 |
+
z=similarity_matrix,
|
245 |
+
x=[f"Text {i+1}" for i in range(len(texts))],
|
246 |
+
y=[f"Text {i+1}" for i in range(len(texts))],
|
247 |
+
colorscale='Viridis',
|
248 |
+
text=np.round(similarity_matrix, 3),
|
249 |
+
texttemplate='%{text}',
|
250 |
+
textfont={"size": 12},
|
251 |
+
))
|
252 |
+
|
253 |
+
fig.update_layout(
|
254 |
+
title="Similarity Matrix",
|
255 |
+
height=400
|
256 |
)
|
257 |
|
258 |
+
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
if __name__ == "__main__":
|
261 |
main()
|