{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "976fbfea", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0,'..')" ] }, { "cell_type": "code", "execution_count": 4, "id": "4b164f6b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "from tqdm import tqdm\n", "\n", "\n", "from train import train\n", "import priors\n", "import encoders\n", "import positional_encodings\n", "import utils\n", "import bar_distribution\n", "\n", "\n", "from samlib.utils import chunker" ] }, { "cell_type": "code", "execution_count": null, "id": "29d423b4", "metadata": {}, "outputs": [], "source": [ "mykwargs = \\\n", "{\n", " 'bptt': 5*5+1,\n", "'nlayers': 6,\n", " 'dropout': 0.0, 'steps_per_epoch': 100,\n", " 'batch_size': 100}\n", "mnist_jobs_5shot_pi_prior_search = [\n", " pretrain_and_eval( {'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n", " 'translations': False, 'jonas_style': True}, priors.stroke.DataLoader, Losses.ce, enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, nhid=nhid, y_encoder_generator=encoders.get_Canonical(5), lr=lr, epochs=epochs, single_eval_pos_gen=mykwargs['bptt']-1,\n", " extra_prior_kwargs_dict={'num_features': 28*28, 'fuse_x_y': False, 'num_outputs':5, 'only_train_for_last_idx': True,\n", " 'min_max_strokes': (1,max_strokes), 'min_max_len': (min_len, max_len), 'min_max_width': (min_width, max_width), 'max_offset': max_offset, 'max_target_offset': max_target_offset},\n", " **mykwargs)\n", " for max_strokes, min_len, max_len, min_width, max_width, max_offset, max_target_offset in random_hypers\n", " for enc in [encoders.Linear] for emsize in [1024] for nhead in [4] for nhid in [emsize*2] for warmup_epochs in [5] for lr in [.00001] for epochs in [128,1024] for _ in range(1)]\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "deb93d1e", "metadata": {}, "outputs": [], "source": [ "\n", "@torch.inference_mode()\n", "def get_acc(finetuned_model, eval_pos, device='cpu', steps=100, train_mode=False, **mykwargs):\n", " finetuned_model.to(device)\n", " finetuned_model.eval()\n", "\n", " t_dl = priors.omniglot.DataLoader(steps, batch_size=1000, seq_len=mykwargs['bptt'], train=train_mode,\n", " **mykwargs['extra_prior_kwargs_dict'])\n", "\n", " ps = []\n", " ys = []\n", " for x, y in tqdm(t_dl):\n", " p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)\n", " ps.append(p)\n", " ys.append(y)\n", "\n", " ps = torch.cat(ps, 1)\n", " ys = torch.cat(ys, 1)\n", "\n", " def acc(ps, ys):\n", " return (ps.argmax(-1) == ys.to(ps.device)).float().mean()\n", "\n", " a = acc(ps[eval_pos], ys[eval_pos]).cpu()\n", " print(a.item())\n", " return a\n", "\n", "\n", "def train_and_eval(*args, **kwargs):\n", " r = train(*args, **kwargs)\n", " model = r[-1]\n", " acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n", " model.to('cpu')\n", " return [acc]\n", "\n", "def pretrain_and_eval(extra_prior_kwargs_dict_eval,*args, **kwargs):\n", " r = train(*args, **kwargs)\n", " model = r[-1]\n", " kwargs['extra_prior_kwargs_dict'] = extra_prior_kwargs_dict_eval\n", " acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n", " model.to('cpu')\n", " return r, acc" ] }, { "cell_type": "code", "execution_count": null, "id": "706ecbb7", "metadata": {}, "outputs": [], "source": [ "\n", "emsize = 1024\n", "# mnist_jobs_5shot_pi[20].result()[-1].state_dict()\n", "mykwargs = \\\n", " {'bptt': 5 * 5 + 1,\n", " 'nlayers': 6,\n", " 'nhead': 4, 'emsize': emsize,\n", " 'encoder_generator': encoders.Linear, 'nhid': emsize * 2}\n", "results = train_and_eval(priors.omniglot.DataLoader, Losses.ce, y_encoder_generator=encoders.get_Canonical(5),\n", " load_weights_from_this_state_dict=mnist_jobs_5shot_pi_prior_search[67][0][-1].state_dict(), epochs=32, lr=.00001, dropout=dropout,\n", " single_eval_pos_gen=mykwargs['bptt'] - 1,\n", " extra_prior_kwargs_dict={'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n", " 'translations': True, 'jonas_style': True},\n", " batch_size=100, steps_per_epoch=200, **mykwargs)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "611554b2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }