asigalov61 commited on
Commit
8c8ea80
1 Parent(s): 56ab42f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -46
app.py CHANGED
@@ -84,7 +84,7 @@ def create_msg(name, data):
84
  return {"name": name, "data": data}
85
 
86
 
87
- def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
88
  mid_seq = []
89
  gen_events = int(gen_events)
90
  max_len = gen_events
@@ -92,55 +92,32 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
92
  disable_patch_change = False
93
  disable_channels = None
94
  if tab == 0:
95
- i = 0
96
- mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
97
- patches = {}
98
- for instr in instruments:
99
- patches[i] = patch2number[instr]
100
- i = (i + 1) if i != 8 else 10
101
- if drum_kit != "None":
102
- patches[9] = drum_kits2number[drum_kit]
103
- for i, (c, p) in enumerate(patches.items()):
104
- mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
105
- mid_seq = mid
106
- mid = np.asarray(mid, dtype=np.int64)
107
- if len(instruments) > 0:
108
- disable_patch_change = True
109
- disable_channels = [i for i in range(16) if i not in patches]
110
  elif mid is not None:
111
- mid = tokenizer.tokenize(MIDI.midi2score(mid))
112
- mid = np.asarray(mid, dtype=np.int64)
113
- mid = mid[:int(midi_events)]
114
- max_len += len(mid)
115
- for token_seq in mid:
116
- mid_seq.append(token_seq.tolist())
117
  init_msgs = [create_msg("visualizer_clear", None)]
118
  for tokens in mid_seq:
119
- init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
120
  yield mid_seq, None, None, init_msgs
121
- model = models[model_name]
122
- generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
123
- disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
124
- disable_channels=disable_channels)
125
- for i, token_seq in enumerate(generator):
126
- token_seq = token_seq.tolist()
127
- mid_seq.append(token_seq)
128
- event = tokenizer.tokens2event(token_seq)
129
- yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
130
- mid = tokenizer.detokenize(mid_seq)
131
  with open(f"output.mid", 'wb') as f:
132
- f.write(MIDI.score2midi(mid))
133
- audio = synthesis(MIDI.score2opus(mid), soundfont_path)
134
  yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
135
 
136
 
137
  def cancel_run(mid_seq):
138
  if mid_seq is None:
139
  return None, None
140
- mid = tokenizer.detokenize(mid_seq)
141
  with open(f"output.mid", 'wb') as f:
142
- f.write(MIDI.score2midi(mid))
143
- audio = synthesis(MIDI.score2opus(mid), soundfont_path)
144
  return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
145
 
146
 
@@ -174,11 +151,6 @@ class JSMsgReceiver(gr.HTML):
174
  def get_block_name(self) -> str:
175
  return "html"
176
 
177
- number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
178
- 40: "Blush", 48: "Orchestra"}
179
- patch2number = {v: k for k, v in MIDI.Number2patch.items()}
180
- drum_kits2number = {v: k for k, v in number2drum_kits.items()}
181
-
182
  if __name__ == "__main__":
183
  parser = argparse.ArgumentParser()
184
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
@@ -233,9 +205,7 @@ if __name__ == "__main__":
233
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
234
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
235
  output_midi = gr.File(label="output midi", file_types=[".mid"])
236
- run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
237
- input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
238
- input_allow_cc],
239
  [output_midi_seq, output_midi, output_audio, js_msg])
240
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
241
  app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
84
  return {"name": name, "data": data}
85
 
86
 
87
+ def run(search_prompt):
88
  mid_seq = []
89
  gen_events = int(gen_events)
90
  max_len = gen_events
 
92
  disable_patch_change = False
93
  disable_channels = None
94
  if tab == 0:
95
+ mid_seq = []
96
+
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  elif mid is not None:
98
+ mid_seq = MIDI.midi2score(mid)
99
+
 
 
 
 
100
  init_msgs = [create_msg("visualizer_clear", None)]
101
  for tokens in mid_seq:
102
+ init_msgs.append(create_msg("visualizer_append", tokens))
103
  yield mid_seq, None, None, init_msgs
104
+
105
+ for i in range(len(mid_seq)):
106
+ yield mid_seq, None, None, [create_msg("visualizer_append", mid_seq[i]), create_msg("progress", [i + 1, mid_seq[i+1]])]
107
+
 
 
 
 
 
 
108
  with open(f"output.mid", 'wb') as f:
109
+ f.write(MIDI.score2midi(mid_seq))
110
+ audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
111
  yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
112
 
113
 
114
  def cancel_run(mid_seq):
115
  if mid_seq is None:
116
  return None, None
117
+
118
  with open(f"output.mid", 'wb') as f:
119
+ f.write(MIDI.score2midi(mid_seq))
120
+ audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
121
  return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
122
 
123
 
 
151
  def get_block_name(self) -> str:
152
  return "html"
153
 
 
 
 
 
 
154
  if __name__ == "__main__":
155
  parser = argparse.ArgumentParser()
156
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
 
205
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
206
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
207
  output_midi = gr.File(label="output midi", file_types=[".mid"])
208
+ run_event = search_btn.click(run, [search_prompt],
 
 
209
  [output_midi_seq, output_midi, output_audio, js_msg])
210
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
211
  app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)