mrfakename commited on
Commit
0374441
1 Parent(s): 38f1310
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Data/OOD_texts.txt filter=lfs diff=lfs merge=lfs -text
37
+ Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
38
+ Utils/PLBERT/step_1000000.t7 filter=lfs diff=lfs merge=lfs -text
API_DOCS.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Coming soon
Configs/config.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 100 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 400 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: "/local/LJSpeech-1.1/wavs"
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: false
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'istftnet' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10, 6]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+
59
+ # speech language model config
60
+ slm:
61
+ model: 'microsoft/wavlm-base-plus'
62
+ sr: 16000 # sampling rate of SLM
63
+ hidden: 768 # hidden size of SLM
64
+ nlayers: 13 # number of layers of SLM
65
+ initial_channel: 64 # initial channels of SLM discriminator head
66
+
67
+ # style diffusion model config
68
+ diffusion:
69
+ embedding_mask_proba: 0.1
70
+ # transformer config
71
+ transformer:
72
+ num_layers: 3
73
+ num_heads: 8
74
+ head_features: 64
75
+ multiplier: 2
76
+
77
+ # diffusion distribution config
78
+ dist:
79
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
80
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
81
+ mean: -3.0
82
+ std: 1.0
83
+
84
+ loss_params:
85
+ lambda_mel: 5. # mel reconstruction loss
86
+ lambda_gen: 1. # generator loss
87
+ lambda_slm: 1. # slm feature matching loss
88
+
89
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
90
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
91
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
92
+
93
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
94
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
95
+ lambda_dur: 1. # duration loss (2nd stage)
96
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
97
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
98
+ lambda_diff: 1. # score matching loss (2nd stage)
99
+
100
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
101
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
102
+
103
+ optimizer_params:
104
+ lr: 0.0001 # general learning rate
105
+ bert_lr: 0.00001 # learning rate for PLBERT
106
+ ft_lr: 0.00001 # learning rate for acoustic modules
107
+
108
+ slmadv_params:
109
+ min_len: 400 # minimum length of samples
110
+ max_len: 500 # maximum length of samples
111
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
112
+ iter: 10 # update the discriminator every this iterations of generator update
113
+ thresh: 5 # gradient norm above which the gradient is scaled
114
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
115
+ sig: 1.5 # sigma for differentiable duration modeling
116
+
Configs/config_ft.yml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ save_freq: 5
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50 # number of finetuning epoch (1 hour of data)
6
+ batch_size: 8
7
+ max_len: 400 # maximum number of frames
8
+ pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth"
9
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
10
+ load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
11
+
12
+ F0_path: "Utils/JDC/bst.t7"
13
+ ASR_config: "Utils/ASR/config.yml"
14
+ ASR_path: "Utils/ASR/epoch_00080.pth"
15
+ PLBERT_dir: 'Utils/PLBERT/'
16
+
17
+ data_params:
18
+ train_data: "Data/train_list.txt"
19
+ val_data: "Data/val_list.txt"
20
+ root_path: "/local/LJSpeech-1.1/wavs"
21
+ OOD_data: "Data/OOD_texts.txt"
22
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
23
+
24
+ preprocess_params:
25
+ sr: 24000
26
+ spect_params:
27
+ n_fft: 2048
28
+ win_length: 1200
29
+ hop_length: 300
30
+
31
+ model_params:
32
+ multispeaker: true
33
+
34
+ dim_in: 64
35
+ hidden_dim: 512
36
+ max_conv_dim: 512
37
+ n_layer: 3
38
+ n_mels: 80
39
+
40
+ n_token: 178 # number of phoneme tokens
41
+ max_dur: 50 # maximum duration of a single phoneme
42
+ style_dim: 128 # style vector size
43
+
44
+ dropout: 0.2
45
+
46
+ # config for decoder
47
+ decoder:
48
+ type: 'hifigan' # either hifigan or istftnet
49
+ resblock_kernel_sizes: [3,7,11]
50
+ upsample_rates : [10,5,3,2]
51
+ upsample_initial_channel: 512
52
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
53
+ upsample_kernel_sizes: [20,10,6,4]
54
+
55
+ # speech language model config
56
+ slm:
57
+ model: 'microsoft/wavlm-base-plus'
58
+ sr: 16000 # sampling rate of SLM
59
+ hidden: 768 # hidden size of SLM
60
+ nlayers: 13 # number of layers of SLM
61
+ initial_channel: 64 # initial channels of SLM discriminator head
62
+
63
+ # style diffusion model config
64
+ diffusion:
65
+ embedding_mask_proba: 0.1
66
+ # transformer config
67
+ transformer:
68
+ num_layers: 3
69
+ num_heads: 8
70
+ head_features: 64
71
+ multiplier: 2
72
+
73
+ # diffusion distribution config
74
+ dist:
75
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
76
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
77
+ mean: -3.0
78
+ std: 1.0
79
+
80
+ loss_params:
81
+ lambda_mel: 5. # mel reconstruction loss
82
+ lambda_gen: 1. # generator loss
83
+ lambda_slm: 1. # slm feature matching loss
84
+
85
+ lambda_mono: 1. # monotonic alignment loss (TMA)
86
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
87
+
88
+ lambda_F0: 1. # F0 reconstruction loss
89
+ lambda_norm: 1. # norm reconstruction loss
90
+ lambda_dur: 1. # duration loss
91
+ lambda_ce: 20. # duration predictor probability output CE loss
92
+ lambda_sty: 1. # style reconstruction loss
93
+ lambda_diff: 1. # score matching loss
94
+
95
+ diff_epoch: 10 # style diffusion starting epoch
96
+ joint_epoch: 30 # joint training starting epoch
97
+
98
+ optimizer_params:
99
+ lr: 0.0001 # general learning rate
100
+ bert_lr: 0.00001 # learning rate for PLBERT
101
+ ft_lr: 0.0001 # learning rate for acoustic modules
102
+
103
+ slmadv_params:
104
+ min_len: 400 # minimum length of samples
105
+ max_len: 500 # maximum length of samples
106
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
107
+ iter: 10 # update the discriminator every this iterations of generator update
108
+ thresh: 5 # gradient norm above which the gradient is scaled
109
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
110
+ sig: 1.5 # sigma for differentiable duration modeling
111
+
Configs/config_libritts.yml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LibriTTS"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 50 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 30 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 300 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: ""
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'hifigan' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10,5,3,2]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20,10,6,4]
56
+
57
+ # speech language model config
58
+ slm:
59
+ model: 'microsoft/wavlm-base-plus'
60
+ sr: 16000 # sampling rate of SLM
61
+ hidden: 768 # hidden size of SLM
62
+ nlayers: 13 # number of layers of SLM
63
+ initial_channel: 64 # initial channels of SLM discriminator head
64
+
65
+ # style diffusion model config
66
+ diffusion:
67
+ embedding_mask_proba: 0.1
68
+ # transformer config
69
+ transformer:
70
+ num_layers: 3
71
+ num_heads: 8
72
+ head_features: 64
73
+ multiplier: 2
74
+
75
+ # diffusion distribution config
76
+ dist:
77
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
78
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
79
+ mean: -3.0
80
+ std: 1.0
81
+
82
+ loss_params:
83
+ lambda_mel: 5. # mel reconstruction loss
84
+ lambda_gen: 1. # generator loss
85
+ lambda_slm: 1. # slm feature matching loss
86
+
87
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
88
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
89
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
90
+
91
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
92
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
93
+ lambda_dur: 1. # duration loss (2nd stage)
94
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
95
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
96
+ lambda_diff: 1. # score matching loss (2nd stage)
97
+
98
+ diff_epoch: 10 # style diffusion starting epoch (2nd stage)
99
+ joint_epoch: 15 # joint training starting epoch (2nd stage)
100
+
101
+ optimizer_params:
102
+ lr: 0.0001 # general learning rate
103
+ bert_lr: 0.00001 # learning rate for PLBERT
104
+ ft_lr: 0.00001 # learning rate for acoustic modules
105
+
106
+ slmadv_params:
107
+ min_len: 400 # minimum length of samples
108
+ max_len: 500 # maximum length of samples
109
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
110
+ iter: 20 # update the discriminator every this iterations of generator update
111
+ thresh: 5 # gradient norm above which the gradient is scaled
112
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
113
+ sig: 1.5 # sigma for differentiable duration modeling
Data/OOD_texts.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0989ef6a9873b711befefcbe60660ced7a65532359277f766f4db504c558a72
3
+ size 31758898
Data/train_list.txt ADDED
The diff for this file is too large to render. See raw diff
 
