File size: 458 Bytes
1eced3c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from test_model import TestModels
from train import TrainModel
from config import DatasetName, DatasetType
if __name__ == "__main__":

    '''testing the pre-trained models'''
    tester = TestModels(h5_address='./trained_models/AffectNet_6336.h5')
    tester.recognize_fer(img_path='./img.jpg')

    '''training part'''
    trainer = TrainModel(dataset_name=DatasetName.affectnet, ds_type=DatasetType.train_7)
    trainer.train(arch="xcp", weight_path="./")