Music_Generator / app.py
annapurnapadmaprema-ji's picture
Update app.py
4d80d41 verified
raw
history blame
2.39 kB
from audiocraft.models import MusicGen
import streamlit as st
import torch
import torchaudio
from io import BytesIO
@st.cache_resource
def load_model():
model = MusicGen.get_pretrained("facebook/musicgen-small")
return model
def generate_music_tensors(description, duration: int):
print("Description:", description)
print("Duration:", duration)
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=300,
top_p=0.85,
temperature=0.8,
duration=duration
)
output = model.generate(
descriptions=[description],
progress=True,
return_tokens=True
)
return output[0]
def save_audio_to_bytes(samples: torch.Tensor):
sample_rate = 32000
assert samples.dim() == 3 # Expecting (batch, channels, samples)
samples = samples[0] # Take the first batch item
samples = samples.detach().cpu()
# Save audio to a byte buffer instead of a file for easier download
byte_io = BytesIO()
torchaudio.save(byte_io, samples, sample_rate=sample_rate, format="wav")
byte_io.seek(0) # Reset buffer position to the beginning for reading
return byte_io
st.set_page_config(
page_icon=":musical_note:",
page_title="Music Gen"
)
def main():
st.title("Your Music")
with st.expander("See Explanation"):
st.write("App is developed using Meta's Audiocraft Music Gen model. Write a description and we will generate audio.")
text_area = st.text_area("Enter description")
time_slider = st.slider("Select time duration (seconds)", 2, 20, 10)
if text_area and time_slider:
st.json({
"Description": text_area,
"Selected duration": time_slider
})
st.write("We will be back with your music....please enjoy doing the rest of your tasks while we come back in some time :)")
st.subheader("Generated Music")
music_tensors = generate_music_tensors(text_area, time_slider)
# Save to byte buffer for download
audio_file = save_audio_to_bytes(music_tensors)
# Play and download audio
st.audio(audio_file, format="audio/wav")
st.download_button(
label="Download Audio",
data=audio_file,
file_name="generated_music.wav",
mime="audio/wav"
)
if __name__ == "__main__":
main()