ahassoun's picture
Upload 3018 files
ee6e328
|
raw
history blame
18.3 kB

์Šคํฌ๋ฆฝํŠธ๋กœ ์‹คํ–‰ํ•˜๊ธฐ[[train-with-a-script]]

๐Ÿค— Transformers ๋…ธํŠธ๋ถ๊ณผ ํ•จ๊ป˜ PyTorch, TensorFlow, ๋˜๋Š” JAX/Flax๋ฅผ ์‚ฌ์šฉํ•ด ํŠน์ • ํƒœ์Šคํฌ์— ๋Œ€ํ•œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ฃผ๋Š” ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ ์—ฐ๊ตฌ ํ”„๋กœ์ ํŠธ ๋ฐ ๋ ˆ๊ฑฐ์‹œ ์˜ˆ์ œ์—์„œ ๋Œ€๋ถ€๋ถ„ ์ปค๋ฎค๋‹ˆํ‹ฐ์—์„œ ์ œ๊ณตํ•œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์Šคํฌ๋ฆฝํŠธ๋Š” ์ ๊ทน์ ์œผ๋กœ ์œ ์ง€ ๊ด€๋ฆฌ๋˜์ง€ ์•Š์œผ๋ฉฐ ์ตœ์‹  ๋ฒ„์ „์˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์™€ ํ˜ธํ™˜๋˜์ง€ ์•Š์„ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ํŠน์ • ๋ฒ„์ „์˜ ๐Ÿค— Transformers๋ฅผ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ๋ชจ๋“  ๋ฌธ์ œ์—์„œ ๋ฐ”๋กœ ์ž‘๋™ํ•˜๋Š” ๊ฒƒ์€ ์•„๋‹ˆ๋ฉฐ, ํ•ด๊ฒฐํ•˜๋ ค๋Š” ๋ฌธ์ œ์— ๋งž๊ฒŒ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ๋ณ€๊ฒฝํ•ด์•ผ ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋Œ€๋ถ€๋ถ„์˜ ์Šคํฌ๋ฆฝํŠธ์—๋Š” ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ๋ฐฉ๋ฒ•์ด ๋‚˜์™€์žˆ์–ด ํ•„์š”์— ๋”ฐ๋ผ ์ˆ˜์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ์— ๊ตฌํ˜„ํ•˜๊ณ  ์‹ถ์€ ๊ธฐ๋Šฅ์ด ์žˆ์œผ๋ฉด pull request๋ฅผ ์ œ์ถœํ•˜๊ธฐ ์ „์— ํฌ๋Ÿผ ๋˜๋Š” ์ด์Šˆ์—์„œ ๋…ผ์˜ํ•ด ์ฃผ์„ธ์š”. ๋ฒ„๊ทธ ์ˆ˜์ •์€ ํ™˜์˜ํ•˜์ง€๋งŒ ๊ฐ€๋…์„ฑ์„ ํฌ์ƒํ•˜๋ฉด์„œ๊นŒ์ง€ ๋” ๋งŽ์€ ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•˜๋Š” pull request๋Š” ๋ณ‘ํ•ฉ(merge)ํ•˜์ง€ ์•Š์„ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” PyTorch ๋ฐ TensorFlow์—์„œ ์š”์•ฝ ํ›ˆ๋ จํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ํŠน๋ณ„ํ•œ ์„ค๋ช…์ด ์—†๋Š” ํ•œ ๋ชจ๋“  ์˜ˆ์ œ๋Š” ๋‘ ํ”„๋ ˆ์ž„์›Œํฌ ๋ชจ๋‘์—์„œ ์ž‘๋™ํ•  ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค.

์„ค์ •ํ•˜๊ธฐ[[setup]]

์ตœ์‹  ๋ฒ„์ „์˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์ƒˆ ๊ฐ€์ƒ ํ™˜๊ฒฝ์—์„œ ์†Œ์Šค๋กœ๋ถ€ํ„ฐ ๐Ÿค— Transformers๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

git clone https://github.com/huggingface/transformers
cd transformers
pip install .

์ด์ „ ๋ฒ„์ „์˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ๋ณด๋ ค๋ฉด ์•„๋ž˜ ํ† ๊ธ€์„ ํด๋ฆญํ•˜์„ธ์š”:

์ด์ „ ๋ฒ„์ „์˜ ๐Ÿค— Transformers ์˜ˆ์ œ

๊ทธ๋ฆฌ๊ณ  ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณต์ œ(clone)ํ•ด์˜จ ๐Ÿค— Transformers ๋ฒ„์ „์„ ํŠน์ • ๋ฒ„์ „(์˜ˆ: v3.5.1)์œผ๋กœ ์ „ํ™˜ํ•˜์„ธ์š”:

