Jyothirmai
commited on
Commit
•
26e26de
1
Parent(s):
e96b01f
Upload 10 files
Browse files- __init__.py +0 -0
- build_vocab.py +80 -0
- callbacks.py +1066 -0
- dataset.py +1 -1
- logger.py +71 -0
- models_debugger.py +816 -0
- tcn.py +83 -0
__init__.py
ADDED
File without changes
|
build_vocab.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from collections import Counter
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
class JsonReader(object):
|
7 |
+
def __init__(self, json_file):
|
8 |
+
self.data = self.__read_json(json_file)
|
9 |
+
self.keys = list(self.data.keys())
|
10 |
+
|
11 |
+
def __read_json(self, filename):
|
12 |
+
with open(filename, 'r') as f:
|
13 |
+
data = json.load(f)
|
14 |
+
return data
|
15 |
+
|
16 |
+
def __getitem__(self, item):
|
17 |
+
return self.data[item]
|
18 |
+
# return self.data[self.keys[item]]
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.data)
|
22 |
+
|
23 |
+
|
24 |
+
class Vocabulary(object):
|
25 |
+
def __init__(self):
|
26 |
+
self.word2idx = {}
|
27 |
+
self.id2word = {}
|
28 |
+
self.idx = 0
|
29 |
+
self.add_word('<pad>')
|
30 |
+
self.add_word('<end>')
|
31 |
+
self.add_word('<start>')
|
32 |
+
self.add_word('<unk>')
|
33 |
+
|
34 |
+
def add_word(self, word):
|
35 |
+
if word not in self.word2idx:
|
36 |
+
self.word2idx[word] = self.idx
|
37 |
+
self.id2word[self.idx] = word
|
38 |
+
self.idx += 1
|
39 |
+
|
40 |
+
def get_word_by_id(self, id):
|
41 |
+
return self.id2word[id]
|
42 |
+
|
43 |
+
def __call__(self, word):
|
44 |
+
if word not in self.word2idx:
|
45 |
+
return self.word2idx['<unk>']
|
46 |
+
return self.word2idx[word]
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.word2idx)
|
50 |
+
|
51 |
+
|
52 |
+
def build_vocab(json_file, threshold):
|
53 |
+
caption_reader = JsonReader(json_file)
|
54 |
+
counter = Counter()
|
55 |
+
|
56 |
+
for items in caption_reader:
|
57 |
+
text = items.replace('.', '').replace(',', '')
|
58 |
+
counter.update(text.lower().split(' '))
|
59 |
+
words = [word for word, cnt in counter.items() if cnt > threshold and word != '']
|
60 |
+
vocab = Vocabulary()
|
61 |
+
|
62 |
+
for word in words:
|
63 |
+
print(word)
|
64 |
+
vocab.add_word(word)
|
65 |
+
return vocab
|
66 |
+
|
67 |
+
|
68 |
+
def main(json_file, threshold, vocab_path):
|
69 |
+
vocab = build_vocab(json_file=json_file,
|
70 |
+
threshold=threshold)
|
71 |
+
with open(vocab_path, 'wb') as f:
|
72 |
+
pickle.dump(vocab, f)
|
73 |
+
print("Total vocabulary size:{}".format(len(vocab)))
|
74 |
+
print("Saved path in {}".format(vocab_path))
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
main(json_file='../data/new_data/debugging_captions.json',
|
79 |
+
threshold=0,
|
80 |
+
vocab_path='../data/new_data/debug_vocab.pkl')
|
callbacks.py
ADDED
@@ -0,0 +1,1066 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callbacks: utilities called at certain points during model training.
|
2 |
+
|
3 |
+
# Adapted from
|
4 |
+
|
5 |
+
- https://github.com/keras-team/keras
|
6 |
+
- https://github.com/bstriner/keras-tqdm/blob/master/keras_tqdm/tqdm_callback.py
|
7 |
+
|
8 |
+
"""
|
9 |
+
from __future__ import absolute_import
|
10 |
+
from __future__ import division
|
11 |
+
from __future__ import print_function
|
12 |
+
|
13 |
+
import os
|
14 |
+
import csv
|
15 |
+
import six
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import time
|
19 |
+
import json
|
20 |
+
import warnings
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from collections import deque
|
24 |
+
from collections import OrderedDict
|
25 |
+
from collections import Iterable
|
26 |
+
|
27 |
+
try:
|
28 |
+
import requests
|
29 |
+
except ImportError:
|
30 |
+
requests = None
|
31 |
+
|
32 |
+
|
33 |
+
class CallbackList(object):
|
34 |
+
"""Container abstracting a list of callbacks.
|
35 |
+
|
36 |
+
# Arguments
|
37 |
+
callbacks: List of `Callback` instances.
|
38 |
+
queue_length: Queue length for keeping
|
39 |
+
running statistics over callback execution time.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, callbacks=None, queue_length=10):
|
43 |
+
callbacks = callbacks or []
|
44 |
+
self.callbacks = [c for c in callbacks]
|
45 |
+
self.queue_length = queue_length
|
46 |
+
|
47 |
+
def append(self, callback):
|
48 |
+
self.callbacks.append(callback)
|
49 |
+
|
50 |
+
def set_params(self, params):
|
51 |
+
for callback in self.callbacks:
|
52 |
+
callback.set_params(params)
|
53 |
+
|
54 |
+
def set_model(self, model):
|
55 |
+
for callback in self.callbacks:
|
56 |
+
callback.set_model(model)
|
57 |
+
|
58 |
+
def on_epoch_begin(self, epoch, logs=None):
|
59 |
+
"""Called at the start of an epoch.
|
60 |
+
|
61 |
+
# Arguments
|
62 |
+
epoch: integer, index of epoch.
|
63 |
+
logs: dictionary of logs.
|
64 |
+
"""
|
65 |
+
logs = logs or {}
|
66 |
+
for callback in self.callbacks:
|
67 |
+
callback.on_epoch_begin(epoch, logs)
|
68 |
+
self._delta_t_batch = 0.
|
69 |
+
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
|
70 |
+
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
|
71 |
+
|
72 |
+
def on_epoch_end(self, epoch, logs=None):
|
73 |
+
"""Called at the end of an epoch.
|
74 |
+
|
75 |
+
# Arguments
|
76 |
+
epoch: integer, index of epoch.
|
77 |
+
logs: dictionary of logs.
|
78 |
+
"""
|
79 |
+
logs = logs or {}
|
80 |
+
for callback in self.callbacks:
|
81 |
+
callback.on_epoch_end(epoch, logs)
|
82 |
+
|
83 |
+
def on_batch_begin(self, batch, logs=None):
|
84 |
+
"""Called right before processing a batch.
|
85 |
+
|
86 |
+
# Arguments
|
87 |
+
batch: integer, index of batch within the current epoch.
|
88 |
+
logs: dictionary of logs.
|
89 |
+
"""
|
90 |
+
logs = logs or {}
|
91 |
+
t_before_callbacks = time.time()
|
92 |
+
for callback in self.callbacks:
|
93 |
+
callback.on_batch_begin(batch, logs)
|
94 |
+
self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
|
95 |
+
delta_t_median = np.median(self._delta_ts_batch_begin)
|
96 |
+
if (self._delta_t_batch > 0. and
|
97 |
+
delta_t_median > 0.95 * self._delta_t_batch and
|
98 |
+
delta_t_median > 0.1):
|
99 |
+
warnings.warn('Method on_batch_begin() is slow compared '
|
100 |
+
'to the batch update (%f). Check your callbacks.'
|
101 |
+
% delta_t_median)
|
102 |
+
self._t_enter_batch = time.time()
|
103 |
+
|
104 |
+
def on_batch_end(self, batch, logs=None):
|
105 |
+
"""Called at the end of a batch.
|
106 |
+
|
107 |
+
# Arguments
|
108 |
+
batch: integer, index of batch within the current epoch.
|
109 |
+
logs: dictionary of logs.
|
110 |
+
"""
|
111 |
+
logs = logs or {}
|
112 |
+
if not hasattr(self, '_t_enter_batch'):
|
113 |
+
self._t_enter_batch = time.time()
|
114 |
+
self._delta_t_batch = time.time() - self._t_enter_batch
|
115 |
+
t_before_callbacks = time.time()
|
116 |
+
for callback in self.callbacks:
|
117 |
+
callback.on_batch_end(batch, logs)
|
118 |
+
self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
|
119 |
+
delta_t_median = np.median(self._delta_ts_batch_end)
|
120 |
+
if (self._delta_t_batch > 0. and
|
121 |
+
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
|
122 |
+
warnings.warn('Method on_batch_end() is slow compared '
|
123 |
+
'to the batch update (%f). Check your callbacks.'
|
124 |
+
% delta_t_median)
|
125 |
+
|
126 |
+
def on_train_begin(self, logs=None):
|
127 |
+
"""Called at the beginning of training.
|
128 |
+
|
129 |
+
# Arguments
|
130 |
+
logs: dictionary of logs.
|
131 |
+
"""
|
132 |
+
logs = logs or {}
|
133 |
+
for callback in self.callbacks:
|
134 |
+
callback.on_train_begin(logs)
|
135 |
+
|
136 |
+
def on_train_end(self, logs=None):
|
137 |
+
"""Called at the end of training.
|
138 |
+
|
139 |
+
# Arguments
|
140 |
+
logs: dictionary of logs.
|
141 |
+
"""
|
142 |
+
logs = logs or {}
|
143 |
+
for callback in self.callbacks:
|
144 |
+
callback.on_train_end(logs)
|
145 |
+
|
146 |
+
def __iter__(self):
|
147 |
+
return iter(self.callbacks)
|
148 |
+
|
149 |
+
|
150 |
+
class Callback(object):
|
151 |
+
"""Abstract base class used to build new callbacks.
|
152 |
+
|
153 |
+
# Properties
|
154 |
+
params: dict. Training parameters
|
155 |
+
(eg. verbosity, batch size, number of epochs...).
|
156 |
+
model: instance of `keras.models.Model`.
|
157 |
+
Reference of the model being trained.
|
158 |
+
|
159 |
+
The `logs` dictionary that callback methods
|
160 |
+
take as argument will contain keys for quantities relevant to
|
161 |
+
the current batch or epoch.
|
162 |
+
|
163 |
+
Currently, the `.fit()` method of the `Sequential` model class
|
164 |
+
will include the following quantities in the `logs` that
|
165 |
+
it passes to its callbacks:
|
166 |
+
|
167 |
+
on_epoch_end: logs include `acc` and `loss`, and
|
168 |
+
optionally include `val_loss`
|
169 |
+
(if validation is enabled in `fit`), and `val_acc`
|
170 |
+
(if validation and accuracy monitoring are enabled).
|
171 |
+
on_batch_begin: logs include `size`,
|
172 |
+
the number of samples in the current batch.
|
173 |
+
on_batch_end: logs include `loss`, and optionally `acc`
|
174 |
+
(if accuracy monitoring is enabled).
|
175 |
+
"""
|
176 |
+
|
177 |
+
def __init__(self):
|
178 |
+
self.validation_data = None
|
179 |
+
self.model = None
|
180 |
+
|
181 |
+
def set_params(self, params):
|
182 |
+
self.params = params
|
183 |
+
|
184 |
+
def set_model(self, model):
|
185 |
+
self.model = model
|
186 |
+
|
187 |
+
def on_epoch_begin(self, epoch, logs=None):
|
188 |
+
pass
|
189 |
+
|
190 |
+
def on_epoch_end(self, epoch, logs=None):
|
191 |
+
pass
|
192 |
+
|
193 |
+
def on_batch_begin(self, batch, logs=None):
|
194 |
+
pass
|
195 |
+
|
196 |
+
def on_batch_end(self, batch, logs=None):
|
197 |
+
pass
|
198 |
+
|
199 |
+
def on_train_begin(self, logs=None):
|
200 |
+
pass
|
201 |
+
|
202 |
+
def on_train_end(self, logs=None):
|
203 |
+
pass
|
204 |
+
|
205 |
+
|
206 |
+
class BaseLogger(Callback):
|
207 |
+
"""Callback that accumulates epoch averages of metrics.
|
208 |
+
|
209 |
+
This callback is automatically applied to every Keras model.
|
210 |
+
"""
|
211 |
+
|
212 |
+
def on_epoch_begin(self, epoch, logs=None):
|
213 |
+
self.seen = 0
|
214 |
+
self.totals = {}
|
215 |
+
|
216 |
+
def on_batch_end(self, batch, logs=None):
|
217 |
+
logs = logs or {}
|
218 |
+
batch_size = logs.get('size', 0)
|
219 |
+
self.seen += batch_size
|
220 |
+
|
221 |
+
for k, v in logs.items():
|
222 |
+
if k in self.totals:
|
223 |
+
self.totals[k] += v * batch_size
|
224 |
+
else:
|
225 |
+
self.totals[k] = v * batch_size
|
226 |
+
|
227 |
+
def on_epoch_end(self, epoch, logs=None):
|
228 |
+
if logs is not None:
|
229 |
+
for k in self.params['metrics']:
|
230 |
+
if k in self.totals:
|
231 |
+
# Make value available to next callbacks.
|
232 |
+
logs[k] = self.totals[k] / self.seen
|
233 |
+
|
234 |
+
|
235 |
+
class TerminateOnNaN(Callback):
|
236 |
+
"""Callback that terminates training when a NaN loss is encountered.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self):
|
240 |
+
super(TerminateOnNaN, self).__init__()
|
241 |
+
|
242 |
+
def on_batch_end(self, batch, logs=None):
|
243 |
+
logs = logs or {}
|
244 |
+
loss = logs.get('loss')
|
245 |
+
if loss is not None:
|
246 |
+
if np.isnan(loss) or np.isinf(loss):
|
247 |
+
print('Batch %d: Invalid loss, terminating training' % (batch))
|
248 |
+
self.model.stop_training = True
|
249 |
+
|
250 |
+
|
251 |
+
class History(Callback):
|
252 |
+
"""Callback that records events into a `History` object.
|
253 |
+
|
254 |
+
This callback is automatically applied to
|
255 |
+
every Keras model. The `History` object
|
256 |
+
gets returned by the `fit` method of models.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def on_train_begin(self, logs=None):
|
260 |
+
self.epoch = []
|
261 |
+
self.history = {}
|
262 |
+
|
263 |
+
def on_epoch_end(self, epoch, logs=None):
|
264 |
+
logs = logs or {}
|
265 |
+
self.epoch.append(epoch)
|
266 |
+
for k, v in logs.items():
|
267 |
+
self.history.setdefault(k, []).append(v)
|
268 |
+
|
269 |
+
|
270 |
+
class ModelCheckpoint(Callback):
|
271 |
+
"""Save the model after every epoch.
|
272 |
+
|
273 |
+
`filepath` can contain named formatting options,
|
274 |
+
which will be filled the value of `epoch` and
|
275 |
+
keys in `logs` (passed in `on_epoch_end`).
|
276 |
+
|
277 |
+
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
|
278 |
+
then the model checkpoints will be saved with the epoch number and
|
279 |
+
the validation loss in the filename.
|
280 |
+
|
281 |
+
# Arguments
|
282 |
+
filepath: string, path to save the model file.
|
283 |
+
monitor: quantity to monitor.
|
284 |
+
verbose: verbosity mode, 0 or 1.
|
285 |
+
save_best_only: if `save_best_only=True`,
|
286 |
+
the latest best model according to
|
287 |
+
the quantity monitored will not be overwritten.
|
288 |
+
mode: one of {auto, min, max}.
|
289 |
+
If `save_best_only=True`, the decision
|
290 |
+
to overwrite the current save file is made
|
291 |
+
based on either the maximization or the
|
292 |
+
minimization of the monitored quantity. For `val_acc`,
|
293 |
+
this should be `max`, for `val_loss` this should
|
294 |
+
be `min`, etc. In `auto` mode, the direction is
|
295 |
+
automatically inferred from the name of the monitored quantity.
|
296 |
+
save_weights_only: if True, then only the model's weights will be
|
297 |
+
saved (`torch.save(self.model.state_dict(), filepath)`), else the full model
|
298 |
+
is saved (`torch.save(self.model.state_dict(), filepath)`).
|
299 |
+
period: Interval (number of epochs) between checkpoints.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, filepath, monitor='val_loss', verbose=0,
|
303 |
+
save_best_only=False, save_weights_only=False,
|
304 |
+
mode='auto', period=1):
|
305 |
+
super(ModelCheckpoint, self).__init__()
|
306 |
+
self.monitor = monitor
|
307 |
+
self.verbose = verbose
|
308 |
+
self.filepath = filepath
|
309 |
+
self.save_best_only = save_best_only
|
310 |
+
self.save_weights_only = save_weights_only
|
311 |
+
self.period = period
|
312 |
+
self.epochs_since_last_save = 0
|
313 |
+
|
314 |
+
if mode not in ['auto', 'min', 'max']:
|
315 |
+
warnings.warn('ModelCheckpoint mode %s is unknown, '
|
316 |
+
'fallback to auto mode.' % (mode),
|
317 |
+
RuntimeWarning)
|
318 |
+
mode = 'auto'
|
319 |
+
|
320 |
+
if mode == 'min':
|
321 |
+
self.monitor_op = np.less
|
322 |
+
self.best = np.Inf
|
323 |
+
elif mode == 'max':
|
324 |
+
self.monitor_op = np.greater
|
325 |
+
self.best = -np.Inf
|
326 |
+
else:
|
327 |
+
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
328 |
+
self.monitor_op = np.greater
|
329 |
+
self.best = -np.Inf
|
330 |
+
else:
|
331 |
+
self.monitor_op = np.less
|
332 |
+
self.best = np.Inf
|
333 |
+
|
334 |
+
def on_epoch_end(self, epoch, logs=None):
|
335 |
+
import torch
|
336 |
+
logs = logs or {}
|
337 |
+
self.epochs_since_last_save += 1
|
338 |
+
if self.epochs_since_last_save >= self.period:
|
339 |
+
self.epochs_since_last_save = 0
|
340 |
+
filepath = self.filepath.format(epoch=epoch + 1, **logs)
|
341 |
+
if self.save_best_only:
|
342 |
+
current = logs.get(self.monitor)
|
343 |
+
if current is None:
|
344 |
+
warnings.warn('Can save best model only with %s available, '
|
345 |
+
'skipping.' % (self.monitor), RuntimeWarning)
|
346 |
+
else:
|
347 |
+
if self.monitor_op(current, self.best):
|
348 |
+
if self.verbose > 0:
|
349 |
+
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
|
350 |
+
' saving model to %s'
|
351 |
+
% (epoch + 1, self.monitor, self.best,
|
352 |
+
current, filepath))
|
353 |
+
self.best = current
|
354 |
+
if self.save_weights_only:
|
355 |
+
torch.save(self.model.state_dict(), filepath)
|
356 |
+
else:
|
357 |
+
torch.save(self.model.state_dict(), filepath)
|
358 |
+
else:
|
359 |
+
if self.verbose > 0:
|
360 |
+
print('\nEpoch %05d: %s did not improve' %
|
361 |
+
(epoch + 1, self.monitor))
|
362 |
+
else:
|
363 |
+
if self.verbose > 0:
|
364 |
+
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
|
365 |
+
if self.save_weights_only:
|
366 |
+
torch.save(self.model.state_dict(), filepath)
|
367 |
+
else:
|
368 |
+
torch.save(self.model.state_dict(), filepath)
|
369 |
+
|
370 |
+
|
371 |
+
class EarlyStopping(Callback):
|
372 |
+
"""Stop training when a monitored quantity has stopped improving.
|
373 |
+
|
374 |
+
# Arguments
|
375 |
+
monitor: quantity to be monitored.
|
376 |
+
min_delta: minimum change in the monitored quantity
|
377 |
+
to qualify as an improvement, i.e. an absolute
|
378 |
+
change of less than min_delta, will count as no
|
379 |
+
improvement.
|
380 |
+
patience: number of epochs with no improvement
|
381 |
+
after which training will be stopped.
|
382 |
+
verbose: verbosity mode.
|
383 |
+
mode: one of {auto, min, max}. In `min` mode,
|
384 |
+
training will stop when the quantity
|
385 |
+
monitored has stopped decreasing; in `max`
|
386 |
+
mode it will stop when the quantity
|
387 |
+
monitored has stopped increasing; in `auto`
|
388 |
+
mode, the direction is automatically inferred
|
389 |
+
from the name of the monitored quantity.
|
390 |
+
"""
|
391 |
+
|
392 |
+
def __init__(self, monitor='val_loss',
|
393 |
+
min_delta=0, patience=0, verbose=0, mode='auto'):
|
394 |
+
super(EarlyStopping, self).__init__()
|
395 |
+
|
396 |
+
self.monitor = monitor
|
397 |
+
self.patience = patience
|
398 |
+
self.verbose = verbose
|
399 |
+
self.min_delta = min_delta
|
400 |
+
self.wait = 0
|
401 |
+
self.stopped_epoch = 0
|
402 |
+
|
403 |
+
if mode not in ['auto', 'min', 'max']:
|
404 |
+
warnings.warn('EarlyStopping mode %s is unknown, '
|
405 |
+
'fallback to auto mode.' % mode,
|
406 |
+
RuntimeWarning)
|
407 |
+
mode = 'auto'
|
408 |
+
|
409 |
+
if mode == 'min':
|
410 |
+
self.monitor_op = np.less
|
411 |
+
elif mode == 'max':
|
412 |
+
self.monitor_op = np.greater
|
413 |
+
else:
|
414 |
+
if 'acc' in self.monitor:
|
415 |
+
self.monitor_op = np.greater
|
416 |
+
else:
|
417 |
+
self.monitor_op = np.less
|
418 |
+
|
419 |
+
if self.monitor_op == np.greater:
|
420 |
+
self.min_delta *= 1
|
421 |
+
else:
|
422 |
+
self.min_delta *= -1
|
423 |
+
|
424 |
+
def on_train_begin(self, logs=None):
|
425 |
+
# Allow instances to be re-used
|
426 |
+
self.wait = 0
|
427 |
+
self.stopped_epoch = 0
|
428 |
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
429 |
+
|
430 |
+
def on_epoch_end(self, epoch, logs=None):
|
431 |
+
current = logs.get(self.monitor)
|
432 |
+
if current is None:
|
433 |
+
warnings.warn(
|
434 |
+
'Early stopping conditioned on metric `%s` '
|
435 |
+
'which is not available. Available metrics are: %s' %
|
436 |
+
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
|
437 |
+
)
|
438 |
+
return
|
439 |
+
if self.monitor_op(current - self.min_delta, self.best):
|
440 |
+
self.best = current
|
441 |
+
self.wait = 0
|
442 |
+
else:
|
443 |
+
self.wait += 1
|
444 |
+
if self.wait >= self.patience:
|
445 |
+
self.stopped_epoch = epoch
|
446 |
+
self.model.stop_training = True
|
447 |
+
|
448 |
+
def on_train_end(self, logs=None):
|
449 |
+
if self.stopped_epoch > 0 and self.verbose > 0:
|
450 |
+
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
|
451 |
+
|
452 |
+
|
453 |
+
class RemoteMonitor(Callback):
|
454 |
+
"""Callback used to stream events to a server.
|
455 |
+
|
456 |
+
Requires the `requests` library.
|
457 |
+
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
458 |
+
HTTP POST, with a `images` argument which is a
|
459 |
+
JSON-encoded dictionary of event images.
|
460 |
+
|
461 |
+
# Arguments
|
462 |
+
root: String; root url of the target server.
|
463 |
+
path: String; path relative to `root` to which the events will be sent.
|
464 |
+
field: String; JSON field under which the images will be stored.
|
465 |
+
headers: Dictionary; optional custom HTTP headers.
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(self,
|
469 |
+
root='http://localhost:9000',
|
470 |
+
path='/publish/epoch/end/',
|
471 |
+
field='images',
|
472 |
+
headers=None):
|
473 |
+
super(RemoteMonitor, self).__init__()
|
474 |
+
|
475 |
+
self.root = root
|
476 |
+
self.path = path
|
477 |
+
self.field = field
|
478 |
+
self.headers = headers
|
479 |
+
|
480 |
+
def on_epoch_end(self, epoch, logs=None):
|
481 |
+
if requests is None:
|
482 |
+
raise ImportError('RemoteMonitor requires '
|
483 |
+
'the `requests` library.')
|
484 |
+
logs = logs or {}
|
485 |
+
send = {}
|
486 |
+
send['epoch'] = epoch
|
487 |
+
for k, v in logs.items():
|
488 |
+
if isinstance(v, (np.ndarray, np.generic)):
|
489 |
+
send[k] = v.item()
|
490 |
+
else:
|
491 |
+
send[k] = v
|
492 |
+
try:
|
493 |
+
requests.post(self.root + self.path,
|
494 |
+
{self.field: json.dumps(send)},
|
495 |
+
headers=self.headers)
|
496 |
+
except requests.exceptions.RequestException:
|
497 |
+
warnings.warn('Warning: could not reach RemoteMonitor '
|
498 |
+
'root server at ' + str(self.root))
|
499 |
+
|
500 |
+
|
501 |
+
class TensorBoard(Callback):
|
502 |
+
"""TensorBoard basic visualizations.
|
503 |
+
|
504 |
+
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
|
505 |
+
is a visualization tool provided with TensorFlow.
|
506 |
+
|
507 |
+
This callback writes a log for TensorBoard, which allows
|
508 |
+
you to visualize dynamic graphs of your training and test
|
509 |
+
metrics, as well as activation histograms for the different
|
510 |
+
layers in your model.
|
511 |
+
|
512 |
+
If you have installed TensorFlow with pip, you should be able
|
513 |
+
to launch TensorBoard from the command line:
|
514 |
+
```sh
|
515 |
+
tensorboard --logdir=/full_path_to_your_logs
|
516 |
+
```
|
517 |
+
|
518 |
+
When using a backend other than TensorFlow, TensorBoard will still work
|
519 |
+
(if you have TensorFlow installed), but the only feature available will
|
520 |
+
be the display of the losses and metrics plots.
|
521 |
+
|
522 |
+
# Arguments
|
523 |
+
log_dir: the path of the directory where to save the log
|
524 |
+
files to be parsed by TensorBoard.
|
525 |
+
histogram_freq: frequency (in epochs) at which to compute activation
|
526 |
+
and weight histograms for the layers of the model. If set to 0,
|
527 |
+
histograms won't be computed. Validation images (or split) must be
|
528 |
+
specified for histogram visualizations.
|
529 |
+
write_graph: whether to visualize the graph in TensorBoard.
|
530 |
+
The log file can become quite large when
|
531 |
+
write_graph is set to True.
|
532 |
+
write_grads: whether to visualize gradient histograms in TensorBoard.
|
533 |
+
`histogram_freq` must be greater than 0.
|
534 |
+
batch_size: size of batch of inputs to feed to the network
|
535 |
+
for histograms computation.
|
536 |
+
write_images: whether to write model weights to visualize as
|
537 |
+
image in TensorBoard.
|
538 |
+
embeddings_freq: frequency (in epochs) at which selected embedding
|
539 |
+
layers will be saved.
|
540 |
+
embeddings_layer_names: a list of names of layers to keep eye on. If
|
541 |
+
None or empty list all the embedding layer will be watched.
|
542 |
+
embeddings_metadata: a dictionary which maps layer name to a file name
|
543 |
+
in which metadata for this embedding layer is saved. See the
|
544 |
+
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
|
545 |
+
about metadata files format. In case if the same metadata file is
|
546 |
+
used for all embedding layers, string can be passed.
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, log_dir='./logs',
|
550 |
+
histogram_freq=0,
|
551 |
+
batch_size=32,
|
552 |
+
write_graph=True,
|
553 |
+
write_grads=False,
|
554 |
+
write_images=False,
|
555 |
+
embeddings_freq=0,
|
556 |
+
embeddings_layer_names=None,
|
557 |
+
embeddings_metadata=None):
|
558 |
+
super(TensorBoard, self).__init__()
|
559 |
+
global tf, projector
|
560 |
+
try:
|
561 |
+
import tensorflow as tf
|
562 |
+
from tensorflow.contrib.tensorboard.plugins import projector
|
563 |
+
except ImportError:
|
564 |
+
raise ImportError('You need the TensorFlow module installed to use TensorBoard.')
|
565 |
+
|
566 |
+
if K.backend() != 'tensorflow':
|
567 |
+
if histogram_freq != 0:
|
568 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
569 |
+
'histogram_freq was set to 0')
|
570 |
+
histogram_freq = 0
|
571 |
+
if write_graph:
|
572 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
573 |
+
'write_graph was set to False')
|
574 |
+
write_graph = False
|
575 |
+
if write_images:
|
576 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
577 |
+
'write_images was set to False')
|
578 |
+
write_images = False
|
579 |
+
if embeddings_freq != 0:
|
580 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
581 |
+
'embeddings_freq was set to 0')
|
582 |
+
embeddings_freq = 0
|
583 |
+
|
584 |
+
self.log_dir = log_dir
|
585 |
+
self.histogram_freq = histogram_freq
|
586 |
+
self.merged = None
|
587 |
+
self.write_graph = write_graph
|
588 |
+
self.write_grads = write_grads
|
589 |
+
self.write_images = write_images
|
590 |
+
self.embeddings_freq = embeddings_freq
|
591 |
+
self.embeddings_layer_names = embeddings_layer_names
|
592 |
+
self.embeddings_metadata = embeddings_metadata or {}
|
593 |
+
self.batch_size = batch_size
|
594 |
+
|
595 |
+
def set_model(self, model):
|
596 |
+
self.model = model
|
597 |
+
if K.backend() == 'tensorflow':
|
598 |
+
self.sess = K.get_session()
|
599 |
+
if self.histogram_freq and self.merged is None:
|
600 |
+
for layer in self.model.layers:
|
601 |
+
|
602 |
+
for weight in layer.weights:
|
603 |
+
mapped_weight_name = weight.name.replace(':', '_')
|
604 |
+
tf.summary.histogram(mapped_weight_name, weight)
|
605 |
+
if self.write_grads:
|
606 |
+
grads = model.optimizer.get_gradients(model.total_loss,
|
607 |
+
weight)
|
608 |
+
|
609 |
+
def is_indexed_slices(grad):
|
610 |
+
return type(grad).__name__ == 'IndexedSlices'
|
611 |
+
grads = [
|
612 |
+
grad.values if is_indexed_slices(grad) else grad
|
613 |
+
for grad in grads]
|
614 |
+
tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
|
615 |
+
if self.write_images:
|
616 |
+
w_img = tf.squeeze(weight)
|
617 |
+
shape = K.int_shape(w_img)
|
618 |
+
if len(shape) == 2: # dense layer kernel case
|
619 |
+
if shape[0] > shape[1]:
|
620 |
+
w_img = tf.transpose(w_img)
|
621 |
+
shape = K.int_shape(w_img)
|
622 |
+
w_img = tf.reshape(w_img, [1,
|
623 |
+
shape[0],
|
624 |
+
shape[1],
|
625 |
+
1])
|
626 |
+
elif len(shape) == 3: # convnet case
|
627 |
+
if K.image_data_format() == 'channels_last':
|
628 |
+
# switch to channels_first to display
|
629 |
+
# every kernel as a separate image
|
630 |
+
w_img = tf.transpose(w_img, perm=[2, 0, 1])
|
631 |
+
shape = K.int_shape(w_img)
|
632 |
+
w_img = tf.reshape(w_img, [shape[0],
|
633 |
+
shape[1],
|
634 |
+
shape[2],
|
635 |
+
1])
|
636 |
+
elif len(shape) == 1: # bias case
|
637 |
+
w_img = tf.reshape(w_img, [1,
|
638 |
+
shape[0],
|
639 |
+
1,
|
640 |
+
1])
|
641 |
+
else:
|
642 |
+
# not possible to handle 3D convnets etc.
|
643 |
+
continue
|
644 |
+
|
645 |
+
shape = K.int_shape(w_img)
|
646 |
+
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
|
647 |
+
tf.summary.image(mapped_weight_name, w_img)
|
648 |
+
|
649 |
+
if hasattr(layer, 'output'):
|
650 |
+
tf.summary.histogram('{}_out'.format(layer.name),
|
651 |
+
layer.output)
|
652 |
+
self.merged = tf.summary.merge_all()
|
653 |
+
|
654 |
+
if self.write_graph:
|
655 |
+
self.writer = tf.summary.FileWriter(self.log_dir,
|
656 |
+
self.sess.graph)
|
657 |
+
else:
|
658 |
+
self.writer = tf.summary.FileWriter(self.log_dir)
|
659 |
+
|
660 |
+
if self.embeddings_freq:
|
661 |
+
embeddings_layer_names = self.embeddings_layer_names
|
662 |
+
|
663 |
+
if not embeddings_layer_names:
|
664 |
+
embeddings_layer_names = [layer.name for layer in self.model.layers
|
665 |
+
if type(layer).__name__ == 'Embedding']
|
666 |
+
|
667 |
+
embeddings = {layer.name: layer.weights[0]
|
668 |
+
for layer in self.model.layers
|
669 |
+
if layer.name in embeddings_layer_names}
|
670 |
+
|
671 |
+
self.saver = tf.train.Saver(list(embeddings.values()))
|
672 |
+
|
673 |
+
embeddings_metadata = {}
|
674 |
+
|
675 |
+
if not isinstance(self.embeddings_metadata, str):
|
676 |
+
embeddings_metadata = self.embeddings_metadata
|
677 |
+
else:
|
678 |
+
embeddings_metadata = {layer_name: self.embeddings_metadata
|
679 |
+
for layer_name in embeddings.keys()}
|
680 |
+
|
681 |
+
config = projector.ProjectorConfig()
|
682 |
+
self.embeddings_ckpt_path = os.path.join(self.log_dir,
|
683 |
+
'keras_embedding.ckpt')
|
684 |
+
|
685 |
+
for layer_name, tensor in embeddings.items():
|
686 |
+
embedding = config.embeddings.add()
|
687 |
+
embedding.tensor_name = tensor.name
|
688 |
+
|
689 |
+
if layer_name in embeddings_metadata:
|
690 |
+
embedding.metadata_path = embeddings_metadata[layer_name]
|
691 |
+
|
692 |
+
projector.visualize_embeddings(self.writer, config)
|
693 |
+
|
694 |
+
def on_epoch_end(self, epoch, logs=None):
|
695 |
+
logs = logs or {}
|
696 |
+
|
697 |
+
if not self.validation_data and self.histogram_freq:
|
698 |
+
raise ValueError('If printing histograms, validation_data must be '
|
699 |
+
'provided, and cannot be a generator.')
|
700 |
+
if self.validation_data and self.histogram_freq:
|
701 |
+
if epoch % self.histogram_freq == 0:
|
702 |
+
|
703 |
+
val_data = self.validation_data
|
704 |
+
tensors = (self.model.inputs +
|
705 |
+
self.model.targets +
|
706 |
+
self.model.sample_weights)
|
707 |
+
|
708 |
+
if self.model.uses_learning_phase:
|
709 |
+
tensors += [K.learning_phase()]
|
710 |
+
|
711 |
+
assert len(val_data) == len(tensors)
|
712 |
+
val_size = val_data[0].shape[0]
|
713 |
+
i = 0
|
714 |
+
while i < val_size:
|
715 |
+
step = min(self.batch_size, val_size - i)
|
716 |
+
if self.model.uses_learning_phase:
|
717 |
+
# do not slice the learning phase
|
718 |
+
batch_val = [x[i:i + step] for x in val_data[:-1]]
|
719 |
+
batch_val.append(val_data[-1])
|
720 |
+
else:
|
721 |
+
batch_val = [x[i:i + step] for x in val_data]
|
722 |
+
assert len(batch_val) == len(tensors)
|
723 |
+
feed_dict = dict(zip(tensors, batch_val))
|
724 |
+
result = self.sess.run([self.merged], feed_dict=feed_dict)
|
725 |
+
summary_str = result[0]
|
726 |
+
self.writer.add_summary(summary_str, epoch)
|
727 |
+
i += self.batch_size
|
728 |
+
|
729 |
+
if self.embeddings_freq and self.embeddings_ckpt_path:
|
730 |
+
if epoch % self.embeddings_freq == 0:
|
731 |
+
self.saver.save(self.sess,
|
732 |
+
self.embeddings_ckpt_path,
|
733 |
+
epoch)
|
734 |
+
|
735 |
+
for name, value in logs.items():
|
736 |
+
if name in ['batch', 'size']:
|
737 |
+
continue
|
738 |
+
summary = tf.Summary()
|
739 |
+
summary_value = summary.value.add()
|
740 |
+
summary_value.simple_value = value.item()
|
741 |
+
summary_value.tag = name
|
742 |
+
self.writer.add_summary(summary, epoch)
|
743 |
+
self.writer.flush()
|
744 |
+
|
745 |
+
def on_train_end(self, _):
|
746 |
+
self.writer.close()
|
747 |
+
|
748 |
+
|
749 |
+
class CSVLogger(Callback):
|
750 |
+
"""Callback that streams epoch results to a csv file.
|
751 |
+
|
752 |
+
Supports all values that can be represented as a string,
|
753 |
+
including 1D iterables such as np.ndarray.
|
754 |
+
|
755 |
+
# Example
|
756 |
+
|
757 |
+
```python
|
758 |
+
csv_logger = CSVLogger('training.log')
|
759 |
+
model.fit(X_train, Y_train, callbacks=[csv_logger])
|
760 |
+
```
|
761 |
+
|
762 |
+
# Arguments
|
763 |
+
filename: filename of the csv file, e.g. 'run/log.csv'.
|
764 |
+
separator: string used to separate elements in the csv file.
|
765 |
+
append: True: append if file exists (useful for continuing
|
766 |
+
training). False: overwrite existing file,
|
767 |
+
output_on_train_end: An additional output file to write to
|
768 |
+
write to when training ends. An example is
|
769 |
+
CSVLogger(filename='./mylog.csv', output_on_train_end=os.sys.stdout)
|
770 |
+
"""
|
771 |
+
|
772 |
+
def __init__(self, filename, separator=',', append=False, output_on_train_end=None):
|
773 |
+
self.sep = separator
|
774 |
+
self.filename = filename
|
775 |
+
self.append = append
|
776 |
+
self.writer = None
|
777 |
+
self.keys = None
|
778 |
+
self.append_header = True
|
779 |
+
self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
|
780 |
+
self.output_on_train_end = output_on_train_end
|
781 |
+
super(CSVLogger, self).__init__()
|
782 |
+
|
783 |
+
def on_train_begin(self, logs=None):
|
784 |
+
if self.append:
|
785 |
+
if os.path.exists(self.filename):
|
786 |
+
with open(self.filename, 'r' + self.file_flags) as f:
|
787 |
+
self.append_header = not bool(len(f.readline()))
|
788 |
+
self.csv_file = open(self.filename, 'a' + self.file_flags)
|
789 |
+
else:
|
790 |
+
self.csv_file = open(self.filename, 'w' + self.file_flags)
|
791 |
+
|
792 |
+
def on_epoch_end(self, epoch, logs=None):
|
793 |
+
logs = logs or {}
|
794 |
+
|
795 |
+
def handle_value(k):
|
796 |
+
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
797 |
+
if isinstance(k, six.string_types):
|
798 |
+
return k
|
799 |
+
elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
|
800 |
+
return '"[%s]"' % (', '.join(map(str, k)))
|
801 |
+
else:
|
802 |
+
return k
|
803 |
+
|
804 |
+
if self.keys is None:
|
805 |
+
self.keys = sorted(logs.keys())
|
806 |
+
|
807 |
+
if self.model is not None and getattr(self.model, 'stop_training', False):
|
808 |
+
# We set NA so that csv parsers do not fail for this last epoch.
|
809 |
+
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
|
810 |
+
|
811 |
+
if not self.writer:
|
812 |
+
class CustomDialect(csv.excel):
|
813 |
+
delimiter = self.sep
|
814 |
+
|
815 |
+
self.writer = csv.DictWriter(self.csv_file,
|
816 |
+
fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
|
817 |
+
if self.append_header:
|
818 |
+
self.writer.writeheader()
|
819 |
+
|
820 |
+
row_dict = OrderedDict({'epoch': epoch})
|
821 |
+
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
|
822 |
+
self.writer.writerow(row_dict)
|
823 |
+
self.csv_file.flush()
|
824 |
+
|
825 |
+
def on_train_end(self, logs=None):
|
826 |
+
self.csv_file.close()
|
827 |
+
if os.path.exists(self.filename):
|
828 |
+
with open(self.filename, 'r' + self.file_flags) as f:
|
829 |
+
print(f.read(), file=self.output_on_train_end)
|
830 |
+
self.writer = None
|
831 |
+
|
832 |
+
|
833 |
+
class LambdaCallback(Callback):
|
834 |
+
r"""Callback for creating simple, custom callbacks on-the-fly.
|
835 |
+
|
836 |
+
This callback is constructed with anonymous functions that will be called
|
837 |
+
at the appropriate time. Note that the callbacks expects positional
|
838 |
+
arguments, as:
|
839 |
+
|
840 |
+
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
841 |
+
`epoch`, `logs`
|
842 |
+
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
843 |
+
`batch`, `logs`
|
844 |
+
- `on_train_begin` and `on_train_end` expect one positional argument:
|
845 |
+
`logs`
|
846 |
+
|
847 |
+
# Arguments
|
848 |
+
on_epoch_begin: called at the beginning of every epoch.
|
849 |
+
on_epoch_end: called at the end of every epoch.
|
850 |
+
on_batch_begin: called at the beginning of every batch.
|
851 |
+
on_batch_end: called at the end of every batch.
|
852 |
+
on_train_begin: called at the beginning of model training.
|
853 |
+
on_train_end: called at the end of model training.
|
854 |
+
|
855 |
+
# Example
|
856 |
+
|
857 |
+
```python
|
858 |
+
# Print the batch number at the beginning of every batch.
|
859 |
+
batch_print_callback = LambdaCallback(
|
860 |
+
on_batch_begin=lambda batch,logs: print(batch))
|
861 |
+
|
862 |
+
# Stream the epoch loss to a file in JSON format. The file content
|
863 |
+
# is not well-formed JSON but rather has a JSON object per line.
|
864 |
+
import json
|
865 |
+
json_log = open('loss_log.json', mode='wt', buffering=1)
|
866 |
+
json_logging_callback = LambdaCallback(
|
867 |
+
on_epoch_end=lambda epoch, logs: json_log.write(
|
868 |
+
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
|
869 |
+
on_train_end=lambda logs: json_log.close()
|
870 |
+
)
|
871 |
+
|
872 |
+
# Terminate some processes after having finished model training.
|
873 |
+
processes = ...
|
874 |
+
cleanup_callback = LambdaCallback(
|
875 |
+
on_train_end=lambda logs: [
|
876 |
+
p.terminate() for p in processes if p.is_alive()])
|
877 |
+
|
878 |
+
model.fit(...,
|
879 |
+
callbacks=[batch_print_callback,
|
880 |
+
json_logging_callback,
|
881 |
+
cleanup_callback])
|
882 |
+
```
|
883 |
+
"""
|
884 |
+
|
885 |
+
def __init__(self,
|
886 |
+
on_epoch_begin=None,
|
887 |
+
on_epoch_end=None,
|
888 |
+
on_batch_begin=None,
|
889 |
+
on_batch_end=None,
|
890 |
+
on_train_begin=None,
|
891 |
+
on_train_end=None,
|
892 |
+
**kwargs):
|
893 |
+
super(LambdaCallback, self).__init__()
|
894 |
+
self.__dict__.update(kwargs)
|
895 |
+
if on_epoch_begin is not None:
|
896 |
+
self.on_epoch_begin = on_epoch_begin
|
897 |
+
else:
|
898 |
+
self.on_epoch_begin = lambda epoch, logs: None
|
899 |
+
if on_epoch_end is not None:
|
900 |
+
self.on_epoch_end = on_epoch_end
|
901 |
+
else:
|
902 |
+
self.on_epoch_end = lambda epoch, logs: None
|
903 |
+
if on_batch_begin is not None:
|
904 |
+
self.on_batch_begin = on_batch_begin
|
905 |
+
else:
|
906 |
+
self.on_batch_begin = lambda batch, logs: None
|
907 |
+
if on_batch_end is not None:
|
908 |
+
self.on_batch_end = on_batch_end
|
909 |
+
else:
|
910 |
+
self.on_batch_end = lambda batch, logs: None
|
911 |
+
if on_train_begin is not None:
|
912 |
+
self.on_train_begin = on_train_begin
|
913 |
+
else:
|
914 |
+
self.on_train_begin = lambda logs: None
|
915 |
+
if on_train_end is not None:
|
916 |
+
self.on_train_end = on_train_end
|
917 |
+
else:
|
918 |
+
self.on_train_end = lambda logs: None
|
919 |
+
from sys import stderr
|
920 |
+
|
921 |
+
|
922 |
+
class TQDMCallback(Callback):
|
923 |
+
def __init__(self, outer_description="Training",
|
924 |
+
inner_description_initial="Epoch: {epoch}",
|
925 |
+
inner_description_update="Epoch: {epoch} - {metrics}",
|
926 |
+
metric_format="{name}: {value:0.3f}",
|
927 |
+
separator=", ",
|
928 |
+
leave_inner=True,
|
929 |
+
leave_outer=True,
|
930 |
+
show_inner=True,
|
931 |
+
show_outer=True,
|
932 |
+
output_file=stderr,
|
933 |
+
initial=0):
|
934 |
+
"""
|
935 |
+
Construct a callback that will create and update progress bars.
|
936 |
+
|
937 |
+
:param outer_description: string for outer progress bar
|
938 |
+
:param inner_description_initial: initial format for epoch ("Epoch: {epoch}")
|
939 |
+
:param inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}")
|
940 |
+
:param metric_format: format for each metric name/value pair ("{name}: {value:0.3f}")
|
941 |
+
:param separator: separator between metrics (", ")
|
942 |
+
:param leave_inner: True to leave inner bars
|
943 |
+
:param leave_outer: True to leave outer bars
|
944 |
+
:param show_inner: False to hide inner bars
|
945 |
+
:param show_outer: False to hide outer bar
|
946 |
+
:param output_file: output file (default sys.stderr)
|
947 |
+
:param initial: Initial counter state
|
948 |
+
"""
|
949 |
+
self.outer_description = outer_description
|
950 |
+
self.inner_description_initial = inner_description_initial
|
951 |
+
self.inner_description_update = inner_description_update
|
952 |
+
self.metric_format = metric_format
|
953 |
+
self.separator = separator
|
954 |
+
self.leave_inner = leave_inner
|
955 |
+
self.leave_outer = leave_outer
|
956 |
+
self.show_inner = show_inner
|
957 |
+
self.show_outer = show_outer
|
958 |
+
self.output_file = output_file
|
959 |
+
self.tqdm_outer = None
|
960 |
+
self.tqdm_inner = None
|
961 |
+
self.epoch = None
|
962 |
+
self.running_logs = None
|
963 |
+
self.inner_count = None
|
964 |
+
self.initial = initial
|
965 |
+
|
966 |
+
def tqdm(self, desc, total, leave, initial=0):
|
967 |
+
"""
|
968 |
+
Extension point. Override to provide custom options to tqdm initializer.
|
969 |
+
:param desc: Description string
|
970 |
+
:param total: Total number of updates
|
971 |
+
:param leave: Leave progress bar when done
|
972 |
+
:param initial: Initial counter state
|
973 |
+
:return: new progress bar
|
974 |
+
"""
|
975 |
+
return tqdm(desc=desc, total=total, leave=leave, file=self.output_file, initial=initial)
|
976 |
+
|
977 |
+
def build_tqdm_outer(self, desc, total):
|
978 |
+
"""
|
979 |
+
Extension point. Override to provide custom options to outer progress bars (Epoch loop)
|
980 |
+
:param desc: Description
|
981 |
+
:param total: Number of epochs
|
982 |
+
:return: new progress bar
|
983 |
+
"""
|
984 |
+
return self.tqdm(desc=desc, total=total, leave=self.leave_outer, initial=self.initial)
|
985 |
+
|
986 |
+
def build_tqdm_inner(self, desc, total):
|
987 |
+
"""
|
988 |
+
Extension point. Override to provide custom options to inner progress bars (Batch loop)
|
989 |
+
:param desc: Description
|
990 |
+
:param total: Number of batches
|
991 |
+
:return: new progress bar
|
992 |
+
"""
|
993 |
+
return self.tqdm(desc=desc, total=total, leave=self.leave_inner)
|
994 |
+
|
995 |
+
def on_epoch_begin(self, epoch, logs={}):
|
996 |
+
self.epoch = epoch
|
997 |
+
desc = self.inner_description_initial.format(epoch=self.epoch)
|
998 |
+
self.mode = 0 # samples
|
999 |
+
if 'samples' in self.params:
|
1000 |
+
self.inner_total = self.params['samples']
|
1001 |
+
elif 'nb_sample' in self.params:
|
1002 |
+
self.inner_total = self.params['nb_sample']
|
1003 |
+
else:
|
1004 |
+
self.mode = 1 # steps
|
1005 |
+
self.inner_total = self.params['steps']
|
1006 |
+
if self.show_inner:
|
1007 |
+
self.tqdm_inner = self.build_tqdm_inner(desc=desc, total=self.inner_total)
|
1008 |
+
self.inner_count = 0
|
1009 |
+
self.running_logs = {}
|
1010 |
+
|
1011 |
+
def on_epoch_end(self, epoch, logs={}):
|
1012 |
+
metrics = self.format_metrics(logs)
|
1013 |
+
desc = self.inner_description_update.format(epoch=epoch, metrics=metrics)
|
1014 |
+
if self.show_inner:
|
1015 |
+
self.tqdm_inner.desc = desc
|
1016 |
+
# set miniters and mininterval to 0 so last update displays
|
1017 |
+
self.tqdm_inner.miniters = 0
|
1018 |
+
self.tqdm_inner.mininterval = 0
|
1019 |
+
self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n)
|
1020 |
+
self.tqdm_inner.close()
|
1021 |
+
if self.show_outer:
|
1022 |
+
self.tqdm_outer.update(1)
|
1023 |
+
|
1024 |
+
def on_batch_begin(self, batch, logs={}):
|
1025 |
+
pass
|
1026 |
+
|
1027 |
+
def on_batch_end(self, batch, logs={}):
|
1028 |
+
if self.mode == 0:
|
1029 |
+
update = logs['size']
|
1030 |
+
else:
|
1031 |
+
update = 1
|
1032 |
+
self.inner_count += update
|
1033 |
+
if self.inner_count < self.inner_total:
|
1034 |
+
self.append_logs(logs)
|
1035 |
+
metrics = self.format_metrics(self.running_logs)
|
1036 |
+
desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics)
|
1037 |
+
if self.show_inner:
|
1038 |
+
self.tqdm_inner.desc = desc
|
1039 |
+
self.tqdm_inner.update(update)
|
1040 |
+
|
1041 |
+
def on_train_begin(self, logs={}):
|
1042 |
+
if self.show_outer:
|
1043 |
+
epochs = (self.params['epochs'] if 'epochs' in self.params
|
1044 |
+
else self.params['nb_epoch'])
|
1045 |
+
self.tqdm_outer = self.build_tqdm_outer(desc=self.outer_description,
|
1046 |
+
total=epochs)
|
1047 |
+
|
1048 |
+
def on_train_end(self, logs={}):
|
1049 |
+
if self.show_outer:
|
1050 |
+
self.tqdm_outer.close()
|
1051 |
+
|
1052 |
+
def append_logs(self, logs):
|
1053 |
+
metrics = self.params['metrics']
|
1054 |
+
for metric, value in six.iteritems(logs):
|
1055 |
+
if metric in metrics:
|
1056 |
+
if metric in self.running_logs:
|
1057 |
+
self.running_logs[metric].append(value[()])
|
1058 |
+
else:
|
1059 |
+
self.running_logs[metric] = [value[()]]
|
1060 |
+
|
1061 |
+
def format_metrics(self, logs):
|
1062 |
+
metrics = self.params['metrics']
|
1063 |
+
strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics
|
1064 |
+
if
|
1065 |
+
metric in logs]
|
1066 |
+
return self.separator.join(strings)
|
dataset.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import Dataset
|
|
3 |
from PIL import Image
|
4 |
import os
|
5 |
import json
|
6 |
-
from build_vocab import Vocabulary, JsonReader
|
7 |
import numpy as np
|
8 |
from torchvision import transforms
|
9 |
import pickle
|
|
|
3 |
from PIL import Image
|
4 |
import os
|
5 |
import json
|
6 |
+
from utils.build_vocab import Vocabulary, JsonReader
|
7 |
import numpy as np
|
8 |
from torchvision import transforms
|
9 |
import pickle
|
logger.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
import scipy.misc
|
5 |
+
try:
|
6 |
+
from StringIO import StringIO # Python 2.7
|
7 |
+
except ImportError:
|
8 |
+
from io import BytesIO # Python 3.x
|
9 |
+
|
10 |
+
|
11 |
+
class Logger(object):
|
12 |
+
|
13 |
+
def __init__(self, log_dir):
|
14 |
+
"""Create a summary writer logging to log_dir."""
|
15 |
+
self.writer = tf.summary.FileWriter(log_dir)
|
16 |
+
|
17 |
+
def scalar_summary(self, tag, value, step):
|
18 |
+
"""Log a scalar variable."""
|
19 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
20 |
+
self.writer.add_summary(summary, step)
|
21 |
+
|
22 |
+
def image_summary(self, tag, images, step):
|
23 |
+
"""Log a list of images."""
|
24 |
+
|
25 |
+
img_summaries = []
|
26 |
+
for i, img in enumerate(images):
|
27 |
+
# Write the image to a string
|
28 |
+
try:
|
29 |
+
s = StringIO()
|
30 |
+
except:
|
31 |
+
s = BytesIO()
|
32 |
+
scipy.misc.toimage(img).save(s, format="png")
|
33 |
+
|
34 |
+
# Create an Image object
|
35 |
+
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
36 |
+
height=img.shape[0],
|
37 |
+
width=img.shape[1])
|
38 |
+
# Create a Summary value
|
39 |
+
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
|
40 |
+
|
41 |
+
# Create and write Summary
|
42 |
+
summary = tf.Summary(value=img_summaries)
|
43 |
+
self.writer.add_summary(summary, step)
|
44 |
+
|
45 |
+
def histo_summary(self, tag, values, step, bins=1000):
|
46 |
+
"""Log a histogram of the tensor of values."""
|
47 |
+
|
48 |
+
# Create a histogram using numpy
|
49 |
+
counts, bin_edges = np.histogram(values, bins=bins)
|
50 |
+
|
51 |
+
# Fill the fields of the histogram proto
|
52 |
+
hist = tf.HistogramProto()
|
53 |
+
hist.min = float(np.min(values))
|
54 |
+
hist.max = float(np.max(values))
|
55 |
+
hist.num = int(np.prod(values.shape))
|
56 |
+
hist.sum = float(np.sum(values))
|
57 |
+
hist.sum_squares = float(np.sum(values**2))
|
58 |
+
|
59 |
+
# Drop the start of the first bin
|
60 |
+
bin_edges = bin_edges[1:]
|
61 |
+
|
62 |
+
# Add bin edges and counts
|
63 |
+
for edge in bin_edges:
|
64 |
+
hist.bucket_limit.append(edge)
|
65 |
+
for c in counts:
|
66 |
+
hist.bucket.append(c)
|
67 |
+
|
68 |
+
# Create and write Summary
|
69 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
70 |
+
self.writer.add_summary(summary, step)
|
71 |
+
self.writer.flush()
|
models_debugger.py
ADDED
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
import numpy as np
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from torchvision.models.vgg import model_urls as vgg_model_urls
|
7 |
+
import torchvision.models as models
|
8 |
+
|
9 |
+
from utils.tcn import *
|
10 |
+
|
11 |
+
|
12 |
+
class DenseNet121(nn.Module):
|
13 |
+
def __init__(self, classes=14, pretrained=True):
|
14 |
+
super(DenseNet121, self).__init__()
|
15 |
+
self.model = torchvision.models.densenet121(pretrained=pretrained)
|
16 |
+
num_in_features = self.model.classifier.in_features
|
17 |
+
self.model.classifier = nn.Sequential(
|
18 |
+
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
|
19 |
+
# nn.Sigmoid()
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x) -> object:
|
23 |
+
"""
|
24 |
+
|
25 |
+
:rtype: object
|
26 |
+
"""
|
27 |
+
x = self.densenet121(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class DenseNet161(nn.Module):
|
32 |
+
def __init__(self, classes=156, pretrained=True):
|
33 |
+
super(DenseNet161, self).__init__()
|
34 |
+
self.model = torchvision.models.densenet161(pretrained=pretrained)
|
35 |
+
num_in_features = self.model.classifier.in_features
|
36 |
+
self.model.classifier = nn.Sequential(
|
37 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
38 |
+
# nn.Sigmoid()
|
39 |
+
)
|
40 |
+
|
41 |
+
def __init_linear(self, in_features, out_features):
|
42 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
43 |
+
func.weight.data.normal_(0, 0.1)
|
44 |
+
return func
|
45 |
+
|
46 |
+
def forward(self, x) -> object:
|
47 |
+
"""
|
48 |
+
|
49 |
+
:rtype: object
|
50 |
+
"""
|
51 |
+
x = self.model(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class DenseNet169(nn.Module):
|
56 |
+
def __init__(self, classes=156, pretrained=True):
|
57 |
+
super(DenseNet169, self).__init__()
|
58 |
+
self.model = torchvision.models.densenet169(pretrained=pretrained)
|
59 |
+
num_in_features = self.model.classifier.in_features
|
60 |
+
self.model.classifier = nn.Sequential(
|
61 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
62 |
+
# nn.Sigmoid()
|
63 |
+
)
|
64 |
+
|
65 |
+
def __init_linear(self, in_features, out_features):
|
66 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
67 |
+
func.weight.data.normal_(0, 0.1)
|
68 |
+
return func
|
69 |
+
|
70 |
+
def forward(self, x) -> object:
|
71 |
+
"""
|
72 |
+
|
73 |
+
:rtype: object
|
74 |
+
"""
|
75 |
+
x = self.model(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class DenseNet201(nn.Module):
|
80 |
+
def __init__(self, classes=156, pretrained=True):
|
81 |
+
super(DenseNet201, self).__init__()
|
82 |
+
self.model = torchvision.models.densenet201(pretrained=pretrained)
|
83 |
+
num_in_features = self.model.classifier.in_features
|
84 |
+
self.model.classifier = nn.Sequential(
|
85 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
86 |
+
nn.Sigmoid()
|
87 |
+
)
|
88 |
+
|
89 |
+
def __init_linear(self, in_features, out_features):
|
90 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
91 |
+
func.weight.data.normal_(0, 0.1)
|
92 |
+
return func
|
93 |
+
|
94 |
+
def forward(self, x) -> object:
|
95 |
+
"""
|
96 |
+
|
97 |
+
:rtype: object
|
98 |
+
"""
|
99 |
+
x = self.model(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class ResNet18(nn.Module):
|
104 |
+
def __init__(self, classes=156, pretrained=True):
|
105 |
+
super(ResNet18, self).__init__()
|
106 |
+
self.model = torchvision.models.resnet18(pretrained=pretrained)
|
107 |
+
num_in_features = self.model.fc.in_features
|
108 |
+
self.model.fc = nn.Sequential(
|
109 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
110 |
+
# nn.Sigmoid()
|
111 |
+
)
|
112 |
+
|
113 |
+
def __init_linear(self, in_features, out_features):
|
114 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
115 |
+
func.weight.data.normal_(0, 0.1)
|
116 |
+
return func
|
117 |
+
|
118 |
+
def forward(self, x) -> object:
|
119 |
+
"""
|
120 |
+
|
121 |
+
:rtype: object
|
122 |
+
"""
|
123 |
+
x = self.model(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class ResNet34(nn.Module):
|
128 |
+
def __init__(self, classes=156, pretrained=True):
|
129 |
+
super(ResNet34, self).__init__()
|
130 |
+
self.model = torchvision.models.resnet34(pretrained=pretrained)
|
131 |
+
num_in_features = self.model.fc.in_features
|
132 |
+
self.model.fc = nn.Sequential(
|
133 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
134 |
+
# nn.Sigmoid()
|
135 |
+
)
|
136 |
+
|
137 |
+
def __init_linear(self, in_features, out_features):
|
138 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
139 |
+
func.weight.data.normal_(0, 0.1)
|
140 |
+
return func
|
141 |
+
|
142 |
+
def forward(self, x) -> object:
|
143 |
+
"""
|
144 |
+
|
145 |
+
:rtype: object
|
146 |
+
"""
|
147 |
+
x = self.model(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class ResNet50(nn.Module):
|
152 |
+
def __init__(self, classes=156, pretrained=True):
|
153 |
+
super(ResNet50, self).__init__()
|
154 |
+
self.model = torchvision.models.resnet50(pretrained=pretrained)
|
155 |
+
num_in_features = self.model.fc.in_features
|
156 |
+
self.model.fc = nn.Sequential(
|
157 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
158 |
+
# nn.Sigmoid()
|
159 |
+
)
|
160 |
+
|
161 |
+
def __init_linear(self, in_features, out_features):
|
162 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
163 |
+
func.weight.data.normal_(0, 0.1)
|
164 |
+
return func
|
165 |
+
|
166 |
+
def forward(self, x) -> object:
|
167 |
+
"""
|
168 |
+
|
169 |
+
:rtype: object
|
170 |
+
"""
|
171 |
+
x = self.model(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class ResNet101(nn.Module):
|
176 |
+
def __init__(self, classes=156, pretrained=True):
|
177 |
+
super(ResNet101, self).__init__()
|
178 |
+
self.model = torchvision.models.resnet101(pretrained=pretrained)
|
179 |
+
num_in_features = self.model.fc.in_features
|
180 |
+
self.model.fc = nn.Sequential(
|
181 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
182 |
+
# nn.Sigmoid()
|
183 |
+
)
|
184 |
+
|
185 |
+
def __init_linear(self, in_features, out_features):
|
186 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
187 |
+
func.weight.data.normal_(0, 0.1)
|
188 |
+
return func
|
189 |
+
|
190 |
+
def forward(self, x) -> object:
|
191 |
+
"""
|
192 |
+
|
193 |
+
:rtype: object
|
194 |
+
"""
|
195 |
+
x = self.model(x)
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class ResNet152(nn.Module):
|
200 |
+
def __init__(self, classes=156, pretrained=True):
|
201 |
+
super(ResNet152, self).__init__()
|
202 |
+
self.model = torchvision.models.resnet152(pretrained=pretrained)
|
203 |
+
num_in_features = self.model.fc.in_features
|
204 |
+
self.model.fc = nn.Sequential(
|
205 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
206 |
+
# nn.Sigmoid()
|
207 |
+
)
|
208 |
+
|
209 |
+
def __init_linear(self, in_features, out_features):
|
210 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
211 |
+
func.weight.data.normal_(0, 0.1)
|
212 |
+
return func
|
213 |
+
|
214 |
+
def forward(self, x) -> object:
|
215 |
+
"""
|
216 |
+
|
217 |
+
:rtype: object
|
218 |
+
"""
|
219 |
+
x = self.model(x)
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
class VGG19(nn.Module):
|
224 |
+
def __init__(self, classes=14, pretrained=True):
|
225 |
+
super(VGG19, self).__init__()
|
226 |
+
self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
|
227 |
+
self.model.classifier = nn.Sequential(
|
228 |
+
self.__init_linear(in_features=25088, out_features=4096),
|
229 |
+
nn.ReLU(),
|
230 |
+
nn.Dropout(0.5),
|
231 |
+
self.__init_linear(in_features=4096, out_features=4096),
|
232 |
+
nn.ReLU(),
|
233 |
+
nn.Dropout(0.5),
|
234 |
+
self.__init_linear(in_features=4096, out_features=classes),
|
235 |
+
# nn.Sigmoid()
|
236 |
+
)
|
237 |
+
|
238 |
+
def __init_linear(self, in_features, out_features):
|
239 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
240 |
+
func.weight.data.normal_(0, 0.1)
|
241 |
+
return func
|
242 |
+
|
243 |
+
def forward(self, x) -> object:
|
244 |
+
"""
|
245 |
+
|
246 |
+
:rtype: object
|
247 |
+
"""
|
248 |
+
x = self.model(x)
|
249 |
+
return x
|
250 |
+
|
251 |
+
|
252 |
+
class VGG(nn.Module):
|
253 |
+
def __init__(self, tags_num):
|
254 |
+
super(VGG, self).__init__()
|
255 |
+
vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://')
|
256 |
+
self.vgg19 = models.vgg19(pretrained=True)
|
257 |
+
vgg19_classifier = list(self.vgg19.classifier.children())[:-1]
|
258 |
+
self.classifier = nn.Sequential(*vgg19_classifier)
|
259 |
+
self.fc = nn.Linear(4096, tags_num)
|
260 |
+
self.fc.apply(self.init_weights)
|
261 |
+
self.bn = nn.BatchNorm1d(tags_num, momentum=0.1)
|
262 |
+
# self.init_weights()
|
263 |
+
|
264 |
+
def init_weights(self, m):
|
265 |
+
if type(m) == nn.Linear:
|
266 |
+
self.fc.weight.data.normal_(0, 0.1)
|
267 |
+
self.fc.bias.data.fill_(0)
|
268 |
+
|
269 |
+
def forward(self, images) -> object:
|
270 |
+
"""
|
271 |
+
|
272 |
+
:rtype: object
|
273 |
+
"""
|
274 |
+
visual_feats = self.vgg19.features(images)
|
275 |
+
tags_classifier = visual_feats.view(visual_feats.size(0), -1)
|
276 |
+
tags_classifier = self.bn(self.fc(self.classifier(tags_classifier)))
|
277 |
+
return tags_classifier
|
278 |
+
|
279 |
+
|
280 |
+
class InceptionV3(nn.Module):
|
281 |
+
def __init__(self, classes=156, pretrained=True):
|
282 |
+
super(InceptionV3, self).__init__()
|
283 |
+
self.model = torchvision.models.inception_v3(pretrained=pretrained)
|
284 |
+
num_in_features = self.model.classifier.in_features
|
285 |
+
self.model.classifier = nn.Sequential(
|
286 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
287 |
+
# nn.Sigmoid()
|
288 |
+
)
|
289 |
+
|
290 |
+
def __init_linear(self, in_features, out_features):
|
291 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
292 |
+
func.weight.data.normal_(0, 0.1)
|
293 |
+
return func
|
294 |
+
|
295 |
+
def forward(self, x) -> object:
|
296 |
+
"""
|
297 |
+
|
298 |
+
:rtype: object
|
299 |
+
"""
|
300 |
+
x = self.model(x)
|
301 |
+
return x
|
302 |
+
|
303 |
+
|
304 |
+
class CheXNetDenseNet121(nn.Module):
|
305 |
+
def __init__(self, classes=14, pretrained=True):
|
306 |
+
super(CheXNetDenseNet121, self).__init__()
|
307 |
+
self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)
|
308 |
+
num_in_features = self.densenet121.classifier.in_features
|
309 |
+
self.densenet121.classifier = nn.Sequential(
|
310 |
+
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
|
311 |
+
nn.Sigmoid()
|
312 |
+
)
|
313 |
+
|
314 |
+
def forward(self, x) -> object:
|
315 |
+
"""
|
316 |
+
|
317 |
+
:rtype: object
|
318 |
+
"""
|
319 |
+
x = self.densenet121(x)
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
class CheXNet(nn.Module):
|
324 |
+
def __init__(self, classes=156):
|
325 |
+
super(CheXNet, self).__init__()
|
326 |
+
self.densenet121 = CheXNetDenseNet121(classes=14)
|
327 |
+
self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda()
|
328 |
+
self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict'])
|
329 |
+
self.densenet121.module.densenet121.classifier = nn.Sequential(
|
330 |
+
self.__init_linear(1024, classes),
|
331 |
+
nn.Sigmoid()
|
332 |
+
)
|
333 |
+
|
334 |
+
def __init_linear(self, in_features, out_features):
|
335 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
336 |
+
func.weight.data.normal_(0, 0.1)
|
337 |
+
return func
|
338 |
+
|
339 |
+
def forward(self, x) -> object:
|
340 |
+
"""
|
341 |
+
|
342 |
+
:rtype: object
|
343 |
+
"""
|
344 |
+
x = self.densenet121(x)
|
345 |
+
return x
|
346 |
+
|
347 |
+
|
348 |
+
class ModelFactory(object):
|
349 |
+
def __init__(self, model_name, pretrained, classes):
|
350 |
+
self.model_name = model_name
|
351 |
+
self.pretrained = pretrained
|
352 |
+
self.classes = classes
|
353 |
+
|
354 |
+
def create_model(self):
|
355 |
+
if self.model_name == 'VGG19':
|
356 |
+
_model = VGG19(pretrained=self.pretrained, classes=self.classes)
|
357 |
+
elif self.model_name == 'DenseNet121':
|
358 |
+
_model = DenseNet121(pretrained=self.pretrained, classes=self.classes)
|
359 |
+
elif self.model_name == 'DenseNet161':
|
360 |
+
_model = DenseNet161(pretrained=self.pretrained, classes=self.classes)
|
361 |
+
elif self.model_name == 'DenseNet169':
|
362 |
+
_model = DenseNet169(pretrained=self.pretrained, classes=self.classes)
|
363 |
+
elif self.model_name == 'DenseNet201':
|
364 |
+
_model = DenseNet201(pretrained=self.pretrained, classes=self.classes)
|
365 |
+
elif self.model_name == 'CheXNet':
|
366 |
+
_model = CheXNet(classes=self.classes)
|
367 |
+
elif self.model_name == 'ResNet18':
|
368 |
+
_model = ResNet18(pretrained=self.pretrained, classes=self.classes)
|
369 |
+
elif self.model_name == 'ResNet34':
|
370 |
+
_model = ResNet34(pretrained=self.pretrained, classes=self.classes)
|
371 |
+
elif self.model_name == 'ResNet50':
|
372 |
+
_model = ResNet50(pretrained=self.pretrained, classes=self.classes)
|
373 |
+
elif self.model_name == 'ResNet101':
|
374 |
+
_model = ResNet101(pretrained=self.pretrained, classes=self.classes)
|
375 |
+
elif self.model_name == 'ResNet152':
|
376 |
+
_model = ResNet152(pretrained=self.pretrained, classes=self.classes)
|
377 |
+
elif self.model_name == 'VGG':
|
378 |
+
_model = VGG(tags_num=self.classes)
|
379 |
+
else:
|
380 |
+
_model = CheXNet(classes=self.classes)
|
381 |
+
|
382 |
+
return _model
|
383 |
+
|
384 |
+
|
385 |
+
class EncoderCNN(nn.Module):
|
386 |
+
def __init__(self, embed_size, pretrained=True):
|
387 |
+
super(EncoderCNN, self).__init__()
|
388 |
+
# TODO Extract Image features from CNN based on other models
|
389 |
+
resnet = models.resnet152(pretrained=pretrained)
|
390 |
+
modules = list(resnet.children())[:-1]
|
391 |
+
self.resnet = nn.Sequential(*modules)
|
392 |
+
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
|
393 |
+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)
|
394 |
+
self.__init_weights()
|
395 |
+
|
396 |
+
def __init_weights(self):
|
397 |
+
self.linear.weight.data.normal_(0.0, 0.1)
|
398 |
+
self.linear.bias.data.fill_(0)
|
399 |
+
|
400 |
+
def forward(self, images) -> object:
|
401 |
+
"""
|
402 |
+
|
403 |
+
:rtype: object
|
404 |
+
"""
|
405 |
+
features = self.resnet(images)
|
406 |
+
features = Variable(features.data)
|
407 |
+
features = features.view(features.size(0), -1)
|
408 |
+
features = self.bn(self.linear(features))
|
409 |
+
return features
|
410 |
+
|
411 |
+
|
412 |
+
class DecoderRNN(nn.Module):
|
413 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
|
414 |
+
super(DecoderRNN, self).__init__()
|
415 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
416 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
417 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
418 |
+
self.__init_weights()
|
419 |
+
self.n_max = n_max
|
420 |
+
|
421 |
+
def __init_weights(self):
|
422 |
+
self.embed.weight.data.uniform_(-0.1, 0.1)
|
423 |
+
self.linear.weight.data.uniform_(-0.1, 0.1)
|
424 |
+
self.linear.bias.data.fill_(0)
|
425 |
+
|
426 |
+
def forward(self, features, captions) -> object:
|
427 |
+
"""
|
428 |
+
|
429 |
+
:rtype: object
|
430 |
+
"""
|
431 |
+
embeddings = self.embed(captions)
|
432 |
+
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
433 |
+
hidden, _ = self.lstm(embeddings)
|
434 |
+
outputs = self.linear(hidden[:, -1, :])
|
435 |
+
return outputs
|
436 |
+
|
437 |
+
def sample(self, features, start_tokens):
|
438 |
+
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
|
439 |
+
predicted = start_tokens
|
440 |
+
embeddings = features
|
441 |
+
embeddings = embeddings.unsqueeze(1)
|
442 |
+
|
443 |
+
for i in range(self.n_max):
|
444 |
+
predicted = self.embed(predicted)
|
445 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
446 |
+
hidden_states, _ = self.lstm(embeddings)
|
447 |
+
hidden_states = hidden_states[:, -1, :]
|
448 |
+
outputs = self.linear(hidden_states)
|
449 |
+
predicted = torch.max(outputs, 1)[1]
|
450 |
+
sampled_ids[:, i] = predicted
|
451 |
+
predicted = predicted.unsqueeze(1)
|
452 |
+
return sampled_ids
|
453 |
+
|
454 |
+
|
455 |
+
class VisualFeatureExtractor(nn.Module):
|
456 |
+
def __init__(self, pretrained=False):
|
457 |
+
super(VisualFeatureExtractor, self).__init__()
|
458 |
+
resnet = models.resnet152(pretrained=pretrained)
|
459 |
+
modules = list(resnet.children())[:-1]
|
460 |
+
self.resnet = nn.Sequential(*modules)
|
461 |
+
self.out_features = resnet.fc.in_features
|
462 |
+
|
463 |
+
def forward(self, images) -> object:
|
464 |
+
"""
|
465 |
+
|
466 |
+
:rtype: object
|
467 |
+
"""
|
468 |
+
features = self.resnet(images)
|
469 |
+
features = features.view(features.size(0), -1)
|
470 |
+
return features
|
471 |
+
|
472 |
+
|
473 |
+
class MLC(nn.Module):
|
474 |
+
def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10):
|
475 |
+
super(MLC, self).__init__()
|
476 |
+
self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
|
477 |
+
self.embed = nn.Embedding(classes, sementic_features_dim)
|
478 |
+
self.k = k
|
479 |
+
self.softmax = nn.Softmax()
|
480 |
+
|
481 |
+
def forward(self, visual_features) -> object:
|
482 |
+
"""
|
483 |
+
|
484 |
+
:rtype: object
|
485 |
+
"""
|
486 |
+
tags = self.softmax(self.classifier(visual_features))
|
487 |
+
semantic_features = self.embed(torch.topk(tags, self.k)[1])
|
488 |
+
return tags, semantic_features
|
489 |
+
|
490 |
+
|
491 |
+
class CoAttention(nn.Module):
|
492 |
+
def __init__(self, embed_size=512, hidden_size=512, visual_size=2048):
|
493 |
+
super(CoAttention, self).__init__()
|
494 |
+
self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
|
495 |
+
self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
496 |
+
|
497 |
+
self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
|
498 |
+
self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
499 |
+
|
500 |
+
self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
|
501 |
+
self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
502 |
+
|
503 |
+
self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
|
504 |
+
self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1)
|
505 |
+
|
506 |
+
self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
|
507 |
+
self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
508 |
+
|
509 |
+
self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
510 |
+
self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1)
|
511 |
+
|
512 |
+
self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
|
513 |
+
self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
|
514 |
+
|
515 |
+
self.tanh = nn.Tanh()
|
516 |
+
self.softmax = nn.Softmax()
|
517 |
+
|
518 |
+
def forward(self, visual_features, semantic_features, h_sent) -> object:
|
519 |
+
"""
|
520 |
+
only training
|
521 |
+
:rtype: object
|
522 |
+
"""
|
523 |
+
W_v = self.bn_v(self.W_v(visual_features))
|
524 |
+
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
|
525 |
+
|
526 |
+
alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
|
527 |
+
v_att = torch.mul(alpha_v, visual_features)
|
528 |
+
# v_att = torch.mul(alpha_v, visual_features).sum(1).unsqueeze(1)
|
529 |
+
|
530 |
+
W_a_h = self.bn_a_h(self.W_a_h(h_sent))
|
531 |
+
W_a = self.bn_a(self.W_a(semantic_features))
|
532 |
+
alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
|
533 |
+
a_att = torch.mul(alpha_a, semantic_features).sum(1)
|
534 |
+
# a_att = (alpha_a * semantic_features).sum(1)
|
535 |
+
ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1)))
|
536 |
+
# return self.W_fc(self.bn_fc(torch.cat([v_att, a_att], dim=1)))
|
537 |
+
return ctx, v_att
|
538 |
+
|
539 |
+
|
540 |
+
class SentenceLSTM(nn.Module):
|
541 |
+
def __init__(self, embed_size=512, hidden_size=512, num_layers=1):
|
542 |
+
super(SentenceLSTM, self).__init__()
|
543 |
+
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
|
544 |
+
self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
545 |
+
self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
546 |
+
|
547 |
+
self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
548 |
+
self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
549 |
+
|
550 |
+
self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
551 |
+
self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
552 |
+
|
553 |
+
self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
554 |
+
self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
555 |
+
|
556 |
+
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
|
557 |
+
self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
558 |
+
|
559 |
+
self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
560 |
+
self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
561 |
+
|
562 |
+
self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
563 |
+
self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
564 |
+
|
565 |
+
self.sigmoid = nn.Sigmoid()
|
566 |
+
self.tanh = nn.Tanh()
|
567 |
+
|
568 |
+
# def forward(self, ctx, prev_hidden_state, states=None) -> object:
|
569 |
+
# """
|
570 |
+
# Only training
|
571 |
+
# :rtype: object
|
572 |
+
# """
|
573 |
+
# ctx = ctx.unsqueeze(1)
|
574 |
+
# hidden_state, states = self.lstm(ctx, states)
|
575 |
+
# topic = self.bn_topic(self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
|
576 |
+
# + self.bn_t_ctx(self.W_t_ctx(ctx)))))
|
577 |
+
# p_stop = self.bn_stop(self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
|
578 |
+
# + self.bn_stop_s(self.W_stop_s(hidden_state)))))
|
579 |
+
# return topic, p_stop, hidden_state, states
|
580 |
+
|
581 |
+
def forward(self, ctx, prev_hidden_state, states=None) -> object:
|
582 |
+
"""
|
583 |
+
v2
|
584 |
+
:rtype: object
|
585 |
+
"""
|
586 |
+
ctx = ctx.unsqueeze(1)
|
587 |
+
hidden_state, states = self.lstm(ctx, states)
|
588 |
+
topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
|
589 |
+
+ self.W_t_ctx(ctx)))))
|
590 |
+
p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
|
591 |
+
+ self.W_stop_s(hidden_state)))))
|
592 |
+
return topic, p_stop, hidden_state, states
|
593 |
+
|
594 |
+
|
595 |
+
class SentenceTCN(nn.Module):
|
596 |
+
def __init__(self,
|
597 |
+
input_channel=10,
|
598 |
+
embed_size=512,
|
599 |
+
output_size=512,
|
600 |
+
nhid=512,
|
601 |
+
levels=8,
|
602 |
+
kernel_size=2,
|
603 |
+
dropout=0):
|
604 |
+
super(SentenceTCN, self).__init__()
|
605 |
+
channel_sizes = [nhid] * levels
|
606 |
+
self.tcn = TCN(input_size=input_channel,
|
607 |
+
output_size=output_size,
|
608 |
+
num_channels=channel_sizes,
|
609 |
+
kernel_size=kernel_size,
|
610 |
+
dropout=dropout)
|
611 |
+
self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
612 |
+
self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
613 |
+
self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
614 |
+
self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
615 |
+
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
|
616 |
+
self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True)
|
617 |
+
self.tanh = nn.Tanh()
|
618 |
+
|
619 |
+
def forward(self, ctx, prev_output) -> object:
|
620 |
+
"""
|
621 |
+
|
622 |
+
:rtype: object
|
623 |
+
"""
|
624 |
+
output = self.tcn.forward(ctx)
|
625 |
+
topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1))
|
626 |
+
p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output)))
|
627 |
+
return topic, p_stop, output
|
628 |
+
|
629 |
+
|
630 |
+
class WordLSTM(nn.Module):
|
631 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
|
632 |
+
super(WordLSTM, self).__init__()
|
633 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
634 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
635 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
636 |
+
self.__init_weights()
|
637 |
+
self.n_max = n_max
|
638 |
+
self.vocab_size = vocab_size
|
639 |
+
|
640 |
+
def __init_weights(self):
|
641 |
+
self.embed.weight.data.uniform_(-0.1, 0.1)
|
642 |
+
self.linear.weight.data.uniform_(-0.1, 0.1)
|
643 |
+
self.linear.bias.data.fill_(0)
|
644 |
+
|
645 |
+
def forward(self, topic_vec, captions) -> object:
|
646 |
+
"""
|
647 |
+
|
648 |
+
:rtype: object
|
649 |
+
"""
|
650 |
+
embeddings = self.embed(captions)
|
651 |
+
embeddings = torch.cat((topic_vec, embeddings), 1)
|
652 |
+
hidden, _ = self.lstm(embeddings)
|
653 |
+
outputs = self.linear(hidden[:, -1, :])
|
654 |
+
return outputs
|
655 |
+
|
656 |
+
def val(self, features, start_tokens):
|
657 |
+
samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size))
|
658 |
+
samples[:, 0, start_tokens[0]] = 1
|
659 |
+
predicted = start_tokens
|
660 |
+
embeddings = features
|
661 |
+
embeddings = embeddings
|
662 |
+
|
663 |
+
for i in range(1, self.n_max):
|
664 |
+
predicted = self.embed(predicted)
|
665 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
666 |
+
hidden_states, _ = self.lstm(embeddings)
|
667 |
+
hidden_states = hidden_states[:, -1, :]
|
668 |
+
outputs = self.linear(hidden_states)
|
669 |
+
samples[:, i, :] = outputs
|
670 |
+
predicted = torch.max(outputs, 1)[1]
|
671 |
+
predicted = predicted.unsqueeze(1)
|
672 |
+
return samples
|
673 |
+
|
674 |
+
def sample(self, features, start_tokens):
|
675 |
+
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
|
676 |
+
sampled_ids[:, 0] = start_tokens.view(-1,)
|
677 |
+
predicted = start_tokens
|
678 |
+
embeddings = features
|
679 |
+
embeddings = embeddings
|
680 |
+
|
681 |
+
for i in range(1, self.n_max):
|
682 |
+
predicted = self.embed(predicted)
|
683 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
684 |
+
hidden_states, _ = self.lstm(embeddings)
|
685 |
+
hidden_states = hidden_states[:, -1, :]
|
686 |
+
outputs = self.linear(hidden_states)
|
687 |
+
predicted = torch.max(outputs, 1)[1]
|
688 |
+
sampled_ids[:, i] = predicted
|
689 |
+
predicted = predicted.unsqueeze(1)
|
690 |
+
return sampled_ids
|
691 |
+
|
692 |
+
|
693 |
+
class WordTCN(nn.Module):
|
694 |
+
def __init__(self,
|
695 |
+
input_channel=11,
|
696 |
+
vocab_size=1000,
|
697 |
+
embed_size=512,
|
698 |
+
output_size=512,
|
699 |
+
nhid=512,
|
700 |
+
levels=8,
|
701 |
+
kernel_size=2,
|
702 |
+
dropout=0,
|
703 |
+
n_max=50):
|
704 |
+
super(WordTCN, self).__init__()
|
705 |
+
self.vocab_size = vocab_size
|
706 |
+
self.embed_size = embed_size
|
707 |
+
self.output_size = output_size
|
708 |
+
channel_sizes = [nhid] * levels
|
709 |
+
self.kernel_size = kernel_size
|
710 |
+
self.dropout = dropout
|
711 |
+
self.n_max = n_max
|
712 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
713 |
+
self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True)
|
714 |
+
self.tcn = TCN(input_size=input_channel,
|
715 |
+
output_size=output_size,
|
716 |
+
num_channels=channel_sizes,
|
717 |
+
kernel_size=kernel_size,
|
718 |
+
dropout=dropout)
|
719 |
+
|
720 |
+
def forward(self, topic_vec, captions) -> object:
|
721 |
+
"""
|
722 |
+
|
723 |
+
:rtype: object
|
724 |
+
"""
|
725 |
+
captions = self.embed(captions)
|
726 |
+
embeddings = torch.cat([topic_vec, captions], dim=1)
|
727 |
+
output = self.tcn.forward(embeddings)
|
728 |
+
words = self.W_out(output)
|
729 |
+
return words
|
730 |
+
|
731 |
+
|
732 |
+
if __name__ == '__main__':
|
733 |
+
import warnings
|
734 |
+
warnings.filterwarnings("ignore")
|
735 |
+
images = torch.randn((4, 3, 224, 224))
|
736 |
+
captions = torch.ones((4, 10)).long()
|
737 |
+
hidden_state = torch.randn((4, 1, 512))
|
738 |
+
|
739 |
+
print("images:{}".format(images.shape))
|
740 |
+
print("captions:{}".format(captions.shape))
|
741 |
+
print("hidden_states:{}".format(hidden_state.shape))
|
742 |
+
|
743 |
+
extractor = VisualFeatureExtractor()
|
744 |
+
visual_features = extractor.forward(images)
|
745 |
+
print("visual_features:{}".format(visual_features.shape))
|
746 |
+
|
747 |
+
mlc = MLC()
|
748 |
+
tags, semantic_features = mlc.forward(visual_features)
|
749 |
+
print("tags:{}".format(tags.shape))
|
750 |
+
print("semantic_features:{}".format(semantic_features.shape))
|
751 |
+
|
752 |
+
co_att = CoAttention()
|
753 |
+
ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state)
|
754 |
+
print("ctx:{}".format(ctx.shape))
|
755 |
+
print("v_att:{}".format(v_att.shape))
|
756 |
+
|
757 |
+
sent_lstm = SentenceLSTM()
|
758 |
+
topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
|
759 |
+
print("Topic:{}".format(topic.shape))
|
760 |
+
print("P_STOP:{}".format(p_stop.shape))
|
761 |
+
|
762 |
+
word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
|
763 |
+
words = word_lstm.forward(topic, captions)
|
764 |
+
print("words:{}".format(words.shape))
|
765 |
+
|
766 |
+
# Expected Output
|
767 |
+
# images: torch.Size([4, 3, 224, 224])
|
768 |
+
# captions: torch.Size([4, 1, 10])
|
769 |
+
# hidden_states: torch.Size([4, 1, 512])
|
770 |
+
# visual_features: torch.Size([4, 2048, 7, 7])
|
771 |
+
# tags: torch.Size([4, 156])
|
772 |
+
# semantic_features: torch.Size([4, 10, 512])
|
773 |
+
# ctx: torch.Size([4, 512])
|
774 |
+
# Topic: torch.Size([4, 1, 512])
|
775 |
+
# P_STOP: torch.Size([4, 1, 2])
|
776 |
+
# words: torch.Size([4, 1000])
|
777 |
+
|
778 |
+
# images = torch.randn((4, 3, 224, 224))
|
779 |
+
# captions = torch.ones((4, 3, 10)).long()
|
780 |
+
# prev_outputs = torch.randn((4, 512))
|
781 |
+
# now_words = torch.ones((4, 1))
|
782 |
+
#
|
783 |
+
# ctx_records = torch.zeros((4, 10, 512))
|
784 |
+
# captions = torch.zeros((4, 10)).long()
|
785 |
+
#
|
786 |
+
# print("images:{}".format(images.shape))
|
787 |
+
# print("captions:{}".format(captions.shape))
|
788 |
+
# print("hidden_states:{}".format(prev_outputs.shape))
|
789 |
+
#
|
790 |
+
# extractor = VisualFeatureExtractor()
|
791 |
+
# visual_features = extractor.forward(images)
|
792 |
+
# print("visual_features:{}".format(visual_features.shape))
|
793 |
+
#
|
794 |
+
# mlc = MLC()
|
795 |
+
# tags, semantic_features = mlc.forward(visual_features)
|
796 |
+
# print("tags:{}".format(tags.shape))
|
797 |
+
# print("semantic_features:{}".format(semantic_features.shape))
|
798 |
+
#
|
799 |
+
# co_att = CoAttention()
|
800 |
+
# ctx = co_att.forward(visual_features, semantic_features, prev_outputs)
|
801 |
+
# print("ctx:{}".format(ctx.shape))
|
802 |
+
#
|
803 |
+
# ctx_records[:, 0, :] = ctx
|
804 |
+
#
|
805 |
+
# sent_tcn = SentenceTCN()
|
806 |
+
# topic, p_stop, prev_outputs = sent_tcn.forward(ctx_records, prev_outputs)
|
807 |
+
# print("Topic:{}".format(topic.shape))
|
808 |
+
# print("P_STOP:{}".format(p_stop.shape))
|
809 |
+
# print("Prev_Outputs:{}".format(prev_outputs.shape))
|
810 |
+
#
|
811 |
+
# captions[:, 0] = now_words.view(-1,)
|
812 |
+
#
|
813 |
+
# word_tcn = WordTCN()
|
814 |
+
# words = word_tcn.forward(topic, captions)
|
815 |
+
# print("words:{}".format(words.shape))
|
816 |
+
|
tcn.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils import weight_norm
|
4 |
+
|
5 |
+
|
6 |
+
class Chomp1d(nn.Module):
|
7 |
+
def __init__(self, chomp_size):
|
8 |
+
super(Chomp1d, self).__init__()
|
9 |
+
self.chomp_size = chomp_size
|
10 |
+
|
11 |
+
def forward(self, x) -> object:
|
12 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
13 |
+
|
14 |
+
|
15 |
+
class TemporalBlock(nn.Module):
|
16 |
+
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
|
17 |
+
super(TemporalBlock, self).__init__()
|
18 |
+
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
|
19 |
+
stride=stride, padding=padding, dilation=dilation))
|
20 |
+
self.chomp1 = Chomp1d(padding)
|
21 |
+
self.relu1 = nn.ReLU(inplace=False)
|
22 |
+
self.dropout1 = nn.Dropout(dropout)
|
23 |
+
|
24 |
+
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
|
25 |
+
stride=stride, padding=padding, dilation=dilation))
|
26 |
+
self.chomp2 = Chomp1d(padding)
|
27 |
+
self.relu2 = nn.ReLU(inplace=False)
|
28 |
+
self.dropout2 = nn.Dropout(dropout)
|
29 |
+
|
30 |
+
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
|
31 |
+
self.conv2, self.chomp2, self.relu2, self.dropout2)
|
32 |
+
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
33 |
+
self.relu = nn.ReLU(inplace=False)
|
34 |
+
self.init_weights()
|
35 |
+
|
36 |
+
def init_weights(self):
|
37 |
+
self.conv1.weight.data.normal_(0, 0.01)
|
38 |
+
self.conv2.weight.data.normal_(0, 0.01)
|
39 |
+
if self.downsample is not None:
|
40 |
+
self.downsample.weight.data.normal_(0, 0.01)
|
41 |
+
|
42 |
+
def forward(self, x) -> object:
|
43 |
+
out = self.net(x)
|
44 |
+
res = x if self.downsample is None else self.downsample(x)
|
45 |
+
return self.relu(out + res)
|
46 |
+
|
47 |
+
|
48 |
+
class TemporalConvNet(nn.Module):
|
49 |
+
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
50 |
+
super(TemporalConvNet, self).__init__()
|
51 |
+
layers = []
|
52 |
+
num_levels = len(num_channels)
|
53 |
+
for i in range(num_levels):
|
54 |
+
dilation_size = 2 ** i
|
55 |
+
in_channels = num_inputs if i == 0 else num_channels[i-1]
|
56 |
+
out_channels = num_channels[i]
|
57 |
+
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
|
58 |
+
padding=(kernel_size-1) * dilation_size, dropout=dropout)]
|
59 |
+
|
60 |
+
self.network = nn.Sequential(*layers)
|
61 |
+
|
62 |
+
def forward(self, x) -> object:
|
63 |
+
return self.network(x)
|
64 |
+
|
65 |
+
|
66 |
+
class TCN(nn.Module):
|
67 |
+
def __init__(self, input_size, output_size, num_channels, kernel_size=2, dropout=0):
|
68 |
+
super(TCN, self).__init__()
|
69 |
+
self.tcn = TemporalConvNet(num_inputs=input_size,
|
70 |
+
num_channels=num_channels,
|
71 |
+
kernel_size=kernel_size,
|
72 |
+
dropout=dropout)
|
73 |
+
self.linear = nn.Linear(num_channels[-1], output_size)
|
74 |
+
self.init_weights()
|
75 |
+
|
76 |
+
def init_weights(self):
|
77 |
+
self.linear.weight.data.normal_(0, 0.01)
|
78 |
+
self.linear.bias.data.fill_(0)
|
79 |
+
|
80 |
+
def forward(self, inputs) -> object:
|
81 |
+
y = self.tcn.forward(inputs)
|
82 |
+
output = self.linear(y[:, :, -1])
|
83 |
+
return output
|