Spaces:
Sleeping
Sleeping
File size: 3,388 Bytes
6204178 cae9ed9 6204178 c77f36f 6204178 416ff82 6204178 50b6be3 6204178 50b6be3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import streamlit as st
from qdrant_client import QdrantClient
from transformers import pipeline
from audiocraft.models import MusicGen
import os
import torch
# import baseten
st.title("Music Recommendation App")
st.subheader("A :red[Generative AI]-to-Real Music Approach")
st.markdown("""
The purpose of this app is to help creative people explore the possibilities of Generative AI in the music
domain, while comparing their creations to music made by people with all sorts of instruments.
There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and
Qdrant for our vector database.
""")
client = QdrantClient(
"https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333",
api_key=os.environ['QDRANT_API_KEY'],
)
# classifier = baseten.deployed_model_id('20awxxq')
classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device)
model = MusicGen.get_pretrained('small')
val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5)
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=val1
)
music_prompt = st.text_input(
label="Music Prompt",
value="Fast-paced bachata in the style of Romeo Santos."
)
if st.button("Generate Some Music!"):
with st.spinner("Wait for it..."):
output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy()
st.success("Done! :)")
st.audio(output, sample_rate=32000)
genres = classifier(output)
if genres:
st.markdown("## Best Prediction")
col1, col2 = st.columns(2, gap="small")
col1.subheader(genres[0]['label'])
col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%")
st.markdown("### Other Predictions")
col3, col4 = st.columns(2, gap="small")
for idx, genre in enumerate(genres[1:]):
if idx % 2 == 0:
col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
else:
col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
features = classifier.feature_extractor(
output, sampling_rate=16_000, return_tensors="pt", padding=True,
return_attention_mask=True, max_length=16_000, truncation=True
)
with torch.no_grad():
vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0]
results = client.search(
collection_name="music_vectors",
query_vector=vectr.tolist(),
limit=10
)
st.markdown("## Real Recommendations")
col5, col6 = st.columns(2)
for idx, result in enumerate(results):
if idx % 2 == 0:
col5.header(f"Genre: {result.payload['genre']}")
col5.markdown(f"### Artist: {result.payload['artist']}")
col5.markdown(f"#### Song name: {result.payload['name']}")
try:
col5.audio(result.payload["urls"])
except:
continue
else:
col6.header(f"Genre: {result.payload['genre']}")
col6.markdown(f"### Artist: {result.payload['artist']}")
col6.markdown(f"#### Song name: {result.payload['name']}")
try:
col6.audio(result.payload["urls"])
except:
continue |