ํ์ดํ๋ผ์ธ, ๋ชจ๋ธ ๋ฐ ์ค์ผ์ค๋ฌ ์ดํดํ๊ธฐ
[[open-in-colab]]
๐งจ Diffusers๋ ์ฌ์ฉ์ ์นํ์ ์ด๋ฉฐ ์ ์ฐํ ๋๊ตฌ ์์๋ก, ์ฌ์ฉ์ฌ๋ก์ ๋ง๊ฒ diffusion ์์คํ
์ ๊ตฌ์ถ ํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ์ด ๋๊ตฌ ์์์ ํต์ฌ์ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ์
๋๋ค. [DiffusionPipeline
]์ ํธ์๋ฅผ ์ํด ์ด๋ฌํ ๊ตฌ์ฑ ์์๋ฅผ ๋ฒ๋ค๋ก ์ ๊ณตํ์ง๋ง, ํ์ดํ๋ผ์ธ์ ๋ถ๋ฆฌํ๊ณ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ฌ์ฉํด ์๋ก์ด diffusion ์์คํ
์ ๋ง๋ค ์๋ ์์ต๋๋ค.
์ด ํํ ๋ฆฌ์ผ์์๋ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ๋ถํฐ ์์ํด Stable Diffusion ํ์ดํ๋ผ์ธ๊น์ง ์งํํ๋ฉฐ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํด ์ถ๋ก ์ ์ํ diffusion ์์คํ ์ ์กฐ๋ฆฝํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์๋๋ค.
๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ํด์ฒดํ๊ธฐ
ํ์ดํ๋ผ์ธ์ ์ถ๋ก ์ ์ํด ๋ชจ๋ธ์ ์คํํ๋ ๋น ๋ฅด๊ณ ์ฌ์ด ๋ฐฉ๋ฒ์ผ๋ก, ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๋ฐ ์ฝ๋๊ฐ 4์ค ์ด์ ํ์ํ์ง ์์ต๋๋ค:
>>> from diffusers import DDPMPipeline
>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256").to("cuda")
>>> image = ddpm(num_inference_steps=25).images[0]
>>> image
์ ๋ง ์ฝ์ต๋๋ค. ๊ทธ๋ฐ๋ฐ ํ์ดํ๋ผ์ธ์ ์ด๋ป๊ฒ ์ด๋ ๊ฒ ํ ์ ์์์๊น์? ํ์ดํ๋ผ์ธ์ ์ธ๋ถํํ์ฌ ๋ด๋ถ์์ ์ด๋ค ์ผ์ด ์ผ์ด๋๊ณ ์๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์ ์์์์ ํ์ดํ๋ผ์ธ์๋ [UNet2DModel
] ๋ชจ๋ธ๊ณผ [DDPMScheduler
]๊ฐ ํฌํจ๋์ด ์์ต๋๋ค. ํ์ดํ๋ผ์ธ์ ์ํ๋ ์ถ๋ ฅ ํฌ๊ธฐ์ ๋๋ค ๋
ธ์ด์ฆ๋ฅผ ๋ฐ์ ๋ชจ๋ธ์ ์ฌ๋ฌ๋ฒ ํต๊ณผ์์ผ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํฉ๋๋ค. ๊ฐ timestep์์ ๋ชจ๋ธ์ noise residual์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ๋ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๋
ธ์ด์ฆ๊ฐ ์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ํ์ดํ๋ผ์ธ์ ์ง์ ๋ ์ถ๋ก ์คํ
์์ ๋๋ฌํ ๋๊น์ง ์ด ๊ณผ์ ์ ๋ฐ๋ณตํฉ๋๋ค.
๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๋๋ก ์ฌ์ฉํ์ฌ ํ์ดํ๋ผ์ธ์ ๋ค์ ์์ฑํ๊ธฐ ์ํด ์์ฒด์ ์ธ ๋ ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค๋ฅผ ์์ฑํด ๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์ต๋๋ค:
>>> from diffusers import DDPMScheduler, UNet2DModel >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256") >>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
๋ ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค๋ฅผ ์คํํ timestep ์๋ฅผ ์ค์ ํฉ๋๋ค:
>>> scheduler.set_timesteps(50)
์ค์ผ์ค๋ฌ์ timestep์ ์ค์ ํ๋ฉด ๊ท ๋ฑํ ๊ฐ๊ฒฉ์ ๊ตฌ์ฑ ์์๋ฅผ ๊ฐ์ง ํ ์๊ฐ ์์ฑ๋ฉ๋๋ค.(์ด ์์์์๋ 50๊ฐ) ๊ฐ ์์๋ ๋ชจ๋ธ์ด ์ด๋ฏธ์ง์ ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ ์๊ฐ ๊ฐ๊ฒฉ์ ํด๋นํฉ๋๋ค. ๋์ค์ ๋ ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฅผ ๋ง๋ค ๋ ์ด ํ ์๋ฅผ ๋ฐ๋ณตํ์ฌ ์ด๋ฏธ์ง์ ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํฉ๋๋ค:
>>> scheduler.timesteps tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720, 700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20, 0])
์ํ๋ ์ถ๋ ฅ๊ณผ ๊ฐ์ ๋ชจ์์ ๊ฐ์ง ๋๋ค ๋ ธ์ด์ฆ๋ฅผ ์์ฑํฉ๋๋ค:
>>> import torch >>> sample_size = model.config.sample_size >>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
์ด์ timestep์ ๋ฐ๋ณตํ๋ ๋ฃจํ๋ฅผ ์์ฑํฉ๋๋ค. ๊ฐ timestep์์ ๋ชจ๋ธ์ [
UNet2DModel.forward
]๋ฅผ ํตํด noisy residual์ ๋ฐํํฉ๋๋ค. ์ค์ผ์ค๋ฌ์ [~DDPMScheduler.step
] ๋ฉ์๋๋ noisy residual, timestep, ๊ทธ๋ฆฌ๊ณ ์ ๋ ฅ์ ๋ฐ์ ์ด์ timestep์์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ์ด ์ถ๋ ฅ์ ๋ ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ์ ๋ชจ๋ธ์ ๋ํ ๋ค์ ์ ๋ ฅ์ด ๋๋ฉฐ,timesteps
๋ฐฐ์ด์ ๋์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต๋ฉ๋๋ค.>>> input = noise >>> for t in scheduler.timesteps: ... with torch.no_grad(): ... noisy_residual = model(input, t).sample ... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample ... input = previous_noisy_sample
์ด๊ฒ์ด ์ ์ฒด ๋ ธ์ด์ฆ ์ ๊ฑฐ ํ๋ก์ธ์ค์ด๋ฉฐ, ๋์ผํ ํจํด์ ์ฌ์ฉํด ๋ชจ๋ diffusion ์์คํ ์ ์์ฑํ ์ ์์ต๋๋ค.
๋ง์ง๋ง ๋จ๊ณ๋ ๋ ธ์ด์ฆ๊ฐ ์ ๊ฑฐ๋ ์ถ๋ ฅ์ ์ด๋ฏธ์ง๋ก ๋ณํํ๋ ๊ฒ์ ๋๋ค:
>>> from PIL import Image >>> import numpy as np >>> image = (input / 2 + 0.5).clamp(0, 1) >>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0] >>> image = Image.fromarray((image * 255).round().astype("uint8")) >>> image
๋ค์ ์น์
์์๋ ์ฌ๋ฌ๋ถ์ ๊ธฐ์ ์ ์ํํด๋ณด๊ณ ์ข ๋ ๋ณต์กํ Stable Diffusion ํ์ดํ๋ผ์ธ์ ๋ถ์ํด ๋ณด๊ฒ ์ต๋๋ค. ๋ฐฉ๋ฒ์ ๊ฑฐ์ ๋์ผํฉ๋๋ค. ํ์ํ ๊ตฌ์ฑ์์๋ค์ ์ด๊ธฐํํ๊ณ timestep์๋ฅผ ์ค์ ํ์ฌ timestep
๋ฐฐ์ด์ ์์ฑํฉ๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ์์ timestep
๋ฐฐ์ด์ด ์ฌ์ฉ๋๋ฉฐ, ์ด ๋ฐฐ์ด์ ๊ฐ ์์์ ๋ํด ๋ชจ๋ธ์ ๋
ธ์ด์ฆ๊ฐ ์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ timestep
์ ๋ฐ๋ณตํ๊ณ ๊ฐ timestep์์ noise residual์ ์ถ๋ ฅํ๊ณ ์ค์ผ์ค๋ฌ๋ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ด์ timestep์์ ๋
ธ์ด์ฆ๊ฐ ๋ํ ์ด๋ฏธ์ง๋ฅผ ์์ธกํฉ๋๋ค. ์ด ํ๋ก์ธ์ค๋ timestep
๋ฐฐ์ด์ ๋์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต๋ฉ๋๋ค.
ํ๋ฒ ์ฌ์ฉํด ๋ด ์๋ค!
Stable Diffusion ํ์ดํ๋ผ์ธ ํด์ฒดํ๊ธฐ
Stable Diffusion ์ text-to-image latent diffusion ๋ชจ๋ธ์ ๋๋ค. latent diffusion ๋ชจ๋ธ์ด๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ด์ ๋ ์ค์ ํฝ์ ๊ณต๊ฐ ๋์ ์ด๋ฏธ์ง์ ์ ์ฐจ์์ ํํ์ผ๋ก ์์ ํ๊ธฐ ๋๋ฌธ์ด๊ณ , ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ด ๋ ๋์ต๋๋ค. ์ธ์ฝ๋๋ ์ด๋ฏธ์ง๋ฅผ ๋ ์์ ํํ์ผ๋ก ์์ถํ๊ณ , ๋์ฝ๋๋ ์์ถ๋ ํํ์ ๋ค์ ์ด๋ฏธ์ง๋ก ๋ณํํฉ๋๋ค. text-to-image ๋ชจ๋ธ์ ๊ฒฝ์ฐ ํ ์คํธ ์๋ฒ ๋ฉ์ ์์ฑํ๊ธฐ ์ํด tokenizer์ ์ธ์ฝ๋๊ฐ ํ์ํฉ๋๋ค. ์ด์ ์์ ์์ ์ด๋ฏธ UNet ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๊ฐ ํ์ํ๋ค๋ ๊ฒ์ ์๊ณ ๊ณ์ จ์ ๊ฒ์ ๋๋ค.
๋ณด์๋ค์ํผ, ์ด๊ฒ์ UNet ๋ชจ๋ธ๋ง ํฌํจ๋ DDPM ํ์ดํ๋ผ์ธ๋ณด๋ค ๋ ๋ณต์กํฉ๋๋ค. Stable Diffusion ๋ชจ๋ธ์๋ ์ธ ๊ฐ์ ๊ฐ๋ณ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ด ์์ต๋๋ค.
๐ก VAE, UNet ๋ฐ ํ ์คํธ ์ธ์ฝ๋ ๋ชจ๋ธ์ ์๋๋ฐฉ์์ ๋ํ ์์ธํ ๋ด์ฉ์ How does Stable Diffusion work? ๋ธ๋ก๊ทธ๋ฅผ ์ฐธ์กฐํ์ธ์.
์ด์ Stable Diffusion ํ์ดํ๋ผ์ธ์ ํ์ํ ๊ตฌ์ฑ์์๋ค์ด ๋ฌด์์ธ์ง ์์์ผ๋, [~ModelMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด ๋ชจ๋ ๊ตฌ์ฑ์์๋ฅผ ๋ถ๋ฌ์ต๋๋ค. ์ฌ์ ํ์ต๋ ์ฒดํฌํฌ์ธํธ runwayml/stable-diffusion-v1-5
์์ ์ฐพ์ ์ ์์ผ๋ฉฐ, ๊ฐ ๊ตฌ์ฑ์์๋ค์ ๋ณ๋์ ํ์ ํด๋์ ์ ์ฅ๋์ด ์์ต๋๋ค:
>>> from PIL import Image
>>> import torch
>>> from transformers import CLIPTextModel, CLIPTokenizer
>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
>>> text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
๊ธฐ๋ณธ [PNDMScheduler
] ๋์ , [UniPCMultistepScheduler
]๋ก ๊ต์ฒดํ์ฌ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ฅผ ์ผ๋ง๋ ์ฝ๊ฒ ์ฐ๊ฒฐํ ์ ์๋์ง ํ์ธํฉ๋๋ค:
>>> from diffusers import UniPCMultistepScheduler
>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
์ถ๋ก ์๋๋ฅผ ๋์ด๋ ค๋ฉด ์ค์ผ์ค๋ฌ์ ๋ฌ๋ฆฌ ํ์ต ๊ฐ๋ฅํ ๊ฐ์ค์น๊ฐ ์์ผ๋ฏ๋ก ๋ชจ๋ธ์ GPU๋ก ์ฎ๊ธฐ์ธ์:
>>> torch_device = "cuda"
>>> vae.to(torch_device)
>>> text_encoder.to(torch_device)
>>> unet.to(torch_device)
ํ ์คํธ ์๋ฒ ๋ฉ ์์ฑํ๊ธฐ
๋ค์ ๋จ๊ณ๋ ์๋ฒ ๋ฉ์ ์์ฑํ๊ธฐ ์ํด ํ ์คํธ๋ฅผ ํ ํฐํํ๋ ๊ฒ์ ๋๋ค. ์ด ํ ์คํธ๋ UNet ๋ชจ๋ธ์์ condition์ผ๋ก ์ฌ์ฉ๋๊ณ ์ ๋ ฅ ํ๋กฌํํธ์ ์ ์ฌํ ๋ฐฉํฅ์ผ๋ก diffusion ํ๋ก์ธ์ค๋ฅผ ์กฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
๐ก guidance_scale
๋งค๊ฐ๋ณ์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ๋ ํ๋กฌํํธ์ ์ผ๋ง๋ ๋ง์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ ์ง ๊ฒฐ์ ํฉ๋๋ค.
๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์์ฑํ๊ณ ์ถ๋ค๋ฉด ์ํ๋ ํ๋กฌํํธ๋ฅผ ์์ ๋กญ๊ฒ ์ ํํ์ธ์!
>>> prompt = ["a photograph of an astronaut riding a horse"]
>>> height = 512 # Stable Diffusion์ ๊ธฐ๋ณธ ๋์ด
>>> width = 512 # Stable Diffusion์ ๊ธฐ๋ณธ ๋๋น
>>> num_inference_steps = 25 # ๋
ธ์ด์ฆ ์ ๊ฑฐ ์คํ
์
>>> guidance_scale = 7.5 # classifier-free guidance๋ฅผ ์ํ scale
>>> generator = torch.manual_seed(0) # ์ด๊ธฐ ์ ์ฌ ๋
ธ์ด์ฆ๋ฅผ ์์ฑํ๋ seed generator
>>> batch_size = len(prompt)
ํ ์คํธ๋ฅผ ํ ํฐํํ๊ณ ํ๋กฌํํธ์์ ์๋ฒ ๋ฉ์ ์์ฑํฉ๋๋ค:
>>> text_input = tokenizer(
... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
... )
>>> with torch.no_grad():
... text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
๋ํ ํจ๋ฉ ํ ํฐ์ ์๋ฒ ๋ฉ์ธ unconditional ํ
์คํธ ์๋ฒ ๋ฉ์ ์์ฑํด์ผ ํฉ๋๋ค. ์ด ์๋ฒ ๋ฉ์ ์กฐ๊ฑด๋ถ text_embeddings
๊ณผ ๋์ผํ shape(batch_size
๊ทธ๋ฆฌ๊ณ seq_length
)์ ๊ฐ์ ธ์ผ ํฉ๋๋ค:
>>> max_length = text_input.input_ids.shape[-1]
>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
๋๋ฒ์ forward pass๋ฅผ ํผํ๊ธฐ ์ํด conditional ์๋ฒ ๋ฉ๊ณผ unconditional ์๋ฒ ๋ฉ์ ๋ฐฐ์น(batch)๋ก ์ฐ๊ฒฐํ๊ฒ ์ต๋๋ค:
>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
๋๋ค ๋ ธ์ด์ฆ ์์ฑ
๊ทธ๋ค์ diffusion ํ๋ก์ธ์ค์ ์์์ ์ผ๋ก ์ด๊ธฐ ๋๋ค ๋
ธ์ด์ฆ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๊ฒ์ด ์ด๋ฏธ์ง์ ์ ์ฌ์ ํํ์ด๋ฉฐ ์ ์ฐจ์ ์ผ๋ก ๋
ธ์ด์ฆ๊ฐ ์ ๊ฑฐ๋ฉ๋๋ค. ์ด ์์ ์์ latent
์ด๋ฏธ์ง๋ ์ต์ข
์ด๋ฏธ์ง ํฌ๊ธฐ๋ณด๋ค ์์ง๋ง ๋์ค์ ๋ชจ๋ธ์ด ์ด๋ฅผ 512x512 ์ด๋ฏธ์ง ํฌ๊ธฐ๋ก ๋ณํํ๋ฏ๋ก ๊ด์ฐฎ์ต๋๋ค.
๐ก vae
๋ชจ๋ธ์๋ 3๊ฐ์ ๋ค์ด ์ํ๋ง ๋ ์ด์ด๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋์ด์ ๋๋น๊ฐ 8๋ก ๋๋ฉ๋๋ค. ๋ค์์ ์คํํ์ฌ ํ์ธํ ์ ์์ต๋๋ค:
2 ** (len(vae.config.block_out_channels) - 1) == 8
>>> latents = torch.randn(
... (batch_size, unet.config.in_channels, height // 8, width // 8),
... generator=generator,
... device=torch_device,
... )
์ด๋ฏธ์ง ๋ ธ์ด์ฆ ์ ๊ฑฐ
๋จผ์ [UniPCMultistepScheduler
]์ ๊ฐ์ ํฅ์๋ ์ค์ผ์ค๋ฌ์ ํ์ํ ๋
ธ์ด์ฆ ์ค์ผ์ผ ๊ฐ์ธ ์ด๊ธฐ ๋
ธ์ด์ฆ ๋ถํฌ sigma ๋ก ์
๋ ฅ์ ์ค์ผ์ผ๋ง ํ๋ ๊ฒ๋ถํฐ ์์ํฉ๋๋ค:
>>> latents = latents * scheduler.init_noise_sigma
๋ง์ง๋ง ๋จ๊ณ๋ latent
์ ์์ํ ๋
ธ์ด์ฆ๋ฅผ ์ ์ง์ ์ผ๋ก ํ๋กฌํํธ์ ์ค๋ช
๋ ์ด๋ฏธ์ง๋ก ๋ณํํ๋ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฅผ ์์ฑํ๋ ๊ฒ์
๋๋ค. ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ ์ธ ๊ฐ์ง ์์
์ ์ํํด์ผ ํ๋ค๋ ์ ์ ๊ธฐ์ตํ์ธ์:
- ๋ ธ์ด์ฆ ์ ๊ฑฐ ์ค์ ์ฌ์ฉํ ์ค์ผ์ค๋ฌ์ timesteps๋ฅผ ์ค์ ํฉ๋๋ค.
- timestep์ ๋ฐ๋ผ ๋ฐ๋ณตํฉ๋๋ค.
- ๊ฐ timestep์์ UNet ๋ชจ๋ธ์ ํธ์ถํ์ฌ noise residual์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ์ ์ ๋ฌํ์ฌ ์ด์ ๋ ธ์ด์ฆ ์ํ์ ๊ณ์ฐํฉ๋๋ค.
>>> from tqdm.auto import tqdm
>>> scheduler.set_timesteps(num_inference_steps)
>>> for t in tqdm(scheduler.timesteps):
... # classifier-free guidance๋ฅผ ์ํํ๋ ๊ฒฝ์ฐ ๋๋ฒ์ forward pass๋ฅผ ์ํํ์ง ์๋๋ก latent๋ฅผ ํ์ฅ.
... latent_model_input = torch.cat([latents] * 2)
... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
... # noise residual ์์ธก
... with torch.no_grad():
... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
... # guidance ์ํ
... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
... # ์ด์ ๋
ธ์ด์ฆ ์ํ์ ๊ณ์ฐ x_t -> x_t-1
... latents = scheduler.step(noise_pred, t, latents).prev_sample
์ด๋ฏธ์ง ๋์ฝ๋ฉ
๋ง์ง๋ง ๋จ๊ณ๋ vae
๋ฅผ ์ด์ฉํ์ฌ ์ ์ฌ ํํ์ ์ด๋ฏธ์ง๋ก ๋์ฝ๋ฉํ๊ณ sample
๊ณผ ํจ๊ป ๋์ฝ๋ฉ๋ ์ถ๋ ฅ์ ์ป๋ ๊ฒ์
๋๋ค:
# latent๋ฅผ ์ค์ผ์ผ๋งํ๊ณ vae๋ก ์ด๋ฏธ์ง ๋์ฝ๋ฉ
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
๋ง์ง๋ง์ผ๋ก ์ด๋ฏธ์ง๋ฅผ PIL.Image
๋ก ๋ณํํ๋ฉด ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ํ์ธํ ์ ์์ต๋๋ค!
>>> image = (image / 2 + 0.5).clamp(0, 1)
>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
>>> images = (image * 255).round().astype("uint8")
>>> pil_images = [Image.fromarray(image) for image in images]
>>> pil_images[0]
๋ค์ ๋จ๊ณ
๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ๋ถํฐ ๋ณต์กํ ํ์ดํ๋ผ์ธ๊น์ง, ์์ ๋ง์ diffusion ์์คํ ์ ์์ฑํ๋ ๋ฐ ํ์ํ ๊ฒ์ ๋ ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฟ์ด๋ผ๋ ๊ฒ์ ์ ์ ์์์ต๋๋ค. ์ด ๋ฃจํ๋ ์ค์ผ์ค๋ฌ์ timesteps๋ฅผ ์ค์ ํ๊ณ , ์ด๋ฅผ ๋ฐ๋ณตํ๋ฉฐ, UNet ๋ชจ๋ธ์ ํธ์ถํ์ฌ noise residual์ ์์ธกํ๊ณ ์ค์ผ์ค๋ฌ์ ์ ๋ฌํ์ฌ ์ด์ ๋ ธ์ด์ฆ ์ํ์ ๊ณ์ฐํ๋ ๊ณผ์ ์ ๋ฒ๊ฐ์ ๊ฐ๋ฉฐ ์ํํด์ผ ํฉ๋๋ค.
์ด๊ฒ์ด ๋ฐ๋ก ๐งจ Diffusers๊ฐ ์ค๊ณ๋ ๋ชฉ์ ์ ๋๋ค: ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํด ์์ ๋ง์ diffusion ์์คํ ์ ์ง๊ด์ ์ด๊ณ ์ฝ๊ฒ ์์ฑํ ์ ์๋๋ก ํ๊ธฐ ์ํด์์ ๋๋ค.
๋ค์ ๋จ๊ณ๋ฅผ ์์ ๋กญ๊ฒ ์งํํ์ธ์:
- ๐งจ Diffusers์ ํ์ดํ๋ผ์ธ ๊ตฌ์ถ ๋ฐ ๊ธฐ์ฌํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด์ธ์. ์ฌ๋ฌ๋ถ์ด ์ด๋ค ์์ด๋์ด๋ฅผ ๋ด๋์์ง ๊ธฐ๋๋ฉ๋๋ค!
- ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ์ ์ดํด๋ณด๊ณ , ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๋๋ก ์ฌ์ฉํ์ฌ ํ์ดํ๋ผ์ธ์ ์ฒ์๋ถํฐ ํด์ฒดํ๊ณ ๋น๋ํ ์ ์๋์ง ํ์ธํด ๋ณด์ธ์.