Music_Generator / app.py
annapurnapadmaprema-ji's picture
Update app.py
0bb0ef0 verified
raw
history blame
2.7 kB
from audiocraft.models import MusicGen
import streamlit as st
import os
import torch
import torchaudio
import numpy as np
import base64
from io import BytesIO
@st.cache_resource
def load_model():
model = MusicGen.get_pretrained("facebook/musicgen-small")
return model
def generate_music_tensors(description, duration: int):
print("Description:", description)
print("Duration:", duration)
model = load_model()
# Experiment with different generation parameters for improved quality
model.set_generation_params(
use_sampling=True,
top_k=300, # Increase top_k for more diversity
top_p=0.85, # Probability threshold for token sampling
temperature=0.8, # Control randomness; lower values = more focused output
duration=duration
)
output = model.generate(
descriptions=[description],
progress=True,
return_tokens=True
)
return output[0]
def save_audio_to_bytes(samples: torch.Tensor):
sample_rate = 32000
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
# Save audio to a byte buffer instead of file for easier download
byte_io = BytesIO()
torchaudio.save(byte_io, samples, sample_rate=sample_rate, format="wav")
byte_io.seek(0) # Reset buffer position to the beginning for reading
return byte_io
st.set_page_config(
page_icon=":musical_note:",
page_title="Music Gen"
)
def main():
st.title("Your Music")
with st.expander("See Explanation"):
st.write("App is developed using Meta's Audiocraft Music Gen model. Write a description and we will generate audio.")
text_area = st.text_area("Enter description")
time_slider = st.slider("Select time duration (seconds)", 2, 20, 10)
if text_area and time_slider:
st.json(
{
"Description": text_area,
"Selected duration": time_slider
}
st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)")
)
st.subheader("Generated Music")
music_tensors = generate_music_tensors(text_area, time_slider)
# Save to byte buffer for download
audio_file = save_audio_to_bytes(music_tensors)
# Play and download audio
st.audio(audio_file, format="audio/wav")
st.download_button(
label="Download Audio",
data=audio_file,
file_name="generated_music.wav",
mime="audio/wav"
)
if __name__ == "__main__":
main()