Spaces:
Build error
Build error
Commit
•
c731c61
1
Parent(s):
25b69c5
linting
Browse files- app.py +51 -35
- layers/fc.py +3 -2
- layers/layer_norm.py +1 -1
- model_LA.py +64 -74
- model_LAV.py +70 -83
- utils/audio.py +46 -29
- utils/audio_params.py +23 -18
- utils/compute_args.py +28 -15
- utils/plot.py +1 -1
- utils/pred_func.py +1 -1
- utils/tokenize.py +14 -15
app.py
CHANGED
@@ -6,39 +6,47 @@ import torch
|
|
6 |
import numpy as np
|
7 |
from utils.audio import load_spectrograms
|
8 |
from utils.compute_args import compute_args
|
9 |
-
from utils.tokenize import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from model_LA import Model_LA
|
11 |
import gradio as gr
|
12 |
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
# load model
|
16 |
-
ckpts_path =
|
17 |
model_name = "Model_LA_e"
|
18 |
# Listing sorted checkpoints
|
19 |
-
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name,
|
20 |
|
21 |
# Load original args
|
22 |
-
args = torch.load(ckpts[0], map_location=torch.device(device))[
|
23 |
args = compute_args(args)
|
24 |
pretrained_emb = np.load("train_glove.npy")
|
25 |
-
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb"))
|
26 |
-
state_dict = torch.load(ckpts[0], map_location=torch.device(device))[
|
27 |
|
28 |
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device)
|
29 |
net.load_state_dict(state_dict)
|
30 |
|
|
|
31 |
def inference(source_video, transcription):
|
32 |
# data preprocessing
|
33 |
# text
|
34 |
def clean(w):
|
35 |
-
return
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
s = [clean(w) for w in transcription.split() if clean(w) !=
|
42 |
|
43 |
# Sound
|
44 |
_, mel, mag = load_spectrograms(source_video)
|
@@ -55,32 +63,40 @@ def inference(source_video, transcription):
|
|
55 |
print(f"Processed video shape from {mel.shape} to {V.shape}")
|
56 |
|
57 |
net.train(False)
|
58 |
-
x = np.expand_dims(L,axis=0)
|
59 |
-
y = np.expand_dims(A,axis=0)
|
60 |
-
z = np.expand_dims(V,axis=0)
|
61 |
-
x, y, z =
|
|
|
|
|
|
|
|
|
62 |
pred = net(x, y, z).cpu().data.numpy()[0]
|
63 |
# pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax
|
64 |
-
label_to_ix = [
|
65 |
# result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))}
|
66 |
-
result_dict = {label_to_ix[i]: float(pred[i])>0 for i in range(len(label_to_ix))}
|
67 |
return result_dict
|
68 |
|
69 |
|
70 |
-
title="Emotion Recognition"
|
71 |
-
description=""
|
72 |
-
|
73 |
-
examples = [
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
6 |
import numpy as np
|
7 |
from utils.audio import load_spectrograms
|
8 |
from utils.compute_args import compute_args
|
9 |
+
from utils.tokenize import (
|
10 |
+
tokenize,
|
11 |
+
create_dict,
|
12 |
+
sent_to_ix,
|
13 |
+
cmumosei_2,
|
14 |
+
cmumosei_7,
|
15 |
+
pad_feature,
|
16 |
+
)
|
17 |
from model_LA import Model_LA
|
18 |
import gradio as gr
|
19 |
|
20 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
# load model
|
23 |
+
ckpts_path = "ckpt"
|
24 |
model_name = "Model_LA_e"
|
25 |
# Listing sorted checkpoints
|
26 |
+
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name, "best*")), reverse=True)
|
27 |
|
28 |
# Load original args
|
29 |
+
args = torch.load(ckpts[0], map_location=torch.device(device))["args"]
|
30 |
args = compute_args(args)
|
31 |
pretrained_emb = np.load("train_glove.npy")
|
32 |
+
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb"))
|
33 |
+
state_dict = torch.load(ckpts[0], map_location=torch.device(device))["state_dict"]
|
34 |
|
35 |
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device)
|
36 |
net.load_state_dict(state_dict)
|
37 |
|
38 |
+
|
39 |
def inference(source_video, transcription):
|
40 |
# data preprocessing
|
41 |
# text
|
42 |
def clean(w):
|
43 |
+
return (
|
44 |
+
re.sub(r"([.,'!?\"()*#:;])", "", w.lower())
|
45 |
+
.replace("-", " ")
|
46 |
+
.replace("/", " ")
|
47 |
+
)
|
48 |
|
49 |
+
s = [clean(w) for w in transcription.split() if clean(w) != ""]
|
50 |
|
51 |
# Sound
|
52 |
_, mel, mag = load_spectrograms(source_video)
|
|
|
63 |
print(f"Processed video shape from {mel.shape} to {V.shape}")
|
64 |
|
65 |
net.train(False)
|
66 |
+
x = np.expand_dims(L, axis=0)
|
67 |
+
y = np.expand_dims(A, axis=0)
|
68 |
+
z = np.expand_dims(V, axis=0)
|
69 |
+
x, y, z = (
|
70 |
+
torch.from_numpy(x).to(device),
|
71 |
+
torch.from_numpy(y).to(device),
|
72 |
+
torch.from_numpy(z).float().to(device),
|
73 |
+
)
|
74 |
pred = net(x, y, z).cpu().data.numpy()[0]
|
75 |
# pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax
|
76 |
+
label_to_ix = ["happy", "sad", "angry", "fear", "disgust", "surprise"]
|
77 |
# result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))}
|
78 |
+
result_dict = {label_to_ix[i]: float(pred[i]) > 0 for i in range(len(label_to_ix))}
|
79 |
return result_dict
|
80 |
|
81 |
|
82 |
+
title = "Emotion Recognition"
|
83 |
+
description = ""
|
84 |
+
|
85 |
+
examples = [
|
86 |
+
[
|
87 |
+
"examples/0h-zjBukYpk_2.mp4",
|
88 |
+
"NOW IM NOT EVEN GONNA SUGAR COAT THIS THIS MOVIE FRUSTRATED ME TO SUCH AN EXTREME EXTENT THAT I WAS LOUDLY EXCLAIMING WHY AT THE END OF THE FILM",
|
89 |
+
],
|
90 |
+
["examples/0h-zjBukYpk_19.mp4", "NOW OTHER PERFORMANCES ARE BORDERLINE OKAY"],
|
91 |
+
["examples/03bSnISJMiM_1.mp4", "IT WAS REALLY GOOD "],
|
92 |
+
["examples/03bSnISJMiM_5.mp4", "AND THEY SHOULDVE I GUESS "],
|
93 |
+
]
|
94 |
+
|
95 |
+
gr.Interface(
|
96 |
+
inference,
|
97 |
+
inputs=[gr.inputs.Video(type="avi", source="upload"), "text"],
|
98 |
+
outputs=["label"],
|
99 |
+
title=title,
|
100 |
+
description=description,
|
101 |
+
examples=examples,
|
102 |
+
).launch(debug=True)
|
layers/fc.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import torch.nn as nn
|
2 |
|
|
|
3 |
class FC(nn.Module):
|
4 |
-
def __init__(self, in_size, out_size, dropout_r=0
|
5 |
super(FC, self).__init__()
|
6 |
self.dropout_r = dropout_r
|
7 |
self.use_relu = use_relu
|
@@ -27,7 +28,7 @@ class FC(nn.Module):
|
|
27 |
|
28 |
|
29 |
class MLP(nn.Module):
|
30 |
-
def __init__(self, in_size, mid_size, out_size, dropout_r=0
|
31 |
super(MLP, self).__init__()
|
32 |
|
33 |
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
+
|
4 |
class FC(nn.Module):
|
5 |
+
def __init__(self, in_size, out_size, dropout_r=0.0, use_relu=True):
|
6 |
super(FC, self).__init__()
|
7 |
self.dropout_r = dropout_r
|
8 |
self.use_relu = use_relu
|
|
|
28 |
|
29 |
|
30 |
class MLP(nn.Module):
|
31 |
+
def __init__(self, in_size, mid_size, out_size, dropout_r=0.0, use_relu=True):
|
32 |
super(MLP, self).__init__()
|
33 |
|
34 |
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
layers/layer_norm.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch.nn as nn
|
2 |
import torch
|
3 |
|
|
|
4 |
class LayerNorm(nn.Module):
|
5 |
def __init__(self, size, eps=1e-6):
|
6 |
super(LayerNorm, self).__init__()
|
@@ -13,4 +14,3 @@ class LayerNorm(nn.Module):
|
|
13 |
mean = x.mean(-1, keepdim=True)
|
14 |
std = x.std(-1, keepdim=True)
|
15 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
16 |
-
|
|
|
1 |
import torch.nn as nn
|
2 |
import torch
|
3 |
|
4 |
+
|
5 |
class LayerNorm(nn.Module):
|
6 |
def __init__(self, size, eps=1e-6):
|
7 |
super(LayerNorm, self).__init__()
|
|
|
14 |
mean = x.mean(-1, keepdim=True)
|
15 |
std = x.std(-1, keepdim=True)
|
16 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
|
model_LA.py
CHANGED
@@ -10,10 +10,8 @@ from layers.layer_norm import LayerNorm
|
|
10 |
# ---------- Masking sequence --------
|
11 |
# ------------------------------------
|
12 |
def make_mask(feature):
|
13 |
-
return (torch.sum(
|
14 |
-
|
15 |
-
dim=-1
|
16 |
-
) == 0).unsqueeze(1).unsqueeze(2)
|
17 |
|
18 |
# ------------------------------
|
19 |
# ---------- Flattening --------
|
@@ -31,29 +29,23 @@ class AttFlat(nn.Module):
|
|
31 |
mid_size=args.ff_size,
|
32 |
out_size=flat_glimpse,
|
33 |
dropout_r=args.dropout_r,
|
34 |
-
use_relu=True
|
35 |
)
|
36 |
|
37 |
if self.merge:
|
38 |
self.linear_merge = nn.Linear(
|
39 |
-
args.hidden_size * flat_glimpse,
|
40 |
-
args.hidden_size * 2
|
41 |
)
|
42 |
|
43 |
def forward(self, x, x_mask):
|
44 |
att = self.mlp(x)
|
45 |
if x_mask is not None:
|
46 |
-
att = att.masked_fill(
|
47 |
-
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
48 |
-
-1e9
|
49 |
-
)
|
50 |
att = F.softmax(att, dim=1)
|
51 |
|
52 |
att_list = []
|
53 |
for i in range(self.flat_glimpse):
|
54 |
-
att_list.append(
|
55 |
-
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
56 |
-
)
|
57 |
|
58 |
if self.merge:
|
59 |
x_atted = torch.cat(att_list, dim=1)
|
@@ -63,10 +55,12 @@ class AttFlat(nn.Module):
|
|
63 |
|
64 |
return torch.stack(att_list).transpose_(0, 1)
|
65 |
|
|
|
66 |
# ------------------------
|
67 |
# ---- Self Attention ----
|
68 |
# ------------------------
|
69 |
|
|
|
70 |
class SA(nn.Module):
|
71 |
def __init__(self, args):
|
72 |
super(SA, self).__init__()
|
@@ -81,13 +75,9 @@ class SA(nn.Module):
|
|
81 |
self.norm2 = LayerNorm(args.hidden_size)
|
82 |
|
83 |
def forward(self, y, y_mask):
|
84 |
-
y = self.norm1(y + self.dropout1(
|
85 |
-
self.mhatt(y, y, y, y_mask)
|
86 |
-
))
|
87 |
|
88 |
-
y = self.norm2(y + self.dropout2(
|
89 |
-
self.ffn(y)
|
90 |
-
))
|
91 |
|
92 |
return y
|
93 |
|
@@ -96,6 +86,7 @@ class SA(nn.Module):
|
|
96 |
# ---- Self Guided Attention ----
|
97 |
# -------------------------------
|
98 |
|
|
|
99 |
class SGA(nn.Module):
|
100 |
def __init__(self, args):
|
101 |
super(SGA, self).__init__()
|
@@ -114,24 +105,20 @@ class SGA(nn.Module):
|
|
114 |
self.norm3 = LayerNorm(args.hidden_size)
|
115 |
|
116 |
def forward(self, x, y, x_mask, y_mask):
|
117 |
-
x = self.norm1(x + self.dropout1(
|
118 |
-
self.mhatt1(v=x, k=x, q=x, mask=x_mask)
|
119 |
-
))
|
120 |
|
121 |
-
x = self.norm2(x + self.dropout2(
|
122 |
-
self.mhatt2(v=y, k=y, q=x, mask=y_mask)
|
123 |
-
))
|
124 |
|
125 |
-
x = self.norm3(x + self.dropout3(
|
126 |
-
self.ffn(x)
|
127 |
-
))
|
128 |
|
129 |
return x
|
130 |
|
|
|
131 |
# ------------------------------
|
132 |
# ---- Multi-Head Attention ----
|
133 |
# ------------------------------
|
134 |
|
|
|
135 |
class MHAtt(nn.Module):
|
136 |
def __init__(self, args):
|
137 |
super(MHAtt, self).__init__()
|
@@ -146,33 +133,45 @@ class MHAtt(nn.Module):
|
|
146 |
|
147 |
def forward(self, v, k, q, mask):
|
148 |
n_batches = q.size(0)
|
149 |
-
v =
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
atted = self.att(v, k, q, mask)
|
171 |
|
172 |
-
atted =
|
173 |
-
|
174 |
-
|
175 |
-
self.args.hidden_size
|
176 |
)
|
177 |
atted = self.linear_merge(atted)
|
178 |
|
@@ -181,9 +180,7 @@ class MHAtt(nn.Module):
|
|
181 |
def att(self, value, key, query, mask):
|
182 |
d_k = query.size(-1)
|
183 |
|
184 |
-
scores = torch.matmul(
|
185 |
-
query, key.transpose(-2, -1)
|
186 |
-
) / math.sqrt(d_k)
|
187 |
|
188 |
if mask is not None:
|
189 |
scores = scores.masked_fill(mask, -1e9)
|
@@ -198,6 +195,7 @@ class MHAtt(nn.Module):
|
|
198 |
# ---- Feed Forward Nets ----
|
199 |
# ---------------------------
|
200 |
|
|
|
201 |
class FFN(nn.Module):
|
202 |
def __init__(self, args):
|
203 |
super(FFN, self).__init__()
|
@@ -207,12 +205,13 @@ class FFN(nn.Module):
|
|
207 |
mid_size=args.ff_size,
|
208 |
out_size=args.hidden_size,
|
209 |
dropout_r=args.dropout_r,
|
210 |
-
use_relu=True
|
211 |
)
|
212 |
|
213 |
def forward(self, x):
|
214 |
return self.mlp(x)
|
215 |
|
|
|
216 |
# ---------------------------
|
217 |
# ---- FF + norm -----------
|
218 |
# ---------------------------
|
@@ -231,7 +230,6 @@ class FFAndNorm(nn.Module):
|
|
231 |
return x
|
232 |
|
233 |
|
234 |
-
|
235 |
class Block(nn.Module):
|
236 |
def __init__(self, args, i):
|
237 |
super(Block, self).__init__()
|
@@ -239,7 +237,7 @@ class Block(nn.Module):
|
|
239 |
self.sa1 = SA(args)
|
240 |
self.sa3 = SGA(args)
|
241 |
|
242 |
-
self.last =
|
243 |
if not self.last:
|
244 |
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
245 |
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
@@ -261,8 +259,7 @@ class Block(nn.Module):
|
|
261 |
ax = self.att_lang(x, x_mask)
|
262 |
ay = self.att_audio(y, y_mask)
|
263 |
|
264 |
-
return self.norm_l(x + self.dropout(ax)),
|
265 |
-
self.norm_i(y + self.dropout(ay))
|
266 |
|
267 |
|
268 |
class Model_LA(nn.Module):
|
@@ -273,8 +270,7 @@ class Model_LA(nn.Module):
|
|
273 |
|
274 |
# LSTM
|
275 |
self.embedding = nn.Embedding(
|
276 |
-
num_embeddings=vocab_size,
|
277 |
-
embedding_dim=args.word_embed_size
|
278 |
)
|
279 |
|
280 |
# Loading the GloVe embedding weights
|
@@ -284,7 +280,7 @@ class Model_LA(nn.Module):
|
|
284 |
input_size=args.word_embed_size,
|
285 |
hidden_size=args.hidden_size,
|
286 |
num_layers=1,
|
287 |
-
batch_first=True
|
288 |
)
|
289 |
|
290 |
# self.lstm_y = nn.LSTM(
|
@@ -301,7 +297,7 @@ class Model_LA(nn.Module):
|
|
301 |
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
302 |
|
303 |
# Flattenting features before proj
|
304 |
-
self.attflat_img
|
305 |
self.attflat_lang = AttFlat(args, 1, merge=True)
|
306 |
|
307 |
# Classification layers
|
@@ -325,19 +321,13 @@ class Model_LA(nn.Module):
|
|
325 |
x_m, x_y = x_mask, y_mask
|
326 |
x, y = dec(x, x_m, y, x_y)
|
327 |
|
328 |
-
x = self.attflat_lang(
|
329 |
-
x,
|
330 |
-
None
|
331 |
-
)
|
332 |
|
333 |
-
y = self.attflat_img(
|
334 |
-
y,
|
335 |
-
None
|
336 |
-
)
|
337 |
|
338 |
# Classification layers
|
339 |
proj_feat = x + y
|
340 |
proj_feat = self.proj_norm(proj_feat)
|
341 |
ans = self.proj(proj_feat)
|
342 |
|
343 |
-
return ans
|
|
|
10 |
# ---------- Masking sequence --------
|
11 |
# ------------------------------------
|
12 |
def make_mask(feature):
|
13 |
+
return (torch.sum(torch.abs(feature), dim=-1) == 0).unsqueeze(1).unsqueeze(2)
|
14 |
+
|
|
|
|
|
15 |
|
16 |
# ------------------------------
|
17 |
# ---------- Flattening --------
|
|
|
29 |
mid_size=args.ff_size,
|
30 |
out_size=flat_glimpse,
|
31 |
dropout_r=args.dropout_r,
|
32 |
+
use_relu=True,
|
33 |
)
|
34 |
|
35 |
if self.merge:
|
36 |
self.linear_merge = nn.Linear(
|
37 |
+
args.hidden_size * flat_glimpse, args.hidden_size * 2
|
|
|
38 |
)
|
39 |
|
40 |
def forward(self, x, x_mask):
|
41 |
att = self.mlp(x)
|
42 |
if x_mask is not None:
|
43 |
+
att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), -1e9)
|
|
|
|
|
|
|
44 |
att = F.softmax(att, dim=1)
|
45 |
|
46 |
att_list = []
|
47 |
for i in range(self.flat_glimpse):
|
48 |
+
att_list.append(torch.sum(att[:, :, i : i + 1] * x, dim=1))
|
|
|
|
|
49 |
|
50 |
if self.merge:
|
51 |
x_atted = torch.cat(att_list, dim=1)
|
|
|
55 |
|
56 |
return torch.stack(att_list).transpose_(0, 1)
|
57 |
|
58 |
+
|
59 |
# ------------------------
|
60 |
# ---- Self Attention ----
|
61 |
# ------------------------
|
62 |
|
63 |
+
|
64 |
class SA(nn.Module):
|
65 |
def __init__(self, args):
|
66 |
super(SA, self).__init__()
|
|
|
75 |
self.norm2 = LayerNorm(args.hidden_size)
|
76 |
|
77 |
def forward(self, y, y_mask):
|
78 |
+
y = self.norm1(y + self.dropout1(self.mhatt(y, y, y, y_mask)))
|
|
|
|
|
79 |
|
80 |
+
y = self.norm2(y + self.dropout2(self.ffn(y)))
|
|
|
|
|
81 |
|
82 |
return y
|
83 |
|
|
|
86 |
# ---- Self Guided Attention ----
|
87 |
# -------------------------------
|
88 |
|
89 |
+
|
90 |
class SGA(nn.Module):
|
91 |
def __init__(self, args):
|
92 |
super(SGA, self).__init__()
|
|
|
105 |
self.norm3 = LayerNorm(args.hidden_size)
|
106 |
|
107 |
def forward(self, x, y, x_mask, y_mask):
|
108 |
+
x = self.norm1(x + self.dropout1(self.mhatt1(v=x, k=x, q=x, mask=x_mask)))
|
|
|
|
|
109 |
|
110 |
+
x = self.norm2(x + self.dropout2(self.mhatt2(v=y, k=y, q=x, mask=y_mask)))
|
|
|
|
|
111 |
|
112 |
+
x = self.norm3(x + self.dropout3(self.ffn(x)))
|
|
|
|
|
113 |
|
114 |
return x
|
115 |
|
116 |
+
|
117 |
# ------------------------------
|
118 |
# ---- Multi-Head Attention ----
|
119 |
# ------------------------------
|
120 |
|
121 |
+
|
122 |
class MHAtt(nn.Module):
|
123 |
def __init__(self, args):
|
124 |
super(MHAtt, self).__init__()
|
|
|
133 |
|
134 |
def forward(self, v, k, q, mask):
|
135 |
n_batches = q.size(0)
|
136 |
+
v = (
|
137 |
+
self.linear_v(v)
|
138 |
+
.view(
|
139 |
+
n_batches,
|
140 |
+
-1,
|
141 |
+
self.args.multi_head,
|
142 |
+
int(self.args.hidden_size / self.args.multi_head),
|
143 |
+
)
|
144 |
+
.transpose(1, 2)
|
145 |
+
)
|
146 |
+
|
147 |
+
k = (
|
148 |
+
self.linear_k(k)
|
149 |
+
.view(
|
150 |
+
n_batches,
|
151 |
+
-1,
|
152 |
+
self.args.multi_head,
|
153 |
+
int(self.args.hidden_size / self.args.multi_head),
|
154 |
+
)
|
155 |
+
.transpose(1, 2)
|
156 |
+
)
|
157 |
+
|
158 |
+
q = (
|
159 |
+
self.linear_q(q)
|
160 |
+
.view(
|
161 |
+
n_batches,
|
162 |
+
-1,
|
163 |
+
self.args.multi_head,
|
164 |
+
int(self.args.hidden_size / self.args.multi_head),
|
165 |
+
)
|
166 |
+
.transpose(1, 2)
|
167 |
+
)
|
168 |
|
169 |
atted = self.att(v, k, q, mask)
|
170 |
|
171 |
+
atted = (
|
172 |
+
atted.transpose(1, 2)
|
173 |
+
.contiguous()
|
174 |
+
.view(n_batches, -1, self.args.hidden_size)
|
175 |
)
|
176 |
atted = self.linear_merge(atted)
|
177 |
|
|
|
180 |
def att(self, value, key, query, mask):
|
181 |
d_k = query.size(-1)
|
182 |
|
183 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
|
|
|
184 |
|
185 |
if mask is not None:
|
186 |
scores = scores.masked_fill(mask, -1e9)
|
|
|
195 |
# ---- Feed Forward Nets ----
|
196 |
# ---------------------------
|
197 |
|
198 |
+
|
199 |
class FFN(nn.Module):
|
200 |
def __init__(self, args):
|
201 |
super(FFN, self).__init__()
|
|
|
205 |
mid_size=args.ff_size,
|
206 |
out_size=args.hidden_size,
|
207 |
dropout_r=args.dropout_r,
|
208 |
+
use_relu=True,
|
209 |
)
|
210 |
|
211 |
def forward(self, x):
|
212 |
return self.mlp(x)
|
213 |
|
214 |
+
|
215 |
# ---------------------------
|
216 |
# ---- FF + norm -----------
|
217 |
# ---------------------------
|
|
|
230 |
return x
|
231 |
|
232 |
|
|
|
233 |
class Block(nn.Module):
|
234 |
def __init__(self, args, i):
|
235 |
super(Block, self).__init__()
|
|
|
237 |
self.sa1 = SA(args)
|
238 |
self.sa3 = SGA(args)
|
239 |
|
240 |
+
self.last = i == args.layer - 1
|
241 |
if not self.last:
|
242 |
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
243 |
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
|
|
259 |
ax = self.att_lang(x, x_mask)
|
260 |
ay = self.att_audio(y, y_mask)
|
261 |
|
262 |
+
return self.norm_l(x + self.dropout(ax)), self.norm_i(y + self.dropout(ay))
|
|
|
263 |
|
264 |
|
265 |
class Model_LA(nn.Module):
|
|
|
270 |
|
271 |
# LSTM
|
272 |
self.embedding = nn.Embedding(
|
273 |
+
num_embeddings=vocab_size, embedding_dim=args.word_embed_size
|
|
|
274 |
)
|
275 |
|
276 |
# Loading the GloVe embedding weights
|
|
|
280 |
input_size=args.word_embed_size,
|
281 |
hidden_size=args.hidden_size,
|
282 |
num_layers=1,
|
283 |
+
batch_first=True,
|
284 |
)
|
285 |
|
286 |
# self.lstm_y = nn.LSTM(
|
|
|
297 |
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
298 |
|
299 |
# Flattenting features before proj
|
300 |
+
self.attflat_img = AttFlat(args, 1, merge=True)
|
301 |
self.attflat_lang = AttFlat(args, 1, merge=True)
|
302 |
|
303 |
# Classification layers
|
|
|
321 |
x_m, x_y = x_mask, y_mask
|
322 |
x, y = dec(x, x_m, y, x_y)
|
323 |
|
324 |
+
x = self.attflat_lang(x, None)
|
|
|
|
|
|
|
325 |
|
326 |
+
y = self.attflat_img(y, None)
|
|
|
|
|
|
|
327 |
|
328 |
# Classification layers
|
329 |
proj_feat = x + y
|
330 |
proj_feat = self.proj_norm(proj_feat)
|
331 |
ans = self.proj(proj_feat)
|
332 |
|
333 |
+
return ans
|
model_LAV.py
CHANGED
@@ -10,10 +10,8 @@ from layers.layer_norm import LayerNorm
|
|
10 |
# ---------- Masking sequence --------
|
11 |
# ------------------------------------
|
12 |
def make_mask(feature):
|
13 |
-
return (torch.sum(
|
14 |
-
|
15 |
-
dim=-1
|
16 |
-
) == 0).unsqueeze(1).unsqueeze(2)
|
17 |
|
18 |
# ------------------------------
|
19 |
# ---------- Flattening --------
|
@@ -31,29 +29,23 @@ class AttFlat(nn.Module):
|
|
31 |
mid_size=args.ff_size,
|
32 |
out_size=flat_glimpse,
|
33 |
dropout_r=args.dropout_r,
|
34 |
-
use_relu=True
|
35 |
)
|
36 |
|
37 |
if self.merge:
|
38 |
self.linear_merge = nn.Linear(
|
39 |
-
args.hidden_size * flat_glimpse,
|
40 |
-
args.hidden_size * 2
|
41 |
)
|
42 |
|
43 |
def forward(self, x, x_mask):
|
44 |
att = self.mlp(x)
|
45 |
if x_mask is not None:
|
46 |
-
att = att.masked_fill(
|
47 |
-
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
48 |
-
-1e9
|
49 |
-
)
|
50 |
att = F.softmax(att, dim=1)
|
51 |
|
52 |
att_list = []
|
53 |
for i in range(self.flat_glimpse):
|
54 |
-
att_list.append(
|
55 |
-
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
56 |
-
)
|
57 |
|
58 |
if self.merge:
|
59 |
x_atted = torch.cat(att_list, dim=1)
|
@@ -63,10 +55,12 @@ class AttFlat(nn.Module):
|
|
63 |
|
64 |
return torch.stack(att_list).transpose_(0, 1)
|
65 |
|
|
|
66 |
# ------------------------
|
67 |
# ---- Self Attention ----
|
68 |
# ------------------------
|
69 |
|
|
|
70 |
class SA(nn.Module):
|
71 |
def __init__(self, args):
|
72 |
super(SA, self).__init__()
|
@@ -81,13 +75,9 @@ class SA(nn.Module):
|
|
81 |
self.norm2 = LayerNorm(args.hidden_size)
|
82 |
|
83 |
def forward(self, y, y_mask):
|
84 |
-
y = self.norm1(y + self.dropout1(
|
85 |
-
self.mhatt(y, y, y, y_mask)
|
86 |
-
))
|
87 |
|
88 |
-
y = self.norm2(y + self.dropout2(
|
89 |
-
self.ffn(y)
|
90 |
-
))
|
91 |
|
92 |
return y
|
93 |
|
@@ -96,6 +86,7 @@ class SA(nn.Module):
|
|
96 |
# ---- Self Guided Attention ----
|
97 |
# -------------------------------
|
98 |
|
|
|
99 |
class SGA(nn.Module):
|
100 |
def __init__(self, args):
|
101 |
super(SGA, self).__init__()
|
@@ -114,24 +105,20 @@ class SGA(nn.Module):
|
|
114 |
self.norm3 = LayerNorm(args.hidden_size)
|
115 |
|
116 |
def forward(self, x, y, x_mask, y_mask):
|
117 |
-
x = self.norm1(x + self.dropout1(
|
118 |
-
self.mhatt1(v=x, k=x, q=x, mask=x_mask)
|
119 |
-
))
|
120 |
|
121 |
-
x = self.norm2(x + self.dropout2(
|
122 |
-
self.mhatt2(v=y, k=y, q=x, mask=y_mask)
|
123 |
-
))
|
124 |
|
125 |
-
x = self.norm3(x + self.dropout3(
|
126 |
-
self.ffn(x)
|
127 |
-
))
|
128 |
|
129 |
return x
|
130 |
|
|
|
131 |
# ------------------------------
|
132 |
# ---- Multi-Head Attention ----
|
133 |
# ------------------------------
|
134 |
|
|
|
135 |
class MHAtt(nn.Module):
|
136 |
def __init__(self, args):
|
137 |
super(MHAtt, self).__init__()
|
@@ -146,33 +133,45 @@ class MHAtt(nn.Module):
|
|
146 |
|
147 |
def forward(self, v, k, q, mask):
|
148 |
n_batches = q.size(0)
|
149 |
-
v =
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
atted = self.att(v, k, q, mask)
|
171 |
|
172 |
-
atted =
|
173 |
-
|
174 |
-
|
175 |
-
self.args.hidden_size
|
176 |
)
|
177 |
atted = self.linear_merge(atted)
|
178 |
|
@@ -181,9 +180,7 @@ class MHAtt(nn.Module):
|
|
181 |
def att(self, value, key, query, mask):
|
182 |
d_k = query.size(-1)
|
183 |
|
184 |
-
scores = torch.matmul(
|
185 |
-
query, key.transpose(-2, -1)
|
186 |
-
) / math.sqrt(d_k)
|
187 |
|
188 |
if mask is not None:
|
189 |
scores = scores.masked_fill(mask, -1e9)
|
@@ -198,6 +195,7 @@ class MHAtt(nn.Module):
|
|
198 |
# ---- Feed Forward Nets ----
|
199 |
# ---------------------------
|
200 |
|
|
|
201 |
class FFN(nn.Module):
|
202 |
def __init__(self, args):
|
203 |
super(FFN, self).__init__()
|
@@ -207,12 +205,13 @@ class FFN(nn.Module):
|
|
207 |
mid_size=args.ff_size,
|
208 |
out_size=args.hidden_size,
|
209 |
dropout_r=args.dropout_r,
|
210 |
-
use_relu=True
|
211 |
)
|
212 |
|
213 |
def forward(self, x):
|
214 |
return self.mlp(x)
|
215 |
|
|
|
216 |
# ---------------------------
|
217 |
# ---- FF + norm -----------
|
218 |
# ---------------------------
|
@@ -231,7 +230,6 @@ class FFAndNorm(nn.Module):
|
|
231 |
return x
|
232 |
|
233 |
|
234 |
-
|
235 |
class Block(nn.Module):
|
236 |
def __init__(self, args, i):
|
237 |
super(Block, self).__init__()
|
@@ -240,7 +238,7 @@ class Block(nn.Module):
|
|
240 |
self.sa2 = SGA(args)
|
241 |
self.sa3 = SGA(args)
|
242 |
|
243 |
-
self.last =
|
244 |
if not self.last:
|
245 |
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
246 |
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
@@ -267,10 +265,11 @@ class Block(nn.Module):
|
|
267 |
ay = self.att_audio(y, y_mask)
|
268 |
az = self.att_vid(z, y_mask)
|
269 |
|
270 |
-
return
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
274 |
|
275 |
|
276 |
class Model_LAV(nn.Module):
|
@@ -281,8 +280,7 @@ class Model_LAV(nn.Module):
|
|
281 |
|
282 |
# LSTM
|
283 |
self.embedding = nn.Embedding(
|
284 |
-
num_embeddings=vocab_size,
|
285 |
-
embedding_dim=args.word_embed_size
|
286 |
)
|
287 |
|
288 |
# Loading the GloVe embedding weights
|
@@ -292,7 +290,7 @@ class Model_LAV(nn.Module):
|
|
292 |
input_size=args.word_embed_size,
|
293 |
hidden_size=args.hidden_size,
|
294 |
num_layers=1,
|
295 |
-
batch_first=True
|
296 |
)
|
297 |
|
298 |
# self.lstm_y = nn.LSTM(
|
@@ -310,8 +308,8 @@ class Model_LAV(nn.Module):
|
|
310 |
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
311 |
|
312 |
# Flattenting features before proj
|
313 |
-
self.attflat_ac
|
314 |
-
self.attflat_vid
|
315 |
self.attflat_lang = AttFlat(args, 1, merge=True)
|
316 |
|
317 |
# Classification layers
|
@@ -329,7 +327,6 @@ class Model_LAV(nn.Module):
|
|
329 |
y_mask = make_mask(y)
|
330 |
z_mask = make_mask(z)
|
331 |
|
332 |
-
|
333 |
embedding = self.embedding(x)
|
334 |
|
335 |
x, _ = self.lstm_x(embedding)
|
@@ -343,25 +340,15 @@ class Model_LAV(nn.Module):
|
|
343 |
x_m, y_m, z_m = x_mask, y_mask, z_mask
|
344 |
x, y, z = dec(x, x_m, y, y_m, z, z_m)
|
345 |
|
346 |
-
x = self.attflat_lang(
|
347 |
-
x,
|
348 |
-
None
|
349 |
-
)
|
350 |
-
|
351 |
-
y = self.attflat_ac(
|
352 |
-
y,
|
353 |
-
None
|
354 |
-
)
|
355 |
|
356 |
-
|
357 |
-
z,
|
358 |
-
None
|
359 |
-
)
|
360 |
|
|
|
361 |
|
362 |
# Classification layers
|
363 |
proj_feat = x + y + z
|
364 |
proj_feat = self.proj_norm(proj_feat)
|
365 |
ans = self.proj(proj_feat)
|
366 |
|
367 |
-
return ans
|
|
|
10 |
# ---------- Masking sequence --------
|
11 |
# ------------------------------------
|
12 |
def make_mask(feature):
|
13 |
+
return (torch.sum(torch.abs(feature), dim=-1) == 0).unsqueeze(1).unsqueeze(2)
|
14 |
+
|
|
|
|
|
15 |
|
16 |
# ------------------------------
|
17 |
# ---------- Flattening --------
|
|
|
29 |
mid_size=args.ff_size,
|
30 |
out_size=flat_glimpse,
|
31 |
dropout_r=args.dropout_r,
|
32 |
+
use_relu=True,
|
33 |
)
|
34 |
|
35 |
if self.merge:
|
36 |
self.linear_merge = nn.Linear(
|
37 |
+
args.hidden_size * flat_glimpse, args.hidden_size * 2
|
|
|
38 |
)
|
39 |
|
40 |
def forward(self, x, x_mask):
|
41 |
att = self.mlp(x)
|
42 |
if x_mask is not None:
|
43 |
+
att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), -1e9)
|
|
|
|
|
|
|
44 |
att = F.softmax(att, dim=1)
|
45 |
|
46 |
att_list = []
|
47 |
for i in range(self.flat_glimpse):
|
48 |
+
att_list.append(torch.sum(att[:, :, i : i + 1] * x, dim=1))
|
|
|
|
|
49 |
|
50 |
if self.merge:
|
51 |
x_atted = torch.cat(att_list, dim=1)
|
|
|
55 |
|
56 |
return torch.stack(att_list).transpose_(0, 1)
|
57 |
|
58 |
+
|
59 |
# ------------------------
|
60 |
# ---- Self Attention ----
|
61 |
# ------------------------
|
62 |
|
63 |
+
|
64 |
class SA(nn.Module):
|
65 |
def __init__(self, args):
|
66 |
super(SA, self).__init__()
|
|
|
75 |
self.norm2 = LayerNorm(args.hidden_size)
|
76 |
|
77 |
def forward(self, y, y_mask):
|
78 |
+
y = self.norm1(y + self.dropout1(self.mhatt(y, y, y, y_mask)))
|
|
|
|
|
79 |
|
80 |
+
y = self.norm2(y + self.dropout2(self.ffn(y)))
|
|
|
|
|
81 |
|
82 |
return y
|
83 |
|
|
|
86 |
# ---- Self Guided Attention ----
|
87 |
# -------------------------------
|
88 |
|
89 |
+
|
90 |
class SGA(nn.Module):
|
91 |
def __init__(self, args):
|
92 |
super(SGA, self).__init__()
|
|
|
105 |
self.norm3 = LayerNorm(args.hidden_size)
|
106 |
|
107 |
def forward(self, x, y, x_mask, y_mask):
|
108 |
+
x = self.norm1(x + self.dropout1(self.mhatt1(v=x, k=x, q=x, mask=x_mask)))
|
|
|
|
|
109 |
|
110 |
+
x = self.norm2(x + self.dropout2(self.mhatt2(v=y, k=y, q=x, mask=y_mask)))
|
|
|
|
|
111 |
|
112 |
+
x = self.norm3(x + self.dropout3(self.ffn(x)))
|
|
|
|
|
113 |
|
114 |
return x
|
115 |
|
116 |
+
|
117 |
# ------------------------------
|
118 |
# ---- Multi-Head Attention ----
|
119 |
# ------------------------------
|
120 |
|
121 |
+
|
122 |
class MHAtt(nn.Module):
|
123 |
def __init__(self, args):
|
124 |
super(MHAtt, self).__init__()
|
|
|
133 |
|
134 |
def forward(self, v, k, q, mask):
|
135 |
n_batches = q.size(0)
|
136 |
+
v = (
|
137 |
+
self.linear_v(v)
|
138 |
+
.view(
|
139 |
+
n_batches,
|
140 |
+
-1,
|
141 |
+
self.args.multi_head,
|
142 |
+
int(self.args.hidden_size / self.args.multi_head),
|
143 |
+
)
|
144 |
+
.transpose(1, 2)
|
145 |
+
)
|
146 |
+
|
147 |
+
k = (
|
148 |
+
self.linear_k(k)
|
149 |
+
.view(
|
150 |
+
n_batches,
|
151 |
+
-1,
|
152 |
+
self.args.multi_head,
|
153 |
+
int(self.args.hidden_size / self.args.multi_head),
|
154 |
+
)
|
155 |
+
.transpose(1, 2)
|
156 |
+
)
|
157 |
+
|
158 |
+
q = (
|
159 |
+
self.linear_q(q)
|
160 |
+
.view(
|
161 |
+
n_batches,
|
162 |
+
-1,
|
163 |
+
self.args.multi_head,
|
164 |
+
int(self.args.hidden_size / self.args.multi_head),
|
165 |
+
)
|
166 |
+
.transpose(1, 2)
|
167 |
+
)
|
168 |
|
169 |
atted = self.att(v, k, q, mask)
|
170 |
|
171 |
+
atted = (
|
172 |
+
atted.transpose(1, 2)
|
173 |
+
.contiguous()
|
174 |
+
.view(n_batches, -1, self.args.hidden_size)
|
175 |
)
|
176 |
atted = self.linear_merge(atted)
|
177 |
|
|
|
180 |
def att(self, value, key, query, mask):
|
181 |
d_k = query.size(-1)
|
182 |
|
183 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
|
|
|
184 |
|
185 |
if mask is not None:
|
186 |
scores = scores.masked_fill(mask, -1e9)
|
|
|
195 |
# ---- Feed Forward Nets ----
|
196 |
# ---------------------------
|
197 |
|
198 |
+
|
199 |
class FFN(nn.Module):
|
200 |
def __init__(self, args):
|
201 |
super(FFN, self).__init__()
|
|
|
205 |
mid_size=args.ff_size,
|
206 |
out_size=args.hidden_size,
|
207 |
dropout_r=args.dropout_r,
|
208 |
+
use_relu=True,
|
209 |
)
|
210 |
|
211 |
def forward(self, x):
|
212 |
return self.mlp(x)
|
213 |
|
214 |
+
|
215 |
# ---------------------------
|
216 |
# ---- FF + norm -----------
|
217 |
# ---------------------------
|
|
|
230 |
return x
|
231 |
|
232 |
|
|
|
233 |
class Block(nn.Module):
|
234 |
def __init__(self, args, i):
|
235 |
super(Block, self).__init__()
|
|
|
238 |
self.sa2 = SGA(args)
|
239 |
self.sa3 = SGA(args)
|
240 |
|
241 |
+
self.last = i == args.layer - 1
|
242 |
if not self.last:
|
243 |
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
244 |
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
|
|
265 |
ay = self.att_audio(y, y_mask)
|
266 |
az = self.att_vid(z, y_mask)
|
267 |
|
268 |
+
return (
|
269 |
+
self.norm_l(x + self.dropout(ax)),
|
270 |
+
self.norm_a(y + self.dropout(ay)),
|
271 |
+
self.norm_v(z + self.dropout(az)),
|
272 |
+
)
|
273 |
|
274 |
|
275 |
class Model_LAV(nn.Module):
|
|
|
280 |
|
281 |
# LSTM
|
282 |
self.embedding = nn.Embedding(
|
283 |
+
num_embeddings=vocab_size, embedding_dim=args.word_embed_size
|
|
|
284 |
)
|
285 |
|
286 |
# Loading the GloVe embedding weights
|
|
|
290 |
input_size=args.word_embed_size,
|
291 |
hidden_size=args.hidden_size,
|
292 |
num_layers=1,
|
293 |
+
batch_first=True,
|
294 |
)
|
295 |
|
296 |
# self.lstm_y = nn.LSTM(
|
|
|
308 |
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
309 |
|
310 |
# Flattenting features before proj
|
311 |
+
self.attflat_ac = AttFlat(args, 1, merge=True)
|
312 |
+
self.attflat_vid = AttFlat(args, 1, merge=True)
|
313 |
self.attflat_lang = AttFlat(args, 1, merge=True)
|
314 |
|
315 |
# Classification layers
|
|
|
327 |
y_mask = make_mask(y)
|
328 |
z_mask = make_mask(z)
|
329 |
|
|
|
330 |
embedding = self.embedding(x)
|
331 |
|
332 |
x, _ = self.lstm_x(embedding)
|
|
|
340 |
x_m, y_m, z_m = x_mask, y_mask, z_mask
|
341 |
x, y, z = dec(x, x_m, y, y_m, z, z_m)
|
342 |
|
343 |
+
x = self.attflat_lang(x, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
y = self.attflat_ac(y, None)
|
|
|
|
|
|
|
346 |
|
347 |
+
z = self.attflat_vid(z, None)
|
348 |
|
349 |
# Classification layers
|
350 |
proj_feat = x + y + z
|
351 |
proj_feat = self.proj_norm(proj_feat)
|
352 |
ans = self.proj(proj_feat)
|
353 |
|
354 |
+
return ans
|
utils/audio.py
CHANGED
@@ -1,24 +1,26 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
|
4 |
By kyubyong park. [email protected].
|
5 |
https://www.github.com/kyubyong/dc_tts
|
6 |
-
|
7 |
from __future__ import print_function, division
|
8 |
|
9 |
import numpy as np
|
10 |
import librosa
|
11 |
import os, copy
|
12 |
import matplotlib
|
13 |
-
|
|
|
14 |
import matplotlib.pyplot as plt
|
15 |
from scipy import signal
|
16 |
|
17 |
from .audio_params import Hyperparams as hp
|
18 |
import tensorflow as tf
|
19 |
|
|
|
20 |
def get_spectrograms(fpath):
|
21 |
-
|
22 |
Returns normalized melspectrogram and linear spectrogram.
|
23 |
|
24 |
Args:
|
@@ -27,7 +29,7 @@ def get_spectrograms(fpath):
|
|
27 |
Returns:
|
28 |
mel: A 2d array of shape (T, n_mels) and dtype of float32.
|
29 |
mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.
|
30 |
-
|
31 |
# Loading sound file
|
32 |
y, sr = librosa.load(fpath, sr=hp.sr)
|
33 |
|
@@ -38,10 +40,9 @@ def get_spectrograms(fpath):
|
|
38 |
y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])
|
39 |
|
40 |
# stft
|
41 |
-
linear = librosa.stft(
|
42 |
-
|
43 |
-
|
44 |
-
win_length=hp.win_length)
|
45 |
|
46 |
# magnitude spectrogram
|
47 |
mag = np.abs(linear) # (1+n_fft//2, T)
|
@@ -64,15 +65,16 @@ def get_spectrograms(fpath):
|
|
64 |
|
65 |
return mel, mag
|
66 |
|
|
|
67 |
def spectrogram2wav(mag):
|
68 |
-
|
69 |
|
70 |
Args:
|
71 |
mag: A numpy array of (T, 1+n_fft//2)
|
72 |
|
73 |
Returns:
|
74 |
wav: A 1-D numpy array.
|
75 |
-
|
76 |
# transpose
|
77 |
mag = mag.T
|
78 |
|
@@ -83,7 +85,7 @@ def spectrogram2wav(mag):
|
|
83 |
mag = np.power(10.0, mag * 0.05)
|
84 |
|
85 |
# wav reconstruction
|
86 |
-
wav = griffin_lim(mag**hp.power)
|
87 |
|
88 |
# de-preemphasis
|
89 |
wav = signal.lfilter([1], [1, -hp.preemphasis], wav)
|
@@ -93,8 +95,9 @@ def spectrogram2wav(mag):
|
|
93 |
|
94 |
return wav.astype(np.float32)
|
95 |
|
|
|
96 |
def griffin_lim(spectrogram):
|
97 |
-
|
98 |
X_best = copy.deepcopy(spectrogram)
|
99 |
for i in range(hp.n_iter):
|
100 |
X_t = invert_spectrogram(X_best)
|
@@ -106,12 +109,16 @@ def griffin_lim(spectrogram):
|
|
106 |
|
107 |
return y
|
108 |
|
|
|
109 |
def invert_spectrogram(spectrogram):
|
110 |
-
|
111 |
Args:
|
112 |
spectrogram: [1+n_fft//2, t]
|
113 |
-
|
114 |
-
return librosa.istft(
|
|
|
|
|
|
|
115 |
|
116 |
def plot_alignment(alignment, gs, dir=hp.logdir):
|
117 |
"""Plots the alignment.
|
@@ -121,32 +128,43 @@ def plot_alignment(alignment, gs, dir=hp.logdir):
|
|
121 |
gs: (int) global step.
|
122 |
dir: Output path.
|
123 |
"""
|
124 |
-
if not os.path.exists(dir):
|
|
|
125 |
|
126 |
fig, ax = plt.subplots()
|
127 |
im = ax.imshow(alignment)
|
128 |
|
129 |
fig.colorbar(im)
|
130 |
-
plt.title(
|
131 |
-
plt.savefig(
|
132 |
plt.close(fig)
|
133 |
|
|
|
134 |
def guided_attention(g=0.2):
|
135 |
-
|
136 |
W = np.zeros((hp.max_N, hp.max_T), dtype=np.float32)
|
137 |
for n_pos in range(W.shape[0]):
|
138 |
for t_pos in range(W.shape[1]):
|
139 |
-
W[n_pos, t_pos] = 1 - np.exp(
|
|
|
|
|
|
|
140 |
return W
|
141 |
|
142 |
-
|
143 |
-
|
|
|
144 |
step = tf.to_float(global_step + 1)
|
145 |
-
return
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def load_spectrograms(fpath):
|
148 |
-
|
149 |
-
and extracts spectrograms
|
150 |
|
151 |
fname = os.path.basename(fpath)
|
152 |
mel, mag = get_spectrograms(fpath)
|
@@ -158,6 +176,5 @@ def load_spectrograms(fpath):
|
|
158 |
mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant")
|
159 |
|
160 |
# Reduction
|
161 |
-
mel = mel[::hp.r, :]
|
162 |
return fname, mel, mag
|
163 |
-
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# /usr/bin/python2
|
3 |
+
"""
|
4 |
By kyubyong park. [email protected].
|
5 |
https://www.github.com/kyubyong/dc_tts
|
6 |
+
"""
|
7 |
from __future__ import print_function, division
|
8 |
|
9 |
import numpy as np
|
10 |
import librosa
|
11 |
import os, copy
|
12 |
import matplotlib
|
13 |
+
|
14 |
+
matplotlib.use("pdf")
|
15 |
import matplotlib.pyplot as plt
|
16 |
from scipy import signal
|
17 |
|
18 |
from .audio_params import Hyperparams as hp
|
19 |
import tensorflow as tf
|
20 |
|
21 |
+
|
22 |
def get_spectrograms(fpath):
|
23 |
+
"""Parse the wave file in `fpath` and
|
24 |
Returns normalized melspectrogram and linear spectrogram.
|
25 |
|
26 |
Args:
|
|
|
29 |
Returns:
|
30 |
mel: A 2d array of shape (T, n_mels) and dtype of float32.
|
31 |
mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.
|
32 |
+
"""
|
33 |
# Loading sound file
|
34 |
y, sr = librosa.load(fpath, sr=hp.sr)
|
35 |
|
|
|
40 |
y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])
|
41 |
|
42 |
# stft
|
43 |
+
linear = librosa.stft(
|
44 |
+
y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length
|
45 |
+
)
|
|
|
46 |
|
47 |
# magnitude spectrogram
|
48 |
mag = np.abs(linear) # (1+n_fft//2, T)
|
|
|
65 |
|
66 |
return mel, mag
|
67 |
|
68 |
+
|
69 |
def spectrogram2wav(mag):
|
70 |
+
"""# Generate wave file from linear magnitude spectrogram
|
71 |
|
72 |
Args:
|
73 |
mag: A numpy array of (T, 1+n_fft//2)
|
74 |
|
75 |
Returns:
|
76 |
wav: A 1-D numpy array.
|
77 |
+
"""
|
78 |
# transpose
|
79 |
mag = mag.T
|
80 |
|
|
|
85 |
mag = np.power(10.0, mag * 0.05)
|
86 |
|
87 |
# wav reconstruction
|
88 |
+
wav = griffin_lim(mag ** hp.power)
|
89 |
|
90 |
# de-preemphasis
|
91 |
wav = signal.lfilter([1], [1, -hp.preemphasis], wav)
|
|
|
95 |
|
96 |
return wav.astype(np.float32)
|
97 |
|
98 |
+
|
99 |
def griffin_lim(spectrogram):
|
100 |
+
"""Applies Griffin-Lim's raw."""
|
101 |
X_best = copy.deepcopy(spectrogram)
|
102 |
for i in range(hp.n_iter):
|
103 |
X_t = invert_spectrogram(X_best)
|
|
|
109 |
|
110 |
return y
|
111 |
|
112 |
+
|
113 |
def invert_spectrogram(spectrogram):
|
114 |
+
"""Applies inverse fft.
|
115 |
Args:
|
116 |
spectrogram: [1+n_fft//2, t]
|
117 |
+
"""
|
118 |
+
return librosa.istft(
|
119 |
+
spectrogram, hp.hop_length, win_length=hp.win_length, window="hann"
|
120 |
+
)
|
121 |
+
|
122 |
|
123 |
def plot_alignment(alignment, gs, dir=hp.logdir):
|
124 |
"""Plots the alignment.
|
|
|
128 |
gs: (int) global step.
|
129 |
dir: Output path.
|
130 |
"""
|
131 |
+
if not os.path.exists(dir):
|
132 |
+
os.mkdir(dir)
|
133 |
|
134 |
fig, ax = plt.subplots()
|
135 |
im = ax.imshow(alignment)
|
136 |
|
137 |
fig.colorbar(im)
|
138 |
+
plt.title("{} Steps".format(gs))
|
139 |
+
plt.savefig("{}/alignment_{}.png".format(dir, gs), format="png")
|
140 |
plt.close(fig)
|
141 |
|
142 |
+
|
143 |
def guided_attention(g=0.2):
|
144 |
+
"""Guided attention. Refer to page 3 on the paper."""
|
145 |
W = np.zeros((hp.max_N, hp.max_T), dtype=np.float32)
|
146 |
for n_pos in range(W.shape[0]):
|
147 |
for t_pos in range(W.shape[1]):
|
148 |
+
W[n_pos, t_pos] = 1 - np.exp(
|
149 |
+
-((t_pos / float(hp.max_T) - n_pos / float(hp.max_N)) ** 2)
|
150 |
+
/ (2 * g * g)
|
151 |
+
)
|
152 |
return W
|
153 |
|
154 |
+
|
155 |
+
def learning_rate_decay(init_lr, global_step, warmup_steps=4000.0):
|
156 |
+
"""Noam scheme from tensor2tensor"""
|
157 |
step = tf.to_float(global_step + 1)
|
158 |
+
return (
|
159 |
+
init_lr
|
160 |
+
* warmup_steps ** 0.5
|
161 |
+
* tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
162 |
+
)
|
163 |
+
|
164 |
|
165 |
def load_spectrograms(fpath):
|
166 |
+
"""Read the wave file in `fpath`
|
167 |
+
and extracts spectrograms"""
|
168 |
|
169 |
fname = os.path.basename(fpath)
|
170 |
mel, mag = get_spectrograms(fpath)
|
|
|
176 |
mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant")
|
177 |
|
178 |
# Reduction
|
179 |
+
mel = mel[:: hp.r, :]
|
180 |
return fname, mel, mag
|
|
utils/audio_params.py
CHANGED
@@ -1,14 +1,19 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
|
4 |
By kyubyong park. [email protected].
|
5 |
https://www.github.com/kyubyong/dc_tts
|
6 |
-
|
|
|
|
|
7 |
class Hyperparams:
|
8 |
-
|
|
|
9 |
# pipeline
|
10 |
-
prepro =
|
11 |
-
|
|
|
|
|
12 |
# signal processing
|
13 |
sr = 22050 # Sampling rate.
|
14 |
n_fft = 2048 # fft points (samples)
|
@@ -19,29 +24,29 @@ class Hyperparams:
|
|
19 |
n_mels = 80 # Number of Mel banks to generate
|
20 |
power = 1.5 # Exponent for amplifying the predicted magnitude
|
21 |
n_iter = 50 # Number of inversion iterations
|
22 |
-
preemphasis = .97
|
23 |
max_db = 100
|
24 |
ref_db = 20
|
25 |
|
26 |
# Model
|
27 |
-
r = 4
|
28 |
dropout_rate = 0.05
|
29 |
-
e = 128
|
30 |
-
d = 256
|
31 |
-
c = 512
|
32 |
attention_win_size = 3
|
33 |
|
34 |
# data
|
35 |
data = "/data/private/voice/LJSpeech-1.0"
|
36 |
# data = "/data/private/voice/kate"
|
37 |
-
test_data =
|
38 |
-
vocab = "PE abcdefghijklmnopqrstuvwxyz'.?"
|
39 |
-
max_N = 180
|
40 |
-
max_T = 210
|
41 |
|
42 |
# training scheme
|
43 |
-
lr = 0.001
|
44 |
logdir = "logdir/LJ01"
|
45 |
-
sampledir =
|
46 |
-
B = 32
|
47 |
num_iterations = 2000000
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# /usr/bin/python2
|
3 |
+
"""
|
4 |
By kyubyong park. [email protected].
|
5 |
https://www.github.com/kyubyong/dc_tts
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
class Hyperparams:
|
10 |
+
"""Hyper parameters"""
|
11 |
+
|
12 |
# pipeline
|
13 |
+
prepro = (
|
14 |
+
True # if True, run `python prepro.py` first before running `python train.py`.
|
15 |
+
)
|
16 |
+
|
17 |
# signal processing
|
18 |
sr = 22050 # Sampling rate.
|
19 |
n_fft = 2048 # fft points (samples)
|
|
|
24 |
n_mels = 80 # Number of Mel banks to generate
|
25 |
power = 1.5 # Exponent for amplifying the predicted magnitude
|
26 |
n_iter = 50 # Number of inversion iterations
|
27 |
+
preemphasis = 0.97
|
28 |
max_db = 100
|
29 |
ref_db = 20
|
30 |
|
31 |
# Model
|
32 |
+
r = 4 # Reduction factor. Do not change this.
|
33 |
dropout_rate = 0.05
|
34 |
+
e = 128 # == embedding
|
35 |
+
d = 256 # == hidden units of Text2Mel
|
36 |
+
c = 512 # == hidden units of SSRN
|
37 |
attention_win_size = 3
|
38 |
|
39 |
# data
|
40 |
data = "/data/private/voice/LJSpeech-1.0"
|
41 |
# data = "/data/private/voice/kate"
|
42 |
+
test_data = "harvard_sentences.txt"
|
43 |
+
vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS.
|
44 |
+
max_N = 180 # Maximum number of characters.
|
45 |
+
max_T = 210 # Maximum number of mel frames.
|
46 |
|
47 |
# training scheme
|
48 |
+
lr = 0.001 # Initial learning rate.
|
49 |
logdir = "logdir/LJ01"
|
50 |
+
sampledir = "samples"
|
51 |
+
B = 32 # batch size
|
52 |
num_iterations = 2000000
|
utils/compute_args.py
CHANGED
@@ -3,26 +3,39 @@ import torch
|
|
3 |
|
4 |
def compute_args(args):
|
5 |
# DataLoader
|
6 |
-
if not hasattr(args,
|
7 |
-
args.dataset =
|
8 |
|
9 |
-
if args.dataset == "MOSEI":
|
10 |
-
|
|
|
|
|
11 |
|
12 |
# Loss function to use
|
13 |
-
if args.dataset ==
|
14 |
-
|
15 |
-
if args.dataset ==
|
|
|
|
|
|
|
16 |
|
17 |
# Answer size
|
18 |
-
if args.dataset ==
|
19 |
-
|
20 |
-
if args.dataset ==
|
21 |
-
|
22 |
-
if args.dataset ==
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
if args.dataset ==
|
25 |
-
|
26 |
-
if args.dataset ==
|
|
|
|
|
|
|
27 |
|
28 |
return args
|
|
|
3 |
|
4 |
def compute_args(args):
|
5 |
# DataLoader
|
6 |
+
if not hasattr(args, "dataset"): # fix for previous version
|
7 |
+
args.dataset = "MOSEI"
|
8 |
|
9 |
+
if args.dataset == "MOSEI":
|
10 |
+
args.dataloader = "Mosei_Dataset"
|
11 |
+
if args.dataset == "MELD":
|
12 |
+
args.dataloader = "Meld_Dataset"
|
13 |
|
14 |
# Loss function to use
|
15 |
+
if args.dataset == "MOSEI" and args.task == "sentiment":
|
16 |
+
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
17 |
+
if args.dataset == "MOSEI" and args.task == "emotion":
|
18 |
+
args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
|
19 |
+
if args.dataset == "MELD":
|
20 |
+
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
21 |
|
22 |
# Answer size
|
23 |
+
if args.dataset == "MOSEI" and args.task == "sentiment":
|
24 |
+
args.ans_size = 7
|
25 |
+
if args.dataset == "MOSEI" and args.task == "sentiment" and args.task_binary:
|
26 |
+
args.ans_size = 2
|
27 |
+
if args.dataset == "MOSEI" and args.task == "emotion":
|
28 |
+
args.ans_size = 6
|
29 |
+
if args.dataset == "MELD" and args.task == "emotion":
|
30 |
+
args.ans_size = 7
|
31 |
+
if args.dataset == "MELD" and args.task == "sentiment":
|
32 |
+
args.ans_size = 3
|
33 |
|
34 |
+
if args.dataset == "MOSEI":
|
35 |
+
args.pred_func = "amax"
|
36 |
+
if args.dataset == "MOSEI" and args.task == "emotion":
|
37 |
+
args.pred_func = "multi_label"
|
38 |
+
if args.dataset == "MELD":
|
39 |
+
args.pred_func = "amax"
|
40 |
|
41 |
return args
|
utils/plot.py
CHANGED
@@ -10,4 +10,4 @@
|
|
10 |
# maxfreq = n.max()
|
11 |
# # Set a clean upper y-axis limit.
|
12 |
# plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
|
13 |
-
# plt.show()
|
|
|
10 |
# maxfreq = n.max()
|
11 |
# # Set a clean upper y-axis limit.
|
12 |
# plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
|
13 |
+
# plt.show()
|
utils/pred_func.py
CHANGED
@@ -6,4 +6,4 @@ def amax(x):
|
|
6 |
|
7 |
|
8 |
def multi_label(x):
|
9 |
-
return
|
|
|
6 |
|
7 |
|
8 |
def multi_label(x):
|
9 |
+
return x > 0
|
utils/tokenize.py
CHANGED
@@ -6,38 +6,37 @@ import numpy as np
|
|
6 |
import os
|
7 |
import pickle
|
8 |
|
|
|
9 |
def clean(w):
|
10 |
-
return
|
11 |
-
|
12 |
-
|
13 |
-
w.lower()
|
14 |
-
).replace('-', ' ').replace('/', ' ')
|
15 |
|
16 |
|
17 |
def tokenize(key_to_word):
|
18 |
key_to_sentence = {}
|
19 |
for k, v in key_to_word.items():
|
20 |
-
key_to_sentence[k] = [clean(w) for w in v if clean(w) !=
|
21 |
return key_to_sentence
|
22 |
|
23 |
|
24 |
def create_dict(key_to_sentence, dataroot, use_glove=True):
|
25 |
-
token_file = dataroot+"/token_to_ix.pkl"
|
26 |
-
glove_file = dataroot+"/train_glove.npy"
|
27 |
if os.path.exists(glove_file) and os.path.exists(token_file):
|
28 |
print("Loading train language files")
|
29 |
return pickle.load(open(token_file, "rb")), np.load(glove_file)
|
30 |
|
31 |
print("Creating train language files")
|
32 |
token_to_ix = {
|
33 |
-
|
34 |
}
|
35 |
|
36 |
spacy_tool = None
|
37 |
pretrained_emb = []
|
38 |
if use_glove:
|
39 |
spacy_tool = en_vectors_web_lg.load()
|
40 |
-
pretrained_emb.append(spacy_tool(
|
41 |
|
42 |
for k, v in key_to_sentence.items():
|
43 |
for word in v:
|
@@ -51,6 +50,7 @@ def create_dict(key_to_sentence, dataroot, use_glove=True):
|
|
51 |
pickle.dump(token_to_ix, open(token_file, "wb"))
|
52 |
return token_to_ix, pretrained_emb
|
53 |
|
|
|
54 |
def sent_to_ix(s, token_to_ix, max_token=100):
|
55 |
ques_ix = np.zeros(max_token, np.int64)
|
56 |
|
@@ -58,7 +58,7 @@ def sent_to_ix(s, token_to_ix, max_token=100):
|
|
58 |
if word in token_to_ix:
|
59 |
ques_ix[ix] = token_to_ix[word]
|
60 |
else:
|
61 |
-
ques_ix[ix] = token_to_ix[
|
62 |
|
63 |
if ix + 1 == max_token:
|
64 |
break
|
@@ -83,21 +83,20 @@ def cmumosei_7(a):
|
|
83 |
res = 6
|
84 |
return res
|
85 |
|
|
|
86 |
def cmumosei_2(a):
|
87 |
if a < 0:
|
88 |
return 0
|
89 |
if a >= 0:
|
90 |
return 1
|
91 |
|
|
|
92 |
def pad_feature(feat, max_len):
|
93 |
if feat.shape[0] > max_len:
|
94 |
feat = feat[:max_len]
|
95 |
|
96 |
feat = np.pad(
|
97 |
-
feat,
|
98 |
-
((0, max_len - feat.shape[0]), (0, 0)),
|
99 |
-
mode='constant',
|
100 |
-
constant_values=0
|
101 |
)
|
102 |
|
103 |
return feat
|
|
|
6 |
import os
|
7 |
import pickle
|
8 |
|
9 |
+
|
10 |
def clean(w):
|
11 |
+
return (
|
12 |
+
re.sub(r"([.,'!?\"()*#:;])", "", w.lower()).replace("-", " ").replace("/", " ")
|
13 |
+
)
|
|
|
|
|
14 |
|
15 |
|
16 |
def tokenize(key_to_word):
|
17 |
key_to_sentence = {}
|
18 |
for k, v in key_to_word.items():
|
19 |
+
key_to_sentence[k] = [clean(w) for w in v if clean(w) != ""]
|
20 |
return key_to_sentence
|
21 |
|
22 |
|
23 |
def create_dict(key_to_sentence, dataroot, use_glove=True):
|
24 |
+
token_file = dataroot + "/token_to_ix.pkl"
|
25 |
+
glove_file = dataroot + "/train_glove.npy"
|
26 |
if os.path.exists(glove_file) and os.path.exists(token_file):
|
27 |
print("Loading train language files")
|
28 |
return pickle.load(open(token_file, "rb")), np.load(glove_file)
|
29 |
|
30 |
print("Creating train language files")
|
31 |
token_to_ix = {
|
32 |
+
"UNK": 1,
|
33 |
}
|
34 |
|
35 |
spacy_tool = None
|
36 |
pretrained_emb = []
|
37 |
if use_glove:
|
38 |
spacy_tool = en_vectors_web_lg.load()
|
39 |
+
pretrained_emb.append(spacy_tool("UNK").vector)
|
40 |
|
41 |
for k, v in key_to_sentence.items():
|
42 |
for word in v:
|
|
|
50 |
pickle.dump(token_to_ix, open(token_file, "wb"))
|
51 |
return token_to_ix, pretrained_emb
|
52 |
|
53 |
+
|
54 |
def sent_to_ix(s, token_to_ix, max_token=100):
|
55 |
ques_ix = np.zeros(max_token, np.int64)
|
56 |
|
|
|
58 |
if word in token_to_ix:
|
59 |
ques_ix[ix] = token_to_ix[word]
|
60 |
else:
|
61 |
+
ques_ix[ix] = token_to_ix["UNK"]
|
62 |
|
63 |
if ix + 1 == max_token:
|
64 |
break
|
|
|
83 |
res = 6
|
84 |
return res
|
85 |
|
86 |
+
|
87 |
def cmumosei_2(a):
|
88 |
if a < 0:
|
89 |
return 0
|
90 |
if a >= 0:
|
91 |
return 1
|
92 |
|
93 |
+
|
94 |
def pad_feature(feat, max_len):
|
95 |
if feat.shape[0] > max_len:
|
96 |
feat = feat[:max_len]
|
97 |
|
98 |
feat = np.pad(
|
99 |
+
feat, ((0, max_len - feat.shape[0]), (0, 0)), mode="constant", constant_values=0
|
|
|
|
|
|
|
100 |
)
|
101 |
|
102 |
return feat
|