annapurnapadmaprema-ji commited on
Commit
0bb0ef0
1 Parent(s): fe49032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import torchaudio
6
  import numpy as np
7
  import base64
 
8
 
9
  @st.cache_resource
10
  def load_model():
@@ -16,9 +17,12 @@ def generate_music_tensors(description, duration: int):
16
  print("Duration:", duration)
17
  model = load_model()
18
 
 
19
  model.set_generation_params(
20
  use_sampling=True,
21
- top_k=250,
 
 
22
  duration=duration
23
  )
24
 
@@ -29,27 +33,19 @@ def generate_music_tensors(description, duration: int):
29
  )
30
  return output[0]
31
 
32
- def save_audio(samples: torch.Tensor):
33
- sample_rate = 32000 # corrected to integer
34
- save_path = "audio_output/"
35
- os.makedirs(save_path, exist_ok=True) # ensure directory exists
36
-
37
  assert samples.dim() == 2 or samples.dim() == 3
38
  samples = samples.detach().cpu()
39
 
40
  if samples.dim() == 2:
41
  samples = samples[None, ...]
42
- for idx, audio in enumerate(samples):
43
- audio_path = os.path.join(save_path, f"audio_{idx}.wav")
44
- torchaudio.save(audio_path, audio, sample_rate)
45
- return os.path.join(save_path, "audio_0.wav")
46
-
47
- def get_binary_file_downloader_html(bin_file, file_label='File'):
48
- with open(bin_file, 'rb') as f:
49
- data = f.read()
50
- bin_str = base64.b64encode(data).decode()
51
- href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label} from here</a>'
52
- return href
53
 
54
  st.set_page_config(
55
  page_icon=":musical_note:",
@@ -60,25 +56,33 @@ def main():
60
  st.title("Your Music")
61
 
62
  with st.expander("See Explanation"):
63
- st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
64
-
65
  text_area = st.text_area("Enter description")
66
- time_slider = st.slider("Select time duration(s)", 2, 5, 20)
67
-
68
  if text_area and time_slider:
69
  st.json(
70
  {
71
  "Description": text_area,
72
  "Selected duration": time_slider
73
  }
 
74
  )
75
  st.subheader("Generated Music")
76
  music_tensors = generate_music_tensors(text_area, time_slider)
77
- audio_file_path = save_audio(music_tensors)
78
- audio_file = open(audio_file_path, 'rb')
79
- audio_bytes = audio_file.read()
80
- st.audio(audio_bytes)
81
- st.markdown(get_binary_file_downloader_html(audio_file_path, 'Audio'), unsafe_allow_html=True)
 
 
 
 
 
 
 
82
 
83
  if __name__ == "__main__":
84
- main()
 
5
  import torchaudio
6
  import numpy as np
7
  import base64
8
+ from io import BytesIO
9
 
10
  @st.cache_resource
11
  def load_model():
 
17
  print("Duration:", duration)
18
  model = load_model()
19
 
20
+ # Experiment with different generation parameters for improved quality
21
  model.set_generation_params(
22
  use_sampling=True,
23
+ top_k=300, # Increase top_k for more diversity
24
+ top_p=0.85, # Probability threshold for token sampling
25
+ temperature=0.8, # Control randomness; lower values = more focused output
26
  duration=duration
27
  )
28
 
 
33
  )
34
  return output[0]
35
 
36
+ def save_audio_to_bytes(samples: torch.Tensor):
37
+ sample_rate = 32000
 
 
 
38
  assert samples.dim() == 2 or samples.dim() == 3
39
  samples = samples.detach().cpu()
40
 
41
  if samples.dim() == 2:
42
  samples = samples[None, ...]
43
+
44
+ # Save audio to a byte buffer instead of file for easier download
45
+ byte_io = BytesIO()
46
+ torchaudio.save(byte_io, samples, sample_rate=sample_rate, format="wav")
47
+ byte_io.seek(0) # Reset buffer position to the beginning for reading
48
+ return byte_io
 
 
 
 
 
49
 
50
  st.set_page_config(
51
  page_icon=":musical_note:",
 
56
  st.title("Your Music")
57
 
58
  with st.expander("See Explanation"):
59
+ st.write("App is developed using Meta's Audiocraft Music Gen model. Write a description and we will generate audio.")
60
+
61
  text_area = st.text_area("Enter description")
62
+ time_slider = st.slider("Select time duration (seconds)", 2, 20, 10)
63
+
64
  if text_area and time_slider:
65
  st.json(
66
  {
67
  "Description": text_area,
68
  "Selected duration": time_slider
69
  }
70
+ st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)")
71
  )
72
  st.subheader("Generated Music")
73
  music_tensors = generate_music_tensors(text_area, time_slider)
74
+
75
+ # Save to byte buffer for download
76
+ audio_file = save_audio_to_bytes(music_tensors)
77
+
78
+ # Play and download audio
79
+ st.audio(audio_file, format="audio/wav")
80
+ st.download_button(
81
+ label="Download Audio",
82
+ data=audio_file,
83
+ file_name="generated_music.wav",
84
+ mime="audio/wav"
85
+ )
86
 
87
  if __name__ == "__main__":
88
+ main()