Jyothirmai commited on
Commit
26e26de
1 Parent(s): e96b01f

Upload 10 files

Browse files
Files changed (7) hide show
  1. __init__.py +0 -0
  2. build_vocab.py +80 -0
  3. callbacks.py +1066 -0
  4. dataset.py +1 -1
  5. logger.py +71 -0
  6. models_debugger.py +816 -0
  7. tcn.py +83 -0
__init__.py ADDED
File without changes
build_vocab.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from collections import Counter
3
+ import json
4
+
5
+
6
+ class JsonReader(object):
7
+ def __init__(self, json_file):
8
+ self.data = self.__read_json(json_file)
9
+ self.keys = list(self.data.keys())
10
+
11
+ def __read_json(self, filename):
12
+ with open(filename, 'r') as f:
13
+ data = json.load(f)
14
+ return data
15
+
16
+ def __getitem__(self, item):
17
+ return self.data[item]
18
+ # return self.data[self.keys[item]]
19
+
20
+ def __len__(self):
21
+ return len(self.data)
22
+
23
+
24
+ class Vocabulary(object):
25
+ def __init__(self):
26
+ self.word2idx = {}
27
+ self.id2word = {}
28
+ self.idx = 0
29
+ self.add_word('<pad>')
30
+ self.add_word('<end>')
31
+ self.add_word('<start>')
32
+ self.add_word('<unk>')
33
+
34
+ def add_word(self, word):
35
+ if word not in self.word2idx:
36
+ self.word2idx[word] = self.idx
37
+ self.id2word[self.idx] = word
38
+ self.idx += 1
39
+
40
+ def get_word_by_id(self, id):
41
+ return self.id2word[id]
42
+
43
+ def __call__(self, word):
44
+ if word not in self.word2idx:
45
+ return self.word2idx['<unk>']
46
+ return self.word2idx[word]
47
+
48
+ def __len__(self):
49
+ return len(self.word2idx)
50
+
51
+
52
+ def build_vocab(json_file, threshold):
53
+ caption_reader = JsonReader(json_file)
54
+ counter = Counter()
55
+
56
+ for items in caption_reader:
57
+ text = items.replace('.', '').replace(',', '')
58
+ counter.update(text.lower().split(' '))
59
+ words = [word for word, cnt in counter.items() if cnt > threshold and word != '']
60
+ vocab = Vocabulary()
61
+
62
+ for word in words:
63
+ print(word)
64
+ vocab.add_word(word)
65
+ return vocab
66
+
67
+
68
+ def main(json_file, threshold, vocab_path):
69
+ vocab = build_vocab(json_file=json_file,
70
+ threshold=threshold)
71
+ with open(vocab_path, 'wb') as f:
72
+ pickle.dump(vocab, f)
73
+ print("Total vocabulary size:{}".format(len(vocab)))
74
+ print("Saved path in {}".format(vocab_path))
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main(json_file='../data/new_data/debugging_captions.json',
79
+ threshold=0,
80
+ vocab_path='../data/new_data/debug_vocab.pkl')
callbacks.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callbacks: utilities called at certain points during model training.
2
+
3
+ # Adapted from
4
+
5
+ - https://github.com/keras-team/keras
6
+ - https://github.com/bstriner/keras-tqdm/blob/master/keras_tqdm/tqdm_callback.py
7
+
8
+ """
9
+ from __future__ import absolute_import
10
+ from __future__ import division
11
+ from __future__ import print_function
12
+
13
+ import os
14
+ import csv
15
+ import six
16
+
17
+ import numpy as np
18
+ import time
19
+ import json
20
+ import warnings
21
+ from tqdm import tqdm
22
+
23
+ from collections import deque
24
+ from collections import OrderedDict
25
+ from collections import Iterable
26
+
27
+ try:
28
+ import requests
29
+ except ImportError:
30
+ requests = None
31
+
32
+
33
+ class CallbackList(object):
34
+ """Container abstracting a list of callbacks.
35
+
36
+ # Arguments
37
+ callbacks: List of `Callback` instances.
38
+ queue_length: Queue length for keeping
39
+ running statistics over callback execution time.
40
+ """
41
+
42
+ def __init__(self, callbacks=None, queue_length=10):
43
+ callbacks = callbacks or []
44
+ self.callbacks = [c for c in callbacks]
45
+ self.queue_length = queue_length
46
+
47
+ def append(self, callback):
48
+ self.callbacks.append(callback)
49
+
50
+ def set_params(self, params):
51
+ for callback in self.callbacks:
52
+ callback.set_params(params)
53
+
54
+ def set_model(self, model):
55
+ for callback in self.callbacks:
56
+ callback.set_model(model)
57
+
58
+ def on_epoch_begin(self, epoch, logs=None):
59
+ """Called at the start of an epoch.
60
+
61
+ # Arguments
62
+ epoch: integer, index of epoch.
63
+ logs: dictionary of logs.
64
+ """
65
+ logs = logs or {}
66
+ for callback in self.callbacks:
67
+ callback.on_epoch_begin(epoch, logs)
68
+ self._delta_t_batch = 0.
69
+ self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
70
+ self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
71
+
72
+ def on_epoch_end(self, epoch, logs=None):
73
+ """Called at the end of an epoch.
74
+
75
+ # Arguments
76
+ epoch: integer, index of epoch.
77
+ logs: dictionary of logs.
78
+ """
79
+ logs = logs or {}
80
+ for callback in self.callbacks:
81
+ callback.on_epoch_end(epoch, logs)
82
+
83
+ def on_batch_begin(self, batch, logs=None):
84
+ """Called right before processing a batch.
85
+
86
+ # Arguments
87
+ batch: integer, index of batch within the current epoch.
88
+ logs: dictionary of logs.
89
+ """
90
+ logs = logs or {}
91
+ t_before_callbacks = time.time()
92
+ for callback in self.callbacks:
93
+ callback.on_batch_begin(batch, logs)
94
+ self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
95
+ delta_t_median = np.median(self._delta_ts_batch_begin)
96
+ if (self._delta_t_batch > 0. and
97
+ delta_t_median > 0.95 * self._delta_t_batch and
98
+ delta_t_median > 0.1):
99
+ warnings.warn('Method on_batch_begin() is slow compared '
100
+ 'to the batch update (%f). Check your callbacks.'
101
+ % delta_t_median)
102
+ self._t_enter_batch = time.time()
103
+
104
+ def on_batch_end(self, batch, logs=None):
105
+ """Called at the end of a batch.
106
+
107
+ # Arguments
108
+ batch: integer, index of batch within the current epoch.
109
+ logs: dictionary of logs.
110
+ """
111
+ logs = logs or {}
112
+ if not hasattr(self, '_t_enter_batch'):
113
+ self._t_enter_batch = time.time()
114
+ self._delta_t_batch = time.time() - self._t_enter_batch
115
+ t_before_callbacks = time.time()
116
+ for callback in self.callbacks:
117
+ callback.on_batch_end(batch, logs)
118
+ self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
119
+ delta_t_median = np.median(self._delta_ts_batch_end)
120
+ if (self._delta_t_batch > 0. and
121
+ (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
122
+ warnings.warn('Method on_batch_end() is slow compared '
123
+ 'to the batch update (%f). Check your callbacks.'
124
+ % delta_t_median)
125
+
126
+ def on_train_begin(self, logs=None):
127
+ """Called at the beginning of training.
128
+
129
+ # Arguments
130
+ logs: dictionary of logs.
131
+ """
132
+ logs = logs or {}
133
+ for callback in self.callbacks:
134
+ callback.on_train_begin(logs)
135
+
136
+ def on_train_end(self, logs=None):
137
+ """Called at the end of training.
138
+
139
+ # Arguments
140
+ logs: dictionary of logs.
141
+ """
142
+ logs = logs or {}
143
+ for callback in self.callbacks:
144
+ callback.on_train_end(logs)
145
+
146
+ def __iter__(self):
147
+ return iter(self.callbacks)
148
+
149
+
150
+ class Callback(object):
151
+ """Abstract base class used to build new callbacks.
152
+
153
+ # Properties
154
+ params: dict. Training parameters
155
+ (eg. verbosity, batch size, number of epochs...).
156
+ model: instance of `keras.models.Model`.
157
+ Reference of the model being trained.
158
+
159
+ The `logs` dictionary that callback methods
160
+ take as argument will contain keys for quantities relevant to
161
+ the current batch or epoch.
162
+
163
+ Currently, the `.fit()` method of the `Sequential` model class
164
+ will include the following quantities in the `logs` that
165
+ it passes to its callbacks:
166
+
167
+ on_epoch_end: logs include `acc` and `loss`, and
168
+ optionally include `val_loss`
169
+ (if validation is enabled in `fit`), and `val_acc`
170
+ (if validation and accuracy monitoring are enabled).
171
+ on_batch_begin: logs include `size`,
172
+ the number of samples in the current batch.
173
+ on_batch_end: logs include `loss`, and optionally `acc`
174
+ (if accuracy monitoring is enabled).
175
+ """
176
+
177
+ def __init__(self):
178
+ self.validation_data = None
179
+ self.model = None
180
+
181
+ def set_params(self, params):
182
+ self.params = params
183
+
184
+ def set_model(self, model):
185
+ self.model = model
186
+
187
+ def on_epoch_begin(self, epoch, logs=None):
188
+ pass
189
+
190
+ def on_epoch_end(self, epoch, logs=None):
191
+ pass
192
+
193
+ def on_batch_begin(self, batch, logs=None):
194
+ pass
195
+
196
+ def on_batch_end(self, batch, logs=None):
197
+ pass
198
+
199
+ def on_train_begin(self, logs=None):
200
+ pass
201
+
202
+ def on_train_end(self, logs=None):
203
+ pass
204
+
205
+
206
+ class BaseLogger(Callback):
207
+ """Callback that accumulates epoch averages of metrics.
208
+
209
+ This callback is automatically applied to every Keras model.
210
+ """
211
+
212
+ def on_epoch_begin(self, epoch, logs=None):
213
+ self.seen = 0
214
+ self.totals = {}
215
+
216
+ def on_batch_end(self, batch, logs=None):
217
+ logs = logs or {}
218
+ batch_size = logs.get('size', 0)
219
+ self.seen += batch_size
220
+
221
+ for k, v in logs.items():
222
+ if k in self.totals:
223
+ self.totals[k] += v * batch_size
224
+ else:
225
+ self.totals[k] = v * batch_size
226
+
227
+ def on_epoch_end(self, epoch, logs=None):
228
+ if logs is not None:
229
+ for k in self.params['metrics']:
230
+ if k in self.totals:
231
+ # Make value available to next callbacks.
232
+ logs[k] = self.totals[k] / self.seen
233
+
234
+
235
+ class TerminateOnNaN(Callback):
236
+ """Callback that terminates training when a NaN loss is encountered.
237
+ """
238
+
239
+ def __init__(self):
240
+ super(TerminateOnNaN, self).__init__()
241
+
242
+ def on_batch_end(self, batch, logs=None):
243
+ logs = logs or {}
244
+ loss = logs.get('loss')
245
+ if loss is not None:
246
+ if np.isnan(loss) or np.isinf(loss):
247
+ print('Batch %d: Invalid loss, terminating training' % (batch))
248
+ self.model.stop_training = True
249
+
250
+
251
+ class History(Callback):
252
+ """Callback that records events into a `History` object.
253
+
254
+ This callback is automatically applied to
255
+ every Keras model. The `History` object
256
+ gets returned by the `fit` method of models.
257
+ """
258
+
259
+ def on_train_begin(self, logs=None):
260
+ self.epoch = []
261
+ self.history = {}
262
+
263
+ def on_epoch_end(self, epoch, logs=None):
264
+ logs = logs or {}
265
+ self.epoch.append(epoch)
266
+ for k, v in logs.items():
267
+ self.history.setdefault(k, []).append(v)
268
+
269
+
270
+ class ModelCheckpoint(Callback):
271
+ """Save the model after every epoch.
272
+
273
+ `filepath` can contain named formatting options,
274
+ which will be filled the value of `epoch` and
275
+ keys in `logs` (passed in `on_epoch_end`).
276
+
277
+ For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
278
+ then the model checkpoints will be saved with the epoch number and
279
+ the validation loss in the filename.
280
+
281
+ # Arguments
282
+ filepath: string, path to save the model file.
283
+ monitor: quantity to monitor.
284
+ verbose: verbosity mode, 0 or 1.
285
+ save_best_only: if `save_best_only=True`,
286
+ the latest best model according to
287
+ the quantity monitored will not be overwritten.
288
+ mode: one of {auto, min, max}.
289
+ If `save_best_only=True`, the decision
290
+ to overwrite the current save file is made
291
+ based on either the maximization or the
292
+ minimization of the monitored quantity. For `val_acc`,
293
+ this should be `max`, for `val_loss` this should
294
+ be `min`, etc. In `auto` mode, the direction is
295
+ automatically inferred from the name of the monitored quantity.
296
+ save_weights_only: if True, then only the model's weights will be
297
+ saved (`torch.save(self.model.state_dict(), filepath)`), else the full model
298
+ is saved (`torch.save(self.model.state_dict(), filepath)`).
299
+ period: Interval (number of epochs) between checkpoints.
300
+ """
301
+
302
+ def __init__(self, filepath, monitor='val_loss', verbose=0,
303
+ save_best_only=False, save_weights_only=False,
304
+ mode='auto', period=1):
305
+ super(ModelCheckpoint, self).__init__()
306
+ self.monitor = monitor
307
+ self.verbose = verbose
308
+ self.filepath = filepath
309
+ self.save_best_only = save_best_only
310
+ self.save_weights_only = save_weights_only
311
+ self.period = period
312
+ self.epochs_since_last_save = 0
313
+
314
+ if mode not in ['auto', 'min', 'max']:
315
+ warnings.warn('ModelCheckpoint mode %s is unknown, '
316
+ 'fallback to auto mode.' % (mode),
317
+ RuntimeWarning)
318
+ mode = 'auto'
319
+
320
+ if mode == 'min':
321
+ self.monitor_op = np.less
322
+ self.best = np.Inf
323
+ elif mode == 'max':
324
+ self.monitor_op = np.greater
325
+ self.best = -np.Inf
326
+ else:
327
+ if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
328
+ self.monitor_op = np.greater
329
+ self.best = -np.Inf
330
+ else:
331
+ self.monitor_op = np.less
332
+ self.best = np.Inf
333
+
334
+ def on_epoch_end(self, epoch, logs=None):
335
+ import torch
336
+ logs = logs or {}
337
+ self.epochs_since_last_save += 1
338
+ if self.epochs_since_last_save >= self.period:
339
+ self.epochs_since_last_save = 0
340
+ filepath = self.filepath.format(epoch=epoch + 1, **logs)
341
+ if self.save_best_only:
342
+ current = logs.get(self.monitor)
343
+ if current is None:
344
+ warnings.warn('Can save best model only with %s available, '
345
+ 'skipping.' % (self.monitor), RuntimeWarning)
346
+ else:
347
+ if self.monitor_op(current, self.best):
348
+ if self.verbose > 0:
349
+ print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
350
+ ' saving model to %s'
351
+ % (epoch + 1, self.monitor, self.best,
352
+ current, filepath))
353
+ self.best = current
354
+ if self.save_weights_only:
355
+ torch.save(self.model.state_dict(), filepath)
356
+ else:
357
+ torch.save(self.model.state_dict(), filepath)
358
+ else:
359
+ if self.verbose > 0:
360
+ print('\nEpoch %05d: %s did not improve' %
361
+ (epoch + 1, self.monitor))
362
+ else:
363
+ if self.verbose > 0:
364
+ print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
365
+ if self.save_weights_only:
366
+ torch.save(self.model.state_dict(), filepath)
367
+ else:
368
+ torch.save(self.model.state_dict(), filepath)
369
+
370
+
371
+ class EarlyStopping(Callback):
372
+ """Stop training when a monitored quantity has stopped improving.
373
+
374
+ # Arguments
375
+ monitor: quantity to be monitored.
376
+ min_delta: minimum change in the monitored quantity
377
+ to qualify as an improvement, i.e. an absolute
378
+ change of less than min_delta, will count as no
379
+ improvement.
380
+ patience: number of epochs with no improvement
381
+ after which training will be stopped.
382
+ verbose: verbosity mode.
383
+ mode: one of {auto, min, max}. In `min` mode,
384
+ training will stop when the quantity
385
+ monitored has stopped decreasing; in `max`
386
+ mode it will stop when the quantity
387
+ monitored has stopped increasing; in `auto`
388
+ mode, the direction is automatically inferred
389
+ from the name of the monitored quantity.
390
+ """
391
+
392
+ def __init__(self, monitor='val_loss',
393
+ min_delta=0, patience=0, verbose=0, mode='auto'):
394
+ super(EarlyStopping, self).__init__()
395
+
396
+ self.monitor = monitor
397
+ self.patience = patience
398
+ self.verbose = verbose
399
+ self.min_delta = min_delta
400
+ self.wait = 0
401
+ self.stopped_epoch = 0
402
+
403
+ if mode not in ['auto', 'min', 'max']:
404
+ warnings.warn('EarlyStopping mode %s is unknown, '
405
+ 'fallback to auto mode.' % mode,
406
+ RuntimeWarning)
407
+ mode = 'auto'
408
+
409
+ if mode == 'min':
410
+ self.monitor_op = np.less
411
+ elif mode == 'max':
412
+ self.monitor_op = np.greater
413
+ else:
414
+ if 'acc' in self.monitor:
415
+ self.monitor_op = np.greater
416
+ else:
417
+ self.monitor_op = np.less
418
+
419
+ if self.monitor_op == np.greater:
420
+ self.min_delta *= 1
421
+ else:
422
+ self.min_delta *= -1
423
+
424
+ def on_train_begin(self, logs=None):
425
+ # Allow instances to be re-used
426
+ self.wait = 0
427
+ self.stopped_epoch = 0
428
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
429
+
430
+ def on_epoch_end(self, epoch, logs=None):
431
+ current = logs.get(self.monitor)
432
+ if current is None:
433
+ warnings.warn(
434
+ 'Early stopping conditioned on metric `%s` '
435
+ 'which is not available. Available metrics are: %s' %
436
+ (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
437
+ )
438
+ return
439
+ if self.monitor_op(current - self.min_delta, self.best):
440
+ self.best = current
441
+ self.wait = 0
442
+ else:
443
+ self.wait += 1
444
+ if self.wait >= self.patience:
445
+ self.stopped_epoch = epoch
446
+ self.model.stop_training = True
447
+
448
+ def on_train_end(self, logs=None):
449
+ if self.stopped_epoch > 0 and self.verbose > 0:
450
+ print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
451
+
452
+
453
+ class RemoteMonitor(Callback):
454
+ """Callback used to stream events to a server.
455
+
456
+ Requires the `requests` library.
457
+ Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
458
+ HTTP POST, with a `images` argument which is a
459
+ JSON-encoded dictionary of event images.
460
+
461
+ # Arguments
462
+ root: String; root url of the target server.
463
+ path: String; path relative to `root` to which the events will be sent.
464
+ field: String; JSON field under which the images will be stored.
465
+ headers: Dictionary; optional custom HTTP headers.
466
+ """
467
+
468
+ def __init__(self,
469
+ root='http://localhost:9000',
470
+ path='/publish/epoch/end/',
471
+ field='images',
472
+ headers=None):
473
+ super(RemoteMonitor, self).__init__()
474
+
475
+ self.root = root
476
+ self.path = path
477
+ self.field = field
478
+ self.headers = headers
479
+
480
+ def on_epoch_end(self, epoch, logs=None):
481
+ if requests is None:
482
+ raise ImportError('RemoteMonitor requires '
483
+ 'the `requests` library.')
484
+ logs = logs or {}
485
+ send = {}
486
+ send['epoch'] = epoch
487
+ for k, v in logs.items():
488
+ if isinstance(v, (np.ndarray, np.generic)):
489
+ send[k] = v.item()
490
+ else:
491
+ send[k] = v
492
+ try:
493
+ requests.post(self.root + self.path,
494
+ {self.field: json.dumps(send)},
495
+ headers=self.headers)
496
+ except requests.exceptions.RequestException:
497
+ warnings.warn('Warning: could not reach RemoteMonitor '
498
+ 'root server at ' + str(self.root))
499
+
500
+
501
+ class TensorBoard(Callback):
502
+ """TensorBoard basic visualizations.
503
+
504
+ [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
505
+ is a visualization tool provided with TensorFlow.
506
+
507
+ This callback writes a log for TensorBoard, which allows
508
+ you to visualize dynamic graphs of your training and test
509
+ metrics, as well as activation histograms for the different
510
+ layers in your model.
511
+
512
+ If you have installed TensorFlow with pip, you should be able
513
+ to launch TensorBoard from the command line:
514
+ ```sh
515
+ tensorboard --logdir=/full_path_to_your_logs
516
+ ```
517
+
518
+ When using a backend other than TensorFlow, TensorBoard will still work
519
+ (if you have TensorFlow installed), but the only feature available will
520
+ be the display of the losses and metrics plots.
521
+
522
+ # Arguments
523
+ log_dir: the path of the directory where to save the log
524
+ files to be parsed by TensorBoard.
525
+ histogram_freq: frequency (in epochs) at which to compute activation
526
+ and weight histograms for the layers of the model. If set to 0,
527
+ histograms won't be computed. Validation images (or split) must be
528
+ specified for histogram visualizations.
529
+ write_graph: whether to visualize the graph in TensorBoard.
530
+ The log file can become quite large when
531
+ write_graph is set to True.
532
+ write_grads: whether to visualize gradient histograms in TensorBoard.
533
+ `histogram_freq` must be greater than 0.
534
+ batch_size: size of batch of inputs to feed to the network
535
+ for histograms computation.
536
+ write_images: whether to write model weights to visualize as
537
+ image in TensorBoard.
538
+ embeddings_freq: frequency (in epochs) at which selected embedding
539
+ layers will be saved.
540
+ embeddings_layer_names: a list of names of layers to keep eye on. If
541
+ None or empty list all the embedding layer will be watched.
542
+ embeddings_metadata: a dictionary which maps layer name to a file name
543
+ in which metadata for this embedding layer is saved. See the
544
+ [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
545
+ about metadata files format. In case if the same metadata file is
546
+ used for all embedding layers, string can be passed.
547
+ """
548
+
549
+ def __init__(self, log_dir='./logs',
550
+ histogram_freq=0,
551
+ batch_size=32,
552
+ write_graph=True,
553
+ write_grads=False,
554
+ write_images=False,
555
+ embeddings_freq=0,
556
+ embeddings_layer_names=None,
557
+ embeddings_metadata=None):
558
+ super(TensorBoard, self).__init__()
559
+ global tf, projector
560
+ try:
561
+ import tensorflow as tf
562
+ from tensorflow.contrib.tensorboard.plugins import projector
563
+ except ImportError:
564
+ raise ImportError('You need the TensorFlow module installed to use TensorBoard.')
565
+
566
+ if K.backend() != 'tensorflow':
567
+ if histogram_freq != 0:
568
+ warnings.warn('You are not using the TensorFlow backend. '
569
+ 'histogram_freq was set to 0')
570
+ histogram_freq = 0
571
+ if write_graph:
572
+ warnings.warn('You are not using the TensorFlow backend. '
573
+ 'write_graph was set to False')
574
+ write_graph = False
575
+ if write_images:
576
+ warnings.warn('You are not using the TensorFlow backend. '
577
+ 'write_images was set to False')
578
+ write_images = False
579
+ if embeddings_freq != 0:
580
+ warnings.warn('You are not using the TensorFlow backend. '
581
+ 'embeddings_freq was set to 0')
582
+ embeddings_freq = 0
583
+
584
+ self.log_dir = log_dir
585
+ self.histogram_freq = histogram_freq
586
+ self.merged = None
587
+ self.write_graph = write_graph
588
+ self.write_grads = write_grads
589
+ self.write_images = write_images
590
+ self.embeddings_freq = embeddings_freq
591
+ self.embeddings_layer_names = embeddings_layer_names
592
+ self.embeddings_metadata = embeddings_metadata or {}
593
+ self.batch_size = batch_size
594
+
595
+ def set_model(self, model):
596
+ self.model = model
597
+ if K.backend() == 'tensorflow':
598
+ self.sess = K.get_session()
599
+ if self.histogram_freq and self.merged is None:
600
+ for layer in self.model.layers:
601
+
602
+ for weight in layer.weights:
603
+ mapped_weight_name = weight.name.replace(':', '_')
604
+ tf.summary.histogram(mapped_weight_name, weight)
605
+ if self.write_grads:
606
+ grads = model.optimizer.get_gradients(model.total_loss,
607
+ weight)
608
+
609
+ def is_indexed_slices(grad):
610
+ return type(grad).__name__ == 'IndexedSlices'
611
+ grads = [
612
+ grad.values if is_indexed_slices(grad) else grad
613
+ for grad in grads]
614
+ tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
615
+ if self.write_images:
616
+ w_img = tf.squeeze(weight)
617
+ shape = K.int_shape(w_img)
618
+ if len(shape) == 2: # dense layer kernel case
619
+ if shape[0] > shape[1]:
620
+ w_img = tf.transpose(w_img)
621
+ shape = K.int_shape(w_img)
622
+ w_img = tf.reshape(w_img, [1,
623
+ shape[0],
624
+ shape[1],
625
+ 1])
626
+ elif len(shape) == 3: # convnet case
627
+ if K.image_data_format() == 'channels_last':
628
+ # switch to channels_first to display
629
+ # every kernel as a separate image
630
+ w_img = tf.transpose(w_img, perm=[2, 0, 1])
631
+ shape = K.int_shape(w_img)
632
+ w_img = tf.reshape(w_img, [shape[0],
633
+ shape[1],
634
+ shape[2],
635
+ 1])
636
+ elif len(shape) == 1: # bias case
637
+ w_img = tf.reshape(w_img, [1,
638
+ shape[0],
639
+ 1,
640
+ 1])
641
+ else:
642
+ # not possible to handle 3D convnets etc.
643
+ continue
644
+
645
+ shape = K.int_shape(w_img)
646
+ assert len(shape) == 4 and shape[-1] in [1, 3, 4]
647
+ tf.summary.image(mapped_weight_name, w_img)
648
+
649
+ if hasattr(layer, 'output'):
650
+ tf.summary.histogram('{}_out'.format(layer.name),
651
+ layer.output)
652
+ self.merged = tf.summary.merge_all()
653
+
654
+ if self.write_graph:
655
+ self.writer = tf.summary.FileWriter(self.log_dir,
656
+ self.sess.graph)
657
+ else:
658
+ self.writer = tf.summary.FileWriter(self.log_dir)
659
+
660
+ if self.embeddings_freq:
661
+ embeddings_layer_names = self.embeddings_layer_names
662
+
663
+ if not embeddings_layer_names:
664
+ embeddings_layer_names = [layer.name for layer in self.model.layers
665
+ if type(layer).__name__ == 'Embedding']
666
+
667
+ embeddings = {layer.name: layer.weights[0]
668
+ for layer in self.model.layers
669
+ if layer.name in embeddings_layer_names}
670
+
671
+ self.saver = tf.train.Saver(list(embeddings.values()))
672
+
673
+ embeddings_metadata = {}
674
+
675
+ if not isinstance(self.embeddings_metadata, str):
676
+ embeddings_metadata = self.embeddings_metadata
677
+ else:
678
+ embeddings_metadata = {layer_name: self.embeddings_metadata
679
+ for layer_name in embeddings.keys()}
680
+
681
+ config = projector.ProjectorConfig()
682
+ self.embeddings_ckpt_path = os.path.join(self.log_dir,
683
+ 'keras_embedding.ckpt')
684
+
685
+ for layer_name, tensor in embeddings.items():
686
+ embedding = config.embeddings.add()
687
+ embedding.tensor_name = tensor.name
688
+
689
+ if layer_name in embeddings_metadata:
690
+ embedding.metadata_path = embeddings_metadata[layer_name]
691
+
692
+ projector.visualize_embeddings(self.writer, config)
693
+
694
+ def on_epoch_end(self, epoch, logs=None):
695
+ logs = logs or {}
696
+
697
+ if not self.validation_data and self.histogram_freq:
698
+ raise ValueError('If printing histograms, validation_data must be '
699
+ 'provided, and cannot be a generator.')
700
+ if self.validation_data and self.histogram_freq:
701
+ if epoch % self.histogram_freq == 0:
702
+
703
+ val_data = self.validation_data
704
+ tensors = (self.model.inputs +
705
+ self.model.targets +
706
+ self.model.sample_weights)
707
+
708
+ if self.model.uses_learning_phase:
709
+ tensors += [K.learning_phase()]
710
+
711
+ assert len(val_data) == len(tensors)
712
+ val_size = val_data[0].shape[0]
713
+ i = 0
714
+ while i < val_size:
715
+ step = min(self.batch_size, val_size - i)
716
+ if self.model.uses_learning_phase:
717
+ # do not slice the learning phase
718
+ batch_val = [x[i:i + step] for x in val_data[:-1]]
719
+ batch_val.append(val_data[-1])
720
+ else:
721
+ batch_val = [x[i:i + step] for x in val_data]
722
+ assert len(batch_val) == len(tensors)
723
+ feed_dict = dict(zip(tensors, batch_val))
724
+ result = self.sess.run([self.merged], feed_dict=feed_dict)
725
+ summary_str = result[0]
726
+ self.writer.add_summary(summary_str, epoch)
727
+ i += self.batch_size
728
+
729
+ if self.embeddings_freq and self.embeddings_ckpt_path:
730
+ if epoch % self.embeddings_freq == 0:
731
+ self.saver.save(self.sess,
732
+ self.embeddings_ckpt_path,
733
+ epoch)
734
+
735
+ for name, value in logs.items():
736
+ if name in ['batch', 'size']:
737
+ continue
738
+ summary = tf.Summary()
739
+ summary_value = summary.value.add()
740
+ summary_value.simple_value = value.item()
741
+ summary_value.tag = name
742
+ self.writer.add_summary(summary, epoch)
743
+ self.writer.flush()
744
+
745
+ def on_train_end(self, _):
746
+ self.writer.close()
747
+
748
+
749
+ class CSVLogger(Callback):
750
+ """Callback that streams epoch results to a csv file.
751
+
752
+ Supports all values that can be represented as a string,
753
+ including 1D iterables such as np.ndarray.
754
+
755
+ # Example
756
+
757
+ ```python
758
+ csv_logger = CSVLogger('training.log')
759
+ model.fit(X_train, Y_train, callbacks=[csv_logger])
760
+ ```
761
+
762
+ # Arguments
763
+ filename: filename of the csv file, e.g. 'run/log.csv'.
764
+ separator: string used to separate elements in the csv file.
765
+ append: True: append if file exists (useful for continuing
766
+ training). False: overwrite existing file,
767
+ output_on_train_end: An additional output file to write to
768
+ write to when training ends. An example is
769
+ CSVLogger(filename='./mylog.csv', output_on_train_end=os.sys.stdout)
770
+ """
771
+
772
+ def __init__(self, filename, separator=',', append=False, output_on_train_end=None):
773
+ self.sep = separator
774
+ self.filename = filename
775
+ self.append = append
776
+ self.writer = None
777
+ self.keys = None
778
+ self.append_header = True
779
+ self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
780
+ self.output_on_train_end = output_on_train_end
781
+ super(CSVLogger, self).__init__()
782
+
783
+ def on_train_begin(self, logs=None):
784
+ if self.append:
785
+ if os.path.exists(self.filename):
786
+ with open(self.filename, 'r' + self.file_flags) as f:
787
+ self.append_header = not bool(len(f.readline()))
788
+ self.csv_file = open(self.filename, 'a' + self.file_flags)
789
+ else:
790
+ self.csv_file = open(self.filename, 'w' + self.file_flags)
791
+
792
+ def on_epoch_end(self, epoch, logs=None):
793
+ logs = logs or {}
794
+
795
+ def handle_value(k):
796
+ is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
797
+ if isinstance(k, six.string_types):
798
+ return k
799
+ elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
800
+ return '"[%s]"' % (', '.join(map(str, k)))
801
+ else:
802
+ return k
803
+
804
+ if self.keys is None:
805
+ self.keys = sorted(logs.keys())
806
+
807
+ if self.model is not None and getattr(self.model, 'stop_training', False):
808
+ # We set NA so that csv parsers do not fail for this last epoch.
809
+ logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
810
+
811
+ if not self.writer:
812
+ class CustomDialect(csv.excel):
813
+ delimiter = self.sep
814
+
815
+ self.writer = csv.DictWriter(self.csv_file,
816
+ fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
817
+ if self.append_header:
818
+ self.writer.writeheader()
819
+
820
+ row_dict = OrderedDict({'epoch': epoch})
821
+ row_dict.update((key, handle_value(logs[key])) for key in self.keys)
822
+ self.writer.writerow(row_dict)
823
+ self.csv_file.flush()
824
+
825
+ def on_train_end(self, logs=None):
826
+ self.csv_file.close()
827
+ if os.path.exists(self.filename):
828
+ with open(self.filename, 'r' + self.file_flags) as f:
829
+ print(f.read(), file=self.output_on_train_end)
830
+ self.writer = None
831
+
832
+
833
+ class LambdaCallback(Callback):
834
+ r"""Callback for creating simple, custom callbacks on-the-fly.
835
+
836
+ This callback is constructed with anonymous functions that will be called
837
+ at the appropriate time. Note that the callbacks expects positional
838
+ arguments, as:
839
+
840
+ - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
841
+ `epoch`, `logs`
842
+ - `on_batch_begin` and `on_batch_end` expect two positional arguments:
843
+ `batch`, `logs`
844
+ - `on_train_begin` and `on_train_end` expect one positional argument:
845
+ `logs`
846
+
847
+ # Arguments
848
+ on_epoch_begin: called at the beginning of every epoch.
849
+ on_epoch_end: called at the end of every epoch.
850
+ on_batch_begin: called at the beginning of every batch.
851
+ on_batch_end: called at the end of every batch.
852
+ on_train_begin: called at the beginning of model training.
853
+ on_train_end: called at the end of model training.
854
+
855
+ # Example
856
+
857
+ ```python
858
+ # Print the batch number at the beginning of every batch.
859
+ batch_print_callback = LambdaCallback(
860
+ on_batch_begin=lambda batch,logs: print(batch))
861
+
862
+ # Stream the epoch loss to a file in JSON format. The file content
863
+ # is not well-formed JSON but rather has a JSON object per line.
864
+ import json
865
+ json_log = open('loss_log.json', mode='wt', buffering=1)
866
+ json_logging_callback = LambdaCallback(
867
+ on_epoch_end=lambda epoch, logs: json_log.write(
868
+ json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
869
+ on_train_end=lambda logs: json_log.close()
870
+ )
871
+
872
+ # Terminate some processes after having finished model training.
873
+ processes = ...
874
+ cleanup_callback = LambdaCallback(
875
+ on_train_end=lambda logs: [
876
+ p.terminate() for p in processes if p.is_alive()])
877
+
878
+ model.fit(...,
879
+ callbacks=[batch_print_callback,
880
+ json_logging_callback,
881
+ cleanup_callback])
882
+ ```
883
+ """
884
+
885
+ def __init__(self,
886
+ on_epoch_begin=None,
887
+ on_epoch_end=None,
888
+ on_batch_begin=None,
889
+ on_batch_end=None,
890
+ on_train_begin=None,
891
+ on_train_end=None,
892
+ **kwargs):
893
+ super(LambdaCallback, self).__init__()
894
+ self.__dict__.update(kwargs)
895
+ if on_epoch_begin is not None:
896
+ self.on_epoch_begin = on_epoch_begin
897
+ else:
898
+ self.on_epoch_begin = lambda epoch, logs: None
899
+ if on_epoch_end is not None:
900
+ self.on_epoch_end = on_epoch_end
901
+ else:
902
+ self.on_epoch_end = lambda epoch, logs: None
903
+ if on_batch_begin is not None:
904
+ self.on_batch_begin = on_batch_begin
905
+ else:
906
+ self.on_batch_begin = lambda batch, logs: None
907
+ if on_batch_end is not None:
908
+ self.on_batch_end = on_batch_end
909
+ else:
910
+ self.on_batch_end = lambda batch, logs: None
911
+ if on_train_begin is not None:
912
+ self.on_train_begin = on_train_begin
913
+ else:
914
+ self.on_train_begin = lambda logs: None
915
+ if on_train_end is not None:
916
+ self.on_train_end = on_train_end
917
+ else:
918
+ self.on_train_end = lambda logs: None
919
+ from sys import stderr
920
+
921
+
922
+ class TQDMCallback(Callback):
923
+ def __init__(self, outer_description="Training",
924
+ inner_description_initial="Epoch: {epoch}",
925
+ inner_description_update="Epoch: {epoch} - {metrics}",
926
+ metric_format="{name}: {value:0.3f}",
927
+ separator=", ",
928
+ leave_inner=True,
929
+ leave_outer=True,
930
+ show_inner=True,
931
+ show_outer=True,
932
+ output_file=stderr,
933
+ initial=0):
934
+ """
935
+ Construct a callback that will create and update progress bars.
936
+
937
+ :param outer_description: string for outer progress bar
938
+ :param inner_description_initial: initial format for epoch ("Epoch: {epoch}")
939
+ :param inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}")
940
+ :param metric_format: format for each metric name/value pair ("{name}: {value:0.3f}")
941
+ :param separator: separator between metrics (", ")
942
+ :param leave_inner: True to leave inner bars
943
+ :param leave_outer: True to leave outer bars
944
+ :param show_inner: False to hide inner bars
945
+ :param show_outer: False to hide outer bar
946
+ :param output_file: output file (default sys.stderr)
947
+ :param initial: Initial counter state
948
+ """
949
+ self.outer_description = outer_description
950
+ self.inner_description_initial = inner_description_initial
951
+ self.inner_description_update = inner_description_update
952
+ self.metric_format = metric_format
953
+ self.separator = separator
954
+ self.leave_inner = leave_inner
955
+ self.leave_outer = leave_outer
956
+ self.show_inner = show_inner
957
+ self.show_outer = show_outer
958
+ self.output_file = output_file
959
+ self.tqdm_outer = None
960
+ self.tqdm_inner = None
961
+ self.epoch = None
962
+ self.running_logs = None
963
+ self.inner_count = None
964
+ self.initial = initial
965
+
966
+ def tqdm(self, desc, total, leave, initial=0):
967
+ """
968
+ Extension point. Override to provide custom options to tqdm initializer.
969
+ :param desc: Description string
970
+ :param total: Total number of updates
971
+ :param leave: Leave progress bar when done
972
+ :param initial: Initial counter state
973
+ :return: new progress bar
974
+ """
975
+ return tqdm(desc=desc, total=total, leave=leave, file=self.output_file, initial=initial)
976
+
977
+ def build_tqdm_outer(self, desc, total):
978
+ """
979
+ Extension point. Override to provide custom options to outer progress bars (Epoch loop)
980
+ :param desc: Description
981
+ :param total: Number of epochs
982
+ :return: new progress bar
983
+ """
984
+ return self.tqdm(desc=desc, total=total, leave=self.leave_outer, initial=self.initial)
985
+
986
+ def build_tqdm_inner(self, desc, total):
987
+ """
988
+ Extension point. Override to provide custom options to inner progress bars (Batch loop)
989
+ :param desc: Description
990
+ :param total: Number of batches
991
+ :return: new progress bar
992
+ """
993
+ return self.tqdm(desc=desc, total=total, leave=self.leave_inner)
994
+
995
+ def on_epoch_begin(self, epoch, logs={}):
996
+ self.epoch = epoch
997
+ desc = self.inner_description_initial.format(epoch=self.epoch)
998
+ self.mode = 0 # samples
999
+ if 'samples' in self.params:
1000
+ self.inner_total = self.params['samples']
1001
+ elif 'nb_sample' in self.params:
1002
+ self.inner_total = self.params['nb_sample']
1003
+ else:
1004
+ self.mode = 1 # steps
1005
+ self.inner_total = self.params['steps']
1006
+ if self.show_inner:
1007
+ self.tqdm_inner = self.build_tqdm_inner(desc=desc, total=self.inner_total)
1008
+ self.inner_count = 0
1009
+ self.running_logs = {}
1010
+
1011
+ def on_epoch_end(self, epoch, logs={}):
1012
+ metrics = self.format_metrics(logs)
1013
+ desc = self.inner_description_update.format(epoch=epoch, metrics=metrics)
1014
+ if self.show_inner:
1015
+ self.tqdm_inner.desc = desc
1016
+ # set miniters and mininterval to 0 so last update displays
1017
+ self.tqdm_inner.miniters = 0
1018
+ self.tqdm_inner.mininterval = 0
1019
+ self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n)
1020
+ self.tqdm_inner.close()
1021
+ if self.show_outer:
1022
+ self.tqdm_outer.update(1)
1023
+
1024
+ def on_batch_begin(self, batch, logs={}):
1025
+ pass
1026
+
1027
+ def on_batch_end(self, batch, logs={}):
1028
+ if self.mode == 0:
1029
+ update = logs['size']
1030
+ else:
1031
+ update = 1
1032
+ self.inner_count += update
1033
+ if self.inner_count < self.inner_total:
1034
+ self.append_logs(logs)
1035
+ metrics = self.format_metrics(self.running_logs)
1036
+ desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics)
1037
+ if self.show_inner:
1038
+ self.tqdm_inner.desc = desc
1039
+ self.tqdm_inner.update(update)
1040
+
1041
+ def on_train_begin(self, logs={}):
1042
+ if self.show_outer:
1043
+ epochs = (self.params['epochs'] if 'epochs' in self.params
1044
+ else self.params['nb_epoch'])
1045
+ self.tqdm_outer = self.build_tqdm_outer(desc=self.outer_description,
1046
+ total=epochs)
1047
+
1048
+ def on_train_end(self, logs={}):
1049
+ if self.show_outer:
1050
+ self.tqdm_outer.close()
1051
+
1052
+ def append_logs(self, logs):
1053
+ metrics = self.params['metrics']
1054
+ for metric, value in six.iteritems(logs):
1055
+ if metric in metrics:
1056
+ if metric in self.running_logs:
1057
+ self.running_logs[metric].append(value[()])
1058
+ else:
1059
+ self.running_logs[metric] = [value[()]]
1060
+
1061
+ def format_metrics(self, logs):
1062
+ metrics = self.params['metrics']
1063
+ strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics
1064
+ if
1065
+ metric in logs]
1066
+ return self.separator.join(strings)
dataset.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import Dataset
3
  from PIL import Image
4
  import os
5
  import json
6
- from build_vocab import Vocabulary, JsonReader
7
  import numpy as np
8
  from torchvision import transforms
9
  import pickle
 
3
  from PIL import Image
4
  import os
5
  import json
6
+ from utils.build_vocab import Vocabulary, JsonReader
7
  import numpy as np
8
  from torchvision import transforms
9
  import pickle
logger.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import scipy.misc
5
+ try:
6
+ from StringIO import StringIO # Python 2.7
7
+ except ImportError:
8
+ from io import BytesIO # Python 3.x
9
+
10
+
11
+ class Logger(object):
12
+
13
+ def __init__(self, log_dir):
14
+ """Create a summary writer logging to log_dir."""
15
+ self.writer = tf.summary.FileWriter(log_dir)
16
+
17
+ def scalar_summary(self, tag, value, step):
18
+ """Log a scalar variable."""
19
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
20
+ self.writer.add_summary(summary, step)
21
+
22
+ def image_summary(self, tag, images, step):
23
+ """Log a list of images."""
24
+
25
+ img_summaries = []
26
+ for i, img in enumerate(images):
27
+ # Write the image to a string
28
+ try:
29
+ s = StringIO()
30
+ except:
31
+ s = BytesIO()
32
+ scipy.misc.toimage(img).save(s, format="png")
33
+
34
+ # Create an Image object
35
+ img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
36
+ height=img.shape[0],
37
+ width=img.shape[1])
38
+ # Create a Summary value
39
+ img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
40
+
41
+ # Create and write Summary
42
+ summary = tf.Summary(value=img_summaries)
43
+ self.writer.add_summary(summary, step)
44
+
45
+ def histo_summary(self, tag, values, step, bins=1000):
46
+ """Log a histogram of the tensor of values."""
47
+
48
+ # Create a histogram using numpy
49
+ counts, bin_edges = np.histogram(values, bins=bins)
50
+
51
+ # Fill the fields of the histogram proto
52
+ hist = tf.HistogramProto()
53
+ hist.min = float(np.min(values))
54
+ hist.max = float(np.max(values))
55
+ hist.num = int(np.prod(values.shape))
56
+ hist.sum = float(np.sum(values))
57
+ hist.sum_squares = float(np.sum(values**2))
58
+
59
+ # Drop the start of the first bin
60
+ bin_edges = bin_edges[1:]
61
+
62
+ # Add bin edges and counts
63
+ for edge in bin_edges:
64
+ hist.bucket_limit.append(edge)
65
+ for c in counts:
66
+ hist.bucket.append(c)
67
+
68
+ # Create and write Summary
69
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
70
+ self.writer.add_summary(summary, step)
71
+ self.writer.flush()
models_debugger.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import numpy as np
5
+ from torch.autograd import Variable
6
+ from torchvision.models.vgg import model_urls as vgg_model_urls
7
+ import torchvision.models as models
8
+
9
+ from utils.tcn import *
10
+
11
+
12
+ class DenseNet121(nn.Module):
13
+ def __init__(self, classes=14, pretrained=True):
14
+ super(DenseNet121, self).__init__()
15
+ self.model = torchvision.models.densenet121(pretrained=pretrained)
16
+ num_in_features = self.model.classifier.in_features
17
+ self.model.classifier = nn.Sequential(
18
+ nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
19
+ # nn.Sigmoid()
20
+ )
21
+
22
+ def forward(self, x) -> object:
23
+ """
24
+
25
+ :rtype: object
26
+ """
27
+ x = self.densenet121(x)
28
+ return x
29
+
30
+
31
+ class DenseNet161(nn.Module):
32
+ def __init__(self, classes=156, pretrained=True):
33
+ super(DenseNet161, self).__init__()
34
+ self.model = torchvision.models.densenet161(pretrained=pretrained)
35
+ num_in_features = self.model.classifier.in_features
36
+ self.model.classifier = nn.Sequential(
37
+ self.__init_linear(in_features=num_in_features, out_features=classes),
38
+ # nn.Sigmoid()
39
+ )
40
+
41
+ def __init_linear(self, in_features, out_features):
42
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
43
+ func.weight.data.normal_(0, 0.1)
44
+ return func
45
+
46
+ def forward(self, x) -> object:
47
+ """
48
+
49
+ :rtype: object
50
+ """
51
+ x = self.model(x)
52
+ return x
53
+
54
+
55
+ class DenseNet169(nn.Module):
56
+ def __init__(self, classes=156, pretrained=True):
57
+ super(DenseNet169, self).__init__()
58
+ self.model = torchvision.models.densenet169(pretrained=pretrained)
59
+ num_in_features = self.model.classifier.in_features
60
+ self.model.classifier = nn.Sequential(
61
+ self.__init_linear(in_features=num_in_features, out_features=classes),
62
+ # nn.Sigmoid()
63
+ )
64
+
65
+ def __init_linear(self, in_features, out_features):
66
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
67
+ func.weight.data.normal_(0, 0.1)
68
+ return func
69
+
70
+ def forward(self, x) -> object:
71
+ """
72
+
73
+ :rtype: object
74
+ """
75
+ x = self.model(x)
76
+ return x
77
+
78
+
79
+ class DenseNet201(nn.Module):
80
+ def __init__(self, classes=156, pretrained=True):
81
+ super(DenseNet201, self).__init__()
82
+ self.model = torchvision.models.densenet201(pretrained=pretrained)
83
+ num_in_features = self.model.classifier.in_features
84
+ self.model.classifier = nn.Sequential(
85
+ self.__init_linear(in_features=num_in_features, out_features=classes),
86
+ nn.Sigmoid()
87
+ )
88
+
89
+ def __init_linear(self, in_features, out_features):
90
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
91
+ func.weight.data.normal_(0, 0.1)
92
+ return func
93
+
94
+ def forward(self, x) -> object:
95
+ """
96
+
97
+ :rtype: object
98
+ """
99
+ x = self.model(x)
100
+ return x
101
+
102
+
103
+ class ResNet18(nn.Module):
104
+ def __init__(self, classes=156, pretrained=True):
105
+ super(ResNet18, self).__init__()
106
+ self.model = torchvision.models.resnet18(pretrained=pretrained)
107
+ num_in_features = self.model.fc.in_features
108
+ self.model.fc = nn.Sequential(
109
+ self.__init_linear(in_features=num_in_features, out_features=classes),
110
+ # nn.Sigmoid()
111
+ )
112
+
113
+ def __init_linear(self, in_features, out_features):
114
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
115
+ func.weight.data.normal_(0, 0.1)
116
+ return func
117
+
118
+ def forward(self, x) -> object:
119
+ """
120
+
121
+ :rtype: object
122
+ """
123
+ x = self.model(x)
124
+ return x
125
+
126
+
127
+ class ResNet34(nn.Module):
128
+ def __init__(self, classes=156, pretrained=True):
129
+ super(ResNet34, self).__init__()
130
+ self.model = torchvision.models.resnet34(pretrained=pretrained)
131
+ num_in_features = self.model.fc.in_features
132
+ self.model.fc = nn.Sequential(
133
+ self.__init_linear(in_features=num_in_features, out_features=classes),
134
+ # nn.Sigmoid()
135
+ )
136
+
137
+ def __init_linear(self, in_features, out_features):
138
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
139
+ func.weight.data.normal_(0, 0.1)
140
+ return func
141
+
142
+ def forward(self, x) -> object:
143
+ """
144
+
145
+ :rtype: object
146
+ """
147
+ x = self.model(x)
148
+ return x
149
+
150
+
151
+ class ResNet50(nn.Module):
152
+ def __init__(self, classes=156, pretrained=True):
153
+ super(ResNet50, self).__init__()
154
+ self.model = torchvision.models.resnet50(pretrained=pretrained)
155
+ num_in_features = self.model.fc.in_features
156
+ self.model.fc = nn.Sequential(
157
+ self.__init_linear(in_features=num_in_features, out_features=classes),
158
+ # nn.Sigmoid()
159
+ )
160
+
161
+ def __init_linear(self, in_features, out_features):
162
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
163
+ func.weight.data.normal_(0, 0.1)
164
+ return func
165
+
166
+ def forward(self, x) -> object:
167
+ """
168
+
169
+ :rtype: object
170
+ """
171
+ x = self.model(x)
172
+ return x
173
+
174
+
175
+ class ResNet101(nn.Module):
176
+ def __init__(self, classes=156, pretrained=True):
177
+ super(ResNet101, self).__init__()
178
+ self.model = torchvision.models.resnet101(pretrained=pretrained)
179
+ num_in_features = self.model.fc.in_features
180
+ self.model.fc = nn.Sequential(
181
+ self.__init_linear(in_features=num_in_features, out_features=classes),
182
+ # nn.Sigmoid()
183
+ )
184
+
185
+ def __init_linear(self, in_features, out_features):
186
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
187
+ func.weight.data.normal_(0, 0.1)
188
+ return func
189
+
190
+ def forward(self, x) -> object:
191
+ """
192
+
193
+ :rtype: object
194
+ """
195
+ x = self.model(x)
196
+ return x
197
+
198
+
199
+ class ResNet152(nn.Module):
200
+ def __init__(self, classes=156, pretrained=True):
201
+ super(ResNet152, self).__init__()
202
+ self.model = torchvision.models.resnet152(pretrained=pretrained)
203
+ num_in_features = self.model.fc.in_features
204
+ self.model.fc = nn.Sequential(
205
+ self.__init_linear(in_features=num_in_features, out_features=classes),
206
+ # nn.Sigmoid()
207
+ )
208
+
209
+ def __init_linear(self, in_features, out_features):
210
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
211
+ func.weight.data.normal_(0, 0.1)
212
+ return func
213
+
214
+ def forward(self, x) -> object:
215
+ """
216
+
217
+ :rtype: object
218
+ """
219
+ x = self.model(x)
220
+ return x
221
+
222
+
223
+ class VGG19(nn.Module):
224
+ def __init__(self, classes=14, pretrained=True):
225
+ super(VGG19, self).__init__()
226
+ self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
227
+ self.model.classifier = nn.Sequential(
228
+ self.__init_linear(in_features=25088, out_features=4096),
229
+ nn.ReLU(),
230
+ nn.Dropout(0.5),
231
+ self.__init_linear(in_features=4096, out_features=4096),
232
+ nn.ReLU(),
233
+ nn.Dropout(0.5),
234
+ self.__init_linear(in_features=4096, out_features=classes),
235
+ # nn.Sigmoid()
236
+ )
237
+
238
+ def __init_linear(self, in_features, out_features):
239
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
240
+ func.weight.data.normal_(0, 0.1)
241
+ return func
242
+
243
+ def forward(self, x) -> object:
244
+ """
245
+
246
+ :rtype: object
247
+ """
248
+ x = self.model(x)
249
+ return x
250
+
251
+
252
+ class VGG(nn.Module):
253
+ def __init__(self, tags_num):
254
+ super(VGG, self).__init__()
255
+ vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://')
256
+ self.vgg19 = models.vgg19(pretrained=True)
257
+ vgg19_classifier = list(self.vgg19.classifier.children())[:-1]
258
+ self.classifier = nn.Sequential(*vgg19_classifier)
259
+ self.fc = nn.Linear(4096, tags_num)
260
+ self.fc.apply(self.init_weights)
261
+ self.bn = nn.BatchNorm1d(tags_num, momentum=0.1)
262
+ # self.init_weights()
263
+
264
+ def init_weights(self, m):
265
+ if type(m) == nn.Linear:
266
+ self.fc.weight.data.normal_(0, 0.1)
267
+ self.fc.bias.data.fill_(0)
268
+
269
+ def forward(self, images) -> object:
270
+ """
271
+
272
+ :rtype: object
273
+ """
274
+ visual_feats = self.vgg19.features(images)
275
+ tags_classifier = visual_feats.view(visual_feats.size(0), -1)
276
+ tags_classifier = self.bn(self.fc(self.classifier(tags_classifier)))
277
+ return tags_classifier
278
+
279
+
280
+ class InceptionV3(nn.Module):
281
+ def __init__(self, classes=156, pretrained=True):
282
+ super(InceptionV3, self).__init__()
283
+ self.model = torchvision.models.inception_v3(pretrained=pretrained)
284
+ num_in_features = self.model.classifier.in_features
285
+ self.model.classifier = nn.Sequential(
286
+ self.__init_linear(in_features=num_in_features, out_features=classes),
287
+ # nn.Sigmoid()
288
+ )
289
+
290
+ def __init_linear(self, in_features, out_features):
291
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
292
+ func.weight.data.normal_(0, 0.1)
293
+ return func
294
+
295
+ def forward(self, x) -> object:
296
+ """
297
+
298
+ :rtype: object
299
+ """
300
+ x = self.model(x)
301
+ return x
302
+
303
+
304
+ class CheXNetDenseNet121(nn.Module):
305
+ def __init__(self, classes=14, pretrained=True):
306
+ super(CheXNetDenseNet121, self).__init__()
307
+ self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)
308
+ num_in_features = self.densenet121.classifier.in_features
309
+ self.densenet121.classifier = nn.Sequential(
310
+ nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
311
+ nn.Sigmoid()
312
+ )
313
+
314
+ def forward(self, x) -> object:
315
+ """
316
+
317
+ :rtype: object
318
+ """
319
+ x = self.densenet121(x)
320
+ return x
321
+
322
+
323
+ class CheXNet(nn.Module):
324
+ def __init__(self, classes=156):
325
+ super(CheXNet, self).__init__()
326
+ self.densenet121 = CheXNetDenseNet121(classes=14)
327
+ self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda()
328
+ self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict'])
329
+ self.densenet121.module.densenet121.classifier = nn.Sequential(
330
+ self.__init_linear(1024, classes),
331
+ nn.Sigmoid()
332
+ )
333
+
334
+ def __init_linear(self, in_features, out_features):
335
+ func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
336
+ func.weight.data.normal_(0, 0.1)
337
+ return func
338
+
339
+ def forward(self, x) -> object:
340
+ """
341
+
342
+ :rtype: object
343
+ """
344
+ x = self.densenet121(x)
345
+ return x
346
+
347
+
348
+ class ModelFactory(object):
349
+ def __init__(self, model_name, pretrained, classes):
350
+ self.model_name = model_name
351
+ self.pretrained = pretrained
352
+ self.classes = classes
353
+
354
+ def create_model(self):
355
+ if self.model_name == 'VGG19':
356
+ _model = VGG19(pretrained=self.pretrained, classes=self.classes)
357
+ elif self.model_name == 'DenseNet121':
358
+ _model = DenseNet121(pretrained=self.pretrained, classes=self.classes)
359
+ elif self.model_name == 'DenseNet161':
360
+ _model = DenseNet161(pretrained=self.pretrained, classes=self.classes)
361
+ elif self.model_name == 'DenseNet169':
362
+ _model = DenseNet169(pretrained=self.pretrained, classes=self.classes)
363
+ elif self.model_name == 'DenseNet201':
364
+ _model = DenseNet201(pretrained=self.pretrained, classes=self.classes)
365
+ elif self.model_name == 'CheXNet':
366
+ _model = CheXNet(classes=self.classes)
367
+ elif self.model_name == 'ResNet18':
368
+ _model = ResNet18(pretrained=self.pretrained, classes=self.classes)
369
+ elif self.model_name == 'ResNet34':
370
+ _model = ResNet34(pretrained=self.pretrained, classes=self.classes)
371
+ elif self.model_name == 'ResNet50':
372
+ _model = ResNet50(pretrained=self.pretrained, classes=self.classes)
373
+ elif self.model_name == 'ResNet101':
374
+ _model = ResNet101(pretrained=self.pretrained, classes=self.classes)
375
+ elif self.model_name == 'ResNet152':
376
+ _model = ResNet152(pretrained=self.pretrained, classes=self.classes)
377
+ elif self.model_name == 'VGG':
378
+ _model = VGG(tags_num=self.classes)
379
+ else:
380
+ _model = CheXNet(classes=self.classes)
381
+
382
+ return _model
383
+
384
+
385
+ class EncoderCNN(nn.Module):
386
+ def __init__(self, embed_size, pretrained=True):
387
+ super(EncoderCNN, self).__init__()
388
+ # TODO Extract Image features from CNN based on other models
389
+ resnet = models.resnet152(pretrained=pretrained)
390
+ modules = list(resnet.children())[:-1]
391
+ self.resnet = nn.Sequential(*modules)
392
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
393
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)
394
+ self.__init_weights()
395
+
396
+ def __init_weights(self):
397
+ self.linear.weight.data.normal_(0.0, 0.1)
398
+ self.linear.bias.data.fill_(0)
399
+
400
+ def forward(self, images) -> object:
401
+ """
402
+
403
+ :rtype: object
404
+ """
405
+ features = self.resnet(images)
406
+ features = Variable(features.data)
407
+ features = features.view(features.size(0), -1)
408
+ features = self.bn(self.linear(features))
409
+ return features
410
+
411
+
412
+ class DecoderRNN(nn.Module):
413
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
414
+ super(DecoderRNN, self).__init__()
415
+ self.embed = nn.Embedding(vocab_size, embed_size)
416
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
417
+ self.linear = nn.Linear(hidden_size, vocab_size)
418
+ self.__init_weights()
419
+ self.n_max = n_max
420
+
421
+ def __init_weights(self):
422
+ self.embed.weight.data.uniform_(-0.1, 0.1)
423
+ self.linear.weight.data.uniform_(-0.1, 0.1)
424
+ self.linear.bias.data.fill_(0)
425
+
426
+ def forward(self, features, captions) -> object:
427
+ """
428
+
429
+ :rtype: object
430
+ """
431
+ embeddings = self.embed(captions)
432
+ embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
433
+ hidden, _ = self.lstm(embeddings)
434
+ outputs = self.linear(hidden[:, -1, :])
435
+ return outputs
436
+
437
+ def sample(self, features, start_tokens):
438
+ sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
439
+ predicted = start_tokens
440
+ embeddings = features
441
+ embeddings = embeddings.unsqueeze(1)
442
+
443
+ for i in range(self.n_max):
444
+ predicted = self.embed(predicted)
445
+ embeddings = torch.cat([embeddings, predicted], dim=1)
446
+ hidden_states, _ = self.lstm(embeddings)
447
+ hidden_states = hidden_states[:, -1, :]
448
+ outputs = self.linear(hidden_states)
449
+ predicted = torch.max(outputs, 1)[1]
450
+ sampled_ids[:, i] = predicted
451
+ predicted = predicted.unsqueeze(1)
452
+ return sampled_ids
453
+
454
+
455
+ class VisualFeatureExtractor(nn.Module):
456
+ def __init__(self, pretrained=False):
457
+ super(VisualFeatureExtractor, self).__init__()
458
+ resnet = models.resnet152(pretrained=pretrained)
459
+ modules = list(resnet.children())[:-1]
460
+ self.resnet = nn.Sequential(*modules)
461
+ self.out_features = resnet.fc.in_features
462
+
463
+ def forward(self, images) -> object:
464
+ """
465
+
466
+ :rtype: object
467
+ """
468
+ features = self.resnet(images)
469
+ features = features.view(features.size(0), -1)
470
+ return features
471
+
472
+
473
+ class MLC(nn.Module):
474
+ def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10):
475
+ super(MLC, self).__init__()
476
+ self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
477
+ self.embed = nn.Embedding(classes, sementic_features_dim)
478
+ self.k = k
479
+ self.softmax = nn.Softmax()
480
+
481
+ def forward(self, visual_features) -> object:
482
+ """
483
+
484
+ :rtype: object
485
+ """
486
+ tags = self.softmax(self.classifier(visual_features))
487
+ semantic_features = self.embed(torch.topk(tags, self.k)[1])
488
+ return tags, semantic_features
489
+
490
+
491
+ class CoAttention(nn.Module):
492
+ def __init__(self, embed_size=512, hidden_size=512, visual_size=2048):
493
+ super(CoAttention, self).__init__()
494
+ self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
495
+ self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
496
+
497
+ self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
498
+ self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
499
+
500
+ self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
501
+ self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
502
+
503
+ self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
504
+ self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1)
505
+
506
+ self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
507
+ self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
508
+
509
+ self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
510
+ self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1)
511
+
512
+ self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
513
+ self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
514
+
515
+ self.tanh = nn.Tanh()
516
+ self.softmax = nn.Softmax()
517
+
518
+ def forward(self, visual_features, semantic_features, h_sent) -> object:
519
+ """
520
+ only training
521
+ :rtype: object
522
+ """
523
+ W_v = self.bn_v(self.W_v(visual_features))
524
+ W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
525
+
526
+ alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
527
+ v_att = torch.mul(alpha_v, visual_features)
528
+ # v_att = torch.mul(alpha_v, visual_features).sum(1).unsqueeze(1)
529
+
530
+ W_a_h = self.bn_a_h(self.W_a_h(h_sent))
531
+ W_a = self.bn_a(self.W_a(semantic_features))
532
+ alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
533
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
534
+ # a_att = (alpha_a * semantic_features).sum(1)
535
+ ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1)))
536
+ # return self.W_fc(self.bn_fc(torch.cat([v_att, a_att], dim=1)))
537
+ return ctx, v_att
538
+
539
+
540
+ class SentenceLSTM(nn.Module):
541
+ def __init__(self, embed_size=512, hidden_size=512, num_layers=1):
542
+ super(SentenceLSTM, self).__init__()
543
+ self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
544
+ self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
545
+ self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
546
+
547
+ self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
548
+ self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1)
549
+
550
+ self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
551
+ self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1)
552
+
553
+ self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
554
+ self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1)
555
+
556
+ self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
557
+ self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1)
558
+
559
+ self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
560
+ self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1)
561
+
562
+ self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
563
+ self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1)
564
+
565
+ self.sigmoid = nn.Sigmoid()
566
+ self.tanh = nn.Tanh()
567
+
568
+ # def forward(self, ctx, prev_hidden_state, states=None) -> object:
569
+ # """
570
+ # Only training
571
+ # :rtype: object
572
+ # """
573
+ # ctx = ctx.unsqueeze(1)
574
+ # hidden_state, states = self.lstm(ctx, states)
575
+ # topic = self.bn_topic(self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
576
+ # + self.bn_t_ctx(self.W_t_ctx(ctx)))))
577
+ # p_stop = self.bn_stop(self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
578
+ # + self.bn_stop_s(self.W_stop_s(hidden_state)))))
579
+ # return topic, p_stop, hidden_state, states
580
+
581
+ def forward(self, ctx, prev_hidden_state, states=None) -> object:
582
+ """
583
+ v2
584
+ :rtype: object
585
+ """
586
+ ctx = ctx.unsqueeze(1)
587
+ hidden_state, states = self.lstm(ctx, states)
588
+ topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
589
+ + self.W_t_ctx(ctx)))))
590
+ p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
591
+ + self.W_stop_s(hidden_state)))))
592
+ return topic, p_stop, hidden_state, states
593
+
594
+
595
+ class SentenceTCN(nn.Module):
596
+ def __init__(self,
597
+ input_channel=10,
598
+ embed_size=512,
599
+ output_size=512,
600
+ nhid=512,
601
+ levels=8,
602
+ kernel_size=2,
603
+ dropout=0):
604
+ super(SentenceTCN, self).__init__()
605
+ channel_sizes = [nhid] * levels
606
+ self.tcn = TCN(input_size=input_channel,
607
+ output_size=output_size,
608
+ num_channels=channel_sizes,
609
+ kernel_size=kernel_size,
610
+ dropout=dropout)
611
+ self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
612
+ self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
613
+ self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
614
+ self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
615
+ self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
616
+ self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True)
617
+ self.tanh = nn.Tanh()
618
+
619
+ def forward(self, ctx, prev_output) -> object:
620
+ """
621
+
622
+ :rtype: object
623
+ """
624
+ output = self.tcn.forward(ctx)
625
+ topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1))
626
+ p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output)))
627
+ return topic, p_stop, output
628
+
629
+
630
+ class WordLSTM(nn.Module):
631
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
632
+ super(WordLSTM, self).__init__()
633
+ self.embed = nn.Embedding(vocab_size, embed_size)
634
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
635
+ self.linear = nn.Linear(hidden_size, vocab_size)
636
+ self.__init_weights()
637
+ self.n_max = n_max
638
+ self.vocab_size = vocab_size
639
+
640
+ def __init_weights(self):
641
+ self.embed.weight.data.uniform_(-0.1, 0.1)
642
+ self.linear.weight.data.uniform_(-0.1, 0.1)
643
+ self.linear.bias.data.fill_(0)
644
+
645
+ def forward(self, topic_vec, captions) -> object:
646
+ """
647
+
648
+ :rtype: object
649
+ """
650
+ embeddings = self.embed(captions)
651
+ embeddings = torch.cat((topic_vec, embeddings), 1)
652
+ hidden, _ = self.lstm(embeddings)
653
+ outputs = self.linear(hidden[:, -1, :])
654
+ return outputs
655
+
656
+ def val(self, features, start_tokens):
657
+ samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size))
658
+ samples[:, 0, start_tokens[0]] = 1
659
+ predicted = start_tokens
660
+ embeddings = features
661
+ embeddings = embeddings
662
+
663
+ for i in range(1, self.n_max):
664
+ predicted = self.embed(predicted)
665
+ embeddings = torch.cat([embeddings, predicted], dim=1)
666
+ hidden_states, _ = self.lstm(embeddings)
667
+ hidden_states = hidden_states[:, -1, :]
668
+ outputs = self.linear(hidden_states)
669
+ samples[:, i, :] = outputs
670
+ predicted = torch.max(outputs, 1)[1]
671
+ predicted = predicted.unsqueeze(1)
672
+ return samples
673
+
674
+ def sample(self, features, start_tokens):
675
+ sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
676
+ sampled_ids[:, 0] = start_tokens.view(-1,)
677
+ predicted = start_tokens
678
+ embeddings = features
679
+ embeddings = embeddings
680
+
681
+ for i in range(1, self.n_max):
682
+ predicted = self.embed(predicted)
683
+ embeddings = torch.cat([embeddings, predicted], dim=1)
684
+ hidden_states, _ = self.lstm(embeddings)
685
+ hidden_states = hidden_states[:, -1, :]
686
+ outputs = self.linear(hidden_states)
687
+ predicted = torch.max(outputs, 1)[1]
688
+ sampled_ids[:, i] = predicted
689
+ predicted = predicted.unsqueeze(1)
690
+ return sampled_ids
691
+
692
+
693
+ class WordTCN(nn.Module):
694
+ def __init__(self,
695
+ input_channel=11,
696
+ vocab_size=1000,
697
+ embed_size=512,
698
+ output_size=512,
699
+ nhid=512,
700
+ levels=8,
701
+ kernel_size=2,
702
+ dropout=0,
703
+ n_max=50):
704
+ super(WordTCN, self).__init__()
705
+ self.vocab_size = vocab_size
706
+ self.embed_size = embed_size
707
+ self.output_size = output_size
708
+ channel_sizes = [nhid] * levels
709
+ self.kernel_size = kernel_size
710
+ self.dropout = dropout
711
+ self.n_max = n_max
712
+ self.embed = nn.Embedding(vocab_size, embed_size)
713
+ self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True)
714
+ self.tcn = TCN(input_size=input_channel,
715
+ output_size=output_size,
716
+ num_channels=channel_sizes,
717
+ kernel_size=kernel_size,
718
+ dropout=dropout)
719
+
720
+ def forward(self, topic_vec, captions) -> object:
721
+ """
722
+
723
+ :rtype: object
724
+ """
725
+ captions = self.embed(captions)
726
+ embeddings = torch.cat([topic_vec, captions], dim=1)
727
+ output = self.tcn.forward(embeddings)
728
+ words = self.W_out(output)
729
+ return words
730
+
731
+
732
+ if __name__ == '__main__':
733
+ import warnings
734
+ warnings.filterwarnings("ignore")
735
+ images = torch.randn((4, 3, 224, 224))
736
+ captions = torch.ones((4, 10)).long()
737
+ hidden_state = torch.randn((4, 1, 512))
738
+
739
+ print("images:{}".format(images.shape))
740
+ print("captions:{}".format(captions.shape))
741
+ print("hidden_states:{}".format(hidden_state.shape))
742
+
743
+ extractor = VisualFeatureExtractor()
744
+ visual_features = extractor.forward(images)
745
+ print("visual_features:{}".format(visual_features.shape))
746
+
747
+ mlc = MLC()
748
+ tags, semantic_features = mlc.forward(visual_features)
749
+ print("tags:{}".format(tags.shape))
750
+ print("semantic_features:{}".format(semantic_features.shape))
751
+
752
+ co_att = CoAttention()
753
+ ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state)
754
+ print("ctx:{}".format(ctx.shape))
755
+ print("v_att:{}".format(v_att.shape))
756
+
757
+ sent_lstm = SentenceLSTM()
758
+ topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
759
+ print("Topic:{}".format(topic.shape))
760
+ print("P_STOP:{}".format(p_stop.shape))
761
+
762
+ word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
763
+ words = word_lstm.forward(topic, captions)
764
+ print("words:{}".format(words.shape))
765
+
766
+ # Expected Output
767
+ # images: torch.Size([4, 3, 224, 224])
768
+ # captions: torch.Size([4, 1, 10])
769
+ # hidden_states: torch.Size([4, 1, 512])
770
+ # visual_features: torch.Size([4, 2048, 7, 7])
771
+ # tags: torch.Size([4, 156])
772
+ # semantic_features: torch.Size([4, 10, 512])
773
+ # ctx: torch.Size([4, 512])
774
+ # Topic: torch.Size([4, 1, 512])
775
+ # P_STOP: torch.Size([4, 1, 2])
776
+ # words: torch.Size([4, 1000])
777
+
778
+ # images = torch.randn((4, 3, 224, 224))
779
+ # captions = torch.ones((4, 3, 10)).long()
780
+ # prev_outputs = torch.randn((4, 512))
781
+ # now_words = torch.ones((4, 1))
782
+ #
783
+ # ctx_records = torch.zeros((4, 10, 512))
784
+ # captions = torch.zeros((4, 10)).long()
785
+ #
786
+ # print("images:{}".format(images.shape))
787
+ # print("captions:{}".format(captions.shape))
788
+ # print("hidden_states:{}".format(prev_outputs.shape))
789
+ #
790
+ # extractor = VisualFeatureExtractor()
791
+ # visual_features = extractor.forward(images)
792
+ # print("visual_features:{}".format(visual_features.shape))
793
+ #
794
+ # mlc = MLC()
795
+ # tags, semantic_features = mlc.forward(visual_features)
796
+ # print("tags:{}".format(tags.shape))
797
+ # print("semantic_features:{}".format(semantic_features.shape))
798
+ #
799
+ # co_att = CoAttention()
800
+ # ctx = co_att.forward(visual_features, semantic_features, prev_outputs)
801
+ # print("ctx:{}".format(ctx.shape))
802
+ #
803
+ # ctx_records[:, 0, :] = ctx
804
+ #
805
+ # sent_tcn = SentenceTCN()
806
+ # topic, p_stop, prev_outputs = sent_tcn.forward(ctx_records, prev_outputs)
807
+ # print("Topic:{}".format(topic.shape))
808
+ # print("P_STOP:{}".format(p_stop.shape))
809
+ # print("Prev_Outputs:{}".format(prev_outputs.shape))
810
+ #
811
+ # captions[:, 0] = now_words.view(-1,)
812
+ #
813
+ # word_tcn = WordTCN()
814
+ # words = word_tcn.forward(topic, captions)
815
+ # print("words:{}".format(words.shape))
816
+
tcn.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import weight_norm
4
+
5
+
6
+ class Chomp1d(nn.Module):
7
+ def __init__(self, chomp_size):
8
+ super(Chomp1d, self).__init__()
9
+ self.chomp_size = chomp_size
10
+
11
+ def forward(self, x) -> object:
12
+ return x[:, :, :-self.chomp_size].contiguous()
13
+
14
+
15
+ class TemporalBlock(nn.Module):
16
+ def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
17
+ super(TemporalBlock, self).__init__()
18
+ self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
19
+ stride=stride, padding=padding, dilation=dilation))
20
+ self.chomp1 = Chomp1d(padding)
21
+ self.relu1 = nn.ReLU(inplace=False)
22
+ self.dropout1 = nn.Dropout(dropout)
23
+
24
+ self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
25
+ stride=stride, padding=padding, dilation=dilation))
26
+ self.chomp2 = Chomp1d(padding)
27
+ self.relu2 = nn.ReLU(inplace=False)
28
+ self.dropout2 = nn.Dropout(dropout)
29
+
30
+ self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
31
+ self.conv2, self.chomp2, self.relu2, self.dropout2)
32
+ self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
33
+ self.relu = nn.ReLU(inplace=False)
34
+ self.init_weights()
35
+
36
+ def init_weights(self):
37
+ self.conv1.weight.data.normal_(0, 0.01)
38
+ self.conv2.weight.data.normal_(0, 0.01)
39
+ if self.downsample is not None:
40
+ self.downsample.weight.data.normal_(0, 0.01)
41
+
42
+ def forward(self, x) -> object:
43
+ out = self.net(x)
44
+ res = x if self.downsample is None else self.downsample(x)
45
+ return self.relu(out + res)
46
+
47
+
48
+ class TemporalConvNet(nn.Module):
49
+ def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
50
+ super(TemporalConvNet, self).__init__()
51
+ layers = []
52
+ num_levels = len(num_channels)
53
+ for i in range(num_levels):
54
+ dilation_size = 2 ** i
55
+ in_channels = num_inputs if i == 0 else num_channels[i-1]
56
+ out_channels = num_channels[i]
57
+ layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
58
+ padding=(kernel_size-1) * dilation_size, dropout=dropout)]
59
+
60
+ self.network = nn.Sequential(*layers)
61
+
62
+ def forward(self, x) -> object:
63
+ return self.network(x)
64
+
65
+
66
+ class TCN(nn.Module):
67
+ def __init__(self, input_size, output_size, num_channels, kernel_size=2, dropout=0):
68
+ super(TCN, self).__init__()
69
+ self.tcn = TemporalConvNet(num_inputs=input_size,
70
+ num_channels=num_channels,
71
+ kernel_size=kernel_size,
72
+ dropout=dropout)
73
+ self.linear = nn.Linear(num_channels[-1], output_size)
74
+ self.init_weights()
75
+
76
+ def init_weights(self):
77
+ self.linear.weight.data.normal_(0, 0.01)
78
+ self.linear.bias.data.fill_(0)
79
+
80
+ def forward(self, inputs) -> object:
81
+ y = self.tcn.forward(inputs)
82
+ output = self.linear(y[:, :, -1])
83
+ return output