|
from audiocraft.models import MusicGen |
|
import streamlit as st |
|
import torch |
|
import torchaudio |
|
from io import BytesIO |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = MusicGen.get_pretrained("facebook/musicgen-small") |
|
return model |
|
|
|
def generate_music_tensors(description, duration: int): |
|
print("Description:", description) |
|
print("Duration:", duration) |
|
model = load_model() |
|
|
|
model.set_generation_params( |
|
use_sampling=True, |
|
top_k=300, |
|
top_p=0.85, |
|
temperature=0.8, |
|
duration=duration |
|
) |
|
|
|
output = model.generate( |
|
descriptions=[description], |
|
progress=True, |
|
return_tokens=True |
|
) |
|
return output[0] |
|
|
|
def save_audio_to_bytes(samples: torch.Tensor): |
|
sample_rate = 32000 |
|
assert samples.dim() == 3 |
|
samples = samples[0] |
|
samples = samples.detach().cpu() |
|
|
|
|
|
byte_io = BytesIO() |
|
torchaudio.save(byte_io, samples, sample_rate=sample_rate, format="wav") |
|
byte_io.seek(0) |
|
return byte_io |
|
|
|
st.set_page_config( |
|
page_icon=":musical_note:", |
|
page_title="Music Gen" |
|
) |
|
|
|
def main(): |
|
st.title("Your Music") |
|
|
|
with st.expander("See Explanation"): |
|
st.write("App is developed using Meta's Audiocraft Music Gen model. Write a description and we will generate audio.") |
|
|
|
text_area = st.text_area("Enter description") |
|
time_slider = st.slider("Select time duration (seconds)", 2, 20, 10) |
|
|
|
if text_area and time_slider: |
|
st.json({ |
|
"Description": text_area, |
|
"Selected duration": time_slider |
|
}) |
|
|
|
st.write("We will be back with your music....please enjoy doing the rest of your tasks while we come back in some time :)") |
|
st.subheader("Generated Music") |
|
music_tensors = generate_music_tensors(text_area, time_slider) |
|
|
|
|
|
audio_file = save_audio_to_bytes(music_tensors) |
|
|
|
|
|
st.audio(audio_file, format="audio/wav") |
|
st.download_button( |
|
label="Download Audio", |
|
data=audio_file, |
|
file_name="generated_music.wav", |
|
mime="audio/wav" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|