misnaej commited on
Commit
e996b6d
1 Parent(s): fc2f33e
Files changed (2) hide show
  1. generate.py +15 -98
  2. generation_utils.py +5 -4
generate.py CHANGED
@@ -190,7 +190,6 @@ class GenerateMidiText:
190
  verbose=True,
191
  expected_length=None,
192
  ):
193
-
194
  """generate until the TRACK_END token is reached
195
  full_piece = input_prompt + generated"""
196
  if expected_length is None:
@@ -270,14 +269,17 @@ class GenerateMidiText:
270
 
271
  def generate_piece(self, instrument_list, density_list, temperature_list):
272
  """generate a sequence with mutiple tracks
273
- - inst_list sets the list of instruments of the order of generation
274
- - density is paired with inst_list
275
- Each track/intrument is generated on a prompt which contains the previously generated track/instrument
276
- This means that the first instrument is generated with less bias than the next one, and so on.
277
 
278
- 'generated_piece' keeps track of the entire piece
279
- 'generated_piece' is returned by self.generate_until_track_end
280
- # it is returned by self.generate_until_track_end"""
 
 
 
 
 
 
 
281
 
282
  generated_piece = "PIECE_START "
283
  for instrument, density, temperature in zip(
@@ -321,7 +323,9 @@ class GenerateMidiText:
321
  for bar in track["bars"][-self.model_n_bar :]:
322
  pre_promt += bar
323
  pre_promt += "TRACK_END "
324
- elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
 
 
325
  # adding an empty bars at the end of the other tracks if they have not been processed yet
326
  pre_promt += othertracks["bars"][0]
327
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
@@ -378,7 +382,7 @@ class GenerateMidiText:
378
  def check_the_piece_for_errors(self, piece: str = None):
379
 
380
  if piece is None:
381
- piece = generate_midi.get_whole_piece_from_bar_dict()
382
  errors = []
383
  errors.append(
384
  [
@@ -396,91 +400,4 @@ class GenerateMidiText:
396
 
397
 
398
  if __name__ == "__main__":
399
-
400
- # worker
401
- DEVICE = "cpu"
402
-
403
- # define generation parameters
404
- N_FILES_TO_GENERATE = 2
405
- Temperatures_to_try = [0.7]
406
-
407
- USE_FAMILIZED_MODEL = True
408
- force_sequence_length = True
409
-
410
- if USE_FAMILIZED_MODEL:
411
- # model_repo = "misnaej/the-jam-machine-elec-famil"
412
- # model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
413
-
414
- # model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
415
- # n_bar_generated = 8
416
-
417
- model_repo = "JammyMachina/improved_4bars-mdl"
418
- n_bar_generated = 4
419
- instrument_promt_list = ["4", "DRUMS", "3"]
420
- # DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
421
- density_list = [3, 2, 2]
422
- # temperature_list = [0.7, 0.7, 0.75]
423
- else:
424
- model_repo = "misnaej/the-jam-machine"
425
- instrument_promt_list = ["30"] # , "DRUMS", "0"]
426
- density_list = [3] # , 2, 3]
427
- # temperature_list = [0.7, 0.5, 0.75]
428
- pass
429
-
430
- # define generation directory
431
- generated_sequence_files_path = define_generation_dir(model_repo)
432
-
433
- # load model and tokenizer
434
- model, tokenizer = LoadModel(
435
- model_repo, from_huggingface=True
436
- ).load_model_and_tokenizer()
437
-
438
- # does the prompt make sense
439
- check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
440
-
441
- for temperature in Temperatures_to_try:
442
- print(f"================= TEMPERATURE {temperature} =======================")
443
- for _ in range(N_FILES_TO_GENERATE):
444
- print(f"========================================")
445
- # 1 - instantiate
446
- generate_midi = GenerateMidiText(model, tokenizer)
447
- # 0 - set the n_bar for this model
448
- generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
449
- # 1 - defines the instruments, densities and temperatures
450
- # 2- generate the first 8 bars for each instrument
451
- generate_midi.set_improvisation_level(30)
452
- generate_midi.generate_piece(
453
- instrument_promt_list,
454
- density_list,
455
- [temperature for _ in density_list],
456
- )
457
- # 3 - force the model to improvise
458
- # generate_midi.set_improvisation_level(20)
459
- # 4 - generate the next 4 bars for each instrument
460
- # generate_midi.generate_n_more_bars(n_bar_generated)
461
- # 5 - lower the improvisation level
462
- generate_midi.generated_piece = (
463
- generate_midi.get_whole_piece_from_bar_dict()
464
- )
465
-
466
- # print the generated sequence in terminal
467
- print("=========================================")
468
- print(generate_midi.generated_piece)
469
- print("=========================================")
470
-
471
- # write to JSON file
472
- filename = WriteTextMidiToFile(
473
- generate_midi,
474
- generated_sequence_files_path,
475
- ).text_midi_to_file()
476
-
477
- # decode the sequence to MIDI """
478
- decode_tokenizer = get_miditok()
479
- TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
480
- generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
481
- )
482
- inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
483
- max_time = get_max_time(inst_midi)
484
- plot_piano_roll(inst_midi)
485
-
486
- print("Et voilà! Your MIDI file is ready! GO JAM!")
 
190
  verbose=True,
191
  expected_length=None,
192
  ):
 
193
  """generate until the TRACK_END token is reached
194
  full_piece = input_prompt + generated"""
195
  if expected_length is None:
 
269
 
270
  def generate_piece(self, instrument_list, density_list, temperature_list):
271
  """generate a sequence with mutiple tracks
 
 
 
 
272
 
273
+ Args:
274
+ - inst_list sets the list of instruments and the the order of generation
275
+ - density and
276
+ - temperature are paired with inst_list
277
+
278
+ Each track/intrument is generated based on a prompt which contains the previously generated track/instrument
279
+
280
+ Returns:
281
+ 'generated_piece' which keeps track of the entire piece
282
+ """
283
 
284
  generated_piece = "PIECE_START "
285
  for instrument, density, temperature in zip(
 
323
  for bar in track["bars"][-self.model_n_bar :]:
324
  pre_promt += bar
325
  pre_promt += "TRACK_END "
326
+ elif (
327
+ False
328
+ ): # len_diff <= 0: # THIS DOES NOT WORK - It just fills things with empty bars
329
  # adding an empty bars at the end of the other tracks if they have not been processed yet
330
  pre_promt += othertracks["bars"][0]
331
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
 
382
  def check_the_piece_for_errors(self, piece: str = None):
383
 
384
  if piece is None:
385
+ piece = self.get_whole_piece_from_bar_dict()
386
  errors = []
387
  errors.append(
388
  [
 
400
 
401
 
402
  if __name__ == "__main__":
403
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_utils.py CHANGED
@@ -15,10 +15,6 @@ matplotlib.rcParams["axes.edgecolor"] = "grey"
15
 
16
 
17
  def define_generation_dir(model_repo_path):
18
- #### to remove later ####
19
- if model_repo_path == "models/model_2048_fake_wholedataset":
20
- model_repo_path = "misnaej/the-jam-machine"
21
- #### to remove later ####
22
  generated_sequence_files_path = f"midi/generated/{model_repo_path}"
23
  if not os.path.exists(generated_sequence_files_path):
24
  os.makedirs(generated_sequence_files_path)
@@ -61,6 +57,11 @@ def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
61
  )
62
 
63
 
 
 
 
 
 
64
  def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
65
  """Forcing the generated sequence to have the expected length
66
  expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
 
15
 
16
 
17
  def define_generation_dir(model_repo_path):
 
 
 
 
18
  generated_sequence_files_path = f"midi/generated/{model_repo_path}"
19
  if not os.path.exists(generated_sequence_files_path):
20
  os.makedirs(generated_sequence_files_path)
 
57
  )
58
 
59
 
60
+ # TODO
61
+ def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
62
+ pass
63
+
64
+
65
  def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
66
  """Forcing the generated sequence to have the expected length
67
  expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""