DexterSptizu's picture
Create app.py
c58af7d verified
raw
history blame
7.1 kB
import streamlit as st
import numpy as np
from transformers import AutoModel
import plotly.graph_objects as go
from sklearn.manifold import MDS
import pandas as pd
import torch
# Page configuration
st.set_page_config(
page_title="Jina Embeddings Explorer",
page_icon="🔮",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.title-font {
font-size: 28px !important;
font-weight: bold;
color: #2c3e50;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
return AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)
model = load_model()
def get_embeddings(texts, task="text-matching"):
"""Get embeddings using Jina v3 model"""
with torch.no_grad():
embeddings = model.encode(texts, task=task)
return embeddings
def create_similarity_based_visualization(texts, task="text-matching"):
"""Create visualization based on similarity distances"""
n = len(texts)
# Get embeddings
embeddings = get_embeddings(texts, task=task)
# Calculate similarity matrix using cosine similarity
similarity_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
similarity_matrix[i][j] = np.dot(embeddings[i], embeddings[j]) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j]))
# Convert similarities to distances
distance_matrix = 1 - similarity_matrix
# Use MDS for visualization
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
coords = mds.fit_transform(distance_matrix)
# Create 3D visualization
fig = go.Figure()
# Add points
fig.add_trace(go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode='markers+text',
text=texts,
textposition='top center',
marker=dict(
size=10,
color=list(range(len(texts))),
colorscale='Viridis',
opacity=0.8
),
name='Texts'
))
# Add lines between points
for i in range(n):
for j in range(i+1, n):
opacity = max(0.1, min(1.0, similarity_matrix[i,j]))
fig.add_trace(go.Scatter3d(
x=[coords[i,0], coords[j,0]],
y=[coords[i,1], coords[j,1]],
z=[coords[i,2], coords[j,2]],
mode='lines',
line=dict(
color='gray',
width=2
),
opacity=opacity,
showlegend=False,
hoverinfo='skip'
))
fig.update_layout(
title=f"3D Similarity Visualization (Task: {task})",
scene=dict(
xaxis_title="Dimension 1",
yaxis_title="Dimension 2",
zaxis_title="Dimension 3",
camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
eye=dict(x=1.5, y=1.5, z=1.5)
)
),
height=700
)
return fig, similarity_matrix
def main():
st.title("🔮 Jina Embeddings v3 Explorer")
st.markdown("<p class='title-font'>Explore text similarities using state-of-the-art embeddings</p>",
unsafe_allow_html=True)
with st.expander("ℹ️ About Jina Embeddings v3", expanded=True):
st.markdown("""
This tool uses Jina Embeddings v3, a powerful multilingual embedding model that supports:
- Multiple tasks: text-matching, retrieval, classification, separation
- Long sequences: up to 8192 tokens
- 30+ languages
- State-of-the-art performance
""")
# Task selection
task = st.selectbox(
"Select Task",
["text-matching", "retrieval.query", "retrieval.passage", "separation", "classification"],
help="Different tasks optimize embeddings for specific use cases"
)
# Example templates
examples = {
"Similar Concepts": [
"I love programming in Python",
"Coding with Python is amazing",
"Software development is fun",
"I enjoy writing code"
],
"Multilingual": [
"Hello, how are you?",
"Hola, ¿cómo estás?",
"Bonjour, comment allez-vous?",
"你好,你好吗?"
],
"Technical Concepts": [
"Machine learning is a subset of artificial intelligence",
"AI systems can learn from data",
"Neural networks process information",
"Deep learning models require training"
]
}
col1, col2 = st.columns([3, 1])
with col1:
selected_example = st.selectbox("Choose an example set:", list(examples.keys()))
with col2:
if st.button("Load Example"):
st.session_state.texts = examples[selected_example]
# Text input
num_texts = st.slider("Number of texts:", 2, 6, 4)
texts = []
for i in range(num_texts):
default_text = (examples[selected_example][i]
if selected_example in examples and i < len(examples[selected_example])
else f"Example text {i+1}")
text = st.text_area(
f"Text {i+1}",
value=default_text,
height=100,
key=f"text_{i}"
)
texts.append(text)
if st.button("Analyze Texts", type="primary"):
if all(texts):
fig, similarity_matrix = create_similarity_based_visualization(texts, task)
# Display visualization
st.plotly_chart(fig, use_container_width=True)
# Show similarity matrix
st.markdown("### Similarity Matrix")
fig_matrix = go.Figure(data=go.Heatmap(
z=similarity_matrix,
x=[f"Text {i+1}" for i in range(len(texts))],
y=[f"Text {i+1}" for i in range(len(texts))],
colorscale='Viridis',
text=np.round(similarity_matrix, 3),
texttemplate='%{text}',
textfont={"size": 12},
))
fig_matrix.update_layout(
title=f"Similarity Matrix (Task: {task})",
height=400
)
st.plotly_chart(fig_matrix, use_container_width=True)
# Interpretation
st.markdown("### 📊 Similarity Analysis")
for i in range(len(texts)):
for j in range(i+1, len(texts)):
similarity = similarity_matrix[i][j]
interpretation = (
"🟢 Very Similar" if similarity > 0.8
else "🟡 Moderately Similar" if similarity > 0.5
else "🔴 Different"
)
st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}")
if __name__ == "__main__":
main()