Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
from transformers import AutoModel | |
import plotly.graph_objects as go | |
from sklearn.manifold import MDS | |
import pandas as pd | |
import torch | |
# Page configuration | |
st.set_page_config( | |
page_title="Jina Embeddings Explorer", | |
page_icon="🔮", | |
layout="wide" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.title-font { | |
font-size: 28px !important; | |
font-weight: bold; | |
color: #2c3e50; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def load_model(): | |
return AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) | |
model = load_model() | |
def get_embeddings(texts, task="text-matching"): | |
"""Get embeddings using Jina v3 model""" | |
with torch.no_grad(): | |
embeddings = model.encode(texts, task=task) | |
return embeddings | |
def create_similarity_based_visualization(texts, task="text-matching"): | |
"""Create visualization based on similarity distances""" | |
n = len(texts) | |
# Get embeddings | |
embeddings = get_embeddings(texts, task=task) | |
# Calculate similarity matrix using cosine similarity | |
similarity_matrix = np.zeros((n, n)) | |
for i in range(n): | |
for j in range(n): | |
similarity_matrix[i][j] = np.dot(embeddings[i], embeddings[j]) / ( | |
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])) | |
# Convert similarities to distances | |
distance_matrix = 1 - similarity_matrix | |
# Use MDS for visualization | |
mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42) | |
coords = mds.fit_transform(distance_matrix) | |
# Create 3D visualization | |
fig = go.Figure() | |
# Add points | |
fig.add_trace(go.Scatter3d( | |
x=coords[:, 0], | |
y=coords[:, 1], | |
z=coords[:, 2], | |
mode='markers+text', | |
text=texts, | |
textposition='top center', | |
marker=dict( | |
size=10, | |
color=list(range(len(texts))), | |
colorscale='Viridis', | |
opacity=0.8 | |
), | |
name='Texts' | |
)) | |
# Add lines between points | |
for i in range(n): | |
for j in range(i+1, n): | |
opacity = max(0.1, min(1.0, similarity_matrix[i,j])) | |
fig.add_trace(go.Scatter3d( | |
x=[coords[i,0], coords[j,0]], | |
y=[coords[i,1], coords[j,1]], | |
z=[coords[i,2], coords[j,2]], | |
mode='lines', | |
line=dict( | |
color='gray', | |
width=2 | |
), | |
opacity=opacity, | |
showlegend=False, | |
hoverinfo='skip' | |
)) | |
fig.update_layout( | |
title=f"3D Similarity Visualization (Task: {task})", | |
scene=dict( | |
xaxis_title="Dimension 1", | |
yaxis_title="Dimension 2", | |
zaxis_title="Dimension 3", | |
camera=dict( | |
up=dict(x=0, y=0, z=1), | |
center=dict(x=0, y=0, z=0), | |
eye=dict(x=1.5, y=1.5, z=1.5) | |
) | |
), | |
height=700 | |
) | |
return fig, similarity_matrix | |
def main(): | |
st.title("🔮 Jina Embeddings v3 Explorer") | |
st.markdown("<p class='title-font'>Explore text similarities using state-of-the-art embeddings</p>", | |
unsafe_allow_html=True) | |
with st.expander("ℹ️ About Jina Embeddings v3", expanded=True): | |
st.markdown(""" | |
This tool uses Jina Embeddings v3, a powerful multilingual embedding model that supports: | |
- Multiple tasks: text-matching, retrieval, classification, separation | |
- Long sequences: up to 8192 tokens | |
- 30+ languages | |
- State-of-the-art performance | |
""") | |
# Task selection | |
task = st.selectbox( | |
"Select Task", | |
["text-matching", "retrieval.query", "retrieval.passage", "separation", "classification"], | |
help="Different tasks optimize embeddings for specific use cases" | |
) | |
# Example templates | |
examples = { | |
"Similar Concepts": [ | |
"I love programming in Python", | |
"Coding with Python is amazing", | |
"Software development is fun", | |
"I enjoy writing code" | |
], | |
"Multilingual": [ | |
"Hello, how are you?", | |
"Hola, ¿cómo estás?", | |
"Bonjour, comment allez-vous?", | |
"你好,你好吗?" | |
], | |
"Technical Concepts": [ | |
"Machine learning is a subset of artificial intelligence", | |
"AI systems can learn from data", | |
"Neural networks process information", | |
"Deep learning models require training" | |
] | |
} | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
selected_example = st.selectbox("Choose an example set:", list(examples.keys())) | |
with col2: | |
if st.button("Load Example"): | |
st.session_state.texts = examples[selected_example] | |
# Text input | |
num_texts = st.slider("Number of texts:", 2, 6, 4) | |
texts = [] | |
for i in range(num_texts): | |
default_text = (examples[selected_example][i] | |
if selected_example in examples and i < len(examples[selected_example]) | |
else f"Example text {i+1}") | |
text = st.text_area( | |
f"Text {i+1}", | |
value=default_text, | |
height=100, | |
key=f"text_{i}" | |
) | |
texts.append(text) | |
if st.button("Analyze Texts", type="primary"): | |
if all(texts): | |
fig, similarity_matrix = create_similarity_based_visualization(texts, task) | |
# Display visualization | |
st.plotly_chart(fig, use_container_width=True) | |
# Show similarity matrix | |
st.markdown("### Similarity Matrix") | |
fig_matrix = go.Figure(data=go.Heatmap( | |
z=similarity_matrix, | |
x=[f"Text {i+1}" for i in range(len(texts))], | |
y=[f"Text {i+1}" for i in range(len(texts))], | |
colorscale='Viridis', | |
text=np.round(similarity_matrix, 3), | |
texttemplate='%{text}', | |
textfont={"size": 12}, | |
)) | |
fig_matrix.update_layout( | |
title=f"Similarity Matrix (Task: {task})", | |
height=400 | |
) | |
st.plotly_chart(fig_matrix, use_container_width=True) | |
# Interpretation | |
st.markdown("### 📊 Similarity Analysis") | |
for i in range(len(texts)): | |
for j in range(i+1, len(texts)): | |
similarity = similarity_matrix[i][j] | |
interpretation = ( | |
"🟢 Very Similar" if similarity > 0.8 | |
else "🟡 Moderately Similar" if similarity > 0.5 | |
else "🔴 Different" | |
) | |
st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}") | |
if __name__ == "__main__": | |
main() |