File size: 458 Bytes
864ec44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Author: Bingxin Ke
# Last modified: 2024-05-17

from .marigold_trainer import MarigoldTrainer
from .marigold_xl_trainer import MarigoldXLTrainer
from .marigold_inpaint_trainer import MarigoldInpaintTrainer

trainer_cls_name_dict = {
    "MarigoldTrainer": MarigoldTrainer,
    "MarigoldXLTrainer": MarigoldXLTrainer,
    "MarigoldInpaintTrainer": MarigoldInpaintTrainer
}


def get_trainer_cls(trainer_name):
    return trainer_cls_name_dict[trainer_name]