from audiocraft.models import MusicGen import streamlit as st import os import torch import torchaudio import numpy as np import base64 from io import BytesIO @st.cache_resource def load_model(): model = MusicGen.get_pretrained("facebook/musicgen-small") return model def generate_music_tensors(description, duration: int): print("Description:", description) print("Duration:", duration) model = load_model() # Experiment with different generation parameters for improved quality model.set_generation_params( use_sampling=True, top_k=300, # Increase top_k for more diversity top_p=0.85, # Probability threshold for token sampling temperature=0.8, # Control randomness; lower values = more focused output duration=duration ) output = model.generate( descriptions=[description], progress=True, return_tokens=True ) return output[0] def save_audio_to_bytes(samples: torch.Tensor): sample_rate = 32000 assert samples.dim() == 2 or samples.dim() == 3 samples = samples.detach().cpu() if samples.dim() == 2: samples = samples[None, ...] # Save audio to a byte buffer instead of file for easier download byte_io = BytesIO() torchaudio.save(byte_io, samples, sample_rate=sample_rate, format="wav") byte_io.seek(0) # Reset buffer position to the beginning for reading return byte_io st.set_page_config( page_icon=":musical_note:", page_title="Music Gen" ) def main(): st.title("Your Music") with st.expander("See Explanation"): st.write("App is developed using Meta's Audiocraft Music Gen model. Write a description and we will generate audio.") text_area = st.text_area("Enter description") time_slider = st.slider("Select time duration (seconds)", 2, 20, 10) if text_area and time_slider: st.json( { "Description": text_area, "Selected duration": time_slider } st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)") ) st.subheader("Generated Music") music_tensors = generate_music_tensors(text_area, time_slider) # Save to byte buffer for download audio_file = save_audio_to_bytes(music_tensors) # Play and download audio st.audio(audio_file, format="audio/wav") st.download_button( label="Download Audio", data=audio_file, file_name="generated_music.wav", mime="audio/wav" ) if __name__ == "__main__": main()