Upamanyu098's picture
End of training
ef4d689 verified
|
raw
history blame
11.9 kB

์Šค์ผ€์ค„๋Ÿฌ

diffusion ํŒŒ์ดํ”„๋ผ์ธ์€ diffusion ๋ชจ๋ธ, ์Šค์ผ€์ค„๋Ÿฌ ๋“ฑ์˜ ์ปดํฌ๋„ŒํŠธ๋“ค๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํŒŒ์ดํ”„๋ผ์ธ ์•ˆ์˜ ์ผ๋ถ€ ์ปดํฌ๋„ŒํŠธ๋ฅผ ๋‹ค๋ฅธ ์ปดํฌ๋„ŒํŠธ๋กœ ๊ต์ฒดํ•˜๋Š” ์‹์˜ ์ปค์Šคํ„ฐ๋งˆ์ด์ง• ์—ญ์‹œ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์ด์™€ ๊ฐ™์€ ์ปดํฌ๋„ŒํŠธ ์ปค์Šคํ„ฐ๋งˆ์ด์ง•์˜ ๊ฐ€์žฅ ๋Œ€ํ‘œ์ ์ธ ์˜ˆ์‹œ๊ฐ€ ๋ฐ”๋กœ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๊ต์ฒดํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์Šค์ผ€์ฅด๋Ÿฌ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด diffusion ์‹œ์Šคํ…œ์˜ ์ „๋ฐ˜์ ์ธ ๋””๋…ธ์ด์ง• ํ”„๋กœ์„ธ์Šค๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

  • ๋””๋…ธ์ด์ง• ์Šคํ…์„ ์–ผ๋งˆ๋‚˜ ๊ฐ€์ ธ๊ฐ€์•ผ ํ• ๊นŒ?
  • ํ™•๋ฅ ์ ์œผ๋กœ(stochastic) ํ˜น์€ ํ™•์ •์ ์œผ๋กœ(deterministic)?
  • ๋””๋…ธ์ด์ง• ๋œ ์ƒ˜ํ”Œ์„ ์ฐพ์•„๋‚ด๊ธฐ ์œ„ํ•ด ์–ด๋–ค ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์‚ฌ์šฉํ•ด์•ผ ํ• ๊นŒ?

์ด๋Ÿฌํ•œ ํ”„๋กœ์„ธ์Šค๋Š” ๋‹ค์†Œ ๋‚œํ•ดํ•˜๊ณ , ๋””๋…ธ์ด์ง• ์†๋„์™€ ๋””๋…ธ์ด์ง• ํ€„๋ฆฌํ‹ฐ ์‚ฌ์ด์˜ ํŠธ๋ ˆ์ด๋“œ ์˜คํ”„๋ฅผ ์ •์˜ํ•ด์•ผ ํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฃผ์–ด์ง„ ํŒŒ์ดํ”„๋ผ์ธ์— ์–ด๋–ค ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ๊ฐ€์žฅ ์ ํ•ฉํ•œ์ง€๋ฅผ ์ •๋Ÿ‰์ ์œผ๋กœ ํŒ๋‹จํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์–ด๋ ค์šด ์ผ์ž…๋‹ˆ๋‹ค. ์ด๋กœ ์ธํ•ด ์ผ๋‹จ ํ•ด๋‹น ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•˜์—ฌ, ์ƒ์„ฑ๋˜๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ ๋ˆˆ์œผ๋กœ ๋ณด๋ฉฐ, ์ •์„ฑ์ ์œผ๋กœ ์„ฑ๋Šฅ์„ ํŒ๋‹จํ•ด๋ณด๋Š” ๊ฒƒ์ด ์ถ”์ฒœ๋˜๊ณค ํ•ฉ๋‹ˆ๋‹ค.

ํŒŒ์ดํ”„๋ผ์ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋จผ์ € ์Šคํ…Œ์ด๋ธ” diffusion ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋„๋ก ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฌผ๋ก  ์Šคํ…Œ์ด๋ธ” diffusion์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š”, ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์— ๋“ฑ๋ก๋œ ์‚ฌ์šฉ์ž์—ฌ์•ผ ํ•˜๋ฉฐ, ๊ด€๋ จ ๋ผ์ด์„ผ์Šค์— ๋™์˜ํ•ด์•ผ ํ•œ๋‹ค๋Š” ์ ์„ ์žŠ์ง€ ๋ง์•„์ฃผ์„ธ์š”.

์—ญ์ž ์ฃผ: ๋‹ค๋งŒ, ํ˜„์žฌ ์‹ ๊ทœ๋กœ ์ƒ์„ฑํ•œ ํ—ˆ๊น…ํŽ˜์ด์Šค ๊ณ„์ •์— ๋Œ€ํ•ด์„œ๋Š” ๋ผ์ด์„ผ์Šค ๋™์˜๋ฅผ ์š”๊ตฌํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค!

