annapurnapadmaprema-ji commited on
Commit
bc62b2b
1 Parent(s): d400b65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -36
app.py CHANGED
@@ -1,37 +1,81 @@
1
- from audiocraft.models import MusicGen
2
- import streamlit as st
3
- import os
4
- import torch
5
- import torchaudio
6
- import numpy as np
7
- import base64
8
-
9
- @st.cache_resource
10
- def load_model():
11
- model=MusicGen.get_pretrained("facebook/musicgen-small")
12
- return model
13
-
14
- st.set_page_config(
15
- page_icon=":musical_note:",
16
- page_title="Music Gen"
17
- )
18
-
19
- def main():
20
- st.title("Your Music")
21
-
22
- with st.expander("See Explanation"):
23
- st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
24
- text_area=st.text_area("Enter description")
25
- time_slider=st.slider("Select time duration(s)",2,5,20)
26
-
27
- if text_area and time_slider:
28
- st.json(
29
- {
30
- "Description":text_area,
31
- "Selected duration:":time_slider
32
- }
33
- )
34
- st.subheader("Generated Music")
35
-
36
- if __name__=="__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  main()
 
1
+ from audiocraft.models import MusicGen
2
+ import streamlit as st
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ import base64
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ model=MusicGen.get_pretrained("facebook/musicgen-small")
12
+ return model
13
+
14
+ def generate_music_tensors(description,duration:int):
15
+ print("Description:",description)
16
+ print("Duration:",duration)
17
+ model=load_model()
18
+
19
+ model.set_generation_params(
20
+ use_sampling=True,
21
+ top_k=250,
22
+ duration=duration
23
+ )
24
+
25
+ output=model.generate(
26
+ descriptions=[description],
27
+ progress=True,
28
+ return_tokens=True
29
+ )
30
+ return output[0]
31
+
32
+ def save_audio(samples:torch.tensor):
33
+ sample_rate=32000,
34
+ save_path="audio_output/"
35
+
36
+ assert samples.dim()==2 or samples.dim()==3
37
+ samples=samples.detach().cpu()
38
+
39
+ if samples.dim()==2:
40
+ samples=samples[None,...]
41
+ for idx,audio in enumerate(samples):
42
+ audio_path=os.path.join(save_path,f"audio_{idx}.wav")
43
+ torchaudio.save(audio_path,audio,sample_rate)
44
+
45
+ def get_binary_file_downloader_html(bin_file,file_label='File'):
46
+ with open(bin_file,'rb') as f:
47
+ data=f.read()
48
+ bin_str=base64.b64encode(data).decode()
49
+ href=f'<a href="data:application/octet-stream;base64,{bin_str} download {(bin_file)}">Download {file_label} from here</a>'
50
+ return href
51
+
52
+ st.set_page_config(
53
+ page_icon=":musical_note:",
54
+ page_title="Music Gen"
55
+ )
56
+
57
+ def main():
58
+ st.title("Your Music")
59
+
60
+ with st.expander("See Explanation"):
61
+ st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
62
+ text_area=st.text_area("Enter description")
63
+ time_slider=st.slider("Select time duration(s)",2,5,20)
64
+
65
+ if text_area and time_slider:
66
+ st.json(
67
+ {
68
+ "Description":text_area,
69
+ "Selected duration:":time_slider
70
+ }
71
+ )
72
+ st.subheader("Generated Music")
73
+ music_tensors=generate_music_tensors(text_area,time_slider)
74
+ save_music_file=save_audio(music_tensors)
75
+ audio_file_path='audio_output/audio_0.wav'
76
+ audio_file=open(audio_file_path,'rb')
77
+ audio_bytes=audio_file.read()
78
+ st.audio(audio_bytes)
79
+ st.markdown(get_binary_file_downloader_html,audio_file_path,'Audio',unsafe_allow_html=True)
80
+ if __name__=="__main__":
81
  main()