cleanup
Browse files- generate.py +15 -98
- generation_utils.py +5 -4
generate.py
CHANGED
@@ -190,7 +190,6 @@ class GenerateMidiText:
|
|
190 |
verbose=True,
|
191 |
expected_length=None,
|
192 |
):
|
193 |
-
|
194 |
"""generate until the TRACK_END token is reached
|
195 |
full_piece = input_prompt + generated"""
|
196 |
if expected_length is None:
|
@@ -270,14 +269,17 @@ class GenerateMidiText:
|
|
270 |
|
271 |
def generate_piece(self, instrument_list, density_list, temperature_list):
|
272 |
"""generate a sequence with mutiple tracks
|
273 |
-
- inst_list sets the list of instruments of the order of generation
|
274 |
-
- density is paired with inst_list
|
275 |
-
Each track/intrument is generated on a prompt which contains the previously generated track/instrument
|
276 |
-
This means that the first instrument is generated with less bias than the next one, and so on.
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
generated_piece = "PIECE_START "
|
283 |
for instrument, density, temperature in zip(
|
@@ -321,7 +323,9 @@ class GenerateMidiText:
|
|
321 |
for bar in track["bars"][-self.model_n_bar :]:
|
322 |
pre_promt += bar
|
323 |
pre_promt += "TRACK_END "
|
324 |
-
elif
|
|
|
|
|
325 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
326 |
pre_promt += othertracks["bars"][0]
|
327 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
@@ -378,7 +382,7 @@ class GenerateMidiText:
|
|
378 |
def check_the_piece_for_errors(self, piece: str = None):
|
379 |
|
380 |
if piece is None:
|
381 |
-
piece =
|
382 |
errors = []
|
383 |
errors.append(
|
384 |
[
|
@@ -396,91 +400,4 @@ class GenerateMidiText:
|
|
396 |
|
397 |
|
398 |
if __name__ == "__main__":
|
399 |
-
|
400 |
-
# worker
|
401 |
-
DEVICE = "cpu"
|
402 |
-
|
403 |
-
# define generation parameters
|
404 |
-
N_FILES_TO_GENERATE = 2
|
405 |
-
Temperatures_to_try = [0.7]
|
406 |
-
|
407 |
-
USE_FAMILIZED_MODEL = True
|
408 |
-
force_sequence_length = True
|
409 |
-
|
410 |
-
if USE_FAMILIZED_MODEL:
|
411 |
-
# model_repo = "misnaej/the-jam-machine-elec-famil"
|
412 |
-
# model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
|
413 |
-
|
414 |
-
# model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
|
415 |
-
# n_bar_generated = 8
|
416 |
-
|
417 |
-
model_repo = "JammyMachina/improved_4bars-mdl"
|
418 |
-
n_bar_generated = 4
|
419 |
-
instrument_promt_list = ["4", "DRUMS", "3"]
|
420 |
-
# DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
|
421 |
-
density_list = [3, 2, 2]
|
422 |
-
# temperature_list = [0.7, 0.7, 0.75]
|
423 |
-
else:
|
424 |
-
model_repo = "misnaej/the-jam-machine"
|
425 |
-
instrument_promt_list = ["30"] # , "DRUMS", "0"]
|
426 |
-
density_list = [3] # , 2, 3]
|
427 |
-
# temperature_list = [0.7, 0.5, 0.75]
|
428 |
-
pass
|
429 |
-
|
430 |
-
# define generation directory
|
431 |
-
generated_sequence_files_path = define_generation_dir(model_repo)
|
432 |
-
|
433 |
-
# load model and tokenizer
|
434 |
-
model, tokenizer = LoadModel(
|
435 |
-
model_repo, from_huggingface=True
|
436 |
-
).load_model_and_tokenizer()
|
437 |
-
|
438 |
-
# does the prompt make sense
|
439 |
-
check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
|
440 |
-
|
441 |
-
for temperature in Temperatures_to_try:
|
442 |
-
print(f"================= TEMPERATURE {temperature} =======================")
|
443 |
-
for _ in range(N_FILES_TO_GENERATE):
|
444 |
-
print(f"========================================")
|
445 |
-
# 1 - instantiate
|
446 |
-
generate_midi = GenerateMidiText(model, tokenizer)
|
447 |
-
# 0 - set the n_bar for this model
|
448 |
-
generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
|
449 |
-
# 1 - defines the instruments, densities and temperatures
|
450 |
-
# 2- generate the first 8 bars for each instrument
|
451 |
-
generate_midi.set_improvisation_level(30)
|
452 |
-
generate_midi.generate_piece(
|
453 |
-
instrument_promt_list,
|
454 |
-
density_list,
|
455 |
-
[temperature for _ in density_list],
|
456 |
-
)
|
457 |
-
# 3 - force the model to improvise
|
458 |
-
# generate_midi.set_improvisation_level(20)
|
459 |
-
# 4 - generate the next 4 bars for each instrument
|
460 |
-
# generate_midi.generate_n_more_bars(n_bar_generated)
|
461 |
-
# 5 - lower the improvisation level
|
462 |
-
generate_midi.generated_piece = (
|
463 |
-
generate_midi.get_whole_piece_from_bar_dict()
|
464 |
-
)
|
465 |
-
|
466 |
-
# print the generated sequence in terminal
|
467 |
-
print("=========================================")
|
468 |
-
print(generate_midi.generated_piece)
|
469 |
-
print("=========================================")
|
470 |
-
|
471 |
-
# write to JSON file
|
472 |
-
filename = WriteTextMidiToFile(
|
473 |
-
generate_midi,
|
474 |
-
generated_sequence_files_path,
|
475 |
-
).text_midi_to_file()
|
476 |
-
|
477 |
-
# decode the sequence to MIDI """
|
478 |
-
decode_tokenizer = get_miditok()
|
479 |
-
TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
|
480 |
-
generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
|
481 |
-
)
|
482 |
-
inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
|
483 |
-
max_time = get_max_time(inst_midi)
|
484 |
-
plot_piano_roll(inst_midi)
|
485 |
-
|
486 |
-
print("Et voilà! Your MIDI file is ready! GO JAM!")
|
|
|
190 |
verbose=True,
|
191 |
expected_length=None,
|
192 |
):
|
|
|
193 |
"""generate until the TRACK_END token is reached
|
194 |
full_piece = input_prompt + generated"""
|
195 |
if expected_length is None:
|
|
|
269 |
|
270 |
def generate_piece(self, instrument_list, density_list, temperature_list):
|
271 |
"""generate a sequence with mutiple tracks
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
Args:
|
274 |
+
- inst_list sets the list of instruments and the the order of generation
|
275 |
+
- density and
|
276 |
+
- temperature are paired with inst_list
|
277 |
+
|
278 |
+
Each track/intrument is generated based on a prompt which contains the previously generated track/instrument
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
'generated_piece' which keeps track of the entire piece
|
282 |
+
"""
|
283 |
|
284 |
generated_piece = "PIECE_START "
|
285 |
for instrument, density, temperature in zip(
|
|
|
323 |
for bar in track["bars"][-self.model_n_bar :]:
|
324 |
pre_promt += bar
|
325 |
pre_promt += "TRACK_END "
|
326 |
+
elif (
|
327 |
+
False
|
328 |
+
): # len_diff <= 0: # THIS DOES NOT WORK - It just fills things with empty bars
|
329 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
330 |
pre_promt += othertracks["bars"][0]
|
331 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
|
|
382 |
def check_the_piece_for_errors(self, piece: str = None):
|
383 |
|
384 |
if piece is None:
|
385 |
+
piece = self.get_whole_piece_from_bar_dict()
|
386 |
errors = []
|
387 |
errors.append(
|
388 |
[
|
|
|
400 |
|
401 |
|
402 |
if __name__ == "__main__":
|
403 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_utils.py
CHANGED
@@ -15,10 +15,6 @@ matplotlib.rcParams["axes.edgecolor"] = "grey"
|
|
15 |
|
16 |
|
17 |
def define_generation_dir(model_repo_path):
|
18 |
-
#### to remove later ####
|
19 |
-
if model_repo_path == "models/model_2048_fake_wholedataset":
|
20 |
-
model_repo_path = "misnaej/the-jam-machine"
|
21 |
-
#### to remove later ####
|
22 |
generated_sequence_files_path = f"midi/generated/{model_repo_path}"
|
23 |
if not os.path.exists(generated_sequence_files_path):
|
24 |
os.makedirs(generated_sequence_files_path)
|
@@ -61,6 +57,11 @@ def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
|
|
61 |
)
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
64 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
65 |
"""Forcing the generated sequence to have the expected length
|
66 |
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
|
|
|
15 |
|
16 |
|
17 |
def define_generation_dir(model_repo_path):
|
|
|
|
|
|
|
|
|
18 |
generated_sequence_files_path = f"midi/generated/{model_repo_path}"
|
19 |
if not os.path.exists(generated_sequence_files_path):
|
20 |
os.makedirs(generated_sequence_files_path)
|
|
|
57 |
)
|
58 |
|
59 |
|
60 |
+
# TODO
|
61 |
+
def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
66 |
"""Forcing the generated sequence to have the expected length
|
67 |
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
|