MIDI-Search / app.py
asigalov61's picture
Update app.py
67342fc verified
raw
history blame
8.43 kB
# https://huggingface.co/spaces/asigalov61/MIDI-Search
import os
import time as reqtime
import datetime
from pytz import timezone
import numpy as np
import gradio as gr
import copy
import random
import pickle
import zlib
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
#==========================================================================================================
def find_midi(title, artist):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('-' * 70)
print('Req title:', title)
print('Req artist:', artist)
print('-' * 70)
input_text = ''
if title != '':
input_text += title
if artist != '':
input_text += ' by ' + artist
print('Searching...')
query_embedding = model.encode([input_text])
# Compute cosine similarity between query and each sentence in the corpus
similarities = util.cos_sim(query_embedding, corpus_embeddings)
# Find the index of the most similar sentence
closest_index = np.argmax(similarities)
closest_index_match_ratio = max(similarities[0]).tolist()
best_corpus_match = all_MIDI_files_names[closest_index]
print('Done!')
print('=' * 70)
print('Match corpus index', closest_index)
print('Match corpus ratio', closest_index_match_ratio)
print('=' * 70)
print('Done!')
print('=' * 70)
song_artist = best_corpus_match[0]
zlib_file_name = best_corpus_match[1]
print('Fetching MIDI score...')
with open(zlib_file_name, 'rb') as f:
compressed_data = f.read()
# Decompress the data
decompressed_data = zlib.decompress(compressed_data)
# Convert the bytes back to a list using pickle
scores_data = pickle.loads(decompressed_data)
fnames = [f[0] for f in scores_data]
fnameidx = fnames.index(song_artist)
MIDI_score_data = scores_data[fnameidx][1]
print('Sample INTs', MIDI_score_data[:12])
print('=' * 70)
if len(outy) != 0:
song = outy
song_f = []
time = 0
dur = 0
vel = 90
pitch = 0
channel = 0
patches = [-1] * 16
channels = [0] * 16
channels[9] = 1
for ss in song:
if 0 <= ss < 256:
time += ss * 16
if 256 <= ss < 512:
dur = (ss-256) * 16
if 512 <= ss <= 640:
patch = (ss-512)
if patch < 128:
if patch not in patches:
if 0 in channels:
cha = channels.index(0)
channels[cha] = 1
else:
cha = 15
patches[cha] = patch
channel = patches.index(patch)
else:
channel = patches.index(patch)
if patch == 128:
channel = 9
if 640 < ss < 768:
ptc = (ss-640)
if 768 < ss < 896:
vel = (ss - 768)
song_f.append(['note', time, dur, channel, ptc, vel, patch ])
patches = [0 if x==-1 else x for x in patches]
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
print('Sample INTs', song_f[:3])
print('=' * 70)
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature = 'Los Angeles MIDI Dataset Search',
output_file_name = song_artist,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches,
timings_multiplier=16
)
new_fn = song_artist + '.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(song_artist)
output_midi_summary = str(song_f[:3])
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi_title, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
#==========================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
print('Loading files list...')
all_MIDI_files_names = TMIDIX.Tegridy_Any_Pickle_File_Reader('all_MIDI_files_names')
print('Done!')
print('=' * 70)
print('Loading clean_midi corpus...')
clean_midi_artist_song_description_summaries_lyrics_score = TMIDIX.Tegridy_Any_Pickle_File_Reader('clean_midi_artist_song_description_summaries_lyrics_scores')
print('Done!')
print('=' * 70)
print('Loading MIDI corpus embeddings...')
corpus_embeddings = np.load('MIDI_corpus_embeddings_all-mpnet-base-v2.npz')['data']
print('Done!')
print('=' * 70)
print('Loading Sentence Transformer model...')
model = SentenceTransformer('all-mpnet-base-v2')
print('Done!')
print('=' * 70)
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Advanced MIDI Search</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Search and explore 179k+ MIDI titles</h1>")
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Search&style=flat)\n\n"
"Giant Music Transformer Aux Data Demo\n\n"
"Please see [Giant Music Transformer](https://github.com/asigalov61/Giant-Music-Transformer) for more information and features\n\n"
"[Open In Colab]"
"(https://colab.research.google.com/github/asigalov61/Giant-Music-Transformer/blob/main/Giant_Music_Transformer_TTM.ipynb)"
" for all features"
)
title = gr.Textbox(label="Desired Song Title", value="Family Guy")
artist = gr.Textbox(label="Desired Song Artist", value="TV Themes")
submit = gr.Button()
gr.Markdown("# Search results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = submit.click(find_midi, [title, artist],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot ])
app.launch()