File size: 1,715 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
    """load the latest checkpoints and optimizers"""
    model_dict = {}
    optimizer_dict = {}
    
    # globt all the checkpoints in the directory
    for file in os.listdir(checkpoint_path):
        if file.endswith(".pt") and '_' in file:
            name, epoch_str = file.rsplit('_', 1)
            epoch = int(epoch_str.split('.')[0])
            
            if name.startswith("checkpoint"):
                model_dict[epoch] = file
            elif name.startswith("optimizer"):
                optimizer_dict[epoch] = file
    
    # get the largest epoch
    common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
    if common_epochs:
        max_epoch = max(common_epochs)
        model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
        optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
        
        # load model and optimizer
        model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
        optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
        
        print(f'resume model and optimizer from {max_epoch} epoch')
        return max_epoch + 1
    
    else:
        # load pretrained checkpoint
        if model_dict:
            model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
            model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
            
        return 0