File size: 12,756 Bytes
ef4d689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# JAX / Flaxμμμ 𧨠Stable Diffusion!
[[open-in-colab]]
π€ Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) λ λ²μ 0.5.1λΆν° Flaxλ₯Ό μ§μν©λλ€! μ΄λ₯Ό ν΅ν΄ Colab, Kaggle, Google Cloud Platformμμ μ¬μ©ν μ μλ κ²μ²λΌ Google TPUμμ μ΄κ³ μ μΆλ‘ μ΄ κ°λ₯ν©λλ€.
μ΄ λ
ΈνΈλΆμ JAX / Flaxλ₯Ό μ¬μ©ν΄ μΆλ‘ μ μ€ννλ λ°©λ²μ 보μ¬μ€λλ€. Stable Diffusionμ μλ λ°©μμ λν μμΈν λ΄μ©μ μνκ±°λ GPUμμ μ€ννλ €λ©΄ μ΄ [λ
ΈνΈλΆ] ](https://huggingface.co/docs/diffusers/stable_diffusion)μ μ°Έμ‘°νμΈμ.
λ¨Όμ , TPU λ°±μλλ₯Ό μ¬μ©νκ³ μλμ§ νμΈν©λλ€. Colabμμ μ΄ λ
ΈνΈλΆμ μ€ννλ κ²½μ°, λ©λ΄μμ λ°νμμ μ νν λ€μ "λ°νμ μ ν λ³κ²½" μ΅μ
μ μ νν λ€μ νλμ¨μ΄ κ°μκΈ° μ€μ μμ TPUλ₯Ό μ νν©λλ€.
JAXλ TPU μ μ©μ μλμ§λ§ κ° TPU μλ²μλ 8κ°μ TPU κ°μκΈ°κ° λ³λ ¬λ‘ μλνκΈ° λλ¬Έμ ν΄λΉ νλμ¨μ΄μμ λ λΉμ λ°νλ€λ μ μ μμλμΈμ.
## Setup
λ¨Όμ diffusersκ° μ€μΉλμ΄ μλμ§ νμΈν©λλ€.
```bash
!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
```
```python
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
```
```python
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
"TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
```
```python out
Found 8 JAX devices of type Cloud TPU.
```
κ·Έλ° λ€μ λͺ¨λ dependenciesλ₯Ό κ°μ Έμ΅λλ€.
```python
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
```
## λͺ¨λΈ λΆλ¬μ€κΈ°
TPU μ₯μΉλ ν¨μ¨μ μΈ half-float μ νμΈ bfloat16μ μ§μν©λλ€. ν
μ€νΈμλ μ΄ μ νμ μ¬μ©νμ§λ§ λμ float32λ₯Ό μ¬μ©νμ¬ μ 체 μ λ°λ(full precision)λ₯Ό μ¬μ©ν μλ μμ΅λλ€.
```python
dtype = jnp.bfloat16
```
Flaxλ ν¨μν νλ μμν¬μ΄λ―λ‘ λͺ¨λΈμ 무μν(stateless)νμ΄λ©° 맀κ°λ³μλ λͺ¨λΈ μΈλΆμ μ μ₯λ©λλ€. μ¬μ νμ΅λ Flax νμ΄νλΌμΈμ λΆλ¬μ€λ©΄ νμ΄νλΌμΈ μ체μ λͺ¨λΈ κ°μ€μΉ(λλ 맀κ°λ³μ)κ° λͺ¨λ λ°νλ©λλ€. μ ν¬λ bf16 λ²μ μ κ°μ€μΉλ₯Ό μ¬μ©νκ³ μμΌλ―λ‘ μ ν κ²½κ³ κ° νμλμ§λ§ 무μν΄λ λ©λλ€.
```python
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=dtype,
)
```
## μΆλ‘
TPUμλ μΌλ°μ μΌλ‘ 8κ°μ λλ°μ΄μ€κ° λ³λ ¬λ‘ μλνλ―λ‘ λ³΄μ ν λλ°μ΄μ€ μλ§νΌ ν둬ννΈλ₯Ό 볡μ ν©λλ€. κ·Έλ° λ€μ κ°κ° νλμ μ΄λ―Έμ§ μμ±μ λ΄λΉνλ 8κ°μ λλ°μ΄μ€μμ ν λ²μ μΆλ‘ μ μνν©λλ€. λ°λΌμ νλμ μΉ©μ΄ νλμ μ΄λ―Έμ§λ₯Ό μμ±νλ λ° κ±Έλ¦¬λ μκ°κ³Ό λμΌν μκ°μ 8κ°μ μ΄λ―Έμ§λ₯Ό μ»μ μ μμ΅λλ€.
ν둬ννΈλ₯Ό 볡μ νκ³ λλ©΄ νμ΄νλΌμΈμ `prepare_inputs` ν¨μλ₯Ό νΈμΆνμ¬ ν ν°νλ ν
μ€νΈ IDλ₯Ό μ»μ΅λλ€. ν ν°νλ ν
μ€νΈμ κΈΈμ΄λ κΈ°λ³Έ CLIP ν
μ€νΈ λͺ¨λΈμ ꡬμ±μ λ°λΌ 77ν ν°μΌλ‘ μ€μ λ©λλ€.
```python
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
```
```python out
(8, 77)
```
### 볡μ¬(Replication) λ° μ λ ¬ν
λͺ¨λΈ 맀κ°λ³μμ μ
λ ₯κ°μ μ°λ¦¬κ° 보μ ν 8κ°μ λ³λ ¬ μ₯μΉμ 볡μ¬(Replication)λμ΄μΌ ν©λλ€. 맀κ°λ³μ λμ
λ리λ `flax.jax_utils.replicate`(λμ
λ리λ₯Ό μννλ©° κ°μ€μΉμ λͺ¨μμ λ³κ²½νμ¬ 8λ² λ°λ³΅νλ ν¨μ)λ₯Ό μ¬μ©νμ¬ λ³΅μ¬λ©λλ€. λ°°μ΄μ `shard`λ₯Ό μ¬μ©νμ¬ λ³΅μ λ©λλ€.
```python
p_params = replicate(params)
```
```python
prompt_ids = shard(prompt_ids)
prompt_ids.shape
```
```python out
(8, 1, 77)
```
μ΄ shapeμ 8κ°μ λλ°μ΄μ€ κ°κ°μ΄ shape `(1, 77)`μ jnp λ°°μ΄μ μ
λ ₯κ°μΌλ‘ λ°λλ€λ μλ―Έμ
λλ€. μ¦ 1μ λλ°μ΄μ€λΉ batch(λ°°μΉ) ν¬κΈ°μ
λλ€. λ©λͺ¨λ¦¬κ° μΆ©λΆν TPUμμλ ν λ²μ μ¬λ¬ μ΄λ―Έμ§(μΉ©λΉ)λ₯Ό μμ±νλ €λ κ²½μ° 1λ³΄λ€ ν΄ μ μμ΅λλ€.
μ΄λ―Έμ§λ₯Ό μμ±ν μ€λΉκ° κ±°μ μλ£λμμ΅λλ€! μ΄μ μμ± ν¨μμ μ λ¬ν λμ μμ±κΈ°λ§ λ§λ€λ©΄ λ©λλ€. μ΄κ²μ λμλ₯Ό λ€λ£¨λ λͺ¨λ ν¨μμ λμ μμ±κΈ°κ° μμ΄μΌ νλ€λ, λμμ λν΄ λ§€μ° μ§μ§νκ³ λ
λ¨μ μΈ Flaxμ νμ€ μ μ°¨μ
λλ€. μ΄λ κ² νλ©΄ μ¬λ¬ λΆμ°λ κΈ°κΈ°μμ νλ ¨ν λμλ μ¬νμ±μ΄ 보μ₯λ©λλ€.
μλ ν¬νΌ ν¨μλ μλλ₯Ό μ¬μ©νμ¬ λμ μμ±κΈ°λ₯Ό μ΄κΈ°νν©λλ€. λμΌν μλλ₯Ό μ¬μ©νλ ν μ νν λμΌν κ²°κ³Όλ₯Ό μ»μ μ μμ΅λλ€. λμ€μ λ
ΈνΈλΆμμ κ²°κ³Όλ₯Ό νμν λμ λ€λ₯Έ μλλ₯Ό μμ λ‘κ² μ¬μ©νμΈμ.
```python
def create_key(seed=0):
return jax.random.PRNGKey(seed)
```
rngλ₯Ό μ»μ λ€μ 8λ² 'λΆν 'νμ¬ κ° λλ°μ΄μ€κ° λ€λ₯Έ μ λλ μ΄ν°λ₯Ό μμ νλλ‘ ν©λλ€. λ°λΌμ κ° λλ°μ΄μ€λ§λ€ λ€λ₯Έ μ΄λ―Έμ§κ° μμ±λλ©° μ 체 νλ‘μΈμ€λ₯Ό μ¬νν μ μμ΅λλ€.
```python
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
```
JAX μ½λλ λ§€μ° λΉ λ₯΄κ² μ€νλλ ν¨μ¨μ μΈ ννμΌλ‘ μ»΄νμΌν μ μμ΅λλ€. νμ§λ§ νμ νΈμΆμμ λͺ¨λ μ
λ ₯μ΄ λμΌν λͺ¨μμ κ°λλ‘ ν΄μΌ νλ©°, κ·Έλ μ§ μμΌλ©΄ JAXκ° μ½λλ₯Ό λ€μ μ»΄νμΌν΄μΌ νλ―λ‘ μ΅μ νλ μλλ₯Ό νμ©ν μ μμ΅λλ€.
`jit = True`λ₯Ό μΈμλ‘ μ λ¬νλ©΄ Flax νμ΄νλΌμΈμ΄ μ½λλ₯Ό μ»΄νμΌν μ μμ΅λλ€. λν λͺ¨λΈμ΄ μ¬μ© κ°λ₯ν 8κ°μ λλ°μ΄μ€μμ λ³λ ¬λ‘ μ€νλλλ‘ λ³΄μ₯ν©λλ€.
λ€μ μ
μ μ²μ μ€ννλ©΄ μ»΄νμΌνλ λ° μκ°μ΄ μ€λ 걸리μ§λ§ μ΄ν νΈμΆ(μ
λ ₯μ΄ λ€λ₯Έ κ²½μ°μλ)μ ν¨μ¬ λΉ¨λΌμ§λλ€. μλ₯Ό λ€μ΄, ν
μ€νΈνμ λ TPU v2-8μμ μ»΄νμΌνλ λ° 1λΆ μ΄μ 걸리μ§λ§ μ΄ν μΆλ‘ μ€νμλ μ½ 7μ΄κ° 걸립λλ€.
```
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
```
```python out
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s
```
λ°νλ λ°°μ΄μ shapeμ `(8, 1, 512, 512, 3)`μ
λλ€. μ΄λ₯Ό μ¬κ΅¬μ±νμ¬ λ λ²μ§Έ μ°¨μμ μ κ±°νκ³ 512 Γ 512 Γ 3μ μ΄λ―Έμ§ 8κ°λ₯Ό μ»μ λ€μ PILλ‘ λ³νν©λλ€.
```python
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
```
### μκ°ν
μ΄λ―Έμ§λ₯Ό 그리λμ νμνλ λμ°λ―Έ ν¨μλ₯Ό λ§λ€μ΄ λ³΄κ² μ΅λλ€.
```python
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
```python
image_grid(images, 2, 4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
## λ€λ₯Έ ν둬ννΈ μ¬μ©
λͺ¨λ λλ°μ΄μ€μμ λμΌν ν둬ννΈλ₯Ό 볡μ ν νμλ μμ΅λλ€. ν둬ννΈ 2κ°λ₯Ό κ°κ° 4λ²μ© μμ±νκ±°λ ν λ²μ 8κ°μ μλ‘ λ€λ₯Έ ν둬ννΈλ₯Ό μμ±νλ λ± μνλ κ²μ 무μμ΄λ ν μ μμ΅λλ€. νλ² ν΄λ³΄μΈμ!
λ¨Όμ μ
λ ₯ μ€λΉ μ½λλ₯Ό νΈλ¦¬ν ν¨μλ‘ λ¦¬ν©ν°λ§νκ² μ΅λλ€:
```python
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
```
```python
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
## λ³λ ¬ν(parallelization)λ μ΄λ»κ² μλνλκ°?
μμ `diffusers` Flax νμ΄νλΌμΈμ΄ λͺ¨λΈμ μλμΌλ‘ μ»΄νμΌνκ³ μ¬μ© κ°λ₯ν λͺ¨λ κΈ°κΈ°μμ λ³λ ¬λ‘ μ€ννλ€κ³ λ§μλλ Έμ΅λλ€. μ΄μ κ·Έ νλ‘μΈμ€λ₯Ό κ°λ΅νκ² μ΄ν΄λ³΄κ³ μλ λ°©μμ 보μ¬λλ¦¬κ² μ΅λλ€.
JAX λ³λ ¬νλ μ¬λ¬ κ°μ§ λ°©λ²μΌλ‘ μνν μ μμ΅λλ€. κ°μ₯ μ¬μ΄ λ°©λ²μ jax.pmap ν¨μλ₯Ό μ¬μ©νμ¬ λ¨μΌ νλ‘κ·Έλ¨, λ€μ€ λ°μ΄ν°(SPMD) λ³λ ¬νλ₯Ό λ¬μ±νλ κ²μ
λλ€. μ¦, λμΌν μ½λμ 볡μ¬λ³Έμ κ°κ° λ€λ₯Έ λ°μ΄ν° μ
λ ₯μ λν΄ μ¬λ¬ κ° μ€ννλ κ²μ
λλ€. λ μ κ΅ν μ κ·Ό λ°©μλ κ°λ₯νλ―λ‘ κ΄μ¬μ΄ μμΌμλ€λ©΄ [JAX λ¬Έμ](https://jax.readthedocs.io/en/latest/index.html)μ [`pjit` νμ΄μ§](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit)μμ μ΄ μ£Όμ λ₯Ό μ΄ν΄λ³΄μκΈ° λ°λλλ€!
`jax.pmap`μ λ κ°μ§ κΈ°λ₯μ μνν©λλ€:
- `jax.jit()`λ₯Ό νΈμΆν κ²μ²λΌ μ½λλ₯Ό μ»΄νμΌ(λλ `jit`)ν©λλ€. μ΄ μμ
μ `pmap`μ νΈμΆν λκ° μλλΌ pmapped ν¨μκ° μ²μ νΈμΆλ λ μνλ©λλ€.
- μ»΄νμΌλ μ½λκ° μ¬μ© κ°λ₯ν λͺ¨λ κΈ°κΈ°μμ λ³λ ¬λ‘ μ€νλλλ‘ ν©λλ€.
μλ λ°©μμ 보μ¬λ리기 μν΄ μ΄λ―Έμ§ μμ±μ μ€ννλ λΉκ³΅κ° λ©μλμΈ νμ΄νλΌμΈμ `_generate` λ©μλλ₯Ό `pmap`ν©λλ€. μ΄ λ©μλλ ν₯ν `Diffusers` 릴리μ€μμ μ΄λ¦μ΄ λ³κ²½λκ±°λ μ κ±°λ μ μλ€λ μ μ μ μνμΈμ.
```python
p_generate = pmap(pipeline._generate)
```
`pmap`μ μ¬μ©ν ν μ€λΉλ ν¨μ `p_generate`λ κ°λ
μ μΌλ‘ λ€μμ μνν©λλ€:
* κ° μ₯μΉμμ κΈ°λ³Έ ν¨μ `pipeline._generate`μ 볡μ¬λ³Έμ νΈμΆν©λλ€.
* κ° μ₯μΉμ μ
λ ₯ μΈμμ λ€λ₯Έ λΆλΆμ 보λ
λλ€. μ΄κ²μ΄ λ°λ‘ μ€λ©μ΄ μ¬μ©λλ μ΄μ μ
λλ€. μ΄ κ²½μ° `prompt_ids`μ shapeμ `(8, 1, 77, 768)`μ
λλ€. μ΄ λ°°μ΄μ 8κ°λ‘ λΆν λκ³ `_generate`μ κ° λ³΅μ¬λ³Έμ `(1, 77, 768)`μ shapeμ κ°μ§ μ
λ ₯μ λ°κ² λ©λλ€.
λ³λ ¬λ‘ νΈμΆλλ€λ μ¬μ€μ μμ ν 무μνκ³ `_generate`λ₯Ό μ½λ©ν μ μμ΅λλ€. batch(λ°°μΉ) ν¬κΈ°(μ΄ μμ μμλ `1`)μ μ½λμ μ ν©ν μ°¨μλ§ μ κ²½ μ°λ©΄ λλ©°, λ³λ ¬λ‘ μλνκΈ° μν΄ μ무κ²λ λ³κ²½ν νμκ° μμ΅λλ€.
νμ΄νλΌμΈ νΈμΆμ μ¬μ©ν λμ λ§μ°¬κ°μ§λ‘, λ€μ μ
μ μ²μ μ€νν λλ μκ°μ΄ 걸리μ§λ§ κ·Έ μ΄νμλ ν¨μ¬ λΉ¨λΌμ§λλ€.
```
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
```
```python out
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
```
```python
images.shape
```
```python out
(8, 1, 512, 512, 3)
```
JAXλ λΉλκΈ° λμ€ν¨μΉλ₯Ό μ¬μ©νκ³ κ°λ₯ν ν 빨리 μ μ΄κΆμ Python 루νμ λ°ννκΈ° λλ¬Έμ μΆλ‘ μκ°μ μ ννκ² μΈ‘μ νκΈ° μν΄ `block_until_ready()`λ₯Ό μ¬μ©ν©λλ€. μμ§ κ΅¬μ²΄νλμ§ μμ κ³μ° κ²°κ³Όλ₯Ό μ¬μ©νλ €λ κ²½μ° μλμΌλ‘ μ°¨λ¨μ΄ μνλλ―λ‘ μ½λμμ μ΄ ν¨μλ₯Ό μ¬μ©ν νμκ° μμ΅λλ€. |