{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "411c59b3-f177-4a10-8925-d931ce572eaa", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL\n", "from PIL import Image\n", "\n", "from ip_adapter import IPAdapterPlus" ] }, { "cell_type": "code", "execution_count": 2, "id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49", "metadata": {}, "outputs": [], "source": [ "base_model_path = \"SG161222/Realistic_Vision_V4.0_noVAE\"\n", "vae_model_path = \"stabilityai/sd-vae-ft-mse\"\n", "image_encoder_path = \"models/image_encoder\"\n", "ip_ckpt = \"models/ip-adapter-plus_sd15.bin\"\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "63ec542f-8474-4f38-9457-073425578073", "metadata": {}, "outputs": [], "source": [ "def image_grid(imgs, rows, cols):\n", " assert len(imgs) == rows*cols\n", "\n", " w, h = imgs[0].size\n", " grid = Image.new('RGB', size=(cols*w, rows*h))\n", " grid_w, grid_h = grid.size\n", " \n", " for i, img in enumerate(imgs):\n", " grid.paste(img, box=(i%cols*w, i//cols*h))\n", " return grid\n", "\n", "noise_scheduler = DDIMScheduler(\n", " num_train_timesteps=1000,\n", " beta_start=0.00085,\n", " beta_end=0.012,\n", " beta_schedule=\"scaled_linear\",\n", " clip_sample=False,\n", " set_alpha_to_one=False,\n", " steps_offset=1,\n", ")\n", "vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3849f9d0-5f68-4a49-9190-69dd50720cae", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6da401960cd8491bb042f2b0e41066bb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/5 [00:00" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# read image prompt\n", "image = Image.open(\"assets/images/statue.png\")\n", "image.resize((256, 256))" ] }, { "cell_type": "code", "execution_count": 6, "id": "a23de3d2-169e-470b-8012-960e3d07b04b", "metadata": {}, "outputs": [], "source": [ "# load ip-adapter\n", "ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)" ] }, { "cell_type": "code", "execution_count": 7, "id": "d83df45f-717d-4bb3-a5fd-0ea30930a431", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0bfeb757632142a9aeecca0629ff1731", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# only image prompt\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] }, { "cell_type": "code", "execution_count": 8, "id": "b77f52de-a9e4-44e1-aeec-8165414f1273", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5f63b14047f34511b5e7f850e4b05259", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# multimodal prompts\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42,\n", " prompt=\"best quality, high quality, wearing a hat on the beach\", scale=0.6)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] }, { "cell_type": "code", "execution_count": 9, "id": "5d3d874a-49b2-4c7e-ad58-b0ecc085c1fd", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "211b7bb9032847c6aaaab8befcbd707c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# multimodal prompts\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42,\n", " prompt=\"best quality, high quality, wearing sunglasses in a garden\", scale=0.6)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }