Upamanyu098's picture
End of training
ef4d689 verified
|
raw
history blame
15.5 kB

ํŒŒ์ดํ”„๋ผ์ธ, ๋ชจ๋ธ ๋ฐ ์Šค์ผ€์ค„๋Ÿฌ ์ดํ•ดํ•˜๊ธฐ

[[open-in-colab]]

๐Ÿงจ Diffusers๋Š” ์‚ฌ์šฉ์ž ์นœํ™”์ ์ด๋ฉฐ ์œ ์—ฐํ•œ ๋„๊ตฌ ์ƒ์ž๋กœ, ์‚ฌ์šฉ์‚ฌ๋ก€์— ๋งž๊ฒŒ diffusion ์‹œ์Šคํ…œ์„ ๊ตฌ์ถ• ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋„๊ตฌ ์ƒ์ž์˜ ํ•ต์‹ฌ์€ ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ์ž…๋‹ˆ๋‹ค. [DiffusionPipeline]์€ ํŽธ์˜๋ฅผ ์œ„ํ•ด ์ด๋Ÿฌํ•œ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๋ฒˆ๋“ค๋กœ ์ œ๊ณตํ•˜์ง€๋งŒ, ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถ„๋ฆฌํ•˜๊ณ  ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๊ฐœ๋ณ„์ ์œผ๋กœ ์‚ฌ์šฉํ•ด ์ƒˆ๋กœ์šด diffusion ์‹œ์Šคํ…œ์„ ๋งŒ๋“ค ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ๊นŒ์ง€ ์ง„ํ–‰ํ•˜๋ฉฐ ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•ด ์ถ”๋ก ์„ ์œ„ํ•œ diffusion ์‹œ์Šคํ…œ์„ ์กฐ๋ฆฝํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ฐฐ์›๋‹ˆ๋‹ค.

๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ ํ•ด์ฒดํ•˜๊ธฐ

ํŒŒ์ดํ”„๋ผ์ธ์€ ์ถ”๋ก ์„ ์œ„ํ•ด ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋Š” ๋น ๋ฅด๊ณ  ์‰ฌ์šด ๋ฐฉ๋ฒ•์œผ๋กœ, ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์ฝ”๋“œ๊ฐ€ 4์ค„ ์ด์ƒ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค:

>>> from diffusers import DDPMPipeline

>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256").to("cuda")
>>> image = ddpm(num_inference_steps=25).images[0]
>>> image
Image of cat created from DDPMPipeline

์ •๋ง ์‰ฝ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ํŒŒ์ดํ”„๋ผ์ธ์€ ์–ด๋–ป๊ฒŒ ์ด๋ ‡๊ฒŒ ํ•  ์ˆ˜ ์žˆ์—ˆ์„๊นŒ์š”? ํŒŒ์ดํ”„๋ผ์ธ์„ ์„ธ๋ถ„ํ™”ํ•˜์—ฌ ๋‚ด๋ถ€์—์„œ ์–ด๋–ค ์ผ์ด ์ผ์–ด๋‚˜๊ณ  ์žˆ๋Š”์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

