# JAX / Flaxμ—μ„œμ˜ 🧨 Stable Diffusion! [[open-in-colab]] πŸ€— Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) λŠ” 버전 0.5.1λΆ€ν„° Flaxλ₯Ό μ§€μ›ν•©λ‹ˆλ‹€! 이λ₯Ό 톡해 Colab, Kaggle, Google Cloud Platformμ—μ„œ μ‚¬μš©ν•  수 μžˆλŠ” κ²ƒμ²˜λŸΌ Google TPUμ—μ„œ μ΄ˆκ³ μ† 좔둠이 κ°€λŠ₯ν•©λ‹ˆλ‹€. 이 λ…ΈνŠΈλΆμ€ JAX / Flaxλ₯Ό μ‚¬μš©ν•΄ 좔둠을 μ‹€ν–‰ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€. Stable Diffusion의 μž‘λ™ 방식에 λŒ€ν•œ μžμ„Έν•œ λ‚΄μš©μ„ μ›ν•˜κ±°λ‚˜ GPUμ—μ„œ μ‹€ν–‰ν•˜λ €λ©΄ 이 [λ…ΈνŠΈλΆ] ](https://huggingface.co/docs/diffusers/stable_diffusion)을 μ°Έμ‘°ν•˜μ„Έμš”. λ¨Όμ €, TPU λ°±μ—”λ“œλ₯Ό μ‚¬μš©ν•˜κ³  μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€. Colabμ—μ„œ 이 λ…ΈνŠΈλΆμ„ μ‹€ν–‰ν•˜λŠ” 경우, λ©”λ‰΄μ—μ„œ λŸ°νƒ€μž„μ„ μ„ νƒν•œ λ‹€μŒ "λŸ°νƒ€μž„ μœ ν˜• λ³€κ²½" μ˜΅μ…˜μ„ μ„ νƒν•œ λ‹€μŒ ν•˜λ“œμ›¨μ–΄ 가속기 μ„€μ •μ—μ„œ TPUλ₯Ό μ„ νƒν•©λ‹ˆλ‹€. JAXλŠ” TPU μ „μš©μ€ μ•„λ‹ˆμ§€λ§Œ 각 TPU μ„œλ²„μ—λŠ” 8개의 TPU 가속기가 λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° λ•Œλ¬Έμ— ν•΄λ‹Ή ν•˜λ“œμ›¨μ–΄μ—μ„œ 더 빛을 λ°œν•œλ‹€λŠ” 점은 μ•Œμ•„λ‘μ„Έμš”. ## Setup λ¨Όμ € diffusersκ°€ μ„€μΉ˜λ˜μ–΄ μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€. ```bash !pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy !pip install diffusers ``` ```python import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() import jax ``` ```python num_devices = jax.device_count() device_type = jax.devices()[0].device_kind print(f"Found {num_devices} JAX devices of type {device_type}.") assert ( "TPU" in device_type ), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator" ``` ```python out Found 8 JAX devices of type Cloud TPU. ``` 그런 λ‹€μŒ λͺ¨λ“  dependenciesλ₯Ό κ°€μ Έμ˜΅λ‹ˆλ‹€. ```python import numpy as np import jax import jax.numpy as jnp from pathlib import Path from jax import pmap from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from huggingface_hub import notebook_login from diffusers import FlaxStableDiffusionPipeline ``` ## λͺ¨λΈ 뢈러였기 TPU μž₯μΉ˜λŠ” 효율적인 half-float μœ ν˜•μΈ bfloat16을 μ§€μ›ν•©λ‹ˆλ‹€. ν…ŒμŠ€νŠΈμ—λŠ” 이 μœ ν˜•μ„ μ‚¬μš©ν•˜μ§€λ§Œ λŒ€μ‹  float32λ₯Ό μ‚¬μš©ν•˜μ—¬ 전체 정밀도(full precision)λ₯Ό μ‚¬μš©ν•  μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€. ```python dtype = jnp.bfloat16 ``` FlaxλŠ” ν•¨μˆ˜ν˜• ν”„λ ˆμž„μ›Œν¬μ΄λ―€λ‘œ λͺ¨λΈμ€ λ¬΄μƒνƒœ(stateless)ν˜•μ΄λ©° λ§€κ°œλ³€μˆ˜λŠ” λͺ¨λΈ 외뢀에 μ €μž₯λ©λ‹ˆλ‹€. μ‚¬μ „ν•™μŠ΅λœ Flax νŒŒμ΄ν”„λΌμΈμ„ 뢈러였면 νŒŒμ΄ν”„λΌμΈ μžμ²΄μ™€ λͺ¨λΈ κ°€μ€‘μΉ˜(λ˜λŠ” λ§€κ°œλ³€μˆ˜)κ°€ λͺ¨λ‘ λ°˜ν™˜λ©λ‹ˆλ‹€. μ €ν¬λŠ” bf16 λ²„μ „μ˜ κ°€μ€‘μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  μžˆμœΌλ―€λ‘œ μœ ν˜• κ²½κ³ κ°€ ν‘œμ‹œλ˜μ§€λ§Œ λ¬΄μ‹œν•΄λ„ λ©λ‹ˆλ‹€. ```python pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=dtype, ) ``` ## μΆ”λ‘  TPUμ—λŠ” 일반적으둜 8개의 λ””λ°”μ΄μŠ€κ°€ λ³‘λ ¬λ‘œ μž‘λ™ν•˜λ―€λ‘œ λ³΄μœ ν•œ λ””λ°”μ΄μŠ€ 수만큼 ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•©λ‹ˆλ‹€. 그런 λ‹€μŒ 각각 ν•˜λ‚˜μ˜ 이미지 생성을 λ‹΄λ‹Ήν•˜λŠ” 8개의 λ””λ°”μ΄μŠ€μ—μ„œ ν•œ λ²ˆμ— 좔둠을 μˆ˜ν–‰ν•©λ‹ˆλ‹€. λ”°λΌμ„œ ν•˜λ‚˜μ˜ 칩이 ν•˜λ‚˜μ˜ 이미지λ₯Ό μƒμ„±ν•˜λŠ” 데 κ±Έλ¦¬λŠ” μ‹œκ°„κ³Ό λ™μΌν•œ μ‹œκ°„μ— 8개의 이미지λ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€. ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•˜κ³  λ‚˜λ©΄ νŒŒμ΄ν”„λΌμΈμ˜ `prepare_inputs` ν•¨μˆ˜λ₯Ό ν˜ΈμΆœν•˜μ—¬ ν† ν°ν™”λœ ν…μŠ€νŠΈ IDλ₯Ό μ–»μŠ΅λ‹ˆλ‹€. ν† ν°ν™”λœ ν…μŠ€νŠΈμ˜ κΈΈμ΄λŠ” κΈ°λ³Έ CLIP ν…μŠ€νŠΈ λͺ¨λΈμ˜ ꡬ성에 따라 77ν† ν°μœΌλ‘œ μ„€μ •λ©λ‹ˆλ‹€. ```python prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" prompt = [prompt] * jax.device_count() prompt_ids = pipeline.prepare_inputs(prompt) prompt_ids.shape ``` ```python out (8, 77) ``` ### 볡사(Replication) 및 μ •λ ¬ν™” λͺ¨λΈ λ§€κ°œλ³€μˆ˜μ™€ μž…λ ₯값은 μš°λ¦¬κ°€ λ³΄μœ ν•œ 8개의 병렬 μž₯μΉ˜μ— 볡사(Replication)λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€. λ§€κ°œλ³€μˆ˜ λ”•μ…”λ„ˆλ¦¬λŠ” `flax.jax_utils.replicate`(λ”•μ…”λ„ˆλ¦¬λ₯Ό μˆœνšŒν•˜λ©° κ°€μ€‘μΉ˜μ˜ λͺ¨μ–‘을 λ³€κ²½ν•˜μ—¬ 8번 λ°˜λ³΅ν•˜λŠ” ν•¨μˆ˜)λ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ‚¬λ©λ‹ˆλ‹€. 배열은 `shard`λ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ œλ©λ‹ˆλ‹€. ```python p_params = replicate(params) ``` ```python prompt_ids = shard(prompt_ids) prompt_ids.shape ``` ```python out (8, 1, 77) ``` 이 shape은 8개의 λ””λ°”μ΄μŠ€ 각각이 shape `(1, 77)`의 jnp 배열을 μž…λ ₯κ°’μœΌλ‘œ λ°›λŠ”λ‹€λŠ” μ˜λ―Έμž…λ‹ˆλ‹€. 즉 1은 λ””λ°”μ΄μŠ€λ‹Ή batch(배치) ν¬κΈ°μž…λ‹ˆλ‹€. λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•œ TPUμ—μ„œλŠ” ν•œ λ²ˆμ— μ—¬λŸ¬ 이미지(μΉ©λ‹Ή)λ₯Ό μƒμ„±ν•˜λ €λŠ” 경우 1보닀 클 수 μžˆμŠ΅λ‹ˆλ‹€. 이미지λ₯Ό 생성할 μ€€λΉ„κ°€ 거의 μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€! 이제 생성 ν•¨μˆ˜μ— 전달할 λ‚œμˆ˜ μƒμ„±κΈ°λ§Œ λ§Œλ“€λ©΄ λ©λ‹ˆλ‹€. 이것은 λ‚œμˆ˜λ₯Ό λ‹€λ£¨λŠ” λͺ¨λ“  ν•¨μˆ˜μ— λ‚œμˆ˜ 생성기가 μžˆμ–΄μ•Ό ν•œλ‹€λŠ”, λ‚œμˆ˜μ— λŒ€ν•΄ 맀우 μ§„μ§€ν•˜κ³  독단적인 Flax의 ν‘œμ€€ μ ˆμ°¨μž…λ‹ˆλ‹€. μ΄λ ‡κ²Œ ν•˜λ©΄ μ—¬λŸ¬ λΆ„μ‚°λœ κΈ°κΈ°μ—μ„œ ν›ˆλ ¨ν•  λ•Œμ—λ„ μž¬ν˜„μ„±μ΄ 보μž₯λ©λ‹ˆλ‹€. μ•„λž˜ 헬퍼 ν•¨μˆ˜λŠ” μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜μ—¬ λ‚œμˆ˜ 생성기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€. λ™μΌν•œ μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜λŠ” ν•œ μ •ν™•νžˆ λ™μΌν•œ κ²°κ³Όλ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€. λ‚˜μ€‘μ— λ…ΈνŠΈλΆμ—μ„œ κ²°κ³Όλ₯Ό 탐색할 λ•Œμ—” λ‹€λ₯Έ μ‹œλ“œλ₯Ό 자유둭게 μ‚¬μš©ν•˜μ„Έμš”. ```python def create_key(seed=0): return jax.random.PRNGKey(seed) ``` rngλ₯Ό 얻은 λ‹€μŒ 8번 'λΆ„ν• 'ν•˜μ—¬ 각 λ””λ°”μ΄μŠ€κ°€ λ‹€λ₯Έ μ œλ„ˆλ ˆμ΄ν„°λ₯Ό μˆ˜μ‹ ν•˜λ„λ‘ ν•©λ‹ˆλ‹€. λ”°λΌμ„œ 각 λ””λ°”μ΄μŠ€λ§ˆλ‹€ λ‹€λ₯Έ 이미지가 μƒμ„±λ˜λ©° 전체 ν”„λ‘œμ„ΈμŠ€λ₯Ό μž¬ν˜„ν•  수 μžˆμŠ΅λ‹ˆλ‹€. ```python rng = create_key(0) rng = jax.random.split(rng, jax.device_count()) ``` JAX μ½”λ“œλŠ” 맀우 λΉ λ₯΄κ²Œ μ‹€ν–‰λ˜λŠ” 효율적인 ν‘œν˜„μœΌλ‘œ μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•˜μ§€λ§Œ 후속 ν˜ΈμΆœμ—μ„œ λͺ¨λ“  μž…λ ₯이 λ™μΌν•œ λͺ¨μ–‘을 갖도둝 ν•΄μ•Ό ν•˜λ©°, 그렇지 μ•ŠμœΌλ©΄ JAXκ°€ μ½”λ“œλ₯Ό λ‹€μ‹œ μ»΄νŒŒμΌν•΄μ•Ό ν•˜λ―€λ‘œ μ΅œμ ν™”λœ 속도λ₯Ό ν™œμš©ν•  수 μ—†μŠ΅λ‹ˆλ‹€. `jit = True`λ₯Ό 인수둜 μ „λ‹¬ν•˜λ©΄ Flax νŒŒμ΄ν”„λΌμΈμ΄ μ½”λ“œλ₯Ό μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. λ˜ν•œ λͺ¨λΈμ΄ μ‚¬μš© κ°€λŠ₯ν•œ 8개의 λ””λ°”μ΄μŠ€μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ 보μž₯ν•©λ‹ˆλ‹€. λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•˜λ©΄ μ»΄νŒŒμΌν•˜λŠ” 데 μ‹œκ°„μ΄ 였래 κ±Έλ¦¬μ§€λ§Œ 이후 호좜(μž…λ ₯이 λ‹€λ₯Έ κ²½μš°μ—λ„)은 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄, ν…ŒμŠ€νŠΈν–ˆμ„ λ•Œ TPU v2-8μ—μ„œ μ»΄νŒŒμΌν•˜λŠ” 데 1λΆ„ 이상 κ±Έλ¦¬μ§€λ§Œ 이후 μΆ”λ‘  μ‹€ν–‰μ—λŠ” μ•½ 7μ΄ˆκ°€ κ±Έλ¦½λ‹ˆλ‹€. ``` %%time images = pipeline(prompt_ids, p_params, rng, jit=True)[0] ``` ```python out CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s Wall time: 1min 29s ``` λ°˜ν™˜λœ λ°°μ—΄μ˜ shape은 `(8, 1, 512, 512, 3)`μž…λ‹ˆλ‹€. 이λ₯Ό μž¬κ΅¬μ„±ν•˜μ—¬ 두 번째 차원을 μ œκ±°ν•˜κ³  512 Γ— 512 Γ— 3의 이미지 8개λ₯Ό 얻은 λ‹€μŒ PIL둜 λ³€ν™˜ν•©λ‹ˆλ‹€. ```python images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) ``` ### μ‹œκ°ν™” 이미지λ₯Ό κ·Έλ¦¬λ“œμ— ν‘œμ‹œν•˜λŠ” λ„μš°λ―Έ ν•¨μˆ˜λ₯Ό λ§Œλ“€μ–΄ λ³΄κ² μŠ΅λ‹ˆλ‹€. ```python def image_grid(imgs, rows, cols): w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid ``` ```python image_grid(images, 2, 4) ``` ![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg) ## λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈ μ‚¬μš© λͺ¨λ“  λ””λ°”μ΄μŠ€μ—μ„œ λ™μΌν•œ ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•  ν•„μš”λŠ” μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘¬ν”„νŠΈ 2개λ₯Ό 각각 4λ²ˆμ”© μƒμ„±ν•˜κ±°λ‚˜ ν•œ λ²ˆμ— 8개의 μ„œλ‘œ λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈλ₯Ό μƒμ„±ν•˜λŠ” λ“± μ›ν•˜λŠ” 것은 무엇이든 ν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•œλ²ˆ ν•΄λ³΄μ„Έμš”! λ¨Όμ € μž…λ ₯ μ€€λΉ„ μ½”λ“œλ₯Ό νŽΈλ¦¬ν•œ ν•¨μˆ˜λ‘œ λ¦¬νŒ©ν„°λ§ν•˜κ² μŠ΅λ‹ˆλ‹€: ```python prompts = [ "Labrador in the style of Hokusai", "Painting of a squirrel skating in New York", "HAL-9000 in the style of Van Gogh", "Times Square under water, with fish and a dolphin swimming around", "Ancient Roman fresco showing a man working on his laptop", "Close-up photograph of young black woman against urban background, high quality, bokeh", "Armchair in the shape of an avocado", "Clown astronaut in space, with Earth in the background", ] ``` ```python prompt_ids = pipeline.prepare_inputs(prompts) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, p_params, rng, jit=True).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) image_grid(images, 2, 4) ``` ![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg) ## 병렬화(parallelization)λŠ” μ–΄λ–»κ²Œ μž‘λ™ν•˜λŠ”κ°€? μ•žμ„œ `diffusers` Flax νŒŒμ΄ν”„λΌμΈμ΄ λͺ¨λΈμ„ μžλ™μœΌλ‘œ μ»΄νŒŒμΌν•˜κ³  μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰ν•œλ‹€κ³  λ§μ”€λ“œλ ΈμŠ΅λ‹ˆλ‹€. 이제 κ·Έ ν”„λ‘œμ„ΈμŠ€λ₯Ό κ°„λž΅ν•˜κ²Œ μ‚΄νŽ΄λ³΄κ³  μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€. JAX λ³‘λ ¬ν™”λŠ” μ—¬λŸ¬ 가지 λ°©λ²•μœΌλ‘œ μˆ˜ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. κ°€μž₯ μ‰¬μš΄ 방법은 jax.pmap ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•˜μ—¬ 단일 ν”„λ‘œκ·Έλž¨, 닀쀑 데이터(SPMD) 병렬화λ₯Ό λ‹¬μ„±ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 즉, λ™μΌν•œ μ½”λ“œμ˜ 볡사본을 각각 λ‹€λ₯Έ 데이터 μž…λ ₯에 λŒ€ν•΄ μ—¬λŸ¬ 개 μ‹€ν–‰ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 더 μ •κ΅ν•œ μ ‘κ·Ό 방식도 κ°€λŠ₯ν•˜λ―€λ‘œ 관심이 μžˆμœΌμ‹œλ‹€λ©΄ [JAX λ¬Έμ„œ](https://jax.readthedocs.io/en/latest/index.html)와 [`pjit` νŽ˜μ΄μ§€](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit)μ—μ„œ 이 주제λ₯Ό μ‚΄νŽ΄λ³΄μ‹œκΈ° λ°”λžλ‹ˆλ‹€! `jax.pmap`은 두 가지 κΈ°λŠ₯을 μˆ˜ν–‰ν•©λ‹ˆλ‹€: - `jax.jit()`λ₯Ό ν˜ΈμΆœν•œ κ²ƒμ²˜λŸΌ μ½”λ“œλ₯Ό 컴파일(λ˜λŠ” `jit`)ν•©λ‹ˆλ‹€. 이 μž‘μ—…μ€ `pmap`을 ν˜ΈμΆœν•  λ•Œκ°€ μ•„λ‹ˆλΌ pmapped ν•¨μˆ˜κ°€ 처음 호좜될 λ•Œ μˆ˜ν–‰λ©λ‹ˆλ‹€. - 컴파일된 μ½”λ“œκ°€ μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ ν•©λ‹ˆλ‹€. μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κΈ° μœ„ν•΄ 이미지 생성을 μ‹€ν–‰ν•˜λŠ” λΉ„κ³΅κ°œ λ©”μ„œλ“œμΈ νŒŒμ΄ν”„λΌμΈμ˜ `_generate` λ©”μ„œλ“œλ₯Ό `pmap`ν•©λ‹ˆλ‹€. 이 λ©”μ„œλ“œλŠ” ν–₯ν›„ `Diffusers` λ¦΄λ¦¬μŠ€μ—μ„œ 이름이 λ³€κ²½λ˜κ±°λ‚˜ 제거될 수 μžˆλ‹€λŠ” 점에 μœ μ˜ν•˜μ„Έμš”. ```python p_generate = pmap(pipeline._generate) ``` `pmap`을 μ‚¬μš©ν•œ ν›„ μ€€λΉ„λœ ν•¨μˆ˜ `p_generate`λŠ” κ°œλ…μ μœΌλ‘œ λ‹€μŒμ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€: * 각 μž₯μΉ˜μ—μ„œ κΈ°λ³Έ ν•¨μˆ˜ `pipeline._generate`의 볡사본을 ν˜ΈμΆœν•©λ‹ˆλ‹€. * 각 μž₯μΉ˜μ— μž…λ ₯ 인수의 λ‹€λ₯Έ 뢀뢄을 λ³΄λƒ…λ‹ˆλ‹€. 이것이 λ°”λ‘œ 샀딩이 μ‚¬μš©λ˜λŠ” μ΄μœ μž…λ‹ˆλ‹€. 이 경우 `prompt_ids`의 shape은 `(8, 1, 77, 768)`μž…λ‹ˆλ‹€. 이 배열은 8개둜 λΆ„ν• λ˜κ³  `_generate`의 각 볡사본은 `(1, 77, 768)`의 shape을 가진 μž…λ ₯을 λ°›κ²Œ λ©λ‹ˆλ‹€. λ³‘λ ¬λ‘œ ν˜ΈμΆœλœλ‹€λŠ” 사싀을 μ™„μ „νžˆ λ¬΄μ‹œν•˜κ³  `_generate`λ₯Ό μ½”λ”©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. batch(배치) 크기(이 μ˜ˆμ œμ—μ„œλŠ” `1`)와 μ½”λ“œμ— μ ν•©ν•œ μ°¨μ›λ§Œ μ‹ κ²½ μ“°λ©΄ 되며, λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° μœ„ν•΄ 아무것도 λ³€κ²½ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€. νŒŒμ΄ν”„λΌμΈ ν˜ΈμΆœμ„ μ‚¬μš©ν•  λ•Œμ™€ λ§ˆμ°¬κ°€μ§€λ‘œ, λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•  λ•ŒλŠ” μ‹œκ°„μ΄ κ±Έλ¦¬μ§€λ§Œ κ·Έ μ΄ν›„μ—λŠ” 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€. ``` %%time images = p_generate(prompt_ids, p_params, rng) images = images.block_until_ready() images.shape ``` ```python out CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s Wall time: 1min 15s ``` ```python images.shape ``` ```python out (8, 1, 512, 512, 3) ``` JAXλŠ” 비동기 λ””μŠ€νŒ¨μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  κ°€λŠ₯ν•œ ν•œ 빨리 μ œμ–΄κΆŒμ„ Python 루프에 λ°˜ν™˜ν•˜κΈ° λ•Œλ¬Έμ— μΆ”λ‘  μ‹œκ°„μ„ μ •ν™•ν•˜κ²Œ μΈ‘μ •ν•˜κΈ° μœ„ν•΄ `block_until_ready()`λ₯Ό μ‚¬μš©ν•©λ‹ˆλ‹€. 아직 κ΅¬μ²΄ν™”λ˜μ§€ μ•Šμ€ 계산 κ²°κ³Όλ₯Ό μ‚¬μš©ν•˜λ €λŠ” 경우 μžλ™μœΌλ‘œ 차단이 μˆ˜ν–‰λ˜λ―€λ‘œ μ½”λ“œμ—μ„œ 이 ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€.