FrankZxShen
commited on
Commit
•
0f66f70
1
Parent(s):
29755d6
Update utils.py
Browse files
utils.py
CHANGED
@@ -10,8 +10,6 @@ from scipy.io.wavfile import read
|
|
10 |
import torch
|
11 |
import regex as re
|
12 |
|
13 |
-
import loralib as lora
|
14 |
-
|
15 |
MATPLOTLIB_FLAG = False
|
16 |
|
17 |
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
@@ -146,91 +144,6 @@ def tag_cke(text,prev_sentence=None):
|
|
146 |
return prev_lang,tagged_text
|
147 |
|
148 |
|
149 |
-
def load_lora_checkpoint(checkpoint_path, model, optimizer=None, generator_path = "./pretrained_models/G_latest.pth"):
|
150 |
-
assert os.path.isfile(checkpoint_path)
|
151 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
152 |
-
iteration = checkpoint_dict['iteration']
|
153 |
-
learning_rate = checkpoint_dict['learning_rate']
|
154 |
-
if optimizer is not None:
|
155 |
-
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
156 |
-
# saved_state_dict = checkpoint_dict['model']
|
157 |
-
generator_state_dict = torch.load(generator_path)['model']
|
158 |
-
lora_state_dict = checkpoint_dict['model']
|
159 |
-
new_state_dict = {}
|
160 |
-
for k, v in lora_state_dict.items():
|
161 |
-
try:
|
162 |
-
if k == 'emb_g.weight':
|
163 |
-
if drop_speaker_emb:
|
164 |
-
new_state_dict[k] = v
|
165 |
-
continue
|
166 |
-
v[:lora_state_dict[k].shape[0], :] = lora_state_dict[k]
|
167 |
-
new_state_dict[k] = v
|
168 |
-
else:
|
169 |
-
new_state_dict[k] = lora_state_dict[k]
|
170 |
-
except:
|
171 |
-
logger.info("%s is not in the checkpoint" % k)
|
172 |
-
new_state_dict[k] = v
|
173 |
-
if hasattr(model, 'module'):
|
174 |
-
model.module.load_state_dict(generator_state_dict, strict=False)
|
175 |
-
model.module.load_state_dict(new_state_dict, strict=False)
|
176 |
-
# lora.mark_only_lora_as_trainable(model.module)
|
177 |
-
else:
|
178 |
-
model.load_state_dict(generator_state_dict, strict=False)
|
179 |
-
model.load_state_dict(new_state_dict, strict=False)
|
180 |
-
# lora.mark_only_lora_as_trainable(model)
|
181 |
-
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
|
182 |
-
checkpoint_path, iteration))
|
183 |
-
|
184 |
-
return model, optimizer, learning_rate, iteration
|
185 |
-
|
186 |
-
def save_lora_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
187 |
-
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
188 |
-
iteration, checkpoint_path))
|
189 |
-
if hasattr(model, 'module'):
|
190 |
-
state_dict = lora.lora_state_dict(model.module)
|
191 |
-
else:
|
192 |
-
state_dict = lora.lora_state_dict(model)
|
193 |
-
torch.save({'model': state_dict,
|
194 |
-
'iteration': iteration,
|
195 |
-
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
196 |
-
'learning_rate': learning_rate}, checkpoint_path)
|
197 |
-
|
198 |
-
def load_lora_checkpoint_fix(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
|
199 |
-
assert os.path.isfile(checkpoint_path)
|
200 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
201 |
-
iteration = checkpoint_dict['iteration']
|
202 |
-
learning_rate = checkpoint_dict['learning_rate']
|
203 |
-
if optimizer is not None:
|
204 |
-
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
205 |
-
saved_state_dict = checkpoint_dict['model']
|
206 |
-
if hasattr(model, 'module'):
|
207 |
-
state_dict = model.module.state_dict()
|
208 |
-
else:
|
209 |
-
state_dict = model.state_dict()
|
210 |
-
new_state_dict = {}
|
211 |
-
for k, v in state_dict.items():
|
212 |
-
try:
|
213 |
-
if k == 'emb_g.weight':
|
214 |
-
if drop_speaker_emb:
|
215 |
-
new_state_dict[k] = v
|
216 |
-
continue
|
217 |
-
v[:saved_state_dict[k].shape[0], :] = saved_state_dict[k]
|
218 |
-
new_state_dict[k] = v
|
219 |
-
else:
|
220 |
-
new_state_dict[k] = saved_state_dict[k]
|
221 |
-
except:
|
222 |
-
logger.info("%s is not in the checkpoint" % k)
|
223 |
-
new_state_dict[k] = v
|
224 |
-
if hasattr(model, 'module'):
|
225 |
-
model.module.load_state_dict(new_state_dict, strict=False)
|
226 |
-
lora.mark_only_lora_as_trainable(model.module)
|
227 |
-
else:
|
228 |
-
model.load_state_dict(new_state_dict, strict=False)
|
229 |
-
lora.mark_only_lora_as_trainable(model)
|
230 |
-
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
231 |
-
checkpoint_path, iteration))
|
232 |
-
return model, optimizer, learning_rate, iteration
|
233 |
-
|
234 |
def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
|
235 |
assert os.path.isfile(checkpoint_path)
|
236 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
|
|
10 |
import torch
|
11 |
import regex as re
|
12 |
|
|
|
|
|
13 |
MATPLOTLIB_FLAG = False
|
14 |
|
15 |
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
|
|
144 |
return prev_lang,tagged_text
|
145 |
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
|
148 |
assert os.path.isfile(checkpoint_path)
|
149 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|