์œ„ ์˜ˆ์‹œ์—์„œ ํŒŒ์ดํ”„๋ผ์ธ์—๋Š” [UNet2DModel] ๋ชจ๋ธ๊ณผ [DDPMScheduler]๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์€ ์›ํ•˜๋Š” ์ถœ๋ ฅ ํฌ๊ธฐ์˜ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ๋ฒˆ ํ†ต๊ณผ์‹œ์ผœ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฐ timestep์—์„œ ๋ชจ๋ธ์€ noise residual์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ๋Š” ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์€ ์ง€์ •๋œ ์ถ”๋ก  ์Šคํ…์ˆ˜์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ์ด ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ณ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ๋‹ค์‹œ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์ž์ฒด์ ์ธ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ž‘์„ฑํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  1. ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค:

    >>> from diffusers import DDPMScheduler, UNet2DModel
    
    >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
    >>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
    
  2. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค๋ฅผ ์‹คํ–‰ํ•  timestep ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:

    >>> scheduler.set_timesteps(50)
    
  3. ์Šค์ผ€์ค„๋Ÿฌ์˜ timestep์„ ์„ค์ •ํ•˜๋ฉด ๊ท ๋“ฑํ•œ ๊ฐ„๊ฒฉ์˜ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๊ฐ€์ง„ ํ…์„œ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.(์ด ์˜ˆ์‹œ์—์„œ๋Š” 50๊ฐœ) ๊ฐ ์š”์†Œ๋Š” ๋ชจ๋ธ์ด ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋Š” ์‹œ๊ฐ„ ๊ฐ„๊ฒฉ์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฅผ ๋งŒ๋“ค ๋•Œ ์ด ํ…์„œ๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค:

    >>> scheduler.timesteps
    tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
        700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,
        420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,
        140, 120, 100,  80,  60,  40,  20,   0])
    
  4. ์›ํ•˜๋Š” ์ถœ๋ ฅ๊ณผ ๊ฐ™์€ ๋ชจ์–‘์„ ๊ฐ€์ง„ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

    >>> import torch
    
    >>> sample_size = model.config.sample_size
    >>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
    
  5. ์ด์ œ timestep์„ ๋ฐ˜๋ณตํ•˜๋Š” ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ฐ timestep์—์„œ ๋ชจ๋ธ์€ [UNet2DModel.forward]๋ฅผ ํ†ตํ•ด noisy residual์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์Šค์ผ€์ค„๋Ÿฌ์˜ [~DDPMScheduler.step] ๋ฉ”์„œ๋“œ๋Š” noisy residual, timestep, ๊ทธ๋ฆฌ๊ณ  ์ž…๋ ฅ์„ ๋ฐ›์•„ ์ด์ „ timestep์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ด ์ถœ๋ ฅ์€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„์˜ ๋ชจ๋ธ์— ๋Œ€ํ•œ ๋‹ค์Œ ์ž…๋ ฅ์ด ๋˜๋ฉฐ, timesteps ๋ฐฐ์—ด์˜ ๋์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต๋ฉ๋‹ˆ๋‹ค.

    >>> input = noise
    
    >>> for t in scheduler.timesteps:
    ...     with torch.no_grad():
    ...         noisy_residual = model(input, t).sample
    ...     previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
    ...     input = previous_noisy_sample
    

    ์ด๊ฒƒ์ด ์ „์ฒด ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ํ”„๋กœ์„ธ์Šค์ด๋ฉฐ, ๋™์ผํ•œ ํŒจํ„ด์„ ์‚ฌ์šฉํ•ด ๋ชจ๋“  diffusion ์‹œ์Šคํ…œ์„ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  6. ๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” ๋…ธ์ด์ฆˆ๊ฐ€ ์ œ๊ฑฐ๋œ ์ถœ๋ ฅ์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:

    >>> from PIL import Image
    >>> import numpy as np
    
    >>> image = (input / 2 + 0.5).clamp(0, 1)
    >>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
    >>> image = Image.fromarray((image * 255).round().astype("uint8"))
    >>> image
    

๋‹ค์Œ ์„น์…˜์—์„œ๋Š” ์—ฌ๋Ÿฌ๋ถ„์˜ ๊ธฐ์ˆ ์„ ์‹œํ—˜ํ•ด๋ณด๊ณ  ์ข€ ๋” ๋ณต์žกํ•œ Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถ„์„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐฉ๋ฒ•์€ ๊ฑฐ์˜ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ๊ตฌ์„ฑ์š”์†Œ๋“ค์„ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  timestep์ˆ˜๋ฅผ ์„ค์ •ํ•˜์—ฌ timestep ๋ฐฐ์—ด์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„์—์„œ timestep ๋ฐฐ์—ด์ด ์‚ฌ์šฉ๋˜๋ฉฐ, ์ด ๋ฐฐ์—ด์˜ ๊ฐ ์š”์†Œ์— ๋Œ€ํ•ด ๋ชจ๋ธ์€ ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋Š” timestep์„ ๋ฐ˜๋ณตํ•˜๊ณ  ๊ฐ timestep์—์„œ noise residual์„ ์ถœ๋ ฅํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ๋Š” ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด์ „ timestep์—์„œ ๋…ธ์ด์ฆˆ๊ฐ€ ๋œํ•œ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ด ํ”„๋กœ์„ธ์Šค๋Š” timestep ๋ฐฐ์—ด์˜ ๋์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต๋ฉ๋‹ˆ๋‹ค.

ํ•œ๋ฒˆ ์‚ฌ์šฉํ•ด ๋ด…์‹œ๋‹ค!

Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ ํ•ด์ฒดํ•˜๊ธฐ

