franciszzj commited on
Commit
c964d4c
1 Parent(s): c6b26ba
README.md CHANGED
@@ -1,13 +1,31 @@
1
- ---
2
- title: TreeFormer
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.32.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TreeFormer
3
+
4
+ This is the code base for IEEE TRANSACTIONS ON GEOSCIENCE AND REMOTE SENSING (TGRS 2023) paper ['TreeFormer: a Semi-Supervised Transformer-based Framework for Tree Counting from a Single High Resolution Image'](https://arxiv.org/abs/2307.06118)
5
+
6
+ <img src="sample_imgs/overview.png">
7
+
8
+ ## Installation
9
+
10
+ Python ≥ 3.7.
11
+
12
+ To install the required packages, please run:
13
+
14
+
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Dataset
20
+ Download the dataset from [google drive](https://drive.google.com/file/d/1xcjv8967VvvzcDM4aqAi7Corkb11T0i2/view?usp=drive_link).
21
+ ## Evaluation
22
+ Download our trained model on [London](https://drive.google.com/file/d/14uuOF5758sxtM5EgeGcRtSln5lUXAHge/view?usp=sharing) dataset.
23
+
24
+ Modify the path to the dataset and model for evaluation in 'test.py'.
25
+
26
+ Run 'test.py'
27
+ ## Acknowledgements
28
+
29
+ - Part of codes are borrowed from [PVT](https://github.com/whai362/PVT) and [DM Count](https://github.com/cvlab-stonybrook/DM-Count). Thanks for their great work!
30
+
31
+
assets/EU.png ADDED
assets/reset.png ADDED
datasets/__init__.py ADDED
File without changes
datasets/crowd.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch.utils.data as data
3
+ import os
4
+ from glob import glob
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from torchvision import transforms
8
+ import random
9
+ import numpy as np
10
+ import scipy.io as sio
11
+
12
+ def random_crop(im_h, im_w, crop_h, crop_w):
13
+ res_h = im_h - crop_h
14
+ res_w = im_w - crop_w
15
+ i = random.randint(0, res_h)
16
+ j = random.randint(0, res_w)
17
+ return i, j, crop_h, crop_w
18
+
19
+
20
+ def gen_discrete_map(im_height, im_width, points):
21
+ """
22
+ func: generate the discrete map.
23
+ points: [num_gt, 2], for each row: [width, height]
24
+ """
25
+ discrete_map = np.zeros([im_height, im_width], dtype=np.float32)
26
+ h, w = discrete_map.shape[:2]
27
+ num_gt = points.shape[0]
28
+ if num_gt == 0:
29
+ return discrete_map
30
+
31
+ # fast create discrete map
32
+ points_np = np.array(points).round().astype(int)
33
+ p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int))
34
+ p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int))
35
+ p_index = torch.from_numpy(p_h* im_width + p_w).to(torch.int64)
36
+ discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy()
37
+
38
+ ''' slow method
39
+ for p in points:
40
+ p = np.round(p).astype(int)
41
+ p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0])
42
+ discrete_map[p[0], p[1]] += 1
43
+ '''
44
+ assert np.sum(discrete_map) == num_gt
45
+ return discrete_map
46
+
47
+
48
+ class Base(data.Dataset):
49
+ def __init__(self, root_path, crop_size, downsample_ratio=8):
50
+
51
+ self.root_path = root_path
52
+ self.c_size = crop_size
53
+ self.d_ratio = downsample_ratio
54
+ assert self.c_size % self.d_ratio == 0
55
+ self.dc_size = self.c_size // self.d_ratio
56
+ self.trans = transforms.Compose([
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+
61
+ def __len__(self):
62
+ pass
63
+
64
+ def __getitem__(self, item):
65
+ pass
66
+
67
+ def train_transform(self, img, keypoints, gauss_im):
68
+ wd, ht = img.size
69
+ st_size = 1.0 * min(wd, ht)
70
+ assert st_size >= self.c_size
71
+ assert len(keypoints) >= 0
72
+ i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
73
+ img = F.crop(img, i, j, h, w)
74
+ gauss_im = F.crop(img, i, j, h, w)
75
+ if len(keypoints) > 0:
76
+ keypoints = keypoints - [j, i]
77
+ idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
78
+ (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
79
+ keypoints = keypoints[idx_mask]
80
+ else:
81
+ keypoints = np.empty([0, 2])
82
+
83
+ gt_discrete = gen_discrete_map(h, w, keypoints)
84
+ down_w = w // self.d_ratio
85
+ down_h = h // self.d_ratio
86
+ gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
87
+ assert np.sum(gt_discrete) == len(keypoints)
88
+
89
+ if len(keypoints) > 0:
90
+ if random.random() > 0.5:
91
+ img = F.hflip(img)
92
+ gauss_im = F.hflip(gauss_im)
93
+ gt_discrete = np.fliplr(gt_discrete)
94
+ keypoints[:, 0] = w - keypoints[:, 0]
95
+ else:
96
+ if random.random() > 0.5:
97
+ img = F.hflip(img)
98
+ gauss_im = F.hflip(gauss_im)
99
+ gt_discrete = np.fliplr(gt_discrete)
100
+ gt_discrete = np.expand_dims(gt_discrete, 0)
101
+
102
+ return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
103
+
104
+
105
+
106
+ class Crowd_TC(Base):
107
+ def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'):
108
+ super().__init__(root_path, crop_size, downsample_ratio)
109
+ self.method = method
110
+ if method not in ['train', 'val']:
111
+ raise Exception("not implement")
112
+
113
+ self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
114
+
115
+ print('number of img [{}]: {}'.format(method, len(self.im_list)))
116
+
117
+ def __len__(self):
118
+ return len(self.im_list)
119
+
120
+ def __getitem__(self, item):
121
+ img_path = self.im_list[item]
122
+ name = os.path.basename(img_path).split('.')[0]
123
+ gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name))
124
+ img = Image.open(img_path).convert('RGB')
125
+ keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0]
126
+ gauss_path = os.path.join(self.root_path, 'ground_truth', '{}_densitymap.npy'.format(name))
127
+ gauss_im = torch.from_numpy(np.load(gauss_path)).float()
128
+ #import pdb;pdb.set_trace()
129
+ #print("label {}", item)
130
+
131
+ if self.method == 'train':
132
+ return self.train_transform(img, keypoints, gauss_im)
133
+ elif self.method == 'val':
134
+ wd, ht = img.size
135
+ st_size = 1.0 * min(wd, ht)
136
+ if st_size < self.c_size:
137
+ rr = 1.0 * self.c_size / st_size
138
+ wd = round(wd * rr)
139
+ ht = round(ht * rr)
140
+ st_size = 1.0 * min(wd, ht)
141
+ img = img.resize((wd, ht), Image.BICUBIC)
142
+ img = self.trans(img)
143
+ #import pdb;pdb.set_trace()
144
+
145
+ return img, len(keypoints), name, gauss_im
146
+
147
+ def train_transform(self, img, keypoints, gauss_im):
148
+ wd, ht = img.size
149
+ st_size = 1.0 * min(wd, ht)
150
+ # resize the image to fit the crop size
151
+ if st_size < self.c_size:
152
+ rr = 1.0 * self.c_size / st_size
153
+ wd = round(wd * rr)
154
+ ht = round(ht * rr)
155
+ st_size = 1.0 * min(wd, ht)
156
+ img = img.resize((wd, ht), Image.BICUBIC)
157
+ #gauss_im = gauss_im.resize((wd, ht), Image.BICUBIC)
158
+ keypoints = keypoints * rr
159
+ assert st_size >= self.c_size, print(wd, ht)
160
+ assert len(keypoints) >= 0
161
+ i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
162
+ img = F.crop(img, i, j, h, w)
163
+ gauss_im = F.crop(gauss_im, i, j, h, w)
164
+ if len(keypoints) > 0:
165
+ keypoints = keypoints - [j, i]
166
+ idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
167
+ (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
168
+ keypoints = keypoints[idx_mask]
169
+ else:
170
+ keypoints = np.empty([0, 2])
171
+
172
+ gt_discrete = gen_discrete_map(h, w, keypoints)
173
+ down_w = w // self.d_ratio
174
+ down_h = h // self.d_ratio
175
+ gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
176
+ assert np.sum(gt_discrete) == len(keypoints)
177
+
178
+
179
+ if len(keypoints) > 0:
180
+ if random.random() > 0.5:
181
+ img = F.hflip(img)
182
+ gauss_im = F.hflip(gauss_im)
183
+ gt_discrete = np.fliplr(gt_discrete)
184
+ keypoints[:, 0] = w - keypoints[:, 0] - 1
185
+ else:
186
+ if random.random() > 0.5:
187
+ img = F.hflip(img)
188
+ gauss_im = F.hflip(gauss_im)
189
+ gt_discrete = np.fliplr(gt_discrete)
190
+ gt_discrete = np.expand_dims(gt_discrete, 0)
191
+ #import pdb;pdb.set_trace()
192
+
193
+ return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
194
+
195
+
196
+ class Base_UL(data.Dataset):
197
+ def __init__(self, root_path, crop_size, downsample_ratio=8):
198
+ self.root_path = root_path
199
+ self.c_size = crop_size
200
+ self.d_ratio = downsample_ratio
201
+ assert self.c_size % self.d_ratio == 0
202
+ self.dc_size = self.c_size // self.d_ratio
203
+ self.trans = transforms.Compose([
204
+ transforms.ToTensor(),
205
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
206
+ ])
207
+
208
+ def __len__(self):
209
+ pass
210
+
211
+ def __getitem__(self, item):
212
+ pass
213
+
214
+ def train_transform_ul(self, img):
215
+ wd, ht = img.size
216
+ st_size = 1.0 * min(wd, ht)
217
+ assert st_size >= self.c_size
218
+ i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
219
+ img = F.crop(img, i, j, h, w)
220
+
221
+ if random.random() > 0.5:
222
+ img = F.hflip(img)
223
+
224
+ return self.trans(img)
225
+
226
+
227
+ class Crowd_UL_TC(Base_UL):
228
+ def __init__(self, root_path, crop_size, downsample_ratio=8, method='train_ul'):
229
+ super().__init__(root_path, crop_size, downsample_ratio)
230
+ self.method = method
231
+ if method not in ['train_ul']:
232
+ raise Exception("not implement")
233
+
234
+ self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
235
+ print('number of img [{}]: {}'.format(method, len(self.im_list)))
236
+
237
+ def __len__(self):
238
+ return len(self.im_list)
239
+
240
+ def __getitem__(self, item):
241
+ img_path = self.im_list[item]
242
+ name = os.path.basename(img_path).split('.')[0]
243
+ img = Image.open(img_path).convert('RGB')
244
+ #print("un_label {}", item)
245
+
246
+ return self.train_transform_ul(img)
247
+
248
+
249
+ def train_transform_ul(self, img):
250
+ wd, ht = img.size
251
+ st_size = 1.0 * min(wd, ht)
252
+ # resize the image to fit the crop size
253
+ if st_size < self.c_size:
254
+ rr = 1.0 * self.c_size / st_size
255
+ wd = round(wd * rr)
256
+ ht = round(ht * rr)
257
+ st_size = 1.0 * min(wd, ht)
258
+ img = img.resize((wd, ht), Image.BICUBIC)
259
+
260
+ assert st_size >= self.c_size, print(wd, ht)
261
+
262
+ i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
263
+ img = F.crop(img, i, j, h, w)
264
+ if random.random() > 0.5:
265
+ img = F.hflip(img)
266
+
267
+ return self.trans(img),1
268
+
demo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+
8
+ from PIL import Image
9
+ from network import pvt_cls as TCN
10
+
11
+ import gradio as gr
12
+
13
+
14
+ def demo(img_path):
15
+ # config
16
+ batch_size = 8
17
+ crop_size = 256
18
+ model_path = '/users/k21163430/workspace/TreeFormer/models/best_model.pth'
19
+
20
+ device = torch.device('cuda')
21
+
22
+ # prepare model
23
+ model = TCN.pvt_treeformer(pretrained=False)
24
+ model.to(device)
25
+ model.load_state_dict(torch.load(model_path, device))
26
+ model.eval()
27
+
28
+ # preprocess
29
+ img = Image.open(img_path).convert('RGB')
30
+ show_img = np.array(img)
31
+ wd, ht = img.size
32
+ st_size = 1.0 * min(wd, ht)
33
+ if st_size < crop_size:
34
+ rr = 1.0 * crop_size / st_size
35
+ wd = round(wd * rr)
36
+ ht = round(ht * rr)
37
+ st_size = 1.0 * min(wd, ht)
38
+ img = img.resize((wd, ht), Image.BICUBIC)
39
+ transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
+ img = transform(img)
44
+ img = img.unsqueeze(0)
45
+
46
+ # model forward
47
+ with torch.no_grad():
48
+ inputs = img.to(device)
49
+ crop_imgs, crop_masks = [], []
50
+ b, c, h, w = inputs.size()
51
+ rh, rw = crop_size, crop_size
52
+
53
+ for i in range(0, h, rh):
54
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
55
+
56
+ for j in range(0, w, rw):
57
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
58
+ crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
59
+ mask = torch.zeros([b, 1, h, w]).to(device)
60
+ mask[:, :, gis:gie, gjs:gje].fill_(1.0)
61
+ crop_masks.append(mask)
62
+ crop_imgs, crop_masks = map(lambda x: torch.cat(
63
+ x, dim=0), (crop_imgs, crop_masks))
64
+
65
+ crop_preds = []
66
+ nz, bz = crop_imgs.size(0), batch_size
67
+ for i in range(0, nz, bz):
68
+
69
+ gs, gt = i, min(nz, i + bz)
70
+ crop_pred, _ = model(crop_imgs[gs:gt])
71
+ crop_pred = crop_pred[0]
72
+
73
+ _, _, h1, w1 = crop_pred.size()
74
+ crop_pred = F.interpolate(crop_pred, size=(
75
+ h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
76
+ crop_preds.append(crop_pred)
77
+ crop_preds = torch.cat(crop_preds, dim=0)
78
+
79
+ # splice them to the original size
80
+ idx = 0
81
+ pred_map = torch.zeros([b, 1, h, w]).to(device)
82
+ for i in range(0, h, rh):
83
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
84
+ for j in range(0, w, rw):
85
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
86
+ pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
87
+ idx += 1
88
+ # for the overlapping area, compute average value
89
+ mask = crop_masks.sum(dim=0).unsqueeze(0)
90
+ outputs = pred_map / mask
91
+
92
+ outputs = F.interpolate(outputs, size=(
93
+ h, w), mode='bilinear', align_corners=True)/4
94
+ outputs = pred_map / mask
95
+ model_output = round(torch.sum(outputs).item())
96
+
97
+ print("{}: {}".format(img_path, model_output))
98
+ outputs = outputs.squeeze().cpu().numpy()
99
+ outputs = (outputs - np.min(outputs)) / \
100
+ (np.max(outputs) - np.min(outputs))
101
+
102
+ show_img = show_img / 255.0
103
+ show_img = show_img * 0.2 + outputs[:, :, None] * 0.8
104
+
105
+ return model_output, show_img
106
+
107
+
108
+ if __name__ == "__main__":
109
+ # test
110
+ # img_path = sys.argv[1]
111
+ # demo(img)
112
+
113
+ # Launch a gr.Interface
114
+ gr_demo = gr.Interface(fn=demo,
115
+ inputs=gr.Image(source="upload",
116
+ type="filepath",
117
+ label="Input Image",
118
+ width=768,
119
+ height=768,
120
+ ),
121
+ outputs=[
122
+ gr.Number(label="Predicted Tree Count"),
123
+ gr.Image(label="Density Map",
124
+ width=768,
125
+ height=768,
126
+ )
127
+ ],
128
+ title="TreeFormer",
129
+ description="TreeFormer is a semi-supervised transformer-based framework for tree counting from a single high resolution image. Upload an image and TreeFormer will predict the number of trees in the image and generate a density map of the trees.",
130
+ article="This work has been developed a spart of the ReSET project which has received funding from the European Union's Horizon 2020 FET Proactive Programme under grant agreement No 101017857. The contents of this publication are the sole responsibility of the ReSET consortium and do not necessarily reflect the opinion of the European Union.",
131
+ examples=[
132
+ ["./examples/IMG_101.jpg"],
133
+ ["./examples/IMG_125.jpg"],
134
+ ["./examples/IMG_138.jpg"],
135
+ ["./examples/IMG_180.jpg"],
136
+ ["./examples/IMG_18.jpg"],
137
+ ["./examples/IMG_206.jpg"],
138
+ ["./examples/IMG_223.jpg"],
139
+ ["./examples/IMG_247.jpg"],
140
+ ["./examples/IMG_270.jpg"],
141
+ ["./examples/IMG_306.jpg"],
142
+ ],
143
+ # cache_examples=True,
144
+ examples_per_page=10,
145
+ allow_flagging=False,
146
+ theme=gr.themes.Default(),
147
+ )
148
+ gr_demo.launch(share=True, server_port=7861, favicon_path="./assets/reset.png")
examples/IMG_101.jpg ADDED
examples/IMG_125.jpg ADDED
examples/IMG_138.jpg ADDED
examples/IMG_18.jpg ADDED
examples/IMG_180.jpg ADDED
examples/IMG_206.jpg ADDED
examples/IMG_223.jpg ADDED
examples/IMG_247.jpg ADDED
examples/IMG_270.jpg ADDED
examples/IMG_306.jpg ADDED
losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
losses/bregman_pytorch.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Rewrite ot.bregman.sinkhorn in Python Optimal Transport (https://pythonot.github.io/_modules/ot/bregman.html#sinkhorn)
4
+ using pytorch operations.
5
+ Bregman projections for regularized OT (Sinkhorn distance).
6
+ """
7
+
8
+ import torch
9
+
10
+ M_EPS = 1e-16
11
+
12
+
13
+ def sinkhorn(a, b, C, reg=1e-1, method='sinkhorn', maxIter=1000, tau=1e3,
14
+ stopThr=1e-9, verbose=False, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
15
+ """
16
+ Solve the entropic regularization optimal transport
17
+ The input should be PyTorch tensors
18
+ The function solves the following optimization problem:
19
+
20
+ .. math::
21
+ \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
22
+ s.t. \gamma 1 = a
23
+ \gamma^T 1= b
24
+ \gamma\geq 0
25
+ where :
26
+ - C is the (ns,nt) metric cost matrix
27
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
28
+ - a and b are target and source measures (sum to 1)
29
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
30
+
31
+ Parameters
32
+ ----------
33
+ a : torch.tensor (na,)
34
+ samples measure in the target domain
35
+ b : torch.tensor (nb,)
36
+ samples in the source domain
37
+ C : torch.tensor (na,nb)
38
+ loss matrix
39
+ reg : float
40
+ Regularization term > 0
41
+ method : str
42
+ method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
43
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
44
+ maxIter : int, optional
45
+ Max number of iterations
46
+ stopThr : float, optional
47
+ Stop threshol on error ( > 0 )
48
+ verbose : bool, optional
49
+ Print information along iterations
50
+ log : bool, optional
51
+ record log if True
52
+
53
+ Returns
54
+ -------
55
+ gamma : (na x nb) torch.tensor
56
+ Optimal transportation matrix for the given parameters
57
+ log : dict
58
+ log dictionary return only if log==True in parameters
59
+
60
+ References
61
+ ----------
62
+ [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
63
+ See Also
64
+ --------
65
+
66
+ """
67
+
68
+ if method.lower() == 'sinkhorn':
69
+ return sinkhorn_knopp(a, b, C, reg, maxIter=maxIter,
70
+ stopThr=stopThr, verbose=verbose, log=log,
71
+ warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
72
+ **kwargs)
73
+ elif method.lower() == 'sinkhorn_stabilized':
74
+ return sinkhorn_stabilized(a, b, C, reg, maxIter=maxIter, tau=tau,
75
+ stopThr=stopThr, verbose=verbose, log=log,
76
+ warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
77
+ **kwargs)
78
+ elif method.lower() == 'sinkhorn_epsilon_scaling':
79
+ return sinkhorn_epsilon_scaling(a, b, C, reg,
80
+ maxIter=maxIter, maxInnerIter=100, tau=tau,
81
+ scaling_base=0.75, scaling_coef=None, stopThr=stopThr,
82
+ verbose=False, log=log, warm_start=warm_start, eval_freq=eval_freq,
83
+ print_freq=print_freq, **kwargs)
84
+ else:
85
+ raise ValueError("Unknown method '%s'." % method)
86
+
87
+
88
+ def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9,
89
+ verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
90
+ """
91
+ Solve the entropic regularization optimal transport
92
+ The input should be PyTorch tensors
93
+ The function solves the following optimization problem:
94
+
95
+ .. math::
96
+ \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
97
+ s.t. \gamma 1 = a
98
+ \gamma^T 1= b
99
+ \gamma\geq 0
100
+ where :
101
+ - C is the (ns,nt) metric cost matrix
102
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
103
+ - a and b are target and source measures (sum to 1)
104
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
105
+
106
+ Parameters
107
+ ----------
108
+ a : torch.tensor (na,)
109
+ samples measure in the target domain
110
+ b : torch.tensor (nb,)
111
+ samples in the source domain
112
+ C : torch.tensor (na,nb)
113
+ loss matrix
114
+ reg : float
115
+ Regularization term > 0
116
+ maxIter : int, optional
117
+ Max number of iterations
118
+ stopThr : float, optional
119
+ Stop threshol on error ( > 0 )
120
+ verbose : bool, optional
121
+ Print information along iterations
122
+ log : bool, optional
123
+ record log if True
124
+
125
+ Returns
126
+ -------
127
+ gamma : (na x nb) torch.tensor
128
+ Optimal transportation matrix for the given parameters
129
+ log : dict
130
+ log dictionary return only if log==True in parameters
131
+
132
+ References
133
+ ----------
134
+ [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
135
+ See Also
136
+ --------
137
+
138
+ """
139
+
140
+ device = a.device
141
+ na, nb = C.shape
142
+
143
+ assert na >= 1 and nb >= 1, 'C needs to be 2d'
144
+ assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
145
+ assert reg > 0, 'reg should be greater than 0'
146
+ assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
147
+
148
+ if log:
149
+ log = {'err': []}
150
+
151
+ if warm_start is not None:
152
+ u = warm_start['u']
153
+ v = warm_start['v']
154
+ else:
155
+ u = torch.ones(na, dtype=a.dtype).to(device) / na
156
+ v = torch.ones(nb, dtype=b.dtype).to(device) / nb
157
+
158
+ K = torch.empty(C.shape, dtype=C.dtype).to(device)
159
+ torch.div(C, -reg, out=K)
160
+ torch.exp(K, out=K)
161
+
162
+ b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
163
+
164
+ it = 1
165
+ err = 1
166
+
167
+ # allocate memory beforehand
168
+ KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
169
+ Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
170
+
171
+ while (err > stopThr and it <= maxIter):
172
+ upre, vpre = u, v
173
+ torch.matmul(u, K, out=KTu)
174
+ v = torch.div(b, KTu + M_EPS)
175
+ torch.matmul(K, v, out=Kv)
176
+ u = torch.div(a, Kv + M_EPS)
177
+
178
+ if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
179
+ torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
180
+ print('Warning: numerical errors at iteration', it)
181
+ u, v = upre, vpre
182
+ break
183
+
184
+ if log and it % eval_freq == 0:
185
+ # we can speed up the process by checking for the error only all
186
+ # the eval_freq iterations
187
+ # below is equivalent to:
188
+ # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0)
189
+ # but with more memory efficient
190
+ b_hat = torch.matmul(u, K) * v
191
+ err = (b - b_hat).pow(2).sum().item()
192
+ # err = (b - b_hat).abs().sum().item()
193
+ log['err'].append(err)
194
+
195
+ if verbose and it % print_freq == 0:
196
+ print('iteration {:5d}, constraint error {:5e}'.format(it, err))
197
+
198
+ it += 1
199
+
200
+ if log:
201
+ log['u'] = u
202
+ log['v'] = v
203
+ log['alpha'] = reg * torch.log(u + M_EPS)
204
+ log['beta'] = reg * torch.log(v + M_EPS)
205
+
206
+ # transport plan
207
+ P = u.reshape(-1, 1) * K * v.reshape(1, -1)
208
+ if log:
209
+ return P, log
210
+ else:
211
+ return P
212
+
213
+
214
+ def sinkhorn_stabilized(a, b, C, reg=1e-1, maxIter=1000, tau=1e3, stopThr=1e-9,
215
+ verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
216
+ """
217
+ Solve the entropic regularization OT problem with log stabilization
218
+ The function solves the following optimization problem:
219
+
220
+ .. math::
221
+ \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
222
+ s.t. \gamma 1 = a
223
+ \gamma^T 1= b
224
+ \gamma\geq 0
225
+ where :
226
+ - C is the (ns,nt) metric cost matrix
227
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
228
+ - a and b are target and source measures (sum to 1)
229
+
230
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]
231
+ but with the log stabilization proposed in [3] an defined in [2] (Algo 3.1)
232
+
233
+ Parameters
234
+ ----------
235
+ a : torch.tensor (na,)
236
+ samples measure in the target domain
237
+ b : torch.tensor (nb,)
238
+ samples in the source domain
239
+ C : torch.tensor (na,nb)
240
+ loss matrix
241
+ reg : float
242
+ Regularization term > 0
243
+ tau : float
244
+ thershold for max value in u or v for log scaling
245
+ maxIter : int, optional
246
+ Max number of iterations
247
+ stopThr : float, optional
248
+ Stop threshol on error ( > 0 )
249
+ verbose : bool, optional
250
+ Print information along iterations
251
+ log : bool, optional
252
+ record log if True
253
+
254
+ Returns
255
+ -------
256
+ gamma : (na x nb) torch.tensor
257
+ Optimal transportation matrix for the given parameters
258
+ log : dict
259
+ log dictionary return only if log==True in parameters
260
+
261
+ References
262
+ ----------
263
+ [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
264
+ [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
265
+ [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
266
+
267
+ See Also
268
+ --------
269
+
270
+ """
271
+
272
+ device = a.device
273
+ na, nb = C.shape
274
+
275
+ assert na >= 1 and nb >= 1, 'C needs to be 2d'
276
+ assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
277
+ assert reg > 0, 'reg should be greater than 0'
278
+ assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
279
+
280
+ if log:
281
+ log = {'err': []}
282
+
283
+ if warm_start is not None:
284
+ alpha = warm_start['alpha']
285
+ beta = warm_start['beta']
286
+ else:
287
+ alpha = torch.zeros(na, dtype=a.dtype).to(device)
288
+ beta = torch.zeros(nb, dtype=b.dtype).to(device)
289
+
290
+ u = torch.ones(na, dtype=a.dtype).to(device) / na
291
+ v = torch.ones(nb, dtype=b.dtype).to(device) / nb
292
+
293
+ def update_K(alpha, beta):
294
+ """log space computation"""
295
+ """memory efficient"""
296
+ torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K)
297
+ torch.add(K, -C, out=K)
298
+ torch.div(K, reg, out=K)
299
+ torch.exp(K, out=K)
300
+
301
+ def update_P(alpha, beta, u, v, ab_updated=False):
302
+ """log space P (gamma) computation"""
303
+ torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P)
304
+ torch.add(P, -C, out=P)
305
+ torch.div(P, reg, out=P)
306
+ if not ab_updated:
307
+ torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P)
308
+ torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P)
309
+ torch.exp(P, out=P)
310
+
311
+ K = torch.empty(C.shape, dtype=C.dtype).to(device)
312
+ update_K(alpha, beta)
313
+
314
+ b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
315
+
316
+ it = 1
317
+ err = 1
318
+ ab_updated = False
319
+
320
+ # allocate memory beforehand
321
+ KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
322
+ Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
323
+ P = torch.empty(C.shape, dtype=C.dtype).to(device)
324
+
325
+ while (err > stopThr and it <= maxIter):
326
+ upre, vpre = u, v
327
+ torch.matmul(u, K, out=KTu)
328
+ v = torch.div(b, KTu + M_EPS)
329
+ torch.matmul(K, v, out=Kv)
330
+ u = torch.div(a, Kv + M_EPS)
331
+
332
+ ab_updated = False
333
+ # remove numerical problems and store them in K
334
+ if u.abs().sum() > tau or v.abs().sum() > tau:
335
+ alpha += reg * torch.log(u + M_EPS)
336
+ beta += reg * torch.log(v + M_EPS)
337
+ u.fill_(1. / na)
338
+ v.fill_(1. / nb)
339
+ update_K(alpha, beta)
340
+ ab_updated = True
341
+
342
+ if log and it % eval_freq == 0:
343
+ # we can speed up the process by checking for the error only all
344
+ # the eval_freq iterations
345
+ update_P(alpha, beta, u, v, ab_updated)
346
+ b_hat = torch.sum(P, 0)
347
+ err = (b - b_hat).pow(2).sum().item()
348
+ log['err'].append(err)
349
+
350
+ if verbose and it % print_freq == 0:
351
+ print('iteration {:5d}, constraint error {:5e}'.format(it, err))
352
+
353
+ it += 1
354
+
355
+ if log:
356
+ log['u'] = u
357
+ log['v'] = v
358
+ log['alpha'] = alpha + reg * torch.log(u + M_EPS)
359
+ log['beta'] = beta + reg * torch.log(v + M_EPS)
360
+
361
+ # transport plan
362
+ update_P(alpha, beta, u, v, False)
363
+
364
+ if log:
365
+ return P, log
366
+ else:
367
+ return P
368
+
369
+
370
+ def sinkhorn_epsilon_scaling(a, b, C, reg=1e-1, maxIter=100, maxInnerIter=100, tau=1e3, scaling_base=0.75,
371
+ scaling_coef=None, stopThr=1e-9, verbose=False, log=False, warm_start=None, eval_freq=10,
372
+ print_freq=200, **kwargs):
373
+ """
374
+ Solve the entropic regularization OT problem with log stabilization
375
+ The function solves the following optimization problem:
376
+
377
+ .. math::
378
+ \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
379
+ s.t. \gamma 1 = a
380
+ \gamma^T 1= b
381
+ \gamma\geq 0
382
+ where :
383
+ - C is the (ns,nt) metric cost matrix
384
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
385
+ - a and b are target and source measures (sum to 1)
386
+
387
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
388
+ scaling algorithm as proposed in [1] but with the log stabilization
389
+ proposed in [3] and the log scaling proposed in [2] algorithm 3.2
390
+
391
+ Parameters
392
+ ----------
393
+ a : torch.tensor (na,)
394
+ samples measure in the target domain
395
+ b : torch.tensor (nb,)
396
+ samples in the source domain
397
+ C : torch.tensor (na,nb)
398
+ loss matrix
399
+ reg : float
400
+ Regularization term > 0
401
+ tau : float
402
+ thershold for max value in u or v for log scaling
403
+ maxIter : int, optional
404
+ Max number of iterations
405
+ stopThr : float, optional
406
+ Stop threshol on error ( > 0 )
407
+ verbose : bool, optional
408
+ Print information along iterations
409
+ log : bool, optional
410
+ record log if True
411
+
412
+ Returns
413
+ -------
414
+ gamma : (na x nb) torch.tensor
415
+ Optimal transportation matrix for the given parameters
416
+ log : dict
417
+ log dictionary return only if log==True in parameters
418
+
419
+ References
420
+ ----------
421
+ [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
422
+ [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
423
+ [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
424
+
425
+ See Also
426
+ --------
427
+
428
+ """
429
+
430
+ na, nb = C.shape
431
+
432
+ assert na >= 1 and nb >= 1, 'C needs to be 2d'
433
+ assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
434
+ assert reg > 0, 'reg should be greater than 0'
435
+ assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
436
+
437
+ def get_reg(it, reg, pre_reg):
438
+ if it == 1:
439
+ return scaling_coef
440
+ else:
441
+ if (pre_reg - reg) * scaling_base < M_EPS:
442
+ return reg
443
+ else:
444
+ return (pre_reg - reg) * scaling_base + reg
445
+
446
+ if scaling_coef is None:
447
+ scaling_coef = C.max() + reg
448
+
449
+ it = 1
450
+ err = 1
451
+ running_reg = scaling_coef
452
+
453
+ if log:
454
+ log = {'err': []}
455
+
456
+ warm_start = None
457
+
458
+ while (err > stopThr and it <= maxIter):
459
+ running_reg = get_reg(it, reg, running_reg)
460
+ P, _log = sinkhorn_stabilized(a, b, C, running_reg, maxIter=maxInnerIter, tau=tau,
461
+ stopThr=stopThr, verbose=False, log=True,
462
+ warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
463
+ **kwargs)
464
+
465
+ warm_start = {}
466
+ warm_start['alpha'] = _log['alpha']
467
+ warm_start['beta'] = _log['beta']
468
+
469
+ primal_val = (C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum()
470
+ dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * b).sum() - reg * P.sum()
471
+ err = primal_val - dual_val
472
+ log['err'].append(err)
473
+
474
+ if verbose and it % print_freq == 0:
475
+ print('iteration {:5d}, constraint error {:5e}'.format(it, err))
476
+
477
+ it += 1
478
+
479
+ if log:
480
+ log['alpha'] = _log['alpha']
481
+ log['beta'] = _log['beta']
482
+ return P, log
483
+ else:
484
+ return P
losses/consistency_loss.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ from losses import ramps
6
+
7
+
8
+
9
+ class consistency_weight(object):
10
+ """
11
+ ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup']
12
+ """
13
+ def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'):
14
+ self.final_w = final_w
15
+ self.iters_per_epoch = iters_per_epoch
16
+ self.rampup_starts = rampup_starts * iters_per_epoch
17
+ self.rampup_ends = rampup_ends * iters_per_epoch
18
+ self.rampup_length = (self.rampup_ends - self.rampup_starts)
19
+ self.rampup_func = getattr(ramps, ramp_type)
20
+ self.current_rampup = 0
21
+
22
+ def __call__(self, epoch, curr_iter):
23
+ cur_total_iter = self.iters_per_epoch * epoch + curr_iter
24
+ if cur_total_iter < self.rampup_starts:
25
+ return 0
26
+ self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length)
27
+ return self.final_w * self.current_rampup
28
+
29
+
30
+ def CE_loss(input_logits, target_targets, ignore_index, temperature=1):
31
+ return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index)
32
+
33
+ # for FocalLoss
34
+ def softmax_helper(x):
35
+ # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
36
+ rpt = [1 for _ in range(len(x.size()))]
37
+ rpt[1] = x.size(1)
38
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
39
+ e_x = torch.exp(x - x_max)
40
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
41
+
42
+ def get_alpha(supervised_loader):
43
+ # get number of classes
44
+ num_labels = 0
45
+ for image_batch, label_batch in supervised_loader:
46
+ label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
47
+ l_unique = torch.unique(label_batch.data)
48
+ list_unique = [element.item() for element in l_unique.flatten()]
49
+ num_labels = max(max(list_unique),num_labels)
50
+ num_classes = num_labels + 1
51
+ # count class occurrences
52
+ alpha = [0 for i in range(num_classes)]
53
+ for image_batch, label_batch in supervised_loader:
54
+ label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
55
+ l_unique = torch.unique(label_batch.data)
56
+ list_unique = [element.item() for element in l_unique.flatten()]
57
+ l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
58
+ list_count = [count.item() for count in l_unique_count.flatten()]
59
+ for index in list_unique:
60
+ alpha[index] += list_count[list_unique.index(index)]
61
+ return alpha
62
+
63
+ # for FocalLoss
64
+ def softmax_helper(x):
65
+ # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
66
+ rpt = [1 for _ in range(len(x.size()))]
67
+ rpt[1] = x.size(1)
68
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
69
+ e_x = torch.exp(x - x_max)
70
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
71
+
72
+
73
+ class FocalLoss(nn.Module):
74
+ """
75
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
76
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
77
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
78
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
79
+ :param num_class:
80
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
81
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
82
+ focus on hard misclassified example
83
+ :param smooth: (float,double) smooth value when cross entropy
84
+ :param balance_index: (int) balance class index, should be specific when alpha is float
85
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
86
+ """
87
+
88
+ def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
89
+ super(FocalLoss, self).__init__()
90
+ self.apply_nonlin = apply_nonlin
91
+ self.alpha = alpha
92
+ self.gamma = gamma
93
+ self.balance_index = balance_index
94
+ self.smooth = smooth
95
+ self.size_average = size_average
96
+
97
+ if self.smooth is not None:
98
+ if self.smooth < 0 or self.smooth > 1.0:
99
+ raise ValueError('smooth value should be in [0,1]')
100
+
101
+ def forward(self, logit, target):
102
+ if self.apply_nonlin is not None:
103
+ logit = self.apply_nonlin(logit)
104
+ num_class = logit.shape[1]
105
+
106
+ if logit.dim() > 2:
107
+ # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
108
+ logit = logit.view(logit.size(0), logit.size(1), -1)
109
+ logit = logit.permute(0, 2, 1).contiguous()
110
+ logit = logit.view(-1, logit.size(-1))
111
+ target = torch.squeeze(target, 1)
112
+ target = target.view(-1, 1)
113
+
114
+ valid_mask = None
115
+ if self.ignore_index is not None:
116
+ valid_mask = target != self.ignore_index
117
+ target = target * valid_mask
118
+
119
+ alpha = self.alpha
120
+
121
+ if alpha is None:
122
+ alpha = torch.ones(num_class, 1)
123
+ elif isinstance(alpha, (list, np.ndarray)):
124
+ assert len(alpha) == num_class
125
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
126
+ alpha = alpha / alpha.sum()
127
+ alpha = 1/alpha # inverse of class frequency
128
+ elif isinstance(alpha, float):
129
+ alpha = torch.ones(num_class, 1)
130
+ alpha = alpha * (1 - self.alpha)
131
+ alpha[self.balance_index] = self.alpha
132
+
133
+ else:
134
+ raise TypeError('Not support alpha type')
135
+
136
+ if alpha.device != logit.device:
137
+ alpha = alpha.to(logit.device)
138
+
139
+ idx = target.cpu().long()
140
+
141
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
142
+
143
+ # to resolve error in idx in scatter_
144
+ idx[idx==225]=0
145
+
146
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
147
+ if one_hot_key.device != logit.device:
148
+ one_hot_key = one_hot_key.to(logit.device)
149
+
150
+ if self.smooth:
151
+ one_hot_key = torch.clamp(
152
+ one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
153
+ pt = (one_hot_key * logit).sum(1) + self.smooth
154
+ logpt = pt.log()
155
+
156
+ gamma = self.gamma
157
+
158
+ alpha = alpha[idx]
159
+ alpha = torch.squeeze(alpha)
160
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
161
+
162
+ if valid_mask is not None:
163
+ loss = loss * valid_mask.squeeze()
164
+
165
+ if self.size_average:
166
+ loss = loss.mean()
167
+ else:
168
+ loss = loss.sum()
169
+ return loss
170
+
171
+
172
+ class abCE_loss(nn.Module):
173
+ """
174
+ Annealed-Bootstrapped cross-entropy loss
175
+ """
176
+ def __init__(self, iters_per_epoch, epochs, num_classes, weight=None,
177
+ reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'):
178
+ super(abCE_loss, self).__init__()
179
+ self.weight = torch.FloatTensor(weight) if weight is not None else weight
180
+ self.reduction = reduction
181
+ self.thresh = thresh
182
+ self.min_kept = min_kept
183
+ self.ramp_type = ramp_type
184
+
185
+ if ramp_type is not None:
186
+ self.rampup_func = getattr(ramps, ramp_type)
187
+ self.iters_per_epoch = iters_per_epoch
188
+ self.num_classes = num_classes
189
+ self.start = 1/num_classes
190
+ self.end = 0.9
191
+ self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch
192
+
193
+ def threshold(self, curr_iter, epoch):
194
+ cur_total_iter = self.iters_per_epoch * epoch + curr_iter
195
+ current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters)
196
+ return current_rampup * (self.end - self.start) + self.start
197
+
198
+ def forward(self, predict, target, ignore_index, curr_iter, epoch):
199
+ batch_kept = self.min_kept * target.size(0)
200
+ prob_out = F.softmax(predict, dim=1)
201
+ tmp_target = target.clone()
202
+ tmp_target[tmp_target == ignore_index] = 0
203
+ prob = prob_out.gather(1, tmp_target.unsqueeze(1))
204
+ mask = target.contiguous().view(-1, ) != ignore_index
205
+ sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
206
+
207
+ if self.ramp_type is not None:
208
+ thresh = self.threshold(curr_iter=curr_iter, epoch=epoch)
209
+ else:
210
+ thresh = self.thresh
211
+
212
+ min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
213
+ threshold = max(min_threshold, thresh)
214
+ loss_matrix = F.cross_entropy(predict, target,
215
+ weight=self.weight.to(predict.device) if self.weight is not None else None,
216
+ ignore_index=ignore_index, reduction='none')
217
+ loss_matirx = loss_matrix.contiguous().view(-1, )
218
+ sort_loss_matirx = loss_matirx[mask][sort_indices]
219
+ select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
220
+ if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
221
+ return select_loss_matrix.sum()
222
+ elif self.reduction == 'mean':
223
+ return select_loss_matrix.mean()
224
+ else:
225
+ raise NotImplementedError('Reduction Error!')
226
+
227
+
228
+
229
+ def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
230
+ assert inputs.requires_grad == True and targets.requires_grad == False
231
+ assert inputs.size() == targets.size() # (batch_size * num_classes * H * W)
232
+ inputs = F.softmax(inputs, dim=1)
233
+ if use_softmax:
234
+ targets = F.softmax(targets, dim=1)
235
+
236
+ if conf_mask:
237
+ loss_mat = F.mse_loss(inputs, targets, reduction='none')
238
+ mask = (targets.max(1)[0] > threshold)
239
+ loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
240
+ if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
241
+ return loss_mat.mean()
242
+ else:
243
+ return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size
244
+
245
+
246
+ def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
247
+ assert inputs.requires_grad == True and targets.requires_grad == False
248
+ assert inputs.size() == targets.size()
249
+
250
+ if use_softmax:
251
+ targets = F.softmax(targets, dim=1)
252
+ if conf_mask:
253
+ loss_mat = F.kl_div(input_log_softmax, targets, reduction='none')
254
+ mask = (targets.max(1)[0] > threshold)
255
+ loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
256
+ if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
257
+ return loss_mat.sum() / mask.shape.numel()
258
+ else:
259
+ return F.kl_div(inputs, targets, reduction='mean')
260
+
261
+
262
+ def softmax_js_loss(inputs, targets, **_):
263
+ assert inputs.requires_grad == True and targets.requires_grad == False
264
+ assert inputs.size() == targets.size()
265
+ epsilon = 1e-5
266
+
267
+ M = (F.softmax(inputs, dim=1) + targets) * 0.5
268
+ kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean')
269
+ kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean')
270
+ return (kl1 + kl2) * 0.5
271
+
272
+
273
+
274
+ def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8):
275
+ """
276
+ Pair-wise loss in the sup. mat.
277
+ """
278
+ if isinstance(unsup_outputs, list):
279
+ unsup_outputs = torch.stack(unsup_outputs)
280
+
281
+ # Only for a subset of the aux outputs to reduce computation and memory
282
+ unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))]
283
+ unsup_outputs = unsup_outputs[:nbr_of_pairs]
284
+
285
+ temp = torch.zeros_like(unsup_outputs) # For grad purposes
286
+ for i, u in enumerate(unsup_outputs):
287
+ temp[i] = F.softmax(u, dim=1)
288
+ mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs
289
+ pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance
290
+ pw_loss = pw_loss.sum(1) # Sum over classes
291
+ if size_average:
292
+ return pw_loss.mean()
293
+ return pw_loss.sum()
294
+
losses/dm_loss.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from losses.consistency_loss import *
5
+ from losses.ot_loss import OT_Loss
6
+
7
+ class DMLoss(nn.Module):
8
+ def __init__(self):
9
+ super(DMLoss, self).__init__()
10
+ self.DMLoss = 0.0
11
+ self.losses = {}
12
+
13
+ def forward(self, results, points, gt_discrete):
14
+ self.DMLoss = 0.0
15
+ self.losses = {}
16
+
17
+ if results is None:
18
+ self.DMLoss = 0.0
19
+ elif isinstance(results, list) and len(results) > 0:
20
+ count = 0
21
+ for i in range(len(results[0])):
22
+ with torch.set_grad_enabled(False):
23
+ preds_mean = (results[0][i])/len(results[0][0][0])
24
+
25
+ for j in range(len(results)):
26
+ var_sel = softmax_kl_loss(results[j][i], preds_mean)
27
+ exp_var = torch.exp(-var_sel)
28
+ consistency_dist = (preds_mean - results[j][i]) ** 2
29
+ temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
30
+
31
+ self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
32
+ self.DMLoss += temploss
33
+
34
+ # Compute counting loss.
35
+ count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1),
36
+ torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
37
+ epoch_count_loss.update(count_loss.item(), N)
38
+
39
+ # Compute OT loss.
40
+ ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points)
41
+
42
+ ot_loss = ot_loss * self.args.ot
43
+ ot_obj_value = ot_obj_value * self.args.ot
44
+ epoch_ot_loss.update(ot_loss.item(), N)
45
+ epoch_ot_obj_value.update(ot_obj_value.item(), N)
46
+ epoch_wd.update(wd, N)
47
+
48
+ gd_count_tensor = (torch.from_numpy(gd_count).float()
49
+ .to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
50
+
51
+ gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
52
+ tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
53
+ torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
54
+ epoch_tv_loss.update(tv_loss.item(), N)
55
+
56
+ count += 1
57
+ if count > 0:
58
+ self.multiconloss = self.multiconloss / count
59
+
60
+
61
+ return self.multiconloss
62
+
losses/multi_con_loss.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from losses.consistency_loss import *
5
+
6
+
7
+ class MultiConLoss(nn.Module):
8
+ def __init__(self):
9
+ super(MultiConLoss, self).__init__()
10
+ self.countloss_criterion = nn.MSELoss(reduction='sum')
11
+ self.multiconloss = 0.0
12
+ self.losses = {}
13
+
14
+ def forward(self, unlabeled_results):
15
+ self.multiconloss = 0.0
16
+ self.losses = {}
17
+
18
+ if unlabeled_results is None:
19
+ self.multiconloss = 0.0
20
+ elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
21
+ count = 0
22
+ for i in range(len(unlabeled_results[0])):
23
+ with torch.set_grad_enabled(False):
24
+ preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results)
25
+ for j in range(len(unlabeled_results)):
26
+
27
+ var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean)
28
+ exp_var = torch.exp(-var_sel)
29
+ consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2
30
+ temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
31
+
32
+ self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
33
+ self.multiconloss += temploss
34
+
35
+ count += 1
36
+ if count > 0:
37
+ self.multiconloss = self.multiconloss / count
38
+
39
+
40
+ return self.multiconloss
41
+
losses/ot_loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Module
3
+ from .bregman_pytorch import sinkhorn
4
+
5
+ class OT_Loss(Module):
6
+ def __init__(self, c_size, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0):
7
+ super(OT_Loss, self).__init__()
8
+ assert c_size % stride == 0
9
+
10
+ self.c_size = c_size
11
+ self.device = device
12
+ self.norm_cood = norm_cood
13
+ self.num_of_iter_in_ot = num_of_iter_in_ot
14
+ self.reg = reg
15
+
16
+ # coordinate is same to image space, set to constant since crop size is same
17
+ self.cood = torch.arange(0, c_size, step=stride,
18
+ dtype=torch.float32, device=device) + stride / 2
19
+ self.density_size = self.cood.size(0)
20
+ self.cood.unsqueeze_(0) # [1, #cood]
21
+ if self.norm_cood:
22
+ self.cood = self.cood / c_size * 2 - 1 # map to [-1, 1]
23
+ self.output_size = self.cood.size(1)
24
+
25
+
26
+ def forward(self, normed_density, unnormed_density, points):
27
+ batch_size = normed_density.size(0)
28
+ assert len(points) == batch_size
29
+ assert self.output_size == normed_density.size(2)
30
+ loss = torch.zeros([1]).to(self.device)
31
+ ot_obj_values = torch.zeros([1]).to(self.device)
32
+ wd = 0 # wasserstain distance
33
+ for idx, im_points in enumerate(points):
34
+ if len(im_points) > 0:
35
+ # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood]
36
+ if self.norm_cood:
37
+ im_points = im_points / self.c_size * 2 - 1 # map to [-1, 1]
38
+ x = im_points[:, 0].unsqueeze_(1) # [#gt, 1]
39
+ y = im_points[:, 1].unsqueeze_(1)
40
+ x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood # [#gt, #cood]
41
+ y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood
42
+ y_dis.unsqueeze_(2)
43
+ x_dis.unsqueeze_(1)
44
+ dis = y_dis + x_dis
45
+ dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood]
46
+
47
+ source_prob = normed_density[idx][0].view([-1]).detach()
48
+ target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device)
49
+
50
+ # use sinkhorn to solve OT, compute optimal beta.
51
+ P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True)
52
+ beta = log['beta'] # size is the same as source_prob: [#cood * #cood]
53
+ ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size]))
54
+ # compute the gradient of OT loss to predicted density (unnormed_density).
55
+ # im_grad = beta / source_count - < beta, source_density> / (source_count)^2
56
+ source_density = unnormed_density[idx][0].view([-1]).detach()
57
+ source_count = source_density.sum()
58
+ im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta # size of [#cood * #cood]
59
+ im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) # size of 1
60
+ im_grad = im_grad_1 - im_grad_2
61
+ im_grad = im_grad.detach().view([1, self.output_size, self.output_size])
62
+ # Define loss = <im_grad, predicted density>. The gradient of loss w.r.t prediced density is im_grad.
63
+ loss += torch.sum(unnormed_density[idx] * im_grad)
64
+ wd += torch.sum(dis * P).item()
65
+
66
+ return loss, wd, ot_obj_values
67
+
68
+
losses/ramps.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, Curious AI Ltd. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Functions for ramping hyperparameters up or down
9
+
10
+ Each function takes the current training step or epoch, and the
11
+ ramp length in the same format, and returns a multiplier between
12
+ 0 and 1.
13
+ """
14
+
15
+
16
+ import numpy as np
17
+
18
+
19
+ def sigmoid_rampup(current, rampup_length):
20
+ """Exponential rampup from https://arxiv.org/abs/1610.02242"""
21
+ if rampup_length == 0:
22
+ return 1.0
23
+ else:
24
+ current = np.clip(current, 0.0, rampup_length)
25
+ phase = 1.0 - current / rampup_length
26
+ return float(np.exp(-5.0 * phase * phase))
27
+
28
+
29
+ def linear_rampup(current, rampup_length):
30
+ """Linear rampup"""
31
+ assert current >= 0 and rampup_length >= 0
32
+ if current >= rampup_length:
33
+ return 1.0
34
+ else:
35
+ return current / rampup_length
36
+
37
+
38
+ def cosine_rampdown(current, rampdown_length):
39
+ """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
40
+ assert 0 <= current <= rampdown_length
41
+ return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
losses/rank_loss.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class MarginRankLoss(nn.Module):
7
+ def __init__(self):
8
+ super(MarginRankLoss, self).__init__()
9
+ self.loss = 0.0
10
+
11
+ def forward(self, img_list, margin=0):
12
+ length = len(img_list)
13
+ self.loss = 0.0
14
+ B, C, H, W = img_list[0].shape
15
+ for i in range(length - 1):
16
+ for j in range(i + 1, length):
17
+ self.loss = self.loss + torch.sum(F.relu(img_list[j].sum(-1).sum(-1).sum(-1) - img_list[i].sum(-1).sum(-1).sum(-1) + margin))
18
+
19
+ self.loss = self.loss / (B*length*(length-1)/2)
20
+ return self.loss
21
+
22
+
23
+ class RankLoss(nn.Module):
24
+ def __init__(self):
25
+ super(RankLoss, self).__init__()
26
+ self.countloss_criterion = nn.MSELoss(reduction='sum')
27
+ self.rankloss_criterion = MarginRankLoss()
28
+ self.rankloss = 0.0
29
+ self.losses = {}
30
+
31
+ def forward(self, unlabeled_results):
32
+ self.rankloss = 0.0
33
+ self.losses = {}
34
+
35
+
36
+ if unlabeled_results is None:
37
+ self.rankloss = 0.0
38
+ elif isinstance(unlabeled_results, tuple) and len(unlabeled_results) > 0:
39
+ self.rankloss = self.rankloss_criterion(unlabeled_results)
40
+ elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
41
+ count = 0
42
+ for i in range(len(unlabeled_results)):
43
+ if isinstance(unlabeled_results[i], tuple) and len(unlabeled_results[i]) > 0:
44
+ temploss = self.rankloss_criterion(unlabeled_results[i])
45
+ self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
46
+ self.rankloss += temploss
47
+
48
+ count += 1
49
+ if count > 0:
50
+ self.rankloss = self.rankloss / count
51
+
52
+ return self.rankloss
53
+
network/pvt_cls.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from timm.models.registry import register_model
8
+ from timm.models.vision_transformer import _cfg
9
+
10
+ import math
11
+ from torch.distributions.uniform import Uniform
12
+ import numpy as np
13
+ import random
14
+
15
+ __all__ = [
16
+ 'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'
17
+ ]
18
+
19
+
20
+ class SELayer(nn.Module):
21
+ def __init__(self, channel, reduction=16):
22
+ super(SELayer, self).__init__()
23
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
24
+ self.fc = nn.Sequential(
25
+ nn.Linear(channel, channel // reduction, bias=False),
26
+ nn.ReLU(inplace=True),
27
+ nn.Linear(channel // reduction, channel, bias=False),
28
+ nn.Sigmoid()
29
+ )
30
+
31
+ def forward(self, x):
32
+ b, c, _, _ = x.size()
33
+ y = self.avg_pool(x).view(b, c)
34
+ y = self.fc(y).view(b, c, 1, 1)
35
+ return x * y.expand_as(x)
36
+
37
+
38
+ class Regression(nn.Module):
39
+ def __init__(self):
40
+ super(Regression, self).__init__()
41
+
42
+ self.v1 = nn.Sequential(
43
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
44
+ nn.Conv2d(256, 128, 3, padding=1, dilation=1),
45
+ nn.BatchNorm2d(128), nn.ReLU(inplace=True))
46
+
47
+ self.v2 = nn.Sequential(
48
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
49
+ nn.Conv2d(512, 256, 3, padding=1, dilation=1),
50
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True))
51
+
52
+ self.v3 = nn.Sequential(
53
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
54
+ nn.Conv2d(1024, 512, 3, padding=1, dilation=1), nn.BatchNorm2d(512),
55
+ nn.ReLU(inplace=True))
56
+
57
+ self.ca2 = nn.Sequential(ChannelAttention(512),
58
+ nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1 ),
59
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True))
60
+
61
+ self.ca1 = nn.Sequential(ChannelAttention(256),
62
+ nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1 ),
63
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True))
64
+
65
+ self.ca0 = nn.Sequential(ChannelAttention(128),
66
+ nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1 ),
67
+ nn.BatchNorm2d(128), nn.ReLU(inplace=True))
68
+
69
+ self.res2 = nn.Sequential(
70
+ nn.Conv2d(512, 256, 3, padding=1, dilation=1), nn.BatchNorm2d(256),
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
73
+ nn.ReLU(inplace=True),
74
+ nn.Conv2d(128, 1, 3, padding=1, dilation=1),
75
+ nn.ReLU(inplace=True))
76
+
77
+ self.res1 = nn.Sequential(
78
+ nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
79
+ nn.ReLU(inplace=True),
80
+ nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
81
+ nn.ReLU(inplace=True),
82
+ nn.Conv2d(64, 1, 3, padding=1, dilation=1),
83
+ nn.ReLU(inplace=True))
84
+
85
+ self.res0 = nn.Sequential(
86
+ nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
87
+ nn.ReLU(inplace=True),
88
+ nn.Conv2d(64, 1, 3, padding=1, dilation=1),
89
+ nn.ReLU(inplace=True))
90
+
91
+ self.noise2 = DropOutDecoder(1, 512, 512)
92
+ self.noise1 = FeatureDropDecoder(1, 256, 256)
93
+ self.noise0 = FeatureNoiseDecoder(1, 128, 128)
94
+
95
+ self.upsam2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
96
+ self.upsam4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
97
+
98
+ self.conv1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False)
99
+ self.conv2 = nn.Conv2d(512, 256, kernel_size=1, bias=False)
100
+ self.conv3 = nn.Conv2d(256, 128, kernel_size=1, bias=False)
101
+ self.conv4 = nn.Conv2d(128, 1, kernel_size=1, bias=False)
102
+
103
+ #cls2.view(8, 1024, 1, 1))
104
+
105
+ self.init_param()
106
+
107
+ def forward(self, x, cls):
108
+ x0 = x[0]; x1 = x[1]; x2 = x[2]; x3 = x[3]
109
+ cls0 = cls[0].view(cls[0].shape[0], cls[0].shape[1], 1, 1)
110
+ cls1 = cls[1].view(cls[1].shape[0], cls[1].shape[1], 1, 1)
111
+ cls2 = cls[2].view(cls[2].shape[0], cls[2].shape[1], 1, 1)
112
+
113
+ x2_1 = self.ca2(x2)+self.v3(x3)
114
+ x1_1 = self.ca1(x1)+self.v2(x2_1)
115
+ x0_1 = self.ca0(x0)+self.v1(x1_1)
116
+
117
+ if self.training:
118
+ yc2 = self.conv4(self.conv3(self.conv2(self.noise2(self.conv1(cls2))))).squeeze()
119
+ yc1 = self.conv4(self.conv3(self.noise1(self.conv2(cls1)))).squeeze()
120
+ yc0 = self.conv4(self.noise0(self.conv3(cls0))).squeeze()
121
+
122
+ y2 = self.res2(self.upsam4(self.noise2(x2_1)))
123
+ y1 = self.res1(self.upsam2(self.noise1(x1_1)))
124
+ y0 = self.res0(self.noise0(x0_1))
125
+
126
+ else:
127
+ yc2 = self.conv4(self.conv3(self.conv2(self.conv1(cls2)))).squeeze()
128
+ yc1 = self.conv4(self.conv3(self.conv2(cls1))).squeeze()
129
+ yc0 = self.conv4(self.conv3(cls0)).squeeze()
130
+
131
+ y2 = self.res2(self.upsam4(x2_1))
132
+ y1 = self.res1(self.upsam2(x1_1))
133
+ y0 = self.res0(x0_1)
134
+
135
+ return [y0, y1, y2], [yc0, yc1, yc2]
136
+
137
+ def init_param(self):
138
+ for m in self.modules():
139
+ if isinstance(m, nn.Conv2d):
140
+ nn.init.normal_(m.weight, std=0.01)
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+ elif isinstance(m, nn.BatchNorm2d):
144
+ nn.init.constant_(m.weight, 1)
145
+ nn.init.constant_(m.bias, 0)
146
+
147
+
148
+ class Mlp(nn.Module):
149
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = act_layer()
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+ self.drop = nn.Dropout(drop)
157
+
158
+ def forward(self, x):
159
+ x = self.fc1(x)
160
+ x = self.act(x)
161
+ x = self.drop(x)
162
+ x = self.fc2(x)
163
+ x = self.drop(x)
164
+ return x
165
+
166
+
167
+
168
+ def upsample(in_channels, out_channels, upscale, kernel_size=3):
169
+ # A series of x 2 upsamling until we get to the upscale we want
170
+ layers = []
171
+ conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
172
+ nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
173
+ layers.append(conv1x1)
174
+ for i in range(int(math.log(upscale, 2))):
175
+ layers.append(PixelShuffle(out_channels, scale=2))
176
+ return nn.Sequential(*layers)
177
+
178
+
179
+
180
+ class FeatureDropDecoder(nn.Module):
181
+ def __init__(self, upscale, conv_in_ch, num_classes):
182
+ super(FeatureDropDecoder, self).__init__()
183
+ self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
184
+
185
+ def feature_dropout(self, x):
186
+ attention = torch.mean(x, dim=1, keepdim=True)
187
+ max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
188
+ threshold = max_val * np.random.uniform(0.7, 0.9)
189
+ threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
190
+ drop_mask = (attention < threshold).float()
191
+ return x.mul(drop_mask)
192
+
193
+ def forward(self, x):
194
+ x = self.feature_dropout(x)
195
+ return x
196
+
197
+
198
+ class FeatureNoiseDecoder(nn.Module):
199
+ def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3):
200
+ super(FeatureNoiseDecoder, self).__init__()
201
+ self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
202
+ self.uni_dist = Uniform(-uniform_range, uniform_range)
203
+
204
+ def feature_based_noise(self, x):
205
+ noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
206
+ x_noise = x.mul(noise_vector) + x
207
+ return x_noise
208
+
209
+ def forward(self, x):
210
+ x = self.feature_based_noise(x)
211
+ return x
212
+
213
+ class DropOutDecoder(nn.Module):
214
+ def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True):
215
+ super(DropOutDecoder, self).__init__()
216
+ self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate)
217
+ self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
218
+
219
+ def forward(self, x):
220
+ x = self.dropout(x)
221
+ return x
222
+
223
+
224
+ ## ChannelAttetion
225
+ class ChannelAttention(nn.Module):
226
+ def __init__(self, in_planes, ratio=16):
227
+ super(ChannelAttention, self).__init__()
228
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
229
+
230
+ self.fc = nn.Sequential(
231
+ nn.Linear(in_planes,in_planes // ratio, bias = False),
232
+ nn.ReLU(inplace = True),
233
+ nn.Linear(in_planes // ratio, in_planes, bias = False)
234
+ )
235
+ self.sigmoid = nn.Sigmoid()
236
+ for m in self.modules():
237
+ if isinstance(m, nn.Conv2d):
238
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
239
+
240
+ def forward(self, in_feature):
241
+ x = in_feature
242
+ b, c, _, _ = in_feature.size()
243
+ avg_out = self.fc(self.avg_pool(x).view(b,c)).view(b, c, 1, 1)
244
+ out = avg_out
245
+ return self.sigmoid(out).expand_as(in_feature) * in_feature
246
+
247
+
248
+ class Attention(nn.Module):
249
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
250
+ super().__init__()
251
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
252
+
253
+ self.dim = dim
254
+ self.num_heads = num_heads
255
+ head_dim = dim // num_heads
256
+ self.scale = qk_scale or head_dim ** -0.5
257
+
258
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
259
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
260
+ self.attn_drop = nn.Dropout(attn_drop)
261
+ self.proj = nn.Linear(dim, dim)
262
+ self.proj_drop = nn.Dropout(proj_drop)
263
+
264
+ self.sr_ratio = sr_ratio
265
+ if sr_ratio > 1:
266
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
267
+ self.norm = nn.LayerNorm(dim)
268
+
269
+ def forward(self, x, H, W):
270
+ B, N, C = x.shape
271
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
272
+
273
+ if self.sr_ratio > 4:
274
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
275
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
276
+ x_ = self.norm(x_)
277
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
278
+ else:
279
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
280
+ k, v = kv[0], kv[1]
281
+
282
+ attn = (q @ k.transpose(-2, -1)) * self.scale
283
+ attn = attn.softmax(dim=-1)
284
+ attn = self.attn_drop(attn)
285
+
286
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
287
+ x = self.proj(x)
288
+ x = self.proj_drop(x)
289
+
290
+ return x
291
+
292
+
293
+ class Block(nn.Module):
294
+
295
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
296
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
297
+ super().__init__()
298
+ self.norm1 = norm_layer(dim)
299
+ self.attn = Attention(
300
+ dim,
301
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
302
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
303
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
304
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
305
+ self.norm2 = norm_layer(dim)
306
+ mlp_hidden_dim = int(dim * mlp_ratio)
307
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
308
+
309
+ def forward(self, x, H, W):
310
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
311
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
312
+
313
+ return x
314
+
315
+
316
+ class PatchEmbed(nn.Module):
317
+ """ Image to Patch Embedding
318
+ """
319
+
320
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
321
+ super().__init__()
322
+ img_size = to_2tuple(img_size)
323
+ patch_size = to_2tuple(patch_size)
324
+
325
+ self.img_size = img_size
326
+ self.patch_size = patch_size
327
+ # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
328
+ # f"img_size {img_size} should be divided by patch_size {patch_size}."
329
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
330
+ self.num_patches = self.H * self.W
331
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
332
+ self.norm = nn.LayerNorm(embed_dim)
333
+
334
+ def forward(self, x):
335
+ B, C, H, W = x.shape
336
+
337
+ x = self.proj(x).flatten(2).transpose(1, 2)
338
+ x = self.norm(x)
339
+ H, W = H // self.patch_size[0], W // self.patch_size[1]
340
+
341
+ return x, (H, W)
342
+
343
+
344
+ class PyramidVisionTransformer(nn.Module):
345
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
346
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
347
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
348
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4):
349
+ super().__init__()
350
+ self.num_classes = num_classes
351
+ self.depths = depths
352
+ self.num_stages = num_stages
353
+
354
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
355
+ cur = 0
356
+
357
+ for i in range(num_stages):
358
+ patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
359
+ patch_size=patch_size if i == 0 else 2,
360
+ in_chans=in_chans if i == 0 else embed_dims[i - 1],
361
+ embed_dim=embed_dims[i])
362
+ num_patches = patch_embed.num_patches if i == 0 else patch_embed.num_patches + 1
363
+ pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))
364
+ pos_drop = nn.Dropout(p=drop_rate)
365
+
366
+ block = nn.ModuleList([Block(
367
+ dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
368
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],
369
+ norm_layer=norm_layer, sr_ratio=sr_ratios[i])
370
+ for j in range(depths[i])])
371
+ cur += depths[i]
372
+
373
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
374
+ setattr(self, f"pos_embed{i + 1}", pos_embed)
375
+ setattr(self, f"pos_drop{i + 1}", pos_drop)
376
+ setattr(self, f"block{i + 1}", block)
377
+
378
+ self.norm = norm_layer(embed_dims[3])
379
+
380
+ # cls_token
381
+ self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
382
+ self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
383
+ self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
384
+
385
+ # classification head
386
+ self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
387
+
388
+
389
+ self.regression = Regression()
390
+
391
+ # init weights
392
+ for i in range(num_stages):
393
+ pos_embed = getattr(self, f"pos_embed{i + 1}")
394
+ trunc_normal_(pos_embed, std=.02)
395
+ trunc_normal_(self.cls_token_1, std=.02)
396
+ trunc_normal_(self.cls_token_2, std=.02)
397
+ trunc_normal_(self.cls_token_3, std=.02)
398
+ self.apply(self._init_weights)
399
+
400
+
401
+ def _init_weights(self, m):
402
+ if isinstance(m, nn.Linear):
403
+ trunc_normal_(m.weight, std=.02)
404
+ if isinstance(m, nn.Linear) and m.bias is not None:
405
+ nn.init.constant_(m.bias, 0)
406
+ elif isinstance(m, nn.LayerNorm):
407
+ nn.init.constant_(m.bias, 0)
408
+ nn.init.constant_(m.weight, 1.0)
409
+
410
+ @torch.jit.ignore
411
+ def no_weight_decay(self):
412
+ # return {'pos_embed', 'cls_token'} # has pos_embed may be better
413
+ return {'cls_token'}
414
+
415
+ def get_classifier(self):
416
+ return self.head
417
+
418
+ def reset_classifier(self, num_classes, global_pool=''):
419
+ self.num_classes = num_classes
420
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
421
+
422
+ def _get_pos_embed(self, pos_embed, patch_embed, H, W):
423
+ if H * W == self.patch_embed1.num_patches:
424
+ return pos_embed
425
+ else:
426
+ return F.interpolate(
427
+ pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
428
+ size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
429
+
430
+ def forward_features(self, x):
431
+ B = x.shape[0]
432
+ outputs = list()
433
+ cls_output = list()
434
+
435
+ for i in range(self.num_stages):
436
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
437
+ pos_embed = getattr(self, f"pos_embed{i + 1}")
438
+ pos_drop = getattr(self, f"pos_drop{i + 1}")
439
+ block = getattr(self, f"block{i + 1}")
440
+ x, (H, W) = patch_embed(x)
441
+
442
+ if i == 0:
443
+ pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)
444
+ elif i == 1:
445
+ cls_tokens = self.cls_token_1.expand(B, -1, -1)
446
+ x = torch.cat((cls_tokens, x), dim=1)
447
+ pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
448
+ pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
449
+ elif i == 2:
450
+ cls_tokens = self.cls_token_2.expand(B, -1, -1)
451
+ x = torch.cat((cls_tokens, x), dim=1)
452
+ pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
453
+ pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
454
+
455
+ elif i == 3:
456
+ cls_tokens = self.cls_token_3.expand(B, -1, -1)
457
+ x = torch.cat((cls_tokens, x), dim=1)
458
+ pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
459
+ pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
460
+
461
+
462
+ x = pos_drop(x + pos_embed)
463
+ for blk in block:
464
+ x = blk(x, H, W)
465
+
466
+ if i == 0:
467
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
468
+ else:
469
+
470
+ x_cls = x[:,1,:]
471
+ x = x[:,1:,:].reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
472
+ cls_output.append(x_cls)
473
+
474
+
475
+ outputs.append(x)
476
+ return outputs, cls_output
477
+
478
+
479
+ def forward(self, label_x, unlabel_x=None):
480
+
481
+ if self.training:
482
+ # labeled image processing
483
+ label_x, l_cls = self.forward_features(label_x)
484
+ out_label_x, out_cls_l = self.regression(label_x, l_cls)
485
+ label_x_1, label_x_2, label_x_3 = out_label_x
486
+
487
+ B,C,H,W = label_x_1.size()
488
+ label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
489
+ label_normed = label_x_1 / (label_sum + 1e-6)
490
+
491
+ # unlabeled image processing
492
+ B,C,H,W = unlabel_x.shape
493
+ unlabel_x, ul_cls = self.forward_features(unlabel_x)
494
+ out_unlabel_x, out_cls_ul = self.regression(unlabel_x, ul_cls)
495
+ y0, y1, y2 = out_unlabel_x
496
+
497
+ unlabel_x_1 = self.generate_feature_patches(y0)
498
+ unlabel_x_2 = self.generate_feature_patches(y1)
499
+ unlabel_x_3 = self.generate_feature_patches(y2)
500
+
501
+ assert unlabel_x_1.shape[0] == B * 5
502
+ assert unlabel_x_2.shape[0] == B * 5
503
+ assert unlabel_x_3.shape[0] == B * 5
504
+
505
+ unlabel_x_1 = torch.split(unlabel_x_1, split_size_or_sections=B, dim=0)
506
+ unlabel_x_2 = torch.split(unlabel_x_2, split_size_or_sections=B, dim=0)
507
+ unlabel_x_3 = torch.split(unlabel_x_3, split_size_or_sections=B, dim=0)
508
+
509
+ return [label_x_1, label_x_2, label_x_3], [unlabel_x_1, unlabel_x_2, unlabel_x_3], label_normed, out_cls_l, out_cls_ul
510
+
511
+
512
+ else:
513
+
514
+ label_x, l_cls = self.forward_features(label_x)
515
+ out_label_x, out_cls_l = self.regression(label_x, l_cls)
516
+ label_x_1, label_x_2, label_x_3 = out_label_x
517
+ B,C,H,W = label_x_1.size()
518
+ label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
519
+ label_normed = label_x_1 / (label_sum + 1e-6)
520
+
521
+ return [label_x_1, label_x_2, label_x_3], label_normed
522
+
523
+
524
+ def generate_feature_patches(self, unlabel_x, ratio=0.75):
525
+ # unlabeled image processing
526
+
527
+ unlabel_x_1 = unlabel_x
528
+ b, c, h, w = unlabel_x.shape
529
+
530
+ center_x = random.randint(h // 2 - (h - h * ratio) // 2, h // 2 + (h - h * ratio) // 2)
531
+ center_y = random.randint(w // 2 - (w - w * ratio) // 2, w // 2 + (w - w * ratio) // 2)
532
+
533
+ new_h2 = int(h * ratio)
534
+ new_w2 = int(w * ratio) # 48*48
535
+ unlabel_x_2 = unlabel_x[:, :, center_x - new_h2 // 2:center_x + new_h2 // 2,
536
+ center_y - new_w2 // 2:center_y + new_w2 // 2]
537
+
538
+ new_h3 = int(new_h2 * ratio)
539
+ new_w3 = int(new_w2 * ratio)
540
+ unlabel_x_3 = unlabel_x[:, :, center_x - new_h3 // 2:center_x + new_h3 // 2,
541
+ center_y - new_w3 // 2:center_y + new_w3 // 2]
542
+
543
+ new_h4 = int(new_h3 * ratio)
544
+ new_w4 = int(new_w3 * ratio)
545
+ unlabel_x_4 = unlabel_x[:, :, center_x - new_h4 // 2:center_x + new_h4 // 2,
546
+ center_y - new_w4 // 2:center_y + new_w4 // 2]
547
+
548
+ new_h5 = int(new_h4 * ratio)
549
+ new_w5 = int(new_w4 * ratio)
550
+ unlabel_x_5 = unlabel_x[:, :, center_x - new_h5 // 2:center_x + new_h5 // 2,
551
+ center_y - new_w5 // 2:center_y + new_w5 // 2]
552
+
553
+ unlabel_x_2 = nn.functional.interpolate(unlabel_x_2, size=(h, w), mode='bilinear')
554
+ unlabel_x_3 = nn.functional.interpolate(unlabel_x_3, size=(h, w), mode='bilinear')
555
+ unlabel_x_4 = nn.functional.interpolate(unlabel_x_4, size=(h, w), mode='bilinear')
556
+ unlabel_x_5 = nn.functional.interpolate(unlabel_x_5, size=(h, w), mode='bilinear')
557
+
558
+ unlabel_x = torch.cat([unlabel_x_1, unlabel_x_2, unlabel_x_3, unlabel_x_4, unlabel_x_5], dim=0)
559
+
560
+ return unlabel_x
561
+
562
+ def _conv_filter(state_dict, patch_size=16):
563
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
564
+ out_dict = {}
565
+ for k, v in state_dict.items():
566
+ if 'patch_embed.proj.weight' in k:
567
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
568
+ out_dict[k] = v
569
+
570
+ return out_dict
571
+
572
+
573
+ @register_model
574
+ def pvt_tiny(pretrained=False, **kwargs):
575
+ model = PyramidVisionTransformer(
576
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
577
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
578
+ **kwargs)
579
+ model.default_cfg = _cfg()
580
+
581
+ return model
582
+
583
+
584
+ @register_model
585
+ def pvt_small(pretrained=False, **kwargs):
586
+ model = PyramidVisionTransformer(
587
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
588
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
589
+ model.default_cfg = _cfg()
590
+
591
+ return model
592
+
593
+
594
+ @register_model
595
+ def pvt_medium(pretrained=False, **kwargs):
596
+ model = PyramidVisionTransformer(
597
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
598
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
599
+ **kwargs)
600
+ model.default_cfg = _cfg()
601
+
602
+ return model
603
+
604
+
605
+ @register_model
606
+ def pvt_large(pretrained=False, **kwargs):
607
+ model = PyramidVisionTransformer(
608
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
609
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
610
+ **kwargs)
611
+ model.default_cfg = _cfg()
612
+
613
+ return model
614
+
615
+ @register_model
616
+ def pvt_treeformer(pretrained=False, **kwargs):
617
+ model = PyramidVisionTransformer(
618
+ patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
619
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
620
+ **kwargs)
621
+ model.default_cfg = _cfg()
622
+
623
+ return model
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.21.5
2
+ Pillow==9.4.0
3
+ scikit_learn==1.2.2
4
+ scipy==1.7.3
5
+ timm==0.4.12
6
+ torch==1.12.1
7
+ torchvision==0.13.1
sample_imgs/overview.png ADDED
test.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import datasets.crowd as crowd
6
+ from network import pvt_cls as TCN
7
+ import torch.nn.functional as F
8
+ from scipy.io import savemat
9
+ from sklearn.metrics import r2_score
10
+
11
+ parser = argparse.ArgumentParser(description='Test ')
12
+ parser.add_argument('--device', default='0', help='assign device')
13
+ parser.add_argument('--batch-size', type=int, default=8, help='train batch size')
14
+ parser.add_argument('--crop-size', type=int, default=256, help='the crop size of the train image')
15
+ parser.add_argument('--model-path', type=str, default='/scratch/users/k2254235/ckpts/SEMI/Treeformer/best_model_mae-21.49_epoch-1759.pth', help='saved model path')
16
+ parser.add_argument('--data-path', type=str, default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='dataset path')
17
+ parser.add_argument('--dataset', type=str, default='TC')
18
+
19
+ def test(args, isSave = True):
20
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu
21
+ device = torch.device('cuda')
22
+
23
+ model_path = args.model_path
24
+ crop_size = args.crop_size
25
+ data_path = args.data_path
26
+
27
+ dataset = crowd.Crowd_TC(os.path.join(data_path, 'test_data'), crop_size, 1, method='val')
28
+ dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, pin_memory=True)
29
+
30
+ model = TCN.pvt_treeformer(pretrained=False)
31
+ model.to(device)
32
+ model.load_state_dict(torch.load(model_path, device))
33
+ model.eval()
34
+ image_errs = []
35
+ result = []
36
+ R2_es = []
37
+ R2_gt = []
38
+ l=0;
39
+ for inputs, count, name, imgauss in dataloader:
40
+ with torch.no_grad():
41
+ inputs = inputs.to(device)
42
+ crop_imgs, crop_masks = [], []
43
+ b, c, h, w = inputs.size()
44
+ rh, rw = args.crop_size, args.crop_size
45
+
46
+ for i in range(0, h, rh):
47
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
48
+
49
+ for j in range(0, w, rw):
50
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
51
+ crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
52
+ mask = torch.zeros([b, 1, h, w]).to(device)
53
+ mask[:, :, gis:gie, gjs:gje].fill_(1.0)
54
+ crop_masks.append(mask)
55
+ crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
56
+
57
+ crop_preds = []
58
+ nz, bz = crop_imgs.size(0), args.batch_size
59
+ for i in range(0, nz, bz):
60
+
61
+ gs, gt = i, min(nz, i + bz)
62
+ crop_pred, _ = model(crop_imgs[gs:gt])
63
+ crop_pred = crop_pred[0]
64
+
65
+ _, _, h1, w1 = crop_pred.size()
66
+ crop_pred = F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
67
+ crop_preds.append(crop_pred)
68
+ crop_preds = torch.cat(crop_preds, dim=0)
69
+ #import pdb;pdb.set_trace()
70
+
71
+ # splice them to the original size
72
+ idx = 0
73
+ pred_map = torch.zeros([b, 1, h, w]).to(device)
74
+ for i in range(0, h, rh):
75
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
76
+ for j in range(0, w, rw):
77
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
78
+ pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
79
+ idx += 1
80
+ # for the overlapping area, compute average value
81
+ mask = crop_masks.sum(dim=0).unsqueeze(0)
82
+ outputs = pred_map / mask
83
+
84
+ outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)/4
85
+ outputs = pred_map / mask
86
+
87
+ img_err = count[0].item() - torch.sum(outputs).item()
88
+ R2_gt.append(count[0].item())
89
+ R2_es.append(torch.sum(outputs).item())
90
+
91
+ print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item())
92
+ image_errs.append(img_err)
93
+ result.append([name, count[0].item(), torch.sum(outputs).item(), img_err])
94
+
95
+ savemat('predictions/'+name[0]+'.mat', {'estimation':np.squeeze(outputs.cpu().data.numpy()),
96
+ 'image': np.squeeze(inputs.cpu().data.numpy()), 'gt': np.squeeze(imgauss.cpu().data.numpy())})
97
+ l=l+1
98
+
99
+
100
+ image_errs = np.array(image_errs)
101
+
102
+ mse = np.sqrt(np.mean(np.square(image_errs)))
103
+ mae = np.mean(np.abs(image_errs))
104
+ R_2 = r2_score(R2_gt,R2_es)
105
+
106
+ print('{}: mae {}, mse {}, R2 {}\n'.format(model_path, mae, mse,R_2))
107
+
108
+ if isSave:
109
+ with open("test.txt","w") as f:
110
+ for i in range(len(result)):
111
+ f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n")
112
+ f.close()
113
+
114
+ if __name__ == '__main__':
115
+ args = parser.parse_args()
116
+ test(args, isSave= True)
117
+
train.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import optim
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data.dataloader import default_collate
8
+ import numpy as np
9
+ from datetime import datetime
10
+ import torch.nn.functional as F
11
+ from datasets.crowd import Crowd_TC, Crowd_UL_TC
12
+
13
+ from network import pvt_cls as TCN
14
+ from losses.multi_con_loss import MultiConLoss
15
+
16
+ from utils.pytorch_utils import Save_Handle, AverageMeter
17
+ import utils.log_utils as log_utils
18
+ import argparse
19
+ from losses.rank_loss import RankLoss
20
+
21
+ from losses import ramps
22
+ from losses.ot_loss import OT_Loss
23
+ from losses.consistency_loss import *
24
+
25
+ parser = argparse.ArgumentParser(description='Train')
26
+ parser.add_argument('--data-dir', default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='data path')
27
+
28
+ parser.add_argument('--dataset', default='TC')
29
+ parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate')
30
+ parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay')
31
+ parser.add_argument('--resume', default='', type=str, help='the path of resume training model')
32
+ parser.add_argument('--max-epoch', type=int, default=4000, help='max training epoch')
33
+ parser.add_argument('--val-epoch', type=int, default=1, help='the num of steps to log training information')
34
+ parser.add_argument('--val-start', type=int, default=0, help='the epoch start to val')
35
+ parser.add_argument('--batch-size', type=int, default=16, help='train batch size')
36
+ parser.add_argument('--batch-size-ul', type=int, default=16, help='train batch size')
37
+ parser.add_argument('--device', default='0', help='assign device')
38
+ parser.add_argument('--num-workers', type=int, default=0, help='the num of training process')
39
+ parser.add_argument('--crop-size', type=int, default= 256, help='the crop size of the train image')
40
+ parser.add_argument('--rl', type=float, default=1, help='entropy regularization in sinkhorn')
41
+ parser.add_argument('--reg', type=float, default=1, help='entropy regularization in sinkhorn')
42
+ parser.add_argument('--ot', type=float, default=0.1, help='entropy regularization in sinkhorn')
43
+ parser.add_argument('--tv', type=float, default=0.01, help='entropy regularization in sinkhorn')
44
+ parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations')
45
+ parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance')
46
+ parser.add_argument('--run-name', default='Treeformer_test', help='run name for wandb interface/logging')
47
+ parser.add_argument('--consistency', type=int, default=1, help='whether to norm cood when computing distance')
48
+ args = parser.parse_args()
49
+
50
+
51
+ def train_collate(batch):
52
+ transposed_batch = list(zip(*batch))
53
+ images = torch.stack(transposed_batch[0], 0)
54
+ gauss = torch.stack(transposed_batch[1], 0)
55
+ points = transposed_batch[2]
56
+ gt_discretes = torch.stack(transposed_batch[3], 0)
57
+ return images, gauss, points, gt_discretes
58
+
59
+
60
+ def train_collate_UL(batch):
61
+ transposed_batch = list(zip(*batch))
62
+ images = torch.stack(transposed_batch[0], 0)
63
+
64
+ return images
65
+
66
+ def get_current_consistency_weight(epoch):
67
+ # Consistency ramp-up from https://arxiv.org/abs/1610.02242
68
+ return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_ramp)
69
+
70
+
71
+ class Trainer(object):
72
+ def __init__(self, args):
73
+ self.args = args
74
+
75
+ def setup(self):
76
+ args = self.args
77
+ sub_dir = (
78
+ "SEMI/{}_12-1-input-{}_reg-{}_nIter-{}_normCood-{}".format(
79
+ args.run_name,args.crop_size,args.reg,
80
+ args.num_of_iter_in_ot,args.norm_cood))
81
+
82
+ self.save_dir = os.path.join("/scratch/users/k2254235","ckpts", sub_dir)
83
+ if not os.path.exists(self.save_dir):
84
+ os.makedirs(self.save_dir)
85
+
86
+ time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S")
87
+ self.logger = log_utils.get_logger(
88
+ os.path.join(self.save_dir, "train-{:s}.log".format(time_str)))
89
+
90
+ log_utils.print_config(vars(args), self.logger)
91
+
92
+ if torch.cuda.is_available():
93
+ self.device = torch.device("cuda")
94
+ self.device_count = torch.cuda.device_count()
95
+ self.logger.info("using {} gpus".format(self.device_count))
96
+ else:
97
+ raise Exception("gpu is not available")
98
+
99
+
100
+ downsample_ratio = 4
101
+ self.datasets = {"train": Crowd_TC(os.path.join(args.data_dir, "train_data"), args.crop_size,
102
+ downsample_ratio, "train"), "val": Crowd_TC(os.path.join(args.data_dir, "valid_data"),
103
+ args.crop_size, downsample_ratio, "val")}
104
+
105
+ self.datasets_ul = { "train_ul": Crowd_UL_TC(os.path.join(args.data_dir, "train_data_ul"),
106
+ args.crop_size, downsample_ratio, "train_ul")}
107
+
108
+
109
+ self.dataloaders = {
110
+ x: DataLoader(self.datasets[x],
111
+ collate_fn=(train_collate if x == "train" else default_collate),
112
+ batch_size=(args.batch_size if x == "train" else 1),
113
+ shuffle=(True if x == "train" else False),
114
+ num_workers=args.num_workers * self.device_count,
115
+ pin_memory=(True if x == "train" else False))
116
+ for x in ["train", "val"]}
117
+
118
+ self.dataloaders_ul = {
119
+ x: DataLoader(self.datasets_ul[x],
120
+ collate_fn=(train_collate_UL ),
121
+ batch_size=(args.batch_size_ul),
122
+ shuffle=(True),
123
+ num_workers=args.num_workers * self.device_count,
124
+ pin_memory=(True if x == "train" else False))
125
+ for x in ["train_ul"]}
126
+
127
+
128
+ self.model = TCN.pvt_treeformer(pretrained=False)
129
+
130
+ self.model.to(self.device)
131
+ self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
132
+ self.start_epoch = 0
133
+
134
+ if args.resume:
135
+ self.logger.info("loading pretrained model from " + args.resume)
136
+ suf = args.resume.rsplit(".", 1)[-1]
137
+ if suf == "tar":
138
+ checkpoint = torch.load(args.resume, self.device)
139
+ self.model.load_state_dict(checkpoint["model_state_dict"])
140
+ self.optimizer.load_state_dict(
141
+ checkpoint["optimizer_state_dict"])
142
+ self.start_epoch = checkpoint["epoch"] + 1
143
+ elif suf == "pth":
144
+ self.model.load_state_dict(
145
+ torch.load(args.resume, self.device))
146
+ else:
147
+ self.logger.info("random initialization")
148
+
149
+ self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood,
150
+ self.device, args.num_of_iter_in_ot, args.reg)
151
+
152
+ self.tvloss = nn.L1Loss(reduction="none").to(self.device)
153
+ self.mse = nn.MSELoss().to(self.device)
154
+ self.mae = nn.L1Loss().to(self.device)
155
+ self.save_list = Save_Handle(max_num=1)
156
+ self.best_mae = np.inf
157
+ self.best_mse = np.inf
158
+ self.rankloss = RankLoss().to(self.device)
159
+ self.kl_distance = nn.KLDivLoss(reduction='none')
160
+ self.multiconloss = MultiConLoss().to(self.device)
161
+
162
+
163
+ def train(self):
164
+ """training process"""
165
+ args = self.args
166
+ for epoch in range(self.start_epoch, args.max_epoch + 1):
167
+ self.logger.info("-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5)
168
+ self.epoch = epoch
169
+ self.train_epoch()
170
+ if epoch % args.val_epoch == 0 and epoch >= args.val_start:
171
+ self.val_epoch()
172
+
173
+ def train_epoch(self):
174
+ epoch_ot_loss = AverageMeter()
175
+ epoch_ot_obj_value = AverageMeter()
176
+ epoch_wd = AverageMeter()
177
+ epoch_tv_loss = AverageMeter()
178
+ epoch_count_loss = AverageMeter()
179
+ epoch_count_consistency_l = AverageMeter()
180
+ epoch_count_consistency_ul = AverageMeter()
181
+ epoch_loss = AverageMeter()
182
+ epoch_mae = AverageMeter()
183
+ epoch_mse = AverageMeter()
184
+ epoch_start = time.time()
185
+ epoch_rank_loss = AverageMeter()
186
+ epoch_consistensy_loss = AverageMeter()
187
+
188
+ self.model.train() # Set model to training mode
189
+
190
+ for step, (inputs, gausss, points, gt_discrete) in enumerate(self.dataloaders["train"]):
191
+ inputs = inputs.to(self.device)
192
+ gausss = gausss.to(self.device)
193
+ gd_count = np.array([len(p) for p in points], dtype=np.float32)
194
+
195
+ points = [p.to(self.device) for p in points]
196
+ gt_discrete = gt_discrete.to(self.device)
197
+ N = inputs.size(0)
198
+
199
+ for st, unlabel_data in enumerate(self.dataloaders_ul["train_ul"]):
200
+ inputs_ul = unlabel_data.to(self.device)
201
+ break
202
+
203
+
204
+ with torch.set_grad_enabled(True):
205
+ outputs_L, outputs_UL, outputs_normed, CLS_L, CLS_UL = self.model(inputs, inputs_ul)
206
+ outputs_L = outputs_L[0]
207
+
208
+ with torch.set_grad_enabled(False):
209
+ preds_UL = (outputs_UL[0][0] + outputs_UL[1][0] + outputs_UL[2][0])/3
210
+
211
+ # Compute counting loss.
212
+ count_loss = self.mae(outputs_L.sum(1).sum(1).sum(1),torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
213
+
214
+ # Compute OT loss.
215
+ ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L, points)
216
+ ot_loss = ot_loss* self.args.ot
217
+ ot_obj_value = ot_obj_value* self.args.ot
218
+
219
+ gd_count_tensor = (torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
220
+ gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
221
+ tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
222
+ torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
223
+
224
+ epoch_ot_loss.update(ot_loss.item(), N)
225
+ epoch_ot_obj_value.update(ot_obj_value.item(), N)
226
+ epoch_wd.update(wd, N)
227
+ epoch_count_loss.update(count_loss.item(), N)
228
+ epoch_tv_loss.update(tv_loss.item(), N)
229
+
230
+ # Compute ranking loss.
231
+ rank_loss = self.rankloss(outputs_UL)*self.args.rl
232
+ epoch_rank_loss.update(rank_loss.item(), N)
233
+
234
+ # Compute multi level consistancy loss
235
+ consistency_loss = args.consistency * self.multiconloss(outputs_UL)
236
+ epoch_consistensy_loss.update(consistency_loss.item(), N)
237
+
238
+
239
+ # Compute consistency count
240
+ Con_cls_UL = (CLS_UL[0] + CLS_UL[1] + CLS_UL[2])/3
241
+ Con_cls_L = torch.from_numpy(gd_count).float().to(self.device)
242
+
243
+ count_loss_l = self.mae(torch.stack((CLS_L[0],CLS_L[1],CLS_L[2])), torch.stack((Con_cls_L, Con_cls_L, Con_cls_L)))
244
+ count_loss_ul = self.mae(torch.stack((CLS_UL[0],CLS_UL[1],CLS_UL[2])), torch.stack((Con_cls_UL, Con_cls_UL, Con_cls_UL)))
245
+ epoch_count_consistency_l.update(count_loss_l.item(), N)
246
+ epoch_count_consistency_ul.update(count_loss_ul.item(), N)
247
+
248
+
249
+ loss = count_loss + ot_loss + tv_loss + rank_loss + count_loss_l + count_loss_ul + consistency_loss
250
+
251
+
252
+ self.optimizer.zero_grad()
253
+ loss.backward()
254
+ self.optimizer.step()
255
+
256
+ pred_count = (torch.sum(outputs_L.view(N, -1),
257
+ dim=1).detach().cpu().numpy())
258
+
259
+ pred_err = pred_count - gd_count
260
+ epoch_loss.update(loss.item(), N)
261
+ epoch_mse.update(np.mean(pred_err * pred_err), N)
262
+ epoch_mae.update(np.mean(abs(pred_err)), N)
263
+
264
+
265
+ self.logger.info(
266
+ "Epoch {} Train, Loss: {:.2f}, Count Loss: {:.2f}, OT Loss: {:.2e}, TV Loss: {:.2e}, Rank Loss: {:.2f},"
267
+ "Consistensy Loss: {:.2f}, MSE: {:.2f}, MAE: {:.2f},LC Loss: {:.2f}, ULC Loss: {:.2f}, Cost {:.1f} sec".format(
268
+ self.epoch, epoch_loss.get_avg(), epoch_count_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_tv_loss.get_avg(), epoch_rank_loss.get_avg(),
269
+ epoch_consistensy_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), epoch_count_consistency_l.get_avg(),
270
+ epoch_count_consistency_ul.get_avg(), time.time() - epoch_start))
271
+
272
+
273
+
274
+ model_state_dic = self.model.state_dict()
275
+ save_path = os.path.join(self.save_dir, "{}_ckpt.tar".format(self.epoch))
276
+
277
+ torch.save({"epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(),
278
+ "model_state_dict": model_state_dic}, save_path)
279
+ self.save_list.append(save_path)
280
+
281
+ def val_epoch(self):
282
+ args = self.args
283
+ epoch_start = time.time()
284
+ self.model.eval() # Set model to evaluate mode
285
+ epoch_res = []
286
+ for inputs, count, name, gauss_im in self.dataloaders["val"]:
287
+ with torch.no_grad():
288
+ inputs = inputs.to(self.device)
289
+ crop_imgs, crop_masks = [], []
290
+ b, c, h, w = inputs.size()
291
+ rh, rw = args.crop_size, args.crop_size
292
+ for i in range(0, h, rh):
293
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
294
+ for j in range(0, w, rw):
295
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
296
+ crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
297
+ mask = torch.zeros([b, 1, h, w]).to(self.device)
298
+ mask[:, :, gis:gie, gjs:gje].fill_(1.0)
299
+ crop_masks.append(mask)
300
+ crop_imgs, crop_masks = map(
301
+ lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
302
+
303
+ crop_preds = []
304
+ nz, bz = crop_imgs.size(0), args.batch_size
305
+ for i in range(0, nz, bz):
306
+ gs, gt = i, min(nz, i + bz)
307
+
308
+ crop_pred, _ = self.model(crop_imgs[gs:gt])
309
+ crop_pred = crop_pred[0]
310
+ _, _, h1, w1 = crop_pred.size()
311
+ crop_pred = (F.interpolate(crop_pred, size=(h1 * 4, w1 * 4),
312
+ mode="bilinear", align_corners=True) / 16 )
313
+
314
+ crop_preds.append(crop_pred)
315
+ crop_preds = torch.cat(crop_preds, dim=0)
316
+
317
+ # splice them to the original size
318
+ idx = 0
319
+ pred_map = torch.zeros([b, 1, h, w]).to(self.device)
320
+ for i in range(0, h, rh):
321
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
322
+ for j in range(0, w, rw):
323
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
324
+ pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
325
+ idx += 1
326
+ # for the overlapping area, compute average value
327
+ mask = crop_masks.sum(dim=0).unsqueeze(0)
328
+ outputs = pred_map / mask
329
+
330
+ res = count[0].item() - torch.sum(outputs).item()
331
+ epoch_res.append(res)
332
+ epoch_res = np.array(epoch_res)
333
+ mse = np.sqrt(np.mean(np.square(epoch_res)))
334
+ mae = np.mean(np.abs(epoch_res))
335
+
336
+ self.logger.info("Epoch {} Val, MSE: {:.2f}, MAE: {:.2f}, Cost {:.1f} sec".format(
337
+ self.epoch, mse, mae, time.time() - epoch_start ))
338
+
339
+
340
+ model_state_dic = self.model.state_dict()
341
+ print("Comaprison", mae, self.best_mae)
342
+ if mae < self.best_mae:
343
+ self.best_mse = mse
344
+ self.best_mae = mae
345
+ self.logger.info(
346
+ "save best mse {:.2f} mae {:.2f} model epoch {}".format(
347
+ self.best_mse, self.best_mae, self.epoch))
348
+
349
+ print("Saving best model at {} epoch".format(self.epoch))
350
+ model_path = os.path.join(
351
+ self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format(
352
+ self.best_mae, self.epoch))
353
+
354
+ torch.save(model_state_dic, model_path)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ import torch
359
+ torch.backends.cudnn.benchmark = True
360
+ trainer = Trainer(args)
361
+ trainer.setup()
362
+ trainer.train()
363
+
364
+
365
+
366
+
367
+
368
+
369
+
utils/__init__.py ADDED
File without changes
utils/log_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def get_logger(log_file):
5
+ logger = logging.getLogger(log_file)
6
+ logger.setLevel(logging.DEBUG)
7
+ fh = logging.FileHandler(log_file)
8
+ fh.setLevel(logging.DEBUG)
9
+ ch = logging.StreamHandler()
10
+ ch.setLevel(logging.INFO)
11
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
12
+ ch.setFormatter(formatter)
13
+ fh.setFormatter(formatter)
14
+ logger.addHandler(ch)
15
+ logger.addHandler(fh)
16
+ return logger
17
+
18
+
19
+ def print_config(config, logger):
20
+ """
21
+ Print configuration of the model
22
+ """
23
+ for k, v in config.items():
24
+ logger.info("{}:\t{}".format(k.ljust(15), v))
utils/pytorch_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):
4
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
5
+ lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
6
+ for param_group in optimizer.param_groups:
7
+ param_group['lr'] = lr
8
+
9
+
10
+ class Save_Handle(object):
11
+ """handle the number of """
12
+ def __init__(self, max_num):
13
+ self.save_list = []
14
+ self.max_num = max_num
15
+
16
+ def append(self, save_path):
17
+ if len(self.save_list) < self.max_num:
18
+ self.save_list.append(save_path)
19
+ else:
20
+ remove_path = self.save_list[0]
21
+ del self.save_list[0]
22
+ self.save_list.append(save_path)
23
+ if os.path.exists(remove_path):
24
+ os.remove(remove_path)
25
+
26
+
27
+ class AverageMeter(object):
28
+ """Computes and stores the average and current value"""
29
+ def __init__(self):
30
+ self.reset()
31
+
32
+ def reset(self):
33
+ self.val = 0
34
+ self.avg = 0
35
+ self.sum = 0
36
+ self.count = 0
37
+
38
+ def update(self, val, n=1):
39
+ self.val = val
40
+ self.sum += val * n
41
+ self.count += n
42
+ self.avg = 1.0 * self.sum / self.count
43
+
44
+ def get_avg(self):
45
+ return self.avg
46
+
47
+ def get_count(self):
48
+ return self.count
49
+
50
+
51
+ def set_trainable(model, requires_grad):
52
+ for param in model.parameters():
53
+ param.requires_grad = requires_grad
54
+
55
+
56
+
57
+ def get_num_params(model):
58
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)