MusicFromVibe / app.py
multimodalart's picture
Update app.py
6d9ab6d verified
raw
history blame contribute delete
No virus
2.75 kB
import gradio as gr
import torch
import pickle
from transformers import ClapModel, ClapProcessor
from sklearn.metrics.pairwise import cosine_similarity
import spaces
def load_results_from_pickle(input_file):
with open(input_file, 'rb') as f:
return pickle.load(f)
@spaces.GPU
def compare_text_to_audio_embeddings(text, pickle_file):
model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")
# Generate text embedding
text_inputs = processor(text=text, return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs.to(0))
text_embedding = text_embedding.cpu().numpy()
# Load audio embeddings
audio_embeddings = load_results_from_pickle(pickle_file)
# Compare embeddings
similarities = []
for item in audio_embeddings:
similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
similarities.append((item['filename'], item["url"], similarity))
# Sort by similarity (highest first)
similarities.sort(key=lambda x: x[2], reverse=True)
return similarities
def get_matches(text_query):
matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")
# Format the output
output = f"<h2>Top 3 matches for '{text_query}'</h2>\n\n"
for filename, url, similarity in matches[:3]:
#output += f"{similarity:.4f}\n"
# Check if the URL is a YouTube URL
if "youtube.com" in url or "youtu.be" in url:
# Extract video ID from the URL
video_id = url.split("v=")[-1] if "v=" in url else url.split("/")[-1]
print(f"https://www.youtube.com/embed/{video_id}")
song_name = filename.rsplit('.', 1)
# Create iframe for YouTube video
output += f'{song_name[0]}: <a href="{url}">{url}</a>\n\n'
#output += f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>\n\n'
#else:
# Use regular link for non-YouTube URLs
return output
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Music from Vibe")
gr.Markdown("Match the text embedding with CLAP music embeddings against Billboard's 100 Greatest Songs of All Time")
with gr.Row():
text_input = gr.Textbox(label="Describe what you want to hear")
output = gr.Markdown(label="Results")
submit_button = gr.Button("Submit")
submit_button.click(fn=get_matches, inputs=text_input, outputs=output)
# Launch the app
demo.launch()