Spaces:
Paused
Paused
franciszzj
commited on
Commit
•
c964d4c
1
Parent(s):
c6b26ba
init
Browse files- README.md +31 -13
- assets/EU.png +0 -0
- assets/reset.png +0 -0
- datasets/__init__.py +0 -0
- datasets/crowd.py +268 -0
- demo.py +148 -0
- examples/IMG_101.jpg +0 -0
- examples/IMG_125.jpg +0 -0
- examples/IMG_138.jpg +0 -0
- examples/IMG_18.jpg +0 -0
- examples/IMG_180.jpg +0 -0
- examples/IMG_206.jpg +0 -0
- examples/IMG_223.jpg +0 -0
- examples/IMG_247.jpg +0 -0
- examples/IMG_270.jpg +0 -0
- examples/IMG_306.jpg +0 -0
- losses/__init__.py +1 -0
- losses/bregman_pytorch.py +484 -0
- losses/consistency_loss.py +294 -0
- losses/dm_loss.py +62 -0
- losses/multi_con_loss.py +41 -0
- losses/ot_loss.py +68 -0
- losses/ramps.py +41 -0
- losses/rank_loss.py +53 -0
- network/pvt_cls.py +623 -0
- requirements.txt +7 -0
- sample_imgs/overview.png +0 -0
- test.py +117 -0
- train.py +369 -0
- utils/__init__.py +0 -0
- utils/log_utils.py +24 -0
- utils/pytorch_utils.py +58 -0
README.md
CHANGED
@@ -1,13 +1,31 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# TreeFormer
|
3 |
+
|
4 |
+
This is the code base for IEEE TRANSACTIONS ON GEOSCIENCE AND REMOTE SENSING (TGRS 2023) paper ['TreeFormer: a Semi-Supervised Transformer-based Framework for Tree Counting from a Single High Resolution Image'](https://arxiv.org/abs/2307.06118)
|
5 |
+
|
6 |
+
<img src="sample_imgs/overview.png">
|
7 |
+
|
8 |
+
## Installation
|
9 |
+
|
10 |
+
Python ≥ 3.7.
|
11 |
+
|
12 |
+
To install the required packages, please run:
|
13 |
+
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install -r requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
## Dataset
|
20 |
+
Download the dataset from [google drive](https://drive.google.com/file/d/1xcjv8967VvvzcDM4aqAi7Corkb11T0i2/view?usp=drive_link).
|
21 |
+
## Evaluation
|
22 |
+
Download our trained model on [London](https://drive.google.com/file/d/14uuOF5758sxtM5EgeGcRtSln5lUXAHge/view?usp=sharing) dataset.
|
23 |
+
|
24 |
+
Modify the path to the dataset and model for evaluation in 'test.py'.
|
25 |
+
|
26 |
+
Run 'test.py'
|
27 |
+
## Acknowledgements
|
28 |
+
|
29 |
+
- Part of codes are borrowed from [PVT](https://github.com/whai362/PVT) and [DM Count](https://github.com/cvlab-stonybrook/DM-Count). Thanks for their great work!
|
30 |
+
|
31 |
+
|
assets/EU.png
ADDED
assets/reset.png
ADDED
datasets/__init__.py
ADDED
File without changes
|
datasets/crowd.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch.utils.data as data
|
3 |
+
import os
|
4 |
+
from glob import glob
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from torchvision import transforms
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
import scipy.io as sio
|
11 |
+
|
12 |
+
def random_crop(im_h, im_w, crop_h, crop_w):
|
13 |
+
res_h = im_h - crop_h
|
14 |
+
res_w = im_w - crop_w
|
15 |
+
i = random.randint(0, res_h)
|
16 |
+
j = random.randint(0, res_w)
|
17 |
+
return i, j, crop_h, crop_w
|
18 |
+
|
19 |
+
|
20 |
+
def gen_discrete_map(im_height, im_width, points):
|
21 |
+
"""
|
22 |
+
func: generate the discrete map.
|
23 |
+
points: [num_gt, 2], for each row: [width, height]
|
24 |
+
"""
|
25 |
+
discrete_map = np.zeros([im_height, im_width], dtype=np.float32)
|
26 |
+
h, w = discrete_map.shape[:2]
|
27 |
+
num_gt = points.shape[0]
|
28 |
+
if num_gt == 0:
|
29 |
+
return discrete_map
|
30 |
+
|
31 |
+
# fast create discrete map
|
32 |
+
points_np = np.array(points).round().astype(int)
|
33 |
+
p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int))
|
34 |
+
p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int))
|
35 |
+
p_index = torch.from_numpy(p_h* im_width + p_w).to(torch.int64)
|
36 |
+
discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy()
|
37 |
+
|
38 |
+
''' slow method
|
39 |
+
for p in points:
|
40 |
+
p = np.round(p).astype(int)
|
41 |
+
p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0])
|
42 |
+
discrete_map[p[0], p[1]] += 1
|
43 |
+
'''
|
44 |
+
assert np.sum(discrete_map) == num_gt
|
45 |
+
return discrete_map
|
46 |
+
|
47 |
+
|
48 |
+
class Base(data.Dataset):
|
49 |
+
def __init__(self, root_path, crop_size, downsample_ratio=8):
|
50 |
+
|
51 |
+
self.root_path = root_path
|
52 |
+
self.c_size = crop_size
|
53 |
+
self.d_ratio = downsample_ratio
|
54 |
+
assert self.c_size % self.d_ratio == 0
|
55 |
+
self.dc_size = self.c_size // self.d_ratio
|
56 |
+
self.trans = transforms.Compose([
|
57 |
+
transforms.ToTensor(),
|
58 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
59 |
+
])
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
def __getitem__(self, item):
|
65 |
+
pass
|
66 |
+
|
67 |
+
def train_transform(self, img, keypoints, gauss_im):
|
68 |
+
wd, ht = img.size
|
69 |
+
st_size = 1.0 * min(wd, ht)
|
70 |
+
assert st_size >= self.c_size
|
71 |
+
assert len(keypoints) >= 0
|
72 |
+
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
|
73 |
+
img = F.crop(img, i, j, h, w)
|
74 |
+
gauss_im = F.crop(img, i, j, h, w)
|
75 |
+
if len(keypoints) > 0:
|
76 |
+
keypoints = keypoints - [j, i]
|
77 |
+
idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
|
78 |
+
(keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
|
79 |
+
keypoints = keypoints[idx_mask]
|
80 |
+
else:
|
81 |
+
keypoints = np.empty([0, 2])
|
82 |
+
|
83 |
+
gt_discrete = gen_discrete_map(h, w, keypoints)
|
84 |
+
down_w = w // self.d_ratio
|
85 |
+
down_h = h // self.d_ratio
|
86 |
+
gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
|
87 |
+
assert np.sum(gt_discrete) == len(keypoints)
|
88 |
+
|
89 |
+
if len(keypoints) > 0:
|
90 |
+
if random.random() > 0.5:
|
91 |
+
img = F.hflip(img)
|
92 |
+
gauss_im = F.hflip(gauss_im)
|
93 |
+
gt_discrete = np.fliplr(gt_discrete)
|
94 |
+
keypoints[:, 0] = w - keypoints[:, 0]
|
95 |
+
else:
|
96 |
+
if random.random() > 0.5:
|
97 |
+
img = F.hflip(img)
|
98 |
+
gauss_im = F.hflip(gauss_im)
|
99 |
+
gt_discrete = np.fliplr(gt_discrete)
|
100 |
+
gt_discrete = np.expand_dims(gt_discrete, 0)
|
101 |
+
|
102 |
+
return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
class Crowd_TC(Base):
|
107 |
+
def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'):
|
108 |
+
super().__init__(root_path, crop_size, downsample_ratio)
|
109 |
+
self.method = method
|
110 |
+
if method not in ['train', 'val']:
|
111 |
+
raise Exception("not implement")
|
112 |
+
|
113 |
+
self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
|
114 |
+
|
115 |
+
print('number of img [{}]: {}'.format(method, len(self.im_list)))
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.im_list)
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
img_path = self.im_list[item]
|
122 |
+
name = os.path.basename(img_path).split('.')[0]
|
123 |
+
gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name))
|
124 |
+
img = Image.open(img_path).convert('RGB')
|
125 |
+
keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0]
|
126 |
+
gauss_path = os.path.join(self.root_path, 'ground_truth', '{}_densitymap.npy'.format(name))
|
127 |
+
gauss_im = torch.from_numpy(np.load(gauss_path)).float()
|
128 |
+
#import pdb;pdb.set_trace()
|
129 |
+
#print("label {}", item)
|
130 |
+
|
131 |
+
if self.method == 'train':
|
132 |
+
return self.train_transform(img, keypoints, gauss_im)
|
133 |
+
elif self.method == 'val':
|
134 |
+
wd, ht = img.size
|
135 |
+
st_size = 1.0 * min(wd, ht)
|
136 |
+
if st_size < self.c_size:
|
137 |
+
rr = 1.0 * self.c_size / st_size
|
138 |
+
wd = round(wd * rr)
|
139 |
+
ht = round(ht * rr)
|
140 |
+
st_size = 1.0 * min(wd, ht)
|
141 |
+
img = img.resize((wd, ht), Image.BICUBIC)
|
142 |
+
img = self.trans(img)
|
143 |
+
#import pdb;pdb.set_trace()
|
144 |
+
|
145 |
+
return img, len(keypoints), name, gauss_im
|
146 |
+
|
147 |
+
def train_transform(self, img, keypoints, gauss_im):
|
148 |
+
wd, ht = img.size
|
149 |
+
st_size = 1.0 * min(wd, ht)
|
150 |
+
# resize the image to fit the crop size
|
151 |
+
if st_size < self.c_size:
|
152 |
+
rr = 1.0 * self.c_size / st_size
|
153 |
+
wd = round(wd * rr)
|
154 |
+
ht = round(ht * rr)
|
155 |
+
st_size = 1.0 * min(wd, ht)
|
156 |
+
img = img.resize((wd, ht), Image.BICUBIC)
|
157 |
+
#gauss_im = gauss_im.resize((wd, ht), Image.BICUBIC)
|
158 |
+
keypoints = keypoints * rr
|
159 |
+
assert st_size >= self.c_size, print(wd, ht)
|
160 |
+
assert len(keypoints) >= 0
|
161 |
+
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
|
162 |
+
img = F.crop(img, i, j, h, w)
|
163 |
+
gauss_im = F.crop(gauss_im, i, j, h, w)
|
164 |
+
if len(keypoints) > 0:
|
165 |
+
keypoints = keypoints - [j, i]
|
166 |
+
idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
|
167 |
+
(keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
|
168 |
+
keypoints = keypoints[idx_mask]
|
169 |
+
else:
|
170 |
+
keypoints = np.empty([0, 2])
|
171 |
+
|
172 |
+
gt_discrete = gen_discrete_map(h, w, keypoints)
|
173 |
+
down_w = w // self.d_ratio
|
174 |
+
down_h = h // self.d_ratio
|
175 |
+
gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
|
176 |
+
assert np.sum(gt_discrete) == len(keypoints)
|
177 |
+
|
178 |
+
|
179 |
+
if len(keypoints) > 0:
|
180 |
+
if random.random() > 0.5:
|
181 |
+
img = F.hflip(img)
|
182 |
+
gauss_im = F.hflip(gauss_im)
|
183 |
+
gt_discrete = np.fliplr(gt_discrete)
|
184 |
+
keypoints[:, 0] = w - keypoints[:, 0] - 1
|
185 |
+
else:
|
186 |
+
if random.random() > 0.5:
|
187 |
+
img = F.hflip(img)
|
188 |
+
gauss_im = F.hflip(gauss_im)
|
189 |
+
gt_discrete = np.fliplr(gt_discrete)
|
190 |
+
gt_discrete = np.expand_dims(gt_discrete, 0)
|
191 |
+
#import pdb;pdb.set_trace()
|
192 |
+
|
193 |
+
return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
|
194 |
+
|
195 |
+
|
196 |
+
class Base_UL(data.Dataset):
|
197 |
+
def __init__(self, root_path, crop_size, downsample_ratio=8):
|
198 |
+
self.root_path = root_path
|
199 |
+
self.c_size = crop_size
|
200 |
+
self.d_ratio = downsample_ratio
|
201 |
+
assert self.c_size % self.d_ratio == 0
|
202 |
+
self.dc_size = self.c_size // self.d_ratio
|
203 |
+
self.trans = transforms.Compose([
|
204 |
+
transforms.ToTensor(),
|
205 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
206 |
+
])
|
207 |
+
|
208 |
+
def __len__(self):
|
209 |
+
pass
|
210 |
+
|
211 |
+
def __getitem__(self, item):
|
212 |
+
pass
|
213 |
+
|
214 |
+
def train_transform_ul(self, img):
|
215 |
+
wd, ht = img.size
|
216 |
+
st_size = 1.0 * min(wd, ht)
|
217 |
+
assert st_size >= self.c_size
|
218 |
+
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
|
219 |
+
img = F.crop(img, i, j, h, w)
|
220 |
+
|
221 |
+
if random.random() > 0.5:
|
222 |
+
img = F.hflip(img)
|
223 |
+
|
224 |
+
return self.trans(img)
|
225 |
+
|
226 |
+
|
227 |
+
class Crowd_UL_TC(Base_UL):
|
228 |
+
def __init__(self, root_path, crop_size, downsample_ratio=8, method='train_ul'):
|
229 |
+
super().__init__(root_path, crop_size, downsample_ratio)
|
230 |
+
self.method = method
|
231 |
+
if method not in ['train_ul']:
|
232 |
+
raise Exception("not implement")
|
233 |
+
|
234 |
+
self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
|
235 |
+
print('number of img [{}]: {}'.format(method, len(self.im_list)))
|
236 |
+
|
237 |
+
def __len__(self):
|
238 |
+
return len(self.im_list)
|
239 |
+
|
240 |
+
def __getitem__(self, item):
|
241 |
+
img_path = self.im_list[item]
|
242 |
+
name = os.path.basename(img_path).split('.')[0]
|
243 |
+
img = Image.open(img_path).convert('RGB')
|
244 |
+
#print("un_label {}", item)
|
245 |
+
|
246 |
+
return self.train_transform_ul(img)
|
247 |
+
|
248 |
+
|
249 |
+
def train_transform_ul(self, img):
|
250 |
+
wd, ht = img.size
|
251 |
+
st_size = 1.0 * min(wd, ht)
|
252 |
+
# resize the image to fit the crop size
|
253 |
+
if st_size < self.c_size:
|
254 |
+
rr = 1.0 * self.c_size / st_size
|
255 |
+
wd = round(wd * rr)
|
256 |
+
ht = round(ht * rr)
|
257 |
+
st_size = 1.0 * min(wd, ht)
|
258 |
+
img = img.resize((wd, ht), Image.BICUBIC)
|
259 |
+
|
260 |
+
assert st_size >= self.c_size, print(wd, ht)
|
261 |
+
|
262 |
+
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
|
263 |
+
img = F.crop(img, i, j, h, w)
|
264 |
+
if random.random() > 0.5:
|
265 |
+
img = F.hflip(img)
|
266 |
+
|
267 |
+
return self.trans(img),1
|
268 |
+
|
demo.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from network import pvt_cls as TCN
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
|
14 |
+
def demo(img_path):
|
15 |
+
# config
|
16 |
+
batch_size = 8
|
17 |
+
crop_size = 256
|
18 |
+
model_path = '/users/k21163430/workspace/TreeFormer/models/best_model.pth'
|
19 |
+
|
20 |
+
device = torch.device('cuda')
|
21 |
+
|
22 |
+
# prepare model
|
23 |
+
model = TCN.pvt_treeformer(pretrained=False)
|
24 |
+
model.to(device)
|
25 |
+
model.load_state_dict(torch.load(model_path, device))
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
# preprocess
|
29 |
+
img = Image.open(img_path).convert('RGB')
|
30 |
+
show_img = np.array(img)
|
31 |
+
wd, ht = img.size
|
32 |
+
st_size = 1.0 * min(wd, ht)
|
33 |
+
if st_size < crop_size:
|
34 |
+
rr = 1.0 * crop_size / st_size
|
35 |
+
wd = round(wd * rr)
|
36 |
+
ht = round(ht * rr)
|
37 |
+
st_size = 1.0 * min(wd, ht)
|
38 |
+
img = img.resize((wd, ht), Image.BICUBIC)
|
39 |
+
transform = transforms.Compose([
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
42 |
+
])
|
43 |
+
img = transform(img)
|
44 |
+
img = img.unsqueeze(0)
|
45 |
+
|
46 |
+
# model forward
|
47 |
+
with torch.no_grad():
|
48 |
+
inputs = img.to(device)
|
49 |
+
crop_imgs, crop_masks = [], []
|
50 |
+
b, c, h, w = inputs.size()
|
51 |
+
rh, rw = crop_size, crop_size
|
52 |
+
|
53 |
+
for i in range(0, h, rh):
|
54 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
55 |
+
|
56 |
+
for j in range(0, w, rw):
|
57 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
58 |
+
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
|
59 |
+
mask = torch.zeros([b, 1, h, w]).to(device)
|
60 |
+
mask[:, :, gis:gie, gjs:gje].fill_(1.0)
|
61 |
+
crop_masks.append(mask)
|
62 |
+
crop_imgs, crop_masks = map(lambda x: torch.cat(
|
63 |
+
x, dim=0), (crop_imgs, crop_masks))
|
64 |
+
|
65 |
+
crop_preds = []
|
66 |
+
nz, bz = crop_imgs.size(0), batch_size
|
67 |
+
for i in range(0, nz, bz):
|
68 |
+
|
69 |
+
gs, gt = i, min(nz, i + bz)
|
70 |
+
crop_pred, _ = model(crop_imgs[gs:gt])
|
71 |
+
crop_pred = crop_pred[0]
|
72 |
+
|
73 |
+
_, _, h1, w1 = crop_pred.size()
|
74 |
+
crop_pred = F.interpolate(crop_pred, size=(
|
75 |
+
h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
|
76 |
+
crop_preds.append(crop_pred)
|
77 |
+
crop_preds = torch.cat(crop_preds, dim=0)
|
78 |
+
|
79 |
+
# splice them to the original size
|
80 |
+
idx = 0
|
81 |
+
pred_map = torch.zeros([b, 1, h, w]).to(device)
|
82 |
+
for i in range(0, h, rh):
|
83 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
84 |
+
for j in range(0, w, rw):
|
85 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
86 |
+
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
|
87 |
+
idx += 1
|
88 |
+
# for the overlapping area, compute average value
|
89 |
+
mask = crop_masks.sum(dim=0).unsqueeze(0)
|
90 |
+
outputs = pred_map / mask
|
91 |
+
|
92 |
+
outputs = F.interpolate(outputs, size=(
|
93 |
+
h, w), mode='bilinear', align_corners=True)/4
|
94 |
+
outputs = pred_map / mask
|
95 |
+
model_output = round(torch.sum(outputs).item())
|
96 |
+
|
97 |
+
print("{}: {}".format(img_path, model_output))
|
98 |
+
outputs = outputs.squeeze().cpu().numpy()
|
99 |
+
outputs = (outputs - np.min(outputs)) / \
|
100 |
+
(np.max(outputs) - np.min(outputs))
|
101 |
+
|
102 |
+
show_img = show_img / 255.0
|
103 |
+
show_img = show_img * 0.2 + outputs[:, :, None] * 0.8
|
104 |
+
|
105 |
+
return model_output, show_img
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
# test
|
110 |
+
# img_path = sys.argv[1]
|
111 |
+
# demo(img)
|
112 |
+
|
113 |
+
# Launch a gr.Interface
|
114 |
+
gr_demo = gr.Interface(fn=demo,
|
115 |
+
inputs=gr.Image(source="upload",
|
116 |
+
type="filepath",
|
117 |
+
label="Input Image",
|
118 |
+
width=768,
|
119 |
+
height=768,
|
120 |
+
),
|
121 |
+
outputs=[
|
122 |
+
gr.Number(label="Predicted Tree Count"),
|
123 |
+
gr.Image(label="Density Map",
|
124 |
+
width=768,
|
125 |
+
height=768,
|
126 |
+
)
|
127 |
+
],
|
128 |
+
title="TreeFormer",
|
129 |
+
description="TreeFormer is a semi-supervised transformer-based framework for tree counting from a single high resolution image. Upload an image and TreeFormer will predict the number of trees in the image and generate a density map of the trees.",
|
130 |
+
article="This work has been developed a spart of the ReSET project which has received funding from the European Union's Horizon 2020 FET Proactive Programme under grant agreement No 101017857. The contents of this publication are the sole responsibility of the ReSET consortium and do not necessarily reflect the opinion of the European Union.",
|
131 |
+
examples=[
|
132 |
+
["./examples/IMG_101.jpg"],
|
133 |
+
["./examples/IMG_125.jpg"],
|
134 |
+
["./examples/IMG_138.jpg"],
|
135 |
+
["./examples/IMG_180.jpg"],
|
136 |
+
["./examples/IMG_18.jpg"],
|
137 |
+
["./examples/IMG_206.jpg"],
|
138 |
+
["./examples/IMG_223.jpg"],
|
139 |
+
["./examples/IMG_247.jpg"],
|
140 |
+
["./examples/IMG_270.jpg"],
|
141 |
+
["./examples/IMG_306.jpg"],
|
142 |
+
],
|
143 |
+
# cache_examples=True,
|
144 |
+
examples_per_page=10,
|
145 |
+
allow_flagging=False,
|
146 |
+
theme=gr.themes.Default(),
|
147 |
+
)
|
148 |
+
gr_demo.launch(share=True, server_port=7861, favicon_path="./assets/reset.png")
|
examples/IMG_101.jpg
ADDED
examples/IMG_125.jpg
ADDED
examples/IMG_138.jpg
ADDED
examples/IMG_18.jpg
ADDED
examples/IMG_180.jpg
ADDED
examples/IMG_206.jpg
ADDED
examples/IMG_223.jpg
ADDED
examples/IMG_247.jpg
ADDED
examples/IMG_270.jpg
ADDED
examples/IMG_306.jpg
ADDED
losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
losses/bregman_pytorch.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Rewrite ot.bregman.sinkhorn in Python Optimal Transport (https://pythonot.github.io/_modules/ot/bregman.html#sinkhorn)
|
4 |
+
using pytorch operations.
|
5 |
+
Bregman projections for regularized OT (Sinkhorn distance).
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
M_EPS = 1e-16
|
11 |
+
|
12 |
+
|
13 |
+
def sinkhorn(a, b, C, reg=1e-1, method='sinkhorn', maxIter=1000, tau=1e3,
|
14 |
+
stopThr=1e-9, verbose=False, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
|
15 |
+
"""
|
16 |
+
Solve the entropic regularization optimal transport
|
17 |
+
The input should be PyTorch tensors
|
18 |
+
The function solves the following optimization problem:
|
19 |
+
|
20 |
+
.. math::
|
21 |
+
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
|
22 |
+
s.t. \gamma 1 = a
|
23 |
+
\gamma^T 1= b
|
24 |
+
\gamma\geq 0
|
25 |
+
where :
|
26 |
+
- C is the (ns,nt) metric cost matrix
|
27 |
+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
|
28 |
+
- a and b are target and source measures (sum to 1)
|
29 |
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
a : torch.tensor (na,)
|
34 |
+
samples measure in the target domain
|
35 |
+
b : torch.tensor (nb,)
|
36 |
+
samples in the source domain
|
37 |
+
C : torch.tensor (na,nb)
|
38 |
+
loss matrix
|
39 |
+
reg : float
|
40 |
+
Regularization term > 0
|
41 |
+
method : str
|
42 |
+
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
|
43 |
+
'sinkhorn_epsilon_scaling', see those function for specific parameters
|
44 |
+
maxIter : int, optional
|
45 |
+
Max number of iterations
|
46 |
+
stopThr : float, optional
|
47 |
+
Stop threshol on error ( > 0 )
|
48 |
+
verbose : bool, optional
|
49 |
+
Print information along iterations
|
50 |
+
log : bool, optional
|
51 |
+
record log if True
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
gamma : (na x nb) torch.tensor
|
56 |
+
Optimal transportation matrix for the given parameters
|
57 |
+
log : dict
|
58 |
+
log dictionary return only if log==True in parameters
|
59 |
+
|
60 |
+
References
|
61 |
+
----------
|
62 |
+
[1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
|
63 |
+
See Also
|
64 |
+
--------
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
if method.lower() == 'sinkhorn':
|
69 |
+
return sinkhorn_knopp(a, b, C, reg, maxIter=maxIter,
|
70 |
+
stopThr=stopThr, verbose=verbose, log=log,
|
71 |
+
warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
|
72 |
+
**kwargs)
|
73 |
+
elif method.lower() == 'sinkhorn_stabilized':
|
74 |
+
return sinkhorn_stabilized(a, b, C, reg, maxIter=maxIter, tau=tau,
|
75 |
+
stopThr=stopThr, verbose=verbose, log=log,
|
76 |
+
warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
|
77 |
+
**kwargs)
|
78 |
+
elif method.lower() == 'sinkhorn_epsilon_scaling':
|
79 |
+
return sinkhorn_epsilon_scaling(a, b, C, reg,
|
80 |
+
maxIter=maxIter, maxInnerIter=100, tau=tau,
|
81 |
+
scaling_base=0.75, scaling_coef=None, stopThr=stopThr,
|
82 |
+
verbose=False, log=log, warm_start=warm_start, eval_freq=eval_freq,
|
83 |
+
print_freq=print_freq, **kwargs)
|
84 |
+
else:
|
85 |
+
raise ValueError("Unknown method '%s'." % method)
|
86 |
+
|
87 |
+
|
88 |
+
def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9,
|
89 |
+
verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
|
90 |
+
"""
|
91 |
+
Solve the entropic regularization optimal transport
|
92 |
+
The input should be PyTorch tensors
|
93 |
+
The function solves the following optimization problem:
|
94 |
+
|
95 |
+
.. math::
|
96 |
+
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
|
97 |
+
s.t. \gamma 1 = a
|
98 |
+
\gamma^T 1= b
|
99 |
+
\gamma\geq 0
|
100 |
+
where :
|
101 |
+
- C is the (ns,nt) metric cost matrix
|
102 |
+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
|
103 |
+
- a and b are target and source measures (sum to 1)
|
104 |
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
a : torch.tensor (na,)
|
109 |
+
samples measure in the target domain
|
110 |
+
b : torch.tensor (nb,)
|
111 |
+
samples in the source domain
|
112 |
+
C : torch.tensor (na,nb)
|
113 |
+
loss matrix
|
114 |
+
reg : float
|
115 |
+
Regularization term > 0
|
116 |
+
maxIter : int, optional
|
117 |
+
Max number of iterations
|
118 |
+
stopThr : float, optional
|
119 |
+
Stop threshol on error ( > 0 )
|
120 |
+
verbose : bool, optional
|
121 |
+
Print information along iterations
|
122 |
+
log : bool, optional
|
123 |
+
record log if True
|
124 |
+
|
125 |
+
Returns
|
126 |
+
-------
|
127 |
+
gamma : (na x nb) torch.tensor
|
128 |
+
Optimal transportation matrix for the given parameters
|
129 |
+
log : dict
|
130 |
+
log dictionary return only if log==True in parameters
|
131 |
+
|
132 |
+
References
|
133 |
+
----------
|
134 |
+
[1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
|
135 |
+
See Also
|
136 |
+
--------
|
137 |
+
|
138 |
+
"""
|
139 |
+
|
140 |
+
device = a.device
|
141 |
+
na, nb = C.shape
|
142 |
+
|
143 |
+
assert na >= 1 and nb >= 1, 'C needs to be 2d'
|
144 |
+
assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
|
145 |
+
assert reg > 0, 'reg should be greater than 0'
|
146 |
+
assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
|
147 |
+
|
148 |
+
if log:
|
149 |
+
log = {'err': []}
|
150 |
+
|
151 |
+
if warm_start is not None:
|
152 |
+
u = warm_start['u']
|
153 |
+
v = warm_start['v']
|
154 |
+
else:
|
155 |
+
u = torch.ones(na, dtype=a.dtype).to(device) / na
|
156 |
+
v = torch.ones(nb, dtype=b.dtype).to(device) / nb
|
157 |
+
|
158 |
+
K = torch.empty(C.shape, dtype=C.dtype).to(device)
|
159 |
+
torch.div(C, -reg, out=K)
|
160 |
+
torch.exp(K, out=K)
|
161 |
+
|
162 |
+
b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
|
163 |
+
|
164 |
+
it = 1
|
165 |
+
err = 1
|
166 |
+
|
167 |
+
# allocate memory beforehand
|
168 |
+
KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
|
169 |
+
Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
|
170 |
+
|
171 |
+
while (err > stopThr and it <= maxIter):
|
172 |
+
upre, vpre = u, v
|
173 |
+
torch.matmul(u, K, out=KTu)
|
174 |
+
v = torch.div(b, KTu + M_EPS)
|
175 |
+
torch.matmul(K, v, out=Kv)
|
176 |
+
u = torch.div(a, Kv + M_EPS)
|
177 |
+
|
178 |
+
if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
|
179 |
+
torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
|
180 |
+
print('Warning: numerical errors at iteration', it)
|
181 |
+
u, v = upre, vpre
|
182 |
+
break
|
183 |
+
|
184 |
+
if log and it % eval_freq == 0:
|
185 |
+
# we can speed up the process by checking for the error only all
|
186 |
+
# the eval_freq iterations
|
187 |
+
# below is equivalent to:
|
188 |
+
# b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0)
|
189 |
+
# but with more memory efficient
|
190 |
+
b_hat = torch.matmul(u, K) * v
|
191 |
+
err = (b - b_hat).pow(2).sum().item()
|
192 |
+
# err = (b - b_hat).abs().sum().item()
|
193 |
+
log['err'].append(err)
|
194 |
+
|
195 |
+
if verbose and it % print_freq == 0:
|
196 |
+
print('iteration {:5d}, constraint error {:5e}'.format(it, err))
|
197 |
+
|
198 |
+
it += 1
|
199 |
+
|
200 |
+
if log:
|
201 |
+
log['u'] = u
|
202 |
+
log['v'] = v
|
203 |
+
log['alpha'] = reg * torch.log(u + M_EPS)
|
204 |
+
log['beta'] = reg * torch.log(v + M_EPS)
|
205 |
+
|
206 |
+
# transport plan
|
207 |
+
P = u.reshape(-1, 1) * K * v.reshape(1, -1)
|
208 |
+
if log:
|
209 |
+
return P, log
|
210 |
+
else:
|
211 |
+
return P
|
212 |
+
|
213 |
+
|
214 |
+
def sinkhorn_stabilized(a, b, C, reg=1e-1, maxIter=1000, tau=1e3, stopThr=1e-9,
|
215 |
+
verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
|
216 |
+
"""
|
217 |
+
Solve the entropic regularization OT problem with log stabilization
|
218 |
+
The function solves the following optimization problem:
|
219 |
+
|
220 |
+
.. math::
|
221 |
+
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
|
222 |
+
s.t. \gamma 1 = a
|
223 |
+
\gamma^T 1= b
|
224 |
+
\gamma\geq 0
|
225 |
+
where :
|
226 |
+
- C is the (ns,nt) metric cost matrix
|
227 |
+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
|
228 |
+
- a and b are target and source measures (sum to 1)
|
229 |
+
|
230 |
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]
|
231 |
+
but with the log stabilization proposed in [3] an defined in [2] (Algo 3.1)
|
232 |
+
|
233 |
+
Parameters
|
234 |
+
----------
|
235 |
+
a : torch.tensor (na,)
|
236 |
+
samples measure in the target domain
|
237 |
+
b : torch.tensor (nb,)
|
238 |
+
samples in the source domain
|
239 |
+
C : torch.tensor (na,nb)
|
240 |
+
loss matrix
|
241 |
+
reg : float
|
242 |
+
Regularization term > 0
|
243 |
+
tau : float
|
244 |
+
thershold for max value in u or v for log scaling
|
245 |
+
maxIter : int, optional
|
246 |
+
Max number of iterations
|
247 |
+
stopThr : float, optional
|
248 |
+
Stop threshol on error ( > 0 )
|
249 |
+
verbose : bool, optional
|
250 |
+
Print information along iterations
|
251 |
+
log : bool, optional
|
252 |
+
record log if True
|
253 |
+
|
254 |
+
Returns
|
255 |
+
-------
|
256 |
+
gamma : (na x nb) torch.tensor
|
257 |
+
Optimal transportation matrix for the given parameters
|
258 |
+
log : dict
|
259 |
+
log dictionary return only if log==True in parameters
|
260 |
+
|
261 |
+
References
|
262 |
+
----------
|
263 |
+
[1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
|
264 |
+
[2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
|
265 |
+
[3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
|
266 |
+
|
267 |
+
See Also
|
268 |
+
--------
|
269 |
+
|
270 |
+
"""
|
271 |
+
|
272 |
+
device = a.device
|
273 |
+
na, nb = C.shape
|
274 |
+
|
275 |
+
assert na >= 1 and nb >= 1, 'C needs to be 2d'
|
276 |
+
assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
|
277 |
+
assert reg > 0, 'reg should be greater than 0'
|
278 |
+
assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
|
279 |
+
|
280 |
+
if log:
|
281 |
+
log = {'err': []}
|
282 |
+
|
283 |
+
if warm_start is not None:
|
284 |
+
alpha = warm_start['alpha']
|
285 |
+
beta = warm_start['beta']
|
286 |
+
else:
|
287 |
+
alpha = torch.zeros(na, dtype=a.dtype).to(device)
|
288 |
+
beta = torch.zeros(nb, dtype=b.dtype).to(device)
|
289 |
+
|
290 |
+
u = torch.ones(na, dtype=a.dtype).to(device) / na
|
291 |
+
v = torch.ones(nb, dtype=b.dtype).to(device) / nb
|
292 |
+
|
293 |
+
def update_K(alpha, beta):
|
294 |
+
"""log space computation"""
|
295 |
+
"""memory efficient"""
|
296 |
+
torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K)
|
297 |
+
torch.add(K, -C, out=K)
|
298 |
+
torch.div(K, reg, out=K)
|
299 |
+
torch.exp(K, out=K)
|
300 |
+
|
301 |
+
def update_P(alpha, beta, u, v, ab_updated=False):
|
302 |
+
"""log space P (gamma) computation"""
|
303 |
+
torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P)
|
304 |
+
torch.add(P, -C, out=P)
|
305 |
+
torch.div(P, reg, out=P)
|
306 |
+
if not ab_updated:
|
307 |
+
torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P)
|
308 |
+
torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P)
|
309 |
+
torch.exp(P, out=P)
|
310 |
+
|
311 |
+
K = torch.empty(C.shape, dtype=C.dtype).to(device)
|
312 |
+
update_K(alpha, beta)
|
313 |
+
|
314 |
+
b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
|
315 |
+
|
316 |
+
it = 1
|
317 |
+
err = 1
|
318 |
+
ab_updated = False
|
319 |
+
|
320 |
+
# allocate memory beforehand
|
321 |
+
KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
|
322 |
+
Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
|
323 |
+
P = torch.empty(C.shape, dtype=C.dtype).to(device)
|
324 |
+
|
325 |
+
while (err > stopThr and it <= maxIter):
|
326 |
+
upre, vpre = u, v
|
327 |
+
torch.matmul(u, K, out=KTu)
|
328 |
+
v = torch.div(b, KTu + M_EPS)
|
329 |
+
torch.matmul(K, v, out=Kv)
|
330 |
+
u = torch.div(a, Kv + M_EPS)
|
331 |
+
|
332 |
+
ab_updated = False
|
333 |
+
# remove numerical problems and store them in K
|
334 |
+
if u.abs().sum() > tau or v.abs().sum() > tau:
|
335 |
+
alpha += reg * torch.log(u + M_EPS)
|
336 |
+
beta += reg * torch.log(v + M_EPS)
|
337 |
+
u.fill_(1. / na)
|
338 |
+
v.fill_(1. / nb)
|
339 |
+
update_K(alpha, beta)
|
340 |
+
ab_updated = True
|
341 |
+
|
342 |
+
if log and it % eval_freq == 0:
|
343 |
+
# we can speed up the process by checking for the error only all
|
344 |
+
# the eval_freq iterations
|
345 |
+
update_P(alpha, beta, u, v, ab_updated)
|
346 |
+
b_hat = torch.sum(P, 0)
|
347 |
+
err = (b - b_hat).pow(2).sum().item()
|
348 |
+
log['err'].append(err)
|
349 |
+
|
350 |
+
if verbose and it % print_freq == 0:
|
351 |
+
print('iteration {:5d}, constraint error {:5e}'.format(it, err))
|
352 |
+
|
353 |
+
it += 1
|
354 |
+
|
355 |
+
if log:
|
356 |
+
log['u'] = u
|
357 |
+
log['v'] = v
|
358 |
+
log['alpha'] = alpha + reg * torch.log(u + M_EPS)
|
359 |
+
log['beta'] = beta + reg * torch.log(v + M_EPS)
|
360 |
+
|
361 |
+
# transport plan
|
362 |
+
update_P(alpha, beta, u, v, False)
|
363 |
+
|
364 |
+
if log:
|
365 |
+
return P, log
|
366 |
+
else:
|
367 |
+
return P
|
368 |
+
|
369 |
+
|
370 |
+
def sinkhorn_epsilon_scaling(a, b, C, reg=1e-1, maxIter=100, maxInnerIter=100, tau=1e3, scaling_base=0.75,
|
371 |
+
scaling_coef=None, stopThr=1e-9, verbose=False, log=False, warm_start=None, eval_freq=10,
|
372 |
+
print_freq=200, **kwargs):
|
373 |
+
"""
|
374 |
+
Solve the entropic regularization OT problem with log stabilization
|
375 |
+
The function solves the following optimization problem:
|
376 |
+
|
377 |
+
.. math::
|
378 |
+
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
|
379 |
+
s.t. \gamma 1 = a
|
380 |
+
\gamma^T 1= b
|
381 |
+
\gamma\geq 0
|
382 |
+
where :
|
383 |
+
- C is the (ns,nt) metric cost matrix
|
384 |
+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
|
385 |
+
- a and b are target and source measures (sum to 1)
|
386 |
+
|
387 |
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
|
388 |
+
scaling algorithm as proposed in [1] but with the log stabilization
|
389 |
+
proposed in [3] and the log scaling proposed in [2] algorithm 3.2
|
390 |
+
|
391 |
+
Parameters
|
392 |
+
----------
|
393 |
+
a : torch.tensor (na,)
|
394 |
+
samples measure in the target domain
|
395 |
+
b : torch.tensor (nb,)
|
396 |
+
samples in the source domain
|
397 |
+
C : torch.tensor (na,nb)
|
398 |
+
loss matrix
|
399 |
+
reg : float
|
400 |
+
Regularization term > 0
|
401 |
+
tau : float
|
402 |
+
thershold for max value in u or v for log scaling
|
403 |
+
maxIter : int, optional
|
404 |
+
Max number of iterations
|
405 |
+
stopThr : float, optional
|
406 |
+
Stop threshol on error ( > 0 )
|
407 |
+
verbose : bool, optional
|
408 |
+
Print information along iterations
|
409 |
+
log : bool, optional
|
410 |
+
record log if True
|
411 |
+
|
412 |
+
Returns
|
413 |
+
-------
|
414 |
+
gamma : (na x nb) torch.tensor
|
415 |
+
Optimal transportation matrix for the given parameters
|
416 |
+
log : dict
|
417 |
+
log dictionary return only if log==True in parameters
|
418 |
+
|
419 |
+
References
|
420 |
+
----------
|
421 |
+
[1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
|
422 |
+
[2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
|
423 |
+
[3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
|
424 |
+
|
425 |
+
See Also
|
426 |
+
--------
|
427 |
+
|
428 |
+
"""
|
429 |
+
|
430 |
+
na, nb = C.shape
|
431 |
+
|
432 |
+
assert na >= 1 and nb >= 1, 'C needs to be 2d'
|
433 |
+
assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
|
434 |
+
assert reg > 0, 'reg should be greater than 0'
|
435 |
+
assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
|
436 |
+
|
437 |
+
def get_reg(it, reg, pre_reg):
|
438 |
+
if it == 1:
|
439 |
+
return scaling_coef
|
440 |
+
else:
|
441 |
+
if (pre_reg - reg) * scaling_base < M_EPS:
|
442 |
+
return reg
|
443 |
+
else:
|
444 |
+
return (pre_reg - reg) * scaling_base + reg
|
445 |
+
|
446 |
+
if scaling_coef is None:
|
447 |
+
scaling_coef = C.max() + reg
|
448 |
+
|
449 |
+
it = 1
|
450 |
+
err = 1
|
451 |
+
running_reg = scaling_coef
|
452 |
+
|
453 |
+
if log:
|
454 |
+
log = {'err': []}
|
455 |
+
|
456 |
+
warm_start = None
|
457 |
+
|
458 |
+
while (err > stopThr and it <= maxIter):
|
459 |
+
running_reg = get_reg(it, reg, running_reg)
|
460 |
+
P, _log = sinkhorn_stabilized(a, b, C, running_reg, maxIter=maxInnerIter, tau=tau,
|
461 |
+
stopThr=stopThr, verbose=False, log=True,
|
462 |
+
warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
|
463 |
+
**kwargs)
|
464 |
+
|
465 |
+
warm_start = {}
|
466 |
+
warm_start['alpha'] = _log['alpha']
|
467 |
+
warm_start['beta'] = _log['beta']
|
468 |
+
|
469 |
+
primal_val = (C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum()
|
470 |
+
dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * b).sum() - reg * P.sum()
|
471 |
+
err = primal_val - dual_val
|
472 |
+
log['err'].append(err)
|
473 |
+
|
474 |
+
if verbose and it % print_freq == 0:
|
475 |
+
print('iteration {:5d}, constraint error {:5e}'.format(it, err))
|
476 |
+
|
477 |
+
it += 1
|
478 |
+
|
479 |
+
if log:
|
480 |
+
log['alpha'] = _log['alpha']
|
481 |
+
log['beta'] = _log['beta']
|
482 |
+
return P, log
|
483 |
+
else:
|
484 |
+
return P
|
losses/consistency_loss.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn as nn
|
5 |
+
from losses import ramps
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class consistency_weight(object):
|
10 |
+
"""
|
11 |
+
ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup']
|
12 |
+
"""
|
13 |
+
def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'):
|
14 |
+
self.final_w = final_w
|
15 |
+
self.iters_per_epoch = iters_per_epoch
|
16 |
+
self.rampup_starts = rampup_starts * iters_per_epoch
|
17 |
+
self.rampup_ends = rampup_ends * iters_per_epoch
|
18 |
+
self.rampup_length = (self.rampup_ends - self.rampup_starts)
|
19 |
+
self.rampup_func = getattr(ramps, ramp_type)
|
20 |
+
self.current_rampup = 0
|
21 |
+
|
22 |
+
def __call__(self, epoch, curr_iter):
|
23 |
+
cur_total_iter = self.iters_per_epoch * epoch + curr_iter
|
24 |
+
if cur_total_iter < self.rampup_starts:
|
25 |
+
return 0
|
26 |
+
self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length)
|
27 |
+
return self.final_w * self.current_rampup
|
28 |
+
|
29 |
+
|
30 |
+
def CE_loss(input_logits, target_targets, ignore_index, temperature=1):
|
31 |
+
return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index)
|
32 |
+
|
33 |
+
# for FocalLoss
|
34 |
+
def softmax_helper(x):
|
35 |
+
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
|
36 |
+
rpt = [1 for _ in range(len(x.size()))]
|
37 |
+
rpt[1] = x.size(1)
|
38 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
39 |
+
e_x = torch.exp(x - x_max)
|
40 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
41 |
+
|
42 |
+
def get_alpha(supervised_loader):
|
43 |
+
# get number of classes
|
44 |
+
num_labels = 0
|
45 |
+
for image_batch, label_batch in supervised_loader:
|
46 |
+
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
|
47 |
+
l_unique = torch.unique(label_batch.data)
|
48 |
+
list_unique = [element.item() for element in l_unique.flatten()]
|
49 |
+
num_labels = max(max(list_unique),num_labels)
|
50 |
+
num_classes = num_labels + 1
|
51 |
+
# count class occurrences
|
52 |
+
alpha = [0 for i in range(num_classes)]
|
53 |
+
for image_batch, label_batch in supervised_loader:
|
54 |
+
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
|
55 |
+
l_unique = torch.unique(label_batch.data)
|
56 |
+
list_unique = [element.item() for element in l_unique.flatten()]
|
57 |
+
l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
|
58 |
+
list_count = [count.item() for count in l_unique_count.flatten()]
|
59 |
+
for index in list_unique:
|
60 |
+
alpha[index] += list_count[list_unique.index(index)]
|
61 |
+
return alpha
|
62 |
+
|
63 |
+
# for FocalLoss
|
64 |
+
def softmax_helper(x):
|
65 |
+
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
|
66 |
+
rpt = [1 for _ in range(len(x.size()))]
|
67 |
+
rpt[1] = x.size(1)
|
68 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
69 |
+
e_x = torch.exp(x - x_max)
|
70 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
71 |
+
|
72 |
+
|
73 |
+
class FocalLoss(nn.Module):
|
74 |
+
"""
|
75 |
+
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
|
76 |
+
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
|
77 |
+
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
|
78 |
+
Focal_Loss= -1*alpha*(1-pt)*log(pt)
|
79 |
+
:param num_class:
|
80 |
+
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
|
81 |
+
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
|
82 |
+
focus on hard misclassified example
|
83 |
+
:param smooth: (float,double) smooth value when cross entropy
|
84 |
+
:param balance_index: (int) balance class index, should be specific when alpha is float
|
85 |
+
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
|
89 |
+
super(FocalLoss, self).__init__()
|
90 |
+
self.apply_nonlin = apply_nonlin
|
91 |
+
self.alpha = alpha
|
92 |
+
self.gamma = gamma
|
93 |
+
self.balance_index = balance_index
|
94 |
+
self.smooth = smooth
|
95 |
+
self.size_average = size_average
|
96 |
+
|
97 |
+
if self.smooth is not None:
|
98 |
+
if self.smooth < 0 or self.smooth > 1.0:
|
99 |
+
raise ValueError('smooth value should be in [0,1]')
|
100 |
+
|
101 |
+
def forward(self, logit, target):
|
102 |
+
if self.apply_nonlin is not None:
|
103 |
+
logit = self.apply_nonlin(logit)
|
104 |
+
num_class = logit.shape[1]
|
105 |
+
|
106 |
+
if logit.dim() > 2:
|
107 |
+
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
|
108 |
+
logit = logit.view(logit.size(0), logit.size(1), -1)
|
109 |
+
logit = logit.permute(0, 2, 1).contiguous()
|
110 |
+
logit = logit.view(-1, logit.size(-1))
|
111 |
+
target = torch.squeeze(target, 1)
|
112 |
+
target = target.view(-1, 1)
|
113 |
+
|
114 |
+
valid_mask = None
|
115 |
+
if self.ignore_index is not None:
|
116 |
+
valid_mask = target != self.ignore_index
|
117 |
+
target = target * valid_mask
|
118 |
+
|
119 |
+
alpha = self.alpha
|
120 |
+
|
121 |
+
if alpha is None:
|
122 |
+
alpha = torch.ones(num_class, 1)
|
123 |
+
elif isinstance(alpha, (list, np.ndarray)):
|
124 |
+
assert len(alpha) == num_class
|
125 |
+
alpha = torch.FloatTensor(alpha).view(num_class, 1)
|
126 |
+
alpha = alpha / alpha.sum()
|
127 |
+
alpha = 1/alpha # inverse of class frequency
|
128 |
+
elif isinstance(alpha, float):
|
129 |
+
alpha = torch.ones(num_class, 1)
|
130 |
+
alpha = alpha * (1 - self.alpha)
|
131 |
+
alpha[self.balance_index] = self.alpha
|
132 |
+
|
133 |
+
else:
|
134 |
+
raise TypeError('Not support alpha type')
|
135 |
+
|
136 |
+
if alpha.device != logit.device:
|
137 |
+
alpha = alpha.to(logit.device)
|
138 |
+
|
139 |
+
idx = target.cpu().long()
|
140 |
+
|
141 |
+
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
|
142 |
+
|
143 |
+
# to resolve error in idx in scatter_
|
144 |
+
idx[idx==225]=0
|
145 |
+
|
146 |
+
one_hot_key = one_hot_key.scatter_(1, idx, 1)
|
147 |
+
if one_hot_key.device != logit.device:
|
148 |
+
one_hot_key = one_hot_key.to(logit.device)
|
149 |
+
|
150 |
+
if self.smooth:
|
151 |
+
one_hot_key = torch.clamp(
|
152 |
+
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
|
153 |
+
pt = (one_hot_key * logit).sum(1) + self.smooth
|
154 |
+
logpt = pt.log()
|
155 |
+
|
156 |
+
gamma = self.gamma
|
157 |
+
|
158 |
+
alpha = alpha[idx]
|
159 |
+
alpha = torch.squeeze(alpha)
|
160 |
+
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
|
161 |
+
|
162 |
+
if valid_mask is not None:
|
163 |
+
loss = loss * valid_mask.squeeze()
|
164 |
+
|
165 |
+
if self.size_average:
|
166 |
+
loss = loss.mean()
|
167 |
+
else:
|
168 |
+
loss = loss.sum()
|
169 |
+
return loss
|
170 |
+
|
171 |
+
|
172 |
+
class abCE_loss(nn.Module):
|
173 |
+
"""
|
174 |
+
Annealed-Bootstrapped cross-entropy loss
|
175 |
+
"""
|
176 |
+
def __init__(self, iters_per_epoch, epochs, num_classes, weight=None,
|
177 |
+
reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'):
|
178 |
+
super(abCE_loss, self).__init__()
|
179 |
+
self.weight = torch.FloatTensor(weight) if weight is not None else weight
|
180 |
+
self.reduction = reduction
|
181 |
+
self.thresh = thresh
|
182 |
+
self.min_kept = min_kept
|
183 |
+
self.ramp_type = ramp_type
|
184 |
+
|
185 |
+
if ramp_type is not None:
|
186 |
+
self.rampup_func = getattr(ramps, ramp_type)
|
187 |
+
self.iters_per_epoch = iters_per_epoch
|
188 |
+
self.num_classes = num_classes
|
189 |
+
self.start = 1/num_classes
|
190 |
+
self.end = 0.9
|
191 |
+
self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch
|
192 |
+
|
193 |
+
def threshold(self, curr_iter, epoch):
|
194 |
+
cur_total_iter = self.iters_per_epoch * epoch + curr_iter
|
195 |
+
current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters)
|
196 |
+
return current_rampup * (self.end - self.start) + self.start
|
197 |
+
|
198 |
+
def forward(self, predict, target, ignore_index, curr_iter, epoch):
|
199 |
+
batch_kept = self.min_kept * target.size(0)
|
200 |
+
prob_out = F.softmax(predict, dim=1)
|
201 |
+
tmp_target = target.clone()
|
202 |
+
tmp_target[tmp_target == ignore_index] = 0
|
203 |
+
prob = prob_out.gather(1, tmp_target.unsqueeze(1))
|
204 |
+
mask = target.contiguous().view(-1, ) != ignore_index
|
205 |
+
sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
|
206 |
+
|
207 |
+
if self.ramp_type is not None:
|
208 |
+
thresh = self.threshold(curr_iter=curr_iter, epoch=epoch)
|
209 |
+
else:
|
210 |
+
thresh = self.thresh
|
211 |
+
|
212 |
+
min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
|
213 |
+
threshold = max(min_threshold, thresh)
|
214 |
+
loss_matrix = F.cross_entropy(predict, target,
|
215 |
+
weight=self.weight.to(predict.device) if self.weight is not None else None,
|
216 |
+
ignore_index=ignore_index, reduction='none')
|
217 |
+
loss_matirx = loss_matrix.contiguous().view(-1, )
|
218 |
+
sort_loss_matirx = loss_matirx[mask][sort_indices]
|
219 |
+
select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
|
220 |
+
if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
|
221 |
+
return select_loss_matrix.sum()
|
222 |
+
elif self.reduction == 'mean':
|
223 |
+
return select_loss_matrix.mean()
|
224 |
+
else:
|
225 |
+
raise NotImplementedError('Reduction Error!')
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
|
230 |
+
assert inputs.requires_grad == True and targets.requires_grad == False
|
231 |
+
assert inputs.size() == targets.size() # (batch_size * num_classes * H * W)
|
232 |
+
inputs = F.softmax(inputs, dim=1)
|
233 |
+
if use_softmax:
|
234 |
+
targets = F.softmax(targets, dim=1)
|
235 |
+
|
236 |
+
if conf_mask:
|
237 |
+
loss_mat = F.mse_loss(inputs, targets, reduction='none')
|
238 |
+
mask = (targets.max(1)[0] > threshold)
|
239 |
+
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
|
240 |
+
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
|
241 |
+
return loss_mat.mean()
|
242 |
+
else:
|
243 |
+
return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size
|
244 |
+
|
245 |
+
|
246 |
+
def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
|
247 |
+
assert inputs.requires_grad == True and targets.requires_grad == False
|
248 |
+
assert inputs.size() == targets.size()
|
249 |
+
|
250 |
+
if use_softmax:
|
251 |
+
targets = F.softmax(targets, dim=1)
|
252 |
+
if conf_mask:
|
253 |
+
loss_mat = F.kl_div(input_log_softmax, targets, reduction='none')
|
254 |
+
mask = (targets.max(1)[0] > threshold)
|
255 |
+
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
|
256 |
+
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
|
257 |
+
return loss_mat.sum() / mask.shape.numel()
|
258 |
+
else:
|
259 |
+
return F.kl_div(inputs, targets, reduction='mean')
|
260 |
+
|
261 |
+
|
262 |
+
def softmax_js_loss(inputs, targets, **_):
|
263 |
+
assert inputs.requires_grad == True and targets.requires_grad == False
|
264 |
+
assert inputs.size() == targets.size()
|
265 |
+
epsilon = 1e-5
|
266 |
+
|
267 |
+
M = (F.softmax(inputs, dim=1) + targets) * 0.5
|
268 |
+
kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean')
|
269 |
+
kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean')
|
270 |
+
return (kl1 + kl2) * 0.5
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8):
|
275 |
+
"""
|
276 |
+
Pair-wise loss in the sup. mat.
|
277 |
+
"""
|
278 |
+
if isinstance(unsup_outputs, list):
|
279 |
+
unsup_outputs = torch.stack(unsup_outputs)
|
280 |
+
|
281 |
+
# Only for a subset of the aux outputs to reduce computation and memory
|
282 |
+
unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))]
|
283 |
+
unsup_outputs = unsup_outputs[:nbr_of_pairs]
|
284 |
+
|
285 |
+
temp = torch.zeros_like(unsup_outputs) # For grad purposes
|
286 |
+
for i, u in enumerate(unsup_outputs):
|
287 |
+
temp[i] = F.softmax(u, dim=1)
|
288 |
+
mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs
|
289 |
+
pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance
|
290 |
+
pw_loss = pw_loss.sum(1) # Sum over classes
|
291 |
+
if size_average:
|
292 |
+
return pw_loss.mean()
|
293 |
+
return pw_loss.sum()
|
294 |
+
|
losses/dm_loss.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from losses.consistency_loss import *
|
5 |
+
from losses.ot_loss import OT_Loss
|
6 |
+
|
7 |
+
class DMLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(DMLoss, self).__init__()
|
10 |
+
self.DMLoss = 0.0
|
11 |
+
self.losses = {}
|
12 |
+
|
13 |
+
def forward(self, results, points, gt_discrete):
|
14 |
+
self.DMLoss = 0.0
|
15 |
+
self.losses = {}
|
16 |
+
|
17 |
+
if results is None:
|
18 |
+
self.DMLoss = 0.0
|
19 |
+
elif isinstance(results, list) and len(results) > 0:
|
20 |
+
count = 0
|
21 |
+
for i in range(len(results[0])):
|
22 |
+
with torch.set_grad_enabled(False):
|
23 |
+
preds_mean = (results[0][i])/len(results[0][0][0])
|
24 |
+
|
25 |
+
for j in range(len(results)):
|
26 |
+
var_sel = softmax_kl_loss(results[j][i], preds_mean)
|
27 |
+
exp_var = torch.exp(-var_sel)
|
28 |
+
consistency_dist = (preds_mean - results[j][i]) ** 2
|
29 |
+
temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
|
30 |
+
|
31 |
+
self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
|
32 |
+
self.DMLoss += temploss
|
33 |
+
|
34 |
+
# Compute counting loss.
|
35 |
+
count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1),
|
36 |
+
torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
|
37 |
+
epoch_count_loss.update(count_loss.item(), N)
|
38 |
+
|
39 |
+
# Compute OT loss.
|
40 |
+
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points)
|
41 |
+
|
42 |
+
ot_loss = ot_loss * self.args.ot
|
43 |
+
ot_obj_value = ot_obj_value * self.args.ot
|
44 |
+
epoch_ot_loss.update(ot_loss.item(), N)
|
45 |
+
epoch_ot_obj_value.update(ot_obj_value.item(), N)
|
46 |
+
epoch_wd.update(wd, N)
|
47 |
+
|
48 |
+
gd_count_tensor = (torch.from_numpy(gd_count).float()
|
49 |
+
.to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
|
50 |
+
|
51 |
+
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
|
52 |
+
tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
|
53 |
+
torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
|
54 |
+
epoch_tv_loss.update(tv_loss.item(), N)
|
55 |
+
|
56 |
+
count += 1
|
57 |
+
if count > 0:
|
58 |
+
self.multiconloss = self.multiconloss / count
|
59 |
+
|
60 |
+
|
61 |
+
return self.multiconloss
|
62 |
+
|
losses/multi_con_loss.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from losses.consistency_loss import *
|
5 |
+
|
6 |
+
|
7 |
+
class MultiConLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(MultiConLoss, self).__init__()
|
10 |
+
self.countloss_criterion = nn.MSELoss(reduction='sum')
|
11 |
+
self.multiconloss = 0.0
|
12 |
+
self.losses = {}
|
13 |
+
|
14 |
+
def forward(self, unlabeled_results):
|
15 |
+
self.multiconloss = 0.0
|
16 |
+
self.losses = {}
|
17 |
+
|
18 |
+
if unlabeled_results is None:
|
19 |
+
self.multiconloss = 0.0
|
20 |
+
elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
|
21 |
+
count = 0
|
22 |
+
for i in range(len(unlabeled_results[0])):
|
23 |
+
with torch.set_grad_enabled(False):
|
24 |
+
preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results)
|
25 |
+
for j in range(len(unlabeled_results)):
|
26 |
+
|
27 |
+
var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean)
|
28 |
+
exp_var = torch.exp(-var_sel)
|
29 |
+
consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2
|
30 |
+
temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
|
31 |
+
|
32 |
+
self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
|
33 |
+
self.multiconloss += temploss
|
34 |
+
|
35 |
+
count += 1
|
36 |
+
if count > 0:
|
37 |
+
self.multiconloss = self.multiconloss / count
|
38 |
+
|
39 |
+
|
40 |
+
return self.multiconloss
|
41 |
+
|
losses/ot_loss.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import Module
|
3 |
+
from .bregman_pytorch import sinkhorn
|
4 |
+
|
5 |
+
class OT_Loss(Module):
|
6 |
+
def __init__(self, c_size, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0):
|
7 |
+
super(OT_Loss, self).__init__()
|
8 |
+
assert c_size % stride == 0
|
9 |
+
|
10 |
+
self.c_size = c_size
|
11 |
+
self.device = device
|
12 |
+
self.norm_cood = norm_cood
|
13 |
+
self.num_of_iter_in_ot = num_of_iter_in_ot
|
14 |
+
self.reg = reg
|
15 |
+
|
16 |
+
# coordinate is same to image space, set to constant since crop size is same
|
17 |
+
self.cood = torch.arange(0, c_size, step=stride,
|
18 |
+
dtype=torch.float32, device=device) + stride / 2
|
19 |
+
self.density_size = self.cood.size(0)
|
20 |
+
self.cood.unsqueeze_(0) # [1, #cood]
|
21 |
+
if self.norm_cood:
|
22 |
+
self.cood = self.cood / c_size * 2 - 1 # map to [-1, 1]
|
23 |
+
self.output_size = self.cood.size(1)
|
24 |
+
|
25 |
+
|
26 |
+
def forward(self, normed_density, unnormed_density, points):
|
27 |
+
batch_size = normed_density.size(0)
|
28 |
+
assert len(points) == batch_size
|
29 |
+
assert self.output_size == normed_density.size(2)
|
30 |
+
loss = torch.zeros([1]).to(self.device)
|
31 |
+
ot_obj_values = torch.zeros([1]).to(self.device)
|
32 |
+
wd = 0 # wasserstain distance
|
33 |
+
for idx, im_points in enumerate(points):
|
34 |
+
if len(im_points) > 0:
|
35 |
+
# compute l2 square distance, it should be source target distance. [#gt, #cood * #cood]
|
36 |
+
if self.norm_cood:
|
37 |
+
im_points = im_points / self.c_size * 2 - 1 # map to [-1, 1]
|
38 |
+
x = im_points[:, 0].unsqueeze_(1) # [#gt, 1]
|
39 |
+
y = im_points[:, 1].unsqueeze_(1)
|
40 |
+
x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood # [#gt, #cood]
|
41 |
+
y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood
|
42 |
+
y_dis.unsqueeze_(2)
|
43 |
+
x_dis.unsqueeze_(1)
|
44 |
+
dis = y_dis + x_dis
|
45 |
+
dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood]
|
46 |
+
|
47 |
+
source_prob = normed_density[idx][0].view([-1]).detach()
|
48 |
+
target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device)
|
49 |
+
|
50 |
+
# use sinkhorn to solve OT, compute optimal beta.
|
51 |
+
P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True)
|
52 |
+
beta = log['beta'] # size is the same as source_prob: [#cood * #cood]
|
53 |
+
ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size]))
|
54 |
+
# compute the gradient of OT loss to predicted density (unnormed_density).
|
55 |
+
# im_grad = beta / source_count - < beta, source_density> / (source_count)^2
|
56 |
+
source_density = unnormed_density[idx][0].view([-1]).detach()
|
57 |
+
source_count = source_density.sum()
|
58 |
+
im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta # size of [#cood * #cood]
|
59 |
+
im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) # size of 1
|
60 |
+
im_grad = im_grad_1 - im_grad_2
|
61 |
+
im_grad = im_grad.detach().view([1, self.output_size, self.output_size])
|
62 |
+
# Define loss = <im_grad, predicted density>. The gradient of loss w.r.t prediced density is im_grad.
|
63 |
+
loss += torch.sum(unnormed_density[idx] * im_grad)
|
64 |
+
wd += torch.sum(dis * P).item()
|
65 |
+
|
66 |
+
return loss, wd, ot_obj_values
|
67 |
+
|
68 |
+
|
losses/ramps.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018, Curious AI Ltd. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Functions for ramping hyperparameters up or down
|
9 |
+
|
10 |
+
Each function takes the current training step or epoch, and the
|
11 |
+
ramp length in the same format, and returns a multiplier between
|
12 |
+
0 and 1.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
def sigmoid_rampup(current, rampup_length):
|
20 |
+
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
|
21 |
+
if rampup_length == 0:
|
22 |
+
return 1.0
|
23 |
+
else:
|
24 |
+
current = np.clip(current, 0.0, rampup_length)
|
25 |
+
phase = 1.0 - current / rampup_length
|
26 |
+
return float(np.exp(-5.0 * phase * phase))
|
27 |
+
|
28 |
+
|
29 |
+
def linear_rampup(current, rampup_length):
|
30 |
+
"""Linear rampup"""
|
31 |
+
assert current >= 0 and rampup_length >= 0
|
32 |
+
if current >= rampup_length:
|
33 |
+
return 1.0
|
34 |
+
else:
|
35 |
+
return current / rampup_length
|
36 |
+
|
37 |
+
|
38 |
+
def cosine_rampdown(current, rampdown_length):
|
39 |
+
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
|
40 |
+
assert 0 <= current <= rampdown_length
|
41 |
+
return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
|
losses/rank_loss.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class MarginRankLoss(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(MarginRankLoss, self).__init__()
|
9 |
+
self.loss = 0.0
|
10 |
+
|
11 |
+
def forward(self, img_list, margin=0):
|
12 |
+
length = len(img_list)
|
13 |
+
self.loss = 0.0
|
14 |
+
B, C, H, W = img_list[0].shape
|
15 |
+
for i in range(length - 1):
|
16 |
+
for j in range(i + 1, length):
|
17 |
+
self.loss = self.loss + torch.sum(F.relu(img_list[j].sum(-1).sum(-1).sum(-1) - img_list[i].sum(-1).sum(-1).sum(-1) + margin))
|
18 |
+
|
19 |
+
self.loss = self.loss / (B*length*(length-1)/2)
|
20 |
+
return self.loss
|
21 |
+
|
22 |
+
|
23 |
+
class RankLoss(nn.Module):
|
24 |
+
def __init__(self):
|
25 |
+
super(RankLoss, self).__init__()
|
26 |
+
self.countloss_criterion = nn.MSELoss(reduction='sum')
|
27 |
+
self.rankloss_criterion = MarginRankLoss()
|
28 |
+
self.rankloss = 0.0
|
29 |
+
self.losses = {}
|
30 |
+
|
31 |
+
def forward(self, unlabeled_results):
|
32 |
+
self.rankloss = 0.0
|
33 |
+
self.losses = {}
|
34 |
+
|
35 |
+
|
36 |
+
if unlabeled_results is None:
|
37 |
+
self.rankloss = 0.0
|
38 |
+
elif isinstance(unlabeled_results, tuple) and len(unlabeled_results) > 0:
|
39 |
+
self.rankloss = self.rankloss_criterion(unlabeled_results)
|
40 |
+
elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
|
41 |
+
count = 0
|
42 |
+
for i in range(len(unlabeled_results)):
|
43 |
+
if isinstance(unlabeled_results[i], tuple) and len(unlabeled_results[i]) > 0:
|
44 |
+
temploss = self.rankloss_criterion(unlabeled_results[i])
|
45 |
+
self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
|
46 |
+
self.rankloss += temploss
|
47 |
+
|
48 |
+
count += 1
|
49 |
+
if count > 0:
|
50 |
+
self.rankloss = self.rankloss / count
|
51 |
+
|
52 |
+
return self.rankloss
|
53 |
+
|
network/pvt_cls.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
from timm.models.registry import register_model
|
8 |
+
from timm.models.vision_transformer import _cfg
|
9 |
+
|
10 |
+
import math
|
11 |
+
from torch.distributions.uniform import Uniform
|
12 |
+
import numpy as np
|
13 |
+
import random
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
class SELayer(nn.Module):
|
21 |
+
def __init__(self, channel, reduction=16):
|
22 |
+
super(SELayer, self).__init__()
|
23 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
24 |
+
self.fc = nn.Sequential(
|
25 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
28 |
+
nn.Sigmoid()
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
b, c, _, _ = x.size()
|
33 |
+
y = self.avg_pool(x).view(b, c)
|
34 |
+
y = self.fc(y).view(b, c, 1, 1)
|
35 |
+
return x * y.expand_as(x)
|
36 |
+
|
37 |
+
|
38 |
+
class Regression(nn.Module):
|
39 |
+
def __init__(self):
|
40 |
+
super(Regression, self).__init__()
|
41 |
+
|
42 |
+
self.v1 = nn.Sequential(
|
43 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
44 |
+
nn.Conv2d(256, 128, 3, padding=1, dilation=1),
|
45 |
+
nn.BatchNorm2d(128), nn.ReLU(inplace=True))
|
46 |
+
|
47 |
+
self.v2 = nn.Sequential(
|
48 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
49 |
+
nn.Conv2d(512, 256, 3, padding=1, dilation=1),
|
50 |
+
nn.BatchNorm2d(256), nn.ReLU(inplace=True))
|
51 |
+
|
52 |
+
self.v3 = nn.Sequential(
|
53 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
54 |
+
nn.Conv2d(1024, 512, 3, padding=1, dilation=1), nn.BatchNorm2d(512),
|
55 |
+
nn.ReLU(inplace=True))
|
56 |
+
|
57 |
+
self.ca2 = nn.Sequential(ChannelAttention(512),
|
58 |
+
nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1 ),
|
59 |
+
nn.BatchNorm2d(512), nn.ReLU(inplace=True))
|
60 |
+
|
61 |
+
self.ca1 = nn.Sequential(ChannelAttention(256),
|
62 |
+
nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1 ),
|
63 |
+
nn.BatchNorm2d(256), nn.ReLU(inplace=True))
|
64 |
+
|
65 |
+
self.ca0 = nn.Sequential(ChannelAttention(128),
|
66 |
+
nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1 ),
|
67 |
+
nn.BatchNorm2d(128), nn.ReLU(inplace=True))
|
68 |
+
|
69 |
+
self.res2 = nn.Sequential(
|
70 |
+
nn.Conv2d(512, 256, 3, padding=1, dilation=1), nn.BatchNorm2d(256),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
|
73 |
+
nn.ReLU(inplace=True),
|
74 |
+
nn.Conv2d(128, 1, 3, padding=1, dilation=1),
|
75 |
+
nn.ReLU(inplace=True))
|
76 |
+
|
77 |
+
self.res1 = nn.Sequential(
|
78 |
+
nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
|
79 |
+
nn.ReLU(inplace=True),
|
80 |
+
nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.Conv2d(64, 1, 3, padding=1, dilation=1),
|
83 |
+
nn.ReLU(inplace=True))
|
84 |
+
|
85 |
+
self.res0 = nn.Sequential(
|
86 |
+
nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
|
87 |
+
nn.ReLU(inplace=True),
|
88 |
+
nn.Conv2d(64, 1, 3, padding=1, dilation=1),
|
89 |
+
nn.ReLU(inplace=True))
|
90 |
+
|
91 |
+
self.noise2 = DropOutDecoder(1, 512, 512)
|
92 |
+
self.noise1 = FeatureDropDecoder(1, 256, 256)
|
93 |
+
self.noise0 = FeatureNoiseDecoder(1, 128, 128)
|
94 |
+
|
95 |
+
self.upsam2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
96 |
+
self.upsam4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
97 |
+
|
98 |
+
self.conv1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False)
|
99 |
+
self.conv2 = nn.Conv2d(512, 256, kernel_size=1, bias=False)
|
100 |
+
self.conv3 = nn.Conv2d(256, 128, kernel_size=1, bias=False)
|
101 |
+
self.conv4 = nn.Conv2d(128, 1, kernel_size=1, bias=False)
|
102 |
+
|
103 |
+
#cls2.view(8, 1024, 1, 1))
|
104 |
+
|
105 |
+
self.init_param()
|
106 |
+
|
107 |
+
def forward(self, x, cls):
|
108 |
+
x0 = x[0]; x1 = x[1]; x2 = x[2]; x3 = x[3]
|
109 |
+
cls0 = cls[0].view(cls[0].shape[0], cls[0].shape[1], 1, 1)
|
110 |
+
cls1 = cls[1].view(cls[1].shape[0], cls[1].shape[1], 1, 1)
|
111 |
+
cls2 = cls[2].view(cls[2].shape[0], cls[2].shape[1], 1, 1)
|
112 |
+
|
113 |
+
x2_1 = self.ca2(x2)+self.v3(x3)
|
114 |
+
x1_1 = self.ca1(x1)+self.v2(x2_1)
|
115 |
+
x0_1 = self.ca0(x0)+self.v1(x1_1)
|
116 |
+
|
117 |
+
if self.training:
|
118 |
+
yc2 = self.conv4(self.conv3(self.conv2(self.noise2(self.conv1(cls2))))).squeeze()
|
119 |
+
yc1 = self.conv4(self.conv3(self.noise1(self.conv2(cls1)))).squeeze()
|
120 |
+
yc0 = self.conv4(self.noise0(self.conv3(cls0))).squeeze()
|
121 |
+
|
122 |
+
y2 = self.res2(self.upsam4(self.noise2(x2_1)))
|
123 |
+
y1 = self.res1(self.upsam2(self.noise1(x1_1)))
|
124 |
+
y0 = self.res0(self.noise0(x0_1))
|
125 |
+
|
126 |
+
else:
|
127 |
+
yc2 = self.conv4(self.conv3(self.conv2(self.conv1(cls2)))).squeeze()
|
128 |
+
yc1 = self.conv4(self.conv3(self.conv2(cls1))).squeeze()
|
129 |
+
yc0 = self.conv4(self.conv3(cls0)).squeeze()
|
130 |
+
|
131 |
+
y2 = self.res2(self.upsam4(x2_1))
|
132 |
+
y1 = self.res1(self.upsam2(x1_1))
|
133 |
+
y0 = self.res0(x0_1)
|
134 |
+
|
135 |
+
return [y0, y1, y2], [yc0, yc1, yc2]
|
136 |
+
|
137 |
+
def init_param(self):
|
138 |
+
for m in self.modules():
|
139 |
+
if isinstance(m, nn.Conv2d):
|
140 |
+
nn.init.normal_(m.weight, std=0.01)
|
141 |
+
if m.bias is not None:
|
142 |
+
nn.init.constant_(m.bias, 0)
|
143 |
+
elif isinstance(m, nn.BatchNorm2d):
|
144 |
+
nn.init.constant_(m.weight, 1)
|
145 |
+
nn.init.constant_(m.bias, 0)
|
146 |
+
|
147 |
+
|
148 |
+
class Mlp(nn.Module):
|
149 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
150 |
+
super().__init__()
|
151 |
+
out_features = out_features or in_features
|
152 |
+
hidden_features = hidden_features or in_features
|
153 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
154 |
+
self.act = act_layer()
|
155 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
156 |
+
self.drop = nn.Dropout(drop)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
x = self.fc1(x)
|
160 |
+
x = self.act(x)
|
161 |
+
x = self.drop(x)
|
162 |
+
x = self.fc2(x)
|
163 |
+
x = self.drop(x)
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
def upsample(in_channels, out_channels, upscale, kernel_size=3):
|
169 |
+
# A series of x 2 upsamling until we get to the upscale we want
|
170 |
+
layers = []
|
171 |
+
conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
172 |
+
nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
|
173 |
+
layers.append(conv1x1)
|
174 |
+
for i in range(int(math.log(upscale, 2))):
|
175 |
+
layers.append(PixelShuffle(out_channels, scale=2))
|
176 |
+
return nn.Sequential(*layers)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureDropDecoder(nn.Module):
|
181 |
+
def __init__(self, upscale, conv_in_ch, num_classes):
|
182 |
+
super(FeatureDropDecoder, self).__init__()
|
183 |
+
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
|
184 |
+
|
185 |
+
def feature_dropout(self, x):
|
186 |
+
attention = torch.mean(x, dim=1, keepdim=True)
|
187 |
+
max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
|
188 |
+
threshold = max_val * np.random.uniform(0.7, 0.9)
|
189 |
+
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
|
190 |
+
drop_mask = (attention < threshold).float()
|
191 |
+
return x.mul(drop_mask)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
x = self.feature_dropout(x)
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
class FeatureNoiseDecoder(nn.Module):
|
199 |
+
def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3):
|
200 |
+
super(FeatureNoiseDecoder, self).__init__()
|
201 |
+
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
|
202 |
+
self.uni_dist = Uniform(-uniform_range, uniform_range)
|
203 |
+
|
204 |
+
def feature_based_noise(self, x):
|
205 |
+
noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
|
206 |
+
x_noise = x.mul(noise_vector) + x
|
207 |
+
return x_noise
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
x = self.feature_based_noise(x)
|
211 |
+
return x
|
212 |
+
|
213 |
+
class DropOutDecoder(nn.Module):
|
214 |
+
def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True):
|
215 |
+
super(DropOutDecoder, self).__init__()
|
216 |
+
self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate)
|
217 |
+
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
x = self.dropout(x)
|
221 |
+
return x
|
222 |
+
|
223 |
+
|
224 |
+
## ChannelAttetion
|
225 |
+
class ChannelAttention(nn.Module):
|
226 |
+
def __init__(self, in_planes, ratio=16):
|
227 |
+
super(ChannelAttention, self).__init__()
|
228 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
229 |
+
|
230 |
+
self.fc = nn.Sequential(
|
231 |
+
nn.Linear(in_planes,in_planes // ratio, bias = False),
|
232 |
+
nn.ReLU(inplace = True),
|
233 |
+
nn.Linear(in_planes // ratio, in_planes, bias = False)
|
234 |
+
)
|
235 |
+
self.sigmoid = nn.Sigmoid()
|
236 |
+
for m in self.modules():
|
237 |
+
if isinstance(m, nn.Conv2d):
|
238 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
239 |
+
|
240 |
+
def forward(self, in_feature):
|
241 |
+
x = in_feature
|
242 |
+
b, c, _, _ = in_feature.size()
|
243 |
+
avg_out = self.fc(self.avg_pool(x).view(b,c)).view(b, c, 1, 1)
|
244 |
+
out = avg_out
|
245 |
+
return self.sigmoid(out).expand_as(in_feature) * in_feature
|
246 |
+
|
247 |
+
|
248 |
+
class Attention(nn.Module):
|
249 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
250 |
+
super().__init__()
|
251 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
252 |
+
|
253 |
+
self.dim = dim
|
254 |
+
self.num_heads = num_heads
|
255 |
+
head_dim = dim // num_heads
|
256 |
+
self.scale = qk_scale or head_dim ** -0.5
|
257 |
+
|
258 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
259 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
260 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
261 |
+
self.proj = nn.Linear(dim, dim)
|
262 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
263 |
+
|
264 |
+
self.sr_ratio = sr_ratio
|
265 |
+
if sr_ratio > 1:
|
266 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
267 |
+
self.norm = nn.LayerNorm(dim)
|
268 |
+
|
269 |
+
def forward(self, x, H, W):
|
270 |
+
B, N, C = x.shape
|
271 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
272 |
+
|
273 |
+
if self.sr_ratio > 4:
|
274 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
275 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
276 |
+
x_ = self.norm(x_)
|
277 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
278 |
+
else:
|
279 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
280 |
+
k, v = kv[0], kv[1]
|
281 |
+
|
282 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
283 |
+
attn = attn.softmax(dim=-1)
|
284 |
+
attn = self.attn_drop(attn)
|
285 |
+
|
286 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
287 |
+
x = self.proj(x)
|
288 |
+
x = self.proj_drop(x)
|
289 |
+
|
290 |
+
return x
|
291 |
+
|
292 |
+
|
293 |
+
class Block(nn.Module):
|
294 |
+
|
295 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
296 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
297 |
+
super().__init__()
|
298 |
+
self.norm1 = norm_layer(dim)
|
299 |
+
self.attn = Attention(
|
300 |
+
dim,
|
301 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
302 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
303 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
304 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
305 |
+
self.norm2 = norm_layer(dim)
|
306 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
307 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
308 |
+
|
309 |
+
def forward(self, x, H, W):
|
310 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
311 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
312 |
+
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class PatchEmbed(nn.Module):
|
317 |
+
""" Image to Patch Embedding
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
321 |
+
super().__init__()
|
322 |
+
img_size = to_2tuple(img_size)
|
323 |
+
patch_size = to_2tuple(patch_size)
|
324 |
+
|
325 |
+
self.img_size = img_size
|
326 |
+
self.patch_size = patch_size
|
327 |
+
# assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
|
328 |
+
# f"img_size {img_size} should be divided by patch_size {patch_size}."
|
329 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
330 |
+
self.num_patches = self.H * self.W
|
331 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
332 |
+
self.norm = nn.LayerNorm(embed_dim)
|
333 |
+
|
334 |
+
def forward(self, x):
|
335 |
+
B, C, H, W = x.shape
|
336 |
+
|
337 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
338 |
+
x = self.norm(x)
|
339 |
+
H, W = H // self.patch_size[0], W // self.patch_size[1]
|
340 |
+
|
341 |
+
return x, (H, W)
|
342 |
+
|
343 |
+
|
344 |
+
class PyramidVisionTransformer(nn.Module):
|
345 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
346 |
+
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
347 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
348 |
+
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4):
|
349 |
+
super().__init__()
|
350 |
+
self.num_classes = num_classes
|
351 |
+
self.depths = depths
|
352 |
+
self.num_stages = num_stages
|
353 |
+
|
354 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
355 |
+
cur = 0
|
356 |
+
|
357 |
+
for i in range(num_stages):
|
358 |
+
patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
|
359 |
+
patch_size=patch_size if i == 0 else 2,
|
360 |
+
in_chans=in_chans if i == 0 else embed_dims[i - 1],
|
361 |
+
embed_dim=embed_dims[i])
|
362 |
+
num_patches = patch_embed.num_patches if i == 0 else patch_embed.num_patches + 1
|
363 |
+
pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))
|
364 |
+
pos_drop = nn.Dropout(p=drop_rate)
|
365 |
+
|
366 |
+
block = nn.ModuleList([Block(
|
367 |
+
dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
|
368 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],
|
369 |
+
norm_layer=norm_layer, sr_ratio=sr_ratios[i])
|
370 |
+
for j in range(depths[i])])
|
371 |
+
cur += depths[i]
|
372 |
+
|
373 |
+
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
374 |
+
setattr(self, f"pos_embed{i + 1}", pos_embed)
|
375 |
+
setattr(self, f"pos_drop{i + 1}", pos_drop)
|
376 |
+
setattr(self, f"block{i + 1}", block)
|
377 |
+
|
378 |
+
self.norm = norm_layer(embed_dims[3])
|
379 |
+
|
380 |
+
# cls_token
|
381 |
+
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
|
382 |
+
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
|
383 |
+
self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
|
384 |
+
|
385 |
+
# classification head
|
386 |
+
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
387 |
+
|
388 |
+
|
389 |
+
self.regression = Regression()
|
390 |
+
|
391 |
+
# init weights
|
392 |
+
for i in range(num_stages):
|
393 |
+
pos_embed = getattr(self, f"pos_embed{i + 1}")
|
394 |
+
trunc_normal_(pos_embed, std=.02)
|
395 |
+
trunc_normal_(self.cls_token_1, std=.02)
|
396 |
+
trunc_normal_(self.cls_token_2, std=.02)
|
397 |
+
trunc_normal_(self.cls_token_3, std=.02)
|
398 |
+
self.apply(self._init_weights)
|
399 |
+
|
400 |
+
|
401 |
+
def _init_weights(self, m):
|
402 |
+
if isinstance(m, nn.Linear):
|
403 |
+
trunc_normal_(m.weight, std=.02)
|
404 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
405 |
+
nn.init.constant_(m.bias, 0)
|
406 |
+
elif isinstance(m, nn.LayerNorm):
|
407 |
+
nn.init.constant_(m.bias, 0)
|
408 |
+
nn.init.constant_(m.weight, 1.0)
|
409 |
+
|
410 |
+
@torch.jit.ignore
|
411 |
+
def no_weight_decay(self):
|
412 |
+
# return {'pos_embed', 'cls_token'} # has pos_embed may be better
|
413 |
+
return {'cls_token'}
|
414 |
+
|
415 |
+
def get_classifier(self):
|
416 |
+
return self.head
|
417 |
+
|
418 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
419 |
+
self.num_classes = num_classes
|
420 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
421 |
+
|
422 |
+
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
|
423 |
+
if H * W == self.patch_embed1.num_patches:
|
424 |
+
return pos_embed
|
425 |
+
else:
|
426 |
+
return F.interpolate(
|
427 |
+
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
|
428 |
+
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
|
429 |
+
|
430 |
+
def forward_features(self, x):
|
431 |
+
B = x.shape[0]
|
432 |
+
outputs = list()
|
433 |
+
cls_output = list()
|
434 |
+
|
435 |
+
for i in range(self.num_stages):
|
436 |
+
patch_embed = getattr(self, f"patch_embed{i + 1}")
|
437 |
+
pos_embed = getattr(self, f"pos_embed{i + 1}")
|
438 |
+
pos_drop = getattr(self, f"pos_drop{i + 1}")
|
439 |
+
block = getattr(self, f"block{i + 1}")
|
440 |
+
x, (H, W) = patch_embed(x)
|
441 |
+
|
442 |
+
if i == 0:
|
443 |
+
pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)
|
444 |
+
elif i == 1:
|
445 |
+
cls_tokens = self.cls_token_1.expand(B, -1, -1)
|
446 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
447 |
+
pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
|
448 |
+
pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
|
449 |
+
elif i == 2:
|
450 |
+
cls_tokens = self.cls_token_2.expand(B, -1, -1)
|
451 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
452 |
+
pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
|
453 |
+
pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
|
454 |
+
|
455 |
+
elif i == 3:
|
456 |
+
cls_tokens = self.cls_token_3.expand(B, -1, -1)
|
457 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
458 |
+
pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
|
459 |
+
pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
|
460 |
+
|
461 |
+
|
462 |
+
x = pos_drop(x + pos_embed)
|
463 |
+
for blk in block:
|
464 |
+
x = blk(x, H, W)
|
465 |
+
|
466 |
+
if i == 0:
|
467 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
468 |
+
else:
|
469 |
+
|
470 |
+
x_cls = x[:,1,:]
|
471 |
+
x = x[:,1:,:].reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
472 |
+
cls_output.append(x_cls)
|
473 |
+
|
474 |
+
|
475 |
+
outputs.append(x)
|
476 |
+
return outputs, cls_output
|
477 |
+
|
478 |
+
|
479 |
+
def forward(self, label_x, unlabel_x=None):
|
480 |
+
|
481 |
+
if self.training:
|
482 |
+
# labeled image processing
|
483 |
+
label_x, l_cls = self.forward_features(label_x)
|
484 |
+
out_label_x, out_cls_l = self.regression(label_x, l_cls)
|
485 |
+
label_x_1, label_x_2, label_x_3 = out_label_x
|
486 |
+
|
487 |
+
B,C,H,W = label_x_1.size()
|
488 |
+
label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
489 |
+
label_normed = label_x_1 / (label_sum + 1e-6)
|
490 |
+
|
491 |
+
# unlabeled image processing
|
492 |
+
B,C,H,W = unlabel_x.shape
|
493 |
+
unlabel_x, ul_cls = self.forward_features(unlabel_x)
|
494 |
+
out_unlabel_x, out_cls_ul = self.regression(unlabel_x, ul_cls)
|
495 |
+
y0, y1, y2 = out_unlabel_x
|
496 |
+
|
497 |
+
unlabel_x_1 = self.generate_feature_patches(y0)
|
498 |
+
unlabel_x_2 = self.generate_feature_patches(y1)
|
499 |
+
unlabel_x_3 = self.generate_feature_patches(y2)
|
500 |
+
|
501 |
+
assert unlabel_x_1.shape[0] == B * 5
|
502 |
+
assert unlabel_x_2.shape[0] == B * 5
|
503 |
+
assert unlabel_x_3.shape[0] == B * 5
|
504 |
+
|
505 |
+
unlabel_x_1 = torch.split(unlabel_x_1, split_size_or_sections=B, dim=0)
|
506 |
+
unlabel_x_2 = torch.split(unlabel_x_2, split_size_or_sections=B, dim=0)
|
507 |
+
unlabel_x_3 = torch.split(unlabel_x_3, split_size_or_sections=B, dim=0)
|
508 |
+
|
509 |
+
return [label_x_1, label_x_2, label_x_3], [unlabel_x_1, unlabel_x_2, unlabel_x_3], label_normed, out_cls_l, out_cls_ul
|
510 |
+
|
511 |
+
|
512 |
+
else:
|
513 |
+
|
514 |
+
label_x, l_cls = self.forward_features(label_x)
|
515 |
+
out_label_x, out_cls_l = self.regression(label_x, l_cls)
|
516 |
+
label_x_1, label_x_2, label_x_3 = out_label_x
|
517 |
+
B,C,H,W = label_x_1.size()
|
518 |
+
label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
519 |
+
label_normed = label_x_1 / (label_sum + 1e-6)
|
520 |
+
|
521 |
+
return [label_x_1, label_x_2, label_x_3], label_normed
|
522 |
+
|
523 |
+
|
524 |
+
def generate_feature_patches(self, unlabel_x, ratio=0.75):
|
525 |
+
# unlabeled image processing
|
526 |
+
|
527 |
+
unlabel_x_1 = unlabel_x
|
528 |
+
b, c, h, w = unlabel_x.shape
|
529 |
+
|
530 |
+
center_x = random.randint(h // 2 - (h - h * ratio) // 2, h // 2 + (h - h * ratio) // 2)
|
531 |
+
center_y = random.randint(w // 2 - (w - w * ratio) // 2, w // 2 + (w - w * ratio) // 2)
|
532 |
+
|
533 |
+
new_h2 = int(h * ratio)
|
534 |
+
new_w2 = int(w * ratio) # 48*48
|
535 |
+
unlabel_x_2 = unlabel_x[:, :, center_x - new_h2 // 2:center_x + new_h2 // 2,
|
536 |
+
center_y - new_w2 // 2:center_y + new_w2 // 2]
|
537 |
+
|
538 |
+
new_h3 = int(new_h2 * ratio)
|
539 |
+
new_w3 = int(new_w2 * ratio)
|
540 |
+
unlabel_x_3 = unlabel_x[:, :, center_x - new_h3 // 2:center_x + new_h3 // 2,
|
541 |
+
center_y - new_w3 // 2:center_y + new_w3 // 2]
|
542 |
+
|
543 |
+
new_h4 = int(new_h3 * ratio)
|
544 |
+
new_w4 = int(new_w3 * ratio)
|
545 |
+
unlabel_x_4 = unlabel_x[:, :, center_x - new_h4 // 2:center_x + new_h4 // 2,
|
546 |
+
center_y - new_w4 // 2:center_y + new_w4 // 2]
|
547 |
+
|
548 |
+
new_h5 = int(new_h4 * ratio)
|
549 |
+
new_w5 = int(new_w4 * ratio)
|
550 |
+
unlabel_x_5 = unlabel_x[:, :, center_x - new_h5 // 2:center_x + new_h5 // 2,
|
551 |
+
center_y - new_w5 // 2:center_y + new_w5 // 2]
|
552 |
+
|
553 |
+
unlabel_x_2 = nn.functional.interpolate(unlabel_x_2, size=(h, w), mode='bilinear')
|
554 |
+
unlabel_x_3 = nn.functional.interpolate(unlabel_x_3, size=(h, w), mode='bilinear')
|
555 |
+
unlabel_x_4 = nn.functional.interpolate(unlabel_x_4, size=(h, w), mode='bilinear')
|
556 |
+
unlabel_x_5 = nn.functional.interpolate(unlabel_x_5, size=(h, w), mode='bilinear')
|
557 |
+
|
558 |
+
unlabel_x = torch.cat([unlabel_x_1, unlabel_x_2, unlabel_x_3, unlabel_x_4, unlabel_x_5], dim=0)
|
559 |
+
|
560 |
+
return unlabel_x
|
561 |
+
|
562 |
+
def _conv_filter(state_dict, patch_size=16):
|
563 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
564 |
+
out_dict = {}
|
565 |
+
for k, v in state_dict.items():
|
566 |
+
if 'patch_embed.proj.weight' in k:
|
567 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
568 |
+
out_dict[k] = v
|
569 |
+
|
570 |
+
return out_dict
|
571 |
+
|
572 |
+
|
573 |
+
@register_model
|
574 |
+
def pvt_tiny(pretrained=False, **kwargs):
|
575 |
+
model = PyramidVisionTransformer(
|
576 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
577 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
578 |
+
**kwargs)
|
579 |
+
model.default_cfg = _cfg()
|
580 |
+
|
581 |
+
return model
|
582 |
+
|
583 |
+
|
584 |
+
@register_model
|
585 |
+
def pvt_small(pretrained=False, **kwargs):
|
586 |
+
model = PyramidVisionTransformer(
|
587 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
588 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
|
589 |
+
model.default_cfg = _cfg()
|
590 |
+
|
591 |
+
return model
|
592 |
+
|
593 |
+
|
594 |
+
@register_model
|
595 |
+
def pvt_medium(pretrained=False, **kwargs):
|
596 |
+
model = PyramidVisionTransformer(
|
597 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
598 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
599 |
+
**kwargs)
|
600 |
+
model.default_cfg = _cfg()
|
601 |
+
|
602 |
+
return model
|
603 |
+
|
604 |
+
|
605 |
+
@register_model
|
606 |
+
def pvt_large(pretrained=False, **kwargs):
|
607 |
+
model = PyramidVisionTransformer(
|
608 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
609 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
610 |
+
**kwargs)
|
611 |
+
model.default_cfg = _cfg()
|
612 |
+
|
613 |
+
return model
|
614 |
+
|
615 |
+
@register_model
|
616 |
+
def pvt_treeformer(pretrained=False, **kwargs):
|
617 |
+
model = PyramidVisionTransformer(
|
618 |
+
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
|
619 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
620 |
+
**kwargs)
|
621 |
+
model.default_cfg = _cfg()
|
622 |
+
|
623 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.21.5
|
2 |
+
Pillow==9.4.0
|
3 |
+
scikit_learn==1.2.2
|
4 |
+
scipy==1.7.3
|
5 |
+
timm==0.4.12
|
6 |
+
torch==1.12.1
|
7 |
+
torchvision==0.13.1
|
sample_imgs/overview.png
ADDED
test.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import datasets.crowd as crowd
|
6 |
+
from network import pvt_cls as TCN
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from scipy.io import savemat
|
9 |
+
from sklearn.metrics import r2_score
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser(description='Test ')
|
12 |
+
parser.add_argument('--device', default='0', help='assign device')
|
13 |
+
parser.add_argument('--batch-size', type=int, default=8, help='train batch size')
|
14 |
+
parser.add_argument('--crop-size', type=int, default=256, help='the crop size of the train image')
|
15 |
+
parser.add_argument('--model-path', type=str, default='/scratch/users/k2254235/ckpts/SEMI/Treeformer/best_model_mae-21.49_epoch-1759.pth', help='saved model path')
|
16 |
+
parser.add_argument('--data-path', type=str, default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='dataset path')
|
17 |
+
parser.add_argument('--dataset', type=str, default='TC')
|
18 |
+
|
19 |
+
def test(args, isSave = True):
|
20 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu
|
21 |
+
device = torch.device('cuda')
|
22 |
+
|
23 |
+
model_path = args.model_path
|
24 |
+
crop_size = args.crop_size
|
25 |
+
data_path = args.data_path
|
26 |
+
|
27 |
+
dataset = crowd.Crowd_TC(os.path.join(data_path, 'test_data'), crop_size, 1, method='val')
|
28 |
+
dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, pin_memory=True)
|
29 |
+
|
30 |
+
model = TCN.pvt_treeformer(pretrained=False)
|
31 |
+
model.to(device)
|
32 |
+
model.load_state_dict(torch.load(model_path, device))
|
33 |
+
model.eval()
|
34 |
+
image_errs = []
|
35 |
+
result = []
|
36 |
+
R2_es = []
|
37 |
+
R2_gt = []
|
38 |
+
l=0;
|
39 |
+
for inputs, count, name, imgauss in dataloader:
|
40 |
+
with torch.no_grad():
|
41 |
+
inputs = inputs.to(device)
|
42 |
+
crop_imgs, crop_masks = [], []
|
43 |
+
b, c, h, w = inputs.size()
|
44 |
+
rh, rw = args.crop_size, args.crop_size
|
45 |
+
|
46 |
+
for i in range(0, h, rh):
|
47 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
48 |
+
|
49 |
+
for j in range(0, w, rw):
|
50 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
51 |
+
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
|
52 |
+
mask = torch.zeros([b, 1, h, w]).to(device)
|
53 |
+
mask[:, :, gis:gie, gjs:gje].fill_(1.0)
|
54 |
+
crop_masks.append(mask)
|
55 |
+
crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
|
56 |
+
|
57 |
+
crop_preds = []
|
58 |
+
nz, bz = crop_imgs.size(0), args.batch_size
|
59 |
+
for i in range(0, nz, bz):
|
60 |
+
|
61 |
+
gs, gt = i, min(nz, i + bz)
|
62 |
+
crop_pred, _ = model(crop_imgs[gs:gt])
|
63 |
+
crop_pred = crop_pred[0]
|
64 |
+
|
65 |
+
_, _, h1, w1 = crop_pred.size()
|
66 |
+
crop_pred = F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
|
67 |
+
crop_preds.append(crop_pred)
|
68 |
+
crop_preds = torch.cat(crop_preds, dim=0)
|
69 |
+
#import pdb;pdb.set_trace()
|
70 |
+
|
71 |
+
# splice them to the original size
|
72 |
+
idx = 0
|
73 |
+
pred_map = torch.zeros([b, 1, h, w]).to(device)
|
74 |
+
for i in range(0, h, rh):
|
75 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
76 |
+
for j in range(0, w, rw):
|
77 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
78 |
+
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
|
79 |
+
idx += 1
|
80 |
+
# for the overlapping area, compute average value
|
81 |
+
mask = crop_masks.sum(dim=0).unsqueeze(0)
|
82 |
+
outputs = pred_map / mask
|
83 |
+
|
84 |
+
outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)/4
|
85 |
+
outputs = pred_map / mask
|
86 |
+
|
87 |
+
img_err = count[0].item() - torch.sum(outputs).item()
|
88 |
+
R2_gt.append(count[0].item())
|
89 |
+
R2_es.append(torch.sum(outputs).item())
|
90 |
+
|
91 |
+
print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item())
|
92 |
+
image_errs.append(img_err)
|
93 |
+
result.append([name, count[0].item(), torch.sum(outputs).item(), img_err])
|
94 |
+
|
95 |
+
savemat('predictions/'+name[0]+'.mat', {'estimation':np.squeeze(outputs.cpu().data.numpy()),
|
96 |
+
'image': np.squeeze(inputs.cpu().data.numpy()), 'gt': np.squeeze(imgauss.cpu().data.numpy())})
|
97 |
+
l=l+1
|
98 |
+
|
99 |
+
|
100 |
+
image_errs = np.array(image_errs)
|
101 |
+
|
102 |
+
mse = np.sqrt(np.mean(np.square(image_errs)))
|
103 |
+
mae = np.mean(np.abs(image_errs))
|
104 |
+
R_2 = r2_score(R2_gt,R2_es)
|
105 |
+
|
106 |
+
print('{}: mae {}, mse {}, R2 {}\n'.format(model_path, mae, mse,R_2))
|
107 |
+
|
108 |
+
if isSave:
|
109 |
+
with open("test.txt","w") as f:
|
110 |
+
for i in range(len(result)):
|
111 |
+
f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n")
|
112 |
+
f.close()
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
args = parser.parse_args()
|
116 |
+
test(args, isSave= True)
|
117 |
+
|
train.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch import optim
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.data.dataloader import default_collate
|
8 |
+
import numpy as np
|
9 |
+
from datetime import datetime
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from datasets.crowd import Crowd_TC, Crowd_UL_TC
|
12 |
+
|
13 |
+
from network import pvt_cls as TCN
|
14 |
+
from losses.multi_con_loss import MultiConLoss
|
15 |
+
|
16 |
+
from utils.pytorch_utils import Save_Handle, AverageMeter
|
17 |
+
import utils.log_utils as log_utils
|
18 |
+
import argparse
|
19 |
+
from losses.rank_loss import RankLoss
|
20 |
+
|
21 |
+
from losses import ramps
|
22 |
+
from losses.ot_loss import OT_Loss
|
23 |
+
from losses.consistency_loss import *
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser(description='Train')
|
26 |
+
parser.add_argument('--data-dir', default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='data path')
|
27 |
+
|
28 |
+
parser.add_argument('--dataset', default='TC')
|
29 |
+
parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate')
|
30 |
+
parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay')
|
31 |
+
parser.add_argument('--resume', default='', type=str, help='the path of resume training model')
|
32 |
+
parser.add_argument('--max-epoch', type=int, default=4000, help='max training epoch')
|
33 |
+
parser.add_argument('--val-epoch', type=int, default=1, help='the num of steps to log training information')
|
34 |
+
parser.add_argument('--val-start', type=int, default=0, help='the epoch start to val')
|
35 |
+
parser.add_argument('--batch-size', type=int, default=16, help='train batch size')
|
36 |
+
parser.add_argument('--batch-size-ul', type=int, default=16, help='train batch size')
|
37 |
+
parser.add_argument('--device', default='0', help='assign device')
|
38 |
+
parser.add_argument('--num-workers', type=int, default=0, help='the num of training process')
|
39 |
+
parser.add_argument('--crop-size', type=int, default= 256, help='the crop size of the train image')
|
40 |
+
parser.add_argument('--rl', type=float, default=1, help='entropy regularization in sinkhorn')
|
41 |
+
parser.add_argument('--reg', type=float, default=1, help='entropy regularization in sinkhorn')
|
42 |
+
parser.add_argument('--ot', type=float, default=0.1, help='entropy regularization in sinkhorn')
|
43 |
+
parser.add_argument('--tv', type=float, default=0.01, help='entropy regularization in sinkhorn')
|
44 |
+
parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations')
|
45 |
+
parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance')
|
46 |
+
parser.add_argument('--run-name', default='Treeformer_test', help='run name for wandb interface/logging')
|
47 |
+
parser.add_argument('--consistency', type=int, default=1, help='whether to norm cood when computing distance')
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
|
51 |
+
def train_collate(batch):
|
52 |
+
transposed_batch = list(zip(*batch))
|
53 |
+
images = torch.stack(transposed_batch[0], 0)
|
54 |
+
gauss = torch.stack(transposed_batch[1], 0)
|
55 |
+
points = transposed_batch[2]
|
56 |
+
gt_discretes = torch.stack(transposed_batch[3], 0)
|
57 |
+
return images, gauss, points, gt_discretes
|
58 |
+
|
59 |
+
|
60 |
+
def train_collate_UL(batch):
|
61 |
+
transposed_batch = list(zip(*batch))
|
62 |
+
images = torch.stack(transposed_batch[0], 0)
|
63 |
+
|
64 |
+
return images
|
65 |
+
|
66 |
+
def get_current_consistency_weight(epoch):
|
67 |
+
# Consistency ramp-up from https://arxiv.org/abs/1610.02242
|
68 |
+
return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_ramp)
|
69 |
+
|
70 |
+
|
71 |
+
class Trainer(object):
|
72 |
+
def __init__(self, args):
|
73 |
+
self.args = args
|
74 |
+
|
75 |
+
def setup(self):
|
76 |
+
args = self.args
|
77 |
+
sub_dir = (
|
78 |
+
"SEMI/{}_12-1-input-{}_reg-{}_nIter-{}_normCood-{}".format(
|
79 |
+
args.run_name,args.crop_size,args.reg,
|
80 |
+
args.num_of_iter_in_ot,args.norm_cood))
|
81 |
+
|
82 |
+
self.save_dir = os.path.join("/scratch/users/k2254235","ckpts", sub_dir)
|
83 |
+
if not os.path.exists(self.save_dir):
|
84 |
+
os.makedirs(self.save_dir)
|
85 |
+
|
86 |
+
time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S")
|
87 |
+
self.logger = log_utils.get_logger(
|
88 |
+
os.path.join(self.save_dir, "train-{:s}.log".format(time_str)))
|
89 |
+
|
90 |
+
log_utils.print_config(vars(args), self.logger)
|
91 |
+
|
92 |
+
if torch.cuda.is_available():
|
93 |
+
self.device = torch.device("cuda")
|
94 |
+
self.device_count = torch.cuda.device_count()
|
95 |
+
self.logger.info("using {} gpus".format(self.device_count))
|
96 |
+
else:
|
97 |
+
raise Exception("gpu is not available")
|
98 |
+
|
99 |
+
|
100 |
+
downsample_ratio = 4
|
101 |
+
self.datasets = {"train": Crowd_TC(os.path.join(args.data_dir, "train_data"), args.crop_size,
|
102 |
+
downsample_ratio, "train"), "val": Crowd_TC(os.path.join(args.data_dir, "valid_data"),
|
103 |
+
args.crop_size, downsample_ratio, "val")}
|
104 |
+
|
105 |
+
self.datasets_ul = { "train_ul": Crowd_UL_TC(os.path.join(args.data_dir, "train_data_ul"),
|
106 |
+
args.crop_size, downsample_ratio, "train_ul")}
|
107 |
+
|
108 |
+
|
109 |
+
self.dataloaders = {
|
110 |
+
x: DataLoader(self.datasets[x],
|
111 |
+
collate_fn=(train_collate if x == "train" else default_collate),
|
112 |
+
batch_size=(args.batch_size if x == "train" else 1),
|
113 |
+
shuffle=(True if x == "train" else False),
|
114 |
+
num_workers=args.num_workers * self.device_count,
|
115 |
+
pin_memory=(True if x == "train" else False))
|
116 |
+
for x in ["train", "val"]}
|
117 |
+
|
118 |
+
self.dataloaders_ul = {
|
119 |
+
x: DataLoader(self.datasets_ul[x],
|
120 |
+
collate_fn=(train_collate_UL ),
|
121 |
+
batch_size=(args.batch_size_ul),
|
122 |
+
shuffle=(True),
|
123 |
+
num_workers=args.num_workers * self.device_count,
|
124 |
+
pin_memory=(True if x == "train" else False))
|
125 |
+
for x in ["train_ul"]}
|
126 |
+
|
127 |
+
|
128 |
+
self.model = TCN.pvt_treeformer(pretrained=False)
|
129 |
+
|
130 |
+
self.model.to(self.device)
|
131 |
+
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
132 |
+
self.start_epoch = 0
|
133 |
+
|
134 |
+
if args.resume:
|
135 |
+
self.logger.info("loading pretrained model from " + args.resume)
|
136 |
+
suf = args.resume.rsplit(".", 1)[-1]
|
137 |
+
if suf == "tar":
|
138 |
+
checkpoint = torch.load(args.resume, self.device)
|
139 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
140 |
+
self.optimizer.load_state_dict(
|
141 |
+
checkpoint["optimizer_state_dict"])
|
142 |
+
self.start_epoch = checkpoint["epoch"] + 1
|
143 |
+
elif suf == "pth":
|
144 |
+
self.model.load_state_dict(
|
145 |
+
torch.load(args.resume, self.device))
|
146 |
+
else:
|
147 |
+
self.logger.info("random initialization")
|
148 |
+
|
149 |
+
self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood,
|
150 |
+
self.device, args.num_of_iter_in_ot, args.reg)
|
151 |
+
|
152 |
+
self.tvloss = nn.L1Loss(reduction="none").to(self.device)
|
153 |
+
self.mse = nn.MSELoss().to(self.device)
|
154 |
+
self.mae = nn.L1Loss().to(self.device)
|
155 |
+
self.save_list = Save_Handle(max_num=1)
|
156 |
+
self.best_mae = np.inf
|
157 |
+
self.best_mse = np.inf
|
158 |
+
self.rankloss = RankLoss().to(self.device)
|
159 |
+
self.kl_distance = nn.KLDivLoss(reduction='none')
|
160 |
+
self.multiconloss = MultiConLoss().to(self.device)
|
161 |
+
|
162 |
+
|
163 |
+
def train(self):
|
164 |
+
"""training process"""
|
165 |
+
args = self.args
|
166 |
+
for epoch in range(self.start_epoch, args.max_epoch + 1):
|
167 |
+
self.logger.info("-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5)
|
168 |
+
self.epoch = epoch
|
169 |
+
self.train_epoch()
|
170 |
+
if epoch % args.val_epoch == 0 and epoch >= args.val_start:
|
171 |
+
self.val_epoch()
|
172 |
+
|
173 |
+
def train_epoch(self):
|
174 |
+
epoch_ot_loss = AverageMeter()
|
175 |
+
epoch_ot_obj_value = AverageMeter()
|
176 |
+
epoch_wd = AverageMeter()
|
177 |
+
epoch_tv_loss = AverageMeter()
|
178 |
+
epoch_count_loss = AverageMeter()
|
179 |
+
epoch_count_consistency_l = AverageMeter()
|
180 |
+
epoch_count_consistency_ul = AverageMeter()
|
181 |
+
epoch_loss = AverageMeter()
|
182 |
+
epoch_mae = AverageMeter()
|
183 |
+
epoch_mse = AverageMeter()
|
184 |
+
epoch_start = time.time()
|
185 |
+
epoch_rank_loss = AverageMeter()
|
186 |
+
epoch_consistensy_loss = AverageMeter()
|
187 |
+
|
188 |
+
self.model.train() # Set model to training mode
|
189 |
+
|
190 |
+
for step, (inputs, gausss, points, gt_discrete) in enumerate(self.dataloaders["train"]):
|
191 |
+
inputs = inputs.to(self.device)
|
192 |
+
gausss = gausss.to(self.device)
|
193 |
+
gd_count = np.array([len(p) for p in points], dtype=np.float32)
|
194 |
+
|
195 |
+
points = [p.to(self.device) for p in points]
|
196 |
+
gt_discrete = gt_discrete.to(self.device)
|
197 |
+
N = inputs.size(0)
|
198 |
+
|
199 |
+
for st, unlabel_data in enumerate(self.dataloaders_ul["train_ul"]):
|
200 |
+
inputs_ul = unlabel_data.to(self.device)
|
201 |
+
break
|
202 |
+
|
203 |
+
|
204 |
+
with torch.set_grad_enabled(True):
|
205 |
+
outputs_L, outputs_UL, outputs_normed, CLS_L, CLS_UL = self.model(inputs, inputs_ul)
|
206 |
+
outputs_L = outputs_L[0]
|
207 |
+
|
208 |
+
with torch.set_grad_enabled(False):
|
209 |
+
preds_UL = (outputs_UL[0][0] + outputs_UL[1][0] + outputs_UL[2][0])/3
|
210 |
+
|
211 |
+
# Compute counting loss.
|
212 |
+
count_loss = self.mae(outputs_L.sum(1).sum(1).sum(1),torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
|
213 |
+
|
214 |
+
# Compute OT loss.
|
215 |
+
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L, points)
|
216 |
+
ot_loss = ot_loss* self.args.ot
|
217 |
+
ot_obj_value = ot_obj_value* self.args.ot
|
218 |
+
|
219 |
+
gd_count_tensor = (torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
|
220 |
+
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
|
221 |
+
tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
|
222 |
+
torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
|
223 |
+
|
224 |
+
epoch_ot_loss.update(ot_loss.item(), N)
|
225 |
+
epoch_ot_obj_value.update(ot_obj_value.item(), N)
|
226 |
+
epoch_wd.update(wd, N)
|
227 |
+
epoch_count_loss.update(count_loss.item(), N)
|
228 |
+
epoch_tv_loss.update(tv_loss.item(), N)
|
229 |
+
|
230 |
+
# Compute ranking loss.
|
231 |
+
rank_loss = self.rankloss(outputs_UL)*self.args.rl
|
232 |
+
epoch_rank_loss.update(rank_loss.item(), N)
|
233 |
+
|
234 |
+
# Compute multi level consistancy loss
|
235 |
+
consistency_loss = args.consistency * self.multiconloss(outputs_UL)
|
236 |
+
epoch_consistensy_loss.update(consistency_loss.item(), N)
|
237 |
+
|
238 |
+
|
239 |
+
# Compute consistency count
|
240 |
+
Con_cls_UL = (CLS_UL[0] + CLS_UL[1] + CLS_UL[2])/3
|
241 |
+
Con_cls_L = torch.from_numpy(gd_count).float().to(self.device)
|
242 |
+
|
243 |
+
count_loss_l = self.mae(torch.stack((CLS_L[0],CLS_L[1],CLS_L[2])), torch.stack((Con_cls_L, Con_cls_L, Con_cls_L)))
|
244 |
+
count_loss_ul = self.mae(torch.stack((CLS_UL[0],CLS_UL[1],CLS_UL[2])), torch.stack((Con_cls_UL, Con_cls_UL, Con_cls_UL)))
|
245 |
+
epoch_count_consistency_l.update(count_loss_l.item(), N)
|
246 |
+
epoch_count_consistency_ul.update(count_loss_ul.item(), N)
|
247 |
+
|
248 |
+
|
249 |
+
loss = count_loss + ot_loss + tv_loss + rank_loss + count_loss_l + count_loss_ul + consistency_loss
|
250 |
+
|
251 |
+
|
252 |
+
self.optimizer.zero_grad()
|
253 |
+
loss.backward()
|
254 |
+
self.optimizer.step()
|
255 |
+
|
256 |
+
pred_count = (torch.sum(outputs_L.view(N, -1),
|
257 |
+
dim=1).detach().cpu().numpy())
|
258 |
+
|
259 |
+
pred_err = pred_count - gd_count
|
260 |
+
epoch_loss.update(loss.item(), N)
|
261 |
+
epoch_mse.update(np.mean(pred_err * pred_err), N)
|
262 |
+
epoch_mae.update(np.mean(abs(pred_err)), N)
|
263 |
+
|
264 |
+
|
265 |
+
self.logger.info(
|
266 |
+
"Epoch {} Train, Loss: {:.2f}, Count Loss: {:.2f}, OT Loss: {:.2e}, TV Loss: {:.2e}, Rank Loss: {:.2f},"
|
267 |
+
"Consistensy Loss: {:.2f}, MSE: {:.2f}, MAE: {:.2f},LC Loss: {:.2f}, ULC Loss: {:.2f}, Cost {:.1f} sec".format(
|
268 |
+
self.epoch, epoch_loss.get_avg(), epoch_count_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_tv_loss.get_avg(), epoch_rank_loss.get_avg(),
|
269 |
+
epoch_consistensy_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), epoch_count_consistency_l.get_avg(),
|
270 |
+
epoch_count_consistency_ul.get_avg(), time.time() - epoch_start))
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
model_state_dic = self.model.state_dict()
|
275 |
+
save_path = os.path.join(self.save_dir, "{}_ckpt.tar".format(self.epoch))
|
276 |
+
|
277 |
+
torch.save({"epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(),
|
278 |
+
"model_state_dict": model_state_dic}, save_path)
|
279 |
+
self.save_list.append(save_path)
|
280 |
+
|
281 |
+
def val_epoch(self):
|
282 |
+
args = self.args
|
283 |
+
epoch_start = time.time()
|
284 |
+
self.model.eval() # Set model to evaluate mode
|
285 |
+
epoch_res = []
|
286 |
+
for inputs, count, name, gauss_im in self.dataloaders["val"]:
|
287 |
+
with torch.no_grad():
|
288 |
+
inputs = inputs.to(self.device)
|
289 |
+
crop_imgs, crop_masks = [], []
|
290 |
+
b, c, h, w = inputs.size()
|
291 |
+
rh, rw = args.crop_size, args.crop_size
|
292 |
+
for i in range(0, h, rh):
|
293 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
294 |
+
for j in range(0, w, rw):
|
295 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
296 |
+
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
|
297 |
+
mask = torch.zeros([b, 1, h, w]).to(self.device)
|
298 |
+
mask[:, :, gis:gie, gjs:gje].fill_(1.0)
|
299 |
+
crop_masks.append(mask)
|
300 |
+
crop_imgs, crop_masks = map(
|
301 |
+
lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
|
302 |
+
|
303 |
+
crop_preds = []
|
304 |
+
nz, bz = crop_imgs.size(0), args.batch_size
|
305 |
+
for i in range(0, nz, bz):
|
306 |
+
gs, gt = i, min(nz, i + bz)
|
307 |
+
|
308 |
+
crop_pred, _ = self.model(crop_imgs[gs:gt])
|
309 |
+
crop_pred = crop_pred[0]
|
310 |
+
_, _, h1, w1 = crop_pred.size()
|
311 |
+
crop_pred = (F.interpolate(crop_pred, size=(h1 * 4, w1 * 4),
|
312 |
+
mode="bilinear", align_corners=True) / 16 )
|
313 |
+
|
314 |
+
crop_preds.append(crop_pred)
|
315 |
+
crop_preds = torch.cat(crop_preds, dim=0)
|
316 |
+
|
317 |
+
# splice them to the original size
|
318 |
+
idx = 0
|
319 |
+
pred_map = torch.zeros([b, 1, h, w]).to(self.device)
|
320 |
+
for i in range(0, h, rh):
|
321 |
+
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
|
322 |
+
for j in range(0, w, rw):
|
323 |
+
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
|
324 |
+
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
|
325 |
+
idx += 1
|
326 |
+
# for the overlapping area, compute average value
|
327 |
+
mask = crop_masks.sum(dim=0).unsqueeze(0)
|
328 |
+
outputs = pred_map / mask
|
329 |
+
|
330 |
+
res = count[0].item() - torch.sum(outputs).item()
|
331 |
+
epoch_res.append(res)
|
332 |
+
epoch_res = np.array(epoch_res)
|
333 |
+
mse = np.sqrt(np.mean(np.square(epoch_res)))
|
334 |
+
mae = np.mean(np.abs(epoch_res))
|
335 |
+
|
336 |
+
self.logger.info("Epoch {} Val, MSE: {:.2f}, MAE: {:.2f}, Cost {:.1f} sec".format(
|
337 |
+
self.epoch, mse, mae, time.time() - epoch_start ))
|
338 |
+
|
339 |
+
|
340 |
+
model_state_dic = self.model.state_dict()
|
341 |
+
print("Comaprison", mae, self.best_mae)
|
342 |
+
if mae < self.best_mae:
|
343 |
+
self.best_mse = mse
|
344 |
+
self.best_mae = mae
|
345 |
+
self.logger.info(
|
346 |
+
"save best mse {:.2f} mae {:.2f} model epoch {}".format(
|
347 |
+
self.best_mse, self.best_mae, self.epoch))
|
348 |
+
|
349 |
+
print("Saving best model at {} epoch".format(self.epoch))
|
350 |
+
model_path = os.path.join(
|
351 |
+
self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format(
|
352 |
+
self.best_mae, self.epoch))
|
353 |
+
|
354 |
+
torch.save(model_state_dic, model_path)
|
355 |
+
|
356 |
+
|
357 |
+
if __name__ == "__main__":
|
358 |
+
import torch
|
359 |
+
torch.backends.cudnn.benchmark = True
|
360 |
+
trainer = Trainer(args)
|
361 |
+
trainer.setup()
|
362 |
+
trainer.train()
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
|
369 |
+
|
utils/__init__.py
ADDED
File without changes
|
utils/log_utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def get_logger(log_file):
|
5 |
+
logger = logging.getLogger(log_file)
|
6 |
+
logger.setLevel(logging.DEBUG)
|
7 |
+
fh = logging.FileHandler(log_file)
|
8 |
+
fh.setLevel(logging.DEBUG)
|
9 |
+
ch = logging.StreamHandler()
|
10 |
+
ch.setLevel(logging.INFO)
|
11 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
12 |
+
ch.setFormatter(formatter)
|
13 |
+
fh.setFormatter(formatter)
|
14 |
+
logger.addHandler(ch)
|
15 |
+
logger.addHandler(fh)
|
16 |
+
return logger
|
17 |
+
|
18 |
+
|
19 |
+
def print_config(config, logger):
|
20 |
+
"""
|
21 |
+
Print configuration of the model
|
22 |
+
"""
|
23 |
+
for k, v in config.items():
|
24 |
+
logger.info("{}:\t{}".format(k.ljust(15), v))
|
utils/pytorch_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):
|
4 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
5 |
+
lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
|
6 |
+
for param_group in optimizer.param_groups:
|
7 |
+
param_group['lr'] = lr
|
8 |
+
|
9 |
+
|
10 |
+
class Save_Handle(object):
|
11 |
+
"""handle the number of """
|
12 |
+
def __init__(self, max_num):
|
13 |
+
self.save_list = []
|
14 |
+
self.max_num = max_num
|
15 |
+
|
16 |
+
def append(self, save_path):
|
17 |
+
if len(self.save_list) < self.max_num:
|
18 |
+
self.save_list.append(save_path)
|
19 |
+
else:
|
20 |
+
remove_path = self.save_list[0]
|
21 |
+
del self.save_list[0]
|
22 |
+
self.save_list.append(save_path)
|
23 |
+
if os.path.exists(remove_path):
|
24 |
+
os.remove(remove_path)
|
25 |
+
|
26 |
+
|
27 |
+
class AverageMeter(object):
|
28 |
+
"""Computes and stores the average and current value"""
|
29 |
+
def __init__(self):
|
30 |
+
self.reset()
|
31 |
+
|
32 |
+
def reset(self):
|
33 |
+
self.val = 0
|
34 |
+
self.avg = 0
|
35 |
+
self.sum = 0
|
36 |
+
self.count = 0
|
37 |
+
|
38 |
+
def update(self, val, n=1):
|
39 |
+
self.val = val
|
40 |
+
self.sum += val * n
|
41 |
+
self.count += n
|
42 |
+
self.avg = 1.0 * self.sum / self.count
|
43 |
+
|
44 |
+
def get_avg(self):
|
45 |
+
return self.avg
|
46 |
+
|
47 |
+
def get_count(self):
|
48 |
+
return self.count
|
49 |
+
|
50 |
+
|
51 |
+
def set_trainable(model, requires_grad):
|
52 |
+
for param in model.parameters():
|
53 |
+
param.requires_grad = requires_grad
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def get_num_params(model):
|
58 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|