git checkout tags/v3.5.1

์˜ฌ๋ฐ”๋ฅธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋ฒ„์ „์„ ์„ค์ •ํ•œ ํ›„ ์›ํ•˜๋Š” ์˜ˆ์ œ ํด๋”๋กœ ์ด๋™ํ•˜์—ฌ ์˜ˆ์ œ๋ณ„๋กœ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋Œ€ํ•œ ์š”๊ตฌ ์‚ฌํ•ญ(requirements)์„ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค:

pip install -r requirements.txt

์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ํ•˜๊ธฐ[[run-a-script]]

์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” ๐Ÿค— [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ๋Š” ์š”์•ฝ ๊ธฐ๋Šฅ์„ ์ง€์›ํ•˜๋Š” ์•„ํ‚คํ…์ฒ˜์—์„œ [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์˜ˆ๋Š” [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ [T5-small](https://huggingface.co/t5-small)์„ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค. T5 ๋ชจ๋ธ์€ ํ›ˆ๋ จ ๋ฐฉ์‹์— ๋”ฐ๋ผ ์ถ”๊ฐ€ `source_prefix` ์ธ์ˆ˜๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด ํ”„๋กฌํ”„ํŠธ๋Š” ์š”์•ฝ ์ž‘์—…์ž„์„ T5์— ์•Œ๋ ค์ค๋‹ˆ๋‹ค.
python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate
์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” ๐Ÿค— [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ๋Š” ์š”์•ฝ ๊ธฐ๋Šฅ์„ ์ง€์›ํ•˜๋Š” ์•„ํ‚คํ…์ฒ˜์—์„œ Keras๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์˜ˆ๋Š” [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ [T5-small](https://huggingface.co/t5-small)์„ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค. T5 ๋ชจ๋ธ์€ ํ›ˆ๋ จ ๋ฐฉ์‹์— ๋”ฐ๋ผ ์ถ”๊ฐ€ `source_prefix` ์ธ์ˆ˜๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด ํ”„๋กฌํ”„ํŠธ๋Š” ์š”์•ฝ ์ž‘์—…์ž„์„ T5์— ์•Œ๋ ค์ค๋‹ˆ๋‹ค. ```bash python examples/tensorflow/summarization/run_summarization.py \ --model_name_or_path t5-small \ --dataset_name cnn_dailymail \ --dataset_config "3.0.0" \ --output_dir /tmp/tst-summarization \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 16 \ --num_train_epochs 3 \ --do_train \ --do_eval ```

ํ˜ผํ•ฉ ์ •๋ฐ€๋„(mixed precision)๋กœ ๋ถ„์‚ฐ ํ›ˆ๋ จํ•˜๊ธฐ[[distributed-training-and-mixed-precision]]

Trainer ํด๋ž˜์Šค๋Š” ๋ถ„์‚ฐ ํ›ˆ๋ จ๊ณผ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(mixed precision)๋ฅผ ์ง€์›ํ•˜๋ฏ€๋กœ ์Šคํฌ๋ฆฝํŠธ์—์„œ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋‘ ๊ฐ€์ง€ ๊ธฐ๋Šฅ์„ ๋ชจ๋‘ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด ๋‹ค์Œ ๋‘ ๊ฐ€์ง€๋ฅผ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

  • fp16 ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•ด ํ˜ผํ•ฉ ์ •๋ฐ€๋„(mixed precision)๋ฅผ ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.
  • nproc_per_node ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•ด ์‚ฌ์šฉํ•  GPU ๊ฐœ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
python -m torch.distributed.launch \
    --nproc_per_node 8 pytorch/summarization/run_summarization.py \
    --fp16 \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

TensorFlow ์Šคํฌ๋ฆฝํŠธ๋Š” ๋ถ„์‚ฐ ํ›ˆ๋ จ์„ ์œ„ํ•ด MirroredStrategy๋ฅผ ํ™œ์šฉํ•˜๋ฉฐ, ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ์— ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋‹ค์ค‘ GPU ํ™˜๊ฒฝ์ด๋ผ๋ฉด, TensorFlow ์Šคํฌ๋ฆฝํŠธ๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ์—ฌ๋Ÿฌ ๊ฐœ์˜ GPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

TPU ์œ„์—์„œ ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ํ•˜๊ธฐ[[run-a-script-on-a-tpu]]

Tensor Processing Units (TPUs)๋Š” ์„ฑ๋Šฅ์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ํŠน๋ณ„ํžˆ ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. PyTorch๋Š” [XLA](https://www.tensorflow.org/xla) ๋”ฅ๋Ÿฌ๋‹ ์ปดํŒŒ์ผ๋Ÿฌ์™€ ํ•จ๊ป˜ TPU๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค(์ž์„ธํ•œ ๋‚ด์šฉ์€ [์—ฌ๊ธฐ](https://github.com/pytorch/xla/blob/master/README.md) ์ฐธ์กฐ). TPU๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด `xla_spawn.py` ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ณ  `num_cores` ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋ ค๋Š” TPU ์ฝ”์–ด ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
python xla_spawn.py --num_cores 8 \
    summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate
Tensor Processing Units (TPUs)๋Š” ์„ฑ๋Šฅ์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ํŠน๋ณ„ํžˆ ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค. TensorFlow ์Šคํฌ๋ฆฝํŠธ๋Š” TPU๋ฅผ ํ›ˆ๋ จ์— ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด [`TPUStrategy`](https://www.tensorflow.org/guide/distributed_training#tpustrategy)๋ฅผ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค. TPU๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด TPU ๋ฆฌ์†Œ์Šค์˜ ์ด๋ฆ„์„ `tpu` ์ธ์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
python run_summarization.py  \
    --tpu name_of_tpu_resource \
    --model_name_or_path t5-small \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --output_dir /tmp/tst-summarization  \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --num_train_epochs 3 \
    --do_train \
    --do_eval

๐Ÿค— Accelerate๋กœ ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ํ•˜๊ธฐ[[run-a-script-with-accelerate]]

๐Ÿค— Accelerate๋Š” PyTorch ํ›ˆ๋ จ ๊ณผ์ •์— ๋Œ€ํ•œ ์™„์ „ํ•œ ๊ฐ€์‹œ์„ฑ์„ ์œ ์ง€ํ•˜๋ฉด์„œ ์—ฌ๋Ÿฌ ์œ ํ˜•์˜ ์„ค์ •(CPU ์ „์šฉ, ๋‹ค์ค‘ GPU, TPU)์—์„œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ๋Š” ํ†ตํ•ฉ ๋ฐฉ๋ฒ•์„ ์ œ๊ณตํ•˜๋Š” PyTorch ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๐Ÿค— Accelerate๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

์ฐธ๊ณ : Accelerate๋Š” ๋น ๋ฅด๊ฒŒ ๊ฐœ๋ฐœ ์ค‘์ด๋ฏ€๋กœ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด accelerate๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

pip install git+https://github.com/huggingface/accelerate

run_summarization.py ์Šคํฌ๋ฆฝํŠธ ๋Œ€์‹  run_summarization_no_trainer.py ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๐Ÿค— Accelerate ํด๋ž˜์Šค๊ฐ€ ์ง€์›๋˜๋Š” ์Šคํฌ๋ฆฝํŠธ๋Š” ํด๋”์— task_no_trainer.py ํŒŒ์ผ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์—ฌ ๊ตฌ์„ฑ ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค:

accelerate config

์„ค์ •์„ ํ…Œ์ŠคํŠธํ•˜์—ฌ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๊ตฌ์„ฑ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค:

accelerate test

์ด์ œ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค:

accelerate launch run_summarization_no_trainer.py \
    --model_name_or_path t5-small \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ~/tmp/tst-summarization

์‚ฌ์šฉ์ž ์ •์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์‚ฌ์šฉํ•˜๊ธฐ[[use-a-custom-dataset]]

์š”์•ฝ ์Šคํฌ๋ฆฝํŠธ๋Š” ์‚ฌ์šฉ์ž ์ง€์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ CSV ๋˜๋Š” JSON ํŒŒ์ผ์ธ ๊ฒฝ์šฐ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž ์ง€์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š” ๋ช‡ ๊ฐ€์ง€ ์ถ”๊ฐ€ ์ธ์ˆ˜๋ฅผ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

  • train_file๊ณผ validation_file์€ ํ›ˆ๋ จ ๋ฐ ๊ฒ€์ฆ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
  • text_column์€ ์š”์•ฝํ•  ์ž…๋ ฅ ํ…์ŠคํŠธ์ž…๋‹ˆ๋‹ค.
  • summary_column์€ ์ถœ๋ ฅํ•  ๋Œ€์ƒ ํ…์ŠคํŠธ์ž…๋‹ˆ๋‹ค.

์‚ฌ์šฉ์ž ์ง€์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์š”์•ฝ ์Šคํฌ๋ฆฝํŠธ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --text_column text_column_name \
    --summary_column summary_column_name \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate

์Šคํฌ๋ฆฝํŠธ ํ…Œ์ŠคํŠธํ•˜๊ธฐ[[test-a-script]]

์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋Œ€์ƒ์œผ๋กœ ํ›ˆ๋ จ์„ ์™„๋ฃŒํ•˜๋Š”๋ฐ ๊ฝค ์˜ค๋žœ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ๊ธฐ ๋•Œ๋ฌธ์—, ์ž‘์€ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ๋ชจ๋“  ๊ฒƒ์ด ์˜ˆ์ƒ๋Œ€๋กœ ์‹คํ–‰๋˜๋Š”์ง€ ํ™•์ธํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ตœ๋Œ€ ์ƒ˜ํ”Œ ์ˆ˜๋กœ ์ž˜๋ผ๋ƒ…๋‹ˆ๋‹ค:

  • max_train_samples
  • max_eval_samples
  • max_predict_samples
python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --max_train_samples 50 \
    --max_eval_samples 50 \
    --max_predict_samples 50 \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

๋ชจ๋“  ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๊ฐ€ max_predict_samples ์ธ์ˆ˜๋ฅผ ์ง€์›ํ•˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค. ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์ด ์ธ์ˆ˜๋ฅผ ์ง€์›ํ•˜๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ -h ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ํ™•์ธํ•˜์„ธ์š”:

examples/pytorch/summarization/run_summarization.py -h

์ฒดํฌํฌ์ธํŠธ(checkpoint)์—์„œ ํ›ˆ๋ จ ์ด์–ด์„œ ํ•˜๊ธฐ[[resume-training-from-checkpoint]]

๋˜ ๋‹ค๋ฅธ ์œ ์šฉํ•œ ์˜ต์…˜์€ ์ด์ „ ์ฒดํฌํฌ์ธํŠธ์—์„œ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ํ›ˆ๋ จ์ด ์ค‘๋‹จ๋˜๋”๋ผ๋„ ์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ์‹œ์ž‘ํ•˜์ง€ ์•Š๊ณ  ์ค‘๋‹จํ•œ ๋ถ€๋ถ„๋ถ€ํ„ฐ ๋‹ค์‹œ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ์—์„œ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•˜๋Š” ๋ฐฉ๋ฒ•์—๋Š” ๋‘ ๊ฐ€์ง€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฒซ ๋ฒˆ์งธ๋Š” output_dir previous_output_dir ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ output_dir์— ์ €์žฅ๋œ ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ๋ถ€ํ„ฐ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ overwrite_output_dir์„ ์ œ๊ฑฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

python examples/pytorch/summarization/run_summarization.py
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --output_dir previous_output_dir \
    --predict_with_generate

๋‘ ๋ฒˆ์งธ๋Š” resume_from_checkpoint path_to_specific_checkpoint ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŠน์ • ์ฒดํฌํฌ์ธํŠธ ํด๋”์—์„œ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

python examples/pytorch/summarization/run_summarization.py
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --resume_from_checkpoint path_to_specific_checkpoint \
    --predict_with_generate

๋ชจ๋ธ ๊ณต์œ ํ•˜๊ธฐ[[share-your-model]]

๋ชจ๋“  ์Šคํฌ๋ฆฝํŠธ๋Š” ์ตœ์ข… ๋ชจ๋ธ์„ Model Hub์— ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹œ์ž‘ํ•˜๊ธฐ ์ „์— Hugging Face์— ๋กœ๊ทธ์ธํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

huggingface-cli login

๊ทธ๋Ÿฐ ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ์— push_to_hub ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ธ์ˆ˜๋Š” Hugging Face ์‚ฌ์šฉ์ž ์ด๋ฆ„๊ณผ output_dir์— ์ง€์ •๋œ ํด๋” ์ด๋ฆ„์œผ๋กœ ์ €์žฅ์†Œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ €์žฅ์†Œ์— ํŠน์ • ์ด๋ฆ„์„ ์ง€์ •ํ•˜๋ ค๋ฉด push_to_hub_model_id ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ €์žฅ์†Œ๋Š” ๋„ค์ž„์ŠคํŽ˜์ด์Šค ์•„๋ž˜์— ์ž๋™์œผ๋กœ ๋‚˜์—ด๋ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์˜ˆ๋Š” ํŠน์ • ์ €์žฅ์†Œ ์ด๋ฆ„์œผ๋กœ ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค:

python examples/pytorch/summarization/run_summarization.py
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --push_to_hub \
    --push_to_hub_model_id finetuned-t5-cnn_dailymail \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate