Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int: | |
"""load the latest checkpoints and optimizers""" | |
model_dict = {} | |
optimizer_dict = {} | |
# globt all the checkpoints in the directory | |
for file in os.listdir(checkpoint_path): | |
if file.endswith(".pt") and '_' in file: | |
name, epoch_str = file.rsplit('_', 1) | |
epoch = int(epoch_str.split('.')[0]) | |
if name.startswith("checkpoint"): | |
model_dict[epoch] = file | |
elif name.startswith("optimizer"): | |
optimizer_dict[epoch] = file | |
# get the largest epoch | |
common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys()) | |
if common_epochs: | |
max_epoch = max(common_epochs) | |
model_path = os.path.join(checkpoint_path, model_dict[max_epoch]) | |
optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch]) | |
# load model and optimizer | |
model.module.load_state_dict(torch.load(model_path, map_location='cpu')) | |
optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu')) | |
print(f'resume model and optimizer from {max_epoch} epoch') | |
return max_epoch + 1 | |
else: | |
# load pretrained checkpoint | |
if model_dict: | |
model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())]) | |
model.module.load_state_dict(torch.load(model_path, map_location='cpu')) | |
return 0 |