File size: 2,408 Bytes
bc62b2b
 
0b3a9f2
bc62b2b
 
0bb0ef0
bc62b2b
 
 
fe49032
bc62b2b
 
fe49032
 
 
 
bc62b2b
 
 
0b3a9f2
bc62b2b
 
 
fe49032
bc62b2b
 
 
 
 
 
0bb0ef0
 
0b3a9f2
fe49032
bc62b2b
0b3a9f2
 
 
 
 
 
 
bc62b2b
 
 
 
 
 
 
 
 
 
0b3a9f2
0bb0ef0
fe49032
0b3a9f2
0bb0ef0
bc62b2b
0b3a9f2
 
 
 
 
 
 
bc62b2b
fe49032
0b3a9f2
 
 
 
 
 
 
 
0bb0ef0
 
0b3a9f2
0bb0ef0
 
 
fe49032
 
0bb0ef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from audiocraft.models import MusicGen
import streamlit as st
import os
import torch
import torchaudio
from io import BytesIO

@st.cache_resource
def load_model():
    model = MusicGen.get_pretrained("facebook/musicgen-small")
    return model

def generate_music_tensors(description, duration: int):
    print("Description:", description)
    print("Duration:", duration)
    model = load_model()

    model.set_generation_params(
        use_sampling=True,
        top_k=250,
        duration=duration
    )

    output = model.generate(
        descriptions=[description],
        progress=True,
        return_tokens=True
    )
    return output[0]

def save_audio_to_bytes(samples: torch.Tensor):
    sample_rate = 32000
    assert samples.dim() == 2 or samples.dim() == 3
    samples = samples.detach().cpu()

    if samples.dim() == 2:
        samples = samples[None, ...]  # Add batch dimension if missing
    
    audio_buffer = BytesIO()
    torchaudio.save(audio_buffer, samples[0], sample_rate=sample_rate, format="wav")
    audio_buffer.seek(0)  # Move to the start of the buffer
    return audio_buffer

st.set_page_config(
    page_icon=":musical_note:",
    page_title="Music Gen"
)

def main():
    st.title("Your Music")

    with st.expander("See Explanation"):
        st.write("This app uses Meta's Audiocraft Music Gen model to generate audio based on your description.")

    text_area = st.text_area("Enter description")
    time_slider = st.slider("Select time duration (seconds)", 2, 20, 5)

    if text_area and time_slider:
        st.json(
            {
                "Description": text_area,
                "Selected duration": time_slider
            }
        st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)")
        )
        st.subheader("Generated Music")
        music_tensors = generate_music_tensors(text_area, time_slider)
        
        # Convert audio to bytes for playback and download
        audio_buffer = save_audio_to_bytes(music_tensors)
        
        # Play audio
        st.audio(audio_buffer, format="audio/wav")
        
        # Download button for audio
        st.download_button(
            label="Download Audio",
            data=audio_buffer,
            file_name="generated_music.wav",
            mime="audio/wav"
        )

if __name__ == "__main__":
    main()