|
import streamlit as st |
|
import numpy as np |
|
from wordllama import WordLlama |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from sklearn.manifold import MDS |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config( |
|
page_title="WordLlama Explorer", |
|
page_icon="π¦", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.title-font { |
|
font-size: 28px !important; |
|
font-weight: bold; |
|
color: #2c3e50; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_wordllama(): |
|
return WordLlama.load() |
|
|
|
wl = load_wordllama() |
|
|
|
def create_similarity_based_visualization(texts): |
|
"""Create visualization based on similarity distances""" |
|
n = len(texts) |
|
|
|
|
|
similarity_matrix = np.zeros((n, n)) |
|
for i in range(n): |
|
for j in range(n): |
|
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) |
|
|
|
|
|
distance_matrix = 1 - similarity_matrix |
|
|
|
if n == 2: |
|
|
|
fig = go.Figure() |
|
|
|
|
|
similarity = similarity_matrix[0][1] |
|
fig.add_trace(go.Scatter( |
|
x=[0, 1-similarity], |
|
y=[0, 0], |
|
mode='markers+text', |
|
text=texts, |
|
textposition='top center', |
|
marker=dict(size=10, color=['blue', 'red']) |
|
)) |
|
|
|
fig.update_layout( |
|
title=f"Text Similarity Visualization (Similarity: {similarity:.3f})", |
|
xaxis_title="Relative Distance", |
|
yaxis_title="", |
|
height=400, |
|
showlegend=False, |
|
xaxis=dict(range=[-0.1, 1.1]), |
|
yaxis=dict(range=[-0.5, 0.5]) |
|
) |
|
|
|
else: |
|
|
|
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42) |
|
coords = mds.fit_transform(distance_matrix) |
|
|
|
|
|
df = pd.DataFrame( |
|
coords, |
|
columns=['X', 'Y', 'Z'] |
|
) |
|
df['text'] = texts |
|
|
|
|
|
fig = go.Figure(data=[go.Scatter3d( |
|
x=df['X'], |
|
y=df['Y'], |
|
z=df['Z'], |
|
mode='markers+text', |
|
text=texts, |
|
textposition='top center', |
|
marker=dict( |
|
size=10, |
|
color=list(range(len(texts))), |
|
colorscale='Viridis', |
|
opacity=0.8 |
|
) |
|
)]) |
|
|
|
|
|
for i in range(n): |
|
for j in range(i+1, n): |
|
fig.add_trace(go.Scatter3d( |
|
x=[coords[i,0], coords[j,0]], |
|
y=[coords[i,1], coords[j,1]], |
|
z=[coords[i,2], coords[j,2]], |
|
mode='lines', |
|
line=dict( |
|
color=f'rgba(100,100,100,{similarity_matrix[i,j]:.2f})', |
|
width=2 |
|
), |
|
showlegend=False |
|
)) |
|
|
|
fig.update_layout( |
|
title="3D Similarity Visualization", |
|
scene=dict( |
|
xaxis_title="Dimension 1", |
|
yaxis_title="Dimension 2", |
|
zaxis_title="Dimension 3" |
|
), |
|
height=700 |
|
) |
|
|
|
return fig |
|
|
|
def main(): |
|
st.title("π¦ WordLlama Similarity Explorer") |
|
st.markdown("<p class='title-font'>Visualize text similarities in 3D space</p>", |
|
unsafe_allow_html=True) |
|
|
|
with st.expander("βΉοΈ How to interpret the visualization", expanded=True): |
|
st.markdown(""" |
|
- **Distance between points** represents dissimilarity (farther = less similar) |
|
- **Line opacity** indicates similarity strength (darker = more similar) |
|
- **Colors** help distinguish different texts |
|
- **Hover** over points to see full text content |
|
""") |
|
|
|
tabs = st.tabs(["π« Text Similarity", "π― Multi-Text Analysis"]) |
|
|
|
with tabs[0]: |
|
st.markdown("### Compare Two Texts") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
text1 = st.text_area( |
|
"First Text", |
|
value="I love programming in Python", |
|
height=100 |
|
) |
|
with col2: |
|
text2 = st.text_area( |
|
"Second Text", |
|
value="Coding with Python is amazing", |
|
height=100 |
|
) |
|
|
|
if st.button("Analyze Similarity", key="sim_button"): |
|
similarity = wl.similarity(text1, text2) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.metric( |
|
label="Similarity Score", |
|
value=f"{similarity:.4f}", |
|
help="1.0 = identical, 0.0 = completely different" |
|
) |
|
|
|
interpretation = ( |
|
"π’ Very Similar" if similarity > 0.8 |
|
else "π‘ Moderately Similar" if similarity > 0.5 |
|
else "π΄ Different" |
|
) |
|
st.info(f"Interpretation: {interpretation}") |
|
|
|
with col2: |
|
st.plotly_chart( |
|
create_similarity_based_visualization([text1, text2]), |
|
use_container_width=True |
|
) |
|
|
|
with tabs[1]: |
|
st.markdown("### Analyze Multiple Texts") |
|
|
|
|
|
examples = { |
|
"Similar Texts": [ |
|
"I love programming in Python", |
|
"Python programming is my passion", |
|
"I enjoy coding with Python" |
|
], |
|
"Mixed Similarity": [ |
|
"The cat sleeps on the mat", |
|
"A cat is sleeping on the rug", |
|
"Python is a programming language" |
|
], |
|
"Different Topics": [ |
|
"The weather is sunny today", |
|
"Python is a programming language", |
|
"Cats are wonderful pets" |
|
] |
|
} |
|
|
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
selected_example = st.selectbox( |
|
"Choose an example set:", |
|
list(examples.keys()) |
|
) |
|
with col2: |
|
if st.button("Load Example"): |
|
st.session_state.texts = examples[selected_example] |
|
|
|
num_texts = st.slider("Number of texts:", 2, 6, 3) |
|
texts = [] |
|
|
|
for i in range(num_texts): |
|
default_text = (examples[selected_example][i] |
|
if selected_example in examples and i < len(examples[selected_example]) |
|
else f"Example text {i+1}") |
|
text = st.text_area( |
|
f"Text {i+1}", |
|
value=default_text, |
|
height=100, |
|
key=f"text_{i}" |
|
) |
|
texts.append(text) |
|
|
|
if st.button("Analyze Texts", key="analyze_button"): |
|
st.plotly_chart( |
|
create_similarity_based_visualization(texts), |
|
use_container_width=True |
|
) |
|
|
|
|
|
st.markdown("### Similarity Matrix") |
|
similarity_matrix = np.zeros((len(texts), len(texts))) |
|
for i in range(len(texts)): |
|
for j in range(len(texts)): |
|
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) |
|
|
|
fig = go.Figure(data=go.Heatmap( |
|
z=similarity_matrix, |
|
x=[f"Text {i+1}" for i in range(len(texts))], |
|
y=[f"Text {i+1}" for i in range(len(texts))], |
|
colorscale='Viridis', |
|
text=np.round(similarity_matrix, 3), |
|
texttemplate='%{text}', |
|
textfont={"size": 12}, |
|
)) |
|
|
|
fig.update_layout( |
|
title="Similarity Matrix", |
|
height=400 |
|
) |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
if __name__ == "__main__": |
|
main() |