JAX / Flaxμμμ 𧨠Stable Diffusion!
[[open-in-colab]]
π€ Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) λ λ²μ 0.5.1λΆν° Flaxλ₯Ό μ§μν©λλ€! μ΄λ₯Ό ν΅ν΄ Colab, Kaggle, Google Cloud Platformμμ μ¬μ©ν μ μλ κ²μ²λΌ Google TPUμμ μ΄κ³ μ μΆλ‘ μ΄ κ°λ₯ν©λλ€.
μ΄ λ ΈνΈλΆμ JAX / Flaxλ₯Ό μ¬μ©ν΄ μΆλ‘ μ μ€ννλ λ°©λ²μ 보μ¬μ€λλ€. Stable Diffusionμ μλ λ°©μμ λν μμΈν λ΄μ©μ μνκ±°λ GPUμμ μ€ννλ €λ©΄ μ΄ [λ ΈνΈλΆ] ](https://huggingface.co/docs/diffusers/stable_diffusion)μ μ°Έμ‘°νμΈμ.
λ¨Όμ , TPU λ°±μλλ₯Ό μ¬μ©νκ³ μλμ§ νμΈν©λλ€. Colabμμ μ΄ λ ΈνΈλΆμ μ€ννλ κ²½μ°, λ©λ΄μμ λ°νμμ μ νν λ€μ "λ°νμ μ ν λ³κ²½" μ΅μ μ μ νν λ€μ νλμ¨μ΄ κ°μκΈ° μ€μ μμ TPUλ₯Ό μ νν©λλ€.
JAXλ TPU μ μ©μ μλμ§λ§ κ° TPU μλ²μλ 8κ°μ TPU κ°μκΈ°κ° λ³λ ¬λ‘ μλνκΈ° λλ¬Έμ ν΄λΉ νλμ¨μ΄μμ λ λΉμ λ°νλ€λ μ μ μμλμΈμ.
Setup
λ¨Όμ diffusersκ° μ€μΉλμ΄ μλμ§ νμΈν©λλ€.
!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
"TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.
κ·Έλ° λ€μ λͺ¨λ dependenciesλ₯Ό κ°μ Έμ΅λλ€.
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
λͺ¨λΈ λΆλ¬μ€κΈ°
TPU μ₯μΉλ ν¨μ¨μ μΈ half-float μ νμΈ bfloat16μ μ§μν©λλ€. ν μ€νΈμλ μ΄ μ νμ μ¬μ©νμ§λ§ λμ float32λ₯Ό μ¬μ©νμ¬ μ 체 μ λ°λ(full precision)λ₯Ό μ¬μ©ν μλ μμ΅λλ€.
dtype = jnp.bfloat16
Flaxλ ν¨μν νλ μμν¬μ΄λ―λ‘ λͺ¨λΈμ 무μν(stateless)νμ΄λ©° 맀κ°λ³μλ λͺ¨λΈ μΈλΆμ μ μ₯λ©λλ€. μ¬μ νμ΅λ Flax νμ΄νλΌμΈμ λΆλ¬μ€λ©΄ νμ΄νλΌμΈ μ체μ λͺ¨λΈ κ°μ€μΉ(λλ 맀κ°λ³μ)κ° λͺ¨λ λ°νλ©λλ€. μ ν¬λ bf16 λ²μ μ κ°μ€μΉλ₯Ό μ¬μ©νκ³ μμΌλ―λ‘ μ ν κ²½κ³ κ° νμλμ§λ§ 무μν΄λ λ©λλ€.
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=dtype,
)
μΆλ‘
TPUμλ μΌλ°μ μΌλ‘ 8κ°μ λλ°μ΄μ€κ° λ³λ ¬λ‘ μλνλ―λ‘ λ³΄μ ν λλ°μ΄μ€ μλ§νΌ ν둬ννΈλ₯Ό 볡μ ν©λλ€. κ·Έλ° λ€μ κ°κ° νλμ μ΄λ―Έμ§ μμ±μ λ΄λΉνλ 8κ°μ λλ°μ΄μ€μμ ν λ²μ μΆλ‘ μ μνν©λλ€. λ°λΌμ νλμ μΉ©μ΄ νλμ μ΄λ―Έμ§λ₯Ό μμ±νλ λ° κ±Έλ¦¬λ μκ°κ³Ό λμΌν μκ°μ 8κ°μ μ΄λ―Έμ§λ₯Ό μ»μ μ μμ΅λλ€.
ν둬ννΈλ₯Ό 볡μ νκ³ λλ©΄ νμ΄νλΌμΈμ prepare_inputs
ν¨μλ₯Ό νΈμΆνμ¬ ν ν°νλ ν
μ€νΈ IDλ₯Ό μ»μ΅λλ€. ν ν°νλ ν
μ€νΈμ κΈΈμ΄λ κΈ°λ³Έ CLIP ν
μ€νΈ λͺ¨λΈμ ꡬμ±μ λ°λΌ 77ν ν°μΌλ‘ μ€μ λ©λλ€.
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)
볡μ¬(Replication) λ° μ λ ¬ν
λͺ¨λΈ 맀κ°λ³μμ μ
λ ₯κ°μ μ°λ¦¬κ° 보μ ν 8κ°μ λ³λ ¬ μ₯μΉμ 볡μ¬(Replication)λμ΄μΌ ν©λλ€. 맀κ°λ³μ λμ
λ리λ flax.jax_utils.replicate
(λμ
λ리λ₯Ό μννλ©° κ°μ€μΉμ λͺ¨μμ λ³κ²½νμ¬ 8λ² λ°λ³΅νλ ν¨μ)λ₯Ό μ¬μ©νμ¬ λ³΅μ¬λ©λλ€. λ°°μ΄μ shard
λ₯Ό μ¬μ©νμ¬ λ³΅μ λ©λλ€.
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
(8, 1, 77)
μ΄ shapeμ 8κ°μ λλ°μ΄μ€ κ°κ°μ΄ shape (1, 77)
μ jnp λ°°μ΄μ μ
λ ₯κ°μΌλ‘ λ°λλ€λ μλ―Έμ
λλ€. μ¦ 1μ λλ°μ΄μ€λΉ batch(λ°°μΉ) ν¬κΈ°μ
λλ€. λ©λͺ¨λ¦¬κ° μΆ©λΆν TPUμμλ ν λ²μ μ¬λ¬ μ΄λ―Έμ§(μΉ©λΉ)λ₯Ό μμ±νλ €λ κ²½μ° 1λ³΄λ€ ν΄ μ μμ΅λλ€.
μ΄λ―Έμ§λ₯Ό μμ±ν μ€λΉκ° κ±°μ μλ£λμμ΅λλ€! μ΄μ μμ± ν¨μμ μ λ¬ν λμ μμ±κΈ°λ§ λ§λ€λ©΄ λ©λλ€. μ΄κ²μ λμλ₯Ό λ€λ£¨λ λͺ¨λ ν¨μμ λμ μμ±κΈ°κ° μμ΄μΌ νλ€λ, λμμ λν΄ λ§€μ° μ§μ§νκ³ λ λ¨μ μΈ Flaxμ νμ€ μ μ°¨μ λλ€. μ΄λ κ² νλ©΄ μ¬λ¬ λΆμ°λ κΈ°κΈ°μμ νλ ¨ν λμλ μ¬νμ±μ΄ 보μ₯λ©λλ€.
μλ ν¬νΌ ν¨μλ μλλ₯Ό μ¬μ©νμ¬ λμ μμ±κΈ°λ₯Ό μ΄κΈ°νν©λλ€. λμΌν μλλ₯Ό μ¬μ©νλ ν μ νν λμΌν κ²°κ³Όλ₯Ό μ»μ μ μμ΅λλ€. λμ€μ λ ΈνΈλΆμμ κ²°κ³Όλ₯Ό νμν λμ λ€λ₯Έ μλλ₯Ό μμ λ‘κ² μ¬μ©νμΈμ.
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rngλ₯Ό μ»μ λ€μ 8λ² 'λΆν 'νμ¬ κ° λλ°μ΄μ€κ° λ€λ₯Έ μ λλ μ΄ν°λ₯Ό μμ νλλ‘ ν©λλ€. λ°λΌμ κ° λλ°μ΄μ€λ§λ€ λ€λ₯Έ μ΄λ―Έμ§κ° μμ±λλ©° μ 체 νλ‘μΈμ€λ₯Ό μ¬νν μ μμ΅λλ€.
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
JAX μ½λλ λ§€μ° λΉ λ₯΄κ² μ€νλλ ν¨μ¨μ μΈ ννμΌλ‘ μ»΄νμΌν μ μμ΅λλ€. νμ§λ§ νμ νΈμΆμμ λͺ¨λ μ λ ₯μ΄ λμΌν λͺ¨μμ κ°λλ‘ ν΄μΌ νλ©°, κ·Έλ μ§ μμΌλ©΄ JAXκ° μ½λλ₯Ό λ€μ μ»΄νμΌν΄μΌ νλ―λ‘ μ΅μ νλ μλλ₯Ό νμ©ν μ μμ΅λλ€.
jit = True
λ₯Ό μΈμλ‘ μ λ¬νλ©΄ Flax νμ΄νλΌμΈμ΄ μ½λλ₯Ό μ»΄νμΌν μ μμ΅λλ€. λν λͺ¨λΈμ΄ μ¬μ© κ°λ₯ν 8κ°μ λλ°μ΄μ€μμ λ³λ ¬λ‘ μ€νλλλ‘ λ³΄μ₯ν©λλ€.
λ€μ μ μ μ²μ μ€ννλ©΄ μ»΄νμΌνλ λ° μκ°μ΄ μ€λ 걸리μ§λ§ μ΄ν νΈμΆ(μ λ ₯μ΄ λ€λ₯Έ κ²½μ°μλ)μ ν¨μ¬ λΉ¨λΌμ§λλ€. μλ₯Ό λ€μ΄, ν μ€νΈνμ λ TPU v2-8μμ μ»΄νμΌνλ λ° 1λΆ μ΄μ 걸리μ§λ§ μ΄ν μΆλ‘ μ€νμλ μ½ 7μ΄κ° 걸립λλ€.
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s
λ°νλ λ°°μ΄μ shapeμ (8, 1, 512, 512, 3)
μ
λλ€. μ΄λ₯Ό μ¬κ΅¬μ±νμ¬ λ λ²μ§Έ μ°¨μμ μ κ±°νκ³ 512 Γ 512 Γ 3μ μ΄λ―Έμ§ 8κ°λ₯Ό μ»μ λ€μ PILλ‘ λ³νν©λλ€.
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
μκ°ν
μ΄λ―Έμ§λ₯Ό 그리λμ νμνλ λμ°λ―Έ ν¨μλ₯Ό λ§λ€μ΄ λ³΄κ² μ΅λλ€.
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
image_grid(images, 2, 4)
λ€λ₯Έ ν둬ννΈ μ¬μ©
λͺ¨λ λλ°μ΄μ€μμ λμΌν ν둬ννΈλ₯Ό 볡μ ν νμλ μμ΅λλ€. ν둬ννΈ 2κ°λ₯Ό κ°κ° 4λ²μ© μμ±νκ±°λ ν λ²μ 8κ°μ μλ‘ λ€λ₯Έ ν둬ννΈλ₯Ό μμ±νλ λ± μνλ κ²μ 무μμ΄λ ν μ μμ΅λλ€. νλ² ν΄λ³΄μΈμ!
λ¨Όμ μ λ ₯ μ€λΉ μ½λλ₯Ό νΈλ¦¬ν ν¨μλ‘ λ¦¬ν©ν°λ§νκ² μ΅λλ€:
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
λ³λ ¬ν(parallelization)λ μ΄λ»κ² μλνλκ°?
μμ diffusers
Flax νμ΄νλΌμΈμ΄ λͺ¨λΈμ μλμΌλ‘ μ»΄νμΌνκ³ μ¬μ© κ°λ₯ν λͺ¨λ κΈ°κΈ°μμ λ³λ ¬λ‘ μ€ννλ€κ³ λ§μλλ Έμ΅λλ€. μ΄μ κ·Έ νλ‘μΈμ€λ₯Ό κ°λ΅νκ² μ΄ν΄λ³΄κ³ μλ λ°©μμ 보μ¬λλ¦¬κ² μ΅λλ€.
JAX λ³λ ¬νλ μ¬λ¬ κ°μ§ λ°©λ²μΌλ‘ μνν μ μμ΅λλ€. κ°μ₯ μ¬μ΄ λ°©λ²μ jax.pmap ν¨μλ₯Ό μ¬μ©νμ¬ λ¨μΌ νλ‘κ·Έλ¨, λ€μ€ λ°μ΄ν°(SPMD) λ³λ ¬νλ₯Ό λ¬μ±νλ κ²μ
λλ€. μ¦, λμΌν μ½λμ 볡μ¬λ³Έμ κ°κ° λ€λ₯Έ λ°μ΄ν° μ
λ ₯μ λν΄ μ¬λ¬ κ° μ€ννλ κ²μ
λλ€. λ μ κ΅ν μ κ·Ό λ°©μλ κ°λ₯νλ―λ‘ κ΄μ¬μ΄ μμΌμλ€λ©΄ JAX λ¬Έμμ pjit
νμ΄μ§μμ μ΄ μ£Όμ λ₯Ό μ΄ν΄λ³΄μκΈ° λ°λλλ€!
jax.pmap
μ λ κ°μ§ κΈ°λ₯μ μνν©λλ€:
jax.jit()
λ₯Ό νΈμΆν κ²μ²λΌ μ½λλ₯Ό μ»΄νμΌ(λλjit
)ν©λλ€. μ΄ μμ μpmap
μ νΈμΆν λκ° μλλΌ pmapped ν¨μκ° μ²μ νΈμΆλ λ μνλ©λλ€.- μ»΄νμΌλ μ½λκ° μ¬μ© κ°λ₯ν λͺ¨λ κΈ°κΈ°μμ λ³λ ¬λ‘ μ€νλλλ‘ ν©λλ€.
μλ λ°©μμ 보μ¬λ리기 μν΄ μ΄λ―Έμ§ μμ±μ μ€ννλ λΉκ³΅κ° λ©μλμΈ νμ΄νλΌμΈμ _generate
λ©μλλ₯Ό pmap
ν©λλ€. μ΄ λ©μλλ ν₯ν Diffusers
릴리μ€μμ μ΄λ¦μ΄ λ³κ²½λκ±°λ μ κ±°λ μ μλ€λ μ μ μ μνμΈμ.
p_generate = pmap(pipeline._generate)
pmap
μ μ¬μ©ν ν μ€λΉλ ν¨μ p_generate
λ κ°λ
μ μΌλ‘ λ€μμ μνν©λλ€:
- κ° μ₯μΉμμ κΈ°λ³Έ ν¨μ
pipeline._generate
μ 볡μ¬λ³Έμ νΈμΆν©λλ€. - κ° μ₯μΉμ μ
λ ₯ μΈμμ λ€λ₯Έ λΆλΆμ 보λ
λλ€. μ΄κ²μ΄ λ°λ‘ μ€λ©μ΄ μ¬μ©λλ μ΄μ μ
λλ€. μ΄ κ²½μ°
prompt_ids
μ shapeμ(8, 1, 77, 768)
μ λλ€. μ΄ λ°°μ΄μ 8κ°λ‘ λΆν λκ³_generate
μ κ° λ³΅μ¬λ³Έμ(1, 77, 768)
μ shapeμ κ°μ§ μ λ ₯μ λ°κ² λ©λλ€.
λ³λ ¬λ‘ νΈμΆλλ€λ μ¬μ€μ μμ ν 무μνκ³ _generate
λ₯Ό μ½λ©ν μ μμ΅λλ€. batch(λ°°μΉ) ν¬κΈ°(μ΄ μμ μμλ 1
)μ μ½λμ μ ν©ν μ°¨μλ§ μ κ²½ μ°λ©΄ λλ©°, λ³λ ¬λ‘ μλνκΈ° μν΄ μ무κ²λ λ³κ²½ν νμκ° μμ΅λλ€.
νμ΄νλΌμΈ νΈμΆμ μ¬μ©ν λμ λ§μ°¬κ°μ§λ‘, λ€μ μ μ μ²μ μ€νν λλ μκ°μ΄ 걸리μ§λ§ κ·Έ μ΄νμλ ν¨μ¬ λΉ¨λΌμ§λλ€.
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
images.shape
(8, 1, 512, 512, 3)
JAXλ λΉλκΈ° λμ€ν¨μΉλ₯Ό μ¬μ©νκ³ κ°λ₯ν ν 빨리 μ μ΄κΆμ Python 루νμ λ°ννκΈ° λλ¬Έμ μΆλ‘ μκ°μ μ ννκ² μΈ‘μ νκΈ° μν΄ block_until_ready()
λ₯Ό μ¬μ©ν©λλ€. μμ§ κ΅¬μ²΄νλμ§ μμ κ³μ° κ²°κ³Όλ₯Ό μ¬μ©νλ €λ κ²½μ° μλμΌλ‘ μ°¨λ¨μ΄ μνλλ―λ‘ μ½λμμ μ΄ ν¨μλ₯Ό μ¬μ©ν νμκ° μμ΅λλ€.