DexterSptizu's picture
Update app.py
8afde48 verified
raw
history blame
6.67 kB
import streamlit as st
import numpy as np
from wordllama import WordLlama
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA
import pandas as pd
# Page configuration
st.set_page_config(
page_title="WordLlama Explorer",
page_icon="πŸ¦™",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f8f9fa;
}
.stTabs [data-baseweb="tab-list"] {
gap: 24px;
}
.stTabs [data-baseweb="tab"] {
height: 50px;
padding-left: 20px;
padding-right: 20px;
}
.title-font {
font-size: 28px !important;
font-weight: bold;
color: #2c3e50;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_wordllama():
return WordLlama.load()
wl = load_wordllama()
def create_visualization(texts, embeddings):
"""Create appropriate visualization based on number of samples"""
n_samples = len(embeddings)
# Create DataFrame with original embeddings
df = pd.DataFrame(embeddings)
df['text'] = texts
if n_samples == 2:
# For 2 samples, create a 2D visualization
fig = go.Figure()
# Add points
fig.add_trace(go.Scatter(
x=[0, 1],
y=[0, wl.similarity(texts[0], texts[1])],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(size=10)
))
fig.update_layout(
title="Text Similarity Visualization",
xaxis_title="Position",
yaxis_title="Similarity",
height=400,
showlegend=False
)
else:
# For 3 or more samples, use PCA for 3D visualization
pca = PCA(n_components=min(3, n_samples))
embeddings_reduced = pca.fit_transform(embeddings)
# Pad with zeros if needed
if embeddings_reduced.shape[1] < 3:
padding = np.zeros((embeddings_reduced.shape[0], 3 - embeddings_reduced.shape[1]))
embeddings_reduced = np.hstack([embeddings_reduced, padding])
# Create DataFrame for plotting
df_plot = pd.DataFrame(
embeddings_reduced,
columns=['X', 'Y', 'Z']
)
df_plot['text'] = texts
fig = px.scatter_3d(
df_plot, x='X', y='Y', z='Z',
text='text',
title='Text Embeddings Visualization'
)
fig.update_traces(
marker=dict(size=8, opacity=0.8),
textposition='top center'
)
fig.update_layout(
scene=dict(
xaxis_title='Component 1',
yaxis_title='Component 2',
zaxis_title='Component 3'
),
height=700
)
return fig
def create_similarity_matrix(texts):
n = len(texts)
similarity_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
similarity_matrix[i][j] = wl.similarity(texts[i], texts[j])
fig = go.Figure(data=go.Heatmap(
z=similarity_matrix,
x=texts,
y=texts,
colorscale='Viridis',
text=np.round(similarity_matrix, 3),
texttemplate='%{text}',
textfont={"size": 10},
))
fig.update_layout(
title="Similarity Matrix",
height=400
)
return fig
def main():
st.title("πŸ¦™ WordLlama Embedding Explorer")
st.markdown("<p class='title-font'>Explore the power of WordLlama embeddings</p>",
unsafe_allow_html=True)
tabs = st.tabs(["πŸ’« Text Similarity", "🎯 Multi-Text Analysis"])
with tabs[0]:
st.markdown("### Compare Two Texts")
col1, col2 = st.columns(2)
with col1:
text1 = st.text_area("First Text", value="I love programming in Python", height=100)
with col2:
text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100)
if st.button("Calculate Similarity", key="sim_button"):
similarity = wl.similarity(text1, text2)
st.markdown("### Results")
col1, col2 = st.columns(2)
with col1:
st.metric(
label="Similarity Score",
value=f"{similarity:.4f}",
help="1.0 = identical, 0.0 = completely different"
)
interpretation = (
"Very Similar" if similarity > 0.8
else "Moderately Similar" if similarity > 0.5
else "Different"
)
st.info(f"Interpretation: {interpretation}")
with col2:
embeddings = wl.embed([text1, text2])
st.plotly_chart(
create_visualization([text1, text2], embeddings),
use_container_width=True
)
with tabs[1]:
st.markdown("### Analyze Multiple Texts")
num_texts = st.slider("Number of texts:", 2, 6, 3)
texts = []
for i in range(num_texts):
text = st.text_area(
f"Text {i+1}",
value=f"Example text {i+1}",
height=100,
key=f"text_{i}"
)
texts.append(text)
if st.button("Analyze Texts", key="analyze_button"):
embeddings = wl.embed(texts)
st.markdown("### Visualization")
st.plotly_chart(
create_visualization(texts, embeddings),
use_container_width=True
)
st.markdown("### Similarity Matrix")
st.plotly_chart(
create_similarity_matrix(texts),
use_container_width=True
)
# Pairwise similarity analysis
st.markdown("### Pairwise Similarities")
for i in range(len(texts)):
for j in range(i+1, len(texts)):
similarity = wl.similarity(texts[i], texts[j])
interpretation = (
"🟒 Very Similar" if similarity > 0.8
else "🟑 Moderately Similar" if similarity > 0.5
else "πŸ”΄ Different"
)
st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}")
if __name__ == "__main__":
main()