m41w4r3.exe
commited on
Commit
•
6cc2135
1
Parent(s):
facf84e
fix genesis caching
Browse files- decoder.py +1 -1
- generate.py +4 -4
- generation_utils.py +29 -9
- playground.py +56 -38
decoder.py
CHANGED
@@ -178,7 +178,7 @@ class TextDecoder:
|
|
178 |
inst = 0
|
179 |
is_drum = 1
|
180 |
if self.familized:
|
181 |
-
inst = Familizer(arbitrary=True).get_program_number(int(inst))
|
182 |
instruments.append((int(inst), is_drum))
|
183 |
return tuple(instruments)
|
184 |
|
|
|
178 |
inst = 0
|
179 |
is_drum = 1
|
180 |
if self.familized:
|
181 |
+
inst = Familizer(arbitrary=True).get_program_number(int(inst))
|
182 |
instruments.append((int(inst), is_drum))
|
183 |
return tuple(instruments)
|
184 |
|
generate.py
CHANGED
@@ -21,12 +21,12 @@ class GenerateMidiText:
|
|
21 |
- self.process_prompt_for_next_bar()
|
22 |
- self.generate_until_track_end()"""
|
23 |
|
24 |
-
def __init__(self, model, tokenizer):
|
25 |
self.model = model
|
26 |
self.tokenizer = tokenizer
|
27 |
# default initialization
|
28 |
self.initialize_default_parameters()
|
29 |
-
self.initialize_dictionaries()
|
30 |
|
31 |
"""Setters"""
|
32 |
|
@@ -38,8 +38,8 @@ class GenerateMidiText:
|
|
38 |
self.set_nb_bars_generated()
|
39 |
self.set_improvisation_level(0)
|
40 |
|
41 |
-
def initialize_dictionaries(self):
|
42 |
-
self.piece_by_track =
|
43 |
|
44 |
def set_device(self, device="cpu"):
|
45 |
self.device = ("cpu",)
|
|
|
21 |
- self.process_prompt_for_next_bar()
|
22 |
- self.generate_until_track_end()"""
|
23 |
|
24 |
+
def __init__(self, model, tokenizer, piece_by_track=[]):
|
25 |
self.model = model
|
26 |
self.tokenizer = tokenizer
|
27 |
# default initialization
|
28 |
self.initialize_default_parameters()
|
29 |
+
self.initialize_dictionaries(piece_by_track)
|
30 |
|
31 |
"""Setters"""
|
32 |
|
|
|
38 |
self.set_nb_bars_generated()
|
39 |
self.set_improvisation_level(0)
|
40 |
|
41 |
+
def initialize_dictionaries(self, piece_by_track):
|
42 |
+
self.piece_by_track = piece_by_track
|
43 |
|
44 |
def set_device(self, device="cpu"):
|
45 |
self.device = ("cpu",)
|
generation_utils.py
CHANGED
@@ -2,14 +2,16 @@ import os
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import matplotlib
|
|
|
5 |
from constants import INSTRUMENT_CLASSES
|
|
|
6 |
|
7 |
# matplotlib settings
|
8 |
matplotlib.use("Agg") # for server
|
9 |
matplotlib.rcParams["xtick.major.size"] = 0
|
10 |
matplotlib.rcParams["ytick.major.size"] = 0
|
11 |
-
matplotlib.rcParams["axes.facecolor"] = "
|
12 |
-
matplotlib.rcParams["axes.edgecolor"] = "
|
13 |
|
14 |
|
15 |
def define_generation_dir(model_repo_path):
|
@@ -93,7 +95,7 @@ def get_max_time(inst_midi):
|
|
93 |
def plot_piano_roll(inst_midi):
|
94 |
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
95 |
piano_roll_fig.tight_layout()
|
96 |
-
piano_roll_fig.patch.set_alpha(0
|
97 |
inst_count = 0
|
98 |
beats_per_bar = 4
|
99 |
sec_per_beat = 0.5
|
@@ -102,6 +104,14 @@ def plot_piano_roll(inst_midi):
|
|
102 |
int
|
103 |
)
|
104 |
for inst in inst_midi.instruments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
inst_count += 1
|
106 |
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
107 |
|
@@ -118,24 +128,34 @@ def plot_piano_roll(inst_midi):
|
|
118 |
for note in p_midi_note_list:
|
119 |
note_time.append([note.start, note.end])
|
120 |
note_pitch.append([note.pitch, note.pitch])
|
|
|
|
|
121 |
|
122 |
plt.plot(
|
123 |
-
|
124 |
-
|
125 |
-
color=
|
126 |
-
linewidth=
|
127 |
solid_capstyle="butt",
|
128 |
)
|
129 |
plt.ylim(0, 128)
|
130 |
xticks = np.array(bars_time)[:-1]
|
131 |
plt.tight_layout()
|
132 |
plt.xlim(min(bars_time), max(bars_time))
|
133 |
-
|
134 |
plt.xticks(
|
135 |
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
136 |
labels=xticks.argsort() + 1,
|
137 |
visible=False,
|
138 |
)
|
139 |
-
plt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
return piano_roll_fig
|
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import matplotlib
|
5 |
+
|
6 |
from constants import INSTRUMENT_CLASSES
|
7 |
+
from playback import get_music, show_piano_roll
|
8 |
|
9 |
# matplotlib settings
|
10 |
matplotlib.use("Agg") # for server
|
11 |
matplotlib.rcParams["xtick.major.size"] = 0
|
12 |
matplotlib.rcParams["ytick.major.size"] = 0
|
13 |
+
matplotlib.rcParams["axes.facecolor"] = "none"
|
14 |
+
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
15 |
|
16 |
|
17 |
def define_generation_dir(model_repo_path):
|
|
|
95 |
def plot_piano_roll(inst_midi):
|
96 |
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
97 |
piano_roll_fig.tight_layout()
|
98 |
+
piano_roll_fig.patch.set_alpha(0)
|
99 |
inst_count = 0
|
100 |
beats_per_bar = 4
|
101 |
sec_per_beat = 0.5
|
|
|
104 |
int
|
105 |
)
|
106 |
for inst in inst_midi.instruments:
|
107 |
+
# hardcoded for now
|
108 |
+
if inst.name == "Drums":
|
109 |
+
color = "purple"
|
110 |
+
elif inst.name == "Synth Bass 1":
|
111 |
+
color = "orange"
|
112 |
+
else:
|
113 |
+
color = "green"
|
114 |
+
|
115 |
inst_count += 1
|
116 |
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
117 |
|
|
|
128 |
for note in p_midi_note_list:
|
129 |
note_time.append([note.start, note.end])
|
130 |
note_pitch.append([note.pitch, note.pitch])
|
131 |
+
note_pitch = np.array(note_pitch)
|
132 |
+
note_time = np.array(note_time)
|
133 |
|
134 |
plt.plot(
|
135 |
+
note_time.T,
|
136 |
+
note_pitch.T,
|
137 |
+
color=color,
|
138 |
+
linewidth=4,
|
139 |
solid_capstyle="butt",
|
140 |
)
|
141 |
plt.ylim(0, 128)
|
142 |
xticks = np.array(bars_time)[:-1]
|
143 |
plt.tight_layout()
|
144 |
plt.xlim(min(bars_time), max(bars_time))
|
145 |
+
plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
|
146 |
plt.xticks(
|
147 |
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
148 |
labels=xticks.argsort() + 1,
|
149 |
visible=False,
|
150 |
)
|
151 |
+
plt.text(
|
152 |
+
0.2,
|
153 |
+
note_pitch.max() + 4,
|
154 |
+
inst.name,
|
155 |
+
fontsize=20,
|
156 |
+
color=color,
|
157 |
+
horizontalalignment="left",
|
158 |
+
verticalalignment="top",
|
159 |
+
)
|
160 |
|
161 |
return piano_roll_fig
|
playground.py
CHANGED
@@ -26,7 +26,6 @@ model, tokenizer = LoadModel(
|
|
26 |
model_repo, from_huggingface=True, revision=revision
|
27 |
).load_model_and_tokenizer()
|
28 |
|
29 |
-
|
30 |
miditok = get_miditok()
|
31 |
decoder = TextDecoder(miditok)
|
32 |
|
@@ -40,32 +39,49 @@ def define_prompt(state, genesis):
|
|
40 |
|
41 |
|
42 |
def generator(
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
):
|
45 |
|
|
|
|
|
46 |
inst = next(
|
47 |
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
|
48 |
{"family_number": "DRUMS"},
|
49 |
)["family_number"]
|
50 |
|
51 |
-
inst_index =
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
genesis.delete_one_track(inst_index)
|
57 |
-
generated_text = (
|
58 |
-
genesis.get_whole_piece_from_bar_dict()
|
59 |
-
) # maybe not useful here
|
60 |
-
inst_index = -1 # reset to last generated
|
61 |
|
62 |
# Generate
|
63 |
if not add_bars:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# NEW TRACK
|
65 |
input_prompt = define_prompt(state, genesis)
|
66 |
generated_text = genesis.generate_one_new_track(
|
67 |
inst, density, temp, input_prompt=input_prompt
|
68 |
)
|
|
|
|
|
69 |
else:
|
70 |
# NEW BARS
|
71 |
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
@@ -79,14 +95,23 @@ def generator(
|
|
79 |
decoder.get_midi(inst_text, inst_midi_name)
|
80 |
_, inst_audio = get_music(inst_midi_name)
|
81 |
piano_roll = plot_piano_roll(mixed_inst_midi)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
def instrument_row(default_inst):
|
88 |
|
|
|
89 |
with gr.Row():
|
|
|
90 |
with gr.Column(scale=1, min_width=50):
|
91 |
inst = gr.Dropdown(
|
92 |
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
|
@@ -100,35 +125,33 @@ def instrument_row(default_inst):
|
|
100 |
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
|
101 |
with gr.Column(scale=1, min_width=100):
|
102 |
inst_audio = gr.Audio(label="Audio")
|
103 |
-
regenerate = gr.Checkbox(value=False, label="Regenerate")
|
104 |
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
105 |
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
106 |
gen_btn = gr.Button("Generate")
|
107 |
gen_btn.click(
|
108 |
fn=generator,
|
109 |
-
inputs=[
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
state,
|
|
|
|
|
|
|
115 |
],
|
116 |
-
outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
|
117 |
)
|
118 |
|
119 |
|
120 |
-
with gr.Blocks(
|
121 |
-
|
122 |
-
model,
|
123 |
-
tokenizer,
|
124 |
-
)
|
125 |
-
genesis.set_nb_bars_generated(n_bars=n_bar_generated)
|
126 |
state = gr.State([])
|
127 |
mixed_audio = gr.Audio(label="Mixed Audio")
|
128 |
piano_roll = gr.Plot(label="Piano Roll")
|
129 |
-
instrument_row("Drums")
|
130 |
-
instrument_row("Bass")
|
131 |
-
instrument_row("Synth Lead")
|
132 |
# instrument_row("Piano")
|
133 |
|
134 |
demo.launch(debug=True)
|
@@ -138,14 +161,9 @@ TODO: DEPLOY
|
|
138 |
TODO: temp file situation
|
139 |
TODO: clear cache situation
|
140 |
TODO: reset button
|
141 |
-
TODO: instrument mapping business
|
142 |
-
TODO: Y lim axis of piano roll
|
143 |
TODO: add a button to save the generated midi
|
144 |
TODO: add improvise button
|
145 |
-
TODO: making the piano roll fit on the horizontal scale
|
146 |
TODO: set values for temperature as it is done for density
|
147 |
-
TODO: set the color situation to be dark background
|
148 |
-
TODO: make regeration default when an intrument has already been track has already been generated
|
149 |
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
150 |
TODO: row height to fix
|
151 |
|
|
|
26 |
model_repo, from_huggingface=True, revision=revision
|
27 |
).load_model_and_tokenizer()
|
28 |
|
|
|
29 |
miditok = get_miditok()
|
30 |
decoder = TextDecoder(miditok)
|
31 |
|
|
|
39 |
|
40 |
|
41 |
def generator(
|
42 |
+
label,
|
43 |
+
regenerate,
|
44 |
+
temp,
|
45 |
+
density,
|
46 |
+
instrument,
|
47 |
+
state,
|
48 |
+
piece_by_track,
|
49 |
+
add_bars=False,
|
50 |
+
add_bar_count=1,
|
51 |
):
|
52 |
|
53 |
+
genesis = GenerateMidiText(model, tokenizer, piece_by_track)
|
54 |
+
track = {"label": label}
|
55 |
inst = next(
|
56 |
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
|
57 |
{"family_number": "DRUMS"},
|
58 |
)["family_number"]
|
59 |
|
60 |
+
inst_index = -1 # default to last generated
|
61 |
+
if state != []:
|
62 |
+
for index, instrum in enumerate(state):
|
63 |
+
if instrum["label"] == track["label"]:
|
64 |
+
inst_index = index # changing if exists
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Generate
|
67 |
if not add_bars:
|
68 |
+
# Regenerate
|
69 |
+
if regenerate:
|
70 |
+
state.pop(inst_index)
|
71 |
+
genesis.delete_one_track(inst_index)
|
72 |
+
|
73 |
+
generated_text = (
|
74 |
+
genesis.get_whole_piece_from_bar_dict()
|
75 |
+
) # maybe not useful here
|
76 |
+
inst_index = -1 # reset to last generated
|
77 |
+
|
78 |
# NEW TRACK
|
79 |
input_prompt = define_prompt(state, genesis)
|
80 |
generated_text = genesis.generate_one_new_track(
|
81 |
inst, density, temp, input_prompt=input_prompt
|
82 |
)
|
83 |
+
|
84 |
+
regenerate = True # set generate to true
|
85 |
else:
|
86 |
# NEW BARS
|
87 |
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
|
|
95 |
decoder.get_midi(inst_text, inst_midi_name)
|
96 |
_, inst_audio = get_music(inst_midi_name)
|
97 |
piano_roll = plot_piano_roll(mixed_inst_midi)
|
98 |
+
track["text"] = inst_text
|
99 |
+
state.append(track)
|
100 |
+
|
101 |
+
return (
|
102 |
+
inst_text,
|
103 |
+
(44100, inst_audio),
|
104 |
+
piano_roll,
|
105 |
+
state,
|
106 |
+
(44100, mixed_audio),
|
107 |
+
regenerate,
|
108 |
+
genesis.piece_by_track,
|
109 |
+
)
|
110 |
|
|
|
111 |
|
112 |
+
def instrument_row(default_inst, row_id):
|
113 |
with gr.Row():
|
114 |
+
row = gr.Variable(row_id)
|
115 |
with gr.Column(scale=1, min_width=50):
|
116 |
inst = gr.Dropdown(
|
117 |
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
|
|
|
125 |
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
|
126 |
with gr.Column(scale=1, min_width=100):
|
127 |
inst_audio = gr.Audio(label="Audio")
|
128 |
+
regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
|
129 |
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
130 |
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
131 |
gen_btn = gr.Button("Generate")
|
132 |
gen_btn.click(
|
133 |
fn=generator,
|
134 |
+
inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
|
135 |
+
outputs=[
|
136 |
+
output_txt,
|
137 |
+
inst_audio,
|
138 |
+
piano_roll,
|
139 |
state,
|
140 |
+
mixed_audio,
|
141 |
+
regenerate,
|
142 |
+
piece_by_track,
|
143 |
],
|
|
|
144 |
)
|
145 |
|
146 |
|
147 |
+
with gr.Blocks() as demo:
|
148 |
+
piece_by_track = gr.State([])
|
|
|
|
|
|
|
|
|
149 |
state = gr.State([])
|
150 |
mixed_audio = gr.Audio(label="Mixed Audio")
|
151 |
piano_roll = gr.Plot(label="Piano Roll")
|
152 |
+
instrument_row("Drums", 0)
|
153 |
+
instrument_row("Bass", 1)
|
154 |
+
instrument_row("Synth Lead", 2)
|
155 |
# instrument_row("Piano")
|
156 |
|
157 |
demo.launch(debug=True)
|
|
|
161 |
TODO: temp file situation
|
162 |
TODO: clear cache situation
|
163 |
TODO: reset button
|
|
|
|
|
164 |
TODO: add a button to save the generated midi
|
165 |
TODO: add improvise button
|
|
|
166 |
TODO: set values for temperature as it is done for density
|
|
|
|
|
167 |
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
168 |
TODO: row height to fix
|
169 |
|