import argparse import glob import os.path import gradio as gr import pickle import tqdm import json import MIDI from midi_synthesizer import synthesis import copy from collections import Counter import random import statistics import matplotlib.pyplot as plt #========================================================================================================== in_space = os.getenv("SYSTEM") == "spaces" #========================================================================================================== def match_midi(midi, progress=gr.Progress()): print('=' * 70) print('Loading MIDI file...') #================================================== score = MIDI.midi2score(midi) events_matrix = [] track_count = 0 for s in score: if track_count > 0: track = s track.sort(key=lambda x: x[1]) events_matrix.extend(track) else: midi_ticks = s track_count += 1 events_matrix.sort(key=lambda x: x[1]) mult_pitches_counts = [] for i in range(-6, 6): events_matrix1 = [] for e in events_matrix: ev = copy.deepcopy(e) if e[0] == 'note': if e[3] == 9: ev[4] = ((e[4] % 128) + 128) else: ev[4] = ((e[4] % 128) + i) events_matrix1.append(ev) pitches_counts = [[y[0],y[1]] for y in Counter([y[4] for y in events_matrix1 if y[0] == 'note']).most_common()] pitches_counts.sort(key=lambda x: x[0], reverse=True) mult_pitches_counts.append(pitches_counts) patches_list = sorted(list(set([y[3] for y in events_matrix if y[0] == 'patch_change']))) #================================================== ms_score = MIDI.midi2ms_score(midi) ms_events_matrix = [] itrack1 = 1 while itrack1 < len(ms_score): for event in ms_score[itrack1]: if event[0] == 'note': ms_events_matrix.append(event) itrack1 += 1 ms_events_matrix.sort(key=lambda x: x[1]) chords = [] pe = ms_events_matrix[0] cho = [] for e in ms_events_matrix: if (e[1] - pe[1]) == 0: if e[3] != 9: if (e[4] % 12) not in cho: cho.append(e[4] % 12) else: if len(cho) > 0: chords.append(sorted(cho)) cho = [] if e[3] != 9: if (e[4] % 12) not in cho: cho.append(e[4] % 12) pe = e if len(cho) > 0: chords.append(sorted(cho)) ms_chords_counts = sorted([[list(key), val] for key,val in Counter([tuple(c) for c in chords if len(c) > 1]).most_common()], reverse=True, key = lambda x: x[1]) times = [] pt = ms_events_matrix[0][1] start = True for e in ms_events_matrix: if (e[1]-pt) != 0 or start == True: times.append((e[1]-pt)) start = False pt = e[1] durs = [e[2] for e in ms_events_matrix] vels = [e[5] for e in ms_events_matrix] avg_time = int(sum(times) / len(times)) avg_dur = int(sum(durs) / len(durs)) mode_time = statistics.mode(times) mode_dur = statistics.mode(durs) median_time = int(statistics.median(times)) median_dur = int(statistics.median(durs)) #================================================== print('=' * 70) print('Done!') print('=' * 70) #========================================================================================================== #@title MIDI Pitches Search #@markdown Match ratio control option maximum_match_ratio_to_search_for = 1 #@param {type:"slider", min:0, max:1, step:0.01} #@markdown MIDI pitches search options pitches_counts_cutoff_threshold_ratio = 0 #@param {type:"slider", min:0, max:1, step:0.05} search_transposed_pitches = False #@param {type:"boolean"} skip_exact_matches = True #@param {type:"boolean"} #@markdown Additional search options add_pitches_counts_ratios = False #@param {type:"boolean"} add_timings_ratios = False #@param {type:"boolean"} add_durations_ratios = False #@param {type:"boolean"} print('=' * 70) print('MIDI Pitches Search') print('=' * 70) final_ratios = [] for d in progress.tqdm(meta_data): p_counts = d[1][10][1] p_counts.sort(reverse = True, key = lambda x: x[1]) max_p_count = p_counts[0][1] trimmed_p_counts = [y for y in p_counts if y[1] >= (max_p_count * pitches_counts_cutoff_threshold_ratio)] total_p_counts = sum([y[1] for y in trimmed_p_counts]) if search_transposed_pitches: search_pitches = mult_pitches_counts else: search_pitches = [mult_pitches_counts[6]] #=================================================== ratios_list = [] #=================================================== atrat = [0] if add_timings_ratios: source_times = [avg_time, median_time, mode_time] match_times = meta_data[0][1][3][1] times_ratios = [] for i in range(len(source_times)): maxtratio = max(source_times[i], match_times[i]) mintratio = min(source_times[i], match_times[i]) times_ratios.append(mintratio / maxtratio) avg_times_ratio = sum(times_ratios) / len(times_ratios) atrat[0] = avg_times_ratio #=================================================== adrat = [0] if add_durations_ratios: source_durs = [avg_dur, median_dur, mode_dur] match_durs = meta_data[0][1][4][1] durs_ratios = [] for i in range(len(source_durs)): maxtratio = max(source_durs[i], match_durs[i]) mintratio = min(source_durs[i], match_durs[i]) durs_ratios.append(mintratio / maxtratio) avg_durs_ratio = sum(durs_ratios) / len(durs_ratios) adrat[0] = avg_durs_ratio #=================================================== for m in search_pitches: sprat = [] m.sort(reverse = True, key = lambda x: x[1]) max_pitches_count = m[0][1] trimmed_pitches_counts = [y for y in m if y[1] >= (max_pitches_count * pitches_counts_cutoff_threshold_ratio)] total_pitches_counts = sum([y[1] for y in trimmed_pitches_counts]) same_pitches = set([T[0] for T in trimmed_p_counts]) & set([m[0] for m in trimmed_pitches_counts]) num_same_pitches = len(same_pitches) if num_same_pitches == len(trimmed_pitches_counts): same_pitches_ratio = (num_same_pitches / len(trimmed_p_counts)) else: same_pitches_ratio = (num_same_pitches / max(len(trimmed_p_counts), len(trimmed_pitches_counts))) if skip_exact_matches: if same_pitches_ratio == 1: same_pitches_ratio = 0 sprat.append(same_pitches_ratio) #=================================================== spcrat = [0] if add_pitches_counts_ratios: same_trimmed_p_counts = sorted([T for T in trimmed_p_counts if T[0] in same_pitches], reverse = True) same_trimmed_pitches_counts = sorted([T for T in trimmed_pitches_counts if T[0] in same_pitches], reverse = True) same_trimmed_p_counts_ratios = [[s[0], s[1] / total_p_counts] for s in same_trimmed_p_counts] same_trimmed_pitches_counts_ratios = [[s[0], s[1] / total_pitches_counts] for s in same_trimmed_pitches_counts] same_pitches_counts_ratios = [] for i in range(len(same_trimmed_p_counts_ratios)): mincratio = min(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) maxcratio = max(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) same_pitches_counts_ratios.append([same_trimmed_p_counts_ratios[i][0], mincratio / maxcratio]) same_counts_ratios = [s[1] for s in same_pitches_counts_ratios] if len(same_counts_ratios) > 0: avg_same_pitches_counts_ratio = sum(same_counts_ratios) / len(same_counts_ratios) else: avg_same_pitches_counts_ratio = 0 spcrat[0] = avg_same_pitches_counts_ratio #=================================================== r_list = [sprat[0]] if add_pitches_counts_ratios: r_list.append(spcrat[0]) if add_timings_ratios: r_list.append(atrat[0]) if add_durations_ratios: r_list.append(adrat[0]) ratios_list.append(r_list) #=================================================== avg_ratios_list = [] for r in ratios_list: avg_ratios_list.append(sum(r) / len(r)) #=================================================== final_ratio = max(avg_ratios_list) if final_ratio > maximum_match_ratio_to_search_for: final_ratio = 0 final_ratios.append(final_ratio) #=================================================== max_ratio = max(final_ratios) max_ratio_index = final_ratios.index(max_ratio) print('FOUND') print('=' * 70) print('Match ratio', max_ratio) print('MIDI file name', meta_data[max_ratio_index][0]) print('=' * 70) fn = meta_data[max_ratio_index][0] #========================================================================================================== md = meta_data[max_ratio_index] mid_seq = md[1][17:-1] mid_seq_ticks = md[1][16][1] mdata = md[1][:16] txt_mdata = '' for m in mdata: txt_mdata += str(m[0]) + ':' + str(m[1]) txt_mdata += chr(10) x = [] y = [] c = [] colors = ['red', 'yellow', 'green', 'cyan', 'blue', 'pink', 'orange', 'purple', 'gray', 'white', 'gold', 'silver', 'lightgreen', 'indigo', 'maroon', 'turquoise'] for s in [m for m in mid_seq if m[0] == 'note']: x.append(s[1]) y.append(s[4]) c.append(colors[s[3]]) plt.close() plt.figure(figsize=(14,5)) ax=plt.axes(title='MIDI Search Plot') ax.set_facecolor('black') plt.scatter(x,y, c=c) plt.xlabel("Time") plt.ylabel("Pitch") with open(f"output.mid", 'wb') as f: f.write(MIDI.score2midi([mid_seq_ticks, mid_seq])) audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path) yield txt_mdata, "MIDI-Match-Sample.mid", (44100, audio), plt #========================================================================================================== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", default=False, help="share gradio app") parser.add_argument("--port", type=int, default=7860, help="gradio server port") parser.add_argument("--max-gen", type=int, default=1024, help="max") opt = parser.parse_args() soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" meta_data_path = "meta-data/LAMD_META_10000.pickle" print('Loading meta-data...') with open(meta_data_path, 'rb') as f: meta_data = pickle.load(f) print('Done!') app = gr.Blocks() with app: gr.Markdown("

MIDI Match

") gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Match&style=flat)\n\n" "MIDI Match\n\n" "Demo for [MIDI Match](https://github.com/asigalov61)\n\n" "[Open In Colab]" "(https://colab.research.google.com/github/asigalov61/MIDI-Match/blob/main/demo.ipynb)" " for faster running and longer generation" ) gr.Markdown("# Upload any MIDI file to find its closest match") input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary") output_plot = gr.Plot(label="output midi match sample plot") output_audio = gr.Audio(label="output midi match sample audio", format="mp3", elem_id="midi_audio") output_midi = gr.File(label="output midi match sample file", file_types=[".mid"]) output_midi_seq = gr.Textbox(label="output midi match metadata") run_event = input_midi.upload(match_midi, [input_midi], [output_midi_seq, output_midi, output_audio, output_plot]) app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)