import streamlit as st |
import numpy as np |
from wordllama import WordLlama |
import plotly.graph_objects as go |
import plotly.express as px |
from sklearn.manifold import MDS |
import pandas as pd |
st.set_page_config( |
page_title="WordLlama Explorer", |
page_icon="π¦", |
layout="wide" |
) |
st.markdown(""" |
<style> |
.title-font { |
font-size: 28px !important; |
font-weight: bold; |
color: #2c3e50; |
} |
</style> |
""", unsafe_allow_html=True) |
@st.cache_resource |
def load_wordllama(): |
return WordLlama.load() |
wl = load_wordllama() |
def create_similarity_based_visualization(texts): |
"""Create visualization based on similarity distances""" |
n = len(texts) |
similarity_matrix = np.zeros((n, n)) |
for i in range(n): |
for j in range(n): |
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) |
distance_matrix = 1 - similarity_matrix |
if n == 2: |
fig = go.Figure() |
similarity = similarity_matrix[0][1] |
fig.add_trace(go.Scatter( |
x=[0, 1-similarity], |
y=[0, 0], |
mode='markers+text', |
text=texts, |
textposition='top center', |
marker=dict(size=10, color=['blue', 'red']) |
)) |
fig.update_layout( |
title=f"Text Similarity Visualization (Similarity: {similarity:.3f})", |
xaxis_title="Relative Distance", |
yaxis_title="", |
height=400, |
showlegend=False, |
xaxis=dict(range=[-0.1, 1.1]), |
yaxis=dict(range=[-0.5, 0.5]) |
) |
else: |
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42) |
coords = mds.fit_transform(distance_matrix) |
df = pd.DataFrame( |
coords, |
columns=['X', 'Y', 'Z'] |
) |
df['text'] = texts |
fig = go.Figure(data=[go.Scatter3d( |
x=df['X'], |
y=df['Y'], |
z=df['Z'], |
mode='markers+text', |
text=texts, |
textposition='top center', |
marker=dict( |
size=10, |
color=list(range(len(texts))), |
colorscale='Viridis', |
opacity=0.8 |
) |
)]) |
for i in range(n): |
for j in range(i+1, n): |
fig.add_trace(go.Scatter3d( |
x=[coords[i,0], coords[j,0]], |
y=[coords[i,1], coords[j,1]], |
z=[coords[i,2], coords[j,2]], |
mode='lines', |
line=dict( |
color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})', |
width=2 |
), |
showlegend=False |
)) |
fig.update_layout( |
title="3D Similarity Visualization", |
scene=dict( |
xaxis_title="Dimension 1", |
yaxis_title="Dimension 2", |
zaxis_title="Dimension 3" |
), |
height=700 |
) |
return fig |
def main(): |
st.title("π¦ WordLlama Similarity Explorer") |
st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>", |
unsafe_allow_html=True) |
with st.expander("βΉοΈ How to interpret the visualization", expanded=True): |
st.markdown(""" |
- **Distance between points** represents dissimilarity (farther = less similar) |
- **Line opacity** indicates similarity strength (darker = more similar) |
- **Colors** help distinguish different texts |
- **Hover** over points to see full text content |
""") |
tabs = st.tabs(["π« Text Similarity", "π― Multi-Text Analysis"]) |
with tabs[0]: |
st.markdown("### Compare Two Texts") |
col1, col2 = st.columns(2) |
with col1: |
text1 = st.text_area( |
"First Text", |
value="I love programming in Python", |
height=100 |
) |
with col2: |
text2 = st.text_area( |
"Second Text", |
value="Coding with Python is amazing", |
height=100 |
) |
if st.button("Analyze Similarity", key="sim_button"): |
similarity = wl.similarity(text1, text2) |
col1, col2 = st.columns(2) |
with col1: |
st.metric( |
label="Similarity Score", |
value=f"{similarity:.4f}", |
help="1.0 = identical, 0.0 = completely different" |
) |
interpretation = ( |
"π’ Very Similar" if similarity > 0.8 |
else "π‘ Moderately Similar" if similarity > 0.5 |
else "π΄ Different" |
) |
st.info(f"Interpretation: {interpretation}") |
with col2: |
st.plotly_chart( |
create_similarity_based_visualization([text1, text2]), |
use_container_width=True |
) |
with tabs[1]: |
st.markdown("### Analyze Multiple Texts") |
examples = { |
"Similar Texts": [ |
"I love programming in Python", |
"Python programming is my passion", |
"I enjoy coding with Python" |
], |
"Mixed Similarity": [ |
"The cat sleeps on the mat", |
"A cat is sleeping on the rug", |
"Python is a programming language" |
], |
"Different Topics": [ |
"The weather is sunny today", |
"Python is a programming language", |
"Cats are wonderful pets" |
] |
} |
col1, col2 = st.columns([3, 1]) |
with col1: |
selected_example = st.selectbox( |
"Choose an example set:", |
list(examples.keys()) |
) |
with col2: |
if st.button("Load Example"): |
st.session_state.texts = examples[selected_example] |
num_texts = st.slider("Number of texts:", 2, 6, 3) |
texts = [] |
for i in range(num_texts): |
default_text = (examples[selected_example][i] |
if selected_example in examples and i < len(examples[selected_example]) |
else f"Example text {i+1}") |
text = st.text_area( |
f"Text {i+1}", |
value=default_text, |
height=100, |
key=f"text_{i}" |
) |
texts.append(text) |
if st.button("Analyze Texts", key="analyze_button"): |
st.plotly_chart( |
create_similarity_based_visualization(texts), |
use_container_width=True |
) |
st.markdown("### Similarity Matrix") |
similarity_matrix = np.zeros((len(texts), len(texts))) |
for i in range(len(texts)): |
for j in range(len(texts)): |
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) |
fig = go.Figure(data=go.Heatmap( |
z=similarity_matrix, |
x=[f"Text {i+1}" for i in range(len(texts))], |
y=[f"Text {i+1}" for i in range(len(texts))], |
colorscale='Viridis', |
text=np.round(similarity_matrix, 3), |
texttemplate='%{text}', |
textfont={"size": 12}, |
)) |
fig.update_layout( |
title="Similarity Matrix", |
height=400 |
) |
st.plotly_chart(fig, use_container_width=True) |
if __name__ == "__main__": |
main() |