DexterSptizu's picture
Create app.py
ded7cbd verified
raw
history blame
5.98 kB
import streamlit as st
import numpy as np
from wordllama import WordLlama
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA
import pandas as pd
# Page configuration
st.set_page_config(
page_title="WordLlama Explorer",
page_icon="πŸ¦™",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f8f9fa;
}
.stTabs [data-baseweb="tab-list"] {
gap: 24px;
}
.stTabs [data-baseweb="tab"] {
height: 50px;
padding-left: 20px;
padding-right: 20px;
}
.title-font {
font-size: 28px !important;
font-weight: bold;
color: #2c3e50;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_wordllama():
return WordLlama.load()
wl = load_wordllama()
def create_3d_visualization(texts, embeddings):
# Reduce to 3D using PCA
pca = PCA(n_components=3)
embeddings_3d = pca.fit_transform(embeddings)
# Create DataFrame
df = pd.DataFrame(
embeddings_3d,
columns=['X', 'Y', 'Z']
)
df['text'] = texts
fig = px.scatter_3d(
df, x='X', y='Y', z='Z',
text='text',
title='Word Embeddings in 3D Space'
)
fig.update_traces(
marker=dict(size=8, opacity=0.8),
textposition='top center'
)
fig.update_layout(
scene=dict(
xaxis_title='Component 1',
yaxis_title='Component 2',
zaxis_title='Component 3'
),
height=700
)
return fig
def main():
st.title("πŸ¦™ WordLlama Embedding Explorer")
st.markdown("<p class='title-font'>Explore the power of WordLlama embeddings</p>",
unsafe_allow_html=True)
tabs = st.tabs(["πŸ’« Similarity Explorer", "🎯 Document Ranking", "πŸ” Fuzzy Deduplication"])
with tabs[0]:
st.markdown("### Compare Text Similarity")
col1, col2 = st.columns(2)
with col1:
text1 = st.text_area("First Text", value="I love programming in Python", height=100)
with col2:
text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100)
if st.button("Calculate Similarity", key="sim_button"):
similarity = wl.similarity(text1, text2)
st.markdown("### Similarity Score")
st.metric(
label="Cosine Similarity",
value=f"{similarity:.4f}",
help="Score ranges from 0 (different) to 1 (identical)"
)
# Visualize both texts in 3D space
embeddings = wl.embed([text1, text2])
st.plotly_chart(
create_3d_visualization([text1, text2], embeddings),
use_container_width=True
)
with tabs[1]:
st.markdown("### Rank Documents by Similarity")
query = st.text_area("Query Text", value="I went to the car", height=100)
# Multiple document input
st.markdown("### Enter Documents to Rank")
num_docs = st.slider("Number of documents:", 2, 6, 4)
documents = []
for i in range(num_docs):
doc = st.text_area(f"Document {i+1}",
value=f"Example document {i+1}",
height=50,
key=f"doc_{i}")
documents.append(doc)
if st.button("Rank Documents", key="rank_button"):
ranked_docs = wl.rank(query, documents)
st.markdown("### Ranking Results")
for doc, score in ranked_docs:
st.markdown(f"""
<div style='padding: 10px; margin: 5px; background-color: #f0f2f6; border-radius: 5px;'>
<b>Score: {score:.4f}</b><br>
{doc}
</div>
""", unsafe_allow_html=True)
# Visualize all texts including query
all_texts = [query] + documents
embeddings = wl.embed(all_texts)
st.plotly_chart(
create_3d_visualization(all_texts, embeddings),
use_container_width=True
)
with tabs[2]:
st.markdown("### Fuzzy Deduplication")
st.markdown("""
Remove similar documents based on a similarity threshold.
Documents with similarity above the threshold will be considered duplicates.
""")
# Document input for deduplication
st.markdown("### Enter Documents")
num_dedup_docs = st.slider("Number of documents:", 2, 8, 4, key="dedup_slider")
dedup_docs = []
for i in range(num_dedup_docs):
doc = st.text_area(f"Document {i+1}",
value=f"Example document {i+1}",
height=50,
key=f"dedup_doc_{i}")
dedup_docs.append(doc)
threshold = st.slider("Similarity Threshold:", 0.0, 1.0, 0.8)
if st.button("Find Duplicates", key="dedup_button"):
unique_docs = wl.deduplicate(dedup_docs, threshold=threshold)
st.markdown("### Results")
st.markdown(f"Found {len(unique_docs)} unique documents:")
for doc in unique_docs:
st.markdown(f"""
<div style='padding: 10px; margin: 5px; background-color: #f0f2f6; border-radius: 5px;'>
{doc}
</div>
""", unsafe_allow_html=True)
# Visualize all documents
embeddings = wl.embed(dedup_docs)
st.plotly_chart(
create_3d_visualization(dedup_docs, embeddings),
use_container_width=True
)
if __name__ == "__main__":
main()