Upamanyu098's picture
End of training
ef4d689 verified
|
raw
history blame
15.5 kB
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ํŒŒ์ดํ”„๋ผ์ธ, ๋ชจ๋ธ ๋ฐ ์Šค์ผ€์ค„๋Ÿฌ ์ดํ•ดํ•˜๊ธฐ
[[open-in-colab]]
๐Ÿงจ Diffusers๋Š” ์‚ฌ์šฉ์ž ์นœํ™”์ ์ด๋ฉฐ ์œ ์—ฐํ•œ ๋„๊ตฌ ์ƒ์ž๋กœ, ์‚ฌ์šฉ์‚ฌ๋ก€์— ๋งž๊ฒŒ diffusion ์‹œ์Šคํ…œ์„ ๊ตฌ์ถ• ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋„๊ตฌ ์ƒ์ž์˜ ํ•ต์‹ฌ์€ ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ์ž…๋‹ˆ๋‹ค. [`DiffusionPipeline`]์€ ํŽธ์˜๋ฅผ ์œ„ํ•ด ์ด๋Ÿฌํ•œ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๋ฒˆ๋“ค๋กœ ์ œ๊ณตํ•˜์ง€๋งŒ, ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถ„๋ฆฌํ•˜๊ณ  ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๊ฐœ๋ณ„์ ์œผ๋กœ ์‚ฌ์šฉํ•ด ์ƒˆ๋กœ์šด diffusion ์‹œ์Šคํ…œ์„ ๋งŒ๋“ค ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ๊นŒ์ง€ ์ง„ํ–‰ํ•˜๋ฉฐ ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•ด ์ถ”๋ก ์„ ์œ„ํ•œ diffusion ์‹œ์Šคํ…œ์„ ์กฐ๋ฆฝํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ฐฐ์›๋‹ˆ๋‹ค.
## ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ ํ•ด์ฒดํ•˜๊ธฐ
ํŒŒ์ดํ”„๋ผ์ธ์€ ์ถ”๋ก ์„ ์œ„ํ•ด ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋Š” ๋น ๋ฅด๊ณ  ์‰ฌ์šด ๋ฐฉ๋ฒ•์œผ๋กœ, ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์ฝ”๋“œ๊ฐ€ 4์ค„ ์ด์ƒ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค:
```py
>>> from diffusers import DDPMPipeline
>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256").to("cuda")
>>> image = ddpm(num_inference_steps=25).images[0]
>>> image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ddpm-cat.png" alt="Image of cat created from DDPMPipeline"/>
</div>
์ •๋ง ์‰ฝ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ํŒŒ์ดํ”„๋ผ์ธ์€ ์–ด๋–ป๊ฒŒ ์ด๋ ‡๊ฒŒ ํ•  ์ˆ˜ ์žˆ์—ˆ์„๊นŒ์š”? ํŒŒ์ดํ”„๋ผ์ธ์„ ์„ธ๋ถ„ํ™”ํ•˜์—ฌ ๋‚ด๋ถ€์—์„œ ์–ด๋–ค ์ผ์ด ์ผ์–ด๋‚˜๊ณ  ์žˆ๋Š”์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
์œ„ ์˜ˆ์‹œ์—์„œ ํŒŒ์ดํ”„๋ผ์ธ์—๋Š” [`UNet2DModel`] ๋ชจ๋ธ๊ณผ [`DDPMScheduler`]๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์€ ์›ํ•˜๋Š” ์ถœ๋ ฅ ํฌ๊ธฐ์˜ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ๋ฒˆ ํ†ต๊ณผ์‹œ์ผœ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฐ timestep์—์„œ ๋ชจ๋ธ์€ *noise residual*์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ๋Š” ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์€ ์ง€์ •๋œ ์ถ”๋ก  ์Šคํ…์ˆ˜์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ์ด ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ณ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ๋‹ค์‹œ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์ž์ฒด์ ์ธ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ž‘์„ฑํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
1. ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค:
```py
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
>>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
```
2. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค๋ฅผ ์‹คํ–‰ํ•  timestep ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> scheduler.set_timesteps(50)
```
3. ์Šค์ผ€์ค„๋Ÿฌ์˜ timestep์„ ์„ค์ •ํ•˜๋ฉด ๊ท ๋“ฑํ•œ ๊ฐ„๊ฒฉ์˜ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๊ฐ€์ง„ ํ…์„œ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.(์ด ์˜ˆ์‹œ์—์„œ๋Š” 50๊ฐœ) ๊ฐ ์š”์†Œ๋Š” ๋ชจ๋ธ์ด ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋Š” ์‹œ๊ฐ„ ๊ฐ„๊ฒฉ์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฅผ ๋งŒ๋“ค ๋•Œ ์ด ํ…์„œ๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค:
```py
>>> scheduler.timesteps
tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,
420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,
140, 120, 100, 80, 60, 40, 20, 0])
```
4. ์›ํ•˜๋Š” ์ถœ๋ ฅ๊ณผ ๊ฐ™์€ ๋ชจ์–‘์„ ๊ฐ€์ง„ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:
```py
>>> import torch
>>> sample_size = model.config.sample_size
>>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
```
5. ์ด์ œ timestep์„ ๋ฐ˜๋ณตํ•˜๋Š” ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ฐ timestep์—์„œ ๋ชจ๋ธ์€ [`UNet2DModel.forward`]๋ฅผ ํ†ตํ•ด noisy residual์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์Šค์ผ€์ค„๋Ÿฌ์˜ [`~DDPMScheduler.step`] ๋ฉ”์„œ๋“œ๋Š” noisy residual, timestep, ๊ทธ๋ฆฌ๊ณ  ์ž…๋ ฅ์„ ๋ฐ›์•„ ์ด์ „ timestep์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ด ์ถœ๋ ฅ์€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„์˜ ๋ชจ๋ธ์— ๋Œ€ํ•œ ๋‹ค์Œ ์ž…๋ ฅ์ด ๋˜๋ฉฐ, `timesteps` ๋ฐฐ์—ด์˜ ๋์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต๋ฉ๋‹ˆ๋‹ค.
```py
>>> input = noise
>>> for t in scheduler.timesteps:
... with torch.no_grad():
... noisy_residual = model(input, t).sample
... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
... input = previous_noisy_sample
```
์ด๊ฒƒ์ด ์ „์ฒด ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค์ด๋ฉฐ, ๋™์ผํ•œ ํŒจํ„ด์„ ์‚ฌ์šฉํ•ด ๋ชจ๋“  diffusion ์‹œ์Šคํ…œ์„ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
6. ๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” ๋…ธ์ด์ฆˆ๊ฐ€ ์ œ๊ฑฐ๋œ ์ถœ๋ ฅ์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:
```py
>>> from PIL import Image
>>> import numpy as np
>>> image = (input / 2 + 0.5).clamp(0, 1)
>>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
>>> image = Image.fromarray((image * 255).round().astype("uint8"))
>>> image
```
๋‹ค์Œ ์„น์…˜์—์„œ๋Š” ์—ฌ๋Ÿฌ๋ถ„์˜ ๊ธฐ์ˆ ์„ ์‹œํ—˜ํ•ด๋ณด๊ณ  ์ข€ ๋” ๋ณต์žกํ•œ Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถ„์„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐฉ๋ฒ•์€ ๊ฑฐ์˜ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ๊ตฌ์„ฑ์š”์†Œ๋“ค์„ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  timestep์ˆ˜๋ฅผ ์„ค์ •ํ•˜์—ฌ `timestep` ๋ฐฐ์—ด์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„์—์„œ `timestep` ๋ฐฐ์—ด์ด ์‚ฌ์šฉ๋˜๋ฉฐ, ์ด ๋ฐฐ์—ด์˜ ๊ฐ ์š”์†Œ์— ๋Œ€ํ•ด ๋ชจ๋ธ์€ ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋Š” `timestep`์„ ๋ฐ˜๋ณตํ•˜๊ณ  ๊ฐ timestep์—์„œ noise residual์„ ์ถœ๋ ฅํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ๋Š” ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด์ „ timestep์—์„œ ๋…ธ์ด์ฆˆ๊ฐ€ ๋œํ•œ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ด ํ”„๋กœ์„ธ์Šค๋Š” `timestep` ๋ฐฐ์—ด์˜ ๋์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต๋ฉ๋‹ˆ๋‹ค.
ํ•œ๋ฒˆ ์‚ฌ์šฉํ•ด ๋ด…์‹œ๋‹ค!
## Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ ํ•ด์ฒดํ•˜๊ธฐ
Stable Diffusion ์€ text-to-image *latent diffusion* ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. latent diffusion ๋ชจ๋ธ์ด๋ผ๊ณ  ๋ถˆ๋ฆฌ๋Š” ์ด์œ ๋Š” ์‹ค์ œ ํ”ฝ์…€ ๊ณต๊ฐ„ ๋Œ€์‹  ์ด๋ฏธ์ง€์˜ ์ €์ฐจ์›์˜ ํ‘œํ˜„์œผ๋กœ ์ž‘์—…ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๊ณ , ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ด ๋” ๋†’์Šต๋‹ˆ๋‹ค. ์ธ์ฝ”๋”๋Š” ์ด๋ฏธ์ง€๋ฅผ ๋” ์ž‘์€ ํ‘œํ˜„์œผ๋กœ ์••์ถ•ํ•˜๊ณ , ๋””์ฝ”๋”๋Š” ์••์ถ•๋œ ํ‘œํ˜„์„ ๋‹ค์‹œ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. text-to-image ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด tokenizer์™€ ์ธ์ฝ”๋”๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด์ „ ์˜ˆ์ œ์—์„œ ์ด๋ฏธ UNet ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์€ ์•Œ๊ณ  ๊ณ„์…จ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๋ณด์‹œ๋‹ค์‹œํ”ผ, ์ด๊ฒƒ์€ UNet ๋ชจ๋ธ๋งŒ ํฌํ•จ๋œ DDPM ํŒŒ์ดํ”„๋ผ์ธ๋ณด๋‹ค ๋” ๋ณต์žกํ•ฉ๋‹ˆ๋‹ค. Stable Diffusion ๋ชจ๋ธ์—๋Š” ์„ธ ๊ฐœ์˜ ๊ฐœ๋ณ„ ์‚ฌ์ „ํ•™์Šต๋œ ๋ชจ๋ธ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
<Tip>
๐Ÿ’ก VAE, UNet ๋ฐ ํ…์ŠคํŠธ ์ธ์ฝ”๋” ๋ชจ๋ธ์˜ ์ž‘๋™๋ฐฉ์‹์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) ๋ธ”๋กœ๊ทธ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
</Tip>
์ด์ œ Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ์— ํ•„์š”ํ•œ ๊ตฌ์„ฑ์š”์†Œ๋“ค์ด ๋ฌด์—‡์ธ์ง€ ์•Œ์•˜์œผ๋‹ˆ, [`~ModelMixin.from_pretrained`] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋“  ๊ตฌ์„ฑ์š”์†Œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ์‚ฌ์ „ํ•™์Šต๋œ ์ฒดํฌํฌ์ธํŠธ [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ฐ ๊ตฌ์„ฑ์š”์†Œ๋“ค์€ ๋ณ„๋„์˜ ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค:
```py
>>> from PIL import Image
>>> import torch
>>> from transformers import CLIPTextModel, CLIPTokenizer
>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
>>> text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
```
๊ธฐ๋ณธ [`PNDMScheduler`] ๋Œ€์‹ , [`UniPCMultistepScheduler`]๋กœ ๊ต์ฒดํ•˜์—ฌ ๋‹ค๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์–ผ๋งˆ๋‚˜ ์‰ฝ๊ฒŒ ์—ฐ๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค:
```py
>>> from diffusers import UniPCMultistepScheduler
>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
```
์ถ”๋ก  ์†๋„๋ฅผ ๋†’์ด๋ ค๋ฉด ์Šค์ผ€์ค„๋Ÿฌ์™€ ๋‹ฌ๋ฆฌ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๊ฐ€์ค‘์น˜๊ฐ€ ์žˆ์œผ๋ฏ€๋กœ ๋ชจ๋ธ์„ GPU๋กœ ์˜ฎ๊ธฐ์„ธ์š”:
```py
>>> torch_device = "cuda"
>>> vae.to(torch_device)
>>> text_encoder.to(torch_device)
>>> unet.to(torch_device)
```
### ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑํ•˜๊ธฐ
๋‹ค์Œ ๋‹จ๊ณ„๋Š” ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ํ…์ŠคํŠธ๋Š” UNet ๋ชจ๋ธ์—์„œ condition์œผ๋กœ ์‚ฌ์šฉ๋˜๊ณ  ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ์™€ ์œ ์‚ฌํ•œ ๋ฐฉํ–ฅ์œผ๋กœ diffusion ํ”„๋กœ์„ธ์Šค๋ฅผ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
<Tip>
๐Ÿ’ก `guidance_scale` ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ํ”„๋กฌํ”„ํŠธ์— ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ• ์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
</Tip>
๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ์›ํ•˜๋Š” ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์„ ํƒํ•˜์„ธ์š”!
```py
>>> prompt = ["a photograph of an astronaut riding a horse"]
>>> height = 512 # Stable Diffusion์˜ ๊ธฐ๋ณธ ๋†’์ด
>>> width = 512 # Stable Diffusion์˜ ๊ธฐ๋ณธ ๋„ˆ๋น„
>>> num_inference_steps = 25 # ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ์Šคํ… ์ˆ˜
>>> guidance_scale = 7.5 # classifier-free guidance๋ฅผ ์œ„ํ•œ scale
>>> generator = torch.manual_seed(0) # ์ดˆ๊ธฐ ์ž ์žฌ ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•˜๋Š” seed generator
>>> batch_size = len(prompt)
```
ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ  ํ”„๋กฌํ”„ํŠธ์—์„œ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:
```py
>>> text_input = tokenizer(
... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
... )
>>> with torch.no_grad():
... text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
```
๋˜ํ•œ ํŒจ๋”ฉ ํ† ํฐ์˜ ์ž„๋ฒ ๋”ฉ์ธ *unconditional ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ*์„ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ž„๋ฒ ๋”ฉ์€ ์กฐ๊ฑด๋ถ€ `text_embeddings`๊ณผ ๋™์ผํ•œ shape(`batch_size` ๊ทธ๋ฆฌ๊ณ  `seq_length`)์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> max_length = text_input.input_ids.shape[-1]
>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
```
๋‘๋ฒˆ์˜ forward pass๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด conditional ์ž„๋ฒ ๋”ฉ๊ณผ unconditional ์ž„๋ฒ ๋”ฉ์„ ๋ฐฐ์น˜(batch)๋กœ ์—ฐ๊ฒฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:
```py
>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
```
### ๋žœ๋ค ๋…ธ์ด์ฆˆ ์ƒ์„ฑ
๊ทธ๋‹ค์Œ diffusion ํ”„๋กœ์„ธ์Šค์˜ ์‹œ์ž‘์ ์œผ๋กœ ์ดˆ๊ธฐ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ์ด๋ฏธ์ง€์˜ ์ž ์žฌ์  ํ‘œํ˜„์ด๋ฉฐ ์ ์ฐจ์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๊ฐ€ ์ œ๊ฑฐ๋ฉ๋‹ˆ๋‹ค. ์ด ์‹œ์ ์—์„œ `latent` ์ด๋ฏธ์ง€๋Š” ์ตœ์ข… ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ณด๋‹ค ์ž‘์ง€๋งŒ ๋‚˜์ค‘์— ๋ชจ๋ธ์ด ์ด๋ฅผ 512x512 ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋กœ ๋ณ€ํ™˜ํ•˜๋ฏ€๋กœ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค.
<Tip>
๐Ÿ’ก `vae` ๋ชจ๋ธ์—๋Š” 3๊ฐœ์˜ ๋‹ค์šด ์ƒ˜ํ”Œ๋ง ๋ ˆ์ด์–ด๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋†’์ด์™€ ๋„ˆ๋น„๊ฐ€ 8๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์„ ์‹คํ–‰ํ•˜์—ฌ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
```py
2 ** (len(vae.config.block_out_channels) - 1) == 8
```
</Tip>
```py
>>> latents = torch.randn(
... (batch_size, unet.config.in_channels, height // 8, width // 8),
... generator=generator,
... device=torch_device,
... )
```
### ์ด๋ฏธ์ง€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ
๋จผ์ € [`UniPCMultistepScheduler`]์™€ ๊ฐ™์€ ํ–ฅ์ƒ๋œ ์Šค์ผ€์ค„๋Ÿฌ์— ํ•„์š”ํ•œ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ผ ๊ฐ’์ธ ์ดˆ๊ธฐ ๋…ธ์ด์ฆˆ ๋ถ„ํฌ *sigma* ๋กœ ์ž…๋ ฅ์„ ์Šค์ผ€์ผ๋ง ํ•˜๋Š” ๊ฒƒ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค:
```py
>>> latents = latents * scheduler.init_noise_sigma
```
๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” `latent`์˜ ์ˆœ์ˆ˜ํ•œ ๋…ธ์ด์ฆˆ๋ฅผ ์ ์ง„์ ์œผ๋กœ ํ”„๋กฌํ”„ํŠธ์— ์„ค๋ช…๋œ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋Š” ์„ธ ๊ฐ€์ง€ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•œ๋‹ค๋Š” ์ ์„ ๊ธฐ์–ตํ•˜์„ธ์š”:
1. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ์ค‘์— ์‚ฌ์šฉํ•  ์Šค์ผ€์ค„๋Ÿฌ์˜ timesteps๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
2. timestep์„ ๋”ฐ๋ผ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
3. ๊ฐ timestep์—์„œ UNet ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•˜์—ฌ noise residual์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ์— ์ „๋‹ฌํ•˜์—ฌ ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
```py
>>> from tqdm.auto import tqdm
>>> scheduler.set_timesteps(num_inference_steps)
>>> for t in tqdm(scheduler.timesteps):
... # classifier-free guidance๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ ๋‘๋ฒˆ์˜ forward pass๋ฅผ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋„๋ก latent๋ฅผ ํ™•์žฅ.
... latent_model_input = torch.cat([latents] * 2)
... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
... # noise residual ์˜ˆ์ธก
... with torch.no_grad():
... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
... # guidance ์ˆ˜ํ–‰
... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
... # ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐ x_t -> x_t-1
... latents = scheduler.step(noise_pred, t, latents).prev_sample
```
### ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” `vae`๋ฅผ ์ด์šฉํ•˜์—ฌ ์ž ์žฌ ํ‘œํ˜„์„ ์ด๋ฏธ์ง€๋กœ ๋””์ฝ”๋”ฉํ•˜๊ณ  `sample`๊ณผ ํ•จ๊ป˜ ๋””์ฝ”๋”ฉ๋œ ์ถœ๋ ฅ์„ ์–ป๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:
```py
# latent๋ฅผ ์Šค์ผ€์ผ๋งํ•˜๊ณ  vae๋กœ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
```
๋งˆ์ง€๋ง‰์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ `PIL.Image`๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
```py
>>> image = (image / 2 + 0.5).clamp(0, 1)
>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
>>> images = (image * 255).round().astype("uint8")
>>> pil_images = [Image.fromarray(image) for image in images]
>>> pil_images[0]
```
<div class="flex justify-center">
<img src="https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_k_lms.png"/>
</div>
## ๋‹ค์Œ ๋‹จ๊ณ„
๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ๋ถ€ํ„ฐ ๋ณต์žกํ•œ ํŒŒ์ดํ”„๋ผ์ธ๊นŒ์ง€, ์ž์‹ ๋งŒ์˜ diffusion ์‹œ์Šคํ…œ์„ ์ž‘์„ฑํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๊ฒƒ์€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฟ์ด๋ผ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฃจํ”„๋Š” ์Šค์ผ€์ค„๋Ÿฌ์˜ timesteps๋ฅผ ์„ค์ •ํ•˜๊ณ , ์ด๋ฅผ ๋ฐ˜๋ณตํ•˜๋ฉฐ, UNet ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•˜์—ฌ noise residual์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ์— ์ „๋‹ฌํ•˜์—ฌ ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐํ•˜๋Š” ๊ณผ์ •์„ ๋ฒˆ๊ฐˆ์•„ ๊ฐ€๋ฉฐ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
์ด๊ฒƒ์ด ๋ฐ”๋กœ ๐Ÿงจ Diffusers๊ฐ€ ์„ค๊ณ„๋œ ๋ชฉ์ ์ž…๋‹ˆ๋‹ค: ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•ด ์ž์‹ ๋งŒ์˜ diffusion ์‹œ์Šคํ…œ์„ ์ง๊ด€์ ์ด๊ณ  ์‰ฝ๊ฒŒ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ์ž…๋‹ˆ๋‹ค.
๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์ง„ํ–‰ํ•˜์„ธ์š”:
* ๐Ÿงจ Diffusers์— [ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์ถ• ๋ฐ ๊ธฐ์—ฌ](using-diffusers/#contribute_pipeline)ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณด์„ธ์š”. ์—ฌ๋Ÿฌ๋ถ„์ด ์–ด๋–ค ์•„์ด๋””์–ด๋ฅผ ๋‚ด๋†“์„์ง€ ๊ธฐ๋Œ€๋ฉ๋‹ˆ๋‹ค!
* ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ [๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ](./api/pipelines/overview)์„ ์‚ดํŽด๋ณด๊ณ , ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ณ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•ด์ฒดํ•˜๊ณ  ๋นŒ๋“œํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•ด ๋ณด์„ธ์š”.