|
from generation_utils import * |
|
import random |
|
|
|
|
|
class GenerateMidiText: |
|
"""Generating music with Class |
|
|
|
LOGIC: |
|
|
|
FOR GENERATING FROM SCRATCH: |
|
- self.generate_one_new_track() |
|
it calls |
|
- self.generate_until_track_end() |
|
|
|
FOR GENERATING NEW BARS: |
|
- self.generate_one_more_bar() |
|
it calls |
|
- self.process_prompt_for_next_bar() |
|
- self.generate_until_track_end()""" |
|
|
|
def __init__(self, model, tokenizer, piece_by_track=[]): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
self.initialize_default_parameters() |
|
self.initialize_dictionaries(piece_by_track) |
|
|
|
"""Setters""" |
|
|
|
def initialize_default_parameters(self): |
|
self.set_device() |
|
self.set_attention_length() |
|
self.generate_until = "TRACK_END" |
|
self.set_force_sequence_lenth() |
|
self.set_nb_bars_generated() |
|
self.set_improvisation_level(0) |
|
|
|
def initialize_dictionaries(self, piece_by_track): |
|
self.piece_by_track = piece_by_track |
|
|
|
def set_device(self, device="cpu"): |
|
self.device = ("cpu",) |
|
|
|
def set_attention_length(self): |
|
self.max_length = self.model.config.n_positions |
|
print( |
|
f"Attention length set to {self.max_length} -> 'model.config.n_positions'" |
|
) |
|
|
|
def set_force_sequence_lenth(self, force_sequence_length=True): |
|
self.force_sequence_length = force_sequence_length |
|
|
|
def set_improvisation_level(self, improvisation_value): |
|
self.no_repeat_ngram_size = improvisation_value |
|
print("--------------------") |
|
print(f"no_repeat_ngram_size set to {improvisation_value}") |
|
print("--------------------") |
|
|
|
def reset_temperatures(self, track_id, temperature): |
|
self.piece_by_track[track_id]["temperature"] = temperature |
|
|
|
def set_nb_bars_generated(self, n_bars=8): |
|
self.model_n_bar = n_bars |
|
|
|
""" Generation Tools - Dictionnaries """ |
|
|
|
def initiate_track_dict(self, instr, density, temperature): |
|
label = len(self.piece_by_track) |
|
self.piece_by_track.append( |
|
{ |
|
"label": f"track_{label}", |
|
"instrument": instr, |
|
"density": density, |
|
"temperature": temperature, |
|
"bars": [], |
|
} |
|
) |
|
|
|
def update_track_dict__add_bars(self, bars, track_id): |
|
"""Add bars to the track dictionnary""" |
|
for bar in self.striping_track_ends(bars).split("BAR_START "): |
|
if bar == "": |
|
continue |
|
else: |
|
if "TRACK_START" in bar: |
|
self.piece_by_track[track_id]["bars"].append(bar) |
|
else: |
|
self.piece_by_track[track_id]["bars"].append("BAR_START " + bar) |
|
|
|
def get_all_instr_bars(self, track_id): |
|
return self.piece_by_track[track_id]["bars"] |
|
|
|
def striping_track_ends(self, text): |
|
if "TRACK_END" in text: |
|
|
|
|
|
text = text.rstrip(" ").rstrip("TRACK_END") |
|
return text |
|
|
|
def get_last_generated_track(self, piece): |
|
"""Get the last track from a piece written as a single long string""" |
|
track = self.get_tracks_from_a_piece(piece)[-1] |
|
return track |
|
|
|
def get_tracks_from_a_piece(self, piece): |
|
"""Get all the tracks from a piece written as a single long string""" |
|
all_tracks = [ |
|
"TRACK_START " + the_track + "TRACK_END " |
|
for the_track in self.striping_track_ends(piece.split("TRACK_START ")[1::]) |
|
] |
|
return all_tracks |
|
|
|
def get_piece_from_track_list(self, track_list): |
|
piece = "PIECE_START " |
|
for track in track_list: |
|
piece += track |
|
return piece |
|
|
|
def get_whole_track_from_bar_dict(self, track_id): |
|
text = "" |
|
for bar in self.piece_by_track[track_id]["bars"]: |
|
text += bar |
|
text += "TRACK_END " |
|
return text |
|
|
|
@staticmethod |
|
def get_newly_generated_text(input_prompt, full_piece): |
|
return full_piece[len(input_prompt) :] |
|
|
|
def get_whole_piece_from_bar_dict(self): |
|
text = "PIECE_START " |
|
for track_id, _ in enumerate(self.piece_by_track): |
|
text += self.get_whole_track_from_bar_dict(track_id) |
|
return text |
|
|
|
def delete_one_track(self, track): |
|
self.piece_by_track.pop(track) |
|
|
|
"""Basic generation tools""" |
|
|
|
def tokenize_input_prompt(self, input_prompt, verbose=True): |
|
"""Tokenizing prompt |
|
|
|
Args: |
|
- input_prompt (str): prompt to tokenize |
|
|
|
Returns: |
|
- input_prompt_ids (torch.tensor): tokenized prompt |
|
""" |
|
if verbose: |
|
print("Tokenizing input_prompt...") |
|
|
|
return self.tokenizer.encode(input_prompt, return_tensors="pt") |
|
|
|
def generate_sequence_of_token_ids( |
|
self, |
|
input_prompt_ids, |
|
temperature, |
|
verbose=True, |
|
): |
|
""" |
|
generate a sequence of token ids based on input_prompt_ids |
|
The sequence length depends on the trained model (self.model_n_bar) |
|
""" |
|
generated_ids = self.model.generate( |
|
input_prompt_ids, |
|
max_length=self.max_length, |
|
do_sample=True, |
|
temperature=temperature, |
|
no_repeat_ngram_size=self.no_repeat_ngram_size, |
|
eos_token_id=self.tokenizer.encode(self.generate_until)[0], |
|
) |
|
|
|
if verbose: |
|
print("Generating a token_id sequence...") |
|
|
|
return generated_ids |
|
|
|
def convert_ids_to_text(self, generated_ids, verbose=True): |
|
"""converts the token_ids to text""" |
|
generated_text = self.tokenizer.decode(generated_ids[0]) |
|
if verbose: |
|
print("Converting token sequence to MidiText...") |
|
return generated_text |
|
|
|
def generate_until_track_end( |
|
self, |
|
input_prompt="PIECE_START ", |
|
instrument=None, |
|
density=None, |
|
temperature=None, |
|
verbose=True, |
|
expected_length=None, |
|
): |
|
"""generate until the TRACK_END token is reached |
|
full_piece = input_prompt + generated""" |
|
if expected_length is None: |
|
expected_length = self.model_n_bar |
|
|
|
if instrument is not None: |
|
input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} " |
|
if density is not None: |
|
input_prompt = f"{input_prompt}DENSITY={str(density)} " |
|
|
|
if instrument is None and density is not None: |
|
print("Density cannot be defined without an input_prompt instrument #TOFIX") |
|
|
|
if temperature is None: |
|
ValueError("Temperature must be defined") |
|
|
|
if verbose: |
|
print("--------------------") |
|
print( |
|
f"Generating {instrument} - Density {density} - temperature {temperature}" |
|
) |
|
bar_count_checks = False |
|
failed = 0 |
|
while not bar_count_checks: |
|
input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose) |
|
generated_tokens = self.generate_sequence_of_token_ids( |
|
input_prompt_ids, temperature, verbose=verbose |
|
) |
|
full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose) |
|
generated = self.get_newly_generated_text(input_prompt, full_piece) |
|
|
|
bar_count_checks, bar_count = bar_count_check(generated, expected_length) |
|
|
|
if not self.force_sequence_length: |
|
|
|
bar_count_checks = True |
|
|
|
if not bar_count_checks and self.force_sequence_length: |
|
|
|
if failed > -1: |
|
full_piece, bar_count_checks = forcing_bar_count( |
|
input_prompt, |
|
generated, |
|
bar_count, |
|
expected_length, |
|
) |
|
else: |
|
print('"--- Wrong length - Regenerating ---') |
|
|
|
if not bar_count_checks: |
|
failed += 1 |
|
|
|
if failed > 2: |
|
bar_count_checks = True |
|
|
|
return full_piece |
|
|
|
def generate_one_new_track( |
|
self, |
|
instrument, |
|
density, |
|
temperature, |
|
input_prompt="PIECE_START ", |
|
): |
|
self.initiate_track_dict(instrument, density, temperature) |
|
full_piece = self.generate_until_track_end( |
|
input_prompt=input_prompt, |
|
instrument=instrument, |
|
density=density, |
|
temperature=temperature, |
|
) |
|
|
|
track = self.get_last_generated_track(full_piece) |
|
self.update_track_dict__add_bars(track, -1) |
|
full_piece = self.get_whole_piece_from_bar_dict() |
|
return full_piece |
|
|
|
""" Piece generation - Basics """ |
|
|
|
def generate_piece(self, instrument_list, density_list, temperature_list): |
|
"""generate a sequence with mutiple tracks |
|
|
|
Args: |
|
- inst_list sets the list of instruments and the the order of generation |
|
- density and |
|
- temperature are paired with inst_list |
|
|
|
Each track/intrument is generated based on a prompt which contains the previously generated track/instrument |
|
|
|
Returns: |
|
'generated_piece' which keeps track of the entire piece |
|
""" |
|
|
|
generated_piece = "PIECE_START " |
|
for instrument, density, temperature in zip( |
|
instrument_list, density_list, temperature_list |
|
): |
|
generated_piece = self.generate_one_new_track( |
|
instrument, |
|
density, |
|
temperature, |
|
input_prompt=generated_piece, |
|
) |
|
|
|
|
|
self.check_the_piece_for_errors() |
|
return generated_piece |
|
|
|
""" Piece generation - Extra Bars """ |
|
|
|
def process_prompt_for_next_bar(self, track_idx, verbose=True): |
|
"""Processing the prompt for the model to generate one more bar only. |
|
The prompt containts: |
|
if not the first bar: the previous, already processed, bars of the track |
|
the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ") |
|
the last (self.model_n_bar)-1 bars of the track |
|
Args: |
|
track_idx (int): the index of the track to be processed |
|
|
|
Returns: |
|
the processed prompt for generating the next bar |
|
""" |
|
track = self.piece_by_track[track_idx] |
|
|
|
pre_promt = "PIECE_START " |
|
for i, othertrack in enumerate(self.piece_by_track): |
|
if i != track_idx: |
|
len_diff = len(othertrack["bars"]) - len(track["bars"]) |
|
if len_diff > 0: |
|
if verbose: |
|
print( |
|
f"Adding bars - {len(track['bars'][-self.model_n_bar :])} selected from SIDE track: {i} for prompt" |
|
) |
|
|
|
pre_promt += othertrack["bars"][0] |
|
for bar in track["bars"][-self.model_n_bar :]: |
|
pre_promt += bar |
|
pre_promt += "TRACK_END " |
|
elif ( |
|
False |
|
): |
|
|
|
pre_promt += othertracks["bars"][0] |
|
for bar in track["bars"][-(self.model_n_bar - 1) :]: |
|
pre_promt += bar |
|
for _ in range(abs(len_diff) + 1): |
|
pre_promt += "BAR_START BAR_END " |
|
pre_promt += "TRACK_END " |
|
|
|
|
|
|
|
processed_prompt = track["bars"][0] |
|
if verbose: |
|
print( |
|
f"Adding bars - {len(track['bars'][-(self.model_n_bar - 1) :])} selected from MAIN track: {track_idx} for prompt" |
|
) |
|
for bar in track["bars"][-(self.model_n_bar - 1) :]: |
|
|
|
processed_prompt += bar |
|
|
|
processed_prompt += "BAR_START " |
|
|
|
|
|
pre_promt = self.force_prompt_length(pre_promt, 1500) |
|
|
|
print( |
|
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---" |
|
) |
|
|
|
return pre_promt + processed_prompt |
|
|
|
def force_prompt_length(self, prompt, expected_length): |
|
"""remove one instrument/track from the prompt it too long |
|
Args: |
|
prompt (str): the prompt to be processed |
|
expected_length (int): the expected length of the prompt |
|
Returns: |
|
the truncated prompt""" |
|
if len(prompt.split(" ")) < expected_length: |
|
truncated_prompt = prompt |
|
else: |
|
tracks = self.get_tracks_from_a_piece(prompt) |
|
selected_tracks = random.sample(tracks, len(tracks) - 1) |
|
truncated_prompt = self.get_piece_from_track_list(selected_tracks) |
|
print(f"Prompt too long - deleting one track") |
|
|
|
return truncated_prompt |
|
|
|
def generate_one_more_bar(self, track_index): |
|
"""Generate one more bar from the input_prompt""" |
|
processed_prompt = self.process_prompt_for_next_bar(track_index) |
|
|
|
prompt_plus_bar = self.generate_until_track_end( |
|
input_prompt=processed_prompt, |
|
temperature=self.piece_by_track[track_index]["temperature"], |
|
expected_length=1, |
|
verbose=False, |
|
) |
|
added_bar = self.get_newly_generated_bar(prompt_plus_bar) |
|
self.update_track_dict__add_bars(added_bar, track_index) |
|
|
|
def get_newly_generated_bar(self, prompt_plus_bar): |
|
return "BAR_START " + self.striping_track_ends( |
|
prompt_plus_bar.split("BAR_START ")[-1] |
|
) |
|
|
|
def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True): |
|
"""Generate n more bars from the input_prompt""" |
|
if only_this_track is None: |
|
only_this_track |
|
|
|
print(f"================== ") |
|
print(f"Adding {n_bars} more bars to the piece ") |
|
for bar_id in range(n_bars): |
|
print(f"----- added bar #{bar_id+1} --") |
|
for i, track in enumerate(self.piece_by_track): |
|
if only_this_track is None or i == only_this_track: |
|
print(f"--------- {track['label']}") |
|
self.generate_one_more_bar(i) |
|
self.check_the_piece_for_errors() |
|
|
|
def check_the_piece_for_errors(self, piece: str = None): |
|
if piece is None: |
|
piece = self.get_whole_piece_from_bar_dict() |
|
errors = [] |
|
errors.append( |
|
[ |
|
(token, id) |
|
for id, token in enumerate(piece.split(" ")) |
|
if token not in self.tokenizer.vocab or token == "UNK" |
|
] |
|
) |
|
if len(errors) > 0: |
|
|
|
for er in errors: |
|
er |
|
print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}") |
|
print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5]) |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|