{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "VGrGd6__l5ch" }, "source": [ "# Melody2Song Seq2Seq Music Transformer (ver. 1.0)\n", "\n", "***\n", "\n", "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", "\n", "***\n", "\n", "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n", "\n", "***\n", "\n", "#### Project Los Angeles\n", "\n", "#### Tegridy Code 2024\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": { "id": "shLrgoXdl5cj" }, "source": [ "# (GPU CHECK)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X3rABEpKCO02", "cellView": "form" }, "outputs": [], "source": [ "# @title NVIDIA GPU Check\n", "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "0RcVC4btl5ck" }, "source": [ "# (SETUP ENVIRONMENT)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "viHgEaNACPTs", "cellView": "form" }, "outputs": [], "source": [ "# @title Install requirements\n", "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools\n", "!pip install einops\n", "!pip install torch-summary\n", "!apt install fluidsynth" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DzCOZU_gBiQV", "cellView": "form" }, "outputs": [], "source": [ "# @title Load all needed modules\n", "\n", "print('=' * 70)\n", "print('Loading needed modules...')\n", "print('=' * 70)\n", "\n", "import os\n", "import pickle\n", "import random\n", "import secrets\n", "import tqdm\n", "import math\n", "import torch\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from torchsummary import summary\n", "\n", "%cd /content/tegridy-tools/tegridy-tools/\n", "\n", "import TMIDIX\n", "from midi_to_colab_audio import midi_to_colab_audio\n", "\n", "%cd /content/tegridy-tools/tegridy-tools/X-Transformer\n", "\n", "from x_transformer_1_23_2 import *\n", "\n", "%cd /content/\n", "\n", "import random\n", "\n", "from sklearn import metrics\n", "\n", "from IPython.display import Audio, display\n", "\n", "from huggingface_hub import hf_hub_download\n", "\n", "from google.colab import files\n", "\n", "print('=' * 70)\n", "print('Done')\n", "print('=' * 70)\n", "print('Torch version:', torch.__version__)\n", "print('=' * 70)\n", "print('Enjoy! :)')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "source": [ "# (SETUP DATA AND MODEL)" ], "metadata": { "id": "SQ1_7P4bLdtB" } }, { "cell_type": "code", "source": [ "#@title Load Melody2Song Seq2Seq Music Trnasofmer Data and Pre-Trained Model\n", "\n", "#@markdown Model precision option\n", "\n", "model_precision = \"bfloat16\" # @param [\"bfloat16\", \"float16\"]\n", "\n", "plot_tokens_embeddings = True # @param {type:\"boolean\"}\n", "\n", "print('=' * 70)\n", "print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...')\n", "print('=' * 70)\n", "\n", "data_path = '/content'\n", "\n", "if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'):\n", " print('Data file already exists...')\n", "\n", "else:\n", " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n", " repo_type='space',\n", " filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle',\n", " local_dir=data_path,\n", " )\n", "\n", "print('=' * 70)\n", "seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')\n", "\n", "print('=' * 70)\n", "print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...')\n", "print('Please wait...')\n", "print('=' * 70)\n", "\n", "full_path_to_models_dir = \"/content\"\n", "\n", "model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth'\n", "model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name\n", "num_layers = 24\n", "if os.path.isfile(model_path):\n", " print('Model already exists...')\n", "\n", "else:\n", " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n", " repo_type='space',\n", " filename=model_checkpoint_file_name,\n", " local_dir=full_path_to_models_dir,\n", " )\n", "\n", "\n", "print('=' * 70)\n", "print('Instantiating model...')\n", "\n", "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n", "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n", "device_type = 'cuda'\n", "\n", "if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():\n", " dtype = 'bfloat16'\n", "else:\n", " dtype = 'float16'\n", "\n", "if model_precision == 'float16':\n", " dtype = 'float16'\n", "\n", "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n", "ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)\n", "\n", "SEQ_LEN = 2560\n", "PAD_IDX = 514\n", "\n", "# instantiate the model\n", "\n", "model = TransformerWrapper(\n", " num_tokens = PAD_IDX+1,\n", " max_seq_len = SEQ_LEN,\n", " attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True)\n", ")\n", "\n", "model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)\n", "\n", "model.cuda()\n", "print('=' * 70)\n", "\n", "print('Loading model checkpoint...')\n", "\n", "model.load_state_dict(torch.load(model_path))\n", "print('=' * 70)\n", "\n", "model.eval()\n", "\n", "print('Done!')\n", "print('=' * 70)\n", "\n", "print('Model will use', dtype, 'precision...')\n", "print('=' * 70)\n", "\n", "# Model stats\n", "print('Model summary...')\n", "summary(model)\n", "\n", "if plot_tokens_embeddings:\n", "\n", " tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n", "\n", " cos_sim = metrics.pairwise_distances(\n", " tok_emb, metric='cosine'\n", " )\n", " plt.figure(figsize=(7, 7))\n", " plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n", " im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n", " plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n", " plt.xlabel(\"Position\")\n", " plt.ylabel(\"Position\")\n", " plt.tight_layout()\n", " plt.plot()\n", " plt.savefig(\"/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")" ], "metadata": { "cellView": "form", "id": "z7QLJ6FajxPA" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# (LOAD SEED MELODY)" ], "metadata": { "id": "NdJ1_A8gNoV3" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AIvb6MmSO9R3", "cellView": "form" }, "outputs": [], "source": [ "# @title Load desired seed melody\n", "\n", "#@markdown NOTE: If custom MIDI file is not provided, sample seed melody will be used instead\n", "\n", "full_path_to_custom_seed_melody_MIDI_file = \"/content/tegridy-tools/tegridy-tools/seed-melody.mid\" # @param {type:\"string\"}\n", "sample_seed_melody_number = 0 # @param {type:\"slider\", min:0, max:203664, step:1}\n", "\n", "print('=' * 70)\n", "print('Loading seed melody...')\n", "print('=' * 70)\n", "\n", "if full_path_to_custom_seed_melody_MIDI_file != '':\n", "\n", " #===============================================================================\n", " # Raw single-track ms score\n", "\n", " raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file)\n", "\n", " #===============================================================================\n", " # Enhanced score notes\n", "\n", " escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]\n", "\n", " #===============================================================================\n", " # Augmented enhanced score notes\n", "\n", " escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))\n", "\n", " cscore = TMIDIX.chordify_score([1000, escore_notes])\n", "\n", " fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore])\n", "\n", " melody = []\n", "\n", " pe = fixed_mel_score[0]\n", "\n", " for s in fixed_mel_score:\n", "\n", " dtime = max(0, min(127, s[1]-pe[1]))\n", " dur = max(1, min(127, s[2]))\n", " ptc = max(1, min(127, s[4]))\n", "\n", " chan = 1\n", "\n", " melody.extend([dtime, dur+128, (128 * chan)+ptc+256])\n", "\n", " pe = s\n", "\n", " if len(melody) >= 192:\n", " melody = [512] + melody[:192] + [513]\n", "\n", " else:\n", " mult = math.ceil(192 / len(melody))\n", " melody = melody * mult\n", " melody = [512] + melody[:192] + [513]\n", "\n", " print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file)\n", " print('=' * 70)\n", "\n", "else:\n", " melody = seed_melodies_data[sample_seed_melody_number]\n", " print('Loaded sample seed melody #', sample_seed_melody_number)\n", " print('=' * 70)\n", "\n", "print('Sample melody INTs:', melody[:10])\n", "print('=' * 70)\n", "print('Done!')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": { "id": "feXay_Ed7mG5" }, "source": [ "# (GENERATE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "naf65RxUXwDg", "cellView": "form" }, "outputs": [], "source": [ "# @title Generate song from melody\n", "\n", "melody_MIDI_patch_number = 40 # @param {type:\"slider\", min:0, max:127, step:1}\n", "accompaniment_MIDI_patch_number = 0 # @param {type:\"slider\", min:0, max:127, step:1}\n", "number_of_tokens_to_generate = 900 # @param {type:\"slider\", min:15, max:2354, step:3}\n", "number_of_batches_to_generate = 4 # @param {type:\"slider\", min:1, max:16, step:1}\n", "top_k_value = 25 # @param {type:\"slider\", min:1, max:50, step:1}\n", "temperature = 0.9 # @param {type:\"slider\", min:0.1, max:1, step:0.05}\n", "render_MIDI_to_audio = True # @param {type:\"boolean\"}\n", "\n", "print('=' * 70)\n", "print('Melody2Song Seq1Seq Music Transformer Model Generator')\n", "print('=' * 70)\n", "\n", "print('Generating...')\n", "print('=' * 70)\n", "\n", "model.eval()\n", "\n", "torch.cuda.empty_cache()\n", "\n", "x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda'))\n", "\n", "with ctx:\n", " out = model.generate(x,\n", " number_of_tokens_to_generate,\n", " filter_logits_fn=top_k,\n", " filter_kwargs={'k': top_k_value},\n", " temperature=0.9,\n", " return_prime=False,\n", " verbose=True)\n", "\n", "output = out.tolist()\n", "\n", "print('=' * 70)\n", "print('Done!')\n", "print('=' * 70)\n", "\n", "#======================================================================\n", "print('Rendering results...')\n", "\n", "for i in range(number_of_batches_to_generate):\n", "\n", " print('=' * 70)\n", " print('Batch #', i)\n", " print('=' * 70)\n", "\n", " out1 = output[i]\n", "\n", " print('Sample INTs', out1[:12])\n", " print('=' * 70)\n", "\n", " if len(out1) != 0:\n", "\n", " song = out1\n", " song_f = []\n", "\n", " time = 0\n", " dur = 0\n", " vel = 90\n", " pitch = 0\n", " channel = 0\n", "\n", " patches = [0] * 16\n", " patches[0] = accompaniment_MIDI_patch_number\n", " patches[3] = melody_MIDI_patch_number\n", "\n", " for ss in song:\n", "\n", " if 0 < ss < 128:\n", "\n", " time += (ss * 32)\n", "\n", " if 128 < ss < 256:\n", "\n", " dur = (ss-128) * 32\n", "\n", " if 256 < ss < 512:\n", "\n", " pitch = (ss-256) % 128\n", "\n", " channel = (ss-256) // 128\n", "\n", " if channel == 1:\n", " channel = 3\n", " vel = 110 + (pitch % 12)\n", " song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number])\n", "\n", " else:\n", " vel = 80 + (pitch % 12)\n", " channel = 0\n", " song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number])\n", "\n", " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", " output_signature = 'Melody2Song Seq2Seq Music Transformer',\n", " output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i),\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", " print('=' * 70)\n", " print('Displaying resulting composition...')\n", " print('=' * 70)\n", "\n", " fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i)\n", "\n", " if render_MIDI_to_audio:\n", " midi_audio = midi_to_colab_audio(fname + '.mid')\n", " display(Audio(midi_audio, rate=16000, normalize=False))\n", "\n", " TMIDIX.plot_ms_SONG(song_f, plot_title=fname)" ] }, { "cell_type": "markdown", "metadata": { "id": "z87TlDTVl5cp" }, "source": [ "# Congrats! You did it! :)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "gpuType": "L4", "private_outputs": true, "provenance": [], "machine_shape": "hm" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }