DexterSptizu's picture
Update app.py
9cafd02 verified
raw
history blame
8.92 kB
import streamlit as st
import numpy as np
from wordllama import WordLlama
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import MDS
import pandas as pd
# Page configuration
st.set_page_config(
page_title="WordLlama Explorer",
page_icon="πŸ¦™",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.title-font {
font-size: 28px !important;
font-weight: bold;
color: #2c3e50;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_wordllama():
return WordLlama.load()
wl = load_wordllama()
def create_similarity_based_visualization(texts):
"""Create visualization based on similarity distances"""
n = len(texts)
# Create similarity matrix
similarity_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
# Convert similarities to distances
distance_matrix = 1 - similarity_matrix
if n == 2:
# For 2 texts, create a simple 2D visualization
fig = go.Figure()
# Place points based on similarity
similarity = similarity_matrix[0][1]
fig.add_trace(go.Scatter(
x=[0, 1-similarity],
y=[0, 0],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(
size=10,
color=['blue', 'red']
)
))
fig.update_layout(
title=f"Text Similarity Visualization (Similarity: {similarity:.3f})",
xaxis_title="Relative Distance",
yaxis_title="",
height=400,
showlegend=False,
xaxis=dict(range=[-0.1, 1.1]),
yaxis=dict(range=[-0.5, 0.5])
)
else:
# For 3 or more texts, use MDS for 3D visualization
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
coords = mds.fit_transform(distance_matrix)
# Create 3D scatter plot
fig = go.Figure()
# Add points
fig.add_trace(go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(
size=10,
color=list(range(len(texts))),
colorscale='Viridis',
opacity=0.8
),
name='Texts'
))
# Add lines between points with valid opacity values
for i in range(n):
for j in range(i+1, n):
# Calculate opacity based on similarity (ensure it's between 0.1 and 1)
opacity = max(0.1, min(1.0, similarity_matrix[i,j]))
fig.add_trace(go.Scatter3d(
x=[coords[i,0], coords[j,0]],
y=[coords[i,1], coords[j,1]],
z=[coords[i,2], coords[j,2]],
mode='lines',
line=dict(
color='gray',
width=2
),
opacity=opacity,
showlegend=False,
hoverinfo='skip'
))
fig.update_layout(
title="3D Similarity Visualization",
scene=dict(
xaxis_title="Dimension 1",
yaxis_title="Dimension 2",
zaxis_title="Dimension 3",
camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
eye=dict(x=1.5, y=1.5, z=1.5)
)
),
height=700,
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
return fig
def main():
st.title("πŸ¦™ WordLlama Similarity Explorer")
st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>",
unsafe_allow_html=True)
with st.expander("ℹ️ How to interpret the visualization", expanded=True):
st.markdown("""
- **Distance between points** represents dissimilarity (farther = less similar)
- **Line opacity** indicates similarity strength (darker = more similar)
- **Colors** help distinguish different texts
- **Hover** over points to see full text content
""")
tabs = st.tabs(["πŸ’« Text Similarity", "🎯 Multi-Text Analysis"])
with tabs[0]:
st.markdown("### Compare Two Texts")
col1, col2 = st.columns(2)
with col1:
text1 = st.text_area(
"First Text",
value="I love programming in Python",
height=100
)
with col2:
text2 = st.text_area(
"Second Text",
value="Coding with Python is amazing",
height=100
)
if st.button("Analyze Similarity", key="sim_button"):
similarity = wl.similarity(text1, text2)
col1, col2 = st.columns(2)
with col1:
st.metric(
label="Similarity Score",
value=f"{similarity:.4f}",
help="1.0 = identical, 0.0 = completely different"
)
interpretation = (
"🟒 Very Similar" if similarity > 0.8
else "🟑 Moderately Similar" if similarity > 0.5
else "πŸ”΄ Different"
)
st.info(f"Interpretation: {interpretation}")
with col2:
st.plotly_chart(
create_similarity_based_visualization([text1, text2]),
use_container_width=True
)
with tabs[1]:
st.markdown("### Analyze Multiple Texts")
# Example templates
examples = {
"Similar Texts": [
"I love programming in Python",
"Python programming is my passion",
"I enjoy coding with Python"
],
"Mixed Similarity": [
"The cat sleeps on the mat",
"A cat is sleeping on the rug",
"Python is a programming language"
],
"Different Topics": [
"The weather is sunny today",
"Python is a programming language",
"Cats are wonderful pets"
]
}
col1, col2 = st.columns([3, 1])
with col1:
selected_example = st.selectbox(
"Choose an example set:",
list(examples.keys())
)
with col2:
if st.button("Load Example"):
st.session_state.texts = examples[selected_example]
num_texts = st.slider("Number of texts:", 2, 6, 3)
texts = []
for i in range(num_texts):
default_text = (examples[selected_example][i]
if selected_example in examples and i < len(examples[selected_example])
else f"Example text {i+1}")
text = st.text_area(
f"Text {i+1}",
value=default_text,
height=100,
key=f"text_{i}"
)
texts.append(text)
if st.button("Analyze Texts", key="analyze_button"):
st.plotly_chart(
create_similarity_based_visualization(texts),
use_container_width=True
)
# Show similarity matrix
st.markdown("### Similarity Matrix")
similarity_matrix = np.zeros((len(texts), len(texts)))
for i in range(len(texts)):
for j in range(len(texts)):
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
fig = go.Figure(data=go.Heatmap(
z=similarity_matrix,
x=[f"Text {i+1}" for i in range(len(texts))],
y=[f"Text {i+1}" for i in range(len(texts))],
colorscale='Viridis',
text=np.round(similarity_matrix, 3),
texttemplate='%{text}',
textfont={"size": 12},
))
fig.update_layout(
title="Similarity Matrix",
height=400
)
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()