File size: 2,539 Bytes
4ba8f5e
 
 
 
 
16d9442
4ba8f5e
 
 
 
 
16d9442
4ba8f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cae6a1
4ba8f5e
58cf143
7cae6a1
 
 
 
 
 
 
 
190a2ea
7cae6a1
 
 
 
 
 
4ba8f5e
 
 
 
58cf143
4ba8f5e
190a2ea
06c1a5a
4ba8f5e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import torch
import pickle
from transformers import ClapModel, ClapProcessor
from sklearn.metrics.pairwise import cosine_similarity
import spaces 

def load_results_from_pickle(input_file):
    with open(input_file, 'rb') as f:
        return pickle.load(f)

@spaces.GPU
def compare_text_to_audio_embeddings(text, pickle_file):
    model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
    processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")

    # Generate text embedding
    text_inputs = processor(text=text, return_tensors="pt", padding=True)
    with torch.no_grad():
        text_embedding = model.get_text_features(**text_inputs.to(0))
    text_embedding = text_embedding.cpu().numpy()

    # Load audio embeddings
    audio_embeddings = load_results_from_pickle(pickle_file)

    # Compare embeddings
    similarities = []
    for item in audio_embeddings:
        similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
        similarities.append((item['filename'], item["url"], similarity))

    # Sort by similarity (highest first)
    similarities.sort(key=lambda x: x[2], reverse=True)

    return similarities

def get_matches(text_query):
    matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")

    # Format the output
    output = f"<h2>Top 5 matches for '{text_query}</h2>'\n\n"

    for filename, url, similarity in matches[:3]:
        output += f"{similarity:.4f}\n"
        
        # Check if the URL is a YouTube URL
        if "youtube.com" in url or "youtu.be" in url:
            # Extract video ID from the URL
            video_id = url.split("v=")[-1] if "v=" in url else url.split("/")[-1]
            print(f"https://www.youtube.com/embed/{video_id}")
            # Create iframe for YouTube video
            output += f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>\n\n'
        else:
            # Use regular link for non-YouTube URLs
            output += f'<a href="{url}">{url}</a>\n\n'

    return output

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Music from Vibe")
    with gr.Row():
        text_input = gr.Textbox(label="Describe your song")
        output = gr.Markdown(label="Results")
    submit_button = gr.Button("Submit")
    submit_button.click(fn=get_matches, inputs=text_input, outputs=output)

# Launch the app
demo.launch()