์ค์ผ์ค๋ฌ
diffusion ํ์ดํ๋ผ์ธ์ diffusion ๋ชจ๋ธ, ์ค์ผ์ค๋ฌ ๋ฑ์ ์ปดํฌ๋ํธ๋ค๋ก ๊ตฌ์ฑ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ํ์ดํ๋ผ์ธ ์์ ์ผ๋ถ ์ปดํฌ๋ํธ๋ฅผ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ก ๊ต์ฒดํ๋ ์์ ์ปค์คํฐ๋ง์ด์ง ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ์ด์ ๊ฐ์ ์ปดํฌ๋ํธ ์ปค์คํฐ๋ง์ด์ง์ ๊ฐ์ฅ ๋ํ์ ์ธ ์์๊ฐ ๋ฐ๋ก ์ค์ผ์ค๋ฌ๋ฅผ ๊ต์ฒดํ๋ ๊ฒ์ ๋๋ค.
์ค์ผ์ฅด๋ฌ๋ ๋ค์๊ณผ ๊ฐ์ด diffusion ์์คํ ์ ์ ๋ฐ์ ์ธ ๋๋ ธ์ด์ง ํ๋ก์ธ์ค๋ฅผ ์ ์ํฉ๋๋ค.
- ๋๋ ธ์ด์ง ์คํ ์ ์ผ๋ง๋ ๊ฐ์ ธ๊ฐ์ผ ํ ๊น?
- ํ๋ฅ ์ ์ผ๋ก(stochastic) ํน์ ํ์ ์ ์ผ๋ก(deterministic)?
- ๋๋ ธ์ด์ง ๋ ์ํ์ ์ฐพ์๋ด๊ธฐ ์ํด ์ด๋ค ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํด์ผ ํ ๊น?
์ด๋ฌํ ํ๋ก์ธ์ค๋ ๋ค์ ๋ํดํ๊ณ , ๋๋ ธ์ด์ง ์๋์ ๋๋ ธ์ด์ง ํ๋ฆฌํฐ ์ฌ์ด์ ํธ๋ ์ด๋ ์คํ๋ฅผ ์ ์ํด์ผ ํ๋ ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค. ์ฃผ์ด์ง ํ์ดํ๋ผ์ธ์ ์ด๋ค ์ค์ผ์ค๋ฌ๊ฐ ๊ฐ์ฅ ์ ํฉํ์ง๋ฅผ ์ ๋์ ์ผ๋ก ํ๋จํ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ค์ด ์ผ์ ๋๋ค. ์ด๋ก ์ธํด ์ผ๋จ ํด๋น ์ค์ผ์ค๋ฌ๋ฅผ ์ง์ ์ฌ์ฉํ์ฌ, ์์ฑ๋๋ ์ด๋ฏธ์ง๋ฅผ ์ง์ ๋์ผ๋ก ๋ณด๋ฉฐ, ์ ์ฑ์ ์ผ๋ก ์ฑ๋ฅ์ ํ๋จํด๋ณด๋ ๊ฒ์ด ์ถ์ฒ๋๊ณค ํฉ๋๋ค.
ํ์ดํ๋ผ์ธ ๋ถ๋ฌ์ค๊ธฐ
๋จผ์ ์คํ ์ด๋ธ diffusion ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ค๋๋ก ํด๋ณด๊ฒ ์ต๋๋ค. ๋ฌผ๋ก ์คํ ์ด๋ธ diffusion์ ์ฌ์ฉํ๊ธฐ ์ํด์๋, ํ๊น ํ์ด์ค ํ๋ธ์ ๋ฑ๋ก๋ ์ฌ์ฉ์์ฌ์ผ ํ๋ฉฐ, ๊ด๋ จ ๋ผ์ด์ผ์ค์ ๋์ํด์ผ ํ๋ค๋ ์ ์ ์์ง ๋ง์์ฃผ์ธ์.
์ญ์ ์ฃผ: ๋ค๋ง, ํ์ฌ ์ ๊ท๋ก ์์ฑํ ํ๊น ํ์ด์ค ๊ณ์ ์ ๋ํด์๋ ๋ผ์ด์ผ์ค ๋์๋ฅผ ์๊ตฌํ์ง ์๋ ๊ฒ์ผ๋ก ๋ณด์ ๋๋ค!
from huggingface_hub import login
from diffusers import DiffusionPipeline
import torch
# first we need to login with our access token
login()
# Now we can download the pipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
๋ค์์ผ๋ก, GPU๋ก ์ด๋ํฉ๋๋ค.
pipeline.to("cuda")
์ค์ผ์ค๋ฌ ์ก์ธ์ค
์ค์ผ์ค๋ฌ๋ ์ธ์ ๋ ํ์ดํ๋ผ์ธ์ ์ปดํฌ๋ํธ๋ก์ ์กด์ฌํ๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ํ์ดํ๋ผ์ธ ์ธ์คํด์ค ๋ด์ scheduler
๋ผ๋ ์ด๋ฆ์ ์์ฑ(property)์ผ๋ก ์ ์๋์ด ์์ต๋๋ค.
pipeline.scheduler
Output:
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.8.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"trained_betas": null
}
์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํตํด, ์ฐ๋ฆฌ๋ ํด๋น ์ค์ผ์ค๋ฌ๊ฐ [PNDMScheduler
]์ ์ธ์คํด์ค๋ผ๋ ๊ฒ์ ์ ์ ์์ต๋๋ค. ์ด์ [PNDMScheduler
]์ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ค์ ์ฑ๋ฅ์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ๋จผ์ ํ
์คํธ์ ์ฌ์ฉํ ํ๋กฌํํธ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
๋ค์์ผ๋ก ์ ์ฌํ ์ด๋ฏธ์ง ์์ฑ์ ๋ณด์ฅํ๊ธฐ ์ํด์, ๋ค์๊ณผ ๊ฐ์ด ๋๋ค์๋๋ฅผ ๊ณ ์ ํด์ฃผ๋๋ก ํ๊ฒ ์ต๋๋ค.
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
์ค์ผ์ค๋ฌ ๊ต์ฒดํ๊ธฐ
๋ค์์ผ๋ก ํ์ดํ๋ผ์ธ์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ก ๊ต์ฒดํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค. ๋ชจ๋ ์ค์ผ์ค๋ฌ๋ [SchedulerMixin.compatibles
]๋ผ๋ ์์ฑ(property)์ ๊ฐ๊ณ ์์ต๋๋ค. ํด๋น ์์ฑ์ ํธํ ๊ฐ๋ฅํ ์ค์ผ์ค๋ฌ๋ค์ ๋ํ ์ ๋ณด๋ฅผ ๋ด๊ณ ์์ต๋๋ค.
pipeline.scheduler.compatibles
Output:
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
ํธํ๋๋ ์ค์ผ์ค๋ฌ๋ค์ ์ดํด๋ณด๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
- [
LMSDiscreteScheduler
], - [
DDIMScheduler
], - [
DPMSolverMultistepScheduler
], - [
EulerDiscreteScheduler
], - [
PNDMScheduler
], - [
DDPMScheduler
], - [
EulerAncestralDiscreteScheduler
].
์์ ์ ์ํ๋ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํด์ ๊ฐ๊ฐ์ ์ค์ผ์ค๋ฌ๋ค์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
๋จผ์ ํ์ดํ๋ผ์ธ ์์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ฐ๊พธ๊ธฐ ์ํด [ConfigMixin.config
] ์์ฑ๊ณผ [ConfigMixin.from_config
] ๋ฉ์๋๋ฅผ ํ์ฉํด๋ณด๋ ค๊ณ ํฉ๋๋ค.
pipeline.scheduler.config
Output:
FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('steps_offset', 1),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.8.0.dev0'),
('clip_sample', False)])
๊ธฐ์กด ์ค์ผ์ค๋ฌ์ config๋ฅผ ํธํ ๊ฐ๋ฅํ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ์ ์ด์ํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
๋ค์ ์์๋ ๊ธฐ์กด ์ค์ผ์ค๋ฌ(pipeline.scheduler
)๋ฅผ ๋ค๋ฅธ ์ข
๋ฅ์ ์ค์ผ์ค๋ฌ(DDIMScheduler
)๋ก ๋ฐ๊พธ๋ ์ฝ๋์
๋๋ค. ๊ธฐ์กด ์ค์ผ์ค๋ฌ๊ฐ ๊ฐ๊ณ ์๋ config๋ฅผ .from_config
๋ฉ์๋์ ์ธ์๋ก ์ ๋ฌํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
from diffusers import DDIMScheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
์ด์ ํ์ดํ๋ผ์ธ์ ์คํํด์ ๋ ์ค์ผ์ค๋ฌ ์ฌ์ด์ ์์ฑ๋ ์ด๋ฏธ์ง์ ํ๋ฆฌํฐ๋ฅผ ๋น๊ตํด๋ด ์๋ค.
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
์ค์ผ์ค๋ฌ๋ค ๋น๊ตํด๋ณด๊ธฐ
์ง๊ธ๊น์ง๋ [PNDMScheduler
]์ [DDIMScheduler
] ์ค์ผ์ค๋ฌ๋ฅผ ์คํํด๋ณด์์ต๋๋ค. ์์ง ๋น๊ตํด๋ณผ ์ค์ผ์ค๋ฌ๋ค์ด ๋ ๋ง์ด ๋จ์์์ผ๋ ๊ณ์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
[LMSDiscreteScheduler
]์ ์ผ๋ฐ์ ์ผ๋ก ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
from diffusers import LMSDiscreteScheduler
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
[EulerDiscreteScheduler
]์ [EulerAncestralDiscreteScheduler
] ๊ณ ์ 30๋ฒ์ inference step๋ง์ผ๋ก๋ ๋์ ํ๋ฆฌํฐ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
from diffusers import EulerDiscreteScheduler
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
from diffusers import EulerAncestralDiscreteScheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
์ง๊ธ ์ด ๋ฌธ์๋ฅผ ์์ฑํ๋ ํ์์ ๊ธฐ์ค์์ , [DPMSolverMultistepScheduler
]๊ฐ ์๊ฐ ๋๋น ๊ฐ์ฅ ์ข์ ํ์ง์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ ๊ฐ์ต๋๋ค. 20๋ฒ ์ ๋์ ์คํ
๋ง์ผ๋ก๋ ์คํ๋ ์ ์์ต๋๋ค.
from diffusers import DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image
๋ณด์๋ค์ํผ ์์ฑ๋ ์ด๋ฏธ์ง๋ค์ ๋งค์ฐ ๋น์ทํ๊ณ , ๋น์ทํ ํ๋ฆฌํฐ๋ฅผ ๋ณด์ด๋ ๊ฒ ๊ฐ์ต๋๋ค. ์ค์ ๋ก ์ด๋ค ์ค์ผ์ค๋ฌ๋ฅผ ์ ํํ ๊ฒ์ธ๊ฐ๋ ์ข ์ข ํน์ ์ด์ฉ ์ฌ๋ก์ ๊ธฐ๋ฐํด์ ๊ฒฐ์ ๋๊ณค ํฉ๋๋ค. ๊ฒฐ๊ตญ ์ฌ๋ฌ ์ข ๋ฅ์ ์ค์ผ์ค๋ฌ๋ฅผ ์ง์ ์คํ์์ผ๋ณด๊ณ ๋์ผ๋ก ์ง์ ๋น๊ตํด์ ํ๋จํ๋ ๊ฒ ์ข์ ์ ํ์ผ ๊ฒ ๊ฐ์ต๋๋ค.
Flax์์ ์ค์ผ์ค๋ฌ ๊ต์ฒดํ๊ธฐ
JAX/Flax ์ฌ์ฉ์์ธ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๊ฒฝํ ์๋ ์์ต๋๋ค. ๋ค์์ Flax Stable Diffusion ํ์ดํ๋ผ์ธ๊ณผ ์ด๊ณ ์ DDPM-Solver++ ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์์์ ๋๋ค .
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
model_id = "runwayml/stable-diffusion-v1-5"
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "a photo of an astronaut riding a horse on mars"
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
๋ค์ Flax ์ค์ผ์ค๋ฌ๋ ์์ง Flax Stable Diffusion ํ์ดํ๋ผ์ธ๊ณผ ํธํ๋์ง ์์ต๋๋ค.
FlaxLMSDiscreteScheduler
FlaxDDPMScheduler