|
import streamlit as st |
|
import numpy as np |
|
from wordllama import WordLlama |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from sklearn.decomposition import PCA |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config( |
|
page_title="WordLlama Explorer", |
|
page_icon="π¦", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main { |
|
background-color: #f8f9fa; |
|
} |
|
.stTabs [data-baseweb="tab-list"] { |
|
gap: 24px; |
|
} |
|
.stTabs [data-baseweb="tab"] { |
|
height: 50px; |
|
padding-left: 20px; |
|
padding-right: 20px; |
|
} |
|
.title-font { |
|
font-size: 28px !important; |
|
font-weight: bold; |
|
color: #2c3e50; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_wordllama(): |
|
return WordLlama.load() |
|
|
|
wl = load_wordllama() |
|
|
|
def create_visualization(texts, embeddings): |
|
"""Create appropriate visualization based on number of samples""" |
|
n_samples = len(embeddings) |
|
|
|
|
|
df = pd.DataFrame(embeddings) |
|
df['text'] = texts |
|
|
|
if n_samples == 2: |
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=[0, 1], |
|
y=[0, wl.similarity(texts[0], texts[1])], |
|
mode='markers+text', |
|
text=texts, |
|
textposition='top center', |
|
marker=dict(size=10) |
|
)) |
|
|
|
fig.update_layout( |
|
title="Text Similarity Visualization", |
|
xaxis_title="Position", |
|
yaxis_title="Similarity", |
|
height=400, |
|
showlegend=False |
|
) |
|
|
|
else: |
|
|
|
pca = PCA(n_components=min(3, n_samples)) |
|
embeddings_reduced = pca.fit_transform(embeddings) |
|
|
|
|
|
if embeddings_reduced.shape[1] < 3: |
|
padding = np.zeros((embeddings_reduced.shape[0], 3 - embeddings_reduced.shape[1])) |
|
embeddings_reduced = np.hstack([embeddings_reduced, padding]) |
|
|
|
|
|
df_plot = pd.DataFrame( |
|
embeddings_reduced, |
|
columns=['X', 'Y', 'Z'] |
|
) |
|
df_plot['text'] = texts |
|
|
|
fig = px.scatter_3d( |
|
df_plot, x='X', y='Y', z='Z', |
|
text='text', |
|
title='Text Embeddings Visualization' |
|
) |
|
|
|
fig.update_traces( |
|
marker=dict(size=8, opacity=0.8), |
|
textposition='top center' |
|
) |
|
fig.update_layout( |
|
scene=dict( |
|
xaxis_title='Component 1', |
|
yaxis_title='Component 2', |
|
zaxis_title='Component 3' |
|
), |
|
height=700 |
|
) |
|
|
|
return fig |
|
|
|
def create_similarity_matrix(texts): |
|
n = len(texts) |
|
similarity_matrix = np.zeros((n, n)) |
|
|
|
for i in range(n): |
|
for j in range(n): |
|
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) |
|
|
|
fig = go.Figure(data=go.Heatmap( |
|
z=similarity_matrix, |
|
x=texts, |
|
y=texts, |
|
colorscale='Viridis', |
|
text=np.round(similarity_matrix, 3), |
|
texttemplate='%{text}', |
|
textfont={"size": 10}, |
|
)) |
|
|
|
fig.update_layout( |
|
title="Similarity Matrix", |
|
height=400 |
|
) |
|
|
|
return fig |
|
|
|
def main(): |
|
st.title("π¦ WordLlama Embedding Explorer") |
|
st.markdown("<p class='title-font'>Explore the power of WordLlama embeddings</p>", |
|
unsafe_allow_html=True) |
|
|
|
tabs = st.tabs(["π« Text Similarity", "π― Multi-Text Analysis"]) |
|
|
|
with tabs[0]: |
|
st.markdown("### Compare Two Texts") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
text1 = st.text_area("First Text", value="I love programming in Python", height=100) |
|
with col2: |
|
text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100) |
|
|
|
if st.button("Calculate Similarity", key="sim_button"): |
|
similarity = wl.similarity(text1, text2) |
|
|
|
st.markdown("### Results") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.metric( |
|
label="Similarity Score", |
|
value=f"{similarity:.4f}", |
|
help="1.0 = identical, 0.0 = completely different" |
|
) |
|
|
|
interpretation = ( |
|
"Very Similar" if similarity > 0.8 |
|
else "Moderately Similar" if similarity > 0.5 |
|
else "Different" |
|
) |
|
st.info(f"Interpretation: {interpretation}") |
|
|
|
with col2: |
|
embeddings = wl.embed([text1, text2]) |
|
st.plotly_chart( |
|
create_visualization([text1, text2], embeddings), |
|
use_container_width=True |
|
) |
|
|
|
with tabs[1]: |
|
st.markdown("### Analyze Multiple Texts") |
|
|
|
num_texts = st.slider("Number of texts:", 2, 6, 3) |
|
texts = [] |
|
|
|
for i in range(num_texts): |
|
text = st.text_area( |
|
f"Text {i+1}", |
|
value=f"Example text {i+1}", |
|
height=100, |
|
key=f"text_{i}" |
|
) |
|
texts.append(text) |
|
|
|
if st.button("Analyze Texts", key="analyze_button"): |
|
embeddings = wl.embed(texts) |
|
|
|
st.markdown("### Visualization") |
|
st.plotly_chart( |
|
create_visualization(texts, embeddings), |
|
use_container_width=True |
|
) |
|
|
|
st.markdown("### Similarity Matrix") |
|
st.plotly_chart( |
|
create_similarity_matrix(texts), |
|
use_container_width=True |
|
) |
|
|
|
|
|
st.markdown("### Pairwise Similarities") |
|
for i in range(len(texts)): |
|
for j in range(i+1, len(texts)): |
|
similarity = wl.similarity(texts[i], texts[j]) |
|
interpretation = ( |
|
"π’ Very Similar" if similarity > 0.8 |
|
else "π‘ Moderately Similar" if similarity > 0.5 |
|
else "π΄ Different" |
|
) |
|
st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}") |
|
|
|
if __name__ == "__main__": |
|
main() |