import streamlit as st import numpy as np from wordllama import WordLlama import plotly.graph_objects as go import plotly.express as px from sklearn.decomposition import PCA import pandas as pd # Page configuration st.set_page_config( page_title="WordLlama Explorer", page_icon="🦙", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_wordllama(): return WordLlama.load() wl = load_wordllama() def create_visualization(texts, embeddings): """Create appropriate visualization based on number of samples""" n_samples = len(embeddings) # Create DataFrame with original embeddings df = pd.DataFrame(embeddings) df['text'] = texts if n_samples == 2: # For 2 samples, create a 2D visualization fig = go.Figure() # Add points fig.add_trace(go.Scatter( x=[0, 1], y=[0, wl.similarity(texts[0], texts[1])], mode='markers+text', text=texts, textposition='top center', marker=dict(size=10) )) fig.update_layout( title="Text Similarity Visualization", xaxis_title="Position", yaxis_title="Similarity", height=400, showlegend=False ) else: # For 3 or more samples, use PCA for 3D visualization pca = PCA(n_components=min(3, n_samples)) embeddings_reduced = pca.fit_transform(embeddings) # Pad with zeros if needed if embeddings_reduced.shape[1] < 3: padding = np.zeros((embeddings_reduced.shape[0], 3 - embeddings_reduced.shape[1])) embeddings_reduced = np.hstack([embeddings_reduced, padding]) # Create DataFrame for plotting df_plot = pd.DataFrame( embeddings_reduced, columns=['X', 'Y', 'Z'] ) df_plot['text'] = texts fig = px.scatter_3d( df_plot, x='X', y='Y', z='Z', text='text', title='Text Embeddings Visualization' ) fig.update_traces( marker=dict(size=8, opacity=0.8), textposition='top center' ) fig.update_layout( scene=dict( xaxis_title='Component 1', yaxis_title='Component 2', zaxis_title='Component 3' ), height=700 ) return fig def create_similarity_matrix(texts): n = len(texts) similarity_matrix = np.zeros((n, n)) for i in range(n): for j in range(n): similarity_matrix[i][j] = wl.similarity(texts[i], texts[j]) fig = go.Figure(data=go.Heatmap( z=similarity_matrix, x=texts, y=texts, colorscale='Viridis', text=np.round(similarity_matrix, 3), texttemplate='%{text}', textfont={"size": 10}, )) fig.update_layout( title="Similarity Matrix", height=400 ) return fig def main(): st.title("🦙 WordLlama Embedding Explorer") st.markdown("
Explore the power of WordLlama embeddings
", unsafe_allow_html=True) tabs = st.tabs(["💫 Text Similarity", "🎯 Multi-Text Analysis"]) with tabs[0]: st.markdown("### Compare Two Texts") col1, col2 = st.columns(2) with col1: text1 = st.text_area("First Text", value="I love programming in Python", height=100) with col2: text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100) if st.button("Calculate Similarity", key="sim_button"): similarity = wl.similarity(text1, text2) st.markdown("### Results") col1, col2 = st.columns(2) with col1: st.metric( label="Similarity Score", value=f"{similarity:.4f}", help="1.0 = identical, 0.0 = completely different" ) interpretation = ( "Very Similar" if similarity > 0.8 else "Moderately Similar" if similarity > 0.5 else "Different" ) st.info(f"Interpretation: {interpretation}") with col2: embeddings = wl.embed([text1, text2]) st.plotly_chart( create_visualization([text1, text2], embeddings), use_container_width=True ) with tabs[1]: st.markdown("### Analyze Multiple Texts") num_texts = st.slider("Number of texts:", 2, 6, 3) texts = [] for i in range(num_texts): text = st.text_area( f"Text {i+1}", value=f"Example text {i+1}", height=100, key=f"text_{i}" ) texts.append(text) if st.button("Analyze Texts", key="analyze_button"): embeddings = wl.embed(texts) st.markdown("### Visualization") st.plotly_chart( create_visualization(texts, embeddings), use_container_width=True ) st.markdown("### Similarity Matrix") st.plotly_chart( create_similarity_matrix(texts), use_container_width=True ) # Pairwise similarity analysis st.markdown("### Pairwise Similarities") for i in range(len(texts)): for j in range(i+1, len(texts)): similarity = wl.similarity(texts[i], texts[j]) interpretation = ( "🟢 Very Similar" if similarity > 0.8 else "🟡 Moderately Similar" if similarity > 0.5 else "🔴 Different" ) st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}") if __name__ == "__main__": main()