DexterSptizu's picture
Update app.py
005b8a9 verified
raw
history blame
8.43 kB
import streamlit as st
import numpy as np
from wordllama import WordLlama
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import MDS
import pandas as pd
# Page configuration
st.set_page_config(
page_title="WordLlama Explorer",
page_icon="πŸ¦™",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.title-font {
font-size: 28px !important;
font-weight: bold;
color: #2c3e50;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_wordllama():
return WordLlama.load()
wl = load_wordllama()
def create_similarity_based_visualization(texts):
"""Create visualization based on similarity distances"""
n = len(texts)
# Create similarity matrix
similarity_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
# Convert similarities to distances (1 - similarity)
distance_matrix = 1 - similarity_matrix
if n == 2:
# For 2 texts, create a simple 2D visualization
fig = go.Figure()
# Place points based on similarity
similarity = similarity_matrix[0][1]
fig.add_trace(go.Scatter(
x=[0, 1-similarity], # Distance proportional to similarity
y=[0, 0],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(size=10, color=['blue', 'red'])
))
fig.update_layout(
title=f"Text Similarity Visualization (Similarity: {similarity:.3f})",
xaxis_title="Relative Distance",
yaxis_title="",
height=400,
showlegend=False,
xaxis=dict(range=[-0.1, 1.1]),
yaxis=dict(range=[-0.5, 0.5])
)
else:
# For 3 or more texts, use MDS for 3D visualization
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
coords = mds.fit_transform(distance_matrix)
# Create DataFrame for plotting
df = pd.DataFrame(
coords,
columns=['X', 'Y', 'Z']
)
df['text'] = texts
# Create 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
x=df['X'],
y=df['Y'],
z=df['Z'],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(
size=10,
color=list(range(len(texts))),
colorscale='Viridis',
opacity=0.8
)
)])
# Add lines between points to show connections
for i in range(n):
for j in range(i+1, n):
fig.add_trace(go.Scatter3d(
x=[coords[i,0], coords[j,0]],
y=[coords[i,1], coords[j,1]],
z=[coords[i,2], coords[j,2]],
mode='lines',
line=dict(
color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})',
width=2
),
showlegend=False
))
fig.update_layout(
title="3D Similarity Visualization",
scene=dict(
xaxis_title="Dimension 1",
yaxis_title="Dimension 2",
zaxis_title="Dimension 3"
),
height=700
)
return fig
def main():
st.title("πŸ¦™ WordLlama Similarity Explorer")
st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>",
unsafe_allow_html=True)
with st.expander("ℹ️ How to interpret the visualization", expanded=True):
st.markdown("""
- **Distance between points** represents dissimilarity (farther = less similar)
- **Line opacity** indicates similarity strength (darker = more similar)
- **Colors** help distinguish different texts
- **Hover** over points to see full text content
""")
tabs = st.tabs(["πŸ’« Text Similarity", "🎯 Multi-Text Analysis"])
with tabs[0]:
st.markdown("### Compare Two Texts")
col1, col2 = st.columns(2)
with col1:
text1 = st.text_area(
"First Text",
value="I love programming in Python",
height=100
)
with col2:
text2 = st.text_area(
"Second Text",
value="Coding with Python is amazing",
height=100
)
if st.button("Analyze Similarity", key="sim_button"):
similarity = wl.similarity(text1, text2)
col1, col2 = st.columns(2)
with col1:
st.metric(
label="Similarity Score",
value=f"{similarity:.4f}",
help="1.0 = identical, 0.0 = completely different"
)
interpretation = (
"🟒 Very Similar" if similarity > 0.8
else "🟑 Moderately Similar" if similarity > 0.5
else "πŸ”΄ Different"
)
st.info(f"Interpretation: {interpretation}")
with col2:
st.plotly_chart(
create_similarity_based_visualization([text1, text2]),
use_container_width=True
)
with tabs[1]:
st.markdown("### Analyze Multiple Texts")
# Example templates
examples = {
"Similar Texts": [
"I love programming in Python",
"Python programming is my passion",
"I enjoy coding with Python"
],
"Mixed Similarity": [
"The cat sleeps on the mat",
"A cat is sleeping on the rug",
"Python is a programming language"
],
"Different Topics": [
"The weather is sunny today",
"Python is a programming language",
"Cats are wonderful pets"
]
}
col1, col2 = st.columns([3, 1])
with col1:
selected_example = st.selectbox(
"Choose an example set:",
list(examples.keys())
)
with col2:
if st.button("Load Example"):
st.session_state.texts = examples[selected_example]
num_texts = st.slider("Number of texts:", 2, 6, 3)
texts = []
for i in range(num_texts):
default_text = (examples[selected_example][i]
if selected_example in examples and i < len(examples[selected_example])
else f"Example text {i+1}")
text = st.text_area(
f"Text {i+1}",
value=default_text,
height=100,
key=f"text_{i}"
)
texts.append(text)
if st.button("Analyze Texts", key="analyze_button"):
st.plotly_chart(
create_similarity_based_visualization(texts),
use_container_width=True
)
# Show similarity matrix
st.markdown("### Similarity Matrix")
similarity_matrix = np.zeros((len(texts), len(texts)))
for i in range(len(texts)):
for j in range(len(texts)):
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
fig = go.Figure(data=go.Heatmap(
z=similarity_matrix,
x=[f"Text {i+1}" for i in range(len(texts))],
y=[f"Text {i+1}" for i in range(len(texts))],
colorscale='Viridis',
text=np.round(similarity_matrix, 3),
texttemplate='%{text}',
textfont={"size": 12},
))
fig.update_layout(
title="Similarity Matrix",
height=400
)
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()