StableTTS1.1 / utils /load.py
KdaiP's picture
Upload 80 files
3dd84f8 verified
raw
history blame
1.72 kB
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
"""load the latest checkpoints and optimizers"""
model_dict = {}
optimizer_dict = {}
# globt all the checkpoints in the directory
for file in os.listdir(checkpoint_path):
if file.endswith(".pt") and '_' in file:
name, epoch_str = file.rsplit('_', 1)
epoch = int(epoch_str.split('.')[0])
if name.startswith("checkpoint"):
model_dict[epoch] = file
elif name.startswith("optimizer"):
optimizer_dict[epoch] = file
# get the largest epoch
common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
if common_epochs:
max_epoch = max(common_epochs)
model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
# load model and optimizer
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
print(f'resume model and optimizer from {max_epoch} epoch')
return max_epoch + 1
else:
# load pretrained checkpoint
if model_dict:
model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
return 0