File size: 364 Bytes
a476bbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import os
from typing import Dict
from transformers import TrainingArguments
def get_training_args(config_dict: Dict) -> TrainingArguments:
config = TrainingArguments(**config_dict)
if not os.path.isdir(config.output_dir):
print(f"creating checkpoint directory at {config.output_dir}")
os.makedirs(config.output_dir)
return config
|