from huggingface_hub import login
from diffusers import DiffusionPipeline
import torch

# first we need to login with our access token
login()

# Now we can download the pipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)

๋‹ค์Œ์œผ๋กœ, GPU๋กœ ์ด๋™ํ•ฉ๋‹ˆ๋‹ค.

pipeline.to("cuda")

์Šค์ผ€์ค„๋Ÿฌ ์•ก์„ธ์Šค

์Šค์ผ€์ค„๋Ÿฌ๋Š” ์–ธ์ œ๋‚˜ ํŒŒ์ดํ”„๋ผ์ธ์˜ ์ปดํฌ๋„ŒํŠธ๋กœ์„œ ์กด์žฌํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์œผ๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์ธ์Šคํ„ด์Šค ๋‚ด์— scheduler๋ผ๋Š” ์ด๋ฆ„์˜ ์†์„ฑ(property)์œผ๋กœ ์ •์˜๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

pipeline.scheduler

Output:

PNDMScheduler {
  "_class_name": "PNDMScheduler",
  "_diffusers_version": "0.8.0.dev0",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "num_train_timesteps": 1000,
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "trained_betas": null
}

์ถœ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํ†ตํ•ด, ์šฐ๋ฆฌ๋Š” ํ•ด๋‹น ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ [PNDMScheduler]์˜ ์ธ์Šคํ„ด์Šค๋ผ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ [PNDMScheduler]์™€ ๋‹ค๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ๋“ค์˜ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ํ…Œ์ŠคํŠธ์— ์‚ฌ์šฉํ•  ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."

๋‹ค์Œ์œผ๋กœ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ๋ณด์žฅํ•˜๊ธฐ ์œ„ํ•ด์„œ, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋žœ๋ค์‹œ๋“œ๋ฅผ ๊ณ ์ •ํ•ด์ฃผ๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image



์Šค์ผ€์ค„๋Ÿฌ ๊ต์ฒดํ•˜๊ธฐ

๋‹ค์Œ์œผ๋กœ ํŒŒ์ดํ”„๋ผ์ธ์˜ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋‹ค๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ๋กœ ๊ต์ฒดํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ชจ๋“  ์Šค์ผ€์ค„๋Ÿฌ๋Š” [SchedulerMixin.compatibles]๋ผ๋Š” ์†์„ฑ(property)์„ ๊ฐ–๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ํ•ด๋‹น ์†์„ฑ์€ ํ˜ธํ™˜ ๊ฐ€๋Šฅํ•œ ์Šค์ผ€์ค„๋Ÿฌ๋“ค์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

pipeline.scheduler.compatibles

Output:

[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
 diffusers.schedulers.scheduling_ddim.DDIMScheduler,
 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
 diffusers.schedulers.scheduling_pndm.PNDMScheduler,
 diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
 diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]

ํ˜ธํ™˜๋˜๋Š” ์Šค์ผ€์ค„๋Ÿฌ๋“ค์„ ์‚ดํŽด๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • [LMSDiscreteScheduler],
  • [DDIMScheduler],
  • [DPMSolverMultistepScheduler],
  • [EulerDiscreteScheduler],
  • [PNDMScheduler],
  • [DDPMScheduler],
  • [EulerAncestralDiscreteScheduler].

์•ž์„œ ์ •์˜ํ–ˆ๋˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๊ฐ๊ฐ์˜ ์Šค์ผ€์ค„๋Ÿฌ๋“ค์„ ๋น„๊ตํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

๋จผ์ € ํŒŒ์ดํ”„๋ผ์ธ ์•ˆ์˜ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ฐ”๊พธ๊ธฐ ์œ„ํ•ด [ConfigMixin.config] ์†์„ฑ๊ณผ [ConfigMixin.from_config] ๋ฉ”์„œ๋“œ๋ฅผ ํ™œ์šฉํ•ด๋ณด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

pipeline.scheduler.config

Output:

FrozenDict([('num_train_timesteps', 1000),
            ('beta_start', 0.00085),
            ('beta_end', 0.012),
            ('beta_schedule', 'scaled_linear'),
            ('trained_betas', None),
            ('skip_prk_steps', True),
            ('set_alpha_to_one', False),
            ('steps_offset', 1),
            ('_class_name', 'PNDMScheduler'),
            ('_diffusers_version', '0.8.0.dev0'),
            ('clip_sample', False)])

