Spaces:
Paused
์คํฌ๋ฆฝํธ๋ก ์คํํ๊ธฐ[[train-with-a-script]]
๐ค Transformers ๋ ธํธ๋ถ๊ณผ ํจ๊ป PyTorch, TensorFlow, ๋๋ JAX/Flax๋ฅผ ์ฌ์ฉํด ํน์ ํ์คํฌ์ ๋ํ ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ฃผ๋ ์์ ์คํฌ๋ฆฝํธ๋ ์์ต๋๋ค.
๋ํ ์ฐ๊ตฌ ํ๋ก์ ํธ ๋ฐ ๋ ๊ฑฐ์ ์์ ์์ ๋๋ถ๋ถ ์ปค๋ฎค๋ํฐ์์ ์ ๊ณตํ ์คํฌ๋ฆฝํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ์ด๋ฌํ ์คํฌ๋ฆฝํธ๋ ์ ๊ทน์ ์ผ๋ก ์ ์ง ๊ด๋ฆฌ๋์ง ์์ผ๋ฉฐ ์ต์ ๋ฒ์ ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํธํ๋์ง ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ ํน์ ๋ฒ์ ์ ๐ค Transformers๋ฅผ ํ์๋ก ํฉ๋๋ค.
์์ ์คํฌ๋ฆฝํธ๊ฐ ๋ชจ๋ ๋ฌธ์ ์์ ๋ฐ๋ก ์๋ํ๋ ๊ฒ์ ์๋๋ฉฐ, ํด๊ฒฐํ๋ ค๋ ๋ฌธ์ ์ ๋ง๊ฒ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณ๊ฒฝํด์ผ ํ ์๋ ์์ต๋๋ค. ์ด๋ฅผ ์ํด ๋๋ถ๋ถ์ ์คํฌ๋ฆฝํธ์๋ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ์ด ๋์์์ด ํ์์ ๋ฐ๋ผ ์์ ํ ์ ์์ต๋๋ค.
์์ ์คํฌ๋ฆฝํธ์ ๊ตฌํํ๊ณ ์ถ์ ๊ธฐ๋ฅ์ด ์์ผ๋ฉด pull request๋ฅผ ์ ์ถํ๊ธฐ ์ ์ ํฌ๋ผ ๋๋ ์ด์์์ ๋ ผ์ํด ์ฃผ์ธ์. ๋ฒ๊ทธ ์์ ์ ํ์ํ์ง๋ง ๊ฐ๋ ์ฑ์ ํฌ์ํ๋ฉด์๊น์ง ๋ ๋ง์ ๊ธฐ๋ฅ์ ์ถ๊ฐํ๋ pull request๋ ๋ณํฉ(merge)ํ์ง ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ PyTorch ๋ฐ TensorFlow์์ ์์ฝ ํ๋ จํ๋ ์คํฌ๋ฆฝํธ ์์ ๋ฅผ ์คํํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค. ํน๋ณํ ์ค๋ช ์ด ์๋ ํ ๋ชจ๋ ์์ ๋ ๋ ํ๋ ์์ํฌ ๋ชจ๋์์ ์๋ํ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค.
์ค์ ํ๊ธฐ[[setup]]
์ต์ ๋ฒ์ ์ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ์คํํ๋ ค๋ฉด ์ ๊ฐ์ ํ๊ฒฝ์์ ์์ค๋ก๋ถํฐ ๐ค Transformers๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค:
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
์ด์ ๋ฒ์ ์ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณด๋ ค๋ฉด ์๋ ํ ๊ธ์ ํด๋ฆญํ์ธ์:
์ด์ ๋ฒ์ ์ ๐ค Transformers ์์
๊ทธ๋ฆฌ๊ณ ๋ค์๊ณผ ๊ฐ์ด ๋ณต์ (clone)ํด์จ ๐ค Transformers ๋ฒ์ ์ ํน์ ๋ฒ์ (์: v3.5.1)์ผ๋ก ์ ํํ์ธ์:
git checkout tags/v3.5.1
์ฌ๋ฐ๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ฒ์ ์ ์ค์ ํ ํ ์ํ๋ ์์ ํด๋๋ก ์ด๋ํ์ฌ ์์ ๋ณ๋ก ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ํ ์๊ตฌ ์ฌํญ(requirements)์ ์ค์นํฉ๋๋ค:
pip install -r requirements.txt
์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script]]
์์ ์คํฌ๋ฆฝํธ๋ ๐ค [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ๋ ์์ฝ ๊ธฐ๋ฅ์ ์ง์ํ๋ ์ํคํ ์ฒ์์ [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. ๋ค์ ์๋ [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํฐ ์ธํธ์์ [T5-small](https://huggingface.co/t5-small)์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค. T5 ๋ชจ๋ธ์ ํ๋ จ ๋ฐฉ์์ ๋ฐ๋ผ ์ถ๊ฐ `source_prefix` ์ธ์๊ฐ ํ์ํ๋ฉฐ, ์ด ํ๋กฌํํธ๋ ์์ฝ ์์ ์์ T5์ ์๋ ค์ค๋๋ค.python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
์์ ์คํฌ๋ฆฝํธ๋ ๐ค [Datasets](https://huggingface.co/docs/datasets/) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ๋ ์์ฝ ๊ธฐ๋ฅ์ ์ง์ํ๋ ์ํคํ
์ฒ์์ Keras๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค.
๋ค์ ์๋ [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) ๋ฐ์ดํฐ ์ธํธ์์ [T5-small](https://huggingface.co/t5-small)์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค.
T5 ๋ชจ๋ธ์ ํ๋ จ ๋ฐฉ์์ ๋ฐ๋ผ ์ถ๊ฐ `source_prefix` ์ธ์๊ฐ ํ์ํ๋ฉฐ, ์ด ํ๋กฌํํธ๋ ์์ฝ ์์
์์ T5์ ์๋ ค์ค๋๋ค.
```bash
python examples/tensorflow/summarization/run_summarization.py \
--model_name_or_path t5-small \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 16 \
--num_train_epochs 3 \
--do_train \
--do_eval
```
ํผํฉ ์ ๋ฐ๋(mixed precision)๋ก ๋ถ์ฐ ํ๋ จํ๊ธฐ[[distributed-training-and-mixed-precision]]
Trainer ํด๋์ค๋ ๋ถ์ฐ ํ๋ จ๊ณผ ํผํฉ ์ ๋ฐ๋(mixed precision)๋ฅผ ์ง์ํ๋ฏ๋ก ์คํฌ๋ฆฝํธ์์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด ๋ ๊ฐ์ง ๊ธฐ๋ฅ์ ๋ชจ๋ ํ์ฑํํ๋ ค๋ฉด ๋ค์ ๋ ๊ฐ์ง๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค:
fp16
์ธ์๋ฅผ ์ถ๊ฐํด ํผํฉ ์ ๋ฐ๋(mixed precision)๋ฅผ ํ์ฑํํฉ๋๋ค.nproc_per_node
์ธ์๋ฅผ ์ถ๊ฐํด ์ฌ์ฉํ GPU ๊ฐ์๋ฅผ ์ค์ ํฉ๋๋ค.
python -m torch.distributed.launch \
--nproc_per_node 8 pytorch/summarization/run_summarization.py \
--fp16 \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
TensorFlow ์คํฌ๋ฆฝํธ๋ ๋ถ์ฐ ํ๋ จ์ ์ํด MirroredStrategy
๋ฅผ ํ์ฉํ๋ฉฐ, ํ๋ จ ์คํฌ๋ฆฝํธ์ ์ธ์๋ฅผ ์ถ๊ฐํ ํ์๊ฐ ์์ต๋๋ค.
๋ค์ค GPU ํ๊ฒฝ์ด๋ผ๋ฉด, TensorFlow ์คํฌ๋ฆฝํธ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฌ๋ฌ ๊ฐ์ GPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
TPU ์์์ ์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script-on-a-tpu]]
Tensor Processing Units (TPUs)๋ ์ฑ๋ฅ์ ๊ฐ์ํํ๊ธฐ ์ํด ํน๋ณํ ์ค๊ณ๋์์ต๋๋ค. PyTorch๋ [XLA](https://www.tensorflow.org/xla) ๋ฅ๋ฌ๋ ์ปดํ์ผ๋ฌ์ ํจ๊ป TPU๋ฅผ ์ง์ํฉ๋๋ค(์์ธํ ๋ด์ฉ์ [์ฌ๊ธฐ](https://github.com/pytorch/xla/blob/master/README.md) ์ฐธ์กฐ). TPU๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด `xla_spawn.py` ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ณ `num_cores` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉํ๋ ค๋ TPU ์ฝ์ด ์๋ฅผ ์ค์ ํฉ๋๋ค.python xla_spawn.py --num_cores 8 \
summarization/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
Tensor Processing Units (TPUs)๋ ์ฑ๋ฅ์ ๊ฐ์ํํ๊ธฐ ์ํด ํน๋ณํ ์ค๊ณ๋์์ต๋๋ค.
TensorFlow ์คํฌ๋ฆฝํธ๋ TPU๋ฅผ ํ๋ จ์ ์ฌ์ฉํ๊ธฐ ์ํด [`TPUStrategy`](https://www.tensorflow.org/guide/distributed_training#tpustrategy)๋ฅผ ํ์ฉํฉ๋๋ค.
TPU๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด TPU ๋ฆฌ์์ค์ ์ด๋ฆ์ `tpu` ์ธ์์ ์ ๋ฌํฉ๋๋ค.
python run_summarization.py \
--tpu name_of_tpu_resource \
--model_name_or_path t5-small \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 16 \
--num_train_epochs 3 \
--do_train \
--do_eval
๐ค Accelerate๋ก ์คํฌ๋ฆฝํธ ์คํํ๊ธฐ[[run-a-script-with-accelerate]]
๐ค Accelerate๋ PyTorch ํ๋ จ ๊ณผ์ ์ ๋ํ ์์ ํ ๊ฐ์์ฑ์ ์ ์งํ๋ฉด์ ์ฌ๋ฌ ์ ํ์ ์ค์ (CPU ์ ์ฉ, ๋ค์ค GPU, TPU)์์ ๋ชจ๋ธ์ ํ๋ จํ ์ ์๋ ํตํฉ ๋ฐฉ๋ฒ์ ์ ๊ณตํ๋ PyTorch ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ค. ๐ค Accelerate๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
์ฐธ๊ณ : Accelerate๋ ๋น ๋ฅด๊ฒ ๊ฐ๋ฐ ์ค์ด๋ฏ๋ก ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ค๋ฉด accelerate๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค.
pip install git+https://github.com/huggingface/accelerate
run_summarization.py
์คํฌ๋ฆฝํธ ๋์ run_summarization_no_trainer.py
์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
๐ค Accelerate ํด๋์ค๊ฐ ์ง์๋๋ ์คํฌ๋ฆฝํธ๋ ํด๋์ task_no_trainer.py
ํ์ผ์ด ์์ต๋๋ค.
๋ค์ ๋ช
๋ น์ ์คํํ์ฌ ๊ตฌ์ฑ ํ์ผ์ ์์ฑํ๊ณ ์ ์ฅํฉ๋๋ค:
accelerate config
์ค์ ์ ํ ์คํธํ์ฌ ์ฌ๋ฐ๋ฅด๊ฒ ๊ตฌ์ฑ๋์๋์ง ํ์ธํฉ๋๋ค:
accelerate test
์ด์ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค:
accelerate launch run_summarization_no_trainer.py \
--model_name_or_path t5-small \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir ~/tmp/tst-summarization
์ฌ์ฉ์ ์ ์ ๋ฐ์ดํฐ ์ธํธ ์ฌ์ฉํ๊ธฐ[[use-a-custom-dataset]]
์์ฝ ์คํฌ๋ฆฝํธ๋ ์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๊ฐ CSV ๋๋ JSON ํ์ผ์ธ ๊ฒฝ์ฐ ์ง์ํฉ๋๋ค. ์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ์๋ ๋ช ๊ฐ์ง ์ถ๊ฐ ์ธ์๋ฅผ ์ง์ ํด์ผ ํฉ๋๋ค:
train_file
๊ณผvalidation_file
์ ํ๋ จ ๋ฐ ๊ฒ์ฆ ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํฉ๋๋ค.text_column
์ ์์ฝํ ์ ๋ ฅ ํ ์คํธ์ ๋๋ค.summary_column
์ ์ถ๋ ฅํ ๋์ ํ ์คํธ์ ๋๋ค.
์ฌ์ฉ์ ์ง์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ฌ์ฉํ๋ ์์ฝ ์คํฌ๋ฆฝํธ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--train_file path_to_csv_or_jsonlines_file \
--validation_file path_to_csv_or_jsonlines_file \
--text_column text_column_name \
--summary_column summary_column_name \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--overwrite_output_dir \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--predict_with_generate
์คํฌ๋ฆฝํธ ํ ์คํธํ๊ธฐ[[test-a-script]]
์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋์์ผ๋ก ํ๋ จ์ ์๋ฃํ๋๋ฐ ๊ฝค ์ค๋ ์๊ฐ์ด ๊ฑธ๋ฆฌ๊ธฐ ๋๋ฌธ์, ์์ ๋ฐ์ดํฐ ์ธํธ์์ ๋ชจ๋ ๊ฒ์ด ์์๋๋ก ์คํ๋๋์ง ํ์ธํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋ค์ ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ต๋ ์ํ ์๋ก ์๋ผ๋ ๋๋ค:
max_train_samples
max_eval_samples
max_predict_samples
python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path t5-small \
--max_train_samples 50 \
--max_eval_samples 50 \
--max_predict_samples 50 \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
๋ชจ๋ ์์ ์คํฌ๋ฆฝํธ๊ฐ max_predict_samples
์ธ์๋ฅผ ์ง์ํ์ง๋ ์์ต๋๋ค.
์คํฌ๋ฆฝํธ๊ฐ ์ด ์ธ์๋ฅผ ์ง์ํ๋์ง ํ์คํ์ง ์์ ๊ฒฝ์ฐ -h
์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ํ์ธํ์ธ์:
examples/pytorch/summarization/run_summarization.py -h
์ฒดํฌํฌ์ธํธ(checkpoint)์์ ํ๋ จ ์ด์ด์ ํ๊ธฐ[[resume-training-from-checkpoint]]
๋ ๋ค๋ฅธ ์ ์ฉํ ์ต์ ์ ์ด์ ์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๊ฒ์ ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ํ๋ จ์ด ์ค๋จ๋๋๋ผ๋ ์ฒ์๋ถํฐ ๋ค์ ์์ํ์ง ์๊ณ ์ค๋จํ ๋ถ๋ถ๋ถํฐ ๋ค์ ์์ํ ์ ์์ต๋๋ค. ์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์๋ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค.
์ฒซ ๋ฒ์งธ๋ output_dir previous_output_dir
์ธ์๋ฅผ ์ฌ์ฉํ์ฌ output_dir
์ ์ ์ฅ๋ ์ต์ ์ฒดํฌํฌ์ธํธ๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์
๋๋ค.
์ด ๊ฒฝ์ฐ overwrite_output_dir
์ ์ ๊ฑฐํด์ผ ํฉ๋๋ค:
python examples/pytorch/summarization/run_summarization.py
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--output_dir previous_output_dir \
--predict_with_generate
๋ ๋ฒ์งธ๋ resume_from_checkpoint path_to_specific_checkpoint
์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ํน์ ์ฒดํฌํฌ์ธํธ ํด๋์์ ํ๋ จ์ ์ฌ๊ฐํ๋ ๋ฐฉ๋ฒ์
๋๋ค.
python examples/pytorch/summarization/run_summarization.py
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--resume_from_checkpoint path_to_specific_checkpoint \
--predict_with_generate
๋ชจ๋ธ ๊ณต์ ํ๊ธฐ[[share-your-model]]
๋ชจ๋ ์คํฌ๋ฆฝํธ๋ ์ต์ข ๋ชจ๋ธ์ Model Hub์ ์ ๋ก๋ํ ์ ์์ต๋๋ค. ์์ํ๊ธฐ ์ ์ Hugging Face์ ๋ก๊ทธ์ธํ๋์ง ํ์ธํ์ธ์:
huggingface-cli login
๊ทธ๋ฐ ๋ค์ ์คํฌ๋ฆฝํธ์ push_to_hub
์ธ์๋ฅผ ์ถ๊ฐํฉ๋๋ค.
์ด ์ธ์๋ Hugging Face ์ฌ์ฉ์ ์ด๋ฆ๊ณผ output_dir
์ ์ง์ ๋ ํด๋ ์ด๋ฆ์ผ๋ก ์ ์ฅ์๋ฅผ ์์ฑํฉ๋๋ค.
์ ์ฅ์์ ํน์ ์ด๋ฆ์ ์ง์ ํ๋ ค๋ฉด push_to_hub_model_id
์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๊ฐํฉ๋๋ค.
์ ์ฅ์๋ ๋ค์์คํ์ด์ค ์๋์ ์๋์ผ๋ก ๋์ด๋ฉ๋๋ค.
๋ค์ ์๋ ํน์ ์ ์ฅ์ ์ด๋ฆ์ผ๋ก ๋ชจ๋ธ์ ์
๋ก๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค:
python examples/pytorch/summarization/run_summarization.py
--model_name_or_path t5-small \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--push_to_hub \
--push_to_hub_model_id finetuned-t5-cnn_dailymail \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate