BartPoint commited on
Commit
00e5cd7
1 Parent(s): 9e23608

Update app_multi.py

Browse files
Files changed (1) hide show
  1. app_multi.py +5 -4
app_multi.py CHANGED
@@ -106,11 +106,12 @@ for model_name in multi_cfg.get('models'):
106
  if 'enc_p.emb_phone.weight' in cpt['weight']:
107
  old_shape = cpt['weight']['enc_p.emb_phone.weight'].shape
108
  new_shape = net_g.enc_p.emb_phone.weight.shape
109
- if old_shape != new_shape:
110
- print(f"Resizing enc_p.emb_phone.weight: {old_shape} -> {new_shape}")
111
  weight = cpt['weight']['enc_p.emb_phone.weight']
112
- resized_weight = weight[:, :new_shape[1]].resize_(new_shape)
113
- cpt['weight']['enc_p.emb_phone.weight'] = resized_weight
 
114
 
115
  del net_g.enc_q
116
 
 
106
  if 'enc_p.emb_phone.weight' in cpt['weight']:
107
  old_shape = cpt['weight']['enc_p.emb_phone.weight'].shape
108
  new_shape = net_g.enc_p.emb_phone.weight.shape
109
+ if old_shape != new_shape:
110
+ print(f"Upgrading enc_p.emb_phone.weight size: {old_shape} -> {new_shape}")
111
  weight = cpt['weight']['enc_p.emb_phone.weight']
112
+ upgraded_weight = torch.zeros(new_shape)
113
+ upgraded_weight[:, :old_shape[1]] = weight
114
+ cpt['weight']['enc_p.emb_phone.weight'] = upgraded_weight
115
 
116
  del net_g.enc_q
117