MusicFromVibe / app.py
multimodalart's picture
Update app.py
16d9442 verified
raw
history blame
No virus
1.93 kB
import gradio as gr
import torch
import pickle
from transformers import ClapModel, ClapProcessor
from sklearn.metrics.pairwise import cosine_similarity
import spaces
def load_results_from_pickle(input_file):
with open(input_file, 'rb') as f:
return pickle.load(f)
@spaces.GPU
def compare_text_to_audio_embeddings(text, pickle_file):
model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")
# Generate text embedding
text_inputs = processor(text=text, return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs.to(0))
text_embedding = text_embedding.cpu().numpy()
# Load audio embeddings
audio_embeddings = load_results_from_pickle(pickle_file)
# Compare embeddings
similarities = []
for item in audio_embeddings:
similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
similarities.append((item['filename'], item["url"], similarity))
# Sort by similarity (highest first)
similarities.sort(key=lambda x: x[2], reverse=True)
return similarities
def get_matches(text_query):
matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")
# Format the output
output = f"Top 5 matches for '{text_query}':\n\n"
for filename, url, similarity in matches[:5]:
output += f"{filename}, {url}: {similarity:.4f}\n"
return output
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text to Audio Comparison")
with gr.Row():
text_input = gr.Textbox(label="Enter your text query")
output = gr.Textbox(label="Results", lines=10)
submit_button = gr.Button("Submit")
submit_button.click(fn=get_matches, inputs=text_input, outputs=output)
# Launch the app
demo.launch()