|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# ํ์ดํ๋ผ์ธ, ๋ชจ๋ธ ๋ฐ ์ค์ผ์ค๋ฌ ์ดํดํ๊ธฐ |
|
|
|
[[open-in-colab]] |
|
|
|
๐งจ Diffusers๋ ์ฌ์ฉ์ ์นํ์ ์ด๋ฉฐ ์ ์ฐํ ๋๊ตฌ ์์๋ก, ์ฌ์ฉ์ฌ๋ก์ ๋ง๊ฒ diffusion ์์คํ
์ ๊ตฌ์ถ ํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ์ด ๋๊ตฌ ์์์ ํต์ฌ์ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ์
๋๋ค. [`DiffusionPipeline`]์ ํธ์๋ฅผ ์ํด ์ด๋ฌํ ๊ตฌ์ฑ ์์๋ฅผ ๋ฒ๋ค๋ก ์ ๊ณตํ์ง๋ง, ํ์ดํ๋ผ์ธ์ ๋ถ๋ฆฌํ๊ณ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ฌ์ฉํด ์๋ก์ด diffusion ์์คํ
์ ๋ง๋ค ์๋ ์์ต๋๋ค. |
|
|
|
์ด ํํ ๋ฆฌ์ผ์์๋ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ๋ถํฐ ์์ํด Stable Diffusion ํ์ดํ๋ผ์ธ๊น์ง ์งํํ๋ฉฐ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํด ์ถ๋ก ์ ์ํ diffusion ์์คํ
์ ์กฐ๋ฆฝํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์๋๋ค. |
|
|
|
## ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ํด์ฒดํ๊ธฐ |
|
|
|
ํ์ดํ๋ผ์ธ์ ์ถ๋ก ์ ์ํด ๋ชจ๋ธ์ ์คํํ๋ ๋น ๋ฅด๊ณ ์ฌ์ด ๋ฐฉ๋ฒ์ผ๋ก, ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๋ฐ ์ฝ๋๊ฐ 4์ค ์ด์ ํ์ํ์ง ์์ต๋๋ค: |
|
|
|
```py |
|
>>> from diffusers import DDPMPipeline |
|
|
|
>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256").to("cuda") |
|
>>> image = ddpm(num_inference_steps=25).images[0] |
|
>>> image |
|
``` |
|
|
|
<div class="flex justify-center"> |
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ddpm-cat.png" alt="Image of cat created from DDPMPipeline"/> |
|
</div> |
|
|
|
์ ๋ง ์ฝ์ต๋๋ค. ๊ทธ๋ฐ๋ฐ ํ์ดํ๋ผ์ธ์ ์ด๋ป๊ฒ ์ด๋ ๊ฒ ํ ์ ์์์๊น์? ํ์ดํ๋ผ์ธ์ ์ธ๋ถํํ์ฌ ๋ด๋ถ์์ ์ด๋ค ์ผ์ด ์ผ์ด๋๊ณ ์๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
์ ์์์์ ํ์ดํ๋ผ์ธ์๋ [`UNet2DModel`] ๋ชจ๋ธ๊ณผ [`DDPMScheduler`]๊ฐ ํฌํจ๋์ด ์์ต๋๋ค. ํ์ดํ๋ผ์ธ์ ์ํ๋ ์ถ๋ ฅ ํฌ๊ธฐ์ ๋๋ค ๋
ธ์ด์ฆ๋ฅผ ๋ฐ์ ๋ชจ๋ธ์ ์ฌ๋ฌ๋ฒ ํต๊ณผ์์ผ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํฉ๋๋ค. ๊ฐ timestep์์ ๋ชจ๋ธ์ *noise residual*์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ๋ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๋
ธ์ด์ฆ๊ฐ ์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ํ์ดํ๋ผ์ธ์ ์ง์ ๋ ์ถ๋ก ์คํ
์์ ๋๋ฌํ ๋๊น์ง ์ด ๊ณผ์ ์ ๋ฐ๋ณตํฉ๋๋ค. |
|
|
|
๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๋๋ก ์ฌ์ฉํ์ฌ ํ์ดํ๋ผ์ธ์ ๋ค์ ์์ฑํ๊ธฐ ์ํด ์์ฒด์ ์ธ ๋
ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค๋ฅผ ์์ฑํด ๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
1. ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์ต๋๋ค: |
|
|
|
```py |
|
>>> from diffusers import DDPMScheduler, UNet2DModel |
|
|
|
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256") |
|
>>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda") |
|
``` |
|
|
|
2. ๋
ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค๋ฅผ ์คํํ timestep ์๋ฅผ ์ค์ ํฉ๋๋ค: |
|
|
|
```py |
|
>>> scheduler.set_timesteps(50) |
|
``` |
|
|
|
3. ์ค์ผ์ค๋ฌ์ timestep์ ์ค์ ํ๋ฉด ๊ท ๋ฑํ ๊ฐ๊ฒฉ์ ๊ตฌ์ฑ ์์๋ฅผ ๊ฐ์ง ํ
์๊ฐ ์์ฑ๋ฉ๋๋ค.(์ด ์์์์๋ 50๊ฐ) ๊ฐ ์์๋ ๋ชจ๋ธ์ด ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ ์๊ฐ ๊ฐ๊ฒฉ์ ํด๋นํฉ๋๋ค. ๋์ค์ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฅผ ๋ง๋ค ๋ ์ด ํ
์๋ฅผ ๋ฐ๋ณตํ์ฌ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํฉ๋๋ค: |
|
|
|
```py |
|
>>> scheduler.timesteps |
|
tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720, |
|
700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440, |
|
420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160, |
|
140, 120, 100, 80, 60, 40, 20, 0]) |
|
``` |
|
|
|
4. ์ํ๋ ์ถ๋ ฅ๊ณผ ๊ฐ์ ๋ชจ์์ ๊ฐ์ง ๋๋ค ๋
ธ์ด์ฆ๋ฅผ ์์ฑํฉ๋๋ค: |
|
|
|
```py |
|
>>> import torch |
|
|
|
>>> sample_size = model.config.sample_size |
|
>>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda") |
|
``` |
|
|
|
5. ์ด์ timestep์ ๋ฐ๋ณตํ๋ ๋ฃจํ๋ฅผ ์์ฑํฉ๋๋ค. ๊ฐ timestep์์ ๋ชจ๋ธ์ [`UNet2DModel.forward`]๋ฅผ ํตํด noisy residual์ ๋ฐํํฉ๋๋ค. ์ค์ผ์ค๋ฌ์ [`~DDPMScheduler.step`] ๋ฉ์๋๋ noisy residual, timestep, ๊ทธ๋ฆฌ๊ณ ์
๋ ฅ์ ๋ฐ์ ์ด์ timestep์์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ์ด ์ถ๋ ฅ์ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ์ ๋ชจ๋ธ์ ๋ํ ๋ค์ ์
๋ ฅ์ด ๋๋ฉฐ, `timesteps` ๋ฐฐ์ด์ ๋์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต๋ฉ๋๋ค. |
|
|
|
```py |
|
>>> input = noise |
|
|
|
>>> for t in scheduler.timesteps: |
|
... with torch.no_grad(): |
|
... noisy_residual = model(input, t).sample |
|
... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample |
|
... input = previous_noisy_sample |
|
``` |
|
|
|
์ด๊ฒ์ด ์ ์ฒด ๋
ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค์ด๋ฉฐ, ๋์ผํ ํจํด์ ์ฌ์ฉํด ๋ชจ๋ diffusion ์์คํ
์ ์์ฑํ ์ ์์ต๋๋ค. |
|
|
|
6. ๋ง์ง๋ง ๋จ๊ณ๋ ๋
ธ์ด์ฆ๊ฐ ์ ๊ฑฐ๋ ์ถ๋ ฅ์ ์ด๋ฏธ์ง๋ก ๋ณํํ๋ ๊ฒ์
๋๋ค: |
|
|
|
```py |
|
>>> from PIL import Image |
|
>>> import numpy as np |
|
|
|
>>> image = (input / 2 + 0.5).clamp(0, 1) |
|
>>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0] |
|
>>> image = Image.fromarray((image * 255).round().astype("uint8")) |
|
>>> image |
|
``` |
|
|
|
๋ค์ ์น์
์์๋ ์ฌ๋ฌ๋ถ์ ๊ธฐ์ ์ ์ํํด๋ณด๊ณ ์ข ๋ ๋ณต์กํ Stable Diffusion ํ์ดํ๋ผ์ธ์ ๋ถ์ํด ๋ณด๊ฒ ์ต๋๋ค. ๋ฐฉ๋ฒ์ ๊ฑฐ์ ๋์ผํฉ๋๋ค. ํ์ํ ๊ตฌ์ฑ์์๋ค์ ์ด๊ธฐํํ๊ณ timestep์๋ฅผ ์ค์ ํ์ฌ `timestep` ๋ฐฐ์ด์ ์์ฑํฉ๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ์์ `timestep` ๋ฐฐ์ด์ด ์ฌ์ฉ๋๋ฉฐ, ์ด ๋ฐฐ์ด์ ๊ฐ ์์์ ๋ํด ๋ชจ๋ธ์ ๋
ธ์ด์ฆ๊ฐ ์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ `timestep`์ ๋ฐ๋ณตํ๊ณ ๊ฐ timestep์์ noise residual์ ์ถ๋ ฅํ๊ณ ์ค์ผ์ค๋ฌ๋ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ด์ timestep์์ ๋
ธ์ด์ฆ๊ฐ ๋ํ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ์ด ํ๋ก์ธ์ค๋ `timestep` ๋ฐฐ์ด์ ๋์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต๋ฉ๋๋ค. |
|
|
|
ํ๋ฒ ์ฌ์ฉํด ๋ด
์๋ค! |
|
|
|
## Stable Diffusion ํ์ดํ๋ผ์ธ ํด์ฒดํ๊ธฐ |
|
|
|
Stable Diffusion ์ text-to-image *latent diffusion* ๋ชจ๋ธ์
๋๋ค. latent diffusion ๋ชจ๋ธ์ด๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ด์ ๋ ์ค์ ํฝ์
๊ณต๊ฐ ๋์ ์ด๋ฏธ์ง์ ์ ์ฐจ์์ ํํ์ผ๋ก ์์
ํ๊ธฐ ๋๋ฌธ์ด๊ณ , ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ด ๋ ๋์ต๋๋ค. ์ธ์ฝ๋๋ ์ด๋ฏธ์ง๋ฅผ ๋ ์์ ํํ์ผ๋ก ์์ถํ๊ณ , ๋์ฝ๋๋ ์์ถ๋ ํํ์ ๋ค์ ์ด๋ฏธ์ง๋ก ๋ณํํฉ๋๋ค. text-to-image ๋ชจ๋ธ์ ๊ฒฝ์ฐ ํ
์คํธ ์๋ฒ ๋ฉ์ ์์ฑํ๊ธฐ ์ํด tokenizer์ ์ธ์ฝ๋๊ฐ ํ์ํฉ๋๋ค. ์ด์ ์์ ์์ ์ด๋ฏธ UNet ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๊ฐ ํ์ํ๋ค๋ ๊ฒ์ ์๊ณ ๊ณ์
จ์ ๊ฒ์
๋๋ค. |
|
|
|
๋ณด์๋ค์ํผ, ์ด๊ฒ์ UNet ๋ชจ๋ธ๋ง ํฌํจ๋ DDPM ํ์ดํ๋ผ์ธ๋ณด๋ค ๋ ๋ณต์กํฉ๋๋ค. Stable Diffusion ๋ชจ๋ธ์๋ ์ธ ๊ฐ์ ๊ฐ๋ณ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ด ์์ต๋๋ค. |
|
|
|
<Tip> |
|
|
|
๐ก VAE, UNet ๋ฐ ํ
์คํธ ์ธ์ฝ๋ ๋ชจ๋ธ์ ์๋๋ฐฉ์์ ๋ํ ์์ธํ ๋ด์ฉ์ [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) ๋ธ๋ก๊ทธ๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
</Tip> |
|
|
|
์ด์ Stable Diffusion ํ์ดํ๋ผ์ธ์ ํ์ํ ๊ตฌ์ฑ์์๋ค์ด ๋ฌด์์ธ์ง ์์์ผ๋, [`~ModelMixin.from_pretrained`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด ๋ชจ๋ ๊ตฌ์ฑ์์๋ฅผ ๋ถ๋ฌ์ต๋๋ค. ์ฌ์ ํ์ต๋ ์ฒดํฌํฌ์ธํธ [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)์์ ์ฐพ์ ์ ์์ผ๋ฉฐ, ๊ฐ ๊ตฌ์ฑ์์๋ค์ ๋ณ๋์ ํ์ ํด๋์ ์ ์ฅ๋์ด ์์ต๋๋ค: |
|
|
|
```py |
|
>>> from PIL import Image |
|
>>> import torch |
|
>>> from transformers import CLIPTextModel, CLIPTokenizer |
|
>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler |
|
|
|
>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") |
|
>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") |
|
>>> text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder") |
|
>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
|
``` |
|
|
|
๊ธฐ๋ณธ [`PNDMScheduler`] ๋์ , [`UniPCMultistepScheduler`]๋ก ๊ต์ฒดํ์ฌ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ฅผ ์ผ๋ง๋ ์ฝ๊ฒ ์ฐ๊ฒฐํ ์ ์๋์ง ํ์ธํฉ๋๋ค: |
|
|
|
```py |
|
>>> from diffusers import UniPCMultistepScheduler |
|
|
|
>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") |
|
``` |
|
|
|
์ถ๋ก ์๋๋ฅผ ๋์ด๋ ค๋ฉด ์ค์ผ์ค๋ฌ์ ๋ฌ๋ฆฌ ํ์ต ๊ฐ๋ฅํ ๊ฐ์ค์น๊ฐ ์์ผ๋ฏ๋ก ๋ชจ๋ธ์ GPU๋ก ์ฎ๊ธฐ์ธ์: |
|
|
|
```py |
|
>>> torch_device = "cuda" |
|
>>> vae.to(torch_device) |
|
>>> text_encoder.to(torch_device) |
|
>>> unet.to(torch_device) |
|
``` |
|
|
|
### ํ
์คํธ ์๋ฒ ๋ฉ ์์ฑํ๊ธฐ |
|
|
|
๋ค์ ๋จ๊ณ๋ ์๋ฒ ๋ฉ์ ์์ฑํ๊ธฐ ์ํด ํ
์คํธ๋ฅผ ํ ํฐํํ๋ ๊ฒ์
๋๋ค. ์ด ํ
์คํธ๋ UNet ๋ชจ๋ธ์์ condition์ผ๋ก ์ฌ์ฉ๋๊ณ ์
๋ ฅ ํ๋กฌํํธ์ ์ ์ฌํ ๋ฐฉํฅ์ผ๋ก diffusion ํ๋ก์ธ์ค๋ฅผ ์กฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. |
|
|
|
<Tip> |
|
|
|
๐ก `guidance_scale` ๋งค๊ฐ๋ณ์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ๋ ํ๋กฌํํธ์ ์ผ๋ง๋ ๋ง์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ ์ง ๊ฒฐ์ ํฉ๋๋ค. |
|
|
|
</Tip> |
|
|
|
๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์์ฑํ๊ณ ์ถ๋ค๋ฉด ์ํ๋ ํ๋กฌํํธ๋ฅผ ์์ ๋กญ๊ฒ ์ ํํ์ธ์! |
|
|
|
```py |
|
>>> prompt = ["a photograph of an astronaut riding a horse"] |
|
>>> height = 512 # Stable Diffusion์ ๊ธฐ๋ณธ ๋์ด |
|
>>> width = 512 # Stable Diffusion์ ๊ธฐ๋ณธ ๋๋น |
|
>>> num_inference_steps = 25 # ๋
ธ์ด์ฆ ์ ๊ฑฐ ์คํ
์ |
|
>>> guidance_scale = 7.5 # classifier-free guidance๋ฅผ ์ํ scale |
|
>>> generator = torch.manual_seed(0) # ์ด๊ธฐ ์ ์ฌ ๋
ธ์ด์ฆ๋ฅผ ์์ฑํ๋ seed generator |
|
>>> batch_size = len(prompt) |
|
``` |
|
|
|
ํ
์คํธ๋ฅผ ํ ํฐํํ๊ณ ํ๋กฌํํธ์์ ์๋ฒ ๋ฉ์ ์์ฑํฉ๋๋ค: |
|
|
|
```py |
|
>>> text_input = tokenizer( |
|
... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" |
|
... ) |
|
|
|
>>> with torch.no_grad(): |
|
... text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] |
|
``` |
|
|
|
๋ํ ํจ๋ฉ ํ ํฐ์ ์๋ฒ ๋ฉ์ธ *unconditional ํ
์คํธ ์๋ฒ ๋ฉ*์ ์์ฑํด์ผ ํฉ๋๋ค. ์ด ์๋ฒ ๋ฉ์ ์กฐ๊ฑด๋ถ `text_embeddings`๊ณผ ๋์ผํ shape(`batch_size` ๊ทธ๋ฆฌ๊ณ `seq_length`)์ ๊ฐ์ ธ์ผ ํฉ๋๋ค: |
|
|
|
```py |
|
>>> max_length = text_input.input_ids.shape[-1] |
|
>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt") |
|
>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] |
|
``` |
|
|
|
๋๋ฒ์ forward pass๋ฅผ ํผํ๊ธฐ ์ํด conditional ์๋ฒ ๋ฉ๊ณผ unconditional ์๋ฒ ๋ฉ์ ๋ฐฐ์น(batch)๋ก ์ฐ๊ฒฐํ๊ฒ ์ต๋๋ค: |
|
|
|
```py |
|
>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
``` |
|
|
|
### ๋๋ค ๋
ธ์ด์ฆ ์์ฑ |
|
|
|
๊ทธ๋ค์ diffusion ํ๋ก์ธ์ค์ ์์์ ์ผ๋ก ์ด๊ธฐ ๋๋ค ๋
ธ์ด์ฆ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๊ฒ์ด ์ด๋ฏธ์ง์ ์ ์ฌ์ ํํ์ด๋ฉฐ ์ ์ฐจ์ ์ผ๋ก ๋
ธ์ด์ฆ๊ฐ ์ ๊ฑฐ๋ฉ๋๋ค. ์ด ์์ ์์ `latent` ์ด๋ฏธ์ง๋ ์ต์ข
์ด๋ฏธ์ง ํฌ๊ธฐ๋ณด๋ค ์์ง๋ง ๋์ค์ ๋ชจ๋ธ์ด ์ด๋ฅผ 512x512 ์ด๋ฏธ์ง ํฌ๊ธฐ๋ก ๋ณํํ๋ฏ๋ก ๊ด์ฐฎ์ต๋๋ค. |
|
|
|
<Tip> |
|
|
|
๐ก `vae` ๋ชจ๋ธ์๋ 3๊ฐ์ ๋ค์ด ์ํ๋ง ๋ ์ด์ด๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋์ด์ ๋๋น๊ฐ 8๋ก ๋๋ฉ๋๋ค. ๋ค์์ ์คํํ์ฌ ํ์ธํ ์ ์์ต๋๋ค: |
|
|
|
```py |
|
2 ** (len(vae.config.block_out_channels) - 1) == 8 |
|
``` |
|
|
|
</Tip> |
|
|
|
```py |
|
>>> latents = torch.randn( |
|
... (batch_size, unet.config.in_channels, height // 8, width // 8), |
|
... generator=generator, |
|
... device=torch_device, |
|
... ) |
|
``` |
|
|
|
### ์ด๋ฏธ์ง ๋
ธ์ด์ฆ ์ ๊ฑฐ |
|
|
|
๋จผ์ [`UniPCMultistepScheduler`]์ ๊ฐ์ ํฅ์๋ ์ค์ผ์ค๋ฌ์ ํ์ํ ๋
ธ์ด์ฆ ์ค์ผ์ผ ๊ฐ์ธ ์ด๊ธฐ ๋
ธ์ด์ฆ ๋ถํฌ *sigma* ๋ก ์
๋ ฅ์ ์ค์ผ์ผ๋ง ํ๋ ๊ฒ๋ถํฐ ์์ํฉ๋๋ค: |
|
|
|
```py |
|
>>> latents = latents * scheduler.init_noise_sigma |
|
``` |
|
|
|
๋ง์ง๋ง ๋จ๊ณ๋ `latent`์ ์์ํ ๋
ธ์ด์ฆ๋ฅผ ์ ์ง์ ์ผ๋ก ํ๋กฌํํธ์ ์ค๋ช
๋ ์ด๋ฏธ์ง๋ก ๋ณํํ๋ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฅผ ์์ฑํ๋ ๊ฒ์
๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ ์ธ ๊ฐ์ง ์์
์ ์ํํด์ผ ํ๋ค๋ ์ ์ ๊ธฐ์ตํ์ธ์: |
|
|
|
1. ๋
ธ์ด์ฆ ์ ๊ฑฐ ์ค์ ์ฌ์ฉํ ์ค์ผ์ค๋ฌ์ timesteps๋ฅผ ์ค์ ํฉ๋๋ค. |
|
2. timestep์ ๋ฐ๋ผ ๋ฐ๋ณตํฉ๋๋ค. |
|
3. ๊ฐ timestep์์ UNet ๋ชจ๋ธ์ ํธ์ถํ์ฌ noise residual์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ์ ์ ๋ฌํ์ฌ ์ด์ ๋
ธ์ด์ฆ ์ํ์ ๊ณ์ฐํฉ๋๋ค. |
|
|
|
```py |
|
>>> from tqdm.auto import tqdm |
|
|
|
>>> scheduler.set_timesteps(num_inference_steps) |
|
|
|
>>> for t in tqdm(scheduler.timesteps): |
|
... # classifier-free guidance๋ฅผ ์ํํ๋ ๊ฒฝ์ฐ ๋๋ฒ์ forward pass๋ฅผ ์ํํ์ง ์๋๋ก latent๋ฅผ ํ์ฅ. |
|
... latent_model_input = torch.cat([latents] * 2) |
|
|
|
... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) |
|
|
|
... # noise residual ์์ธก |
|
... with torch.no_grad(): |
|
... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
|
|
|
... # guidance ์ํ |
|
... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
... # ์ด์ ๋
ธ์ด์ฆ ์ํ์ ๊ณ์ฐ x_t -> x_t-1 |
|
... latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
``` |
|
|
|
### ์ด๋ฏธ์ง ๋์ฝ๋ฉ |
|
|
|
๋ง์ง๋ง ๋จ๊ณ๋ `vae`๋ฅผ ์ด์ฉํ์ฌ ์ ์ฌ ํํ์ ์ด๋ฏธ์ง๋ก ๋์ฝ๋ฉํ๊ณ `sample`๊ณผ ํจ๊ป ๋์ฝ๋ฉ๋ ์ถ๋ ฅ์ ์ป๋ ๊ฒ์
๋๋ค: |
|
|
|
```py |
|
# latent๋ฅผ ์ค์ผ์ผ๋งํ๊ณ vae๋ก ์ด๋ฏธ์ง ๋์ฝ๋ฉ |
|
latents = 1 / 0.18215 * latents |
|
with torch.no_grad(): |
|
image = vae.decode(latents).sample |
|
``` |
|
|
|
๋ง์ง๋ง์ผ๋ก ์ด๋ฏธ์ง๋ฅผ `PIL.Image`๋ก ๋ณํํ๋ฉด ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ํ์ธํ ์ ์์ต๋๋ค! |
|
|
|
```py |
|
>>> image = (image / 2 + 0.5).clamp(0, 1) |
|
>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
>>> images = (image * 255).round().astype("uint8") |
|
>>> pil_images = [Image.fromarray(image) for image in images] |
|
>>> pil_images[0] |
|
``` |
|
|
|
<div class="flex justify-center"> |
|
<img src="https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_k_lms.png"/> |
|
</div> |
|
|
|
## ๋ค์ ๋จ๊ณ |
|
|
|
๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ๋ถํฐ ๋ณต์กํ ํ์ดํ๋ผ์ธ๊น์ง, ์์ ๋ง์ diffusion ์์คํ
์ ์์ฑํ๋ ๋ฐ ํ์ํ ๊ฒ์ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฟ์ด๋ผ๋ ๊ฒ์ ์ ์ ์์์ต๋๋ค. ์ด ๋ฃจํ๋ ์ค์ผ์ค๋ฌ์ timesteps๋ฅผ ์ค์ ํ๊ณ , ์ด๋ฅผ ๋ฐ๋ณตํ๋ฉฐ, UNet ๋ชจ๋ธ์ ํธ์ถํ์ฌ noise residual์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ์ ์ ๋ฌํ์ฌ ์ด์ ๋
ธ์ด์ฆ ์ํ์ ๊ณ์ฐํ๋ ๊ณผ์ ์ ๋ฒ๊ฐ์ ๊ฐ๋ฉฐ ์ํํด์ผ ํฉ๋๋ค. |
|
|
|
์ด๊ฒ์ด ๋ฐ๋ก ๐งจ Diffusers๊ฐ ์ค๊ณ๋ ๋ชฉ์ ์
๋๋ค: ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํด ์์ ๋ง์ diffusion ์์คํ
์ ์ง๊ด์ ์ด๊ณ ์ฝ๊ฒ ์์ฑํ ์ ์๋๋ก ํ๊ธฐ ์ํด์์
๋๋ค. |
|
|
|
๋ค์ ๋จ๊ณ๋ฅผ ์์ ๋กญ๊ฒ ์งํํ์ธ์: |
|
|
|
* ๐งจ Diffusers์ [ํ์ดํ๋ผ์ธ ๊ตฌ์ถ ๋ฐ ๊ธฐ์ฌ](using-diffusers/#contribute_pipeline)ํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด์ธ์. ์ฌ๋ฌ๋ถ์ด ์ด๋ค ์์ด๋์ด๋ฅผ ๋ด๋์์ง ๊ธฐ๋๋ฉ๋๋ค! |
|
* ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ [๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ](./api/pipelines/overview)์ ์ดํด๋ณด๊ณ , ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๋๋ก ์ฌ์ฉํ์ฌ ํ์ดํ๋ผ์ธ์ ์ฒ์๋ถํฐ ํด์ฒดํ๊ณ ๋น๋ํ ์ ์๋์ง ํ์ธํด ๋ณด์ธ์. |
|
|