MusicFromVibe / app.py
multimodalart's picture
Update app.py
06c1a5a verified
raw
history blame
2.5 kB
import gradio as gr
import torch
import pickle
from transformers import ClapModel, ClapProcessor
from sklearn.metrics.pairwise import cosine_similarity
import spaces
def load_results_from_pickle(input_file):
with open(input_file, 'rb') as f:
return pickle.load(f)
@spaces.GPU
def compare_text_to_audio_embeddings(text, pickle_file):
model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")
# Generate text embedding
text_inputs = processor(text=text, return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs.to(0))
text_embedding = text_embedding.cpu().numpy()
# Load audio embeddings
audio_embeddings = load_results_from_pickle(pickle_file)
# Compare embeddings
similarities = []
for item in audio_embeddings:
similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
similarities.append((item['filename'], item["url"], similarity))
# Sort by similarity (highest first)
similarities.sort(key=lambda x: x[2], reverse=True)
return similarities
def get_matches(text_query):
matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")
# Format the output
output = f"## Top 5 matches for '{text_query}'\n\n"
for filename, url, similarity in matches[:3]:
output += f"{similarity:.4f}\n"
# Check if the URL is a YouTube URL
if "youtube.com" in url or "youtu.be" in url:
# Extract video ID from the URL
video_id = url.split("v=")[-1] if "v=" in url else url.split("/")[-1]
# Create iframe for YouTube video
output += f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>\n\n'
else:
# Use regular link for non-YouTube URLs
output += f'<a href="{url}">{url}</a>\n\n'
return output
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text to Audio Comparison")
with gr.Row():
text_input = gr.Textbox(label="Enter your text query")
output = gr.Markdown(label="Results")
submit_button = gr.Button("Submit")
submit_button.click(fn=get_matches, inputs=text_input, outputs=output)
# Launch the app
demo.launch()