Stable Diffusion ์€ text-to-image latent diffusion ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. latent diffusion ๋ชจ๋ธ์ด๋ผ๊ณ  ๋ถˆ๋ฆฌ๋Š” ์ด์œ ๋Š” ์‹ค์ œ ํ”ฝ์…€ ๊ณต๊ฐ„ ๋Œ€์‹  ์ด๋ฏธ์ง€์˜ ์ €์ฐจ์›์˜ ํ‘œํ˜„์œผ๋กœ ์ž‘์—…ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๊ณ , ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ด ๋” ๋†’์Šต๋‹ˆ๋‹ค. ์ธ์ฝ”๋”๋Š” ์ด๋ฏธ์ง€๋ฅผ ๋” ์ž‘์€ ํ‘œํ˜„์œผ๋กœ ์••์ถ•ํ•˜๊ณ , ๋””์ฝ”๋”๋Š” ์••์ถ•๋œ ํ‘œํ˜„์„ ๋‹ค์‹œ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. text-to-image ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด tokenizer์™€ ์ธ์ฝ”๋”๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด์ „ ์˜ˆ์ œ์—์„œ ์ด๋ฏธ UNet ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์€ ์•Œ๊ณ  ๊ณ„์…จ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ณด์‹œ๋‹ค์‹œํ”ผ, ์ด๊ฒƒ์€ UNet ๋ชจ๋ธ๋งŒ ํฌํ•จ๋œ DDPM ํŒŒ์ดํ”„๋ผ์ธ๋ณด๋‹ค ๋” ๋ณต์žกํ•ฉ๋‹ˆ๋‹ค. Stable Diffusion ๋ชจ๋ธ์—๋Š” ์„ธ ๊ฐœ์˜ ๊ฐœ๋ณ„ ์‚ฌ์ „ํ•™์Šต๋œ ๋ชจ๋ธ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก VAE, UNet ๋ฐ ํ…์ŠคํŠธ ์ธ์ฝ”๋” ๋ชจ๋ธ์˜ ์ž‘๋™๋ฐฉ์‹์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ How does Stable Diffusion work? ๋ธ”๋กœ๊ทธ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์ด์ œ Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ์— ํ•„์š”ํ•œ ๊ตฌ์„ฑ์š”์†Œ๋“ค์ด ๋ฌด์—‡์ธ์ง€ ์•Œ์•˜์œผ๋‹ˆ, [~ModelMixin.from_pretrained] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋“  ๊ตฌ์„ฑ์š”์†Œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ์‚ฌ์ „ํ•™์Šต๋œ ์ฒดํฌํฌ์ธํŠธ runwayml/stable-diffusion-v1-5์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ฐ ๊ตฌ์„ฑ์š”์†Œ๋“ค์€ ๋ณ„๋„์˜ ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค:

>>> from PIL import Image
>>> import torch
>>> from transformers import CLIPTextModel, CLIPTokenizer
>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler

>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
>>> text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

๊ธฐ๋ณธ [PNDMScheduler] ๋Œ€์‹ , [UniPCMultistepScheduler]๋กœ ๊ต์ฒดํ•˜์—ฌ ๋‹ค๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์–ผ๋งˆ๋‚˜ ์‰ฝ๊ฒŒ ์—ฐ๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค:

>>> from diffusers import UniPCMultistepScheduler

>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

์ถ”๋ก  ์†๋„๋ฅผ ๋†’์ด๋ ค๋ฉด ์Šค์ผ€์ค„๋Ÿฌ์™€ ๋‹ฌ๋ฆฌ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๊ฐ€์ค‘์น˜๊ฐ€ ์žˆ์œผ๋ฏ€๋กœ ๋ชจ๋ธ์„ GPU๋กœ ์˜ฎ๊ธฐ์„ธ์š”:

>>> torch_device = "cuda"
>>> vae.to(torch_device)
>>> text_encoder.to(torch_device)
>>> unet.to(torch_device)

ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑํ•˜๊ธฐ

๋‹ค์Œ ๋‹จ๊ณ„๋Š” ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ํ…์ŠคํŠธ๋Š” UNet ๋ชจ๋ธ์—์„œ condition์œผ๋กœ ์‚ฌ์šฉ๋˜๊ณ  ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ์™€ ์œ ์‚ฌํ•œ ๋ฐฉํ–ฅ์œผ๋กœ diffusion ํ”„๋กœ์„ธ์Šค๋ฅผ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

