Spaces:
Runtime error
Runtime error
Update app_multi.py
Browse files- app_multi.py +5 -4
app_multi.py
CHANGED
@@ -106,11 +106,12 @@ for model_name in multi_cfg.get('models'):
|
|
106 |
if 'enc_p.emb_phone.weight' in cpt['weight']:
|
107 |
old_shape = cpt['weight']['enc_p.emb_phone.weight'].shape
|
108 |
new_shape = net_g.enc_p.emb_phone.weight.shape
|
109 |
-
|
110 |
-
print(f"
|
111 |
weight = cpt['weight']['enc_p.emb_phone.weight']
|
112 |
-
|
113 |
-
|
|
|
114 |
|
115 |
del net_g.enc_q
|
116 |
|
|
|
106 |
if 'enc_p.emb_phone.weight' in cpt['weight']:
|
107 |
old_shape = cpt['weight']['enc_p.emb_phone.weight'].shape
|
108 |
new_shape = net_g.enc_p.emb_phone.weight.shape
|
109 |
+
if old_shape != new_shape:
|
110 |
+
print(f"Upgrading enc_p.emb_phone.weight size: {old_shape} -> {new_shape}")
|
111 |
weight = cpt['weight']['enc_p.emb_phone.weight']
|
112 |
+
upgraded_weight = torch.zeros(new_shape)
|
113 |
+
upgraded_weight[:, :old_shape[1]] = weight
|
114 |
+
cpt['weight']['enc_p.emb_phone.weight'] = upgraded_weight
|
115 |
|
116 |
del net_g.enc_q
|
117 |
|