File size: 458 Bytes
864ec44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
# Author: Bingxin Ke
# Last modified: 2024-05-17
from .marigold_trainer import MarigoldTrainer
from .marigold_xl_trainer import MarigoldXLTrainer
from .marigold_inpaint_trainer import MarigoldInpaintTrainer
trainer_cls_name_dict = {
"MarigoldTrainer": MarigoldTrainer,
"MarigoldXLTrainer": MarigoldXLTrainer,
"MarigoldInpaintTrainer": MarigoldInpaintTrainer
}
def get_trainer_cls(trainer_name):
return trainer_cls_name_dict[trainer_name]
|