๐Ÿ’ก guidance_scale ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ํ”„๋กฌํ”„ํŠธ์— ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ• ์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ์›ํ•˜๋Š” ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์„ ํƒํ•˜์„ธ์š”!

>>> prompt = ["a photograph of an astronaut riding a horse"]
>>> height = 512  # Stable Diffusion์˜ ๊ธฐ๋ณธ ๋†’์ด
>>> width = 512  # Stable Diffusion์˜ ๊ธฐ๋ณธ ๋„ˆ๋น„
>>> num_inference_steps = 25  # ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ์Šคํ… ์ˆ˜
>>> guidance_scale = 7.5  # classifier-free guidance๋ฅผ ์œ„ํ•œ scale
>>> generator = torch.manual_seed(0)  # ์ดˆ๊ธฐ ์ž ์žฌ ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•˜๋Š” seed generator
>>> batch_size = len(prompt)

ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ  ํ”„๋กฌํ”„ํŠธ์—์„œ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

>>> text_input = tokenizer(
...     prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
... )

>>> with torch.no_grad():
...     text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

๋˜ํ•œ ํŒจ๋”ฉ ํ† ํฐ์˜ ์ž„๋ฒ ๋”ฉ์ธ unconditional ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ž„๋ฒ ๋”ฉ์€ ์กฐ๊ฑด๋ถ€ text_embeddings๊ณผ ๋™์ผํ•œ shape(batch_size ๊ทธ๋ฆฌ๊ณ  seq_length)์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค:

>>> max_length = text_input.input_ids.shape[-1]
>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

๋‘๋ฒˆ์˜ forward pass๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด conditional ์ž„๋ฒ ๋”ฉ๊ณผ unconditional ์ž„๋ฒ ๋”ฉ์„ ๋ฐฐ์น˜(batch)๋กœ ์—ฐ๊ฒฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

๋žœ๋ค ๋…ธ์ด์ฆˆ ์ƒ์„ฑ

๊ทธ๋‹ค์Œ diffusion ํ”„๋กœ์„ธ์Šค์˜ ์‹œ์ž‘์ ์œผ๋กœ ์ดˆ๊ธฐ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ์ด๋ฏธ์ง€์˜ ์ž ์žฌ์  ํ‘œํ˜„์ด๋ฉฐ ์ ์ฐจ์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๊ฐ€ ์ œ๊ฑฐ๋ฉ๋‹ˆ๋‹ค. ์ด ์‹œ์ ์—์„œ latent ์ด๋ฏธ์ง€๋Š” ์ตœ์ข… ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ณด๋‹ค ์ž‘์ง€๋งŒ ๋‚˜์ค‘์— ๋ชจ๋ธ์ด ์ด๋ฅผ 512x512 ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋กœ ๋ณ€ํ™˜ํ•˜๋ฏ€๋กœ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก vae ๋ชจ๋ธ์—๋Š” 3๊ฐœ์˜ ๋‹ค์šด ์ƒ˜ํ”Œ๋ง ๋ ˆ์ด์–ด๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋†’์ด์™€ ๋„ˆ๋น„๊ฐ€ 8๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์„ ์‹คํ–‰ํ•˜์—ฌ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