๊ธฐ์กด ์Šค์ผ€์ค„๋Ÿฌ์˜ config๋ฅผ ํ˜ธํ™˜ ๊ฐ€๋Šฅํ•œ ๋‹ค๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ์— ์ด์‹ํ•˜๋Š” ๊ฒƒ ์—ญ์‹œ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ ์˜ˆ์‹œ๋Š” ๊ธฐ์กด ์Šค์ผ€์ค„๋Ÿฌ(pipeline.scheduler)๋ฅผ ๋‹ค๋ฅธ ์ข…๋ฅ˜์˜ ์Šค์ผ€์ค„๋Ÿฌ(DDIMScheduler)๋กœ ๋ฐ”๊พธ๋Š” ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค. ๊ธฐ์กด ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ๊ฐ–๊ณ  ์žˆ๋˜ config๋ฅผ .from_config ๋ฉ”์„œ๋“œ์˜ ์ธ์ž๋กœ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from diffusers import DDIMScheduler

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

์ด์ œ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•ด์„œ ๋‘ ์Šค์ผ€์ค„๋Ÿฌ ์‚ฌ์ด์˜ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์˜ ํ€„๋ฆฌํ‹ฐ๋ฅผ ๋น„๊ตํ•ด๋ด…์‹œ๋‹ค.

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image



์Šค์ผ€์ค„๋Ÿฌ๋“ค ๋น„๊ตํ•ด๋ณด๊ธฐ

์ง€๊ธˆ๊นŒ์ง€๋Š” [PNDMScheduler]์™€ [DDIMScheduler] ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‹คํ–‰ํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค. ์•„์ง ๋น„๊ตํ•ด๋ณผ ์Šค์ผ€์ค„๋Ÿฌ๋“ค์ด ๋” ๋งŽ์ด ๋‚จ์•„์žˆ์œผ๋‹ˆ ๊ณ„์† ๋น„๊ตํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

[LMSDiscreteScheduler]์„ ์ผ๋ฐ˜์ ์œผ๋กœ ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

from diffusers import LMSDiscreteScheduler

pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image



[EulerDiscreteScheduler]์™€ [EulerAncestralDiscreteScheduler] ๊ณ ์ž‘ 30๋ฒˆ์˜ inference step๋งŒ์œผ๋กœ๋„ ๋†’์€ ํ€„๋ฆฌํ‹ฐ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from diffusers import EulerDiscreteScheduler

pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image



from diffusers import EulerAncestralDiscreteScheduler

pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image



์ง€๊ธˆ ์ด ๋ฌธ์„œ๋ฅผ ์ž‘์„ฑํ•˜๋Š” ํ˜„์‹œ์  ๊ธฐ์ค€์—์„ , [DPMSolverMultistepScheduler]๊ฐ€ ์‹œ๊ฐ„ ๋Œ€๋น„ ๊ฐ€์žฅ ์ข‹์€ ํ’ˆ์งˆ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. 20๋ฒˆ ์ •๋„์˜ ์Šคํ…๋งŒ์œผ๋กœ๋„ ์‹คํ–‰๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from diffusers import DPMSolverMultistepScheduler

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)

generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image



๋ณด์‹œ๋‹ค์‹œํ”ผ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋“ค์€ ๋งค์šฐ ๋น„์Šทํ•˜๊ณ , ๋น„์Šทํ•œ ํ€„๋ฆฌํ‹ฐ๋ฅผ ๋ณด์ด๋Š” ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ ์–ด๋–ค ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์„ ํƒํ•  ๊ฒƒ์ธ๊ฐ€๋Š” ์ข…์ข… ํŠน์ • ์ด์šฉ ์‚ฌ๋ก€์— ๊ธฐ๋ฐ˜ํ•ด์„œ ๊ฒฐ์ •๋˜๊ณค ํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ตญ ์—ฌ๋Ÿฌ ์ข…๋ฅ˜์˜ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์ง์ ‘ ์‹คํ–‰์‹œ์ผœ๋ณด๊ณ  ๋ˆˆ์œผ๋กœ ์ง์ ‘ ๋น„๊ตํ•ด์„œ ํŒ๋‹จํ•˜๋Š” ๊ฒŒ ์ข‹์€ ์„ ํƒ์ผ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

Flax์—์„œ ์Šค์ผ€์ค„๋Ÿฌ ๊ต์ฒดํ•˜๊ธฐ

JAX/Flax ์‚ฌ์šฉ์ž์ธ ๊ฒฝ์šฐ ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์€ Flax Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ๊ณผ ์ดˆ๊ณ ์† DDPM-Solver++ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค .

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler

model_id = "runwayml/stable-diffusion-v1-5"
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    revision="bf16",
    dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state

# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "a photo of an astronaut riding a horse on mars"
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

๋‹ค์Œ Flax ์Šค์ผ€์ค„๋Ÿฌ๋Š” ์•„์ง Flax Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ๊ณผ ํ˜ธํ™˜๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

  • FlaxLMSDiscreteScheduler
  • FlaxDDPMScheduler