Spaces:
Build error
Build error
Commit
•
ff43e05
1
Parent(s):
0f6f21e
initial commit
Browse files- .gitattributes +2 -0
- .gitignore +166 -0
- ckpt/Model_LA_e/best81.21325494388027_1117766.pkl +3 -0
- inference.ipynb +0 -0
- layers/fc.py +37 -0
- layers/layer_norm.py +16 -0
- model_LA.py +343 -0
- model_LAV.py +367 -0
- token_to_ix.pkl +3 -0
- train_glove.npy +3 -0
- utils/__init__.py +0 -0
- utils/audio.py +163 -0
- utils/audio_params.py +47 -0
- utils/compute_args.py +28 -0
- utils/plot.py +13 -0
- utils/pred_func.py +9 -0
- utils/tokenize.py +103 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.tar.gz
|
2 |
+
data
|
3 |
+
|
4 |
+
# Initially taken from Github's Python gitignore file
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# tests and logs
|
15 |
+
tests/fixtures/cached_*_text.txt
|
16 |
+
logs/
|
17 |
+
lightning_logs/
|
18 |
+
lang_code_data/
|
19 |
+
|
20 |
+
# Distribution / packaging
|
21 |
+
.Python
|
22 |
+
build/
|
23 |
+
develop-eggs/
|
24 |
+
dist/
|
25 |
+
downloads/
|
26 |
+
eggs/
|
27 |
+
.eggs/
|
28 |
+
lib/
|
29 |
+
lib64/
|
30 |
+
parts/
|
31 |
+
sdist/
|
32 |
+
var/
|
33 |
+
wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.nox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
.python-version
|
93 |
+
|
94 |
+
# celery beat schedule file
|
95 |
+
celerybeat-schedule
|
96 |
+
|
97 |
+
# SageMath parsed files
|
98 |
+
*.sage.py
|
99 |
+
|
100 |
+
# Environments
|
101 |
+
.env
|
102 |
+
.venv
|
103 |
+
env/
|
104 |
+
venv/
|
105 |
+
ENV/
|
106 |
+
env.bak/
|
107 |
+
venv.bak/
|
108 |
+
|
109 |
+
# Spyder project settings
|
110 |
+
.spyderproject
|
111 |
+
.spyproject
|
112 |
+
|
113 |
+
# Rope project settings
|
114 |
+
.ropeproject
|
115 |
+
|
116 |
+
# mkdocs documentation
|
117 |
+
/site
|
118 |
+
|
119 |
+
# mypy
|
120 |
+
.mypy_cache/
|
121 |
+
.dmypy.json
|
122 |
+
dmypy.json
|
123 |
+
|
124 |
+
# Pyre type checker
|
125 |
+
.pyre/
|
126 |
+
|
127 |
+
# vscode
|
128 |
+
.vs
|
129 |
+
.vscode
|
130 |
+
|
131 |
+
# Pycharm
|
132 |
+
.idea
|
133 |
+
|
134 |
+
# TF code
|
135 |
+
tensorflow_code
|
136 |
+
|
137 |
+
# Models
|
138 |
+
proc_data
|
139 |
+
|
140 |
+
# examples
|
141 |
+
runs
|
142 |
+
/runs_old
|
143 |
+
/wandb
|
144 |
+
/examples/runs
|
145 |
+
/examples/**/*.args
|
146 |
+
/examples/rag/sweep
|
147 |
+
|
148 |
+
# data
|
149 |
+
/data
|
150 |
+
serialization_dir
|
151 |
+
|
152 |
+
# emacs
|
153 |
+
*.*~
|
154 |
+
debug.env
|
155 |
+
|
156 |
+
# vim
|
157 |
+
.*.swp
|
158 |
+
|
159 |
+
#ctags
|
160 |
+
tags
|
161 |
+
|
162 |
+
# pre-commit
|
163 |
+
.pre-commit*
|
164 |
+
|
165 |
+
# .lock
|
166 |
+
*.lock
|
ckpt/Model_LA_e/best81.21325494388027_1117766.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c49968f8c2bcd7ec0489bd88c1f41418d15f01932264487c6d088807dcaaf4c
|
3 |
+
size 391671429
|
inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
layers/fc.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class FC(nn.Module):
|
4 |
+
def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
|
5 |
+
super(FC, self).__init__()
|
6 |
+
self.dropout_r = dropout_r
|
7 |
+
self.use_relu = use_relu
|
8 |
+
|
9 |
+
self.linear = nn.Linear(in_size, out_size)
|
10 |
+
|
11 |
+
if use_relu:
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
if dropout_r > 0:
|
15 |
+
self.dropout = nn.Dropout(dropout_r)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.linear(x)
|
19 |
+
|
20 |
+
if self.use_relu:
|
21 |
+
x = self.relu(x)
|
22 |
+
|
23 |
+
if self.dropout_r > 0:
|
24 |
+
x = self.dropout(x)
|
25 |
+
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
class MLP(nn.Module):
|
30 |
+
def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
|
31 |
+
super(MLP, self).__init__()
|
32 |
+
|
33 |
+
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
34 |
+
self.linear = nn.Linear(mid_size, out_size)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.linear(self.fc(x))
|
layers/layer_norm.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class LayerNorm(nn.Module):
|
5 |
+
def __init__(self, size, eps=1e-6):
|
6 |
+
super(LayerNorm, self).__init__()
|
7 |
+
self.eps = eps
|
8 |
+
|
9 |
+
self.a_2 = nn.Parameter(torch.ones(size))
|
10 |
+
self.b_2 = nn.Parameter(torch.zeros(size))
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
mean = x.mean(-1, keepdim=True)
|
14 |
+
std = x.std(-1, keepdim=True)
|
15 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
16 |
+
|
model_LA.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from layers.fc import MLP, FC
|
7 |
+
from layers.layer_norm import LayerNorm
|
8 |
+
|
9 |
+
# ------------------------------------
|
10 |
+
# ---------- Masking sequence --------
|
11 |
+
# ------------------------------------
|
12 |
+
def make_mask(feature):
|
13 |
+
return (torch.sum(
|
14 |
+
torch.abs(feature),
|
15 |
+
dim=-1
|
16 |
+
) == 0).unsqueeze(1).unsqueeze(2)
|
17 |
+
|
18 |
+
# ------------------------------
|
19 |
+
# ---------- Flattening --------
|
20 |
+
# ------------------------------
|
21 |
+
|
22 |
+
|
23 |
+
class AttFlat(nn.Module):
|
24 |
+
def __init__(self, args, flat_glimpse, merge=False):
|
25 |
+
super(AttFlat, self).__init__()
|
26 |
+
self.args = args
|
27 |
+
self.merge = merge
|
28 |
+
self.flat_glimpse = flat_glimpse
|
29 |
+
self.mlp = MLP(
|
30 |
+
in_size=args.hidden_size,
|
31 |
+
mid_size=args.ff_size,
|
32 |
+
out_size=flat_glimpse,
|
33 |
+
dropout_r=args.dropout_r,
|
34 |
+
use_relu=True
|
35 |
+
)
|
36 |
+
|
37 |
+
if self.merge:
|
38 |
+
self.linear_merge = nn.Linear(
|
39 |
+
args.hidden_size * flat_glimpse,
|
40 |
+
args.hidden_size * 2
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x, x_mask):
|
44 |
+
att = self.mlp(x)
|
45 |
+
if x_mask is not None:
|
46 |
+
att = att.masked_fill(
|
47 |
+
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
48 |
+
-1e9
|
49 |
+
)
|
50 |
+
att = F.softmax(att, dim=1)
|
51 |
+
|
52 |
+
att_list = []
|
53 |
+
for i in range(self.flat_glimpse):
|
54 |
+
att_list.append(
|
55 |
+
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.merge:
|
59 |
+
x_atted = torch.cat(att_list, dim=1)
|
60 |
+
x_atted = self.linear_merge(x_atted)
|
61 |
+
|
62 |
+
return x_atted
|
63 |
+
|
64 |
+
return torch.stack(att_list).transpose_(0, 1)
|
65 |
+
|
66 |
+
# ------------------------
|
67 |
+
# ---- Self Attention ----
|
68 |
+
# ------------------------
|
69 |
+
|
70 |
+
class SA(nn.Module):
|
71 |
+
def __init__(self, args):
|
72 |
+
super(SA, self).__init__()
|
73 |
+
|
74 |
+
self.mhatt = MHAtt(args)
|
75 |
+
self.ffn = FFN(args)
|
76 |
+
|
77 |
+
self.dropout1 = nn.Dropout(args.dropout_r)
|
78 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
79 |
+
|
80 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
81 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
82 |
+
|
83 |
+
def forward(self, y, y_mask):
|
84 |
+
y = self.norm1(y + self.dropout1(
|
85 |
+
self.mhatt(y, y, y, y_mask)
|
86 |
+
))
|
87 |
+
|
88 |
+
y = self.norm2(y + self.dropout2(
|
89 |
+
self.ffn(y)
|
90 |
+
))
|
91 |
+
|
92 |
+
return y
|
93 |
+
|
94 |
+
|
95 |
+
# -------------------------------
|
96 |
+
# ---- Self Guided Attention ----
|
97 |
+
# -------------------------------
|
98 |
+
|
99 |
+
class SGA(nn.Module):
|
100 |
+
def __init__(self, args):
|
101 |
+
super(SGA, self).__init__()
|
102 |
+
|
103 |
+
self.mhatt1 = MHAtt(args)
|
104 |
+
self.mhatt2 = MHAtt(args)
|
105 |
+
self.ffn = FFN(args)
|
106 |
+
|
107 |
+
self.dropout1 = nn.Dropout(args.dropout_r)
|
108 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
109 |
+
|
110 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
111 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
112 |
+
|
113 |
+
self.dropout3 = nn.Dropout(args.dropout_r)
|
114 |
+
self.norm3 = LayerNorm(args.hidden_size)
|
115 |
+
|
116 |
+
def forward(self, x, y, x_mask, y_mask):
|
117 |
+
x = self.norm1(x + self.dropout1(
|
118 |
+
self.mhatt1(v=x, k=x, q=x, mask=x_mask)
|
119 |
+
))
|
120 |
+
|
121 |
+
x = self.norm2(x + self.dropout2(
|
122 |
+
self.mhatt2(v=y, k=y, q=x, mask=y_mask)
|
123 |
+
))
|
124 |
+
|
125 |
+
x = self.norm3(x + self.dropout3(
|
126 |
+
self.ffn(x)
|
127 |
+
))
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
# ------------------------------
|
132 |
+
# ---- Multi-Head Attention ----
|
133 |
+
# ------------------------------
|
134 |
+
|
135 |
+
class MHAtt(nn.Module):
|
136 |
+
def __init__(self, args):
|
137 |
+
super(MHAtt, self).__init__()
|
138 |
+
self.args = args
|
139 |
+
|
140 |
+
self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
|
141 |
+
self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
|
142 |
+
self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
|
143 |
+
self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
|
144 |
+
|
145 |
+
self.dropout = nn.Dropout(args.dropout_r)
|
146 |
+
|
147 |
+
def forward(self, v, k, q, mask):
|
148 |
+
n_batches = q.size(0)
|
149 |
+
v = self.linear_v(v).view(
|
150 |
+
n_batches,
|
151 |
+
-1,
|
152 |
+
self.args.multi_head,
|
153 |
+
int(self.args.hidden_size / self.args.multi_head)
|
154 |
+
).transpose(1, 2)
|
155 |
+
|
156 |
+
k = self.linear_k(k).view(
|
157 |
+
n_batches,
|
158 |
+
-1,
|
159 |
+
self.args.multi_head,
|
160 |
+
int(self.args.hidden_size / self.args.multi_head)
|
161 |
+
).transpose(1, 2)
|
162 |
+
|
163 |
+
q = self.linear_q(q).view(
|
164 |
+
n_batches,
|
165 |
+
-1,
|
166 |
+
self.args.multi_head,
|
167 |
+
int(self.args.hidden_size / self.args.multi_head)
|
168 |
+
).transpose(1, 2)
|
169 |
+
|
170 |
+
atted = self.att(v, k, q, mask)
|
171 |
+
|
172 |
+
atted = atted.transpose(1, 2).contiguous().view(
|
173 |
+
n_batches,
|
174 |
+
-1,
|
175 |
+
self.args.hidden_size
|
176 |
+
)
|
177 |
+
atted = self.linear_merge(atted)
|
178 |
+
|
179 |
+
return atted
|
180 |
+
|
181 |
+
def att(self, value, key, query, mask):
|
182 |
+
d_k = query.size(-1)
|
183 |
+
|
184 |
+
scores = torch.matmul(
|
185 |
+
query, key.transpose(-2, -1)
|
186 |
+
) / math.sqrt(d_k)
|
187 |
+
|
188 |
+
if mask is not None:
|
189 |
+
scores = scores.masked_fill(mask, -1e9)
|
190 |
+
|
191 |
+
att_map = F.softmax(scores, dim=-1)
|
192 |
+
att_map = self.dropout(att_map)
|
193 |
+
|
194 |
+
return torch.matmul(att_map, value)
|
195 |
+
|
196 |
+
|
197 |
+
# ---------------------------
|
198 |
+
# ---- Feed Forward Nets ----
|
199 |
+
# ---------------------------
|
200 |
+
|
201 |
+
class FFN(nn.Module):
|
202 |
+
def __init__(self, args):
|
203 |
+
super(FFN, self).__init__()
|
204 |
+
|
205 |
+
self.mlp = MLP(
|
206 |
+
in_size=args.hidden_size,
|
207 |
+
mid_size=args.ff_size,
|
208 |
+
out_size=args.hidden_size,
|
209 |
+
dropout_r=args.dropout_r,
|
210 |
+
use_relu=True
|
211 |
+
)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.mlp(x)
|
215 |
+
|
216 |
+
# ---------------------------
|
217 |
+
# ---- FF + norm -----------
|
218 |
+
# ---------------------------
|
219 |
+
class FFAndNorm(nn.Module):
|
220 |
+
def __init__(self, args):
|
221 |
+
super(FFAndNorm, self).__init__()
|
222 |
+
|
223 |
+
self.ffn = FFN(args)
|
224 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
225 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
226 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
x = self.norm1(x)
|
230 |
+
x = self.norm2(x + self.dropout2(self.ffn(x)))
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
class Block(nn.Module):
|
236 |
+
def __init__(self, args, i):
|
237 |
+
super(Block, self).__init__()
|
238 |
+
self.args = args
|
239 |
+
self.sa1 = SA(args)
|
240 |
+
self.sa3 = SGA(args)
|
241 |
+
|
242 |
+
self.last = (i == args.layer-1)
|
243 |
+
if not self.last:
|
244 |
+
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
245 |
+
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
246 |
+
self.norm_l = LayerNorm(args.hidden_size)
|
247 |
+
self.norm_i = LayerNorm(args.hidden_size)
|
248 |
+
self.dropout = nn.Dropout(args.dropout_r)
|
249 |
+
|
250 |
+
def forward(self, x, x_mask, y, y_mask):
|
251 |
+
|
252 |
+
ax = self.sa1(x, x_mask)
|
253 |
+
ay = self.sa3(y, x, y_mask, x_mask)
|
254 |
+
|
255 |
+
x = ax + x
|
256 |
+
y = ay + y
|
257 |
+
|
258 |
+
if self.last:
|
259 |
+
return x, y
|
260 |
+
|
261 |
+
ax = self.att_lang(x, x_mask)
|
262 |
+
ay = self.att_audio(y, y_mask)
|
263 |
+
|
264 |
+
return self.norm_l(x + self.dropout(ax)), \
|
265 |
+
self.norm_i(y + self.dropout(ay))
|
266 |
+
|
267 |
+
|
268 |
+
class Model_LA(nn.Module):
|
269 |
+
def __init__(self, args, vocab_size, pretrained_emb):
|
270 |
+
super(Model_LA, self).__init__()
|
271 |
+
|
272 |
+
self.args = args
|
273 |
+
|
274 |
+
# LSTM
|
275 |
+
self.embedding = nn.Embedding(
|
276 |
+
num_embeddings=vocab_size,
|
277 |
+
embedding_dim=args.word_embed_size
|
278 |
+
)
|
279 |
+
|
280 |
+
# Loading the GloVe embedding weights
|
281 |
+
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
|
282 |
+
|
283 |
+
self.lstm_x = nn.LSTM(
|
284 |
+
input_size=args.word_embed_size,
|
285 |
+
hidden_size=args.hidden_size,
|
286 |
+
num_layers=1,
|
287 |
+
batch_first=True
|
288 |
+
)
|
289 |
+
|
290 |
+
# self.lstm_y = nn.LSTM(
|
291 |
+
# input_size=args.audio_feat_size,
|
292 |
+
# hidden_size=args.hidden_size,
|
293 |
+
# num_layers=1,
|
294 |
+
# batch_first=True
|
295 |
+
# )
|
296 |
+
|
297 |
+
# Feature size to hid size
|
298 |
+
self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size)
|
299 |
+
|
300 |
+
# Encoder blocks
|
301 |
+
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
302 |
+
|
303 |
+
# Flattenting features before proj
|
304 |
+
self.attflat_img = AttFlat(args, 1, merge=True)
|
305 |
+
self.attflat_lang = AttFlat(args, 1, merge=True)
|
306 |
+
|
307 |
+
# Classification layers
|
308 |
+
self.proj_norm = LayerNorm(2 * args.hidden_size)
|
309 |
+
self.proj = self.proj = nn.Linear(2 * args.hidden_size, args.ans_size)
|
310 |
+
|
311 |
+
def forward(self, x, y, _):
|
312 |
+
x_mask = make_mask(x.unsqueeze(2))
|
313 |
+
y_mask = make_mask(y)
|
314 |
+
|
315 |
+
embedding = self.embedding(x)
|
316 |
+
|
317 |
+
x, _ = self.lstm_x(embedding)
|
318 |
+
# y, _ = self.lstm_y(y)
|
319 |
+
|
320 |
+
y = self.adapter(y)
|
321 |
+
|
322 |
+
for i, dec in enumerate(self.enc_list):
|
323 |
+
x_m, x_y = None, None
|
324 |
+
if i == 0:
|
325 |
+
x_m, x_y = x_mask, y_mask
|
326 |
+
x, y = dec(x, x_m, y, x_y)
|
327 |
+
|
328 |
+
x = self.attflat_lang(
|
329 |
+
x,
|
330 |
+
None
|
331 |
+
)
|
332 |
+
|
333 |
+
y = self.attflat_img(
|
334 |
+
y,
|
335 |
+
None
|
336 |
+
)
|
337 |
+
|
338 |
+
# Classification layers
|
339 |
+
proj_feat = x + y
|
340 |
+
proj_feat = self.proj_norm(proj_feat)
|
341 |
+
ans = self.proj(proj_feat)
|
342 |
+
|
343 |
+
return ans
|
model_LAV.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from layers.fc import MLP
|
7 |
+
from layers.layer_norm import LayerNorm
|
8 |
+
|
9 |
+
# ------------------------------------
|
10 |
+
# ---------- Masking sequence --------
|
11 |
+
# ------------------------------------
|
12 |
+
def make_mask(feature):
|
13 |
+
return (torch.sum(
|
14 |
+
torch.abs(feature),
|
15 |
+
dim=-1
|
16 |
+
) == 0).unsqueeze(1).unsqueeze(2)
|
17 |
+
|
18 |
+
# ------------------------------
|
19 |
+
# ---------- Flattening --------
|
20 |
+
# ------------------------------
|
21 |
+
|
22 |
+
|
23 |
+
class AttFlat(nn.Module):
|
24 |
+
def __init__(self, args, flat_glimpse, merge=False):
|
25 |
+
super(AttFlat, self).__init__()
|
26 |
+
self.args = args
|
27 |
+
self.merge = merge
|
28 |
+
self.flat_glimpse = flat_glimpse
|
29 |
+
self.mlp = MLP(
|
30 |
+
in_size=args.hidden_size,
|
31 |
+
mid_size=args.ff_size,
|
32 |
+
out_size=flat_glimpse,
|
33 |
+
dropout_r=args.dropout_r,
|
34 |
+
use_relu=True
|
35 |
+
)
|
36 |
+
|
37 |
+
if self.merge:
|
38 |
+
self.linear_merge = nn.Linear(
|
39 |
+
args.hidden_size * flat_glimpse,
|
40 |
+
args.hidden_size * 2
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x, x_mask):
|
44 |
+
att = self.mlp(x)
|
45 |
+
if x_mask is not None:
|
46 |
+
att = att.masked_fill(
|
47 |
+
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
48 |
+
-1e9
|
49 |
+
)
|
50 |
+
att = F.softmax(att, dim=1)
|
51 |
+
|
52 |
+
att_list = []
|
53 |
+
for i in range(self.flat_glimpse):
|
54 |
+
att_list.append(
|
55 |
+
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.merge:
|
59 |
+
x_atted = torch.cat(att_list, dim=1)
|
60 |
+
x_atted = self.linear_merge(x_atted)
|
61 |
+
|
62 |
+
return x_atted
|
63 |
+
|
64 |
+
return torch.stack(att_list).transpose_(0, 1)
|
65 |
+
|
66 |
+
# ------------------------
|
67 |
+
# ---- Self Attention ----
|
68 |
+
# ------------------------
|
69 |
+
|
70 |
+
class SA(nn.Module):
|
71 |
+
def __init__(self, args):
|
72 |
+
super(SA, self).__init__()
|
73 |
+
|
74 |
+
self.mhatt = MHAtt(args)
|
75 |
+
self.ffn = FFN(args)
|
76 |
+
|
77 |
+
self.dropout1 = nn.Dropout(args.dropout_r)
|
78 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
79 |
+
|
80 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
81 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
82 |
+
|
83 |
+
def forward(self, y, y_mask):
|
84 |
+
y = self.norm1(y + self.dropout1(
|
85 |
+
self.mhatt(y, y, y, y_mask)
|
86 |
+
))
|
87 |
+
|
88 |
+
y = self.norm2(y + self.dropout2(
|
89 |
+
self.ffn(y)
|
90 |
+
))
|
91 |
+
|
92 |
+
return y
|
93 |
+
|
94 |
+
|
95 |
+
# -------------------------------
|
96 |
+
# ---- Self Guided Attention ----
|
97 |
+
# -------------------------------
|
98 |
+
|
99 |
+
class SGA(nn.Module):
|
100 |
+
def __init__(self, args):
|
101 |
+
super(SGA, self).__init__()
|
102 |
+
|
103 |
+
self.mhatt1 = MHAtt(args)
|
104 |
+
self.mhatt2 = MHAtt(args)
|
105 |
+
self.ffn = FFN(args)
|
106 |
+
|
107 |
+
self.dropout1 = nn.Dropout(args.dropout_r)
|
108 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
109 |
+
|
110 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
111 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
112 |
+
|
113 |
+
self.dropout3 = nn.Dropout(args.dropout_r)
|
114 |
+
self.norm3 = LayerNorm(args.hidden_size)
|
115 |
+
|
116 |
+
def forward(self, x, y, x_mask, y_mask):
|
117 |
+
x = self.norm1(x + self.dropout1(
|
118 |
+
self.mhatt1(v=x, k=x, q=x, mask=x_mask)
|
119 |
+
))
|
120 |
+
|
121 |
+
x = self.norm2(x + self.dropout2(
|
122 |
+
self.mhatt2(v=y, k=y, q=x, mask=y_mask)
|
123 |
+
))
|
124 |
+
|
125 |
+
x = self.norm3(x + self.dropout3(
|
126 |
+
self.ffn(x)
|
127 |
+
))
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
# ------------------------------
|
132 |
+
# ---- Multi-Head Attention ----
|
133 |
+
# ------------------------------
|
134 |
+
|
135 |
+
class MHAtt(nn.Module):
|
136 |
+
def __init__(self, args):
|
137 |
+
super(MHAtt, self).__init__()
|
138 |
+
self.args = args
|
139 |
+
|
140 |
+
self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
|
141 |
+
self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
|
142 |
+
self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
|
143 |
+
self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
|
144 |
+
|
145 |
+
self.dropout = nn.Dropout(args.dropout_r)
|
146 |
+
|
147 |
+
def forward(self, v, k, q, mask):
|
148 |
+
n_batches = q.size(0)
|
149 |
+
v = self.linear_v(v).view(
|
150 |
+
n_batches,
|
151 |
+
-1,
|
152 |
+
self.args.multi_head,
|
153 |
+
int(self.args.hidden_size / self.args.multi_head)
|
154 |
+
).transpose(1, 2)
|
155 |
+
|
156 |
+
k = self.linear_k(k).view(
|
157 |
+
n_batches,
|
158 |
+
-1,
|
159 |
+
self.args.multi_head,
|
160 |
+
int(self.args.hidden_size / self.args.multi_head)
|
161 |
+
).transpose(1, 2)
|
162 |
+
|
163 |
+
q = self.linear_q(q).view(
|
164 |
+
n_batches,
|
165 |
+
-1,
|
166 |
+
self.args.multi_head,
|
167 |
+
int(self.args.hidden_size / self.args.multi_head)
|
168 |
+
).transpose(1, 2)
|
169 |
+
|
170 |
+
atted = self.att(v, k, q, mask)
|
171 |
+
|
172 |
+
atted = atted.transpose(1, 2).contiguous().view(
|
173 |
+
n_batches,
|
174 |
+
-1,
|
175 |
+
self.args.hidden_size
|
176 |
+
)
|
177 |
+
atted = self.linear_merge(atted)
|
178 |
+
|
179 |
+
return atted
|
180 |
+
|
181 |
+
def att(self, value, key, query, mask):
|
182 |
+
d_k = query.size(-1)
|
183 |
+
|
184 |
+
scores = torch.matmul(
|
185 |
+
query, key.transpose(-2, -1)
|
186 |
+
) / math.sqrt(d_k)
|
187 |
+
|
188 |
+
if mask is not None:
|
189 |
+
scores = scores.masked_fill(mask, -1e9)
|
190 |
+
|
191 |
+
att_map = F.softmax(scores, dim=-1)
|
192 |
+
att_map = self.dropout(att_map)
|
193 |
+
|
194 |
+
return torch.matmul(att_map, value)
|
195 |
+
|
196 |
+
|
197 |
+
# ---------------------------
|
198 |
+
# ---- Feed Forward Nets ----
|
199 |
+
# ---------------------------
|
200 |
+
|
201 |
+
class FFN(nn.Module):
|
202 |
+
def __init__(self, args):
|
203 |
+
super(FFN, self).__init__()
|
204 |
+
|
205 |
+
self.mlp = MLP(
|
206 |
+
in_size=args.hidden_size,
|
207 |
+
mid_size=args.ff_size,
|
208 |
+
out_size=args.hidden_size,
|
209 |
+
dropout_r=args.dropout_r,
|
210 |
+
use_relu=True
|
211 |
+
)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.mlp(x)
|
215 |
+
|
216 |
+
# ---------------------------
|
217 |
+
# ---- FF + norm -----------
|
218 |
+
# ---------------------------
|
219 |
+
class FFAndNorm(nn.Module):
|
220 |
+
def __init__(self, args):
|
221 |
+
super(FFAndNorm, self).__init__()
|
222 |
+
|
223 |
+
self.ffn = FFN(args)
|
224 |
+
self.norm1 = LayerNorm(args.hidden_size)
|
225 |
+
self.dropout2 = nn.Dropout(args.dropout_r)
|
226 |
+
self.norm2 = LayerNorm(args.hidden_size)
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
x = self.norm1(x)
|
230 |
+
x = self.norm2(x + self.dropout2(self.ffn(x)))
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
class Block(nn.Module):
|
236 |
+
def __init__(self, args, i):
|
237 |
+
super(Block, self).__init__()
|
238 |
+
self.args = args
|
239 |
+
self.sa1 = SA(args)
|
240 |
+
self.sa2 = SGA(args)
|
241 |
+
self.sa3 = SGA(args)
|
242 |
+
|
243 |
+
self.last = (i == args.layer-1)
|
244 |
+
if not self.last:
|
245 |
+
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
|
246 |
+
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
|
247 |
+
self.att_vid = AttFlat(args, args.video_seq_len, merge=False)
|
248 |
+
self.norm_l = LayerNorm(args.hidden_size)
|
249 |
+
self.norm_a = LayerNorm(args.hidden_size)
|
250 |
+
self.norm_v = LayerNorm(args.hidden_size)
|
251 |
+
self.dropout = nn.Dropout(args.dropout_r)
|
252 |
+
|
253 |
+
def forward(self, x, x_mask, y, y_mask, z, z_mask):
|
254 |
+
|
255 |
+
ax = self.sa1(x, x_mask)
|
256 |
+
ay = self.sa2(y, x, y_mask, x_mask)
|
257 |
+
az = self.sa3(z, x, z_mask, x_mask)
|
258 |
+
|
259 |
+
x = ax + x
|
260 |
+
y = ay + y
|
261 |
+
z = az + z
|
262 |
+
|
263 |
+
if self.last:
|
264 |
+
return x, y, z
|
265 |
+
|
266 |
+
ax = self.att_lang(x, x_mask)
|
267 |
+
ay = self.att_audio(y, y_mask)
|
268 |
+
az = self.att_vid(z, y_mask)
|
269 |
+
|
270 |
+
return self.norm_l(x + self.dropout(ax)), \
|
271 |
+
self.norm_a(y + self.dropout(ay)), \
|
272 |
+
self.norm_v(z + self.dropout(az))
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
class Model_LAV(nn.Module):
|
277 |
+
def __init__(self, args, vocab_size, pretrained_emb):
|
278 |
+
super(Model_LAV, self).__init__()
|
279 |
+
|
280 |
+
self.args = args
|
281 |
+
|
282 |
+
# LSTM
|
283 |
+
self.embedding = nn.Embedding(
|
284 |
+
num_embeddings=vocab_size,
|
285 |
+
embedding_dim=args.word_embed_size
|
286 |
+
)
|
287 |
+
|
288 |
+
# Loading the GloVe embedding weights
|
289 |
+
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
|
290 |
+
|
291 |
+
self.lstm_x = nn.LSTM(
|
292 |
+
input_size=args.word_embed_size,
|
293 |
+
hidden_size=args.hidden_size,
|
294 |
+
num_layers=1,
|
295 |
+
batch_first=True
|
296 |
+
)
|
297 |
+
|
298 |
+
# self.lstm_y = nn.LSTM(
|
299 |
+
# input_size=args.audio_feat_size,
|
300 |
+
# hidden_size=args.hidden_size,
|
301 |
+
# num_layers=1,
|
302 |
+
# batch_first=True
|
303 |
+
# )
|
304 |
+
|
305 |
+
# Feature size to hid size
|
306 |
+
self.adapter_y = nn.Linear(args.audio_feat_size, args.hidden_size)
|
307 |
+
self.adapter_z = nn.Linear(args.video_feat_size, args.hidden_size)
|
308 |
+
|
309 |
+
# Encoder blocks
|
310 |
+
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
|
311 |
+
|
312 |
+
# Flattenting features before proj
|
313 |
+
self.attflat_ac = AttFlat(args, 1, merge=True)
|
314 |
+
self.attflat_vid = AttFlat(args, 1, merge=True)
|
315 |
+
self.attflat_lang = AttFlat(args, 1, merge=True)
|
316 |
+
|
317 |
+
# Classification layers
|
318 |
+
self.proj_norm = LayerNorm(2 * args.hidden_size)
|
319 |
+
if self.args.task == "sentiment":
|
320 |
+
if self.args.task_binary:
|
321 |
+
self.proj = nn.Linear(2 * args.hidden_size, 2)
|
322 |
+
else:
|
323 |
+
self.proj = nn.Linear(2 * args.hidden_size, 7)
|
324 |
+
if self.args.task == "emotion":
|
325 |
+
self.proj = self.proj = nn.Linear(2 * args.hidden_size, 6)
|
326 |
+
|
327 |
+
def forward(self, x, y, z):
|
328 |
+
x_mask = make_mask(x.unsqueeze(2))
|
329 |
+
y_mask = make_mask(y)
|
330 |
+
z_mask = make_mask(z)
|
331 |
+
|
332 |
+
|
333 |
+
embedding = self.embedding(x)
|
334 |
+
|
335 |
+
x, _ = self.lstm_x(embedding)
|
336 |
+
# y, _ = self.lstm_y(y)
|
337 |
+
|
338 |
+
y, z = self.adapter_y(y), self.adapter_z(z)
|
339 |
+
|
340 |
+
for i, dec in enumerate(self.enc_list):
|
341 |
+
x_m, y_m, z_m = None, None, None
|
342 |
+
if i == 0:
|
343 |
+
x_m, y_m, z_m = x_mask, y_mask, z_mask
|
344 |
+
x, y, z = dec(x, x_m, y, y_m, z, z_m)
|
345 |
+
|
346 |
+
x = self.attflat_lang(
|
347 |
+
x,
|
348 |
+
None
|
349 |
+
)
|
350 |
+
|
351 |
+
y = self.attflat_ac(
|
352 |
+
y,
|
353 |
+
None
|
354 |
+
)
|
355 |
+
|
356 |
+
z = self.attflat_vid(
|
357 |
+
z,
|
358 |
+
None
|
359 |
+
)
|
360 |
+
|
361 |
+
|
362 |
+
# Classification layers
|
363 |
+
proj_feat = x + y + z
|
364 |
+
proj_feat = self.proj_norm(proj_feat)
|
365 |
+
ans = self.proj(proj_feat)
|
366 |
+
|
367 |
+
return ans
|
token_to_ix.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1b468b2048c2ac08aaae32ba38c69fc9535af97bf7946e39ba4888794a8574d
|
3 |
+
size 286216
|
train_glove.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c73c457f2e7d047538488d411bcc851ae45b53cf3526482c5b0f6d4b745ebd55
|
3 |
+
size 17012528
|
utils/__init__.py
ADDED
File without changes
|
utils/audio.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#/usr/bin/python2
|
3 |
+
'''
|
4 |
+
By kyubyong park. [email protected].
|
5 |
+
https://www.github.com/kyubyong/dc_tts
|
6 |
+
'''
|
7 |
+
from __future__ import print_function, division
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
import os, copy
|
12 |
+
import matplotlib
|
13 |
+
matplotlib.use('pdf')
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from scipy import signal
|
16 |
+
|
17 |
+
from .audio_params import Hyperparams as hp
|
18 |
+
import tensorflow as tf
|
19 |
+
|
20 |
+
def get_spectrograms(fpath):
|
21 |
+
'''Parse the wave file in `fpath` and
|
22 |
+
Returns normalized melspectrogram and linear spectrogram.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
fpath: A string. The full path of a sound file.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
mel: A 2d array of shape (T, n_mels) and dtype of float32.
|
29 |
+
mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.
|
30 |
+
'''
|
31 |
+
# Loading sound file
|
32 |
+
y, sr = librosa.load(fpath, sr=hp.sr)
|
33 |
+
|
34 |
+
# Trimming
|
35 |
+
y, _ = librosa.effects.trim(y)
|
36 |
+
|
37 |
+
# Preemphasis
|
38 |
+
y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])
|
39 |
+
|
40 |
+
# stft
|
41 |
+
linear = librosa.stft(y=y,
|
42 |
+
n_fft=hp.n_fft,
|
43 |
+
hop_length=hp.hop_length,
|
44 |
+
win_length=hp.win_length)
|
45 |
+
|
46 |
+
# magnitude spectrogram
|
47 |
+
mag = np.abs(linear) # (1+n_fft//2, T)
|
48 |
+
|
49 |
+
# mel spectrogram
|
50 |
+
mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2)
|
51 |
+
mel = np.dot(mel_basis, mag) # (n_mels, t)
|
52 |
+
|
53 |
+
# to decibel
|
54 |
+
mel = 20 * np.log10(np.maximum(1e-5, mel))
|
55 |
+
mag = 20 * np.log10(np.maximum(1e-5, mag))
|
56 |
+
|
57 |
+
# normalize
|
58 |
+
mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
|
59 |
+
mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
|
60 |
+
|
61 |
+
# Transpose
|
62 |
+
mel = mel.T.astype(np.float32) # (T, n_mels)
|
63 |
+
mag = mag.T.astype(np.float32) # (T, 1+n_fft//2)
|
64 |
+
|
65 |
+
return mel, mag
|
66 |
+
|
67 |
+
def spectrogram2wav(mag):
|
68 |
+
'''# Generate wave file from linear magnitude spectrogram
|
69 |
+
|
70 |
+
Args:
|
71 |
+
mag: A numpy array of (T, 1+n_fft//2)
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
wav: A 1-D numpy array.
|
75 |
+
'''
|
76 |
+
# transpose
|
77 |
+
mag = mag.T
|
78 |
+
|
79 |
+
# de-noramlize
|
80 |
+
mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db
|
81 |
+
|
82 |
+
# to amplitude
|
83 |
+
mag = np.power(10.0, mag * 0.05)
|
84 |
+
|
85 |
+
# wav reconstruction
|
86 |
+
wav = griffin_lim(mag**hp.power)
|
87 |
+
|
88 |
+
# de-preemphasis
|
89 |
+
wav = signal.lfilter([1], [1, -hp.preemphasis], wav)
|
90 |
+
|
91 |
+
# trim
|
92 |
+
wav, _ = librosa.effects.trim(wav)
|
93 |
+
|
94 |
+
return wav.astype(np.float32)
|
95 |
+
|
96 |
+
def griffin_lim(spectrogram):
|
97 |
+
'''Applies Griffin-Lim's raw.'''
|
98 |
+
X_best = copy.deepcopy(spectrogram)
|
99 |
+
for i in range(hp.n_iter):
|
100 |
+
X_t = invert_spectrogram(X_best)
|
101 |
+
est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length)
|
102 |
+
phase = est / np.maximum(1e-8, np.abs(est))
|
103 |
+
X_best = spectrogram * phase
|
104 |
+
X_t = invert_spectrogram(X_best)
|
105 |
+
y = np.real(X_t)
|
106 |
+
|
107 |
+
return y
|
108 |
+
|
109 |
+
def invert_spectrogram(spectrogram):
|
110 |
+
'''Applies inverse fft.
|
111 |
+
Args:
|
112 |
+
spectrogram: [1+n_fft//2, t]
|
113 |
+
'''
|
114 |
+
return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann")
|
115 |
+
|
116 |
+
def plot_alignment(alignment, gs, dir=hp.logdir):
|
117 |
+
"""Plots the alignment.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
alignment: A numpy array with shape of (encoder_steps, decoder_steps)
|
121 |
+
gs: (int) global step.
|
122 |
+
dir: Output path.
|
123 |
+
"""
|
124 |
+
if not os.path.exists(dir): os.mkdir(dir)
|
125 |
+
|
126 |
+
fig, ax = plt.subplots()
|
127 |
+
im = ax.imshow(alignment)
|
128 |
+
|
129 |
+
fig.colorbar(im)
|
130 |
+
plt.title('{} Steps'.format(gs))
|
131 |
+
plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png')
|
132 |
+
plt.close(fig)
|
133 |
+
|
134 |
+
def guided_attention(g=0.2):
|
135 |
+
'''Guided attention. Refer to page 3 on the paper.'''
|
136 |
+
W = np.zeros((hp.max_N, hp.max_T), dtype=np.float32)
|
137 |
+
for n_pos in range(W.shape[0]):
|
138 |
+
for t_pos in range(W.shape[1]):
|
139 |
+
W[n_pos, t_pos] = 1 - np.exp(-(t_pos / float(hp.max_T) - n_pos / float(hp.max_N)) ** 2 / (2 * g * g))
|
140 |
+
return W
|
141 |
+
|
142 |
+
def learning_rate_decay(init_lr, global_step, warmup_steps = 4000.0):
|
143 |
+
'''Noam scheme from tensor2tensor'''
|
144 |
+
step = tf.to_float(global_step + 1)
|
145 |
+
return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, step**-0.5)
|
146 |
+
|
147 |
+
def load_spectrograms(fpath):
|
148 |
+
'''Read the wave file in `fpath`
|
149 |
+
and extracts spectrograms'''
|
150 |
+
|
151 |
+
fname = os.path.basename(fpath)
|
152 |
+
mel, mag = get_spectrograms(fpath)
|
153 |
+
t = mel.shape[0]
|
154 |
+
|
155 |
+
# Marginal padding for reduction shape sync.
|
156 |
+
num_paddings = hp.r - (t % hp.r) if t % hp.r != 0 else 0
|
157 |
+
mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant")
|
158 |
+
mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant")
|
159 |
+
|
160 |
+
# Reduction
|
161 |
+
mel = mel[::hp.r, :]
|
162 |
+
return fname, mel, mag
|
163 |
+
|
utils/audio_params.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#/usr/bin/python2
|
3 |
+
'''
|
4 |
+
By kyubyong park. [email protected].
|
5 |
+
https://www.github.com/kyubyong/dc_tts
|
6 |
+
'''
|
7 |
+
class Hyperparams:
|
8 |
+
'''Hyper parameters'''
|
9 |
+
# pipeline
|
10 |
+
prepro = True # if True, run `python prepro.py` first before running `python train.py`.
|
11 |
+
|
12 |
+
# signal processing
|
13 |
+
sr = 22050 # Sampling rate.
|
14 |
+
n_fft = 2048 # fft points (samples)
|
15 |
+
frame_shift = 0.0125 # seconds
|
16 |
+
frame_length = 0.05 # seconds
|
17 |
+
hop_length = int(sr * frame_shift) # samples. =276.
|
18 |
+
win_length = int(sr * frame_length) # samples. =1102.
|
19 |
+
n_mels = 80 # Number of Mel banks to generate
|
20 |
+
power = 1.5 # Exponent for amplifying the predicted magnitude
|
21 |
+
n_iter = 50 # Number of inversion iterations
|
22 |
+
preemphasis = .97
|
23 |
+
max_db = 100
|
24 |
+
ref_db = 20
|
25 |
+
|
26 |
+
# Model
|
27 |
+
r = 4 # Reduction factor. Do not change this.
|
28 |
+
dropout_rate = 0.05
|
29 |
+
e = 128 # == embedding
|
30 |
+
d = 256 # == hidden units of Text2Mel
|
31 |
+
c = 512 # == hidden units of SSRN
|
32 |
+
attention_win_size = 3
|
33 |
+
|
34 |
+
# data
|
35 |
+
data = "/data/private/voice/LJSpeech-1.0"
|
36 |
+
# data = "/data/private/voice/kate"
|
37 |
+
test_data = 'harvard_sentences.txt'
|
38 |
+
vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS.
|
39 |
+
max_N = 180 # Maximum number of characters.
|
40 |
+
max_T = 210 # Maximum number of mel frames.
|
41 |
+
|
42 |
+
# training scheme
|
43 |
+
lr = 0.001 # Initial learning rate.
|
44 |
+
logdir = "logdir/LJ01"
|
45 |
+
sampledir = 'samples'
|
46 |
+
B = 32 # batch size
|
47 |
+
num_iterations = 2000000
|
utils/compute_args.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def compute_args(args):
|
5 |
+
# DataLoader
|
6 |
+
if not hasattr(args, 'dataset'): # fix for previous version
|
7 |
+
args.dataset = 'MOSEI'
|
8 |
+
|
9 |
+
if args.dataset == "MOSEI": args.dataloader = 'Mosei_Dataset'
|
10 |
+
if args.dataset == "MELD": args.dataloader = 'Meld_Dataset'
|
11 |
+
|
12 |
+
# Loss function to use
|
13 |
+
if args.dataset == 'MOSEI' and args.task == 'sentiment': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
14 |
+
if args.dataset == 'MOSEI' and args.task == 'emotion': args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
|
15 |
+
if args.dataset == 'MELD': args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
|
16 |
+
|
17 |
+
# Answer size
|
18 |
+
if args.dataset == 'MOSEI' and args.task == "sentiment": args.ans_size = 7
|
19 |
+
if args.dataset == 'MOSEI' and args.task == "sentiment" and args.task_binary: args.ans_size = 2
|
20 |
+
if args.dataset == 'MOSEI' and args.task == "emotion": args.ans_size = 6
|
21 |
+
if args.dataset == 'MELD' and args.task == "emotion": args.ans_size = 7
|
22 |
+
if args.dataset == 'MELD' and args.task == "sentiment": args.ans_size = 3
|
23 |
+
|
24 |
+
if args.dataset == 'MOSEI': args.pred_func = "amax"
|
25 |
+
if args.dataset == 'MOSEI' and args.task == "emotion": args.pred_func = "multi_label"
|
26 |
+
if args.dataset == 'MELD': args.pred_func = "amax"
|
27 |
+
|
28 |
+
return args
|
utils/plot.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import matplotlib.pyplot as plt
|
2 |
+
# import numpy as np
|
3 |
+
#
|
4 |
+
# def plot(d):
|
5 |
+
# # An "interface" to matplotlib.axes.Axes.hist() method
|
6 |
+
# n, bins, patches = plt.hist(x=d, bins='auto', color='#0504aa',
|
7 |
+
# alpha=0.7, rwidth=0.85)
|
8 |
+
# plt.grid(axis='y', alpha=0.75)
|
9 |
+
# plt.title('My Very Own Histogram')
|
10 |
+
# maxfreq = n.max()
|
11 |
+
# # Set a clean upper y-axis limit.
|
12 |
+
# plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
|
13 |
+
# plt.show()
|
utils/pred_func.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def amax(x):
|
5 |
+
return np.argmax(x, axis=1)
|
6 |
+
|
7 |
+
|
8 |
+
def multi_label(x):
|
9 |
+
return (x > 0)
|
utils/tokenize.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# $ wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz
|
2 |
+
# $ pip install en_vectors_web_lg-2.1.0.tar.gz
|
3 |
+
import en_vectors_web_lg
|
4 |
+
import re
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
def clean(w):
|
10 |
+
return re.sub(
|
11 |
+
r"([.,'!?\"()*#:;])",
|
12 |
+
'',
|
13 |
+
w.lower()
|
14 |
+
).replace('-', ' ').replace('/', ' ')
|
15 |
+
|
16 |
+
|
17 |
+
def tokenize(key_to_word):
|
18 |
+
key_to_sentence = {}
|
19 |
+
for k, v in key_to_word.items():
|
20 |
+
key_to_sentence[k] = [clean(w) for w in v if clean(w) != '']
|
21 |
+
return key_to_sentence
|
22 |
+
|
23 |
+
|
24 |
+
def create_dict(key_to_sentence, dataroot, use_glove=True):
|
25 |
+
token_file = dataroot+"/token_to_ix.pkl"
|
26 |
+
glove_file = dataroot+"/train_glove.npy"
|
27 |
+
if os.path.exists(glove_file) and os.path.exists(token_file):
|
28 |
+
print("Loading train language files")
|
29 |
+
return pickle.load(open(token_file, "rb")), np.load(glove_file)
|
30 |
+
|
31 |
+
print("Creating train language files")
|
32 |
+
token_to_ix = {
|
33 |
+
'UNK': 1,
|
34 |
+
}
|
35 |
+
|
36 |
+
spacy_tool = None
|
37 |
+
pretrained_emb = []
|
38 |
+
if use_glove:
|
39 |
+
spacy_tool = en_vectors_web_lg.load()
|
40 |
+
pretrained_emb.append(spacy_tool('UNK').vector)
|
41 |
+
|
42 |
+
for k, v in key_to_sentence.items():
|
43 |
+
for word in v:
|
44 |
+
if word not in token_to_ix:
|
45 |
+
token_to_ix[word] = len(token_to_ix)
|
46 |
+
if use_glove:
|
47 |
+
pretrained_emb.append(spacy_tool(word).vector)
|
48 |
+
|
49 |
+
pretrained_emb = np.array(pretrained_emb)
|
50 |
+
np.save(glove_file, pretrained_emb)
|
51 |
+
pickle.dump(token_to_ix, open(token_file, "wb"))
|
52 |
+
return token_to_ix, pretrained_emb
|
53 |
+
|
54 |
+
def sent_to_ix(s, token_to_ix, max_token=100):
|
55 |
+
ques_ix = np.zeros(max_token, np.int64)
|
56 |
+
|
57 |
+
for ix, word in enumerate(s):
|
58 |
+
if word in token_to_ix:
|
59 |
+
ques_ix[ix] = token_to_ix[word]
|
60 |
+
else:
|
61 |
+
ques_ix[ix] = token_to_ix['UNK']
|
62 |
+
|
63 |
+
if ix + 1 == max_token:
|
64 |
+
break
|
65 |
+
|
66 |
+
return ques_ix
|
67 |
+
|
68 |
+
|
69 |
+
def cmumosei_7(a):
|
70 |
+
if a < -2:
|
71 |
+
res = 0
|
72 |
+
if -2 <= a and a < -1:
|
73 |
+
res = 1
|
74 |
+
if -1 <= a and a < 0:
|
75 |
+
res = 2
|
76 |
+
if 0 <= a and a <= 0:
|
77 |
+
res = 3
|
78 |
+
if 0 < a and a <= 1:
|
79 |
+
res = 4
|
80 |
+
if 1 < a and a <= 2:
|
81 |
+
res = 5
|
82 |
+
if a > 2:
|
83 |
+
res = 6
|
84 |
+
return res
|
85 |
+
|
86 |
+
def cmumosei_2(a):
|
87 |
+
if a < 0:
|
88 |
+
return 0
|
89 |
+
if a >= 0:
|
90 |
+
return 1
|
91 |
+
|
92 |
+
def pad_feature(feat, max_len):
|
93 |
+
if feat.shape[0] > max_len:
|
94 |
+
feat = feat[:max_len]
|
95 |
+
|
96 |
+
feat = np.pad(
|
97 |
+
feat,
|
98 |
+
((0, max_len - feat.shape[0]), (0, 0)),
|
99 |
+
mode='constant',
|
100 |
+
constant_values=0
|
101 |
+
)
|
102 |
+
|
103 |
+
return feat
|