2 ** (len(vae.config.block_out_channels) - 1) == 8
>>> latents = torch.randn(
...     (batch_size, unet.config.in_channels, height // 8, width // 8),
...     generator=generator,
...     device=torch_device,
... )

์ด๋ฏธ์ง€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ

๋จผ์ € [UniPCMultistepScheduler]์™€ ๊ฐ™์€ ํ–ฅ์ƒ๋œ ์Šค์ผ€์ค„๋Ÿฌ์— ํ•„์š”ํ•œ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ผ ๊ฐ’์ธ ์ดˆ๊ธฐ ๋…ธ์ด์ฆˆ ๋ถ„ํฌ sigma ๋กœ ์ž…๋ ฅ์„ ์Šค์ผ€์ผ๋ง ํ•˜๋Š” ๊ฒƒ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค:

>>> latents = latents * scheduler.init_noise_sigma

๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” latent์˜ ์ˆœ์ˆ˜ํ•œ ๋…ธ์ด์ฆˆ๋ฅผ ์ ์ง„์ ์œผ๋กœ ํ”„๋กฌํ”„ํŠธ์— ์„ค๋ช…๋œ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋Š” ์„ธ ๊ฐ€์ง€ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•œ๋‹ค๋Š” ์ ์„ ๊ธฐ์–ตํ•˜์„ธ์š”:

  1. ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ์ค‘์— ์‚ฌ์šฉํ•  ์Šค์ผ€์ค„๋Ÿฌ์˜ timesteps๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  2. timestep์„ ๋”ฐ๋ผ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
  3. ๊ฐ timestep์—์„œ UNet ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•˜์—ฌ noise residual์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ์— ์ „๋‹ฌํ•˜์—ฌ ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
>>> from tqdm.auto import tqdm

>>> scheduler.set_timesteps(num_inference_steps)

>>> for t in tqdm(scheduler.timesteps):
...     # classifier-free guidance๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ ๋‘๋ฒˆ์˜ forward pass๋ฅผ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋„๋ก latent๋ฅผ ํ™•์žฅ.
...     latent_model_input = torch.cat([latents] * 2)

...     latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

...     # noise residual ์˜ˆ์ธก
...     with torch.no_grad():
...         noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

...     # guidance ์ˆ˜ํ–‰
...     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
...     noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

...     # ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐ x_t -> x_t-1
...     latents = scheduler.step(noise_pred, t, latents).prev_sample

์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ

๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋Š” vae๋ฅผ ์ด์šฉํ•˜์—ฌ ์ž ์žฌ ํ‘œํ˜„์„ ์ด๋ฏธ์ง€๋กœ ๋””์ฝ”๋”ฉํ•˜๊ณ  sample๊ณผ ํ•จ๊ป˜ ๋””์ฝ”๋”ฉ๋œ ์ถœ๋ ฅ์„ ์–ป๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:

# latent๋ฅผ ์Šค์ผ€์ผ๋งํ•˜๊ณ  vae๋กœ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
latents = 1 / 0.18215 * latents
with torch.no_grad():
    image = vae.decode(latents).sample

๋งˆ์ง€๋ง‰์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ PIL.Image๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

>>> image = (image / 2 + 0.5).clamp(0, 1)
>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
>>> images = (image * 255).round().astype("uint8")
>>> pil_images = [Image.fromarray(image) for image in images]
>>> pil_images[0]

๋‹ค์Œ ๋‹จ๊ณ„

๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ๋ถ€ํ„ฐ ๋ณต์žกํ•œ ํŒŒ์ดํ”„๋ผ์ธ๊นŒ์ง€, ์ž์‹ ๋งŒ์˜ diffusion ์‹œ์Šคํ…œ์„ ์ž‘์„ฑํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๊ฒƒ์€ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฃจํ”„๋ฟ์ด๋ผ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฃจํ”„๋Š” ์Šค์ผ€์ค„๋Ÿฌ์˜ timesteps๋ฅผ ์„ค์ •ํ•˜๊ณ , ์ด๋ฅผ ๋ฐ˜๋ณตํ•˜๋ฉฐ, UNet ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•˜์—ฌ noise residual์„ ์˜ˆ์ธกํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ์— ์ „๋‹ฌํ•˜์—ฌ ์ด์ „ ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ์„ ๊ณ„์‚ฐํ•˜๋Š” ๊ณผ์ •์„ ๋ฒˆ๊ฐˆ์•„ ๊ฐ€๋ฉฐ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด๊ฒƒ์ด ๋ฐ”๋กœ ๐Ÿงจ Diffusers๊ฐ€ ์„ค๊ณ„๋œ ๋ชฉ์ ์ž…๋‹ˆ๋‹ค: ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•ด ์ž์‹ ๋งŒ์˜ diffusion ์‹œ์Šคํ…œ์„ ์ง๊ด€์ ์ด๊ณ  ์‰ฝ๊ฒŒ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ์ž…๋‹ˆ๋‹ค.

๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์ง„ํ–‰ํ•˜์„ธ์š”:

  • ๐Ÿงจ Diffusers์— ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์ถ• ๋ฐ ๊ธฐ์—ฌํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณด์„ธ์š”. ์—ฌ๋Ÿฌ๋ถ„์ด ์–ด๋–ค ์•„์ด๋””์–ด๋ฅผ ๋‚ด๋†“์„์ง€ ๊ธฐ๋Œ€๋ฉ๋‹ˆ๋‹ค!
  • ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ดํŽด๋ณด๊ณ , ๋ชจ๋ธ๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ณ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•ด์ฒดํ•˜๊ณ  ๋นŒ๋“œํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•ด ๋ณด์„ธ์š”.