File size: 1,930 Bytes
4ba8f5e
 
 
 
 
16d9442
4ba8f5e
 
 
 
 
16d9442
4ba8f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch
import pickle
from transformers import ClapModel, ClapProcessor
from sklearn.metrics.pairwise import cosine_similarity
import spaces 

def load_results_from_pickle(input_file):
    with open(input_file, 'rb') as f:
        return pickle.load(f)

@spaces.GPU
def compare_text_to_audio_embeddings(text, pickle_file):
    model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
    processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")

    # Generate text embedding
    text_inputs = processor(text=text, return_tensors="pt", padding=True)
    with torch.no_grad():
        text_embedding = model.get_text_features(**text_inputs.to(0))
    text_embedding = text_embedding.cpu().numpy()

    # Load audio embeddings
    audio_embeddings = load_results_from_pickle(pickle_file)

    # Compare embeddings
    similarities = []
    for item in audio_embeddings:
        similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
        similarities.append((item['filename'], item["url"], similarity))

    # Sort by similarity (highest first)
    similarities.sort(key=lambda x: x[2], reverse=True)

    return similarities

def get_matches(text_query):
    matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")
    
    # Format the output
    output = f"Top 5 matches for '{text_query}':\n\n"
    for filename, url, similarity in matches[:5]:
        output += f"{filename}, {url}: {similarity:.4f}\n"
    
    return output

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Text to Audio Comparison")
    with gr.Row():
        text_input = gr.Textbox(label="Enter your text query")
        output = gr.Textbox(label="Results", lines=10)
    submit_button = gr.Button("Submit")
    submit_button.click(fn=get_matches, inputs=text_input, outputs=output)

# Launch the app
demo.launch()