import streamlit as st import numpy as np from transformers import AutoModel import plotly.graph_objects as go from sklearn.manifold import MDS import pandas as pd import torch # Page configuration st.set_page_config( page_title="Jina Embeddings Explorer", page_icon="🔮", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): return AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) model = load_model() def get_embeddings(texts, task="text-matching"): """Get embeddings using Jina v3 model""" with torch.no_grad(): embeddings = model.encode(texts, task=task) return embeddings def create_similarity_based_visualization(texts, task="text-matching"): """Create visualization based on similarity distances""" n = len(texts) # Get embeddings embeddings = get_embeddings(texts, task=task) # Calculate similarity matrix using cosine similarity similarity_matrix = np.zeros((n, n)) for i in range(n): for j in range(n): similarity_matrix[i][j] = np.dot(embeddings[i], embeddings[j]) / ( np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])) # Convert similarities to distances distance_matrix = 1 - similarity_matrix # Use MDS for visualization mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42) coords = mds.fit_transform(distance_matrix) # Create 3D visualization fig = go.Figure() # Add points fig.add_trace(go.Scatter3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode='markers+text', text=texts, textposition='top center', marker=dict( size=10, color=list(range(len(texts))), colorscale='Viridis', opacity=0.8 ), name='Texts' )) # Add lines between points for i in range(n): for j in range(i+1, n): opacity = max(0.1, min(1.0, similarity_matrix[i,j])) fig.add_trace(go.Scatter3d( x=[coords[i,0], coords[j,0]], y=[coords[i,1], coords[j,1]], z=[coords[i,2], coords[j,2]], mode='lines', line=dict( color='gray', width=2 ), opacity=opacity, showlegend=False, hoverinfo='skip' )) fig.update_layout( title=f"3D Similarity Visualization (Task: {task})", scene=dict( xaxis_title="Dimension 1", yaxis_title="Dimension 2", zaxis_title="Dimension 3", camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=1.5, y=1.5, z=1.5) ) ), height=700 ) return fig, similarity_matrix def main(): st.title("🔮 Jina Embeddings v3 Explorer") st.markdown("

Explore text similarities using state-of-the-art embeddings

", unsafe_allow_html=True) with st.expander("ℹ️ About Jina Embeddings v3", expanded=True): st.markdown(""" This tool uses Jina Embeddings v3, a powerful multilingual embedding model that supports: - Multiple tasks: text-matching, retrieval, classification, separation - Long sequences: up to 8192 tokens - 30+ languages - State-of-the-art performance """) # Task selection task = st.selectbox( "Select Task", ["text-matching", "retrieval.query", "retrieval.passage", "separation", "classification"], help="Different tasks optimize embeddings for specific use cases" ) # Example templates examples = { "Similar Concepts": [ "I love programming in Python", "Coding with Python is amazing", "Software development is fun", "I enjoy writing code" ], "Multilingual": [ "Hello, how are you?", "Hola, ¿cómo estás?", "Bonjour, comment allez-vous?", "你好,你好吗?" ], "Technical Concepts": [ "Machine learning is a subset of artificial intelligence", "AI systems can learn from data", "Neural networks process information", "Deep learning models require training" ] } col1, col2 = st.columns([3, 1]) with col1: selected_example = st.selectbox("Choose an example set:", list(examples.keys())) with col2: if st.button("Load Example"): st.session_state.texts = examples[selected_example] # Text input num_texts = st.slider("Number of texts:", 2, 6, 4) texts = [] for i in range(num_texts): default_text = (examples[selected_example][i] if selected_example in examples and i < len(examples[selected_example]) else f"Example text {i+1}") text = st.text_area( f"Text {i+1}", value=default_text, height=100, key=f"text_{i}" ) texts.append(text) if st.button("Analyze Texts", type="primary"): if all(texts): fig, similarity_matrix = create_similarity_based_visualization(texts, task) # Display visualization st.plotly_chart(fig, use_container_width=True) # Show similarity matrix st.markdown("### Similarity Matrix") fig_matrix = go.Figure(data=go.Heatmap( z=similarity_matrix, x=[f"Text {i+1}" for i in range(len(texts))], y=[f"Text {i+1}" for i in range(len(texts))], colorscale='Viridis', text=np.round(similarity_matrix, 3), texttemplate='%{text}', textfont={"size": 12}, )) fig_matrix.update_layout( title=f"Similarity Matrix (Task: {task})", height=400 ) st.plotly_chart(fig_matrix, use_container_width=True) # Interpretation st.markdown("### 📊 Similarity Analysis") for i in range(len(texts)): for j in range(i+1, len(texts)): similarity = similarity_matrix[i][j] interpretation = ( "🟢 Very Similar" if similarity > 0.8 else "🟡 Moderately Similar" if similarity > 0.5 else "🔴 Different" ) st.write(f"{interpretation} ({similarity:.3f}): Text {i+1} vs Text {j+1}") if __name__ == "__main__": main()