Мясников Филипп Сергеевич commited on
Commit
6c92b57
1 Parent(s): cc78303
op/__init__.py ADDED
File without changes
op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
op/fused_act.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+ fused = load(
11
+ 'fused',
12
+ sources=[
13
+ os.path.join(module_path, 'fused_bias_act.cpp'),
14
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
+ ],
16
+ )
17
+
18
+
19
+ class FusedLeakyReLUFunctionBackward(Function):
20
+ @staticmethod
21
+ def forward(ctx, grad_output, out, negative_slope, scale):
22
+ ctx.save_for_backward(out)
23
+ ctx.negative_slope = negative_slope
24
+ ctx.scale = scale
25
+
26
+ empty = grad_output.new_empty(0)
27
+
28
+ grad_input = fused.fused_bias_act(
29
+ grad_output, empty, out, 3, 1, negative_slope, scale
30
+ )
31
+
32
+ dim = [0]
33
+
34
+ if grad_input.ndim > 2:
35
+ dim += list(range(2, grad_input.ndim))
36
+
37
+ grad_bias = grad_input.sum(dim).detach()
38
+
39
+ return grad_input, grad_bias
40
+
41
+ @staticmethod
42
+ def backward(ctx, gradgrad_input, gradgrad_bias):
43
+ out, = ctx.saved_tensors
44
+ gradgrad_out = fused.fused_bias_act(
45
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
+ )
47
+
48
+ return gradgrad_out, None, None, None
49
+
50
+
51
+ class FusedLeakyReLUFunction(Function):
52
+ @staticmethod
53
+ def forward(ctx, input, bias, negative_slope, scale):
54
+ empty = input.new_empty(0)
55
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
+ ctx.save_for_backward(out)
57
+ ctx.negative_slope = negative_slope
58
+ ctx.scale = scale
59
+
60
+ return out
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ out, = ctx.saved_tensors
65
+
66
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
+ grad_output, out, ctx.negative_slope, ctx.scale
68
+ )
69
+
70
+ return grad_input, grad_bias, None, None
71
+
72
+
73
+ class FusedLeakyReLU(nn.Module):
74
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75
+ super().__init__()
76
+
77
+ self.bias = nn.Parameter(torch.zeros(channel))
78
+ self.negative_slope = negative_slope
79
+ self.scale = scale
80
+
81
+ def forward(self, input):
82
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83
+
84
+
85
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
op/fused_act_cpu.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.nn import functional as F
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+
12
+ class FusedLeakyReLU(nn.Module):
13
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
+ super().__init__()
15
+
16
+ self.bias = nn.Parameter(torch.zeros(channel))
17
+ self.negative_slope = negative_slope
18
+ self.scale = scale
19
+
20
+ def forward(self, input):
21
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
+
23
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
24
+ if input.device.type == "cpu":
25
+ if bias is not None:
26
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
27
+ return (
28
+ F.leaky_relu(
29
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
30
+ )
31
+ * scale
32
+ )
33
+
34
+ else:
35
+ return F.leaky_relu(input, negative_slope=0.2) * scale
36
+
37
+ else:
38
+ return FusedLeakyReLUFunction.apply(
39
+ input.contiguous(), bias, negative_slope, scale
40
+ )
41
+
op/upfirdn2d.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ upfirdn2d_op = load(
10
+ 'upfirdn2d',
11
+ sources=[
12
+ os.path.join(module_path, 'upfirdn2d.cpp'),
13
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class UpFirDn2dBackward(Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
+ ):
23
+
24
+ up_x, up_y = up
25
+ down_x, down_y = down
26
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
+
28
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
+
30
+ grad_input = upfirdn2d_op.upfirdn2d(
31
+ grad_output,
32
+ grad_kernel,
33
+ down_x,
34
+ down_y,
35
+ up_x,
36
+ up_y,
37
+ g_pad_x0,
38
+ g_pad_x1,
39
+ g_pad_y0,
40
+ g_pad_y1,
41
+ )
42
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
+
44
+ ctx.save_for_backward(kernel)
45
+
46
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
+
48
+ ctx.up_x = up_x
49
+ ctx.up_y = up_y
50
+ ctx.down_x = down_x
51
+ ctx.down_y = down_y
52
+ ctx.pad_x0 = pad_x0
53
+ ctx.pad_x1 = pad_x1
54
+ ctx.pad_y0 = pad_y0
55
+ ctx.pad_y1 = pad_y1
56
+ ctx.in_size = in_size
57
+ ctx.out_size = out_size
58
+
59
+ return grad_input
60
+
61
+ @staticmethod
62
+ def backward(ctx, gradgrad_input):
63
+ kernel, = ctx.saved_tensors
64
+
65
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
+
67
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
68
+ gradgrad_input,
69
+ kernel,
70
+ ctx.up_x,
71
+ ctx.up_y,
72
+ ctx.down_x,
73
+ ctx.down_y,
74
+ ctx.pad_x0,
75
+ ctx.pad_x1,
76
+ ctx.pad_y0,
77
+ ctx.pad_y1,
78
+ )
79
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
+ gradgrad_out = gradgrad_out.view(
81
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
+ )
83
+
84
+ return gradgrad_out, None, None, None, None, None, None, None, None
85
+
86
+
87
+ class UpFirDn2d(Function):
88
+ @staticmethod
89
+ def forward(ctx, input, kernel, up, down, pad):
90
+ up_x, up_y = up
91
+ down_x, down_y = down
92
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
+
94
+ kernel_h, kernel_w = kernel.shape
95
+ batch, channel, in_h, in_w = input.shape
96
+ ctx.in_size = input.shape
97
+
98
+ input = input.reshape(-1, in_h, in_w, 1)
99
+
100
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
+
102
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
+ ctx.out_size = (out_h, out_w)
105
+
106
+ ctx.up = (up_x, up_y)
107
+ ctx.down = (down_x, down_y)
108
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
+
110
+ g_pad_x0 = kernel_w - pad_x0 - 1
111
+ g_pad_y0 = kernel_h - pad_y0 - 1
112
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
+
115
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
+
117
+ out = upfirdn2d_op.upfirdn2d(
118
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
+ )
120
+ # out = out.view(major, out_h, out_w, minor)
121
+ out = out.view(-1, channel, out_h, out_w)
122
+
123
+ return out
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ kernel, grad_kernel = ctx.saved_tensors
128
+
129
+ grad_input = UpFirDn2dBackward.apply(
130
+ grad_output,
131
+ kernel,
132
+ grad_kernel,
133
+ ctx.up,
134
+ ctx.down,
135
+ ctx.pad,
136
+ ctx.g_pad,
137
+ ctx.in_size,
138
+ ctx.out_size,
139
+ )
140
+
141
+ return grad_input, None, None, None, None
142
+
143
+
144
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
+ out = UpFirDn2d.apply(
146
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
+ )
148
+
149
+ return out
150
+
151
+
152
+ def upfirdn2d_native(
153
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
+ ):
155
+ _, in_h, in_w, minor = input.shape
156
+ kernel_h, kernel_w = kernel.shape
157
+
158
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
159
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161
+
162
+ out = F.pad(
163
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164
+ )
165
+ out = out[
166
+ :,
167
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169
+ :,
170
+ ]
171
+
172
+ out = out.permute(0, 3, 1, 2)
173
+ out = out.reshape(
174
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175
+ )
176
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177
+ out = F.conv2d(out, w)
178
+ out = out.reshape(
179
+ -1,
180
+ minor,
181
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
+ )
184
+ out = out.permute(0, 2, 3, 1)
185
+
186
+ return out[:, ::down_y, ::down_x, :]
187
+
op/upfirdn2d_cpu.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
59
+
60
+ return out.view(-1, channel, out_h, out_w)