File size: 12,756 Bytes
ef4d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# JAX / Flaxμ—μ„œμ˜ 🧨 Stable Diffusion!

[[open-in-colab]]

πŸ€— Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) λŠ” 버전 0.5.1λΆ€ν„° Flaxλ₯Ό μ§€μ›ν•©λ‹ˆλ‹€! 이λ₯Ό 톡해 Colab, Kaggle, Google Cloud Platformμ—μ„œ μ‚¬μš©ν•  수 μžˆλŠ” κ²ƒμ²˜λŸΌ Google TPUμ—μ„œ μ΄ˆκ³ μ† 좔둠이 κ°€λŠ₯ν•©λ‹ˆλ‹€.

이 λ…ΈνŠΈλΆμ€ JAX / Flaxλ₯Ό μ‚¬μš©ν•΄ 좔둠을 μ‹€ν–‰ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€. Stable Diffusion의 μž‘λ™ 방식에 λŒ€ν•œ μžμ„Έν•œ λ‚΄μš©μ„ μ›ν•˜κ±°λ‚˜ GPUμ—μ„œ μ‹€ν–‰ν•˜λ €λ©΄ 이 [λ…ΈνŠΈλΆ] ](https://huggingface.co/docs/diffusers/stable_diffusion)을 μ°Έμ‘°ν•˜μ„Έμš”.

λ¨Όμ €, TPU λ°±μ—”λ“œλ₯Ό μ‚¬μš©ν•˜κ³  μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€. Colabμ—μ„œ 이 λ…ΈνŠΈλΆμ„ μ‹€ν–‰ν•˜λŠ” 경우, λ©”λ‰΄μ—μ„œ λŸ°νƒ€μž„μ„ μ„ νƒν•œ λ‹€μŒ "λŸ°νƒ€μž„ μœ ν˜• λ³€κ²½" μ˜΅μ…˜μ„ μ„ νƒν•œ λ‹€μŒ ν•˜λ“œμ›¨μ–΄ 가속기 μ„€μ •μ—μ„œ TPUλ₯Ό μ„ νƒν•©λ‹ˆλ‹€.

JAXλŠ” TPU μ „μš©μ€ μ•„λ‹ˆμ§€λ§Œ 각 TPU μ„œλ²„μ—λŠ” 8개의 TPU 가속기가 λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° λ•Œλ¬Έμ— ν•΄λ‹Ή ν•˜λ“œμ›¨μ–΄μ—μ„œ 더 빛을 λ°œν•œλ‹€λŠ” 점은 μ•Œμ•„λ‘μ„Έμš”.


## Setup

λ¨Όμ € diffusersκ°€ μ„€μΉ˜λ˜μ–΄ μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€.

```bash
!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
```

```python
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()
import jax
```

```python
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
```

```python out
Found 8 JAX devices of type Cloud TPU.
```

그런 λ‹€μŒ λͺ¨λ“  dependenciesλ₯Ό κ°€μ Έμ˜΅λ‹ˆλ‹€.

```python
import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
```

## λͺ¨λΈ 뢈러였기

TPU μž₯μΉ˜λŠ” 효율적인 half-float μœ ν˜•μΈ bfloat16을 μ§€μ›ν•©λ‹ˆλ‹€. ν…ŒμŠ€νŠΈμ—λŠ” 이 μœ ν˜•μ„ μ‚¬μš©ν•˜μ§€λ§Œ λŒ€μ‹  float32λ₯Ό μ‚¬μš©ν•˜μ—¬ 전체 정밀도(full precision)λ₯Ό μ‚¬μš©ν•  μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.

```python
dtype = jnp.bfloat16
```

FlaxλŠ” ν•¨μˆ˜ν˜• ν”„λ ˆμž„μ›Œν¬μ΄λ―€λ‘œ λͺ¨λΈμ€ λ¬΄μƒνƒœ(stateless)ν˜•μ΄λ©° λ§€κ°œλ³€μˆ˜λŠ” λͺ¨λΈ 외뢀에 μ €μž₯λ©λ‹ˆλ‹€. μ‚¬μ „ν•™μŠ΅λœ Flax νŒŒμ΄ν”„λΌμΈμ„ 뢈러였면 νŒŒμ΄ν”„λΌμΈ μžμ²΄μ™€ λͺ¨λΈ κ°€μ€‘μΉ˜(λ˜λŠ” λ§€κ°œλ³€μˆ˜)κ°€ λͺ¨λ‘ λ°˜ν™˜λ©λ‹ˆλ‹€. μ €ν¬λŠ” bf16 λ²„μ „μ˜ κ°€μ€‘μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  μžˆμœΌλ―€λ‘œ μœ ν˜• κ²½κ³ κ°€ ν‘œμ‹œλ˜μ§€λ§Œ λ¬΄μ‹œν•΄λ„ λ©λ‹ˆλ‹€.

```python
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)
```

## μΆ”λ‘ 

TPUμ—λŠ” 일반적으둜 8개의 λ””λ°”μ΄μŠ€κ°€ λ³‘λ ¬λ‘œ μž‘λ™ν•˜λ―€λ‘œ λ³΄μœ ν•œ λ””λ°”μ΄μŠ€ 수만큼 ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•©λ‹ˆλ‹€. 그런 λ‹€μŒ 각각 ν•˜λ‚˜μ˜ 이미지 생성을 λ‹΄λ‹Ήν•˜λŠ” 8개의 λ””λ°”μ΄μŠ€μ—μ„œ ν•œ λ²ˆμ— 좔둠을 μˆ˜ν–‰ν•©λ‹ˆλ‹€. λ”°λΌμ„œ ν•˜λ‚˜μ˜ 칩이 ν•˜λ‚˜μ˜ 이미지λ₯Ό μƒμ„±ν•˜λŠ” 데 κ±Έλ¦¬λŠ” μ‹œκ°„κ³Ό λ™μΌν•œ μ‹œκ°„μ— 8개의 이미지λ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€.

ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•˜κ³  λ‚˜λ©΄ νŒŒμ΄ν”„λΌμΈμ˜ `prepare_inputs` ν•¨μˆ˜λ₯Ό ν˜ΈμΆœν•˜μ—¬ ν† ν°ν™”λœ ν…μŠ€νŠΈ IDλ₯Ό μ–»μŠ΅λ‹ˆλ‹€. ν† ν°ν™”λœ ν…μŠ€νŠΈμ˜ κΈΈμ΄λŠ” κΈ°λ³Έ CLIP ν…μŠ€νŠΈ λͺ¨λΈμ˜ ꡬ성에 따라 77ν† ν°μœΌλ‘œ μ„€μ •λ©λ‹ˆλ‹€.

```python
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
```

```python out
(8, 77)
```

### 볡사(Replication) 및 μ •λ ¬ν™”

λͺ¨λΈ λ§€κ°œλ³€μˆ˜μ™€ μž…λ ₯값은 μš°λ¦¬κ°€ λ³΄μœ ν•œ 8개의 병렬 μž₯μΉ˜μ— 볡사(Replication)λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€. λ§€κ°œλ³€μˆ˜ λ”•μ…”λ„ˆλ¦¬λŠ” `flax.jax_utils.replicate`(λ”•μ…”λ„ˆλ¦¬λ₯Ό μˆœνšŒν•˜λ©° κ°€μ€‘μΉ˜μ˜ λͺ¨μ–‘을 λ³€κ²½ν•˜μ—¬ 8번 λ°˜λ³΅ν•˜λŠ” ν•¨μˆ˜)λ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ‚¬λ©λ‹ˆλ‹€. 배열은 `shard`λ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ œλ©λ‹ˆλ‹€.

```python
p_params = replicate(params)
```

```python
prompt_ids = shard(prompt_ids)
prompt_ids.shape
```

```python out
(8, 1, 77)
```

이 shape은 8개의 λ””λ°”μ΄μŠ€ 각각이 shape `(1, 77)`의 jnp 배열을 μž…λ ₯κ°’μœΌλ‘œ λ°›λŠ”λ‹€λŠ” μ˜λ―Έμž…λ‹ˆλ‹€. 즉 1은 λ””λ°”μ΄μŠ€λ‹Ή batch(배치) ν¬κΈ°μž…λ‹ˆλ‹€. λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•œ TPUμ—μ„œλŠ” ν•œ λ²ˆμ— μ—¬λŸ¬ 이미지(μΉ©λ‹Ή)λ₯Ό μƒμ„±ν•˜λ €λŠ” 경우 1보닀 클 수 μžˆμŠ΅λ‹ˆλ‹€.

이미지λ₯Ό 생성할 μ€€λΉ„κ°€ 거의 μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€! 이제 생성 ν•¨μˆ˜μ— 전달할 λ‚œμˆ˜ μƒμ„±κΈ°λ§Œ λ§Œλ“€λ©΄ λ©λ‹ˆλ‹€. 이것은 λ‚œμˆ˜λ₯Ό λ‹€λ£¨λŠ” λͺ¨λ“  ν•¨μˆ˜μ— λ‚œμˆ˜ 생성기가 μžˆμ–΄μ•Ό ν•œλ‹€λŠ”, λ‚œμˆ˜μ— λŒ€ν•΄ 맀우 μ§„μ§€ν•˜κ³  독단적인 Flax의 ν‘œμ€€ μ ˆμ°¨μž…λ‹ˆλ‹€. μ΄λ ‡κ²Œ ν•˜λ©΄ μ—¬λŸ¬ λΆ„μ‚°λœ κΈ°κΈ°μ—μ„œ ν›ˆλ ¨ν•  λ•Œμ—λ„ μž¬ν˜„μ„±μ΄ 보μž₯λ©λ‹ˆλ‹€.

μ•„λž˜ 헬퍼 ν•¨μˆ˜λŠ” μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜μ—¬ λ‚œμˆ˜ 생성기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€. λ™μΌν•œ μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜λŠ” ν•œ μ •ν™•νžˆ λ™μΌν•œ κ²°κ³Όλ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€. λ‚˜μ€‘μ— λ…ΈνŠΈλΆμ—μ„œ κ²°κ³Όλ₯Ό 탐색할 λ•Œμ—” λ‹€λ₯Έ μ‹œλ“œλ₯Ό 자유둭게 μ‚¬μš©ν•˜μ„Έμš”.

```python
def create_key(seed=0):
    return jax.random.PRNGKey(seed)
```

rngλ₯Ό 얻은 λ‹€μŒ 8번 'λΆ„ν• 'ν•˜μ—¬ 각 λ””λ°”μ΄μŠ€κ°€ λ‹€λ₯Έ μ œλ„ˆλ ˆμ΄ν„°λ₯Ό μˆ˜μ‹ ν•˜λ„λ‘ ν•©λ‹ˆλ‹€. λ”°λΌμ„œ 각 λ””λ°”μ΄μŠ€λ§ˆλ‹€ λ‹€λ₯Έ 이미지가 μƒμ„±λ˜λ©° 전체 ν”„λ‘œμ„ΈμŠ€λ₯Ό μž¬ν˜„ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

```python
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
```

JAX μ½”λ“œλŠ” 맀우 λΉ λ₯΄κ²Œ μ‹€ν–‰λ˜λŠ” 효율적인 ν‘œν˜„μœΌλ‘œ μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•˜μ§€λ§Œ 후속 ν˜ΈμΆœμ—μ„œ λͺ¨λ“  μž…λ ₯이 λ™μΌν•œ λͺ¨μ–‘을 갖도둝 ν•΄μ•Ό ν•˜λ©°, 그렇지 μ•ŠμœΌλ©΄ JAXκ°€ μ½”λ“œλ₯Ό λ‹€μ‹œ μ»΄νŒŒμΌν•΄μ•Ό ν•˜λ―€λ‘œ μ΅œμ ν™”λœ 속도λ₯Ό ν™œμš©ν•  수 μ—†μŠ΅λ‹ˆλ‹€.

`jit = True`λ₯Ό 인수둜 μ „λ‹¬ν•˜λ©΄ Flax νŒŒμ΄ν”„λΌμΈμ΄ μ½”λ“œλ₯Ό μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. λ˜ν•œ λͺ¨λΈμ΄ μ‚¬μš© κ°€λŠ₯ν•œ 8개의 λ””λ°”μ΄μŠ€μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ 보μž₯ν•©λ‹ˆλ‹€.

λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•˜λ©΄ μ»΄νŒŒμΌν•˜λŠ” 데 μ‹œκ°„μ΄ 였래 κ±Έλ¦¬μ§€λ§Œ 이후 호좜(μž…λ ₯이 λ‹€λ₯Έ κ²½μš°μ—λ„)은 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄, ν…ŒμŠ€νŠΈν–ˆμ„ λ•Œ TPU v2-8μ—μ„œ μ»΄νŒŒμΌν•˜λŠ” 데 1λΆ„ 이상 κ±Έλ¦¬μ§€λ§Œ 이후 μΆ”λ‘  μ‹€ν–‰μ—λŠ” μ•½ 7μ΄ˆκ°€ κ±Έλ¦½λ‹ˆλ‹€.

```
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
```

```python out
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s
```

λ°˜ν™˜λœ λ°°μ—΄μ˜ shape은 `(8, 1, 512, 512, 3)`μž…λ‹ˆλ‹€. 이λ₯Ό μž¬κ΅¬μ„±ν•˜μ—¬ 두 번째 차원을 μ œκ±°ν•˜κ³  512 Γ— 512 Γ— 3의 이미지 8개λ₯Ό 얻은 λ‹€μŒ PIL둜 λ³€ν™˜ν•©λ‹ˆλ‹€.

```python
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
```

### μ‹œκ°ν™”

이미지λ₯Ό κ·Έλ¦¬λ“œμ— ν‘œμ‹œν•˜λŠ” λ„μš°λ―Έ ν•¨μˆ˜λ₯Ό λ§Œλ“€μ–΄ λ³΄κ² μŠ΅λ‹ˆλ‹€.

```python
def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
```

```python
image_grid(images, 2, 4)
```

![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)


## λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈ μ‚¬μš©

λͺ¨λ“  λ””λ°”μ΄μŠ€μ—μ„œ λ™μΌν•œ ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•  ν•„μš”λŠ” μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘¬ν”„νŠΈ 2개λ₯Ό 각각 4λ²ˆμ”© μƒμ„±ν•˜κ±°λ‚˜ ν•œ λ²ˆμ— 8개의 μ„œλ‘œ λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈλ₯Ό μƒμ„±ν•˜λŠ” λ“± μ›ν•˜λŠ” 것은 무엇이든 ν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•œλ²ˆ ν•΄λ³΄μ„Έμš”!

λ¨Όμ € μž…λ ₯ μ€€λΉ„ μ½”λ“œλ₯Ό νŽΈλ¦¬ν•œ ν•¨μˆ˜λ‘œ λ¦¬νŒ©ν„°λ§ν•˜κ² μŠ΅λ‹ˆλ‹€:

```python
prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
```

```python
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

image_grid(images, 2, 4)
```

![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)


## 병렬화(parallelization)λŠ” μ–΄λ–»κ²Œ μž‘λ™ν•˜λŠ”κ°€?

μ•žμ„œ `diffusers` Flax νŒŒμ΄ν”„λΌμΈμ΄ λͺ¨λΈμ„ μžλ™μœΌλ‘œ μ»΄νŒŒμΌν•˜κ³  μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰ν•œλ‹€κ³  λ§μ”€λ“œλ ΈμŠ΅λ‹ˆλ‹€. 이제 κ·Έ ν”„λ‘œμ„ΈμŠ€λ₯Ό κ°„λž΅ν•˜κ²Œ μ‚΄νŽ΄λ³΄κ³  μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€.

JAX λ³‘λ ¬ν™”λŠ” μ—¬λŸ¬ 가지 λ°©λ²•μœΌλ‘œ μˆ˜ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. κ°€μž₯ μ‰¬μš΄ 방법은 jax.pmap ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•˜μ—¬ 단일 ν”„λ‘œκ·Έλž¨, 닀쀑 데이터(SPMD) 병렬화λ₯Ό λ‹¬μ„±ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 즉, λ™μΌν•œ μ½”λ“œμ˜ 볡사본을 각각 λ‹€λ₯Έ 데이터 μž…λ ₯에 λŒ€ν•΄ μ—¬λŸ¬ 개 μ‹€ν–‰ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 더 μ •κ΅ν•œ μ ‘κ·Ό 방식도 κ°€λŠ₯ν•˜λ―€λ‘œ 관심이 μžˆμœΌμ‹œλ‹€λ©΄ [JAX λ¬Έμ„œ](https://jax.readthedocs.io/en/latest/index.html)와 [`pjit` νŽ˜μ΄μ§€](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit)μ—μ„œ 이 주제λ₯Ό μ‚΄νŽ΄λ³΄μ‹œκΈ° λ°”λžλ‹ˆλ‹€!

`jax.pmap`은 두 가지 κΈ°λŠ₯을 μˆ˜ν–‰ν•©λ‹ˆλ‹€:

- `jax.jit()`λ₯Ό ν˜ΈμΆœν•œ κ²ƒμ²˜λŸΌ μ½”λ“œλ₯Ό 컴파일(λ˜λŠ” `jit`)ν•©λ‹ˆλ‹€. 이 μž‘μ—…μ€ `pmap`을 ν˜ΈμΆœν•  λ•Œκ°€ μ•„λ‹ˆλΌ pmapped ν•¨μˆ˜κ°€ 처음 호좜될 λ•Œ μˆ˜ν–‰λ©λ‹ˆλ‹€.
- 컴파일된 μ½”λ“œκ°€ μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ ν•©λ‹ˆλ‹€.

μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κΈ° μœ„ν•΄ 이미지 생성을 μ‹€ν–‰ν•˜λŠ” λΉ„κ³΅κ°œ λ©”μ„œλ“œμΈ νŒŒμ΄ν”„λΌμΈμ˜ `_generate` λ©”μ„œλ“œλ₯Ό `pmap`ν•©λ‹ˆλ‹€. 이 λ©”μ„œλ“œλŠ” ν–₯ν›„ `Diffusers` λ¦΄λ¦¬μŠ€μ—μ„œ 이름이 λ³€κ²½λ˜κ±°λ‚˜ 제거될 수 μžˆλ‹€λŠ” 점에 μœ μ˜ν•˜μ„Έμš”.

```python
p_generate = pmap(pipeline._generate)
```

`pmap`을 μ‚¬μš©ν•œ ν›„ μ€€λΉ„λœ ν•¨μˆ˜ `p_generate`λŠ” κ°œλ…μ μœΌλ‘œ λ‹€μŒμ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€:
* 각 μž₯μΉ˜μ—μ„œ κΈ°λ³Έ ν•¨μˆ˜ `pipeline._generate`의 볡사본을 ν˜ΈμΆœν•©λ‹ˆλ‹€.
* 각 μž₯μΉ˜μ— μž…λ ₯ 인수의 λ‹€λ₯Έ 뢀뢄을 λ³΄λƒ…λ‹ˆλ‹€. 이것이 λ°”λ‘œ 샀딩이 μ‚¬μš©λ˜λŠ” μ΄μœ μž…λ‹ˆλ‹€. 이 경우 `prompt_ids`의 shape은 `(8, 1, 77, 768)`μž…λ‹ˆλ‹€. 이 배열은 8개둜 λΆ„ν• λ˜κ³  `_generate`의 각 볡사본은 `(1, 77, 768)`의 shape을 가진 μž…λ ₯을 λ°›κ²Œ λ©λ‹ˆλ‹€.

λ³‘λ ¬λ‘œ ν˜ΈμΆœλœλ‹€λŠ” 사싀을 μ™„μ „νžˆ λ¬΄μ‹œν•˜κ³  `_generate`λ₯Ό μ½”λ”©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. batch(배치) 크기(이 μ˜ˆμ œμ—μ„œλŠ” `1`)와 μ½”λ“œμ— μ ν•©ν•œ μ°¨μ›λ§Œ μ‹ κ²½ μ“°λ©΄ 되며, λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° μœ„ν•΄ 아무것도 λ³€κ²½ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€.

νŒŒμ΄ν”„λΌμΈ ν˜ΈμΆœμ„ μ‚¬μš©ν•  λ•Œμ™€ λ§ˆμ°¬κ°€μ§€λ‘œ, λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•  λ•ŒλŠ” μ‹œκ°„μ΄ κ±Έλ¦¬μ§€λ§Œ κ·Έ μ΄ν›„μ—λŠ” 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€.

```
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
```

```python out
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
```

```python
images.shape
```

```python out
(8, 1, 512, 512, 3)
```

JAXλŠ” 비동기 λ””μŠ€νŒ¨μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  κ°€λŠ₯ν•œ ν•œ 빨리 μ œμ–΄κΆŒμ„ Python 루프에 λ°˜ν™˜ν•˜κΈ° λ•Œλ¬Έμ— μΆ”λ‘  μ‹œκ°„μ„ μ •ν™•ν•˜κ²Œ μΈ‘μ •ν•˜κΈ° μœ„ν•΄ `block_until_ready()`λ₯Ό μ‚¬μš©ν•©λ‹ˆλ‹€. 아직 κ΅¬μ²΄ν™”λ˜μ§€ μ•Šμ€ 계산 κ²°κ³Όλ₯Ό μ‚¬μš©ν•˜λ €λŠ” 경우 μžλ™μœΌλ‘œ 차단이 μˆ˜ν–‰λ˜λ―€λ‘œ μ½”λ“œμ—μ„œ 이 ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€.