|
import streamlit as st |
|
import numpy as np |
|
from wordllama import WordLlama |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from sklearn.decomposition import PCA |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config( |
|
page_title="WordLlama Explorer", |
|
page_icon="π¦", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main { |
|
background-color: #f8f9fa; |
|
} |
|
.stTabs [data-baseweb="tab-list"] { |
|
gap: 24px; |
|
} |
|
.stTabs [data-baseweb="tab"] { |
|
height: 50px; |
|
padding-left: 20px; |
|
padding-right: 20px; |
|
} |
|
.title-font { |
|
font-size: 28px !important; |
|
font-weight: bold; |
|
color: #2c3e50; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_wordllama(): |
|
return WordLlama.load() |
|
|
|
wl = load_wordllama() |
|
|
|
def create_3d_visualization(texts, embeddings): |
|
|
|
pca = PCA(n_components=3) |
|
embeddings_3d = pca.fit_transform(embeddings) |
|
|
|
|
|
df = pd.DataFrame( |
|
embeddings_3d, |
|
columns=['X', 'Y', 'Z'] |
|
) |
|
df['text'] = texts |
|
|
|
fig = px.scatter_3d( |
|
df, x='X', y='Y', z='Z', |
|
text='text', |
|
title='Word Embeddings in 3D Space' |
|
) |
|
|
|
fig.update_traces( |
|
marker=dict(size=8, opacity=0.8), |
|
textposition='top center' |
|
) |
|
fig.update_layout( |
|
scene=dict( |
|
xaxis_title='Component 1', |
|
yaxis_title='Component 2', |
|
zaxis_title='Component 3' |
|
), |
|
height=700 |
|
) |
|
return fig |
|
|
|
def main(): |
|
st.title("π¦ WordLlama Embedding Explorer") |
|
st.markdown("<p class='title-font'>Explore the power of WordLlama embeddings</p>", |
|
unsafe_allow_html=True) |
|
|
|
tabs = st.tabs(["π« Similarity Explorer", "π― Document Ranking", "π Fuzzy Deduplication"]) |
|
|
|
with tabs[0]: |
|
st.markdown("### Compare Text Similarity") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
text1 = st.text_area("First Text", value="I love programming in Python", height=100) |
|
with col2: |
|
text2 = st.text_area("Second Text", value="Coding with Python is amazing", height=100) |
|
|
|
if st.button("Calculate Similarity", key="sim_button"): |
|
similarity = wl.similarity(text1, text2) |
|
|
|
st.markdown("### Similarity Score") |
|
st.metric( |
|
label="Cosine Similarity", |
|
value=f"{similarity:.4f}", |
|
help="Score ranges from 0 (different) to 1 (identical)" |
|
) |
|
|
|
|
|
embeddings = wl.embed([text1, text2]) |
|
st.plotly_chart( |
|
create_3d_visualization([text1, text2], embeddings), |
|
use_container_width=True |
|
) |
|
|
|
with tabs[1]: |
|
st.markdown("### Rank Documents by Similarity") |
|
|
|
query = st.text_area("Query Text", value="I went to the car", height=100) |
|
|
|
|
|
st.markdown("### Enter Documents to Rank") |
|
num_docs = st.slider("Number of documents:", 2, 6, 4) |
|
|
|
documents = [] |
|
for i in range(num_docs): |
|
doc = st.text_area(f"Document {i+1}", |
|
value=f"Example document {i+1}", |
|
height=50, |
|
key=f"doc_{i}") |
|
documents.append(doc) |
|
|
|
if st.button("Rank Documents", key="rank_button"): |
|
ranked_docs = wl.rank(query, documents) |
|
|
|
st.markdown("### Ranking Results") |
|
for doc, score in ranked_docs: |
|
st.markdown(f""" |
|
<div style='padding: 10px; margin: 5px; background-color: #f0f2f6; border-radius: 5px;'> |
|
<b>Score: {score:.4f}</b><br> |
|
{doc} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
all_texts = [query] + documents |
|
embeddings = wl.embed(all_texts) |
|
st.plotly_chart( |
|
create_3d_visualization(all_texts, embeddings), |
|
use_container_width=True |
|
) |
|
|
|
with tabs[2]: |
|
st.markdown("### Fuzzy Deduplication") |
|
st.markdown(""" |
|
Remove similar documents based on a similarity threshold. |
|
Documents with similarity above the threshold will be considered duplicates. |
|
""") |
|
|
|
|
|
st.markdown("### Enter Documents") |
|
num_dedup_docs = st.slider("Number of documents:", 2, 8, 4, key="dedup_slider") |
|
|
|
dedup_docs = [] |
|
for i in range(num_dedup_docs): |
|
doc = st.text_area(f"Document {i+1}", |
|
value=f"Example document {i+1}", |
|
height=50, |
|
key=f"dedup_doc_{i}") |
|
dedup_docs.append(doc) |
|
|
|
threshold = st.slider("Similarity Threshold:", 0.0, 1.0, 0.8) |
|
|
|
if st.button("Find Duplicates", key="dedup_button"): |
|
unique_docs = wl.deduplicate(dedup_docs, threshold=threshold) |
|
|
|
st.markdown("### Results") |
|
st.markdown(f"Found {len(unique_docs)} unique documents:") |
|
|
|
for doc in unique_docs: |
|
st.markdown(f""" |
|
<div style='padding: 10px; margin: 5px; background-color: #f0f2f6; border-radius: 5px;'> |
|
{doc} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
embeddings = wl.embed(dedup_docs) |
|
st.plotly_chart( |
|
create_3d_visualization(dedup_docs, embeddings), |
|
use_container_width=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |