multimodalart HF staff commited on
Commit
4ba8f5e
1 Parent(s): a4b564f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pickle
4
+ from transformers import ClapModel, ClapProcessor
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ def load_results_from_pickle(input_file):
8
+ with open(input_file, 'rb') as f:
9
+ return pickle.load(f)
10
+
11
+ def compare_text_to_audio_embeddings(text, pickle_file):
12
+ model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech").to(0)
13
+ processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech")
14
+
15
+ # Generate text embedding
16
+ text_inputs = processor(text=text, return_tensors="pt", padding=True)
17
+ with torch.no_grad():
18
+ text_embedding = model.get_text_features(**text_inputs.to(0))
19
+ text_embedding = text_embedding.cpu().numpy()
20
+
21
+ # Load audio embeddings
22
+ audio_embeddings = load_results_from_pickle(pickle_file)
23
+
24
+ # Compare embeddings
25
+ similarities = []
26
+ for item in audio_embeddings:
27
+ similarity = cosine_similarity(text_embedding, item['embedding'])[0][0]
28
+ similarities.append((item['filename'], item["url"], similarity))
29
+
30
+ # Sort by similarity (highest first)
31
+ similarities.sort(key=lambda x: x[2], reverse=True)
32
+
33
+ return similarities
34
+
35
+ def get_matches(text_query):
36
+ matches = compare_text_to_audio_embeddings(text_query, "audio_embeddings_v3.pkl")
37
+
38
+ # Format the output
39
+ output = f"Top 5 matches for '{text_query}':\n\n"
40
+ for filename, url, similarity in matches[:5]:
41
+ output += f"{filename}, {url}: {similarity:.4f}\n"
42
+
43
+ return output
44
+
45
+ # Create the Gradio interface
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# Text to Audio Comparison")
48
+ with gr.Row():
49
+ text_input = gr.Textbox(label="Enter your text query")
50
+ output = gr.Textbox(label="Results", lines=10)
51
+ submit_button = gr.Button("Submit")
52
+ submit_button.click(fn=get_matches, inputs=text_input, outputs=output)
53
+
54
+ # Launch the app
55
+ demo.launch()