Data/val_list.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹᵻɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wʌt ðeɪ hˈɪɹ ænd wʌt ðeɪ ɹˈiːd .|0
2
+ LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː , ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt , tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ , ænd ˈɔːl ðə fˈɜːnɪtʃɚ , aɪ wʊd biː mˈæd æz hˈɛl , tˈuː .|0
3
+ LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹᵻpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪŋkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn ˈeɪtiːn θˈɜːɾi fˈaɪv .|0
4
+ LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹᵻspˈɛkt :|0
5
+ LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹᵻspˈɛkt wʌz tə θɹˈoʊ ðə ɹᵻspˌɑːnsəbˈɪlɪɾi ˌɔn ˈʌðɚz .|0
6
+ LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛl ɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌŋkənvˈɪktᵻd pɹˈɪzənɚ , ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt , ænd stˈɪl ʌŋkəntˈæmᵻnˌeɪɾᵻd ,|0
7
+ LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔː stˈeɪʃənɚz . hɪz ɚɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz .|0
8
+ LJ047-0044.wav|ˈɑːswəld wʌz , haʊˈɛvɚ , wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz . hiː dᵻnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz|0
9
+ LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ . tʃˈɑːɹlz dʒˈeɪ . kˈæɹɪkˌoʊ , ɐ ɹˈɛzᵻdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi .|0
10
+ LJ048-0194.wav|dˈʊɹɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛnti tˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd .|0
11
+ LJ049-0026.wav|ˌɔn əkˈeɪʒən ðə sˈiːkɹᵻt sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt .|0
12
+ LJ004-0152.wav|ɔːlðˈoʊ æt mˈɪstɚ . bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən , ðə fˈɜːst stˈɛp təwˈɔːɹdz ɹᵻfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˈɛvəntˌiːn sˈɛvənti fˈoːɹ .|0
13
+ LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni , ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsᵻsɚɹi tə dˈɑːlɚ mˌeɪk ɐn ɛɡzˈæmpəl.dˈɑːlɚ|0
14
+ LJ043-0002.wav|ðə wˈɔːɹəŋ kəmˈɪʃən ɹᵻpˈoːɹt . baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɔnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi . tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld :|0
15
+ LJ009-0114.wav|mˈɪstɚ . wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dᵻskɹˈaɪbɪŋ ɐnˈʌðɚ ɹᵻlˈɪdʒəs sˈɜːvɪs , wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪˌeɪtli biː ɪnsˈɜːɾᵻd hˈɪɹ .|0
16
+ LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk .|0
17
+ LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd . ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzᵻz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə|0
18
+ LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp , hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪliˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən|0
19
+ LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl , kwˈoʊt , wiː hæd ɐ mˈoʊɾɚkˌeɪd wɛɹˈɛvɚ kplˈʌsplʌs wˌɪtʃ hɐdbɪn bˌɪn hˈeɪstili sˈʌmənd fɚðə ðə pˈɜːpəs wiː wˈɛnt , ˈɛnd kwˈoʊt .|0
20
+ LJ031-0070.wav|dˈɑːktɚ . klˈɑːɹk , hˌuː mˈoʊst klˈoʊsli əbzˈɜːvd ðə hˈɛd wˈuːnd ,|0
21
+ LJ034-0198.wav|jˈuːɪnz , hˌuː wʌz ɔnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstᵻfˌaɪd ðæt hiː kʊd nˌɑːt dᵻskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ .|0
22
+ LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt , tʊ ɐ smˈɔːl ɛkstˈɛnt ,|0
23
+ LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɔnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsᵻsɚɹi .|0
24
+ LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd|0
25
+ LJ005-0014.wav|spˈiːkɪŋ ˌɔn ɐ dᵻbˈeɪt ˌɔn pɹˈɪzən mˈæɾɚz , hiː dᵻklˈɛɹd ðˈæt|0
26
+ LJ012-0161.wav|hiː wʌz ɹᵻpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ .|0
27
+ LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹᵻpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹᵻfˈɜːd tuː|0
28
+ LJ019-0257.wav|hˈɪɹ ðə tɹˈɛd wˈiːl wʌz ɪn jˈuːs , ðɛɹ sˈɛljʊlɚ kɹˈæŋks , ɔːɹ hˈɑːɹd lˈeɪbɚ məʃˈiːnz .|0
29
+ LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɔn .|0
30
+ LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɔnðə kˈoːɹt ;|0
31
+ LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz , nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz . aɪ hæv hæd ɪnˈʌf .|0
32
+ LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp .|0
33
+ LJ046-0058.wav|dˈʊɹɹɪŋ hɪz pɹˈɛzɪdənsi , fɹˈæŋklɪn dˈiː . ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹɪd dʒˈɜːniz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹɪd fˈɪfti θˈaʊzənd mˈaɪlz .|0
34
+ LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ , ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv .|0
35
+ LJ002-0043.wav|lˈɔŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾi sˈɪks fˈiːt , sˈɪks twˈɛnti θɹˈiː fˈiːt , ænd ðɪ ˈeɪtθ ˈeɪtiːn ,|0
36
+ LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən .|0
37
+ LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hæd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹᵻpɹˈiːv , ænd wɪðˌɪn ɐ fjˈuː ˈaʊɚz ʌv ˌɛksɪkjˈuːʃən .|0
38
+ LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹᵻt sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹᵻlˈiːst ɔːɹ ɛskˈeɪps .|0
39
+ LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ , ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt .|0
40
+ LJ042-0096.wav|ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt|0
41
+ LJ049-0050.wav|hˈɪl hæd bˈoʊθ fˈiːt ɔnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mˈɪsɪz . kˈɛnədi .|0
42
+ LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt , nˈuːɡeɪt ɹᵻsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntiz ,|0
43
+ LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs , ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsəŋ ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd .|0
44
+ LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd .|0
45
+ LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kəŋklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɔnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld .|0
46
+ LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən .|0
47
+ LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt , ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ᵻlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm ?|0
48
+ LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪɾ ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz .|0
49
+ LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪst ænd ɹᵻpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɹiz ʌvðə sˈɪɾi ʌv lˈʌndən ,|0
50
+ LJ028-0275.wav|æt lˈæst , ɪnðə twˈɛntiəθ mˈʌnθ ,|0
51
+ LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋ plˈeɪs wɪð ɐ tɹˈæp dˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd .|0
52
+ LJ011-0096.wav|hiː mˈæɹid ɐ lˈeɪdi ˈɔːlsoʊ bᵻlˈɔŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz , hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃʊn , wˈɪtʃ , ænd hɪz ˈoʊn mˈʌni , hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm ,|0
53
+ LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː . kɹˈeɪɡ , ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti ,|0
54
+ LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz , ɡɹˈeɪt lˈɔɪɚz , ɡˈʌvɚnɚz ʌv pɹˈɪzənz , ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː .|0
55
+ LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst , ɐ səspˈɪʃəs sˈɜːkəmstˌæns , æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ .|0
56
+ LJ027-0141.wav|ɪz klˈoʊsli ɹᵻpɹədˈuːst ɪnðə lˈaɪf hˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ . ɔːɹ , ɪn ˈʌðɚ wˈɜːdz ,|0
57
+ LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi , ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz .|0
58
+ LJ031-0202.wav|mˈɪsɪz . kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hæd sˈɜːvd ɪnðə nˈeɪvi .|0
59
+ LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈo��pt fɔːɹ pˈiəɹɪəd ʌv pˈiːs ,|0
60
+ LJ016-0288.wav|dˈɑːlɚ mˈuːlɚ , mˈuːlɚ , hiːz ðə mˈæn , dˈɑːlɚ tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz , wˌɪtʃ wʌz ɹᵻsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz .|0
61
+ LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ , wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdili dɪstˈɪŋɡwɪʃ ðə fˈɔls fɹʌmðə tɹˈuː ,|0
62
+ LJ018-0081.wav|hɪz dᵻfˈɛns bˌiːɪŋ ðæt hiː hæd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd , bˌʌt ðˈæt , ɔnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hæd ɹˈɔŋd hˌɪm ,|0
63
+ LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪŋkɹiːs ɪnðə pˈeɪɹoʊlz , ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts|0
64
+ LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp , bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd , ænd ðə mˈæn wʌz kˈæɹid bˈæk tə dʒˈeɪl .|0
65
+ LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz , ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz .|0
66
+ LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən , ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl .|0
67
+ LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs , ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts , ðə hˈaʊskiːpɚ ðˈɛɹ .|0
68
+ LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛnti tˈuː , nˈaɪntiːn sˈɪksti θɹˈiː , fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈɪɹiəl fɚðə pˌiːˌɑːɹɹˈɛs dʒˈɛnɚɹəl fˈaɪlz|0
69
+ LJ017-0044.wav|ænd ðə dˈiːpɪst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm , ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn , ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ .|0
70
+ LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ , ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn , ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ .|0
71
+ LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɚɹˈɛstᵻd ˌɔn səspˈɪʃən , ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd ;|0
72
+ LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn , bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd , ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sᵻvˈɪɹli .|0
73
+ LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹihˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ . ɔːlðˈoʊ ɪɾ ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt ,|0
74
+ LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm .|0
75
+ LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹᵻkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɔŋ ɪn səspˈɛns .|0
76
+ LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dᵻfˈɜːd .|0
77
+ LJ047-0148.wav|ˌɔn ɑːktˈoʊbɚ twˈɛnti fˈaɪv ,|0
78
+ LJ008-0111.wav|ðeɪ ˈɛntɚd ɐ dˈɑːlɚ stˈoʊŋ kˈoʊld ɹˈuːm , dˈɑːlɚɹ ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ .|0
79
+ LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstᵻfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld .|0
80
+ LJ037-0234.wav|mˈɪsɪz . mˈɛɹi bɹˈɑːk , ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən , wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl ,|0
81
+ LJ040-0002.wav|tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld : bˈækɡɹaʊnd ænd pˈɑːsᵻbəl mˈoʊɾɪvz , pˈɑːɹt wˌʌn .|0
82
+ LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstᵻfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bᵻkˈʌmɪŋ ɪnvˈɑːlvd|0
83
+ LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɔn wˈɑːtʃᵻz , wɜː kˈɛɹfəli ɹᵻmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz .|0
84
+ LJ012-0250.wav|ɔnðə sˈɛvənθ dʒuːlˈaɪ , ˈeɪtiːn θˈɜːɾi sˈɛvən ,|0
85
+ LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈiːnɚz tə wˈɜːk baɪ ðə dʒˈɑːb .|0
86
+ LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən .|0
87
+ LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ᵻsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi .|0
88
+ LJ031-0134.wav|ˌɔn wˈʌn əkˈeɪʒən mˈɪsɪz . dʒˈɑːnsən , ɐkˈʌmpənid baɪ tˈuː sˈiːkɹᵻt sˈɜːvɪs ˈeɪdʒənts , lˈɛft ðə ɹˈuːm tə sˈiː mˈɪsɪz . kˈɛnədi ænd mˈɪsɪz . kˈɑːnæli .|0
89
+ LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn .|0
90
+ LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd , ˈoʊpənd , ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts .|0
91
+ LJ034-0160.wav|ˌɔn bɹˈɛnənz sˈʌbsᵻkwənt sˈɜːʔn̩ aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl .|0
92
+ LJ038-0199.wav|ᵻlˈɛvən . ɪf aɪɐm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ ,|0
93
+ LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈæd fɔːɹ hˌɪm , ænd ɹᵻmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm .|0
94
+ LJ033-0047.wav|aɪ nˈoʊɾɪst wɛn aɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɔn , ˈɛnd kwˈoʊt ,|0
95
+ LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ .|0
96
+ LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli .|0
97
+ LJ003-0111.wav|hiː wʌz ɪŋ kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː , ˈɛnd kwˈoʊt . ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɹɪˈɔsɪɾi .|0
98
+ LJ008-0258.wav|lˈɛt mˌiː ɹᵻtɹˈeɪs maɪ stˈɛps , ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz ,|0
99
+ LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæŋ kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt , mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs , fˈɔːɹt wˈɜːθ , sˌæn æntˈoʊnɪˌoʊ , ænd hjˈuːstən .|0
100
+ LJ004-0045.wav|mˈɪstɚ . stˈɜːdʒᵻz bˈoːɹn , sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ , sˌɜː dʒˈeɪmz skˈɑːɹlɪt , ænd wˈɪljəm wˈɪlbɚfˌoːɹs .|0
Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
93
+
94
+
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self, style_dim, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.fc = nn.Linear(style_dim, channels*2)
25
+
26
+ def forward(self, x, s):
27
+ x = x.transpose(-1, -2)
28
+ x = x.transpose(1, -1)
29
+
30
+ h = self.fc(s)
31
+ h = h.view(h.size(0), h.size(1), 1)
32
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
33
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
34
+
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+ class StyleTransformer1d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_layers: int,
44
+ channels: int,
45
+ num_heads: int,
46
+ head_features: int,
47
+ multiplier: int,
48
+ use_context_time: bool = True,
49
+ use_rel_pos: bool = False,
50
+ context_features_multiplier: int = 1,
51
+ rel_pos_num_buckets: Optional[int] = None,
52
+ rel_pos_max_distance: Optional[int] = None,
53
+ context_features: Optional[int] = None,
54
+ context_embedding_features: Optional[int] = None,
55
+ embedding_max_length: int = 512,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.blocks = nn.ModuleList(
60
+ [
61
+ StyleTransformerBlock(
62
+ features=channels + context_embedding_features,
63
+ head_features=head_features,
64
+ num_heads=num_heads,
65
+ multiplier=multiplier,
66
+ style_dim=context_features,
67
+ use_rel_pos=use_rel_pos,
68
+ rel_pos_num_buckets=rel_pos_num_buckets,
69
+ rel_pos_max_distance=rel_pos_max_distance,
70
+ )
71
+ for i in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ self.to_out = nn.Sequential(
76
+ Rearrange("b t c -> b c t"),
77
+ nn.Conv1d(
78
+ in_channels=channels + context_embedding_features,
79
+ out_channels=channels,
80
+ kernel_size=1,
81
+ ),
82
+ )
83
+
84
+ use_context_features = exists(context_features)
85
+ self.use_context_features = use_context_features
86
+ self.use_context_time = use_context_time
87
+
88
+ if use_context_time or use_context_features:
89
+ context_mapping_features = channels + context_embedding_features
90
+
91
+ self.to_mapping = nn.Sequential(
92
+ nn.Linear(context_mapping_features, context_mapping_features),
93
+ nn.GELU(),
94
+ nn.Linear(context_mapping_features, context_mapping_features),
95
+ nn.GELU(),
96
+ )
97
+
98
+ if use_context_time:
99
+ assert exists(context_mapping_features)
100
+ self.to_time = nn.Sequential(
101
+ TimePositionalEmbedding(
102
+ dim=channels, out_features=context_mapping_features
103
+ ),
104
+ nn.GELU(),
105
+ )
106
+
107
+ if use_context_features:
108
+ assert exists(context_features) and exists(context_mapping_features)
109
+ self.to_features = nn.Sequential(
110
+ nn.Linear(
111
+ in_features=context_features, out_features=context_mapping_features
112
+ ),
113
+ nn.GELU(),
114
+ )
115
+
116
+ self.fixed_embedding = FixedEmbedding(
117
+ max_length=embedding_max_length, features=context_embedding_features
118
+ )
119
+
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+
146
+ mapping = self.get_mapping(time, features)
147
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
148
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
149
+
150
+ for block in self.blocks:
151
+ x = x + mapping
152
+ x = block(x, features)
153
+
154
+ x = x.mean(axis=1).unsqueeze(1)
155
+ x = self.to_out(x)
156
+ x = x.transpose(-1, -2)
157
+
158
+ return x
159
+
160
+ def forward(self, x: Tensor,
161
+ time: Tensor,
162
+ embedding_mask_proba: float = 0.0,
163
+ embedding: Optional[Tensor] = None,
164
+ features: Optional[Tensor] = None,
165
+ embedding_scale: float = 1.0) -> Tensor:
166
+
167
+ b, device = embedding.shape[0], embedding.device
168
+ fixed_embedding = self.fixed_embedding(embedding)
169
+ if embedding_mask_proba > 0.0:
170
+ # Randomly mask embedding
171
+ batch_mask = rand_bool(
172
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
173
+ )
174
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
+
176
+ if embedding_scale != 1.0:
177
+ # Compute both normal and fixed embedding outputs
178
+ out = self.run(x, time, embedding=embedding, features=features)
179
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
+ # Scale conditional output using classifier-free guidance
181
+ return out_masked + (out - out_masked) * embedding_scale
182
+ else:
183
+ return self.run(x, time, embedding=embedding, features=features)
184
+
185
+ return x
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ features: int,
192
+ num_heads: int,
193
+ head_features: int,
194
+ style_dim: int,
195
+ multiplier: int,
196
+ use_rel_pos: bool,
197
+ rel_pos_num_buckets: Optional[int] = None,
198
+ rel_pos_max_distance: Optional[int] = None,
199
+ context_features: Optional[int] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.use_cross_attention = exists(context_features) and context_features > 0
204
+
205
+ self.attention = StyleAttention(
206
+ features=features,
207
+ style_dim=style_dim,
208
+ num_heads=num_heads,
209
+ head_features=head_features,
210
+ use_rel_pos=use_rel_pos,
211
+ rel_pos_num_buckets=rel_pos_num_buckets,
212
+ rel_pos_max_distance=rel_pos_max_distance,
213
+ )
214
+
215
+ if self.use_cross_attention:
216
+ self.cross_attention = StyleAttention(
217
+ features=features,
218
+ style_dim=style_dim,
219
+ num_heads=num_heads,
220
+ head_features=head_features,
221
+ context_features=context_features,
222
+ use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets,
224
+ rel_pos_max_distance=rel_pos_max_distance,
225
+ )
226
+
227
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
+
229
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
+ x = self.attention(x, s) + x
231
+ if self.use_cross_attention:
232
+ x = self.cross_attention(x, s, context=context) + x
233
+ x = self.feed_forward(x) + x
234
+ return x
235
+
236
+ class StyleAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ features: int,
240
+ *,
241
+ style_dim: int,
242
+ head_features: int,
243
+ num_heads: int,
244
+ context_features: Optional[int] = None,
245
+ use_rel_pos: bool,
246
+ rel_pos_num_buckets: Optional[int] = None,
247
+ rel_pos_max_distance: Optional[int] = None,
248
+ ):
249
+ super().__init__()
250
+ self.context_features = context_features
251
+ mid_features = head_features * num_heads
252
+ context_features = default(context_features, features)
253
+
254
+ self.norm = AdaLayerNorm(style_dim, features)
255
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
256
+ self.to_q = nn.Linear(
257
+ in_features=features, out_features=mid_features, bias=False
258
+ )
259
+ self.to_kv = nn.Linear(
260
+ in_features=context_features, out_features=mid_features * 2, bias=False
261
+ )
262
+ self.attention = AttentionBase(
263
+ features,
264
+ num_heads=num_heads,
265
+ head_features=head_features,
266
+ use_rel_pos=use_rel_pos,
267
+ rel_pos_num_buckets=rel_pos_num_buckets,
268
+ rel_pos_max_distance=rel_pos_max_distance,
269
+ )
270
+
271
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
+ assert_message = "You must provide a context when using context_features"
273
+ assert not self.context_features or exists(context), assert_message
274
+ # Use context if provided
275
+ context = default(context, x)
276
+ # Normalize then compute q from input and k,v from context
277
+ x, context = self.norm(x, s), self.norm_context(context, s)
278
+
279
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
+ # Compute and return attention
281
+ return self.attention(q, k, v)
282
+
283
+ class Transformer1d(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_layers: int,
287
+ channels: int,
288
+ num_heads: int,
289
+ head_features: int,
290
+ multiplier: int,
291
+ use_context_time: bool = True,
292
+ use_rel_pos: bool = False,
293
+ context_features_multiplier: int = 1,
294
+ rel_pos_num_buckets: Optional[int] = None,
295
+ rel_pos_max_distance: Optional[int] = None,
296
+ context_features: Optional[int] = None,
297
+ context_embedding_features: Optional[int] = None,
298
+ embedding_max_length: int = 512,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.blocks = nn.ModuleList(
303
+ [
304
+ TransformerBlock(
305
+ features=channels + context_embedding_features,
306
+ head_features=head_features,
307
+ num_heads=num_heads,
308
+ multiplier=multiplier,
309
+ use_rel_pos=use_rel_pos,
310
+ rel_pos_num_buckets=rel_pos_num_buckets,
311
+ rel_pos_max_distance=rel_pos_max_distance,
312
+ )
313
+ for i in range(num_layers)
314
+ ]
315
+ )
316
+
317
+ self.to_out = nn.Sequential(
318
+ Rearrange("b t c -> b c t"),
319
+ nn.Conv1d(
320
+ in_channels=channels + context_embedding_features,
321
+ out_channels=channels,
322
+ kernel_size=1,
323
+ ),
324
+ )
325
+
326
+ use_context_features = exists(context_features)
327
+ self.use_context_features = use_context_features
328
+ self.use_context_time = use_context_time
329
+
330
+ if use_context_time or use_context_features:
331
+ context_mapping_features = channels + context_embedding_features
332
+
333
+ self.to_mapping = nn.Sequential(
334
+ nn.Linear(context_mapping_features, context_mapping_features),
335
+ nn.GELU(),
336
+ nn.Linear(context_mapping_features, context_mapping_features),
337
+ nn.GELU(),
338
+ )
339
+
340
+ if use_context_time:
341
+ assert exists(context_mapping_features)
342
+ self.to_time = nn.Sequential(
343
+ TimePositionalEmbedding(
344
+ dim=channels, out_features=context_mapping_features
345
+ ),
346
+ nn.GELU(),
347
+ )
348
+
349
+ if use_context_features:
350
+ assert exists(context_features) and exists(context_mapping_features)
351
+ self.to_features = nn.Sequential(
352
+ nn.Linear(
353
+ in_features=context_features, out_features=context_mapping_features
354
+ ),
355
+ nn.GELU(),
356
+ )
357
+
358
+ self.fixed_embedding = FixedEmbedding(
359
+ max_length=embedding_max_length, features=context_embedding_features
360
+ )
361
+
362
+
363
+ def get_mapping(
364
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
+ ) -> Optional[Tensor]:
366
+ """Combines context time features and features into mapping"""
367
+ items, mapping = [], None
368
+ # Compute time features
369
+ if self.use_context_time:
370
+ assert_message = "use_context_time=True but no time features provided"
371
+ assert exists(time), assert_message
372
+ items += [self.to_time(time)]
373
+ # Compute features
374
+ if self.use_context_features:
375
+ assert_message = "context_features exists but no features provided"
376
+ assert exists(features), assert_message
377
+ items += [self.to_features(features)]
378
+
379
+ # Compute joint mapping
380
+ if self.use_context_time or self.use_context_features:
381
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
+ mapping = self.to_mapping(mapping)
383
+
384
+ return mapping
385
+
386
+ def run(self, x, time, embedding, features):
387
+
388
+ mapping = self.get_mapping(time, features)
389
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
+
392
+ for block in self.blocks:
393
+ x = x + mapping
394
+ x = block(x)
395
+
396
+ x = x.mean(axis=1).unsqueeze(1)
397
+ x = self.to_out(x)
398
+ x = x.transpose(-1, -2)
399
+
400
+ return x
401
+
402
+ def forward(self, x: Tensor,
403
+ time: Tensor,
404
+ embedding_mask_proba: float = 0.0,
405
+ embedding: Optional[Tensor] = None,
406
+ features: Optional[Tensor] = None,
407
+ embedding_scale: float = 1.0) -> Tensor:
408
+
409
+ b, device = embedding.shape[0], embedding.device
410
+ fixed_embedding = self.fixed_embedding(embedding)
411
+ if embedding_mask_proba > 0.0:
412
+ # Randomly mask embedding
413
+ batch_mask = rand_bool(
414
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
+ )
416
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
+
418
+ if embedding_scale != 1.0:
419
+ # Compute both normal and fixed embedding outputs
420
+ out = self.run(x, time, embedding=embedding, features=features)
421
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
+ # Scale conditional output using classifier-free guidance
423
+ return out_masked + (out - out_masked) * embedding_scale
424
+ else:
425
+ return self.run(x, time, embedding=embedding, features=features)
426
+
427
+ return x
428
+
429
+
430
+ """
431
+ Attention Components
432
+ """
433
+
434
+
435
+ class RelativePositionBias(nn.Module):
436
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
+ super().__init__()
438
+ self.num_buckets = num_buckets
439
+ self.max_distance = max_distance
440
+ self.num_heads = num_heads
441
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
+
443
+ @staticmethod
444
+ def _relative_position_bucket(
445
+ relative_position: Tensor, num_buckets: int, max_distance: int
446
+ ):
447
+ num_buckets //= 2
448
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
449
+ n = torch.abs(relative_position)
450
+
451
+ max_exact = num_buckets // 2
452
+ is_small = n < max_exact
453
+
454
+ val_if_large = (
455
+ max_exact
456
+ + (
457
+ torch.log(n.float() / max_exact)
458
+ / log(max_distance / max_exact)
459
+ * (num_buckets - max_exact)
460
+ ).long()
461
+ )
462
+ val_if_large = torch.min(
463
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
+ )
465
+
466
+ ret += torch.where(is_small, n, val_if_large)
467
+ return ret
468
+
469
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
473
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
+
475
+ relative_position_bucket = self._relative_position_bucket(
476
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
+ )
478
+
479
+ bias = self.relative_attention_bias(relative_position_bucket)
480
+ bias = rearrange(bias, "m n h -> 1 h m n")
481
+ return bias
482
+
483
+
484
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
485
+ mid_features = features * multiplier
486
+ return nn.Sequential(
487
+ nn.Linear(in_features=features, out_features=mid_features),
488
+ nn.GELU(),
489
+ nn.Linear(in_features=mid_features, out_features=features),
490
+ )
491
+
492
+
493
+ class AttentionBase(nn.Module):
494
+ def __init__(
495
+ self,
496
+ features: int,
497
+ *,
498
+ head_features: int,
499
+ num_heads: int,
500
+ use_rel_pos: bool,
501
+ out_features: Optional[int] = None,
502
+ rel_pos_num_buckets: Optional[int] = None,
503
+ rel_pos_max_distance: Optional[int] = None,
504
+ ):
505
+ super().__init__()
506
+ self.scale = head_features ** -0.5
507
+ self.num_heads = num_heads
508
+ self.use_rel_pos = use_rel_pos
509
+ mid_features = head_features * num_heads
510
+
511
+ if use_rel_pos:
512
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
+ self.rel_pos = RelativePositionBias(
514
+ num_buckets=rel_pos_num_buckets,
515
+ max_distance=rel_pos_max_distance,
516
+ num_heads=num_heads,
517
+ )
518
+ if out_features is None:
519
+ out_features = features
520
+
521
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
522
+
523
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
524
+ # Split heads
525
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
526
+ # Compute similarity matrix
527
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
528
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
529
+ sim = sim * self.scale
530
+ # Get attention matrix with softmax
531
+ attn = sim.softmax(dim=-1)
532
+ # Compute values
533
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
534
+ out = rearrange(out, "b h n d -> b n (h d)")
535
+ return self.to_out(out)
536
+
537
+
538
+ class Attention(nn.Module):
539
+ def __init__(
540
+ self,
541
+ features: int,
542
+ *,
543
+ head_features: int,
544
+ num_heads: int,
545
+ out_features: Optional[int] = None,
546
+ context_features: Optional[int] = None,
547
+ use_rel_pos: bool,
548
+ rel_pos_num_buckets: Optional[int] = None,
549
+ rel_pos_max_distance: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.context_features = context_features
553
+ mid_features = head_features * num_heads
554
+ context_features = default(context_features, features)
555
+
556
+ self.norm = nn.LayerNorm(features)
557
+ self.norm_context = nn.LayerNorm(context_features)
558
+ self.to_q = nn.Linear(
559
+ in_features=features, out_features=mid_features, bias=False
560
+ )
561
+ self.to_kv = nn.Linear(
562
+ in_features=context_features, out_features=mid_features * 2, bias=False
563
+ )
564
+
565
+ self.attention = AttentionBase(
566
+ features,
567
+ out_features=out_features,
568
+ num_heads=num_heads,
569
+ head_features=head_features,
570
+ use_rel_pos=use_rel_pos,
571
+ rel_pos_num_buckets=rel_pos_num_buckets,
572
+ rel_pos_max_distance=rel_pos_max_distance,
573
+ )
574
+
575
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
576
+ assert_message = "You must provide a context when using context_features"
577
+ assert not self.context_features or exists(context), assert_message
578
+ # Use context if provided
579
+ context = default(context, x)
580
+ # Normalize then compute q from input and k,v from context
581
+ x, context = self.norm(x), self.norm_context(context)
582
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
583
+ # Compute and return attention
584
+ return self.attention(q, k, v)
585
+
586
+
587
+ """
588
+ Transformer Blocks
589
+ """
590
+
591
+
592
+ class TransformerBlock(nn.Module):
593
+ def __init__(
594
+ self,
595
+ features: int,
596
+ num_heads: int,
597
+ head_features: int,
598
+ multiplier: int,
599
+ use_rel_pos: bool,
600
+ rel_pos_num_buckets: Optional[int] = None,
601
+ rel_pos_max_distance: Optional[int] = None,
602
+ context_features: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+
606
+ self.use_cross_attention = exists(context_features) and context_features > 0
607
+
608
+ self.attention = Attention(
609
+ features=features,
610
+ num_heads=num_heads,
611
+ head_features=head_features,
612
+ use_rel_pos=use_rel_pos,
613
+ rel_pos_num_buckets=rel_pos_num_buckets,
614
+ rel_pos_max_distance=rel_pos_max_distance,
615
+ )
616
+
617
+ if self.use_cross_attention:
618
+ self.cross_attention = Attention(
619
+ features=features,
620
+ num_heads=num_heads,
621
+ head_features=head_features,
622
+ context_features=context_features,
623
+ use_rel_pos=use_rel_pos,
624
+ rel_pos_num_buckets=rel_pos_num_buckets,
625
+ rel_pos_max_distance=rel_pos_max_distance,
626
+ )
627
+
628
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
+
630
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
+ x = self.attention(x) + x
632
+ if self.use_cross_attention:
633
+ x = self.cross_attention(x, context=context) + x
634
+ x = self.feed_forward(x) + x
635
+ return x
636
+
637
+
638
+
639
+ """
640
+ Time Embeddings
641
+ """
642
+
643
+
644
+ class SinusoidalEmbedding(nn.Module):
645
+ def __init__(self, dim: int):
646
+ super().__init__()
647
+ self.dim = dim
648
+
649
+ def forward(self, x: Tensor) -> Tensor:
650
+ device, half_dim = x.device, self.dim // 2
651
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
+
656
+
657
+ class LearnedPositionalEmbedding(nn.Module):
658
+ """Used for continuous time"""
659
+
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ assert (dim % 2) == 0
663
+ half_dim = dim // 2
664
+ self.weights = nn.Parameter(torch.randn(half_dim))
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ x = rearrange(x, "b -> b 1")
668
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
669
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
670
+ fouriered = torch.cat((x, fouriered), dim=-1)
671
+ return fouriered
672
+
673
+
674
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
675
+ return nn.Sequential(
676
+ LearnedPositionalEmbedding(dim),
677
+ nn.Linear(in_features=dim + 1, out_features=out_features),
678
+ )
679
+
680
+ class FixedEmbedding(nn.Module):
681
+ def __init__(self, max_length: int, features: int):
682
+ super().__init__()
683
+ self.max_length = max_length
684
+ self.embedding = nn.Embedding(max_length, features)
685
+
686
+ def forward(self, x: Tensor) -> Tensor:
687
+ batch_size, length, device = *x.shape[0:2], x.device
688
+ assert_message = "Input sequence length must be <= max_length"
689
+ assert length <= self.max_length, assert_message
690
+ position = torch.arange(length, device=device)
691
+ fixed_embedding = self.embedding(position)
692
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
693
+ return fixed_embedding
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+
102
+ alias: str = ""
103
+
104
+ """Base diffusion class"""
105
+
106
+ def denoise_fn(
107
+ self,
108
+ x_noisy: Tensor,
109
+ sigmas: Optional[Tensor] = None,
110
+ sigma: Optional[float] = None,
111
+ **kwargs,
112
+ ) -> Tensor:
113
+ raise NotImplementedError("Diffusion class missing denoise_fn")
114
+
115
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
+ raise NotImplementedError("Diffusion class missing forward function")
117
+
118
+
119
+ class VDiffusion(Diffusion):
120
+
121
+ alias = "v"
122
+
123
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
+ super().__init__()
125
+ self.net = net
126
+ self.sigma_distribution = sigma_distribution
127
+
128
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
+ angle = sigmas * pi / 2
130
+ alpha = torch.cos(angle)
131
+ beta = torch.sin(angle)
132
+ return alpha, beta
133
+
134
+ def denoise_fn(
135
+ self,
136
+ x_noisy: Tensor,
137
+ sigmas: Optional[Tensor] = None,
138
+ sigma: Optional[float] = None,
139
+ **kwargs,
140
+ ) -> Tensor:
141
+ batch_size, device = x_noisy.shape[0], x_noisy.device
142
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
+ return self.net(x_noisy, sigmas, **kwargs)
144
+
145
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
+ batch_size, device = x.shape[0], x.device
147
+
148
+ # Sample amount of noise to add for each batch element
149
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
+
152
+ # Get noise
153
+ noise = default(noise, lambda: torch.randn_like(x))
154
+
155
+ # Combine input and noise weighted by half-circle
156
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
157
+ x_noisy = x * alpha + noise * beta
158
+ x_target = noise * alpha - x * beta
159
+
160
+ # Denoise and return loss
161
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
+ return F.mse_loss(x_denoised, x_target)
163
+
164
+
165
+ class KDiffusion(Diffusion):
166
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
+
168
+ alias = "k"
169
+
170
+ def __init__(
171
+ self,
172
+ net: nn.Module,
173
+ *,
174
+ sigma_distribution: Distribution,
175
+ sigma_data: float, # data distribution standard deviation
176
+ dynamic_threshold: float = 0.0,
177
+ ):
178
+ super().__init__()
179
+ self.net = net
180
+ self.sigma_data = sigma_data
181
+ self.sigma_distribution = sigma_distribution
182
+ self.dynamic_threshold = dynamic_threshold
183
+
184
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
185
+ sigma_data = self.sigma_data
186
+ c_noise = torch.log(sigmas) * 0.25
187
+ sigmas = rearrange(sigmas, "b -> b 1 1")
188
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
189
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
190
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
191
+ return c_skip, c_out, c_in, c_noise
192
+
193
+ def denoise_fn(
194
+ self,
195
+ x_noisy: Tensor,
196
+ sigmas: Optional[Tensor] = None,
197
+ sigma: Optional[float] = None,
198
+ **kwargs,
199
+ ) -> Tensor:
200
+ batch_size, device = x_noisy.shape[0], x_noisy.device
201
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
+
203
+ # Predict network output and add skip connection
204
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
+ x_denoised = c_skip * x_noisy + c_out * x_pred
207
+
208
+ return x_denoised
209
+
210
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
211
+ # Computes weight depending on data distribution
212
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
+
214
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
+ batch_size, device = x.shape[0], x.device
216
+ from einops import rearrange, reduce
217
+
218
+ # Sample amount of noise to add for each batch element
219
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
+
222
+ # Add noise to input
223
+ noise = default(noise, lambda: torch.randn_like(x))
224
+ x_noisy = x + sigmas_padded * noise
225
+
226
+ # Compute denoised values
227
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
+
229
+ # Compute weighted loss
230
+ losses = F.mse_loss(x_denoised, x, reduction="none")
231
+ losses = reduce(losses, "b ... -> b", "mean")
232
+ losses = losses * self.loss_weight(sigmas)
233
+ loss = losses.mean()
234
+ return loss
235
+
236
+
237
+ class VKDiffusion(Diffusion):
238
+
239
+ alias = "vk"
240
+
241
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
+ super().__init__()
243
+ self.net = net
244
+ self.sigma_distribution = sigma_distribution
245
+
246
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
+ sigma_data = 1.0
248
+ sigmas = rearrange(sigmas, "b -> b 1 1")
249
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
+ return c_skip, c_out, c_in
253
+
254
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
+ return sigmas.atan() / pi * 2
256
+
257
+ def t_to_sigma(self, t: Tensor) -> Tensor:
258
+ return (t * pi / 2).tan()
259
+
260
+ def denoise_fn(
261
+ self,
262
+ x_noisy: Tensor,
263
+ sigmas: Optional[Tensor] = None,
264
+ sigma: Optional[float] = None,
265
+ **kwargs,
266
+ ) -> Tensor:
267
+ batch_size, device = x_noisy.shape[0], x_noisy.device
268
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
+
270
+ # Predict network output and add skip connection
271
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
+ x_denoised = c_skip * x_noisy + c_out * x_pred
274
+ return x_denoised
275
+
276
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
+ batch_size, device = x.shape[0], x.device
278
+
279
+ # Sample amount of noise to add for each batch element
280
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
+
283
+ # Add noise to input
284
+ noise = default(noise, lambda: torch.randn_like(x))
285
+ x_noisy = x + sigmas_padded * noise
286
+
287
+ # Compute model output
288
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
+
291
+ # Compute v-objective target
292
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
+
294
+ # Compute loss
295
+ loss = F.mse_loss(x_pred, v_target)
296
+ return loss
297
+
298
+
299
+ """
300
+ Diffusion Sampling
301
+ """
302
+
303
+ """ Schedules """
304
+
305
+
306
+ class Schedule(nn.Module):
307
+ """Interface used by different sampling schedules"""
308
+
309
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
+ raise NotImplementedError()
311
+
312
+
313
+ class LinearSchedule(Schedule):
314
+ def forward(self, num_steps: int, device: Any) -> Tensor:
315
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
+ return sigmas
317
+
318
+
319
+ class KarrasSchedule(Schedule):
320
+ """https://arxiv.org/abs/2206.00364 equation 5"""
321
+
322
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
323
+ super().__init__()
324
+ self.sigma_min = sigma_min
325
+ self.sigma_max = sigma_max
326
+ self.rho = rho
327
+
328
+ def forward(self, num_steps: int, device: Any) -> Tensor:
329
+ rho_inv = 1.0 / self.rho
330
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
331
+ sigmas = (
332
+ self.sigma_max ** rho_inv
333
+ + (steps / (num_steps - 1))
334
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
335
+ ) ** self.rho
336
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
337
+ return sigmas
338
+
339
+
340
+ """ Samplers """
341
+
342
+
343
+ class Sampler(nn.Module):
344
+
345
+ diffusion_types: List[Type[Diffusion]] = []
346
+
347
+ def forward(
348
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
349
+ ) -> Tensor:
350
+ raise NotImplementedError()
351
+
352
+ def inpaint(
353
+ self,
354
+ source: Tensor,
355
+ mask: Tensor,
356
+ fn: Callable,
357
+ sigmas: Tensor,
358
+ num_steps: int,
359
+ num_resamples: int,
360
+ ) -> Tensor:
361
+ raise NotImplementedError("Inpainting not available with current sampler")
362
+
363
+
364
+ class VSampler(Sampler):
365
+
366
+ diffusion_types = [VDiffusion]
367
+
368
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
+ angle = sigma * pi / 2
370
+ alpha = cos(angle)
371
+ beta = sin(angle)
372
+ return alpha, beta
373
+
374
+ def forward(
375
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
+ ) -> Tensor:
377
+ x = sigmas[0] * noise
378
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
+
380
+ for i in range(num_steps - 1):
381
+ is_last = i == num_steps - 1
382
+
383
+ x_denoised = fn(x, sigma=sigmas[i])
384
+ x_pred = x * alpha - x_denoised * beta
385
+ x_eps = x * beta + x_denoised * alpha
386
+
387
+ if not is_last:
388
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
+ x = x_pred * alpha + x_eps * beta
390
+
391
+ return x_pred
392
+
393
+
394
+ class KarrasSampler(Sampler):
395
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
+
397
+ diffusion_types = [KDiffusion, VKDiffusion]
398
+
399
+ def __init__(
400
+ self,
401
+ s_tmin: float = 0,
402
+ s_tmax: float = float("inf"),
403
+ s_churn: float = 0.0,
404
+ s_noise: float = 1.0,
405
+ ):
406
+ super().__init__()
407
+ self.s_tmin = s_tmin
408
+ self.s_tmax = s_tmax
409
+ self.s_noise = s_noise
410
+ self.s_churn = s_churn
411
+
412
+ def step(
413
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
+ ) -> Tensor:
415
+ """Algorithm 2 (step)"""
416
+ # Select temporarily increased noise level
417
+ sigma_hat = sigma + gamma * sigma
418
+ # Add noise to move from sigma to sigma_hat
419
+ epsilon = self.s_noise * torch.randn_like(x)
420
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
+ # Evaluate ∂x/∂sigma at sigma_hat
422
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
+ # Take euler step from sigma_hat to sigma_next
424
+ x_next = x_hat + (sigma_next - sigma_hat) * d
425
+ # Second order correction
426
+ if sigma_next != 0:
427
+ model_out_next = fn(x_next, sigma=sigma_next)
428
+ d_prime = (x_next - model_out_next) / sigma_next
429
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
+ return x_next
431
+
432
+ def forward(
433
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
+ ) -> Tensor:
435
+ x = sigmas[0] * noise
436
+ # Compute gammas
437
+ gammas = torch.where(
438
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
+ min(self.s_churn / num_steps, sqrt(2) - 1),
440
+ 0.0,
441
+ )
442
+ # Denoise to sample
443
+ for i in range(num_steps - 1):
444
+ x = self.step(
445
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
+ )
447
+
448
+ return x
449
+
450
+
451
+ class AEulerSampler(Sampler):
452
+
453
+ diffusion_types = [KDiffusion, VKDiffusion]
454
+
455
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
+ return sigma_up, sigma_down
459
+
460
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
+ # Sigma steps
462
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
+ # Derivative at sigma (∂x/∂sigma)
464
+ d = (x - fn(x, sigma=sigma)) / sigma
465
+ # Euler method
466
+ x_next = x + d * (sigma_down - sigma)
467
+ # Add randomness
468
+ x_next = x_next + torch.randn_like(x) * sigma_up
469
+ return x_next
470
+
471
+ def forward(
472
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
+ ) -> Tensor:
474
+ x = sigmas[0] * noise
475
+ # Denoise to sample
476
+ for i in range(num_steps - 1):
477
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
+ return x
479
+
480
+
481
+ class ADPM2Sampler(Sampler):
482
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
+
484
+ diffusion_types = [KDiffusion, VKDiffusion]
485
+
486
+ def __init__(self, rho: float = 1.0):
487
+ super().__init__()
488
+ self.rho = rho
489
+
490
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
491
+ r = self.rho
492
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
493
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
494
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
495
+ return sigma_up, sigma_down, sigma_mid
496
+
497
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
498
+ # Sigma steps
499
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
500
+ # Derivative at sigma (∂x/∂sigma)
501
+ d = (x - fn(x, sigma=sigma)) / sigma
502
+ # Denoise to midpoint
503
+ x_mid = x + d * (sigma_mid - sigma)
504
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
505
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
506
+ # Denoise to next
507
+ x = x + d_mid * (sigma_down - sigma)
508
+ # Add randomness
509
+ x_next = x + torch.randn_like(x) * sigma_up
510
+ return x_next
511
+
512
+ def forward(
513
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
+ ) -> Tensor:
515
+ x = sigmas[0] * noise
516
+ # Denoise to sample
517
+ for i in range(num_steps - 1):
518
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
+ return x
520
+
521
+ def inpaint(
522
+ self,
523
+ source: Tensor,
524
+ mask: Tensor,
525
+ fn: Callable,
526
+ sigmas: Tensor,
527
+ num_steps: int,
528
+ num_resamples: int,
529
+ ) -> Tensor:
530
+ x = sigmas[0] * torch.randn_like(source)
531
+
532
+ for i in range(num_steps - 1):
533
+ # Noise source to current noise level
534
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
535
+ for r in range(num_resamples):
536
+ # Merge noisy source and current then denoise
537
+ x = source_noisy * mask + x * ~mask
538
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
+ # Renoise if not last resample step
540
+ if r < num_resamples - 1:
541
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
+ x = x + sigma * torch.randn_like(x)
543
+
544
+ return source * mask + x * ~mask
545
+
546
+
547
+ """ Main Classes """
548
+
549
+
550
+ class DiffusionSampler(nn.Module):
551
+ def __init__(
552
+ self,
553
+ diffusion: Diffusion,
554
+ *,
555
+ sampler: Sampler,
556
+ sigma_schedule: Schedule,
557
+ num_steps: Optional[int] = None,
558
+ clamp: bool = True,
559
+ ):
560
+ super().__init__()
561
+ self.denoise_fn = diffusion.denoise_fn
562
+ self.sampler = sampler
563
+ self.sigma_schedule = sigma_schedule
564
+ self.num_steps = num_steps
565
+ self.clamp = clamp
566
+
567
+ # Check sampler is compatible with diffusion type
568
+ sampler_class = sampler.__class__.__name__
569
+ diffusion_class = diffusion.__class__.__name__
570
+ message = f"{sampler_class} incompatible with {diffusion_class}"
571
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
+
573
+ def forward(
574
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
+ ) -> Tensor:
576
+ device = noise.device
577
+ num_steps = default(num_steps, self.num_steps) # type: ignore
578
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
579
+ # Compute sigmas using schedule
580
+ sigmas = self.sigma_schedule(num_steps, device)
581
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
582
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
583
+ # Sample using sampler
584
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
586
+ return x
587
+
588
+
589
+ class DiffusionInpainter(nn.Module):
590
+ def __init__(
591
+ self,
592
+ diffusion: Diffusion,
593
+ *,
594
+ num_steps: int,
595
+ num_resamples: int,
596
+ sampler: Sampler,
597
+ sigma_schedule: Schedule,
598
+ ):
599
+ super().__init__()
600
+ self.denoise_fn = diffusion.denoise_fn
601
+ self.num_steps = num_steps
602
+ self.num_resamples = num_resamples
603
+ self.inpaint_fn = sampler.inpaint
604
+ self.sigma_schedule = sigma_schedule
605
+
606
+ @torch.no_grad()
607
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
+ x = self.inpaint_fn(
609
+ source=inpaint,
610
+ mask=inpaint_mask,
611
+ fn=self.denoise_fn,
612
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
+ num_steps=self.num_steps,
614
+ num_resamples=self.num_resamples,
615
+ )
616
+ return x
617
+
618
+
619
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
620
+ length, device = like.shape[2], like.device
621
+ mask = torch.ones_like(like, dtype=torch.bool)
622
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
+ return mask
624
+
625
+
626
+ class SpanBySpanComposer(nn.Module):
627
+ def __init__(
628
+ self,
629
+ inpainter: DiffusionInpainter,
630
+ *,
631
+ num_spans: int,
632
+ ):
633
+ super().__init__()
634
+ self.inpainter = inpainter
635
+ self.num_spans = num_spans
636
+
637
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
+ half_length = start.shape[2] // 2
639
+
640
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
+ # Inpaint second half from first half
642
+ inpaint = torch.zeros_like(start)
643
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
644
+ inpaint_mask = sequential_mask(like=start, start=half_length)
645
+
646
+ for i in range(self.num_spans):
647
+ # Inpaint second half
648
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
+ # Replace first half with generated second half
650
+ second_half = span[:, :, half_length:]
651
+ inpaint[:, :, :half_length] = second_half
652
+ # Save generated span
653
+ spans.append(second_half)
654
+
655
+ return torch.cat(spans, dim=2)
656
+
657
+
658
+ class XDiffusion(nn.Module):
659
+ def __init__(self, type: str, net: nn.Module, **kwargs):
660
+ super().__init__()
661
+
662
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
664
+ message = f"type='{type}' must be one of {*aliases,}"
665
+ assert type in aliases, message
666
+ self.net = net
667
+
668
+ for XDiffusion in diffusion_classes:
669
+ if XDiffusion.alias == type: # type: ignore
670
+ self.diffusion = XDiffusion(net=net, **kwargs)
671
+
672
+ def forward(self, *args, **kwargs) -> Tensor:
673
+ return self.diffusion(*args, **kwargs)
674
+
675
+ def sample(
676
+ self,
677
+ noise: Tensor,
678
+ num_steps: int,
679
+ sigma_schedule: Schedule,
680
+ sampler: Sampler,
681
+ clamp: bool,
682
+ **kwargs,
683
+ ) -> Tensor:
684
+ diffusion_sampler = DiffusionSampler(
685
+ diffusion=self.diffusion,
686
+ sampler=sampler,
687
+ sigma_schedule=sigma_schedule,
688
+ num_steps=num_steps,
689
+ clamp=clamp,
690
+ )
691
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def rand_bool(shape, proba, device = None):
52
+ if proba == 1:
53
+ return torch.ones(shape, device=device, dtype=torch.bool)
54
+ elif proba == 0:
55
+ return torch.zeros(shape, device=device, dtype=torch.bool)
56
+ else:
57
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
+
59
+
60
+ """
61
+ Kwargs Utils
62
+ """
63
+
64
+
65
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
67
+ for key in d.keys():
68
+ no_prefix = int(not key.startswith(prefix))
69
+ return_dicts[no_prefix][key] = d[key]
70
+ return return_dicts
71
+
72
+
73
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
+ if keep_prefix:
76
+ return kwargs_with_prefix, kwargs
77
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
+ return kwargs_no_prefix, kwargs
79
+
80
+
81
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
82
+ return {prefix + str(k): v for k, v in d.items()}
Modules/discriminators.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+ def stft(x, fft_size, hop_size, win_length, window):
12
+ """Perform STFT and convert to magnitude spectrogram.
13
+ Args:
14
+ x (Tensor): Input signal tensor (B, T).
15
+ fft_size (int): FFT size.
16
+ hop_size (int): Hop size.
17
+ win_length (int): Window length.
18
+ window (str): Window function type.
19
+ Returns:
20
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
21
+ """
22
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
23
+ return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+ class SpecDiscriminator(nn.Module):
30
+ """docstring for Discriminator."""
31
+
32
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
33
+ super(SpecDiscriminator, self).__init__()
34
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
35
+ self.fft_size = fft_size
36
+ self.shift_size = shift_size
37
+ self.win_length = win_length
38
+ self.window = getattr(torch, window)(win_length)
39
+ self.discriminators = nn.ModuleList([
40
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
41
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
42
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
43
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
44
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
45
+ ])
46
+
47
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
48
+
49
+ def forward(self, y):
50
+
51
+ fmap = []
52
+ y = y.squeeze(1)
53
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
54
+ y = y.unsqueeze(1)
55
+ for i, d in enumerate(self.discriminators):
56
+ y = d(y)
57
+ y = F.leaky_relu(y, LRELU_SLOPE)
58
+ fmap.append(y)
59
+
60
+ y = self.out(y)
61
+ fmap.append(y)
62
+
63
+ return torch.flatten(y, 1, -1), fmap
64
+
65
+ class MultiResSpecDiscriminator(torch.nn.Module):
66
+
67
+ def __init__(self,
68
+ fft_sizes=[1024, 2048, 512],
69
+ hop_sizes=[120, 240, 50],
70
+ win_lengths=[600, 1200, 240],
71
+ window="hann_window"):
72
+
73
+ super(MultiResSpecDiscriminator, self).__init__()
74
+ self.discriminators = nn.ModuleList([
75
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
76
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
77
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
78
+ ])
79
+
80
+ def forward(self, y, y_hat):
81
+ y_d_rs = []
82
+ y_d_gs = []
83
+ fmap_rs = []
84
+ fmap_gs = []
85
+ for i, d in enumerate(self.discriminators):
86
+ y_d_r, fmap_r = d(y)
87
+ y_d_g, fmap_g = d(y_hat)
88
+ y_d_rs.append(y_d_r)
89
+ fmap_rs.append(fmap_r)
90
+ y_d_gs.append(y_d_g)
91
+ fmap_gs.append(fmap_g)
92
+
93
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
94
+
95
+
96
+ class DiscriminatorP(torch.nn.Module):
97
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
98
+ super(DiscriminatorP, self).__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
101
+ self.convs = nn.ModuleList([
102
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
104
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
105
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
106
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
107
+ ])
108
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
109
+
110
+ def forward(self, x):
111
+ fmap = []
112
+
113
+ # 1d to 2d
114
+ b, c, t = x.shape
115
+ if t % self.period != 0: # pad first
116
+ n_pad = self.period - (t % self.period)
117
+ x = F.pad(x, (0, n_pad), "reflect")
118
+ t = t + n_pad
119
+ x = x.view(b, c, t // self.period, self.period)
120
+
121
+ for l in self.convs:
122
+ x = l(x)
123
+ x = F.leaky_relu(x, LRELU_SLOPE)
124
+ fmap.append(x)
125
+ x = self.conv_post(x)
126
+ fmap.append(x)
127
+ x = torch.flatten(x, 1, -1)
128
+
129
+ return x, fmap
130
+
131
+
132
+ class MultiPeriodDiscriminator(torch.nn.Module):
133
+ def __init__(self):
134
+ super(MultiPeriodDiscriminator, self).__init__()
135
+ self.discriminators = nn.ModuleList([
136
+ DiscriminatorP(2),
137
+ DiscriminatorP(3),
138
+ DiscriminatorP(5),
139
+ DiscriminatorP(7),
140
+ DiscriminatorP(11),
141
+ ])
142
+
143
+ def forward(self, y, y_hat):
144
+ y_d_rs = []
145
+ y_d_gs = []
146
+ fmap_rs = []
147
+ fmap_gs = []
148
+ for i, d in enumerate(self.discriminators):
149
+ y_d_r, fmap_r = d(y)
150
+ y_d_g, fmap_g = d(y_hat)
151
+ y_d_rs.append(y_d_r)
152
+ fmap_rs.append(fmap_r)
153
+ y_d_gs.append(y_d_g)
154
+ fmap_gs.append(fmap_g)
155
+
156
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
157
+
158
+ class WavLMDiscriminator(nn.Module):
159
+ """docstring for Discriminator."""
160
+
161
+ def __init__(self, slm_hidden=768,
162
+ slm_layers=13,
163
+ initial_channel=64,
164
+ use_spectral_norm=False):
165
+ super(WavLMDiscriminator, self).__init__()
166
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
167
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168
+
169
+ self.convs = nn.ModuleList([
170
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
171
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
172
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
173
+ ])
174
+
175
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
176
+
177
+ def forward(self, x):
178
+ x = self.pre(x)
179
+
180
+ fmap = []
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x
Modules/hifigan.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class AdaIN1d(nn.Module):
15
+ def __init__(self, style_dim, num_features):
16
+ super().__init__()
17
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
+ self.fc = nn.Linear(style_dim, num_features*2)
19
+
20
+ def forward(self, x, s):
21
+ h = self.fc(s)
22
+ h = h.view(h.size(0), h.size(1), 1)
23
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
+ return (1 + gamma) * self.norm(x) + beta
25
+
26
+ class AdaINResBlock1(torch.nn.Module):
27
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
28
+ super(AdaINResBlock1, self).__init__()
29
+ self.convs1 = nn.ModuleList([
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
31
+ padding=get_padding(kernel_size, dilation[0]))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]))),
34
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
35
+ padding=get_padding(kernel_size, dilation[2])))
36
+ ])
37
+ self.convs1.apply(init_weights)
38
+
39
+ self.convs2 = nn.ModuleList([
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
+ padding=get_padding(kernel_size, 1))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
43
+ padding=get_padding(kernel_size, 1))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1)))
46
+ ])
47
+ self.convs2.apply(init_weights)
48
+
49
+ self.adain1 = nn.ModuleList([
50
+ AdaIN1d(style_dim, channels),
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ ])
54
+
55
+ self.adain2 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+
61
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
62
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
63
+
64
+
65
+ def forward(self, x, s):
66
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
67
+ xt = n1(x, s)
68
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
69
+ xt = c1(xt)
70
+ xt = n2(xt, s)
71
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
72
+ xt = c2(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+ def remove_weight_norm(self):
77
+ for l in self.convs1:
78
+ remove_weight_norm(l)
79
+ for l in self.convs2:
80
+ remove_weight_norm(l)
81
+
82
+ class SineGen(torch.nn.Module):
83
+ """ Definition of sine generator
84
+ SineGen(samp_rate, harmonic_num = 0,
85
+ sine_amp = 0.1, noise_std = 0.003,
86
+ voiced_threshold = 0,
87
+ flag_for_pulse=False)
88
+ samp_rate: sampling rate in Hz
89
+ harmonic_num: number of harmonic overtones (default 0)
90
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
91
+ noise_std: std of Gaussian noise (default 0.003)
92
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
93
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
94
+ Note: when flag_for_pulse is True, the first time step of a voiced
95
+ segment is always sin(np.pi) or cos(0)
96
+ """
97
+
98
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
99
+ sine_amp=0.1, noise_std=0.003,
100
+ voiced_threshold=0,
101
+ flag_for_pulse=False):
102
+ super(SineGen, self).__init__()
103
+ self.sine_amp = sine_amp
104
+ self.noise_std = noise_std
105
+ self.harmonic_num = harmonic_num
106
+ self.dim = self.harmonic_num + 1
107
+ self.sampling_rate = samp_rate
108
+ self.voiced_threshold = voiced_threshold
109
+ self.flag_for_pulse = flag_for_pulse
110
+ self.upsample_scale = upsample_scale
111
+
112
+ def _f02uv(self, f0):
113
+ # generate uv signal
114
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
115
+ return uv
116
+
117
+ def _f02sine(self, f0_values):
118
+ """ f0_values: (batchsize, length, dim)
119
+ where dim indicates fundamental tone and overtones
120
+ """
121
+ # convert to F0 in rad. The interger part n can be ignored
122
+ # because 2 * np.pi * n doesn't affect phase
123
+ rad_values = (f0_values / self.sampling_rate) % 1
124
+
125
+ # initial phase noise (no noise for fundamental component)
126
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
127
+ device=f0_values.device)
128
+ rand_ini[:, 0] = 0
129
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
130
+
131
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
132
+ if not self.flag_for_pulse:
133
+ # # for normal case
134
+
135
+ # # To prevent torch.cumsum numerical overflow,
136
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
137
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
138
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
139
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
140
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
141
+ # cumsum_shift = torch.zeros_like(rad_values)
142
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
143
+
144
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
145
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
146
+ scale_factor=1/self.upsample_scale,
147
+ mode="linear").transpose(1, 2)
148
+
149
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
150
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
151
+ # cumsum_shift = torch.zeros_like(rad_values)
152
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
153
+
154
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
155
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
156
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
157
+ sines = torch.sin(phase)
158
+
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+
164
+ # identify the last time step in unvoiced segments
165
+ uv = self._f02uv(f0_values)
166
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
167
+ uv_1[:, -1, :] = 1
168
+ u_loc = (uv < 1) * (uv_1 > 0)
169
+
170
+ # get the instantanouse phase
171
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
172
+ # different batch needs to be processed differently
173
+ for idx in range(f0_values.shape[0]):
174
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
175
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
176
+ # stores the accumulation of i.phase within
177
+ # each voiced segments
178
+ tmp_cumsum[idx, :, :] = 0
179
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
180
+
181
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
182
+ # within the previous voiced segment.
183
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
184
+
185
+ # get the sines
186
+ sines = torch.cos(i_phase * 2 * np.pi)
187
+ return sines
188
+
189
+ def forward(self, f0):
190
+ """ sine_tensor, uv = forward(f0)
191
+ input F0: tensor(batchsize=1, length, dim=1)
192
+ f0 for unvoiced steps should be 0
193
+ output sine_tensor: tensor(batchsize=1, length, dim)
194
+ output uv: tensor(batchsize=1, length, 1)
195
+ """
196
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197
+ device=f0.device)
198
+ # fundamental component
199
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200
+
201
+ # generate sine waveforms
202
+ sine_waves = self._f02sine(fn) * self.sine_amp
203
+
204
+ # generate uv signal
205
+ # uv = torch.ones(f0.shape)
206
+ # uv = uv * (f0 > self.voiced_threshold)
207
+ uv = self._f02uv(f0)
208
+
209
+ # noise: for unvoiced should be similar to sine_amp
210
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
211
+ # . for voiced regions is self.noise_std
212
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
213
+ noise = noise_amp * torch.randn_like(sine_waves)
214
+
215
+ # first: set the unvoiced part to 0 by uv
216
+ # then: additive noise
217
+ sine_waves = sine_waves * uv + noise
218
+ return sine_waves, uv, noise
219
+
220
+
221
+ class SourceModuleHnNSF(torch.nn.Module):
222
+ """ SourceModule for hn-nsf
223
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
224
+ add_noise_std=0.003, voiced_threshod=0)
225
+ sampling_rate: sampling_rate in Hz
226
+ harmonic_num: number of harmonic above F0 (default: 0)
227
+ sine_amp: amplitude of sine source signal (default: 0.1)
228
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
229
+ note that amplitude of noise in unvoiced is decided
230
+ by sine_amp
231
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
232
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
233
+ F0_sampled (batchsize, length, 1)
234
+ Sine_source (batchsize, length, 1)
235
+ noise_source (batchsize, length 1)
236
+ uv (batchsize, length, 1)
237
+ """
238
+
239
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
240
+ add_noise_std=0.003, voiced_threshod=0):
241
+ super(SourceModuleHnNSF, self).__init__()
242
+
243
+ self.sine_amp = sine_amp
244
+ self.noise_std = add_noise_std
245
+
246
+ # to produce sine waveforms
247
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
248
+ sine_amp, add_noise_std, voiced_threshod)
249
+
250
+ # to merge source harmonics into a single excitation
251
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
252
+ self.l_tanh = torch.nn.Tanh()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
257
+ F0_sampled (batchsize, length, 1)
258
+ Sine_source (batchsize, length, 1)
259
+ noise_source (batchsize, length 1)
260
+ """
261
+ # source for harmonic branch
262
+ with torch.no_grad():
263
+ sine_wavs, uv, _ = self.l_sin_gen(x)
264
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
265
+
266
+ # source for noise branch, in the same shape as uv
267
+ noise = torch.randn_like(uv) * self.sine_amp / 3
268
+ return sine_merge, noise, uv
269
+ def padDiff(x):
270
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
271
+
272
+ class Generator(torch.nn.Module):
273
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
274
+ super(Generator, self).__init__()
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+ resblock = AdaINResBlock1
278
+
279
+ self.m_source = SourceModuleHnNSF(
280
+ sampling_rate=24000,
281
+ upsample_scale=np.prod(upsample_rates),
282
+ harmonic_num=8, voiced_threshod=10)
283
+
284
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
285
+ self.noise_convs = nn.ModuleList()
286
+ self.ups = nn.ModuleList()
287
+ self.noise_res = nn.ModuleList()
288
+
289
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
290
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
291
+
292
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
293
+ upsample_initial_channel//(2**(i+1)),
294
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
295
+
296
+ if i + 1 < len(upsample_rates): #
297
+ stride_f0 = np.prod(upsample_rates[i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
300
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
301
+ else:
302
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
303
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
304
+
305
+ self.resblocks = nn.ModuleList()
306
+
307
+ self.alphas = nn.ParameterList()
308
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
309
+
310
+ for i in range(len(self.ups)):
311
+ ch = upsample_initial_channel//(2**(i+1))
312
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
313
+
314
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
315
+ self.resblocks.append(resblock(ch, k, d, style_dim))
316
+
317
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
318
+ self.ups.apply(init_weights)
319
+ self.conv_post.apply(init_weights)
320
+
321
+ def forward(self, x, s, f0):
322
+
323
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
324
+
325
+ har_source, noi_source, uv = self.m_source(f0)
326
+ har_source = har_source.transpose(1, 2)
327
+
328
+ for i in range(self.num_upsamples):
329
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
330
+ x_source = self.noise_convs[i](har_source)
331
+ x_source = self.noise_res[i](x_source, s)
332
+
333
+ x = self.ups[i](x)
334
+ x = x + x_source
335
+
336
+ xs = None
337
+ for j in range(self.num_kernels):
338
+ if xs is None:
339
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
340
+ else:
341
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
342
+ x = xs / self.num_kernels
343
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
344
+ x = self.conv_post(x)
345
+ x = torch.tanh(x)
346
+
347
+ return x
348
+
349
+ def remove_weight_norm(self):
350
+ print('Removing weight norm...')
351
+ for l in self.ups:
352
+ remove_weight_norm(l)
353
+ for l in self.resblocks:
354
+ l.remove_weight_norm()
355
+ remove_weight_norm(self.conv_pre)
356
+ remove_weight_norm(self.conv_post)
357
+
358
+
359
+ class AdainResBlk1d(nn.Module):
360
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
361
+ upsample='none', dropout_p=0.0):
362
+ super().__init__()
363
+ self.actv = actv
364
+ self.upsample_type = upsample
365
+ self.upsample = UpSample1d(upsample)
366
+ self.learned_sc = dim_in != dim_out
367
+ self._build_weights(dim_in, dim_out, style_dim)
368
+ self.dropout = nn.Dropout(dropout_p)
369
+
370
+ if upsample == 'none':
371
+ self.pool = nn.Identity()
372
+ else:
373
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
374
+
375
+
376
+ def _build_weights(self, dim_in, dim_out, style_dim):
377
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
378
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
379
+ self.norm1 = AdaIN1d(style_dim, dim_in)
380
+ self.norm2 = AdaIN1d(style_dim, dim_out)
381
+ if self.learned_sc:
382
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
383
+
384
+ def _shortcut(self, x):
385
+ x = self.upsample(x)
386
+ if self.learned_sc:
387
+ x = self.conv1x1(x)
388
+ return x
389
+
390
+ def _residual(self, x, s):
391
+ x = self.norm1(x, s)
392
+ x = self.actv(x)
393
+ x = self.pool(x)
394
+ x = self.conv1(self.dropout(x))
395
+ x = self.norm2(x, s)
396
+ x = self.actv(x)
397
+ x = self.conv2(self.dropout(x))
398
+ return x
399
+
400
+ def forward(self, x, s):
401
+ out = self._residual(x, s)
402
+ out = (out + self._shortcut(x)) / math.sqrt(2)
403
+ return out
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == 'none':
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode='nearest')
415
+
416
+ class Decoder(nn.Module):
417
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
418
+ resblock_kernel_sizes = [3,7,11],
419
+ upsample_rates = [10,5,3,2],
420
+ upsample_initial_channel=512,
421
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
422
+ upsample_kernel_sizes=[20,10,6,4]):
423
+ super().__init__()
424
+
425
+ self.decode = nn.ModuleList()
426
+
427
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
428
+
429
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
430
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
431
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
432
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
433
+
434
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
435
+
436
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
437
+
438
+ self.asr_res = nn.Sequential(
439
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
440
+ )
441
+
442
+
443
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
+
445
+
446
+ def forward(self, asr, F0_curve, N, s):
447
+ if self.training:
448
+ downlist = [0, 3, 7]
449
+ F0_down = downlist[random.randint(0, 2)]
450
+ downlist = [0, 3, 7, 15]
451
+ N_down = downlist[random.randint(0, 3)]
452
+ if F0_down:
453
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
454
+ if N_down:
455
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
456
+
457
+
458
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
459
+ N = self.N_conv(N.unsqueeze(1))
460
+
461
+ x = torch.cat([asr, F0, N], axis=1)
462
+ x = self.encode(x, s)
463
+
464
+ asr_res = self.asr_res(asr)
465
+
466
+ res = True
467
+ for block in self.decode:
468
+ if res:
469
+ x = torch.cat([x, asr_res, F0, N], axis=1)
470
+ x = block(x, s)
471
+ if block.upsample_type != "none":
472
+ res = False
473
+
474
+ x = self.generator(x, s, F0_curve)
475
+ return x
476
+
477
+
Modules/istftnet.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features*2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+ class AdaINResBlock1(torch.nn.Module):
28
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
29
+ super(AdaINResBlock1, self).__init__()
30
+ self.convs1 = nn.ModuleList([
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
32
+ padding=get_padding(kernel_size, dilation[0]))),
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
34
+ padding=get_padding(kernel_size, dilation[1]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
36
+ padding=get_padding(kernel_size, dilation[2])))
37
+ ])
38
+ self.convs1.apply(init_weights)
39
+
40
+ self.convs2 = nn.ModuleList([
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
42
+ padding=get_padding(kernel_size, 1))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1)))
47
+ ])
48
+ self.convs2.apply(init_weights)
49
+
50
+ self.adain1 = nn.ModuleList([
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ AdaIN1d(style_dim, channels),
54
+ ])
55
+
56
+ self.adain2 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
63
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
64
+
65
+
66
+ def forward(self, x, s):
67
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
68
+ xt = n1(x, s)
69
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
70
+ xt = c1(xt)
71
+ xt = n2(xt, s)
72
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
73
+ xt = c2(xt)
74
+ x = xt + x
75
+ return x
76
+
77
+ def remove_weight_norm(self):
78
+ for l in self.convs1:
79
+ remove_weight_norm(l)
80
+ for l in self.convs2:
81
+ remove_weight_norm(l)
82
+
83
+ class TorchSTFT(torch.nn.Module):
84
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
85
+ super().__init__()
86
+ self.filter_length = filter_length
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
90
+
91
+ def transform(self, input_data):
92
+ forward_transform = torch.stft(
93
+ input_data,
94
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
95
+ return_complex=True)
96
+
97
+ return torch.abs(forward_transform), torch.angle(forward_transform)
98
+
99
+ def inverse(self, magnitude, phase):
100
+ inverse_transform = torch.istft(
101
+ magnitude * torch.exp(phase * 1j),
102
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
103
+
104
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
105
+
106
+ def forward(self, input_data):
107
+ self.magnitude, self.phase = self.transform(input_data)
108
+ reconstruction = self.inverse(self.magnitude, self.phase)
109
+ return reconstruction
110
+
111
+ class SineGen(torch.nn.Module):
112
+ """ Definition of sine generator
113
+ SineGen(samp_rate, harmonic_num = 0,
114
+ sine_amp = 0.1, noise_std = 0.003,
115
+ voiced_threshold = 0,
116
+ flag_for_pulse=False)
117
+ samp_rate: sampling rate in Hz
118
+ harmonic_num: number of harmonic overtones (default 0)
119
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
120
+ noise_std: std of Gaussian noise (default 0.003)
121
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
122
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
123
+ Note: when flag_for_pulse is True, the first time step of a voiced
124
+ segment is always sin(np.pi) or cos(0)
125
+ """
126
+
127
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
128
+ sine_amp=0.1, noise_std=0.003,
129
+ voiced_threshold=0,
130
+ flag_for_pulse=False):
131
+ super(SineGen, self).__init__()
132
+ self.sine_amp = sine_amp
133
+ self.noise_std = noise_std
134
+ self.harmonic_num = harmonic_num
135
+ self.dim = self.harmonic_num + 1
136
+ self.sampling_rate = samp_rate
137
+ self.voiced_threshold = voiced_threshold
138
+ self.flag_for_pulse = flag_for_pulse
139
+ self.upsample_scale = upsample_scale
140
+
141
+ def _f02uv(self, f0):
142
+ # generate uv signal
143
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
144
+ return uv
145
+
146
+ def _f02sine(self, f0_values):
147
+ """ f0_values: (batchsize, length, dim)
148
+ where dim indicates fundamental tone and overtones
149
+ """
150
+ # convert to F0 in rad. The interger part n can be ignored
151
+ # because 2 * np.pi * n doesn't affect phase
152
+ rad_values = (f0_values / self.sampling_rate) % 1
153
+
154
+ # initial phase noise (no noise for fundamental component)
155
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
156
+ device=f0_values.device)
157
+ rand_ini[:, 0] = 0
158
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
159
+
160
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
161
+ if not self.flag_for_pulse:
162
+ # # for normal case
163
+
164
+ # # To prevent torch.cumsum numerical overflow,
165
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
166
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
167
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
168
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
169
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
170
+ # cumsum_shift = torch.zeros_like(rad_values)
171
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
172
+
173
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
174
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
175
+ scale_factor=1/self.upsample_scale,
176
+ mode="linear").transpose(1, 2)
177
+
178
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
179
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
180
+ # cumsum_shift = torch.zeros_like(rad_values)
181
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
182
+
183
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
184
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
185
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
186
+ sines = torch.sin(phase)
187
+
188
+ else:
189
+ # If necessary, make sure that the first time step of every
190
+ # voiced segments is sin(pi) or cos(0)
191
+ # This is used for pulse-train generation
192
+
193
+ # identify the last time step in unvoiced segments
194
+ uv = self._f02uv(f0_values)
195
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
196
+ uv_1[:, -1, :] = 1
197
+ u_loc = (uv < 1) * (uv_1 > 0)
198
+
199
+ # get the instantanouse phase
200
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
201
+ # different batch needs to be processed differently
202
+ for idx in range(f0_values.shape[0]):
203
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
204
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
205
+ # stores the accumulation of i.phase within
206
+ # each voiced segments
207
+ tmp_cumsum[idx, :, :] = 0
208
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
209
+
210
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
211
+ # within the previous voiced segment.
212
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
213
+
214
+ # get the sines
215
+ sines = torch.cos(i_phase * 2 * np.pi)
216
+ return sines
217
+
218
+ def forward(self, f0):
219
+ """ sine_tensor, uv = forward(f0)
220
+ input F0: tensor(batchsize=1, length, dim=1)
221
+ f0 for unvoiced steps should be 0
222
+ output sine_tensor: tensor(batchsize=1, length, dim)
223
+ output uv: tensor(batchsize=1, length, 1)
224
+ """
225
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
226
+ device=f0.device)
227
+ # fundamental component
228
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
229
+
230
+ # generate sine waveforms
231
+ sine_waves = self._f02sine(fn) * self.sine_amp
232
+
233
+ # generate uv signal
234
+ # uv = torch.ones(f0.shape)
235
+ # uv = uv * (f0 > self.voiced_threshold)
236
+ uv = self._f02uv(f0)
237
+
238
+ # noise: for unvoiced should be similar to sine_amp
239
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
240
+ # . for voiced regions is self.noise_std
241
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
242
+ noise = noise_amp * torch.randn_like(sine_waves)
243
+
244
+ # first: set the unvoiced part to 0 by uv
245
+ # then: additive noise
246
+ sine_waves = sine_waves * uv + noise
247
+ return sine_waves, uv, noise
248
+
249
+
250
+ class SourceModuleHnNSF(torch.nn.Module):
251
+ """ SourceModule for hn-nsf
252
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
253
+ add_noise_std=0.003, voiced_threshod=0)
254
+ sampling_rate: sampling_rate in Hz
255
+ harmonic_num: number of harmonic above F0 (default: 0)
256
+ sine_amp: amplitude of sine source signal (default: 0.1)
257
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
258
+ note that amplitude of noise in unvoiced is decided
259
+ by sine_amp
260
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
261
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
262
+ F0_sampled (batchsize, length, 1)
263
+ Sine_source (batchsize, length, 1)
264
+ noise_source (batchsize, length 1)
265
+ uv (batchsize, length, 1)
266
+ """
267
+
268
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
269
+ add_noise_std=0.003, voiced_threshod=0):
270
+ super(SourceModuleHnNSF, self).__init__()
271
+
272
+ self.sine_amp = sine_amp
273
+ self.noise_std = add_noise_std
274
+
275
+ # to produce sine waveforms
276
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
277
+ sine_amp, add_noise_std, voiced_threshod)
278
+
279
+ # to merge source harmonics into a single excitation
280
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
281
+ self.l_tanh = torch.nn.Tanh()
282
+
283
+ def forward(self, x):
284
+ """
285
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
286
+ F0_sampled (batchsize, length, 1)
287
+ Sine_source (batchsize, length, 1)
288
+ noise_source (batchsize, length 1)
289
+ """
290
+ # source for harmonic branch
291
+ with torch.no_grad():
292
+ sine_wavs, uv, _ = self.l_sin_gen(x)
293
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
294
+
295
+ # source for noise branch, in the same shape as uv
296
+ noise = torch.randn_like(uv) * self.sine_amp / 3
297
+ return sine_merge, noise, uv
298
+ def padDiff(x):
299
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
300
+
301
+
302
+ class Generator(torch.nn.Module):
303
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
304
+ super(Generator, self).__init__()
305
+
306
+ self.num_kernels = len(resblock_kernel_sizes)
307
+ self.num_upsamples = len(upsample_rates)
308
+ resblock = AdaINResBlock1
309
+
310
+ self.m_source = SourceModuleHnNSF(
311
+ sampling_rate=24000,
312
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
313
+ harmonic_num=8, voiced_threshod=10)
314
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
315
+ self.noise_convs = nn.ModuleList()
316
+ self.noise_res = nn.ModuleList()
317
+
318
+ self.ups = nn.ModuleList()
319
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
320
+ self.ups.append(weight_norm(
321
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
322
+ k, u, padding=(k-u)//2)))
323
+
324
+ self.resblocks = nn.ModuleList()
325
+ for i in range(len(self.ups)):
326
+ ch = upsample_initial_channel//(2**(i+1))
327
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
328
+ self.resblocks.append(resblock(ch, k, d, style_dim))
329
+
330
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
331
+
332
+ if i + 1 < len(upsample_rates): #
333
+ stride_f0 = np.prod(upsample_rates[i + 1:])
334
+ self.noise_convs.append(Conv1d(
335
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
336
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
337
+ else:
338
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
339
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
340
+
341
+
342
+ self.post_n_fft = gen_istft_n_fft
343
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
344
+ self.ups.apply(init_weights)
345
+ self.conv_post.apply(init_weights)
346
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
347
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
348
+
349
+
350
+ def forward(self, x, s, f0):
351
+ with torch.no_grad():
352
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
353
+
354
+ har_source, noi_source, uv = self.m_source(f0)
355
+ har_source = har_source.transpose(1, 2).squeeze(1)
356
+ har_spec, har_phase = self.stft.transform(har_source)
357
+ har = torch.cat([har_spec, har_phase], dim=1)
358
+
359
+ for i in range(self.num_upsamples):
360
+ x = F.leaky_relu(x, LRELU_SLOPE)
361
+ x_source = self.noise_convs[i](har)
362
+ x_source = self.noise_res[i](x_source, s)
363
+
364
+ x = self.ups[i](x)
365
+ if i == self.num_upsamples - 1:
366
+ x = self.reflection_pad(x)
367
+
368
+ x = x + x_source
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
373
+ else:
374
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
375
+ x = xs / self.num_kernels
376
+ x = F.leaky_relu(x)
377
+ x = self.conv_post(x)
378
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
379
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
380
+ return self.stft.inverse(spec, phase)
381
+
382
+ def fw_phase(self, x, s):
383
+ for i in range(self.num_upsamples):
384
+ x = F.leaky_relu(x, LRELU_SLOPE)
385
+ x = self.ups[i](x)
386
+ xs = None
387
+ for j in range(self.num_kernels):
388
+ if xs is None:
389
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
390
+ else:
391
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
392
+ x = xs / self.num_kernels
393
+ x = F.leaky_relu(x)
394
+ x = self.reflection_pad(x)
395
+ x = self.conv_post(x)
396
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
397
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
398
+ return spec, phase
399
+
400
+ def remove_weight_norm(self):
401
+ print('Removing weight norm...')
402
+ for l in self.ups:
403
+ remove_weight_norm(l)
404
+ for l in self.resblocks:
405
+ l.remove_weight_norm()
406
+ remove_weight_norm(self.conv_pre)
407
+ remove_weight_norm(self.conv_post)
408
+
409
+
410
+ class AdainResBlk1d(nn.Module):
411
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
412
+ upsample='none', dropout_p=0.0):
413
+ super().__init__()
414
+ self.actv = actv
415
+ self.upsample_type = upsample
416
+ self.upsample = UpSample1d(upsample)
417
+ self.learned_sc = dim_in != dim_out
418
+ self._build_weights(dim_in, dim_out, style_dim)
419
+ self.dropout = nn.Dropout(dropout_p)
420
+
421
+ if upsample == 'none':
422
+ self.pool = nn.Identity()
423
+ else:
424
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
425
+
426
+
427
+ def _build_weights(self, dim_in, dim_out, style_dim):
428
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
429
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
430
+ self.norm1 = AdaIN1d(style_dim, dim_in)
431
+ self.norm2 = AdaIN1d(style_dim, dim_out)
432
+ if self.learned_sc:
433
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
434
+
435
+ def _shortcut(self, x):
436
+ x = self.upsample(x)
437
+ if self.learned_sc:
438
+ x = self.conv1x1(x)
439
+ return x
440
+
441
+ def _residual(self, x, s):
442
+ x = self.norm1(x, s)
443
+ x = self.actv(x)
444
+ x = self.pool(x)
445
+ x = self.conv1(self.dropout(x))
446
+ x = self.norm2(x, s)
447
+ x = self.actv(x)
448
+ x = self.conv2(self.dropout(x))
449
+ return x
450
+
451
+ def forward(self, x, s):
452
+ out = self._residual(x, s)
453
+ out = (out + self._shortcut(x)) / math.sqrt(2)
454
+ return out
455
+
456
+ class UpSample1d(nn.Module):
457
+ def __init__(self, layer_type):
458
+ super().__init__()
459
+ self.layer_type = layer_type
460
+
461
+ def forward(self, x):
462
+ if self.layer_type == 'none':
463
+ return x
464
+ else:
465
+ return F.interpolate(x, scale_factor=2, mode='nearest')
466
+
467
+ class Decoder(nn.Module):
468
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
469
+ resblock_kernel_sizes = [3,7,11],
470
+ upsample_rates = [10, 6],
471
+ upsample_initial_channel=512,
472
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
473
+ upsample_kernel_sizes=[20, 12],
474
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
475
+ super().__init__()
476
+
477
+ self.decode = nn.ModuleList()
478
+
479
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
480
+
481
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
482
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
483
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
484
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
485
+
486
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
487
+
488
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
489
+
490
+ self.asr_res = nn.Sequential(
491
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
492
+ )
493
+
494
+
495
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
496
+ upsample_initial_channel, resblock_dilation_sizes,
497
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
498
+
499
+ def forward(self, asr, F0_curve, N, s):
500
+ if self.training:
501
+ downlist = [0, 3, 7]
502
+ F0_down = downlist[random.randint(0, 2)]
503
+ downlist = [0, 3, 7, 15]
504
+ N_down = downlist[random.randint(0, 3)]
505
+ if F0_down:
506
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
507
+ if N_down:
508
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
509
+
510
+
511
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
512
+ N = self.N_conv(N.unsqueeze(1))
513
+
514
+ x = torch.cat([asr, F0, N], axis=1)
515
+ x = self.encode(x, s)
516
+
517
+ asr_res = self.asr_res(asr)
518
+
519
+ res = True
520
+ for block in self.decode:
521
+ if res:
522
+ x = torch.cat([x, asr_res, F0, N], axis=1)
523
+ x = block(x, s)
524
+ if block.upsample_type != "none":
525
+ res = False
526
+
527
+ x = self.generator(x, s, F0_curve)
528
+ return x
529
+
530
+
Modules/slmadv.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ class SLMAdversarialLoss(torch.nn.Module):
6
+
7
+ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
8
+ super(SLMAdversarialLoss, self).__init__()
9
+ self.model = model
10
+ self.wl = wl
11
+ self.sampler = sampler
12
+
13
+ self.min_len = min_len
14
+ self.max_len = max_len
15
+ self.batch_percentage = batch_percentage
16
+
17
+ self.sig = sig
18
+ self.skip_update = skip_update
19
+
20
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
21
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
22
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
23
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
24
+
25
+ if use_ind and np.random.rand() < 0.5:
26
+ s_preds = s_trg
27
+ else:
28
+ num_steps = np.random.randint(3, 5)
29
+ if ref_s is not None:
30
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
31
+ embedding=bert_dur,
32
+ embedding_scale=1,
33
+ features=ref_s, # reference from the same speaker as the embedding
34
+ embedding_mask_proba=0.1,
35
+ num_steps=num_steps).squeeze(1)
36
+ else:
37
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
38
+ embedding=bert_dur,
39
+ embedding_scale=1,
40
+ embedding_mask_proba=0.1,
41
+ num_steps=num_steps).squeeze(1)
42
+
43
+ s_dur = s_preds[:, 128:]
44
+ s = s_preds[:, :128]
45
+
46
+ d, _ = self.model.predictor(d_en, s_dur,
47
+ ref_lengths,
48
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
49
+ text_mask)
50
+
51
+ bib = 0
52
+
53
+ output_lengths = []
54
+ attn_preds = []
55
+
56
+ # differentiable duration modeling
57
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
58
+
59
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
60
+
61
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
62
+ _dur_pred = _s2s_pred.sum(axis=-1)
63
+
64
+ l = int(torch.round(_s2s_pred.sum()).item())
65
+ t = torch.arange(0, l).expand(l)
66
+
67
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
68
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
69
+
70
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
71
+
72
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
73
+ h.unsqueeze(1),
74
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
75
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
76
+
77
+ output_lengths.append(l)
78
+
79
+ max_len = max(output_lengths)
80
+
81
+ with torch.no_grad():
82
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
83
+
84
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
85
+ for bib in range(len(output_lengths)):
86
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
87
+
88
+ asr_pred = t_en @ s2s_attn
89
+
90
+ _, p_pred = self.model.predictor(d_en, s_dur,
91
+ ref_lengths,
92
+ s2s_attn,
93
+ text_mask)
94
+
95
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
96
+ mel_len = min(mel_len, self.max_len // 2)
97
+
98
+ # get clips
99
+
100
+ en = []
101
+ p_en = []
102
+ sp = []
103
+
104
+ F0_fakes = []
105
+ N_fakes = []
106
+
107
+ wav = []
108
+
109
+ for bib in range(len(output_lengths)):
110
+ mel_length_pred = output_lengths[bib]
111
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
112
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
113
+ continue
114
+
115
+ sp.append(s_preds[bib])
116
+
117
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
118
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
119
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
120
+
121
+ # get ground truth clips
122
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
123
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
124
+ wav.append(torch.from_numpy(y).to(ref_text.device))
125
+
126
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
127
+ break
128
+
129
+ if len(sp) <= 1:
130
+ return None
131
+
132
+ sp = torch.stack(sp)
133
+ wav = torch.stack(wav).float()
134
+ en = torch.stack(en)
135
+ p_en = torch.stack(p_en)
136
+
137
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
138
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
139
+
140
+ # discriminator loss
141
+ if (iters + 1) % self.skip_update == 0:
142
+ if np.random.randint(0, 2) == 0:
143
+ wav = y_rec_gt_pred
144
+ use_rec = True
145
+ else:
146
+ use_rec = False
147
+
148
+ crop_size = min(wav.size(-1), y_pred.size(-1))
149
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
150
+ if wav.size(-1) > y_pred.size(-1):
151
+ real_GP = wav[:, : , :crop_size]
152
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
153
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
154
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
155
+
156
+ if np.random.randint(0, 2) == 0:
157
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
158
+ else:
159
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
160
+ else:
161
+ real_GP = y_pred[:, : , :crop_size]
162
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
163
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
164
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
165
+
166
+ if np.random.randint(0, 2) == 0:
167
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
168
+ else:
169
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
170
+
171
+ # regularization (ignore length variation)
172
+ d_loss += loss_reg
173
+
174
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
175
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
176
+
177
+ # regularization (ignore reconstruction artifacts)
178
+ d_loss += F.l1_loss(out_gt, out_rec)
179
+
180
+ else:
181
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
182
+ else:
183
+ d_loss = 0
184
+
185
+ # generator loss
186
+ gen_loss = self.wl.generator(y_pred.squeeze())
187
+
188
+ gen_loss = gen_loss.mean()
189
+
190
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
191
+
192
+ def length_to_mask(lengths):
193
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
194
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
195
+ return mask
Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size*dilation - dilation)/2)
Utils/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
Utils/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
Utils/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments
Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
Utils/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
Utils/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
Utils/PLBERT/step_1000000.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
Utils/PLBERT/util.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+ class CustomAlbert(AlbertModel):
7
+ def forward(self, *args, **kwargs):
8
+ # Call the original forward method
9
+ outputs = super().forward(*args, **kwargs)
10
+
11
+ # Only return the last_hidden_state
12
+ return outputs.last_hidden_state
13
+
14
+
15
+ def load_plbert(log_dir):
16
+ config_path = os.path.join(log_dir, "config.yml")
17
+ plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
+ bert = CustomAlbert(albert_base_configuration)
21
+
22
+ files = os.listdir(log_dir)
23
+ ckpts = []
24
+ for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
+ iters = sorted(iters)[-1]
29
+
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
+ from collections import OrderedDict
33
+ new_state_dict = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
+ new_state_dict[name] = v
39
+ del new_state_dict["embeddings.position_ids"]
40
+ bert.load_state_dict(new_state_dict, strict=False)
41
+
42
+ return bert
Utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StyleTTS 2 HTTP Streaming API by @fakerybakery - Copyright (c) 2023 mrfakename. All rights reserved.
2
+ # Docs: API_DOCS.md
3
+ # To-Do:
4
+ # * Support voice cloning
5
+ # * Implement authentication, user "credits" system w/ SQLite3
6
+ import io
7
+ import markdown
8
+ from tortoise.utils.text import split_and_recombine_text
9
+ from flask import Flask, Response, request, jsonify
10
+ import numpy as np
11
+ import ljinference
12
+ import torch
13
+ import hashlib
14
+ from scipy.io.wavfile import read, write
15
+ from flask_cors import CORS
16
+ import os
17
+ import torchaudio
18
+
19
+ def genHeader(sampleRate, bitsPerSample, channels):
20
+ datasize = 2000 * 10**6
21
+ o = bytes("RIFF", "ascii")
22
+ o += (datasize + 36).to_bytes(4, "little")
23
+ o += bytes("WAVE", "ascii")
24
+ o += bytes("fmt ", "ascii")
25
+ o += (16).to_bytes(4, "little")
26
+ o += (1).to_bytes(2, "little")
27
+ o += (channels).to_bytes(2, "little")
28
+ o += (sampleRate).to_bytes(4, "little")
29
+ o += (sampleRate * channels * bitsPerSample // 8).to_bytes(4, "little")
30
+ o += (channels * bitsPerSample // 8).to_bytes(2, "little")
31
+ o += (bitsPerSample).to_bytes(2, "little")
32
+ o += bytes("data", "ascii")
33
+ o += (datasize).to_bytes(4, "little")
34
+ return o
35
+
36
+ import phonemizer
37
+ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
38
+ print("Starting Flask app")
39
+
40
+ app = Flask(__name__)
41
+ cors = CORS(app)
42
+
43
+ @app.route("/")
44
+ def index():
45
+ with open('API_DOCS.md', 'r') as f:
46
+ return markdown.markdown(f.read())
47
+
48
+ cache_dir = 'cache'
49
+ if not os.path.exists(cache_dir):
50
+ os.makedirs(cache_dir, exist_ok=True)
51
+
52
+ @app.route("/api", methods=['GET', 'POST'])
53
+ def serve_wav():
54
+ if request.method == 'GET':
55
+ request.form = request.args
56
+ if 'text' not in request.form:
57
+ if 'text' not in request.json:
58
+ error_response = {'error': 'Missing required fields. Please include "text" in your request.'}
59
+ return jsonify(error_response), 400
60
+ else:
61
+ text = request.json['text']
62
+ else:
63
+ text = request.form['text'].strip()
64
+
65
+ texts = split_and_recombine_text(text)
66
+ audios = []
67
+ noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
68
+ for t in texts:
69
+ # check for cache
70
+ hash = hashlib.sha256(t.encode()).hexdigest()
71
+ if os.path.exists(os.path.join(cache_dir, hash + '.wav')):
72
+ audios.append(read(os.path.join(cache_dir, hash + '.wav'))[1])
73
+ else:
74
+ aud = ljinference.inference(t, noise, diffusion_steps=7, embedding_scale=1)
75
+ write(os.path.join(cache_dir, hash + '.wav'), 24000, aud)
76
+ audios.append(aud)
77
+ output_buffer = io.BytesIO()
78
+ write(output_buffer, 24000, np.concatenate(audios))
79
+ response = Response(output_buffer.getvalue())
80
+ response.headers["Content-Type"] = "audio/wav"
81
+ return response
82
+ if __name__ == "__main__":
83
+ app.run("0.0.0.0")
ljinference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cached_path import cached_path
2
+
3
+
4
+ import torch
5
+ torch.manual_seed(0)
6
+ torch.backends.cudnn.benchmark = False
7
+ torch.backends.cudnn.deterministic = True
8
+
9
+ import random
10
+ random.seed(0)
11
+
12
+ import numpy as np
13
+ np.random.seed(0)
14
+
15
+ import nltk
16
+ nltk.download('punkt')
17
+
18
+ # load packages
19
+ import time
20
+ import random
21
+ import yaml
22
+ from munch import Munch
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+ import torchaudio
28
+ import librosa
29
+ from nltk.tokenize import word_tokenize
30
+
31
+ from models import *
32
+ from utils import *
33
+ from text_utils import TextCleaner
34
+ textclenaer = TextCleaner()
35
+
36
+
37
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+
39
+ to_mel = torchaudio.transforms.MelSpectrogram(
40
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
41
+ mean, std = -4, 4
42
+
43
+ def length_to_mask(lengths):
44
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
45
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
46
+ return mask
47
+
48
+ def preprocess(wave):
49
+ wave_tensor = torch.from_numpy(wave).float()
50
+ mel_tensor = to_mel(wave_tensor)
51
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
52
+ return mel_tensor
53
+
54
+ def compute_style(ref_dicts):
55
+ reference_embeddings = {}
56
+ for key, path in ref_dicts.items():
57
+ wave, sr = librosa.load(path, sr=24000)
58
+ audio, index = librosa.effects.trim(wave, top_db=30)
59
+ if sr != 24000:
60
+ audio = librosa.resample(audio, sr, 24000)
61
+ mel_tensor = preprocess(audio).to(device)
62
+
63
+ with torch.no_grad():
64
+ ref = model.style_encoder(mel_tensor.unsqueeze(1))
65
+ reference_embeddings[key] = (ref.squeeze(1), audio)
66
+
67
+ return reference_embeddings
68
+
69
+ # load phonemizer
70
+ import phonemizer
71
+ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
72
+
73
+ # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
74
+
75
+
76
+ config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml'))))
77
+
78
+ # load pretrained ASR model
79
+ ASR_config = config.get('ASR_config', False)
80
+ ASR_path = config.get('ASR_path', False)
81
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
82
+
83
+ # load pretrained F0 model
84
+ F0_path = config.get('F0_path', False)
85
+ pitch_extractor = load_F0_models(F0_path)
86
+
87
+ # load BERT model
88
+ from Utils.PLBERT.util import load_plbert
89
+ BERT_path = config.get('PLBERT_dir', False)
90
+ plbert = load_plbert(BERT_path)
91
+
92
+ model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
93
+ _ = [model[key].eval() for key in model]
94
+ _ = [model[key].to(device) for key in model]
95
+
96
+ # params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
97
+ params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu')
98
+ params = params_whole['net']
99
+
100
+ for key in model:
101
+ if key in params:
102
+ print('%s loaded' % key)
103
+ try:
104
+ model[key].load_state_dict(params[key])
105
+ except:
106
+ from collections import OrderedDict
107
+ state_dict = params[key]
108
+ new_state_dict = OrderedDict()
109
+ for k, v in state_dict.items():
110
+ name = k[7:] # remove `module.`
111
+ new_state_dict[name] = v
112
+ # load params
113
+ model[key].load_state_dict(new_state_dict, strict=False)
114
+ # except:
115
+ # _load(params[key], model[key])
116
+ _ = [model[key].eval() for key in model]
117
+
118
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
119
+
120
+ sampler = DiffusionSampler(
121
+ model.diffusion.diffusion,
122
+ sampler=ADPM2Sampler(),
123
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
124
+ clamp=False
125
+ )
126
+
127
+ def inference(text, noise, diffusion_steps=5, embedding_scale=1):
128
+ text = text.strip()
129
+ text = text.replace('"', '')
130
+ ps = global_phonemizer.phonemize([text])
131
+ ps = word_tokenize(ps[0])
132
+ ps = ' '.join(ps)
133
+
134
+ tokens = textclenaer(ps)
135
+ tokens.insert(0, 0)
136
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
137
+
138
+ with torch.no_grad():
139
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
140
+ text_mask = length_to_mask(input_lengths).to(tokens.device)
141
+
142
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
143
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
144
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
145
+
146
+ s_pred = sampler(noise,
147
+ embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
148
+ embedding_scale=embedding_scale).squeeze(0)
149
+
150
+ s = s_pred[:, 128:]
151
+ ref = s_pred[:, :128]
152
+
153
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
154
+
155
+ x, _ = model.predictor.lstm(d)
156
+ duration = model.predictor.duration_proj(x)
157
+ duration = torch.sigmoid(duration).sum(axis=-1)
158
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
159
+
160
+ pred_dur[-1] += 5
161
+
162
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
163
+ c_frame = 0
164
+ for i in range(pred_aln_trg.size(0)):
165
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
166
+ c_frame += int(pred_dur[i].data)
167
+
168
+ # encode prosody
169
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
170
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
171
+ out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
172
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
173
+
174
+ return out.squeeze().cpu().numpy()
175
+
176
+ def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
177
+ text = text.strip()
178
+ text = text.replace('"', '')
179
+ ps = global_phonemizer.phonemize([text])
180
+ ps = word_tokenize(ps[0])
181
+ ps = ' '.join(ps)
182
+
183
+ tokens = textclenaer(ps)
184
+ tokens.insert(0, 0)
185
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
186
+
187
+ with torch.no_grad():
188
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
189
+ text_mask = length_to_mask(input_lengths).to(tokens.device)
190
+
191
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
192
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
193
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
194
+
195
+ s_pred = sampler(noise,
196
+ embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
197
+ embedding_scale=embedding_scale).squeeze(0)
198
+
199
+ if s_prev is not None:
200
+ # convex combination of previous and current style
201
+ s_pred = alpha * s_prev + (1 - alpha) * s_pred
202
+
203
+ s = s_pred[:, 128:]
204
+ ref = s_pred[:, :128]
205
+
206
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
207
+
208
+ x, _ = model.predictor.lstm(d)
209
+ duration = model.predictor.duration_proj(x)
210
+ duration = torch.sigmoid(duration).sum(axis=-1)
211
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
212
+
213
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
214
+ c_frame = 0
215
+ for i in range(pred_aln_trg.size(0)):
216
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
217
+ c_frame += int(pred_dur[i].data)
218
+
219
+ # encode prosody
220
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
221
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
222
+ out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
223
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
224
+
225
+ return out.squeeze().cpu().numpy(), s_pred
losses.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from transformers import AutoModel
6
+
7
+ class SpectralConvergengeLoss(torch.nn.Module):
8
+ """Spectral convergence loss module."""
9
+
10
+ def __init__(self):
11
+ """Initilize spectral convergence loss module."""
12
+ super(SpectralConvergengeLoss, self).__init__()
13
+
14
+ def forward(self, x_mag, y_mag):
15
+ """Calculate forward propagation.
16
+ Args:
17
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
18
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
19
+ Returns:
20
+ Tensor: Spectral convergence loss value.
21
+ """
22
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
23
+
24
+ class STFTLoss(torch.nn.Module):
25
+ """STFT loss module."""
26
+
27
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
28
+ """Initialize STFT loss module."""
29
+ super(STFTLoss, self).__init__()
30
+ self.fft_size = fft_size
31
+ self.shift_size = shift_size
32
+ self.win_length = win_length
33
+ self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=24000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)
34
+
35
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
36
+
37
+ def forward(self, x, y):
38
+ """Calculate forward propagation.
39
+ Args:
40
+ x (Tensor): Predicted signal (B, T).
41
+ y (Tensor): Groundtruth signal (B, T).
42
+ Returns:
43
+ Tensor: Spectral convergence loss value.
44
+ Tensor: Log STFT magnitude loss value.
45
+ """
46
+ x_mag = self.to_mel(x)
47
+ mean, std = -4, 4
48
+ x_mag = (torch.log(1e-5 + x_mag) - mean) / std
49
+
50
+ y_mag = self.to_mel(y)
51
+ mean, std = -4, 4
52
+ y_mag = (torch.log(1e-5 + y_mag) - mean) / std
53
+
54
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
55
+ return sc_loss
56
+
57
+
58
+ class MultiResolutionSTFTLoss(torch.nn.Module):
59
+ """Multi resolution STFT loss module."""
60
+
61
+ def __init__(self,
62
+ fft_sizes=[1024, 2048, 512],
63
+ hop_sizes=[120, 240, 50],
64
+ win_lengths=[600, 1200, 240],
65
+ window=torch.hann_window):
66
+ """Initialize Multi resolution STFT loss module.
67
+ Args:
68
+ fft_sizes (list): List of FFT sizes.
69
+ hop_sizes (list): List of hop sizes.
70
+ win_lengths (list): List of window lengths.
71
+ window (str): Window function type.
72
+ """
73
+ super(MultiResolutionSTFTLoss, self).__init__()
74
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
75
+ self.stft_losses = torch.nn.ModuleList()
76
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
77
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
78
+
79
+ def forward(self, x, y):
80
+ """Calculate forward propagation.
81
+ Args:
82
+ x (Tensor): Predicted signal (B, T).
83
+ y (Tensor): Groundtruth signal (B, T).
84
+ Returns:
85
+ Tensor: Multi resolution spectral convergence loss value.
86
+ Tensor: Multi resolution log STFT magnitude loss value.
87
+ """
88
+ sc_loss = 0.0
89
+ for f in self.stft_losses:
90
+ sc_l = f(x, y)
91
+ sc_loss += sc_l
92
+ sc_loss /= len(self.stft_losses)
93
+
94
+ return sc_loss
95
+
96
+
97
+ def feature_loss(fmap_r, fmap_g):
98
+ loss = 0
99
+ for dr, dg in zip(fmap_r, fmap_g):
100
+ for rl, gl in zip(dr, dg):
101
+ loss += torch.mean(torch.abs(rl - gl))
102
+
103
+ return loss*2
104
+
105
+
106
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
107
+ loss = 0
108
+ r_losses = []
109
+ g_losses = []
110
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
111
+ r_loss = torch.mean((1-dr)**2)
112
+ g_loss = torch.mean(dg**2)
113
+ loss += (r_loss + g_loss)
114
+ r_losses.append(r_loss.item())
115
+ g_losses.append(g_loss.item())
116
+
117
+ return loss, r_losses, g_losses
118
+
119
+
120
+ def generator_loss(disc_outputs):
121
+ loss = 0
122
+ gen_losses = []
123
+ for dg in disc_outputs:
124
+ l = torch.mean((1-dg)**2)
125
+ gen_losses.append(l)
126
+ loss += l
127
+
128
+ return loss, gen_losses
129
+
130
+ """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
131
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
132
+ loss = 0
133
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
134
+ tau = 0.04
135
+ m_DG = torch.median((dr-dg))
136
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
137
+ loss += tau - F.relu(tau - L_rel)
138
+ return loss
139
+
140
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
141
+ loss = 0
142
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
143
+ tau = 0.04
144
+ m_DG = torch.median((dr-dg))
145
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
146
+ loss += tau - F.relu(tau - L_rel)
147
+ return loss
148
+
149
+ class GeneratorLoss(torch.nn.Module):
150
+
151
+ def __init__(self, mpd, msd):
152
+ super(GeneratorLoss, self).__init__()
153
+ self.mpd = mpd
154
+ self.msd = msd
155
+
156
+ def forward(self, y, y_hat):
157
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
158
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
159
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
160
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
161
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
162
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
163
+
164
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
165
+
166
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
167
+
168
+ return loss_gen_all.mean()
169
+
170
+ class DiscriminatorLoss(torch.nn.Module):
171
+
172
+ def __init__(self, mpd, msd):
173
+ super(DiscriminatorLoss, self).__init__()
174
+ self.mpd = mpd
175
+ self.msd = msd
176
+
177
+ def forward(self, y, y_hat):
178
+ # MPD
179
+ y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
180
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
181
+ # MSD
182
+ y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
183
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
184
+
185
+ loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
186
+
187
+
188
+ d_loss = loss_disc_s + loss_disc_f + loss_rel
189
+
190
+ return d_loss.mean()
191
+
192
+
193
+ class WavLMLoss(torch.nn.Module):
194
+
195
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
196
+ super(WavLMLoss, self).__init__()
197
+ self.wavlm = AutoModel.from_pretrained(model)
198
+ self.wd = wd
199
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
200
+
201
+ def forward(self, wav, y_rec):
202
+ with torch.no_grad():
203
+ wav_16 = self.resample(wav)
204
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
205
+ y_rec_16 = self.resample(y_rec)
206
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True).hidden_states
207
+
208
+ floss = 0
209
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
210
+ floss += torch.mean(torch.abs(er - eg))
211
+
212
+ return floss.mean()
213
+
214
+ def generator(self, y_rec):
215
+ y_rec_16 = self.resample(y_rec)
216
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
217
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
218
+ y_df_hat_g = self.wd(y_rec_embeddings)
219
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
220
+
221
+ return loss_gen
222
+
223
+ def discriminator(self, wav, y_rec):
224
+ with torch.no_grad():
225
+ wav_16 = self.resample(wav)
226
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
227
+ y_rec_16 = self.resample(y_rec)
228
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
229
+
230
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
231
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
232
+
233
+ y_d_rs = self.wd(y_embeddings)
234
+ y_d_gs = self.wd(y_rec_embeddings)
235
+
236
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
237
+
238
+ r_loss = torch.mean((1-y_df_hat_r)**2)
239
+ g_loss = torch.mean((y_df_hat_g)**2)
240
+
241
+ loss_disc_f = r_loss + g_loss
242
+
243
+ return loss_disc_f.mean()
244
+
245
+ def discriminator_forward(self, wav):
246
+ with torch.no_grad():
247
+ wav_16 = self.resample(wav)
248
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
249
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
250
+
251
+ y_d_rs = self.wd(y_embeddings)
252
+
253
+ return y_d_rs
meldataset.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding: utf-8
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import random
6
+ import numpy as np
7
+ import random
8
+ import soundfile as sf
9
+ import librosa
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ from torch.utils.data import DataLoader
16
+
17
+ import logging
18
+ logger = logging.getLogger(__name__)
19
+ logger.setLevel(logging.DEBUG)
20
+
21
+ import pandas as pd
22
+
23
+ _pad = "$"
24
+ _punctuation = ';:,.!?¡¿—…"«»“” '
25
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
26
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
27
+
28
+ # Export all symbols:
29
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
30
+
31
+ dicts = {}
32
+ for i in range(len((symbols))):
33
+ dicts[symbols[i]] = i
34
+
35
+ class TextCleaner:
36
+ def __init__(self, dummy=None):
37
+ self.word_index_dictionary = dicts
38
+ def __call__(self, text):
39
+ indexes = []
40
+ for char in text:
41
+ try:
42
+ indexes.append(self.word_index_dictionary[char])
43
+ except KeyError:
44
+ print(text)
45
+ return indexes
46
+
47
+ np.random.seed(1)
48
+ random.seed(1)
49
+ SPECT_PARAMS = {
50
+ "n_fft": 2048,
51
+ "win_length": 1200,
52
+ "hop_length": 300
53
+ }
54
+ MEL_PARAMS = {
55
+ "n_mels": 80,
56
+ }
57
+
58
+ to_mel = torchaudio.transforms.MelSpectrogram(
59
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
60
+ mean, std = -4, 4
61
+
62
+ def preprocess(wave):
63
+ wave_tensor = torch.from_numpy(wave).float()
64
+ mel_tensor = to_mel(wave_tensor)
65
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
66
+ return mel_tensor
67
+
68
+ class FilePathDataset(torch.utils.data.Dataset):
69
+ def __init__(self,
70
+ data_list,
71
+ root_path,
72
+ sr=24000,
73
+ data_augmentation=False,
74
+ validation=False,
75
+ OOD_data="Data/OOD_texts.txt",
76
+ min_length=50,
77
+ ):
78
+
79
+ spect_params = SPECT_PARAMS
80
+ mel_params = MEL_PARAMS
81
+
82
+ _data_list = [l.strip().split('|') for l in data_list]
83
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
84
+ self.text_cleaner = TextCleaner()
85
+ self.sr = sr
86
+
87
+ self.df = pd.DataFrame(self.data_list)
88
+
89
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
90
+
91
+ self.mean, self.std = -4, 4
92
+ self.data_augmentation = data_augmentation and (not validation)
93
+ self.max_mel_length = 192
94
+
95
+ self.min_length = min_length
96
+ with open(OOD_data, 'r', encoding='utf-8') as f:
97
+ tl = f.readlines()
98
+ idx = 1 if '.wav' in tl[0].split('|')[0] else 0
99
+ self.ptexts = [t.split('|')[idx] for t in tl]
100
+
101
+ self.root_path = root_path
102
+
103
+ def __len__(self):
104
+ return len(self.data_list)
105
+
106
+ def __getitem__(self, idx):
107
+ data = self.data_list[idx]
108
+ path = data[0]
109
+
110
+ wave, text_tensor, speaker_id = self._load_tensor(data)
111
+
112
+ mel_tensor = preprocess(wave).squeeze()
113
+
114
+ acoustic_feature = mel_tensor.squeeze()
115
+ length_feature = acoustic_feature.size(1)
116
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
117
+
118
+ # get reference sample
119
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
120
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
121
+
122
+ # get OOD text
123
+
124
+ ps = ""
125
+
126
+ while len(ps) < self.min_length:
127
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
128
+ ps = self.ptexts[rand_idx]
129
+
130
+ text = self.text_cleaner(ps)
131
+ text.insert(0, 0)
132
+ text.append(0)
133
+
134
+ ref_text = torch.LongTensor(text)
135
+
136
+ return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave
137
+
138
+ def _load_tensor(self, data):
139
+ wave_path, text, speaker_id = data
140
+ speaker_id = int(speaker_id)
141
+ wave, sr = sf.read(osp.join(self.root_path, wave_path))
142
+ if wave.shape[-1] == 2:
143
+ wave = wave[:, 0].squeeze()
144
+ if sr != 24000:
145
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
146
+ print(wave_path, sr)
147
+
148
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
149
+
150
+ text = self.text_cleaner(text)
151
+
152
+ text.insert(0, 0)
153
+ text.append(0)
154
+
155
+ text = torch.LongTensor(text)
156
+
157
+ return wave, text, speaker_id
158
+
159
+ def _load_data(self, data):
160
+ wave, text_tensor, speaker_id = self._load_tensor(data)
161
+ mel_tensor = preprocess(wave).squeeze()
162
+
163
+ mel_length = mel_tensor.size(1)
164
+ if mel_length > self.max_mel_length:
165
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
166
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
167
+
168
+ return mel_tensor, speaker_id
169
+
170
+
171
+ class Collater(object):
172
+ """
173
+ Args:
174
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
175
+ """
176
+
177
+ def __init__(self, return_wave=False):
178
+ self.text_pad_index = 0
179
+ self.min_mel_length = 192
180
+ self.max_mel_length = 192
181
+ self.return_wave = return_wave
182
+
183
+
184
+ def __call__(self, batch):
185
+ # batch[0] = wave, mel, text, f0, speakerid
186
+ batch_size = len(batch)
187
+
188
+ # sort by mel length
189
+ lengths = [b[1].shape[1] for b in batch]
190
+ batch_indexes = np.argsort(lengths)[::-1]
191
+ batch = [batch[bid] for bid in batch_indexes]
192
+
193
+ nmels = batch[0][1].size(0)
194
+ max_mel_length = max([b[1].shape[1] for b in batch])
195
+ max_text_length = max([b[2].shape[0] for b in batch])
196
+ max_rtext_length = max([b[3].shape[0] for b in batch])
197
+
198
+ labels = torch.zeros((batch_size)).long()
199
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
200
+ texts = torch.zeros((batch_size, max_text_length)).long()
201
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
202
+
203
+ input_lengths = torch.zeros(batch_size).long()
204
+ ref_lengths = torch.zeros(batch_size).long()
205
+ output_lengths = torch.zeros(batch_size).long()
206
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
207
+ ref_labels = torch.zeros((batch_size)).long()
208
+ paths = ['' for _ in range(batch_size)]
209
+ waves = [None for _ in range(batch_size)]
210
+
211
+ for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch):
212
+ mel_size = mel.size(1)
213
+ text_size = text.size(0)
214
+ rtext_size = ref_text.size(0)
215
+ labels[bid] = label
216
+ mels[bid, :, :mel_size] = mel
217
+ texts[bid, :text_size] = text
218
+ ref_texts[bid, :rtext_size] = ref_text
219
+ input_lengths[bid] = text_size
220
+ ref_lengths[bid] = rtext_size
221
+ output_lengths[bid] = mel_size
222
+ paths[bid] = path
223
+ ref_mel_size = ref_mel.size(1)
224
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
225
+
226
+ ref_labels[bid] = ref_label
227
+ waves[bid] = wave
228
+
229
+ return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels
230
+
231
+
232
+
233
+ def build_dataloader(path_list,
234
+ root_path,
235
+ validation=False,
236
+ OOD_data="Data/OOD_texts.txt",
237
+ min_length=50,
238
+ batch_size=4,
239
+ num_workers=1,
240
+ device='cpu',
241
+ collate_config={},
242
+ dataset_config={}):
243
+
244
+ dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config)
245
+ collate_fn = Collater(**collate_config)
246
+ data_loader = DataLoader(dataset,
247
+ batch_size=batch_size,
248
+ shuffle=(not validation),
249
+ num_workers=num_workers,
250
+ drop_last=(not validation),
251
+ collate_fn=collate_fn,
252
+ pin_memory=(device != 'cpu'))
253
+
254
+ return data_loader
255
+
models.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+
3
+ import os
4
+ import os.path as osp
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+ from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
+ from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
+ from Modules.diffusion.diffusion import AudioDiffusionConditional
21
+
22
+ from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
23
+
24
+ from munch import Munch
25
+ import yaml
26
+
27
+ class LearnedDownSample(nn.Module):
28
+ def __init__(self, layer_type, dim_in):
29
+ super().__init__()
30
+ self.layer_type = layer_type
31
+
32
+ if self.layer_type == 'none':
33
+ self.conv = nn.Identity()
34
+ elif self.layer_type == 'timepreserve':
35
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
36
+ elif self.layer_type == 'half':
37
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
38
+ else:
39
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
40
+
41
+ def forward(self, x):
42
+ return self.conv(x)
43
+
44
+ class LearnedUpSample(nn.Module):
45
+ def __init__(self, layer_type, dim_in):
46
+ super().__init__()
47
+ self.layer_type = layer_type
48
+
49
+ if self.layer_type == 'none':
50
+ self.conv = nn.Identity()
51
+ elif self.layer_type == 'timepreserve':
52
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
53
+ elif self.layer_type == 'half':
54
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
55
+ else:
56
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
57
+
58
+
59
+ def forward(self, x):
60
+ return self.conv(x)
61
+
62
+ class DownSample(nn.Module):
63
+ def __init__(self, layer_type):
64
+ super().__init__()
65
+ self.layer_type = layer_type
66
+
67
+ def forward(self, x):
68
+ if self.layer_type == 'none':
69
+ return x
70
+ elif self.layer_type == 'timepreserve':
71
+ return F.avg_pool2d(x, (2, 1))
72
+ elif self.layer_type == 'half':
73
+ if x.shape[-1] % 2 != 0:
74
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
75
+ return F.avg_pool2d(x, 2)
76
+ else:
77
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
78
+
79
+
80
+ class UpSample(nn.Module):
81
+ def __init__(self, layer_type):
82
+ super().__init__()
83
+ self.layer_type = layer_type
84
+
85
+ def forward(self, x):
86
+ if self.layer_type == 'none':
87
+ return x
88
+ elif self.layer_type == 'timepreserve':
89
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
90
+ elif self.layer_type == 'half':
91
+ return F.interpolate(x, scale_factor=2, mode='nearest')
92
+ else:
93
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
94
+
95
+
96
+ class ResBlk(nn.Module):
97
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
98
+ normalize=False, downsample='none'):
99
+ super().__init__()
100
+ self.actv = actv
101
+ self.normalize = normalize
102
+ self.downsample = DownSample(downsample)
103
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
104
+ self.learned_sc = dim_in != dim_out
105
+ self._build_weights(dim_in, dim_out)
106
+
107
+ def _build_weights(self, dim_in, dim_out):
108
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
109
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
110
+ if self.normalize:
111
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
112
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
113
+ if self.learned_sc:
114
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
115
+
116
+ def _shortcut(self, x):
117
+ if self.learned_sc:
118
+ x = self.conv1x1(x)
119
+ if self.downsample:
120
+ x = self.downsample(x)
121
+ return x
122
+
123
+ def _residual(self, x):
124
+ if self.normalize:
125
+ x = self.norm1(x)
126
+ x = self.actv(x)
127
+ x = self.conv1(x)
128
+ x = self.downsample_res(x)
129
+ if self.normalize:
130
+ x = self.norm2(x)
131
+ x = self.actv(x)
132
+ x = self.conv2(x)
133
+ return x
134
+
135
+ def forward(self, x):
136
+ x = self._shortcut(x) + self._residual(x)
137
+ return x / math.sqrt(2) # unit variance
138
+
139
+ class StyleEncoder(nn.Module):
140
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
141
+ super().__init__()
142
+ blocks = []
143
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
144
+
145
+ repeat_num = 4
146
+ for _ in range(repeat_num):
147
+ dim_out = min(dim_in*2, max_conv_dim)
148
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
149
+ dim_in = dim_out
150
+
151
+ blocks += [nn.LeakyReLU(0.2)]
152
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
153
+ blocks += [nn.AdaptiveAvgPool2d(1)]
154
+ blocks += [nn.LeakyReLU(0.2)]
155
+ self.shared = nn.Sequential(*blocks)
156
+
157
+ self.unshared = nn.Linear(dim_out, style_dim)
158
+
159
+ def forward(self, x):
160
+ h = self.shared(x)
161
+ h = h.view(h.size(0), -1)
162
+ s = self.unshared(h)
163
+
164
+ return s
165
+
166
+ class LinearNorm(torch.nn.Module):
167
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
168
+ super(LinearNorm, self).__init__()
169
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
170
+
171
+ torch.nn.init.xavier_uniform_(
172
+ self.linear_layer.weight,
173
+ gain=torch.nn.init.calculate_gain(w_init_gain))
174
+
175
+ def forward(self, x):
176
+ return self.linear_layer(x)
177
+
178
+ class Discriminator2d(nn.Module):
179
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
180
+ super().__init__()
181
+ blocks = []
182
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
183
+
184
+ for lid in range(repeat_num):
185
+ dim_out = min(dim_in*2, max_conv_dim)
186
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
187
+ dim_in = dim_out
188
+
189
+ blocks += [nn.LeakyReLU(0.2)]
190
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
191
+ blocks += [nn.LeakyReLU(0.2)]
192
+ blocks += [nn.AdaptiveAvgPool2d(1)]
193
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
194
+ self.main = nn.Sequential(*blocks)
195
+
196
+ def get_feature(self, x):
197
+ features = []
198
+ for l in self.main:
199
+ x = l(x)
200
+ features.append(x)
201
+ out = features[-1]
202
+ out = out.view(out.size(0), -1) # (batch, num_domains)
203
+ return out, features
204
+
205
+ def forward(self, x):
206
+ out, features = self.get_feature(x)
207
+ out = out.squeeze() # (batch)
208
+ return out, features
209
+
210
+ class ResBlk1d(nn.Module):
211
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
212
+ normalize=False, downsample='none', dropout_p=0.2):
213
+ super().__init__()
214
+ self.actv = actv
215
+ self.normalize = normalize
216
+ self.downsample_type = downsample
217
+ self.learned_sc = dim_in != dim_out
218
+ self._build_weights(dim_in, dim_out)
219
+ self.dropout_p = dropout_p
220
+
221
+ if self.downsample_type == 'none':
222
+ self.pool = nn.Identity()
223
+ else:
224
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
225
+
226
+ def _build_weights(self, dim_in, dim_out):
227
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
228
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
229
+ if self.normalize:
230
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
231
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
232
+ if self.learned_sc:
233
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
234
+
235
+ def downsample(self, x):
236
+ if self.downsample_type == 'none':
237
+ return x
238
+ else:
239
+ if x.shape[-1] % 2 != 0:
240
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
241
+ return F.avg_pool1d(x, 2)
242
+
243
+ def _shortcut(self, x):
244
+ if self.learned_sc:
245
+ x = self.conv1x1(x)
246
+ x = self.downsample(x)
247
+ return x
248
+
249
+ def _residual(self, x):
250
+ if self.normalize:
251
+ x = self.norm1(x)
252
+ x = self.actv(x)
253
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
254
+
255
+ x = self.conv1(x)
256
+ x = self.pool(x)
257
+ if self.normalize:
258
+ x = self.norm2(x)
259
+
260
+ x = self.actv(x)
261
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
262
+
263
+ x = self.conv2(x)
264
+ return x
265
+
266
+ def forward(self, x):
267
+ x = self._shortcut(x) + self._residual(x)
268
+ return x / math.sqrt(2) # unit variance
269
+
270
+ class LayerNorm(nn.Module):
271
+ def __init__(self, channels, eps=1e-5):
272
+ super().__init__()
273
+ self.channels = channels
274
+ self.eps = eps
275
+
276
+ self.gamma = nn.Parameter(torch.ones(channels))
277
+ self.beta = nn.Parameter(torch.zeros(channels))
278
+
279
+ def forward(self, x):
280
+ x = x.transpose(1, -1)
281
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
282
+ return x.transpose(1, -1)
283
+
284
+ class TextEncoder(nn.Module):
285
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
286
+ super().__init__()
287
+ self.embedding = nn.Embedding(n_symbols, channels)
288
+
289
+ padding = (kernel_size - 1) // 2
290
+ self.cnn = nn.ModuleList()
291
+ for _ in range(depth):
292
+ self.cnn.append(nn.Sequential(
293
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
294
+ LayerNorm(channels),
295
+ actv,
296
+ nn.Dropout(0.2),
297
+ ))
298
+ # self.cnn = nn.Sequential(*self.cnn)
299
+
300
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
301
+
302
+ def forward(self, x, input_lengths, m):
303
+ x = self.embedding(x) # [B, T, emb]
304
+ x = x.transpose(1, 2) # [B, emb, T]
305
+ m = m.to(input_lengths.device).unsqueeze(1)
306
+ x.masked_fill_(m, 0.0)
307
+
308
+ for c in self.cnn:
309
+ x = c(x)
310
+ x.masked_fill_(m, 0.0)
311
+
312
+ x = x.transpose(1, 2) # [B, T, chn]
313
+
314
+ input_lengths = input_lengths.cpu().numpy()
315
+ x = nn.utils.rnn.pack_padded_sequence(
316
+ x, input_lengths, batch_first=True, enforce_sorted=False)
317
+
318
+ self.lstm.flatten_parameters()
319
+ x, _ = self.lstm(x)
320
+ x, _ = nn.utils.rnn.pad_packed_sequence(
321
+ x, batch_first=True)
322
+
323
+ x = x.transpose(-1, -2)
324
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
325
+
326
+ x_pad[:, :, :x.shape[-1]] = x
327
+ x = x_pad.to(x.device)
328
+
329
+ x.masked_fill_(m, 0.0)
330
+
331
+ return x
332
+
333
+ def inference(self, x):
334
+ x = self.embedding(x)
335
+ x = x.transpose(1, 2)
336
+ x = self.cnn(x)
337
+ x = x.transpose(1, 2)
338
+ self.lstm.flatten_parameters()
339
+ x, _ = self.lstm(x)
340
+ return x
341
+
342
+ def length_to_mask(self, lengths):
343
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
344
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
345
+ return mask
346
+
347
+
348
+
349
+ class AdaIN1d(nn.Module):
350
+ def __init__(self, style_dim, num_features):
351
+ super().__init__()
352
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
353
+ self.fc = nn.Linear(style_dim, num_features*2)
354
+
355
+ def forward(self, x, s):
356
+ h = self.fc(s)
357
+ h = h.view(h.size(0), h.size(1), 1)
358
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
359
+ return (1 + gamma) * self.norm(x) + beta
360
+
361
+ class UpSample1d(nn.Module):
362
+ def __init__(self, layer_type):
363
+ super().__init__()
364
+ self.layer_type = layer_type
365
+
366
+ def forward(self, x):
367
+ if self.layer_type == 'none':
368
+ return x
369
+ else:
370
+ return F.interpolate(x, scale_factor=2, mode='nearest')
371
+
372
+ class AdainResBlk1d(nn.Module):
373
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
374
+ upsample='none', dropout_p=0.0):
375
+ super().__init__()
376
+ self.actv = actv
377
+ self.upsample_type = upsample
378
+ self.upsample = UpSample1d(upsample)
379
+ self.learned_sc = dim_in != dim_out
380
+ self._build_weights(dim_in, dim_out, style_dim)
381
+ self.dropout = nn.Dropout(dropout_p)
382
+
383
+ if upsample == 'none':
384
+ self.pool = nn.Identity()
385
+ else:
386
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
387
+
388
+
389
+ def _build_weights(self, dim_in, dim_out, style_dim):
390
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
391
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
392
+ self.norm1 = AdaIN1d(style_dim, dim_in)
393
+ self.norm2 = AdaIN1d(style_dim, dim_out)
394
+ if self.learned_sc:
395
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
396
+
397
+ def _shortcut(self, x):
398
+ x = self.upsample(x)
399
+ if self.learned_sc:
400
+ x = self.conv1x1(x)
401
+ return x
402
+
403
+ def _residual(self, x, s):
404
+ x = self.norm1(x, s)
405
+ x = self.actv(x)
406
+ x = self.pool(x)
407
+ x = self.conv1(self.dropout(x))
408
+ x = self.norm2(x, s)
409
+ x = self.actv(x)
410
+ x = self.conv2(self.dropout(x))
411
+ return x
412
+
413
+ def forward(self, x, s):
414
+ out = self._residual(x, s)
415
+ out = (out + self._shortcut(x)) / math.sqrt(2)
416
+ return out
417
+
418
+ class AdaLayerNorm(nn.Module):
419
+ def __init__(self, style_dim, channels, eps=1e-5):
420
+ super().__init__()
421
+ self.channels = channels
422
+ self.eps = eps
423
+
424
+ self.fc = nn.Linear(style_dim, channels*2)
425
+
426
+ def forward(self, x, s):
427
+ x = x.transpose(-1, -2)
428
+ x = x.transpose(1, -1)
429
+
430
+ h = self.fc(s)
431
+ h = h.view(h.size(0), h.size(1), 1)
432
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
433
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
434
+
435
+
436
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
437
+ x = (1 + gamma) * x + beta
438
+ return x.transpose(1, -1).transpose(-1, -2)
439
+
440
+ class ProsodyPredictor(nn.Module):
441
+
442
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
443
+ super().__init__()
444
+
445
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
446
+ d_model=d_hid,
447
+ nlayers=nlayers,
448
+ dropout=dropout)
449
+
450
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
451
+ self.duration_proj = LinearNorm(d_hid, max_dur)
452
+
453
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
454
+ self.F0 = nn.ModuleList()
455
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
456
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
457
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
458
+
459
+ self.N = nn.ModuleList()
460
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
461
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
462
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
463
+
464
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
465
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
466
+
467
+
468
+ def forward(self, texts, style, text_lengths, alignment, m):
469
+ d = self.text_encoder(texts, style, text_lengths, m)
470
+
471
+ batch_size = d.shape[0]
472
+ text_size = d.shape[1]
473
+
474
+ # predict duration
475
+ input_lengths = text_lengths.cpu().numpy()
476
+ x = nn.utils.rnn.pack_padded_sequence(
477
+ d, input_lengths, batch_first=True, enforce_sorted=False)
478
+
479
+ m = m.to(text_lengths.device).unsqueeze(1)
480
+
481
+ self.lstm.flatten_parameters()
482
+ x, _ = self.lstm(x)
483
+ x, _ = nn.utils.rnn.pad_packed_sequence(
484
+ x, batch_first=True)
485
+
486
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
487
+
488
+ x_pad[:, :x.shape[1], :] = x
489
+ x = x_pad.to(x.device)
490
+
491
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
492
+
493
+ en = (d.transpose(-1, -2) @ alignment)
494
+
495
+ return duration.squeeze(-1), en
496
+
497
+ def F0Ntrain(self, x, s):
498
+ x, _ = self.shared(x.transpose(-1, -2))
499
+
500
+ F0 = x.transpose(-1, -2)
501
+ for block in self.F0:
502
+ F0 = block(F0, s)
503
+ F0 = self.F0_proj(F0)
504
+
505
+ N = x.transpose(-1, -2)
506
+ for block in self.N:
507
+ N = block(N, s)
508
+ N = self.N_proj(N)
509
+
510
+ return F0.squeeze(1), N.squeeze(1)
511
+
512
+ def length_to_mask(self, lengths):
513
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
514
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
515
+ return mask
516
+
517
+ class DurationEncoder(nn.Module):
518
+
519
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
520
+ super().__init__()
521
+ self.lstms = nn.ModuleList()
522
+ for _ in range(nlayers):
523
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
524
+ d_model // 2,
525
+ num_layers=1,
526
+ batch_first=True,
527
+ bidirectional=True,
528
+ dropout=dropout))
529
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
530
+
531
+
532
+ self.dropout = dropout
533
+ self.d_model = d_model
534
+ self.sty_dim = sty_dim
535
+
536
+ def forward(self, x, style, text_lengths, m):
537
+ masks = m.to(text_lengths.device)
538
+
539
+ x = x.permute(2, 0, 1)
540
+ s = style.expand(x.shape[0], x.shape[1], -1)
541
+ x = torch.cat([x, s], axis=-1)
542
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
543
+
544
+ x = x.transpose(0, 1)
545
+ input_lengths = text_lengths.cpu().numpy()
546
+ x = x.transpose(-1, -2)
547
+
548
+ for block in self.lstms:
549
+ if isinstance(block, AdaLayerNorm):
550
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
551
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
552
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
553
+ else:
554
+ x = x.transpose(-1, -2)
555
+ x = nn.utils.rnn.pack_padded_sequence(
556
+ x, input_lengths, batch_first=True, enforce_sorted=False)
557
+ block.flatten_parameters()
558
+ x, _ = block(x)
559
+ x, _ = nn.utils.rnn.pad_packed_sequence(
560
+ x, batch_first=True)
561
+ x = F.dropout(x, p=self.dropout, training=self.training)
562
+ x = x.transpose(-1, -2)
563
+
564
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
565
+
566
+ x_pad[:, :, :x.shape[-1]] = x
567
+ x = x_pad.to(x.device)
568
+
569
+ return x.transpose(-1, -2)
570
+
571
+ def inference(self, x, style):
572
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
573
+ style = style.expand(x.shape[0], x.shape[1], -1)
574
+ x = torch.cat([x, style], axis=-1)
575
+ src = self.pos_encoder(x)
576
+ output = self.transformer_encoder(src).transpose(0, 1)
577
+ return output
578
+
579
+ def length_to_mask(self, lengths):
580
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
581
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
582
+ return mask
583
+
584
+ def load_F0_models(path):
585
+ # load F0 model
586
+
587
+ F0_model = JDCNet(num_class=1, seq_len=192)
588
+ params = torch.load(path, map_location='cpu')['net']
589
+ F0_model.load_state_dict(params)
590
+ _ = F0_model.train()
591
+
592
+ return F0_model
593
+
594
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
595
+ # load ASR model
596
+ def _load_config(path):
597
+ with open(path) as f:
598
+ config = yaml.safe_load(f)
599
+ model_config = config['model_params']
600
+ return model_config
601
+
602
+ def _load_model(model_config, model_path):
603
+ model = ASRCNN(**model_config)
604
+ params = torch.load(model_path, map_location='cpu')['model']
605
+ model.load_state_dict(params)
606
+ return model
607
+
608
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
609
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
610
+ _ = asr_model.train()
611
+
612
+ return asr_model
613
+
614
+ def build_model(args, text_aligner, pitch_extractor, bert):
615
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
616
+
617
+ if args.decoder.type == "istftnet":
618
+ from Modules.istftnet import Decoder
619
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
620
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
621
+ upsample_rates = args.decoder.upsample_rates,
622
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
623
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
624
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
625
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
626
+ else:
627
+ from Modules.hifigan import Decoder
628
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
629
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
630
+ upsample_rates = args.decoder.upsample_rates,
631
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
632
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
633
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
634
+
635
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
636
+
637
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
638
+
639
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
640
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
641
+
642
+ # define diffusion model
643
+ if args.multispeaker:
644
+ transformer = StyleTransformer1d(channels=args.style_dim*2,
645
+ context_embedding_features=bert.config.hidden_size,
646
+ context_features=args.style_dim*2,
647
+ **args.diffusion.transformer)
648
+ else:
649
+ transformer = Transformer1d(channels=args.style_dim*2,
650
+ context_embedding_features=bert.config.hidden_size,
651
+ **args.diffusion.transformer)
652
+
653
+ diffusion = AudioDiffusionConditional(
654
+ in_channels=1,
655
+ embedding_max_length=bert.config.max_position_embeddings,
656
+ embedding_features=bert.config.hidden_size,
657
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
658
+ channels=args.style_dim*2,
659
+ context_features=args.style_dim*2,
660
+ )
661
+
662
+ diffusion.diffusion = KDiffusion(
663
+ net=diffusion.unet,
664
+ sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
665
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
666
+ dynamic_threshold=0.0
667
+ )
668
+ diffusion.diffusion.net = transformer
669
+ diffusion.unet = transformer
670
+
671
+
672
+ nets = Munch(
673
+ bert=bert,
674
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
675
+
676
+ predictor=predictor,
677
+ decoder=decoder,
678
+ text_encoder=text_encoder,
679
+
680
+ predictor_encoder=predictor_encoder,
681
+ style_encoder=style_encoder,
682
+ diffusion=diffusion,
683
+
684
+ text_aligner = text_aligner,
685
+ pitch_extractor=pitch_extractor,
686
+
687
+ mpd = MultiPeriodDiscriminator(),
688
+ msd = MultiResSpecDiscriminator(),
689
+
690
+ # slm discriminator head
691
+ wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
692
+ )
693
+
694
+ return nets
695
+
696
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
697
+ state = torch.load(path, map_location='cpu')
698
+ params = state['net']
699
+ for key in model:
700
+ if key in params and key not in ignore_modules:
701
+ print('%s loaded' % key)
702
+ model[key].load_state_dict(params[key], strict=False)
703
+ _ = [model[key].eval() for key in model]
704
+
705
+ if not load_only_params:
706
+ epoch = state["epoch"]
707
+ iters = state["iters"]
708
+ optimizer.load_state_dict(state["optimizer"])
709
+ else:
710
+ epoch = 0
711
+ iters = 0
712
+
713
+ return model, optimizer, epoch, iters
msinference.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cached_path import cached_path
2
+ import nltk
3
+ nltk.download('punkt')
4
+ from scipy.io.wavfile import write
5
+ import torch
6
+ torch.manual_seed(0)
7
+ torch.backends.cudnn.benchmark = False
8
+ torch.backends.cudnn.deterministic = True
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+ import numpy as np
14
+ np.random.seed(0)
15
+
16
+ # load packages
17
+ import time
18
+ import random
19
+ import yaml
20
+ from munch import Munch
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ import torch.nn.functional as F
25
+ import torchaudio
26
+ import librosa
27
+ from nltk.tokenize import word_tokenize
28
+
29
+ from models import *
30
+ from utils import *
31
+ from text_utils import TextCleaner
32
+ textclenaer = TextCleaner()
33
+
34
+
35
+ to_mel = torchaudio.transforms.MelSpectrogram(
36
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
37
+ mean, std = -4, 4
38
+
39
+ def length_to_mask(lengths):
40
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
41
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
42
+ return mask
43
+
44
+ def preprocess(wave):
45
+ wave_tensor = torch.from_numpy(wave).float()
46
+ mel_tensor = to_mel(wave_tensor)
47
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
48
+ return mel_tensor
49
+
50
+ def compute_style(path):
51
+ wave, sr = librosa.load(path, sr=24000)
52
+ audio, index = librosa.effects.trim(wave, top_db=30)
53
+ if sr != 24000:
54
+ audio = librosa.resample(audio, sr, 24000)
55
+ mel_tensor = preprocess(audio).to(device)
56
+
57
+ with torch.no_grad():
58
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
59
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
60
+
61
+ return torch.cat([ref_s, ref_p], dim=1)
62
+
63
+ device = 'cpu'
64
+ if torch.cuda.is_available():
65
+ device = 'cuda'
66
+ elif torch.backends.mps.is_available():
67
+ # print("MPS would be available but cannot be used rn")
68
+ pass
69
+ # device = 'mps'
70
+
71
+ import phonemizer
72
+ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
73
+ # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
74
+
75
+
76
+ # config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
77
+ config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml"))))
78
+
79
+ # load pretrained ASR model
80
+ ASR_config = config.get('ASR_config', False)
81
+ ASR_path = config.get('ASR_path', False)
82
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
83
+
84
+ # load pretrained F0 model
85
+ F0_path = config.get('F0_path', False)
86
+ pitch_extractor = load_F0_models(F0_path)
87
+
88
+ # load BERT model
89
+ from Utils.PLBERT.util import load_plbert
90
+ BERT_path = config.get('PLBERT_dir', False)
91
+ plbert = load_plbert(BERT_path)
92
+
93
+ model_params = recursive_munch(config['model_params'])
94
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
95
+ _ = [model[key].eval() for key in model]
96
+ _ = [model[key].to(device) for key in model]
97
+
98
+ # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
99
+ params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
100
+ params = params_whole['net']
101
+
102
+ for key in model:
103
+ if key in params:
104
+ print('%s loaded' % key)
105
+ try:
106
+ model[key].load_state_dict(params[key])
107
+ except:
108
+ from collections import OrderedDict
109
+ state_dict = params[key]
110
+ new_state_dict = OrderedDict()
111
+ for k, v in state_dict.items():
112
+ name = k[7:] # remove `module.`
113
+ new_state_dict[name] = v
114
+ # load params
115
+ model[key].load_state_dict(new_state_dict, strict=False)
116
+ # except:
117
+ # _load(params[key], model[key])
118
+ _ = [model[key].eval() for key in model]
119
+
120
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
121
+
122
+ sampler = DiffusionSampler(
123
+ model.diffusion.diffusion,
124
+ sampler=ADPM2Sampler(),
125
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
126
+ clamp=False
127
+ )
128
+
129
+ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
130
+ text = text.strip()
131
+ ps = global_phonemizer.phonemize([text])
132
+ ps = word_tokenize(ps[0])
133
+ ps = ' '.join(ps)
134
+ tokens = textclenaer(ps)
135
+ tokens.insert(0, 0)
136
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
137
+
138
+ with torch.no_grad():
139
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
140
+ text_mask = length_to_mask(input_lengths).to(device)
141
+
142
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
143
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
144
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
145
+
146
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
147
+ embedding=bert_dur,
148
+ embedding_scale=embedding_scale,
149
+ features=ref_s, # reference from the same speaker as the embedding
150
+ num_steps=diffusion_steps).squeeze(1)
151
+
152
+
153
+ s = s_pred[:, 128:]
154
+ ref = s_pred[:, :128]
155
+
156
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
157
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
158
+
159
+ d = model.predictor.text_encoder(d_en,
160
+ s, input_lengths, text_mask)
161
+
162
+ x, _ = model.predictor.lstm(d)
163
+ duration = model.predictor.duration_proj(x)
164
+
165
+ duration = torch.sigmoid(duration).sum(axis=-1)
166
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
167
+
168
+
169
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
170
+ c_frame = 0
171
+ for i in range(pred_aln_trg.size(0)):
172
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
173
+ c_frame += int(pred_dur[i].data)
174
+
175
+ # encode prosody
176
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
177
+ if model_params.decoder.type == "hifigan":
178
+ asr_new = torch.zeros_like(en)
179
+ asr_new[:, :, 0] = en[:, :, 0]
180
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
181
+ en = asr_new
182
+
183
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
184
+
185
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
186
+ if model_params.decoder.type == "hifigan":
187
+ asr_new = torch.zeros_like(asr)
188
+ asr_new[:, :, 0] = asr[:, :, 0]
189
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
190
+ asr = asr_new
191
+
192
+ out = model.decoder(asr,
193
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
194
+
195
+
196
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
197
+
198
+ def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
199
+ text = text.strip()
200
+ ps = global_phonemizer.phonemize([text])
201
+ ps = word_tokenize(ps[0])
202
+ ps = ' '.join(ps)
203
+ ps = ps.replace('``', '"')
204
+ ps = ps.replace("''", '"')
205
+
206
+ tokens = textclenaer(ps)
207
+ tokens.insert(0, 0)
208
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
209
+
210
+ with torch.no_grad():
211
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
212
+ text_mask = length_to_mask(input_lengths).to(device)
213
+
214
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
215
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
216
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
217
+
218
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
219
+ embedding=bert_dur,
220
+ embedding_scale=embedding_scale,
221
+ features=ref_s, # reference from the same speaker as the embedding
222
+ num_steps=diffusion_steps).squeeze(1)
223
+
224
+ if s_prev is not None:
225
+ # convex combination of previous and current style
226
+ s_pred = t * s_prev + (1 - t) * s_pred
227
+
228
+ s = s_pred[:, 128:]
229
+ ref = s_pred[:, :128]
230
+
231
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
232
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
233
+
234
+ s_pred = torch.cat([ref, s], dim=-1)
235
+
236
+ d = model.predictor.text_encoder(d_en,
237
+ s, input_lengths, text_mask)
238
+
239
+ x, _ = model.predictor.lstm(d)
240
+ duration = model.predictor.duration_proj(x)
241
+
242
+ duration = torch.sigmoid(duration).sum(axis=-1)
243
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
244
+
245
+
246
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
247
+ c_frame = 0
248
+ for i in range(pred_aln_trg.size(0)):
249
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
250
+ c_frame += int(pred_dur[i].data)
251
+
252
+ # encode prosody
253
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
254
+ if model_params.decoder.type == "hifigan":
255
+ asr_new = torch.zeros_like(en)
256
+ asr_new[:, :, 0] = en[:, :, 0]
257
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
258
+ en = asr_new
259
+
260
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
261
+
262
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
263
+ if model_params.decoder.type == "hifigan":
264
+ asr_new = torch.zeros_like(asr)
265
+ asr_new[:, :, 0] = asr[:, :, 0]
266
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
267
+ asr = asr_new
268
+
269
+ out = model.decoder(asr,
270
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
271
+
272
+
273
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
274
+
275
+ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
276
+ text = text.strip()
277
+ ps = global_phonemizer.phonemize([text])
278
+ ps = word_tokenize(ps[0])
279
+ ps = ' '.join(ps)
280
+
281
+ tokens = textclenaer(ps)
282
+ tokens.insert(0, 0)
283
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
284
+
285
+ ref_text = ref_text.strip()
286
+ ps = global_phonemizer.phonemize([ref_text])
287
+ ps = word_tokenize(ps[0])
288
+ ps = ' '.join(ps)
289
+
290
+ ref_tokens = textclenaer(ps)
291
+ ref_tokens.insert(0, 0)
292
+ ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
293
+
294
+
295
+ with torch.no_grad():
296
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
297
+ text_mask = length_to_mask(input_lengths).to(device)
298
+
299
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
300
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
301
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
302
+
303
+ ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
304
+ ref_text_mask = length_to_mask(ref_input_lengths).to(device)
305
+ ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
306
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
307
+ embedding=bert_dur,
308
+ embedding_scale=embedding_scale,
309
+ features=ref_s, # reference from the same speaker as the embedding
310
+ num_steps=diffusion_steps).squeeze(1)
311
+
312
+
313
+ s = s_pred[:, 128:]
314
+ ref = s_pred[:, :128]
315
+
316
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
317
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
318
+
319
+ d = model.predictor.text_encoder(d_en,
320
+ s, input_lengths, text_mask)
321
+
322
+ x, _ = model.predictor.lstm(d)
323
+ duration = model.predictor.duration_proj(x)
324
+
325
+ duration = torch.sigmoid(duration).sum(axis=-1)
326
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
327
+
328
+
329
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
330
+ c_frame = 0
331
+ for i in range(pred_aln_trg.size(0)):
332
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
333
+ c_frame += int(pred_dur[i].data)
334
+
335
+ # encode prosody
336
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
337
+ if model_params.decoder.type == "hifigan":
338
+ asr_new = torch.zeros_like(en)
339
+ asr_new[:, :, 0] = en[:, :, 0]
340
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
341
+ en = asr_new
342
+
343
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
344
+
345
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
346
+ if model_params.decoder.type == "hifigan":
347
+ asr_new = torch.zeros_like(asr)
348
+ asr_new[:, :, 0] = asr[:, :, 0]
349
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
350
+ asr = asr_new
351
+
352
+ out = model.decoder(asr,
353
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
354
+
355
+
356
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
optimizers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ import os, sys
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from functools import reduce
9
+ from torch.optim import AdamW
10
+
11
+ class MultiOptimizer:
12
+ def __init__(self, optimizers={}, schedulers={}):
13
+ self.optimizers = optimizers
14
+ self.schedulers = schedulers
15
+ self.keys = list(optimizers.keys())
16
+ self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
17
+
18
+ def state_dict(self):
19
+ state_dicts = [(key, self.optimizers[key].state_dict())\
20
+ for key in self.keys]
21
+ return state_dicts
22
+
23
+ def load_state_dict(self, state_dict):
24
+ for key, val in state_dict:
25
+ try:
26
+ self.optimizers[key].load_state_dict(val)
27
+ except:
28
+ print("Unloaded %s" % key)
29
+
30
+ def step(self, key=None, scaler=None):
31
+ keys = [key] if key is not None else self.keys
32
+ _ = [self._step(key, scaler) for key in keys]
33
+
34
+ def _step(self, key, scaler=None):
35
+ if scaler is not None:
36
+ scaler.step(self.optimizers[key])
37
+ scaler.update()
38
+ else:
39
+ self.optimizers[key].step()
40
+
41
+ def zero_grad(self, key=None):
42
+ if key is not None:
43
+ self.optimizers[key].zero_grad()
44
+ else:
45
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
46
+
47
+ def scheduler(self, *args, key=None):
48
+ if key is not None:
49
+ self.schedulers[key].step(*args)
50
+ else:
51
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
52
+
53
+ def define_scheduler(optimizer, params):
54
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
55
+ optimizer,
56
+ max_lr=params.get('max_lr', 2e-4),
57
+ epochs=params.get('epochs', 200),
58
+ steps_per_epoch=params.get('steps_per_epoch', 1000),
59
+ pct_start=params.get('pct_start', 0.0),
60
+ div_factor=1,
61
+ final_div_factor=1)
62
+
63
+ return scheduler
64
+
65
+ def build_optimizer(parameters_dict, scheduler_params_dict, lr):
66
+ optim = dict([(key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
67
+ for key, params in parameters_dict.items()])
68
+
69
+ schedulers = dict([(key, define_scheduler(opt, scheduler_params_dict[key])) \
70
+ for key, opt in optim.items()])
71
+
72
+ multi_optim = MultiOptimizer(optim, schedulers)
73
+ return multi_optim
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SoundFile
2
+ torchaudio
3
+ munch
4
+ torch
5
+ pydub
6
+ pyyaml
7
+ librosa
8
+ nltk
9
+ matplotlib
10
+ accelerate
11
+ transformers
12
+ einops
13
+ einops-exts
14
+ tqdm
15
+ typing
16
+ typing-extensions
17
+ git+https://github.com/resemble-ai/monotonic_align.git # or resemble-monotonic-align
18
+ gradio
19
+ phonemizer
20
+ cached-path
21
+ tortoise-tts # for the Gradio demo, splitting text
22
+ flask # for api
23
+ markdown # for api
24
+ flask-cors
text_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IPA Phonemizer: https://github.com/bootphon/phonemizer
2
+
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»“” '
5
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ # Export all symbols:
9
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
10
+
11
+ dicts = {}
12
+ for i in range(len((symbols))):
13
+ dicts[symbols[i]] = i
14
+
15
+ class TextCleaner:
16
+ def __init__(self, dummy=None):
17
+ self.word_index_dictionary = dicts
18
+ print(len(dicts))
19
+ def __call__(self, text):
20
+ indexes = []
21
+ for char in text:
22
+ try:
23
+ indexes.append(self.word_index_dictionary[char])
24
+ except KeyError:
25
+ print(text)
26
+ return indexes
train_finetune.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+ # simple fix for dataparallel that allows access to class attributes
34
+ class MyDataParallel(torch.nn.DataParallel):
35
+ def __getattr__(self, name):
36
+ try:
37
+ return super().__getattr__(name)
38
+ except AttributeError:
39
+ return getattr(self.module, name)
40
+
41
+ import logging
42
+ from logging import StreamHandler
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(logging.DEBUG)
45
+ handler = StreamHandler()
46
+ handler.setLevel(logging.DEBUG)
47
+ logger.addHandler(handler)
48
+
49
+
50
+ @click.command()
51
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
52
+ def main(config_path):
53
+ config = yaml.safe_load(open(config_path))
54
+
55
+ log_dir = config['log_dir']
56
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
57
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
58
+ writer = SummaryWriter(log_dir + "/tensorboard")
59
+
60
+ # write logs
61
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
64
+ logger.addHandler(file_handler)
65
+
66
+
67
+ batch_size = config.get('batch_size', 10)
68
+
69
+ epochs = config.get('epochs', 200)
70
+ save_freq = config.get('save_freq', 2)
71
+ log_interval = config.get('log_interval', 10)
72
+ saving_epoch = config.get('save_freq', 2)
73
+
74
+ data_params = config.get('data_params', None)
75
+ sr = config['preprocess_params'].get('sr', 24000)
76
+ train_path = data_params['train_data']
77
+ val_path = data_params['val_data']
78
+ root_path = data_params['root_path']
79
+ min_length = data_params['min_length']
80
+ OOD_data = data_params['OOD_data']
81
+
82
+ max_len = config.get('max_len', 200)
83
+
84
+ loss_params = Munch(config['loss_params'])
85
+ diff_epoch = loss_params.diff_epoch
86
+ joint_epoch = loss_params.joint_epoch
87
+
88
+ optimizer_params = Munch(config['optimizer_params'])
89
+
90
+ train_list, val_list = get_data_path_list(train_path, val_path)
91
+ device = 'cuda'
92
+
93
+ train_dataloader = build_dataloader(train_list,
94
+ root_path,
95
+ OOD_data=OOD_data,
96
+ min_length=min_length,
97
+ batch_size=batch_size,
98
+ num_workers=2,
99
+ dataset_config={},
100
+ device=device)
101
+
102
+ val_dataloader = build_dataloader(val_list,
103
+ root_path,
104
+ OOD_data=OOD_data,
105
+ min_length=min_length,
106
+ batch_size=batch_size,
107
+ validation=True,
108
+ num_workers=0,
109
+ device=device,
110
+ dataset_config={})
111
+
112
+ # load pretrained ASR model
113
+ ASR_config = config.get('ASR_config', False)
114
+ ASR_path = config.get('ASR_path', False)
115
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
116
+
117
+ # load pretrained F0 model
118
+ F0_path = config.get('F0_path', False)
119
+ pitch_extractor = load_F0_models(F0_path)
120
+
121
+ # load PL-BERT model
122
+ BERT_path = config.get('PLBERT_dir', False)
123
+ plbert = load_plbert(BERT_path)
124
+
125
+ # build model
126
+ model_params = recursive_munch(config['model_params'])
127
+ multispeaker = model_params.multispeaker
128
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
129
+ _ = [model[key].to(device) for key in model]
130
+
131
+ # DP
132
+ for key in model:
133
+ if key != "mpd" and key != "msd" and key != "wd":
134
+ model[key] = MyDataParallel(model[key])
135
+
136
+ start_epoch = 0
137
+ iters = 0
138
+
139
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
140
+
141
+ if not load_pretrained:
142
+ if config.get('first_stage_path', '') != '':
143
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
144
+ print('Loading the first stage model at %s ...' % first_stage_path)
145
+ model, _, start_epoch, iters = load_checkpoint(model,
146
+ None,
147
+ first_stage_path,
148
+ load_only_params=True,
149
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
150
+
151
+ # these epochs should be counted from the start epoch
152
+ diff_epoch += start_epoch
153
+ joint_epoch += start_epoch
154
+ epochs += start_epoch
155
+
156
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
157
+ else:
158
+ raise ValueError('You need to specify the path to the first stage model.')
159
+
160
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
161
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
162
+ wl = WavLMLoss(model_params.slm.model,
163
+ model.wd,
164
+ sr,
165
+ model_params.slm.sr).to(device)
166
+
167
+ gl = MyDataParallel(gl)
168
+ dl = MyDataParallel(dl)
169
+ wl = MyDataParallel(wl)
170
+
171
+ sampler = DiffusionSampler(
172
+ model.diffusion.diffusion,
173
+ sampler=ADPM2Sampler(),
174
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
175
+ clamp=False
176
+ )
177
+
178
+ scheduler_params = {
179
+ "max_lr": optimizer_params.lr,
180
+ "pct_start": float(0),
181
+ "epochs": epochs,
182
+ "steps_per_epoch": len(train_dataloader),
183
+ }
184
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
185
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
186
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
187
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
188
+
189
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
190
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
191
+
192
+ # adjust BERT learning rate
193
+ for g in optimizer.optimizers['bert'].param_groups:
194
+ g['betas'] = (0.9, 0.99)
195
+ g['lr'] = optimizer_params.bert_lr
196
+ g['initial_lr'] = optimizer_params.bert_lr
197
+ g['min_lr'] = 0
198
+ g['weight_decay'] = 0.01
199
+
200
+ # adjust acoustic module learning rate
201
+ for module in ["decoder", "style_encoder"]:
202
+ for g in optimizer.optimizers[module].param_groups:
203
+ g['betas'] = (0.0, 0.99)
204
+ g['lr'] = optimizer_params.ft_lr
205
+ g['initial_lr'] = optimizer_params.ft_lr
206
+ g['min_lr'] = 0
207
+ g['weight_decay'] = 1e-4
208
+
209
+ # load models if there is a model
210
+ if load_pretrained:
211
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
212
+ load_only_params=config.get('load_only_params', True))
213
+
214
+ n_down = model.text_aligner.n_down
215
+
216
+ best_loss = float('inf') # best test loss
217
+ loss_train_record = list([])
218
+ loss_test_record = list([])
219
+ iters = 0
220
+
221
+ criterion = nn.L1Loss() # F0 loss (regression)
222
+ torch.cuda.empty_cache()
223
+
224
+ stft_loss = MultiResolutionSTFTLoss().to(device)
225
+
226
+ print('BERT', optimizer.optimizers['bert'])
227
+ print('decoder', optimizer.optimizers['decoder'])
228
+
229
+ start_ds = False
230
+
231
+ running_std = []
232
+
233
+ slmadv_params = Munch(config['slmadv_params'])
234
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
235
+ slmadv_params.min_len,
236
+ slmadv_params.max_len,
237
+ batch_percentage=slmadv_params.batch_percentage,
238
+ skip_update=slmadv_params.iter,
239
+ sig=slmadv_params.sig
240
+ )
241
+
242
+
243
+ for epoch in range(start_epoch, epochs):
244
+ running_loss = 0
245
+ start_time = time.time()
246
+
247
+ _ = [model[key].eval() for key in model]
248
+
249
+ model.text_aligner.train()
250
+ model.text_encoder.train()
251
+
252
+ model.predictor.train()
253
+ model.bert_encoder.train()
254
+ model.bert.train()
255
+ model.msd.train()
256
+ model.mpd.train()
257
+
258
+ for i, batch in enumerate(train_dataloader):
259
+ waves = batch[0]
260
+ batch = [b.to(device) for b in batch[1:]]
261
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
262
+ with torch.no_grad():
263
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
264
+ mel_mask = length_to_mask(mel_input_length).to(device)
265
+ text_mask = length_to_mask(input_lengths).to(texts.device)
266
+
267
+ # compute reference styles
268
+ if multispeaker and epoch >= diff_epoch:
269
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
270
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
271
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
272
+
273
+ try:
274
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
275
+ s2s_attn = s2s_attn.transpose(-1, -2)
276
+ s2s_attn = s2s_attn[..., 1:]
277
+ s2s_attn = s2s_attn.transpose(-1, -2)
278
+ except:
279
+ continue
280
+
281
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
282
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
283
+
284
+ # encode
285
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
286
+
287
+ # 50% of chance of using monotonic version
288
+ if bool(random.getrandbits(1)):
289
+ asr = (t_en @ s2s_attn)
290
+ else:
291
+ asr = (t_en @ s2s_attn_mono)
292
+
293
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
294
+
295
+ # compute the style of the entire utterance
296
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
297
+ ss = []
298
+ gs = []
299
+ for bib in range(len(mel_input_length)):
300
+ mel_length = int(mel_input_length[bib].item())
301
+ mel = mels[bib, :, :mel_input_length[bib]]
302
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
303
+ ss.append(s)
304
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
305
+ gs.append(s)
306
+
307
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
308
+ gs = torch.stack(gs).squeeze() # global acoustic styles
309
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
310
+
311
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
312
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
313
+
314
+ # denoiser training
315
+ if epoch >= diff_epoch:
316
+ num_steps = np.random.randint(3, 5)
317
+
318
+ if model_params.diffusion.dist.estimate_sigma_data:
319
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
320
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
321
+
322
+ if multispeaker:
323
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
324
+ embedding=bert_dur,
325
+ embedding_scale=1,
326
+ features=ref, # reference from the same speaker as the embedding
327
+ embedding_mask_proba=0.1,
328
+ num_steps=num_steps).squeeze(1)
329
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
330
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
331
+ else:
332
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
333
+ embedding=bert_dur,
334
+ embedding_scale=1,
335
+ embedding_mask_proba=0.1,
336
+ num_steps=num_steps).squeeze(1)
337
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
338
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
339
+ else:
340
+ loss_sty = 0
341
+ loss_diff = 0
342
+
343
+
344
+ s_loss = 0
345
+
346
+
347
+ d, p = model.predictor(d_en, s_dur,
348
+ input_lengths,
349
+ s2s_attn_mono,
350
+ text_mask)
351
+
352
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
353
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
354
+ en = []
355
+ gt = []
356
+ p_en = []
357
+ wav = []
358
+ st = []
359
+
360
+ for bib in range(len(mel_input_length)):
361
+ mel_length = int(mel_input_length[bib].item() / 2)
362
+
363
+ random_start = np.random.randint(0, mel_length - mel_len)
364
+ en.append(asr[bib, :, random_start:random_start+mel_len])
365
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
366
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
367
+
368
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
369
+ wav.append(torch.from_numpy(y).to(device))
370
+
371
+ # style reference (better to be different from the GT)
372
+ random_start = np.random.randint(0, mel_length - mel_len_st)
373
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
374
+
375
+ wav = torch.stack(wav).float().detach()
376
+
377
+ en = torch.stack(en)
378
+ p_en = torch.stack(p_en)
379
+ gt = torch.stack(gt).detach()
380
+ st = torch.stack(st).detach()
381
+
382
+
383
+ if gt.size(-1) < 80:
384
+ continue
385
+
386
+ s = model.style_encoder(gt.unsqueeze(1))
387
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
388
+
389
+ with torch.no_grad():
390
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
391
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
392
+
393
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
394
+
395
+ y_rec_gt = wav.unsqueeze(1)
396
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
397
+
398
+ wav = y_rec_gt
399
+
400
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
401
+
402
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
403
+
404
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
405
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
406
+
407
+ optimizer.zero_grad()
408
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
409
+ d_loss.backward()
410
+ optimizer.step('msd')
411
+ optimizer.step('mpd')
412
+
413
+ # generator loss
414
+ optimizer.zero_grad()
415
+
416
+ loss_mel = stft_loss(y_rec, wav)
417
+ loss_gen_all = gl(wav, y_rec).mean()
418
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
419
+
420
+ loss_ce = 0
421
+ loss_dur = 0
422
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
423
+ _s2s_pred = _s2s_pred[:_text_length, :]
424
+ _text_input = _text_input[:_text_length].long()
425
+ _s2s_trg = torch.zeros_like(_s2s_pred)
426
+ for p in range(_s2s_trg.shape[0]):
427
+ _s2s_trg[p, :_text_input[p]] = 1
428
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
429
+
430
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
431
+ _text_input[1:_text_length-1])
432
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
433
+
434
+ loss_ce /= texts.size(0)
435
+ loss_dur /= texts.size(0)
436
+
437
+ loss_s2s = 0
438
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
439
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
440
+ loss_s2s /= texts.size(0)
441
+
442
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
443
+
444
+ g_loss = loss_params.lambda_mel * loss_mel + \
445
+ loss_params.lambda_F0 * loss_F0_rec + \
446
+ loss_params.lambda_ce * loss_ce + \
447
+ loss_params.lambda_norm * loss_norm_rec + \
448
+ loss_params.lambda_dur * loss_dur + \
449
+ loss_params.lambda_gen * loss_gen_all + \
450
+ loss_params.lambda_slm * loss_lm + \
451
+ loss_params.lambda_sty * loss_sty + \
452
+ loss_params.lambda_diff * loss_diff + \
453
+ loss_params.lambda_mono * loss_mono + \
454
+ loss_params.lambda_s2s * loss_s2s
455
+
456
+ running_loss += loss_mel.item()
457
+ g_loss.backward()
458
+ if torch.isnan(g_loss):
459
+ from IPython.core.debugger import set_trace
460
+ set_trace()
461
+
462
+ optimizer.step('bert_encoder')
463
+ optimizer.step('bert')
464
+ optimizer.step('predictor')
465
+ optimizer.step('predictor_encoder')
466
+ optimizer.step('style_encoder')
467
+ optimizer.step('decoder')
468
+
469
+ optimizer.step('text_encoder')
470
+ optimizer.step('text_aligner')
471
+
472
+ if epoch >= diff_epoch:
473
+ optimizer.step('diffusion')
474
+
475
+ d_loss_slm, loss_gen_lm = 0, 0
476
+ if epoch >= joint_epoch:
477
+ # randomly pick whether to use in-distribution text
478
+ if np.random.rand() < 0.5:
479
+ use_ind = True
480
+ else:
481
+ use_ind = False
482
+
483
+ if use_ind:
484
+ ref_lengths = input_lengths
485
+ ref_texts = texts
486
+
487
+ slm_out = slmadv(i,
488
+ y_rec_gt,
489
+ y_rec_gt_pred,
490
+ waves,
491
+ mel_input_length,
492
+ ref_texts,
493
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
494
+
495
+ if slm_out is not None:
496
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
497
+
498
+ # SLM generator loss
499
+ optimizer.zero_grad()
500
+ loss_gen_lm.backward()
501
+
502
+ # compute the gradient norm
503
+ total_norm = {}
504
+ for key in model.keys():
505
+ total_norm[key] = 0
506
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
507
+ for p in parameters:
508
+ param_norm = p.grad.detach().data.norm(2)
509
+ total_norm[key] += param_norm.item() ** 2
510
+ total_norm[key] = total_norm[key] ** 0.5
511
+
512
+ # gradient scaling
513
+ if total_norm['predictor'] > slmadv_params.thresh:
514
+ for key in model.keys():
515
+ for p in model[key].parameters():
516
+ if p.grad is not None:
517
+ p.grad *= (1 / total_norm['predictor'])
518
+
519
+ for p in model.predictor.duration_proj.parameters():
520
+ if p.grad is not None:
521
+ p.grad *= slmadv_params.scale
522
+
523
+ for p in model.predictor.lstm.parameters():
524
+ if p.grad is not None:
525
+ p.grad *= slmadv_params.scale
526
+
527
+ for p in model.diffusion.parameters():
528
+ if p.grad is not None:
529
+ p.grad *= slmadv_params.scale
530
+
531
+ optimizer.step('bert_encoder')
532
+ optimizer.step('bert')
533
+ optimizer.step('predictor')
534
+ optimizer.step('diffusion')
535
+
536
+ # SLM discriminator loss
537
+ if d_loss_slm != 0:
538
+ optimizer.zero_grad()
539
+ d_loss_slm.backward(retain_graph=True)
540
+ optimizer.step('wd')
541
+
542
+ iters = iters + 1
543
+
544
+ if (i+1)%log_interval == 0:
545
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
546
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
547
+
548
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
549
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
550
+ writer.add_scalar('train/d_loss', d_loss, iters)
551
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
552
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
553
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
554
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
555
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
556
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
557
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
558
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
559
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
560
+
561
+ running_loss = 0
562
+
563
+ print('Time elasped:', time.time()-start_time)
564
+
565
+ loss_test = 0
566
+ loss_align = 0
567
+ loss_f = 0
568
+ _ = [model[key].eval() for key in model]
569
+
570
+ with torch.no_grad():
571
+ iters_test = 0
572
+ for batch_idx, batch in enumerate(val_dataloader):
573
+ optimizer.zero_grad()
574
+
575
+ try:
576
+ waves = batch[0]
577
+ batch = [b.to(device) for b in batch[1:]]
578
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
579
+ with torch.no_grad():
580
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
581
+ text_mask = length_to_mask(input_lengths).to(texts.device)
582
+
583
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
584
+ s2s_attn = s2s_attn.transpose(-1, -2)
585
+ s2s_attn = s2s_attn[..., 1:]
586
+ s2s_attn = s2s_attn.transpose(-1, -2)
587
+
588
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
589
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
590
+
591
+ # encode
592
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
593
+ asr = (t_en @ s2s_attn_mono)
594
+
595
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
596
+
597
+ ss = []
598
+ gs = []
599
+
600
+ for bib in range(len(mel_input_length)):
601
+ mel_length = int(mel_input_length[bib].item())
602
+ mel = mels[bib, :, :mel_input_length[bib]]
603
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
604
+ ss.append(s)
605
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
606
+ gs.append(s)
607
+
608
+ s = torch.stack(ss).squeeze()
609
+ gs = torch.stack(gs).squeeze()
610
+ s_trg = torch.cat([s, gs], dim=-1).detach()
611
+
612
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
613
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
614
+ d, p = model.predictor(d_en, s,
615
+ input_lengths,
616
+ s2s_attn_mono,
617
+ text_mask)
618
+ # get clips
619
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
620
+ en = []
621
+ gt = []
622
+
623
+ p_en = []
624
+ wav = []
625
+
626
+ for bib in range(len(mel_input_length)):
627
+ mel_length = int(mel_input_length[bib].item() / 2)
628
+
629
+ random_start = np.random.randint(0, mel_length - mel_len)
630
+ en.append(asr[bib, :, random_start:random_start+mel_len])
631
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
632
+
633
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
634
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
635
+ wav.append(torch.from_numpy(y).to(device))
636
+
637
+ wav = torch.stack(wav).float().detach()
638
+
639
+ en = torch.stack(en)
640
+ p_en = torch.stack(p_en)
641
+ gt = torch.stack(gt).detach()
642
+ s = model.predictor_encoder(gt.unsqueeze(1))
643
+
644
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
645
+
646
+ loss_dur = 0
647
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
648
+ _s2s_pred = _s2s_pred[:_text_length, :]
649
+ _text_input = _text_input[:_text_length].long()
650
+ _s2s_trg = torch.zeros_like(_s2s_pred)
651
+ for bib in range(_s2s_trg.shape[0]):
652
+ _s2s_trg[bib, :_text_input[bib]] = 1
653
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
654
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
655
+ _text_input[1:_text_length-1])
656
+
657
+ loss_dur /= texts.size(0)
658
+
659
+ s = model.style_encoder(gt.unsqueeze(1))
660
+
661
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
662
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
663
+
664
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
665
+
666
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
667
+
668
+ loss_test += (loss_mel).mean()
669
+ loss_align += (loss_dur).mean()
670
+ loss_f += (loss_F0).mean()
671
+
672
+ iters_test += 1
673
+ except:
674
+ continue
675
+
676
+ print('Epochs:', epoch + 1)
677
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
678
+ print('\n\n\n')
679
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
680
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
681
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
682
+
683
+
684
+ if (epoch + 1) % save_freq == 0 :
685
+ if (loss_test / iters_test) < best_loss:
686
+ best_loss = loss_test / iters_test
687
+ print('Saving..')
688
+ state = {
689
+ 'net': {key: model[key].state_dict() for key in model},
690
+ 'optimizer': optimizer.state_dict(),
691
+ 'iters': iters,
692
+ 'val_loss': loss_test / iters_test,
693
+ 'epoch': epoch,
694
+ }
695
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
696
+ torch.save(state, save_path)
697
+
698
+ # if estimate sigma, save the estimated simga
699
+ if model_params.diffusion.dist.estimate_sigma_data:
700
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
701
+
702
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
703
+ yaml.dump(config, outfile, default_flow_style=True)
704
+
705
+
706
+ if __name__=="__main__":
707
+ main()
train_finetune_accelerate.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+ from accelerate import Accelerator
34
+
35
+ accelerator = Accelerator()
36
+
37
+ # simple fix for dataparallel that allows access to class attributes
38
+ class MyDataParallel(torch.nn.DataParallel):
39
+ def __getattr__(self, name):
40
+ try:
41
+ return super().__getattr__(name)
42
+ except AttributeError:
43
+ return getattr(self.module, name)
44
+
45
+ import logging
46
+ from logging import StreamHandler
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.DEBUG)
49
+ handler = StreamHandler()
50
+ handler.setLevel(logging.DEBUG)
51
+ logger.addHandler(handler)
52
+
53
+
54
+ @click.command()
55
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
56
+ def main(config_path):
57
+ config = yaml.safe_load(open(config_path))
58
+
59
+ log_dir = config['log_dir']
60
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
61
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
62
+ writer = SummaryWriter(log_dir + "/tensorboard")
63
+
64
+ # write logs
65
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
66
+ file_handler.setLevel(logging.DEBUG)
67
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
68
+ logger.addHandler(file_handler)
69
+
70
+
71
+ batch_size = config.get('batch_size', 10)
72
+
73
+ epochs = config.get('epochs', 200)
74
+ save_freq = config.get('save_freq', 2)
75
+ log_interval = config.get('log_interval', 10)
76
+ saving_epoch = config.get('save_freq', 2)
77
+
78
+ data_params = config.get('data_params', None)
79
+ sr = config['preprocess_params'].get('sr', 24000)
80
+ train_path = data_params['train_data']
81
+ val_path = data_params['val_data']
82
+ root_path = data_params['root_path']
83
+ min_length = data_params['min_length']
84
+ OOD_data = data_params['OOD_data']
85
+
86
+ max_len = config.get('max_len', 200)
87
+
88
+ loss_params = Munch(config['loss_params'])
89
+ diff_epoch = loss_params.diff_epoch
90
+ joint_epoch = loss_params.joint_epoch
91
+
92
+ optimizer_params = Munch(config['optimizer_params'])
93
+
94
+ train_list, val_list = get_data_path_list(train_path, val_path)
95
+ device = accelerator.device
96
+
97
+ train_dataloader = build_dataloader(train_list,
98
+ root_path,
99
+ OOD_data=OOD_data,
100
+ min_length=min_length,
101
+ batch_size=batch_size,
102
+ num_workers=2,
103
+ dataset_config={},
104
+ device=device)
105
+
106
+ val_dataloader = build_dataloader(val_list,
107
+ root_path,
108
+ OOD_data=OOD_data,
109
+ min_length=min_length,
110
+ batch_size=batch_size,
111
+ validation=True,
112
+ num_workers=0,
113
+ device=device,
114
+ dataset_config={})
115
+
116
+ # load pretrained ASR model
117
+ ASR_config = config.get('ASR_config', False)
118
+ ASR_path = config.get('ASR_path', False)
119
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
120
+
121
+ # load pretrained F0 model
122
+ F0_path = config.get('F0_path', False)
123
+ pitch_extractor = load_F0_models(F0_path)
124
+
125
+ # load PL-BERT model
126
+ BERT_path = config.get('PLBERT_dir', False)
127
+ plbert = load_plbert(BERT_path)
128
+
129
+ # build model
130
+ model_params = recursive_munch(config['model_params'])
131
+ multispeaker = model_params.multispeaker
132
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
133
+ _ = [model[key].to(device) for key in model]
134
+
135
+ # DP
136
+ for key in model:
137
+ if key != "mpd" and key != "msd" and key != "wd":
138
+ model[key] = MyDataParallel(model[key])
139
+
140
+ start_epoch = 0
141
+ iters = 0
142
+
143
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
144
+
145
+ if not load_pretrained:
146
+ if config.get('first_stage_path', '') != '':
147
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
148
+ print('Loading the first stage model at %s ...' % first_stage_path)
149
+ model, _, start_epoch, iters = load_checkpoint(model,
150
+ None,
151
+ first_stage_path,
152
+ load_only_params=True,
153
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
154
+
155
+ # these epochs should be counted from the start epoch
156
+ diff_epoch += start_epoch
157
+ joint_epoch += start_epoch
158
+ epochs += start_epoch
159
+
160
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
161
+ else:
162
+ raise ValueError('You need to specify the path to the first stage model.')
163
+
164
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
165
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
166
+ wl = WavLMLoss(model_params.slm.model,
167
+ model.wd,
168
+ sr,
169
+ model_params.slm.sr).to(device)
170
+
171
+ gl = MyDataParallel(gl)
172
+ dl = MyDataParallel(dl)
173
+ wl = MyDataParallel(wl)
174
+
175
+ sampler = DiffusionSampler(
176
+ model.diffusion.diffusion,
177
+ sampler=ADPM2Sampler(),
178
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
179
+ clamp=False
180
+ )
181
+
182
+ scheduler_params = {
183
+ "max_lr": optimizer_params.lr,
184
+ "pct_start": float(0),
185
+ "epochs": epochs,
186
+ "steps_per_epoch": len(train_dataloader),
187
+ }
188
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
189
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
190
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
191
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
192
+
193
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
194
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
195
+
196
+ # adjust BERT learning rate
197
+ for g in optimizer.optimizers['bert'].param_groups:
198
+ g['betas'] = (0.9, 0.99)
199
+ g['lr'] = optimizer_params.bert_lr
200
+ g['initial_lr'] = optimizer_params.bert_lr
201
+ g['min_lr'] = 0
202
+ g['weight_decay'] = 0.01
203
+
204
+ # adjust acoustic module learning rate
205
+ for module in ["decoder", "style_encoder"]:
206
+ for g in optimizer.optimizers[module].param_groups:
207
+ g['betas'] = (0.0, 0.99)
208
+ g['lr'] = optimizer_params.ft_lr
209
+ g['initial_lr'] = optimizer_params.ft_lr
210
+ g['min_lr'] = 0
211
+ g['weight_decay'] = 1e-4
212
+
213
+ # load models if there is a model
214
+ if load_pretrained:
215
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
216
+ load_only_params=config.get('load_only_params', True))
217
+
218
+ n_down = model.text_aligner.n_down
219
+
220
+ best_loss = float('inf') # best test loss
221
+ loss_train_record = list([])
222
+ loss_test_record = list([])
223
+ iters = 0
224
+
225
+ criterion = nn.L1Loss() # F0 loss (regression)
226
+ torch.cuda.empty_cache()
227
+
228
+ stft_loss = MultiResolutionSTFTLoss().to(device)
229
+
230
+ print('BERT', optimizer.optimizers['bert'])
231
+ print('decoder', optimizer.optimizers['decoder'])
232
+
233
+ start_ds = False
234
+
235
+ running_std = []
236
+
237
+ slmadv_params = Munch(config['slmadv_params'])
238
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
239
+ slmadv_params.min_len,
240
+ slmadv_params.max_len,
241
+ batch_percentage=slmadv_params.batch_percentage,
242
+ skip_update=slmadv_params.iter,
243
+ sig=slmadv_params.sig
244
+ )
245
+
246
+ model, optimizer, train_dataloader = accelerator.prepare(
247
+ model, optimizer, train_dataloader
248
+ )
249
+
250
+ for epoch in range(start_epoch, epochs):
251
+ running_loss = 0
252
+ start_time = time.time()
253
+
254
+ _ = [model[key].eval() for key in model]
255
+
256
+ model.text_aligner.train()
257
+ model.text_encoder.train()
258
+
259
+ model.predictor.train()
260
+ model.bert_encoder.train()
261
+ model.bert.train()
262
+ model.msd.train()
263
+ model.mpd.train()
264
+
265
+ for i, batch in enumerate(train_dataloader):
266
+ waves = batch[0]
267
+ batch = [b.to(device) for b in batch[1:]]
268
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
269
+ with torch.no_grad():
270
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
271
+ mel_mask = length_to_mask(mel_input_length).to(device)
272
+ text_mask = length_to_mask(input_lengths).to(texts.device)
273
+
274
+ # compute reference styles
275
+ if multispeaker and epoch >= diff_epoch:
276
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
277
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
278
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
279
+
280
+ try:
281
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
282
+ s2s_attn = s2s_attn.transpose(-1, -2)
283
+ s2s_attn = s2s_attn[..., 1:]
284
+ s2s_attn = s2s_attn.transpose(-1, -2)
285
+ except:
286
+ continue
287
+
288
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
289
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
290
+
291
+ # encode
292
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
293
+
294
+ # 50% of chance of using monotonic version
295
+ if bool(random.getrandbits(1)):
296
+ asr = (t_en @ s2s_attn)
297
+ else:
298
+ asr = (t_en @ s2s_attn_mono)
299
+
300
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
301
+
302
+ # compute the style of the entire utterance
303
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
304
+ ss = []
305
+ gs = []
306
+ for bib in range(len(mel_input_length)):
307
+ mel_length = int(mel_input_length[bib].item())
308
+ mel = mels[bib, :, :mel_input_length[bib]]
309
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
310
+ ss.append(s)
311
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
312
+ gs.append(s)
313
+
314
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
315
+ gs = torch.stack(gs).squeeze() # global acoustic styles
316
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
317
+
318
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
319
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
320
+
321
+ # denoiser training
322
+ if epoch >= diff_epoch:
323
+ num_steps = np.random.randint(3, 5)
324
+
325
+ if model_params.diffusion.dist.estimate_sigma_data:
326
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
327
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
328
+
329
+ if multispeaker:
330
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
331
+ embedding=bert_dur,
332
+ embedding_scale=1,
333
+ features=ref, # reference from the same speaker as the embedding
334
+ embedding_mask_proba=0.1,
335
+ num_steps=num_steps).squeeze(1)
336
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
337
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
338
+ else:
339
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
340
+ embedding=bert_dur,
341
+ embedding_scale=1,
342
+ embedding_mask_proba=0.1,
343
+ num_steps=num_steps).squeeze(1)
344
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
345
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
346
+ else:
347
+ loss_sty = 0
348
+ loss_diff = 0
349
+
350
+
351
+ s_loss = 0
352
+
353
+
354
+ d, p = model.predictor(d_en, s_dur,
355
+ input_lengths,
356
+ s2s_attn_mono,
357
+ text_mask)
358
+
359
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
360
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
361
+ en = []
362
+ gt = []
363
+ p_en = []
364
+ wav = []
365
+ st = []
366
+
367
+ for bib in range(len(mel_input_length)):
368
+ mel_length = int(mel_input_length[bib].item() / 2)
369
+
370
+ random_start = np.random.randint(0, mel_length - mel_len)
371
+ en.append(asr[bib, :, random_start:random_start+mel_len])
372
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
373
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
374
+
375
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
376
+ wav.append(torch.from_numpy(y).to(device))
377
+
378
+ # style reference (better to be different from the GT)
379
+ random_start = np.random.randint(0, mel_length - mel_len_st)
380
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
381
+
382
+ wav = torch.stack(wav).float().detach()
383
+
384
+ en = torch.stack(en)
385
+ p_en = torch.stack(p_en)
386
+ gt = torch.stack(gt).detach()
387
+ st = torch.stack(st).detach()
388
+
389
+
390
+ if gt.size(-1) < 80:
391
+ continue
392
+
393
+ s = model.style_encoder(gt.unsqueeze(1))
394
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
395
+
396
+ with torch.no_grad():
397
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
398
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
399
+
400
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
401
+
402
+ y_rec_gt = wav.unsqueeze(1)
403
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
404
+
405
+ wav = y_rec_gt
406
+
407
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
408
+
409
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
410
+
411
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
412
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
413
+
414
+ optimizer.zero_grad()
415
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
416
+ accelerator.backward(d_loss)
417
+ optimizer.step('msd')
418
+ optimizer.step('mpd')
419
+
420
+ # generator loss
421
+ optimizer.zero_grad()
422
+
423
+ loss_mel = stft_loss(y_rec, wav)
424
+ loss_gen_all = gl(wav, y_rec).mean()
425
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
426
+
427
+ loss_ce = 0
428
+ loss_dur = 0
429
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
430
+ _s2s_pred = _s2s_pred[:_text_length, :]
431
+ _text_input = _text_input[:_text_length].long()
432
+ _s2s_trg = torch.zeros_like(_s2s_pred)
433
+ for p in range(_s2s_trg.shape[0]):
434
+ _s2s_trg[p, :_text_input[p]] = 1
435
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
436
+
437
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
438
+ _text_input[1:_text_length-1])
439
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
440
+
441
+ loss_ce /= texts.size(0)
442
+ loss_dur /= texts.size(0)
443
+
444
+ loss_s2s = 0
445
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
446
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
447
+ loss_s2s /= texts.size(0)
448
+
449
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
450
+
451
+ g_loss = loss_params.lambda_mel * loss_mel + \
452
+ loss_params.lambda_F0 * loss_F0_rec + \
453
+ loss_params.lambda_ce * loss_ce + \
454
+ loss_params.lambda_norm * loss_norm_rec + \
455
+ loss_params.lambda_dur * loss_dur + \
456
+ loss_params.lambda_gen * loss_gen_all + \
457
+ loss_params.lambda_slm * loss_lm + \
458
+ loss_params.lambda_sty * loss_sty + \
459
+ loss_params.lambda_diff * loss_diff + \
460
+ loss_params.lambda_mono * loss_mono + \
461
+ loss_params.lambda_s2s * loss_s2s
462
+
463
+ running_loss += loss_mel.item()
464
+ accelerator.backward(g_loss)
465
+ if torch.isnan(g_loss):
466
+ from IPython.core.debugger import set_trace
467
+ set_trace()
468
+
469
+ optimizer.step('bert_encoder')
470
+ optimizer.step('bert')
471
+ optimizer.step('predictor')
472
+ optimizer.step('predictor_encoder')
473
+ optimizer.step('style_encoder')
474
+ optimizer.step('decoder')
475
+
476
+ optimizer.step('text_encoder')
477
+ optimizer.step('text_aligner')
478
+
479
+ if epoch >= diff_epoch:
480
+ optimizer.step('diffusion')
481
+
482
+ d_loss_slm, loss_gen_lm = 0, 0
483
+ if epoch >= joint_epoch:
484
+ # randomly pick whether to use in-distribution text
485
+ if np.random.rand() < 0.5:
486
+ use_ind = True
487
+ else:
488
+ use_ind = False
489
+
490
+ if use_ind:
491
+ ref_lengths = input_lengths
492
+ ref_texts = texts
493
+
494
+ slm_out = slmadv(i,
495
+ y_rec_gt,
496
+ y_rec_gt_pred,
497
+ waves,
498
+ mel_input_length,
499
+ ref_texts,
500
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
501
+
502
+ if slm_out is not None:
503
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
504
+
505
+ # SLM generator loss
506
+ optimizer.zero_grad()
507
+ accelerator.backward(loss_gen_lm)
508
+
509
+ # compute the gradient norm
510
+ total_norm = {}
511
+ for key in model.keys():
512
+ total_norm[key] = 0
513
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
514
+ for p in parameters:
515
+ param_norm = p.grad.detach().data.norm(2)
516
+ total_norm[key] += param_norm.item() ** 2
517
+ total_norm[key] = total_norm[key] ** 0.5
518
+
519
+ # gradient scaling
520
+ if total_norm['predictor'] > slmadv_params.thresh:
521
+ for key in model.keys():
522
+ for p in model[key].parameters():
523
+ if p.grad is not None:
524
+ p.grad *= (1 / total_norm['predictor'])
525
+
526
+ for p in model.predictor.duration_proj.parameters():
527
+ if p.grad is not None:
528
+ p.grad *= slmadv_params.scale
529
+
530
+ for p in model.predictor.lstm.parameters():
531
+ if p.grad is not None:
532
+ p.grad *= slmadv_params.scale
533
+
534
+ for p in model.diffusion.parameters():
535
+ if p.grad is not None:
536
+ p.grad *= slmadv_params.scale
537
+
538
+ optimizer.step('bert_encoder')
539
+ optimizer.step('bert')
540
+ optimizer.step('predictor')
541
+ optimizer.step('diffusion')
542
+
543
+ # SLM discriminator loss
544
+ if d_loss_slm != 0:
545
+ optimizer.zero_grad()
546
+ accelerator.backward(d_loss_slm)
547
+ optimizer.step('wd')
548
+
549
+ iters = iters + 1
550
+
551
+ if (i+1)%log_interval == 0:
552
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
553
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
554
+
555
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
556
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
557
+ writer.add_scalar('train/d_loss', d_loss, iters)
558
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
559
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
560
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
561
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
562
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
563
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
564
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
565
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
566
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
567
+
568
+ running_loss = 0
569
+
570
+ print('Time elasped:', time.time()-start_time)
571
+
572
+ loss_test = 0
573
+ loss_align = 0
574
+ loss_f = 0
575
+ _ = [model[key].eval() for key in model]
576
+
577
+ with torch.no_grad():
578
+ iters_test = 0
579
+ for batch_idx, batch in enumerate(val_dataloader):
580
+ optimizer.zero_grad()
581
+
582
+ try:
583
+ waves = batch[0]
584
+ batch = [b.to(device) for b in batch[1:]]
585
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
586
+ with torch.no_grad():
587
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
588
+ text_mask = length_to_mask(input_lengths).to(texts.device)
589
+
590
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
591
+ s2s_attn = s2s_attn.transpose(-1, -2)
592
+ s2s_attn = s2s_attn[..., 1:]
593
+ s2s_attn = s2s_attn.transpose(-1, -2)
594
+
595
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
596
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
597
+
598
+ # encode
599
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
600
+ asr = (t_en @ s2s_attn_mono)
601
+
602
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
603
+
604
+ ss = []
605
+ gs = []
606
+
607
+ for bib in range(len(mel_input_length)):
608
+ mel_length = int(mel_input_length[bib].item())
609
+ mel = mels[bib, :, :mel_input_length[bib]]
610
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
611
+ ss.append(s)
612
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
613
+ gs.append(s)
614
+
615
+ s = torch.stack(ss).squeeze()
616
+ gs = torch.stack(gs).squeeze()
617
+ s_trg = torch.cat([s, gs], dim=-1).detach()
618
+
619
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
620
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
621
+ d, p = model.predictor(d_en, s,
622
+ input_lengths,
623
+ s2s_attn_mono,
624
+ text_mask)
625
+ # get clips
626
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
627
+ en = []
628
+ gt = []
629
+
630
+ p_en = []
631
+ wav = []
632
+
633
+ for bib in range(len(mel_input_length)):
634
+ mel_length = int(mel_input_length[bib].item() / 2)
635
+
636
+ random_start = np.random.randint(0, mel_length - mel_len)
637
+ en.append(asr[bib, :, random_start:random_start+mel_len])
638
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
639
+
640
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
641
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
642
+ wav.append(torch.from_numpy(y).to(device))
643
+
644
+ wav = torch.stack(wav).float().detach()
645
+
646
+ en = torch.stack(en)
647
+ p_en = torch.stack(p_en)
648
+ gt = torch.stack(gt).detach()
649
+ s = model.predictor_encoder(gt.unsqueeze(1))
650
+
651
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
652
+
653
+ loss_dur = 0
654
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
655
+ _s2s_pred = _s2s_pred[:_text_length, :]
656
+ _text_input = _text_input[:_text_length].long()
657
+ _s2s_trg = torch.zeros_like(_s2s_pred)
658
+ for bib in range(_s2s_trg.shape[0]):
659
+ _s2s_trg[bib, :_text_input[bib]] = 1
660
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
661
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
662
+ _text_input[1:_text_length-1])
663
+
664
+ loss_dur /= texts.size(0)
665
+
666
+ s = model.style_encoder(gt.unsqueeze(1))
667
+
668
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
669
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
670
+
671
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
672
+
673
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
674
+
675
+ loss_test += (loss_mel).mean()
676
+ loss_align += (loss_dur).mean()
677
+ loss_f += (loss_F0).mean()
678
+
679
+ iters_test += 1
680
+ except:
681
+ continue
682
+
683
+ print('Epochs:', epoch + 1)
684
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
685
+ print('\n\n\n')
686
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
687
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
688
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
689
+
690
+
691
+ if (epoch + 1) % save_freq == 0 :
692
+ if (loss_test / iters_test) < best_loss:
693
+ best_loss = loss_test / iters_test
694
+ print('Saving..')
695
+ state = {
696
+ 'net': {key: model[key].state_dict() for key in model},
697
+ 'optimizer': optimizer.state_dict(),
698
+ 'iters': iters,
699
+ 'val_loss': loss_test / iters_test,
700
+ 'epoch': epoch,
701
+ }
702
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
703
+ torch.save(state, save_path)
704
+
705
+ # if estimate sigma, save the estimated simga
706
+ if model_params.diffusion.dist.estimate_sigma_data:
707
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
708
+
709
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
710
+ yaml.dump(config, outfile, default_flow_style=True)
711
+
712
+
713
+ if __name__=="__main__":
714
+ main()
train_first.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import re
4
+ import sys
5
+ import yaml
6
+ import shutil
7
+ import numpy as np
8
+ import torch
9
+ import click
10
+ import warnings
11
+ warnings.simplefilter('ignore')
12
+
13
+ # load packages
14
+ import random
15
+ import yaml
16
+ from munch import Munch
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ import torchaudio
22
+ import librosa
23
+
24
+ from models import *
25
+ from meldataset import build_dataloader
26
+ from utils import *
27
+ from losses import *
28
+ from optimizers import build_optimizer
29
+ import time
30
+
31
+ from accelerate import Accelerator
32
+ from accelerate.utils import LoggerType
33
+ from accelerate import DistributedDataParallelKwargs
34
+
35
+ from torch.utils.tensorboard import SummaryWriter
36
+
37
+ import logging
38
+ from accelerate.logging import get_logger
39
+ logger = get_logger(__name__, log_level="DEBUG")
40
+
41
+ @click.command()
42
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
43
+ def main(config_path):
44
+ config = yaml.safe_load(open(config_path))
45
+
46
+ log_dir = config['log_dir']
47
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
48
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
49
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
50
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])
51
+ if accelerator.is_main_process:
52
+ writer = SummaryWriter(log_dir + "/tensorboard")
53
+
54
+ # write logs
55
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
56
+ file_handler.setLevel(logging.DEBUG)
57
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
58
+ logger.logger.addHandler(file_handler)
59
+
60
+ batch_size = config.get('batch_size', 10)
61
+ device = accelerator.device
62
+
63
+ epochs = config.get('epochs_1st', 200)
64
+ save_freq = config.get('save_freq', 2)
65
+ log_interval = config.get('log_interval', 10)
66
+ saving_epoch = config.get('save_freq', 2)
67
+
68
+ data_params = config.get('data_params', None)
69
+ sr = config['preprocess_params'].get('sr', 24000)
70
+ train_path = data_params['train_data']
71
+ val_path = data_params['val_data']
72
+ root_path = data_params['root_path']
73
+ min_length = data_params['min_length']
74
+ OOD_data = data_params['OOD_data']
75
+
76
+ max_len = config.get('max_len', 200)
77
+
78
+ # load data
79
+ train_list, val_list = get_data_path_list(train_path, val_path)
80
+
81
+ train_dataloader = build_dataloader(train_list,
82
+ root_path,
83
+ OOD_data=OOD_data,
84
+ min_length=min_length,
85
+ batch_size=batch_size,
86
+ num_workers=2,
87
+ dataset_config={},
88
+ device=device)
89
+
90
+ val_dataloader = build_dataloader(val_list,
91
+ root_path,
92
+ OOD_data=OOD_data,
93
+ min_length=min_length,
94
+ batch_size=batch_size,
95
+ validation=True,
96
+ num_workers=0,
97
+ device=device,
98
+ dataset_config={})
99
+
100
+ with accelerator.main_process_first():
101
+ # load pretrained ASR model
102
+ ASR_config = config.get('ASR_config', False)
103
+ ASR_path = config.get('ASR_path', False)
104
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
105
+
106
+ # load pretrained F0 model
107
+ F0_path = config.get('F0_path', False)
108
+ pitch_extractor = load_F0_models(F0_path)
109
+
110
+ # load BERT model
111
+ from Utils.PLBERT.util import load_plbert
112
+ BERT_path = config.get('PLBERT_dir', False)
113
+ plbert = load_plbert(BERT_path)
114
+
115
+ scheduler_params = {
116
+ "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
117
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
118
+ "epochs": epochs,
119
+ "steps_per_epoch": len(train_dataloader),
120
+ }
121
+
122
+ model_params = recursive_munch(config['model_params'])
123
+ multispeaker = model_params.multispeaker
124
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
125
+
126
+ best_loss = float('inf') # best test loss
127
+ loss_train_record = list([])
128
+ loss_test_record = list([])
129
+
130
+ loss_params = Munch(config['loss_params'])
131
+ TMA_epoch = loss_params.TMA_epoch
132
+
133
+ for k in model:
134
+ model[k] = accelerator.prepare(model[k])
135
+
136
+ train_dataloader, val_dataloader = accelerator.prepare(
137
+ train_dataloader, val_dataloader
138
+ )
139
+
140
+ _ = [model[key].to(device) for key in model]
141
+
142
+ # initialize optimizers after preparing models for compatibility with FSDP
143
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
144
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model},
145
+ lr=float(config['optimizer_params'].get('lr', 1e-4)))
146
+
147
+ for k, v in optimizer.optimizers.items():
148
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
149
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
150
+
151
+ with accelerator.main_process_first():
152
+ if config.get('pretrained_model', '') != '':
153
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
154
+ load_only_params=config.get('load_only_params', True))
155
+ else:
156
+ start_epoch = 0
157
+ iters = 0
158
+
159
+ # in case not distributed
160
+ try:
161
+ n_down = model.text_aligner.module.n_down
162
+ except:
163
+ n_down = model.text_aligner.n_down
164
+
165
+ # wrapped losses for compatibility with mixed precision
166
+ stft_loss = MultiResolutionSTFTLoss().to(device)
167
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
168
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
169
+ wl = WavLMLoss(model_params.slm.model,
170
+ model.wd,
171
+ sr,
172
+ model_params.slm.sr).to(device)
173
+
174
+ for epoch in range(start_epoch, epochs):
175
+ running_loss = 0
176
+ start_time = time.time()
177
+
178
+ _ = [model[key].train() for key in model]
179
+
180
+ for i, batch in enumerate(train_dataloader):
181
+ waves = batch[0]
182
+ batch = [b.to(device) for b in batch[1:]]
183
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
184
+
185
+ with torch.no_grad():
186
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
187
+ text_mask = length_to_mask(input_lengths).to(texts.device)
188
+
189
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
190
+
191
+ s2s_attn = s2s_attn.transpose(-1, -2)
192
+ s2s_attn = s2s_attn[..., 1:]
193
+ s2s_attn = s2s_attn.transpose(-1, -2)
194
+
195
+ with torch.no_grad():
196
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
197
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
198
+ attn_mask = (attn_mask < 1)
199
+
200
+ s2s_attn.masked_fill_(attn_mask, 0.0)
201
+
202
+ with torch.no_grad():
203
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
204
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
205
+
206
+ # encode
207
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
208
+
209
+ # 50% of chance of using monotonic version
210
+ if bool(random.getrandbits(1)):
211
+ asr = (t_en @ s2s_attn)
212
+ else:
213
+ asr = (t_en @ s2s_attn_mono)
214
+
215
+ # get clips
216
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
217
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
218
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
219
+
220
+ en = []
221
+ gt = []
222
+ wav = []
223
+ st = []
224
+
225
+ for bib in range(len(mel_input_length)):
226
+ mel_length = int(mel_input_length[bib].item() / 2)
227
+
228
+ random_start = np.random.randint(0, mel_length - mel_len)
229
+ en.append(asr[bib, :, random_start:random_start+mel_len])
230
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
231
+
232
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
233
+ wav.append(torch.from_numpy(y).to(device))
234
+
235
+ # style reference (better to be different from the GT)
236
+ random_start = np.random.randint(0, mel_length - mel_len_st)
237
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
238
+
239
+ en = torch.stack(en)
240
+ gt = torch.stack(gt).detach()
241
+ st = torch.stack(st).detach()
242
+
243
+ wav = torch.stack(wav).float().detach()
244
+
245
+ # clip too short to be used by the style encoder
246
+ if gt.shape[-1] < 80:
247
+ continue
248
+
249
+ with torch.no_grad():
250
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
251
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
252
+
253
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
254
+
255
+ y_rec = model.decoder(en, F0_real, real_norm, s)
256
+
257
+ # discriminator loss
258
+
259
+ if epoch >= TMA_epoch:
260
+ optimizer.zero_grad()
261
+ d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
262
+ accelerator.backward(d_loss)
263
+ optimizer.step('msd')
264
+ optimizer.step('mpd')
265
+ else:
266
+ d_loss = 0
267
+
268
+ # generator loss
269
+ optimizer.zero_grad()
270
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
271
+
272
+ if epoch >= TMA_epoch: # start TMA training
273
+ loss_s2s = 0
274
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
275
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
276
+ loss_s2s /= texts.size(0)
277
+
278
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
279
+
280
+ loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
281
+ loss_slm = wl(wav.detach(), y_rec).mean()
282
+
283
+ g_loss = loss_params.lambda_mel * loss_mel + \
284
+ loss_params.lambda_mono * loss_mono + \
285
+ loss_params.lambda_s2s * loss_s2s + \
286
+ loss_params.lambda_gen * loss_gen_all + \
287
+ loss_params.lambda_slm * loss_slm
288
+
289
+ else:
290
+ loss_s2s = 0
291
+ loss_mono = 0
292
+ loss_gen_all = 0
293
+ loss_slm = 0
294
+ g_loss = loss_mel
295
+
296
+ running_loss += accelerator.gather(loss_mel).mean().item()
297
+
298
+ accelerator.backward(g_loss)
299
+
300
+ optimizer.step('text_encoder')
301
+ optimizer.step('style_encoder')
302
+ optimizer.step('decoder')
303
+
304
+ if epoch >= TMA_epoch:
305
+ optimizer.step('text_aligner')
306
+ optimizer.step('pitch_extractor')
307
+
308
+ iters = iters + 1
309
+
310
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
311
+ log_print ('Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f'
312
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, loss_gen_all, d_loss, loss_mono, loss_s2s, loss_slm), logger)
313
+
314
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
315
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
316
+ writer.add_scalar('train/d_loss', d_loss, iters)
317
+ writer.add_scalar('train/mono_loss', loss_mono, iters)
318
+ writer.add_scalar('train/s2s_loss', loss_s2s, iters)
319
+ writer.add_scalar('train/slm_loss', loss_slm, iters)
320
+
321
+ running_loss = 0
322
+
323
+ print('Time elasped:', time.time()-start_time)
324
+
325
+ loss_test = 0
326
+
327
+ _ = [model[key].eval() for key in model]
328
+
329
+ with torch.no_grad():
330
+ iters_test = 0
331
+ for batch_idx, batch in enumerate(val_dataloader):
332
+ optimizer.zero_grad()
333
+
334
+ waves = batch[0]
335
+ batch = [b.to(device) for b in batch[1:]]
336
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
337
+
338
+ with torch.no_grad():
339
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
340
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
341
+
342
+ s2s_attn = s2s_attn.transpose(-1, -2)
343
+ s2s_attn = s2s_attn[..., 1:]
344
+ s2s_attn = s2s_attn.transpose(-1, -2)
345
+
346
+ text_mask = length_to_mask(input_lengths).to(texts.device)
347
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
348
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
349
+ attn_mask = (attn_mask < 1)
350
+ s2s_attn.masked_fill_(attn_mask, 0.0)
351
+
352
+ # encode
353
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
354
+
355
+ asr = (t_en @ s2s_attn)
356
+
357
+ # get clips
358
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
359
+ mel_len = min([int(mel_input_length.min().item() / 2 - 1), max_len // 2])
360
+
361
+ en = []
362
+ gt = []
363
+ wav = []
364
+ for bib in range(len(mel_input_length)):
365
+ mel_length = int(mel_input_length[bib].item() / 2)
366
+
367
+ random_start = np.random.randint(0, mel_length - mel_len)
368
+ en.append(asr[bib, :, random_start:random_start+mel_len])
369
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
370
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
371
+ wav.append(torch.from_numpy(y).to('cuda'))
372
+
373
+ wav = torch.stack(wav).float().detach()
374
+
375
+ en = torch.stack(en)
376
+ gt = torch.stack(gt).detach()
377
+
378
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
379
+ s = model.style_encoder(gt.unsqueeze(1))
380
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
381
+ y_rec = model.decoder(en, F0_real, real_norm, s)
382
+
383
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
384
+
385
+ loss_test += accelerator.gather(loss_mel).mean().item()
386
+ iters_test += 1
387
+
388
+ if accelerator.is_main_process:
389
+ print('Epochs:', epoch + 1)
390
+ log_print('Validation loss: %.3f' % (loss_test / iters_test) + '\n\n\n\n', logger)
391
+ print('\n\n\n')
392
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
393
+ attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
394
+ writer.add_figure('eval/attn', attn_image, epoch)
395
+
396
+ with torch.no_grad():
397
+ for bib in range(len(asr)):
398
+ mel_length = int(mel_input_length[bib].item())
399
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
400
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
401
+
402
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
403
+ F0_real = F0_real.unsqueeze(0)
404
+ s = model.style_encoder(gt.unsqueeze(1))
405
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
406
+
407
+ y_rec = model.decoder(en, F0_real, real_norm, s)
408
+
409
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
410
+ if epoch == 0:
411
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
412
+
413
+ if bib >= 6:
414
+ break
415
+
416
+ if epoch % saving_epoch == 0:
417
+ if (loss_test / iters_test) < best_loss:
418
+ best_loss = loss_test / iters_test
419
+ print('Saving..')
420
+ state = {
421
+ 'net': {key: model[key].state_dict() for key in model},
422
+ 'optimizer': optimizer.state_dict(),
423
+ 'iters': iters,
424
+ 'val_loss': loss_test / iters_test,
425
+ 'epoch': epoch,
426
+ }
427
+ save_path = osp.join(log_dir, 'epoch_1st_%05d.pth' % epoch)
428
+ torch.save(state, save_path)
429
+
430
+ if accelerator.is_main_process:
431
+ print('Saving..')
432
+ state = {
433
+ 'net': {key: model[key].state_dict() for key in model},
434
+ 'optimizer': optimizer.state_dict(),
435
+ 'iters': iters,
436
+ 'val_loss': loss_test / iters_test,
437
+ 'epoch': epoch,
438
+ }
439
+ save_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
440
+ torch.save(state, save_path)
441
+
442
+
443
+
444
+ if __name__=="__main__":
445
+ main()
train_second.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import traceback
15
+ import warnings
16
+ warnings.simplefilter('ignore')
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+ # simple fix for dataparallel that allows access to class attributes
35
+ class MyDataParallel(torch.nn.DataParallel):
36
+ def __getattr__(self, name):
37
+ try:
38
+ return super().__getattr__(name)
39
+ except AttributeError:
40
+ return getattr(self.module, name)
41
+
42
+ import logging
43
+ from logging import StreamHandler
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.DEBUG)
46
+ handler = StreamHandler()
47
+ handler.setLevel(logging.DEBUG)
48
+ logger.addHandler(handler)
49
+
50
+
51
+ @click.command()
52
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
53
+ def main(config_path):
54
+ config = yaml.safe_load(open(config_path))
55
+
56
+ log_dir = config['log_dir']
57
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
58
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
59
+ writer = SummaryWriter(log_dir + "/tensorboard")
60
+
61
+ # write logs
62
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
63
+ file_handler.setLevel(logging.DEBUG)
64
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
65
+ logger.addHandler(file_handler)
66
+
67
+
68
+ batch_size = config.get('batch_size', 10)
69
+
70
+ epochs = config.get('epochs_2nd', 200)
71
+ save_freq = config.get('save_freq', 2)
72
+ log_interval = config.get('log_interval', 10)
73
+ saving_epoch = config.get('save_freq', 2)
74
+
75
+ data_params = config.get('data_params', None)
76
+ sr = config['preprocess_params'].get('sr', 24000)
77
+ train_path = data_params['train_data']
78
+ val_path = data_params['val_data']
79
+ root_path = data_params['root_path']
80
+ min_length = data_params['min_length']
81
+ OOD_data = data_params['OOD_data']
82
+
83
+ max_len = config.get('max_len', 200)
84
+
85
+ loss_params = Munch(config['loss_params'])
86
+ diff_epoch = loss_params.diff_epoch
87
+ joint_epoch = loss_params.joint_epoch
88
+
89
+ optimizer_params = Munch(config['optimizer_params'])
90
+
91
+ train_list, val_list = get_data_path_list(train_path, val_path)
92
+ device = 'cuda'
93
+
94
+ train_dataloader = build_dataloader(train_list,
95
+ root_path,
96
+ OOD_data=OOD_data,
97
+ min_length=min_length,
98
+ batch_size=batch_size,
99
+ num_workers=2,
100
+ dataset_config={},
101
+ device=device)
102
+
103
+ val_dataloader = build_dataloader(val_list,
104
+ root_path,
105
+ OOD_data=OOD_data,
106
+ min_length=min_length,
107
+ batch_size=batch_size,
108
+ validation=True,
109
+ num_workers=0,
110
+ device=device,
111
+ dataset_config={})
112
+
113
+ # load pretrained ASR model
114
+ ASR_config = config.get('ASR_config', False)
115
+ ASR_path = config.get('ASR_path', False)
116
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
117
+
118
+ # load pretrained F0 model
119
+ F0_path = config.get('F0_path', False)
120
+ pitch_extractor = load_F0_models(F0_path)
121
+
122
+ # load PL-BERT model
123
+ BERT_path = config.get('PLBERT_dir', False)
124
+ plbert = load_plbert(BERT_path)
125
+
126
+ # build model
127
+ model_params = recursive_munch(config['model_params'])
128
+ multispeaker = model_params.multispeaker
129
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
130
+ _ = [model[key].to(device) for key in model]
131
+
132
+ # DP
133
+ for key in model:
134
+ if key != "mpd" and key != "msd" and key != "wd":
135
+ model[key] = MyDataParallel(model[key])
136
+
137
+ start_epoch = 0
138
+ iters = 0
139
+
140
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
141
+
142
+ if not load_pretrained:
143
+ if config.get('first_stage_path', '') != '':
144
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
145
+ print('Loading the first stage model at %s ...' % first_stage_path)
146
+ model, _, start_epoch, iters = load_checkpoint(model,
147
+ None,
148
+ first_stage_path,
149
+ load_only_params=True,
150
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
151
+
152
+ # these epochs should be counted from the start epoch
153
+ diff_epoch += start_epoch
154
+ joint_epoch += start_epoch
155
+ epochs += start_epoch
156
+
157
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
158
+ else:
159
+ raise ValueError('You need to specify the path to the first stage model.')
160
+
161
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
162
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
163
+ wl = WavLMLoss(model_params.slm.model,
164
+ model.wd,
165
+ sr,
166
+ model_params.slm.sr).to(device)
167
+
168
+ gl = MyDataParallel(gl)
169
+ dl = MyDataParallel(dl)
170
+ wl = MyDataParallel(wl)
171
+
172
+ sampler = DiffusionSampler(
173
+ model.diffusion.diffusion,
174
+ sampler=ADPM2Sampler(),
175
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
176
+ clamp=False
177
+ )
178
+
179
+ scheduler_params = {
180
+ "max_lr": optimizer_params.lr,
181
+ "pct_start": float(0),
182
+ "epochs": epochs,
183
+ "steps_per_epoch": len(train_dataloader),
184
+ }
185
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
186
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
187
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
188
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
189
+
190
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
191
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
192
+
193
+ # adjust BERT learning rate
194
+ for g in optimizer.optimizers['bert'].param_groups:
195
+ g['betas'] = (0.9, 0.99)
196
+ g['lr'] = optimizer_params.bert_lr
197
+ g['initial_lr'] = optimizer_params.bert_lr
198
+ g['min_lr'] = 0
199
+ g['weight_decay'] = 0.01
200
+
201
+ # adjust acoustic module learning rate
202
+ for module in ["decoder", "style_encoder"]:
203
+ for g in optimizer.optimizers[module].param_groups:
204
+ g['betas'] = (0.0, 0.99)
205
+ g['lr'] = optimizer_params.ft_lr
206
+ g['initial_lr'] = optimizer_params.ft_lr
207
+ g['min_lr'] = 0
208
+ g['weight_decay'] = 1e-4
209
+
210
+ # load models if there is a model
211
+ if load_pretrained:
212
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
213
+ load_only_params=config.get('load_only_params', True))
214
+
215
+ n_down = model.text_aligner.n_down
216
+
217
+ best_loss = float('inf') # best test loss
218
+ loss_train_record = list([])
219
+ loss_test_record = list([])
220
+ iters = 0
221
+
222
+ criterion = nn.L1Loss() # F0 loss (regression)
223
+ torch.cuda.empty_cache()
224
+
225
+ stft_loss = MultiResolutionSTFTLoss().to(device)
226
+
227
+ print('BERT', optimizer.optimizers['bert'])
228
+ print('decoder', optimizer.optimizers['decoder'])
229
+
230
+ start_ds = False
231
+
232
+ running_std = []
233
+
234
+ slmadv_params = Munch(config['slmadv_params'])
235
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
236
+ slmadv_params.min_len,
237
+ slmadv_params.max_len,
238
+ batch_percentage=slmadv_params.batch_percentage,
239
+ skip_update=slmadv_params.iter,
240
+ sig=slmadv_params.sig
241
+ )
242
+
243
+
244
+ for epoch in range(start_epoch, epochs):
245
+ running_loss = 0
246
+ start_time = time.time()
247
+
248
+ _ = [model[key].eval() for key in model]
249
+
250
+ model.predictor.train()
251
+ model.bert_encoder.train()
252
+ model.bert.train()
253
+ model.msd.train()
254
+ model.mpd.train()
255
+
256
+
257
+ if epoch >= diff_epoch:
258
+ start_ds = True
259
+
260
+ for i, batch in enumerate(train_dataloader):
261
+ waves = batch[0]
262
+ batch = [b.to(device) for b in batch[1:]]
263
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
264
+
265
+ with torch.no_grad():
266
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
267
+ mel_mask = length_to_mask(mel_input_length).to(device)
268
+ text_mask = length_to_mask(input_lengths).to(texts.device)
269
+
270
+ try:
271
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
272
+ s2s_attn = s2s_attn.transpose(-1, -2)
273
+ s2s_attn = s2s_attn[..., 1:]
274
+ s2s_attn = s2s_attn.transpose(-1, -2)
275
+ except:
276
+ continue
277
+
278
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
279
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
280
+
281
+ # encode
282
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
283
+ asr = (t_en @ s2s_attn_mono)
284
+
285
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
286
+
287
+ # compute reference styles
288
+ if multispeaker and epoch >= diff_epoch:
289
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
290
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
291
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
292
+
293
+ # compute the style of the entire utterance
294
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
295
+ ss = []
296
+ gs = []
297
+ for bib in range(len(mel_input_length)):
298
+ mel_length = int(mel_input_length[bib].item())
299
+ mel = mels[bib, :, :mel_input_length[bib]]
300
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
301
+ ss.append(s)
302
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
303
+ gs.append(s)
304
+
305
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
306
+ gs = torch.stack(gs).squeeze() # global acoustic styles
307
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
308
+
309
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
310
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
311
+
312
+ # denoiser training
313
+ if epoch >= diff_epoch:
314
+ num_steps = np.random.randint(3, 5)
315
+
316
+ if model_params.diffusion.dist.estimate_sigma_data:
317
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
318
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
319
+
320
+ if multispeaker:
321
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
322
+ embedding=bert_dur,
323
+ embedding_scale=1,
324
+ features=ref, # reference from the same speaker as the embedding
325
+ embedding_mask_proba=0.1,
326
+ num_steps=num_steps).squeeze(1)
327
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
328
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
329
+ else:
330
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
331
+ embedding=bert_dur,
332
+ embedding_scale=1,
333
+ embedding_mask_proba=0.1,
334
+ num_steps=num_steps).squeeze(1)
335
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
336
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
337
+ else:
338
+ loss_sty = 0
339
+ loss_diff = 0
340
+
341
+ d, p = model.predictor(d_en, s_dur,
342
+ input_lengths,
343
+ s2s_attn_mono,
344
+ text_mask)
345
+
346
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
347
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
348
+ en = []
349
+ gt = []
350
+ st = []
351
+ p_en = []
352
+ wav = []
353
+
354
+ for bib in range(len(mel_input_length)):
355
+ mel_length = int(mel_input_length[bib].item() / 2)
356
+
357
+ random_start = np.random.randint(0, mel_length - mel_len)
358
+ en.append(asr[bib, :, random_start:random_start+mel_len])
359
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
360
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
361
+
362
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
363
+ wav.append(torch.from_numpy(y).to(device))
364
+
365
+ # style reference (better to be different from the GT)
366
+ random_start = np.random.randint(0, mel_length - mel_len_st)
367
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
368
+
369
+ wav = torch.stack(wav).float().detach()
370
+
371
+ en = torch.stack(en)
372
+ p_en = torch.stack(p_en)
373
+ gt = torch.stack(gt).detach()
374
+ st = torch.stack(st).detach()
375
+
376
+ if gt.size(-1) < 80:
377
+ continue
378
+
379
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
380
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
381
+
382
+ with torch.no_grad():
383
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
384
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
385
+
386
+ asr_real = model.text_aligner.get_feature(gt)
387
+
388
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
389
+
390
+ y_rec_gt = wav.unsqueeze(1)
391
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
392
+
393
+ if epoch >= joint_epoch:
394
+ # ground truth from recording
395
+ wav = y_rec_gt # use recording since decoder is tuned
396
+ else:
397
+ # ground truth from reconstruction
398
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
399
+
400
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
401
+
402
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
403
+
404
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
405
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
406
+
407
+ if start_ds:
408
+ optimizer.zero_grad()
409
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
410
+ d_loss.backward()
411
+ optimizer.step('msd')
412
+ optimizer.step('mpd')
413
+ else:
414
+ d_loss = 0
415
+
416
+ # generator loss
417
+ optimizer.zero_grad()
418
+
419
+ loss_mel = stft_loss(y_rec, wav)
420
+ if start_ds:
421
+ loss_gen_all = gl(wav, y_rec).mean()
422
+ else:
423
+ loss_gen_all = 0
424
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
425
+
426
+ loss_ce = 0
427
+ loss_dur = 0
428
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
429
+ _s2s_pred = _s2s_pred[:_text_length, :]
430
+ _text_input = _text_input[:_text_length].long()
431
+ _s2s_trg = torch.zeros_like(_s2s_pred)
432
+ for p in range(_s2s_trg.shape[0]):
433
+ _s2s_trg[p, :_text_input[p]] = 1
434
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
435
+
436
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
437
+ _text_input[1:_text_length-1])
438
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
439
+
440
+ loss_ce /= texts.size(0)
441
+ loss_dur /= texts.size(0)
442
+
443
+ g_loss = loss_params.lambda_mel * loss_mel + \
444
+ loss_params.lambda_F0 * loss_F0_rec + \
445
+ loss_params.lambda_ce * loss_ce + \
446
+ loss_params.lambda_norm * loss_norm_rec + \
447
+ loss_params.lambda_dur * loss_dur + \
448
+ loss_params.lambda_gen * loss_gen_all + \
449
+ loss_params.lambda_slm * loss_lm + \
450
+ loss_params.lambda_sty * loss_sty + \
451
+ loss_params.lambda_diff * loss_diff
452
+
453
+ running_loss += loss_mel.item()
454
+ g_loss.backward()
455
+ if torch.isnan(g_loss):
456
+ from IPython.core.debugger import set_trace
457
+ set_trace()
458
+
459
+ optimizer.step('bert_encoder')
460
+ optimizer.step('bert')
461
+ optimizer.step('predictor')
462
+ optimizer.step('predictor_encoder')
463
+
464
+ if epoch >= diff_epoch:
465
+ optimizer.step('diffusion')
466
+
467
+ if epoch >= joint_epoch:
468
+ optimizer.step('style_encoder')
469
+ optimizer.step('decoder')
470
+
471
+ # randomly pick whether to use in-distribution text
472
+ if np.random.rand() < 0.5:
473
+ use_ind = True
474
+ else:
475
+ use_ind = False
476
+
477
+ if use_ind:
478
+ ref_lengths = input_lengths
479
+ ref_texts = texts
480
+
481
+ slm_out = slmadv(i,
482
+ y_rec_gt,
483
+ y_rec_gt_pred,
484
+ waves,
485
+ mel_input_length,
486
+ ref_texts,
487
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
488
+
489
+ if slm_out is None:
490
+ continue
491
+
492
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
493
+
494
+ # SLM generator loss
495
+ optimizer.zero_grad()
496
+ loss_gen_lm.backward()
497
+
498
+ # compute the gradient norm
499
+ total_norm = {}
500
+ for key in model.keys():
501
+ total_norm[key] = 0
502
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
503
+ for p in parameters:
504
+ param_norm = p.grad.detach().data.norm(2)
505
+ total_norm[key] += param_norm.item() ** 2
506
+ total_norm[key] = total_norm[key] ** 0.5
507
+
508
+ # gradient scaling
509
+ if total_norm['predictor'] > slmadv_params.thresh:
510
+ for key in model.keys():
511
+ for p in model[key].parameters():
512
+ if p.grad is not None:
513
+ p.grad *= (1 / total_norm['predictor'])
514
+
515
+ for p in model.predictor.duration_proj.parameters():
516
+ if p.grad is not None:
517
+ p.grad *= slmadv_params.scale
518
+
519
+ for p in model.predictor.lstm.parameters():
520
+ if p.grad is not None:
521
+ p.grad *= slmadv_params.scale
522
+
523
+ for p in model.diffusion.parameters():
524
+ if p.grad is not None:
525
+ p.grad *= slmadv_params.scale
526
+
527
+ optimizer.step('bert_encoder')
528
+ optimizer.step('bert')
529
+ optimizer.step('predictor')
530
+ optimizer.step('diffusion')
531
+
532
+ # SLM discriminator loss
533
+ if d_loss_slm != 0:
534
+ optimizer.zero_grad()
535
+ d_loss_slm.backward(retain_graph=True)
536
+ optimizer.step('wd')
537
+
538
+ else:
539
+ d_loss_slm, loss_gen_lm = 0, 0
540
+
541
+ iters = iters + 1
542
+
543
+ if (i+1)%log_interval == 0:
544
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
545
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
546
+
547
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
548
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
549
+ writer.add_scalar('train/d_loss', d_loss, iters)
550
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
551
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
552
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
553
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
554
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
555
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
556
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
557
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
558
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
559
+
560
+ running_loss = 0
561
+
562
+ print('Time elasped:', time.time()-start_time)
563
+
564
+ loss_test = 0
565
+ loss_align = 0
566
+ loss_f = 0
567
+ _ = [model[key].eval() for key in model]
568
+
569
+ with torch.no_grad():
570
+ iters_test = 0
571
+ for batch_idx, batch in enumerate(val_dataloader):
572
+ optimizer.zero_grad()
573
+
574
+ try:
575
+ waves = batch[0]
576
+ batch = [b.to(device) for b in batch[1:]]
577
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
578
+ with torch.no_grad():
579
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
580
+ text_mask = length_to_mask(input_lengths).to(texts.device)
581
+
582
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
583
+ s2s_attn = s2s_attn.transpose(-1, -2)
584
+ s2s_attn = s2s_attn[..., 1:]
585
+ s2s_attn = s2s_attn.transpose(-1, -2)
586
+
587
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
588
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
589
+
590
+ # encode
591
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
592
+ asr = (t_en @ s2s_attn_mono)
593
+
594
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
595
+
596
+ ss = []
597
+ gs = []
598
+
599
+ for bib in range(len(mel_input_length)):
600
+ mel_length = int(mel_input_length[bib].item())
601
+ mel = mels[bib, :, :mel_input_length[bib]]
602
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
603
+ ss.append(s)
604
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
605
+ gs.append(s)
606
+
607
+ s = torch.stack(ss).squeeze()
608
+ gs = torch.stack(gs).squeeze()
609
+ s_trg = torch.cat([s, gs], dim=-1).detach()
610
+
611
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
612
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
613
+ d, p = model.predictor(d_en, s,
614
+ input_lengths,
615
+ s2s_attn_mono,
616
+ text_mask)
617
+ # get clips
618
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
619
+ en = []
620
+ gt = []
621
+ p_en = []
622
+ wav = []
623
+
624
+ for bib in range(len(mel_input_length)):
625
+ mel_length = int(mel_input_length[bib].item() / 2)
626
+
627
+ random_start = np.random.randint(0, mel_length - mel_len)
628
+ en.append(asr[bib, :, random_start:random_start+mel_len])
629
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
630
+
631
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
632
+
633
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
634
+ wav.append(torch.from_numpy(y).to(device))
635
+
636
+ wav = torch.stack(wav).float().detach()
637
+
638
+ en = torch.stack(en)
639
+ p_en = torch.stack(p_en)
640
+ gt = torch.stack(gt).detach()
641
+
642
+ s = model.predictor_encoder(gt.unsqueeze(1))
643
+
644
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
645
+
646
+ loss_dur = 0
647
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
648
+ _s2s_pred = _s2s_pred[:_text_length, :]
649
+ _text_input = _text_input[:_text_length].long()
650
+ _s2s_trg = torch.zeros_like(_s2s_pred)
651
+ for bib in range(_s2s_trg.shape[0]):
652
+ _s2s_trg[bib, :_text_input[bib]] = 1
653
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
654
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
655
+ _text_input[1:_text_length-1])
656
+
657
+ loss_dur /= texts.size(0)
658
+
659
+ s = model.style_encoder(gt.unsqueeze(1))
660
+
661
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
662
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
663
+
664
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
665
+
666
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
667
+
668
+ loss_test += (loss_mel).mean()
669
+ loss_align += (loss_dur).mean()
670
+ loss_f += (loss_F0).mean()
671
+
672
+ iters_test += 1
673
+ except Exception as e:
674
+ print(f"run into exception", e)
675
+ traceback.print_exc()
676
+ continue
677
+
678
+ print('Epochs:', epoch + 1)
679
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
680
+ print('\n\n\n')
681
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
682
+ writer.add_scalar('eval/dur_loss', loss_align / iters_test, epoch + 1)
683
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
684
+
685
+ if epoch < joint_epoch:
686
+ # generating reconstruction examples with GT duration
687
+
688
+ with torch.no_grad():
689
+ for bib in range(len(asr)):
690
+ mel_length = int(mel_input_length[bib].item())
691
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
692
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
693
+
694
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
695
+ F0_real = F0_real.unsqueeze(0)
696
+ s = model.style_encoder(gt.unsqueeze(1))
697
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
698
+
699
+ y_rec = model.decoder(en, F0_real, real_norm, s)
700
+
701
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
702
+
703
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
704
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
705
+
706
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
707
+
708
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
709
+
710
+ writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
711
+
712
+ if epoch == 0:
713
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
714
+
715
+ if bib >= 5:
716
+ break
717
+ else:
718
+ # generating sampled speech from text directly
719
+ with torch.no_grad():
720
+ # compute reference styles
721
+ if multispeaker and epoch >= diff_epoch:
722
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
723
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
724
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
725
+
726
+ for bib in range(len(d_en)):
727
+ if multispeaker:
728
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
729
+ embedding=bert_dur[bib].unsqueeze(0),
730
+ embedding_scale=1,
731
+ features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding
732
+ num_steps=5).squeeze(1)
733
+ else:
734
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
735
+ embedding=bert_dur[bib].unsqueeze(0),
736
+ embedding_scale=1,
737
+ num_steps=5).squeeze(1)
738
+
739
+ s = s_pred[:, 128:]
740
+ ref = s_pred[:, :128]
741
+
742
+ d = model.predictor.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
743
+ s, input_lengths[bib, ...].unsqueeze(0), text_mask[bib, :input_lengths[bib]].unsqueeze(0))
744
+
745
+ x, _ = model.predictor.lstm(d)
746
+ duration = model.predictor.duration_proj(x)
747
+
748
+ duration = torch.sigmoid(duration).sum(axis=-1)
749
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
750
+
751
+ pred_dur[-1] += 5
752
+
753
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
754
+ c_frame = 0
755
+ for i in range(pred_aln_trg.size(0)):
756
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
757
+ c_frame += int(pred_dur[i].data)
758
+
759
+ # encode prosody
760
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
761
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
762
+ out = model.decoder((t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
763
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
764
+
765
+ writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
766
+
767
+ if bib >= 5:
768
+ break
769
+
770
+ if epoch % saving_epoch == 0:
771
+ if (loss_test / iters_test) < best_loss:
772
+ best_loss = loss_test / iters_test
773
+ print('Saving..')
774
+ state = {
775
+ 'net': {key: model[key].state_dict() for key in model},
776
+ 'optimizer': optimizer.state_dict(),
777
+ 'iters': iters,
778
+ 'val_loss': loss_test / iters_test,
779
+ 'epoch': epoch,
780
+ }
781
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
782
+ torch.save(state, save_path)
783
+
784
+ # if estimate sigma, save the estimated simga
785
+ if model_params.diffusion.dist.estimate_sigma_data:
786
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
787
+
788
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
789
+ yaml.dump(config, outfile, default_flow_style=True)
790
+
791
+ if __name__=="__main__":
792
+ main()
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from monotonic_align import maximum_path
2
+ from monotonic_align import mask_from_lens
3
+ from monotonic_align.core import maximum_path_c
4
+ import numpy as np
5
+ import torch
6
+ import copy
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ import librosa
11
+ import matplotlib.pyplot as plt
12
+ from munch import Munch
13
+
14
+ def maximum_path(neg_cent, mask):
15
+ """ Cython optimized version.
16
+ neg_cent: [b, t_t, t_s]
17
+ mask: [b, t_t, t_s]
18
+ """
19
+ device = neg_cent.device
20
+ dtype = neg_cent.dtype
21
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
22
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
23
+
24
+ t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32))
25
+ t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32))
26
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
27
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
28
+
29
+ def get_data_path_list(train_path=None, val_path=None):
30
+ if train_path is None:
31
+ train_path = "Data/train_list.txt"
32
+ if val_path is None:
33
+ val_path = "Data/val_list.txt"
34
+
35
+ with open(train_path, 'r', encoding='utf-8', errors='ignore') as f:
36
+ train_list = f.readlines()
37
+ with open(val_path, 'r', encoding='utf-8', errors='ignore') as f:
38
+ val_list = f.readlines()
39
+
40
+ return train_list, val_list
41
+
42
+ def length_to_mask(lengths):
43
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
44
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
45
+ return mask
46
+
47
+ # for norm consistency loss
48
+ def log_norm(x, mean=-4, std=4, dim=2):
49
+ """
50
+ normalized log mel -> mel -> norm -> log(norm)
51
+ """
52
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
53
+ return x
54
+
55
+ def get_image(arrs):
56
+ plt.switch_backend('agg')
57
+ fig = plt.figure()
58
+ ax = plt.gca()
59
+ ax.imshow(arrs)
60
+
61
+ return fig
62
+
63
+ def recursive_munch(d):
64
+ if isinstance(d, dict):
65
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
66
+ elif isinstance(d, list):
67
+ return [recursive_munch(v) for v in d]
68
+ else:
69
+ return d
70
+
71
+ def log_print(message, logger):
72
+ logger.info(message)
73
+ print(message)
74
+