JohnAlexander23 commited on
Commit
69f4183
1 Parent(s): 67d8d51

Upload 23 files

Browse files
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SeemoRe
3
+ emoji: 💻
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.31.5
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+ from PIL import Image
9
+ from copy import deepcopy
10
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from gradio_imageslider import ImageSlider
14
+
15
+ ## local code
16
+ from models import seemore
17
+
18
+
19
+ def dict2namespace(config):
20
+ namespace = argparse.Namespace()
21
+ for key, value in config.items():
22
+ if isinstance(value, dict):
23
+ new_value = dict2namespace(value)
24
+ else:
25
+ new_value = value
26
+ setattr(namespace, key, new_value)
27
+ return namespace
28
+
29
+ def load_img (filename, norm=True,):
30
+ img = np.array(Image.open(filename).convert("RGB"))
31
+ h, w = img.shape[:2]
32
+
33
+ if w > 1920 or h > 1080:
34
+ new_h, new_w = h // 4, w // 4
35
+ img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC))
36
+
37
+ if norm:
38
+ img = img / 255.
39
+ img = img.astype(np.float32)
40
+ return img
41
+
42
+ def process_img (image):
43
+ img = np.array(image)
44
+ img = img / 255.
45
+ img = img.astype(np.float32)
46
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
47
+
48
+ with torch.no_grad():
49
+ x_hat = model(y)
50
+
51
+ restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
52
+ restored_img = np.clip(restored_img, 0. , 1.)
53
+
54
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
55
+ #return Image.fromarray(restored_img) #
56
+ return (image, Image.fromarray(restored_img))
57
+
58
+ def load_network(net, load_path, strict=True, param_key='params'):
59
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
60
+ net = net.module
61
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
62
+ if param_key is not None:
63
+ if param_key not in load_net and 'params' in load_net:
64
+ param_key = 'params'
65
+ load_net = load_net[param_key]
66
+ # remove unnecessary 'module.'
67
+ for k, v in deepcopy(load_net).items():
68
+ if k.startswith('module.'):
69
+ load_net[k[7:]] = v
70
+ load_net.pop(k)
71
+ net.load_state_dict(load_net, strict=strict)
72
+
73
+ CONFIG = "configs/eval_seemore_t_x4.yml"
74
+ hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./")
75
+ MODEL_NAME = "SeemoRe_T_X4.pth"
76
+
77
+ # parse config file
78
+ with open(os.path.join(CONFIG), "r") as f:
79
+ config = yaml.safe_load(f)
80
+
81
+ cfg = dict2namespace(config)
82
+
83
+ device = torch.device("cpu")
84
+ model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
85
+ num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim,
86
+ img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size,
87
+ recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
88
+
89
+ model = model.to(device)
90
+ print ("IMAGE MODEL CKPT:", MODEL_NAME)
91
+ load_network(model, MODEL_NAME, strict=True, param_key='params')
92
+
93
+
94
+
95
+
96
+ title = "See More Details"
97
+ description = ''' ### See More Details: Efficient Image Super-Resolution by Experts Mining - ICML 2024, Vienna, Austria
98
+
99
+ #### [Eduard Zamfir<sup>1</sup>](https://eduardzamfir.github.io), [Zongwei Wu<sup>1*</sup>](https://sites.google.com/view/zwwu/accueil), [Nancy Mehta<sup>1</sup>](https://scholar.google.com/citations?user=WwdYdlUAAAAJ&hl=en&oi=ao), [Yulun Zhang<sup>2,3*</sup>](http://yulunzhang.com/) and [Radu Timofte<sup>1</sup>](https://www.informatik.uni-wuerzburg.de/computervision/)
100
+
101
+ #### **<sup>1</sup> University of Würzburg, Germany - <sup>2</sup> Shanghai Jiao Tong University, China - <sup>3</sup> ETH Zürich, Switzerland**
102
+ #### **<sup>*</sup> Corresponding authors**
103
+
104
+ <details>
105
+ <summary> <b> Abstract</b> (click me to read)</summary>
106
+ <p>
107
+ Reconstructing high-resolution (HR) images from low-resolution (LR) inputs poses a significant challenge in image super-resolution (SR). While recent approaches have demonstrated the efficacy of intricate operations customized for various objectives, the straightforward stacking of these disparate operations can result in a substantial computational burden, hampering their practical utility. In response, we introduce **S**eemo**R**e, an efficient SR model employing expert mining. Our approach strategically incorporates experts at different levels, adopting a collaborative methodology. At the macro scale, our experts address rank-wise and spatial-wise informative features, providing a holistic understanding. Subsequently, the model delves into the subtleties of rank choice by leveraging a mixture of low-rank experts. By tapping into experts specialized in distinct key factors crucial for accurate SR, our model excels in uncovering intricate intra-feature details. This collaborative approach is reminiscent of the concept of **see more**, allowing our model to achieve an optimal performance with minimal computational costs in efficient settings
108
+ </p>
109
+ </details>
110
+
111
+
112
+ #### Drag the slider on the super-resolution image left and right to see the changes in the image details. SeemoRe performs x4 upscaling on the input image.
113
+
114
+ <br>
115
+
116
+ <code>
117
+ @inproceedings{zamfir2024details,
118
+ title={See More Details: Efficient Image Super-Resolution by Experts Mining},
119
+ author={Eduard Zamfir and Zongwei Wu and Nancy Mehta and Yulun Zhang and Radu Timofte},
120
+ booktitle={International Conference on Machine Learning},
121
+ year={2024},
122
+ organization={PMLR}
123
+ }
124
+ </code>
125
+ <br>
126
+ '''
127
+
128
+
129
+ article = "<p style='text-align: center'><a href='https://eduardzamfir.github.io/seemore' target='_blank'>See More Details: Efficient Image Super-Resolution by Experts Mining</a></p>"
130
+
131
+ #### Image,Prompts examples
132
+ examples = [
133
+ ['images/0801x4.png'],
134
+ ['images/0840x4.png'],
135
+ ['images/0841x4.png'],
136
+ ['images/0870x4.png'],
137
+ ['images/0878x4.png'],
138
+ ['images/0884x4.png'],
139
+ ['images/0900x4.png'],
140
+ ['images/img002x4.png'],
141
+ ['images/img003x4.png'],
142
+ ['images/img004x4.png'],
143
+ ['images/img035x4.png'],
144
+ ['images/img053x4.png'],
145
+ ['images/img064x4.png'],
146
+ ['images/img083x4.png'],
147
+ ['images/img092x4.png'],
148
+ ]
149
+
150
+ css = """
151
+ .image-frame img, .image-container img {
152
+ width: auto;
153
+ height: auto;
154
+ max-width: none;
155
+ }
156
+ """
157
+
158
+ demo = gr.Interface(
159
+ fn=process_img,
160
+ inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png"),],
161
+ outputs=ImageSlider(label="Super-Resolved Image",
162
+ type="pil",
163
+ show_download_button=True,
164
+ ), #[gr.Image(type="pil", label="Ouput", min_width=500)],
165
+ title=title,
166
+ description=description,
167
+ article=article,
168
+ examples=examples,
169
+ css=css,
170
+ )
171
+
172
+ if __name__ == "__main__":
173
+ demo.launch()
assets/arch.svg ADDED
configs/eval_seemore_t_x4.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: "SeemoRe"
3
+ scale: 4
4
+ in_chans: 3
5
+ num_experts: 3
6
+ img_range: 1.0
7
+ num_layers: 6
8
+ embedding_dim: 36
9
+ use_shuffle: True
10
+ lr_space: exp
11
+ topk: 1
12
+ recursive: 2
13
+ global_kernel_size: 11
14
+
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ flagged/
3
+ checkpoints
4
+ *.pt
5
+ *.gif
6
+ *.pth
images/0801x4.png ADDED
images/0840x4.png ADDED
images/0841x4.png ADDED
images/0870x4.png ADDED
images/0878x4.png ADDED
images/0884x4.png ADDED
images/0900x4.png ADDED
images/img002x4.png ADDED
images/img003x4.png ADDED
images/img004x4.png ADDED
images/img035x4.png ADDED
images/img053x4.png ADDED
images/img064x4.png ADDED
images/img083x4.png ADDED
images/img092x4.png ADDED
models/seemore.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+ from torch import Tensor
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops.layers.torch import Rearrange
8
+
9
+
10
+ ######################
11
+ # Meta Architecture
12
+ ######################
13
+ class SeemoRe(nn.Module):
14
+ def __init__(self,
15
+ scale: int = 4,
16
+ in_chans: int = 3,
17
+ num_experts: int = 6,
18
+ num_layers: int = 6,
19
+ embedding_dim: int = 64,
20
+ img_range: float = 1.0,
21
+ use_shuffle: bool = False,
22
+ global_kernel_size: int = 11,
23
+ recursive: int = 2,
24
+ lr_space: int = 1,
25
+ topk: int = 2,):
26
+ super().__init__()
27
+ self.scale = scale
28
+ self.num_in_channels = in_chans
29
+ self.num_out_channels = in_chans
30
+ self.img_range = img_range
31
+
32
+ rgb_mean = (0.4488, 0.4371, 0.4040)
33
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
34
+
35
+
36
+ # -- SHALLOW FEATURES --
37
+ self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1)
38
+
39
+ # -- DEEP FEATURES --
40
+ self.body = nn.ModuleList(
41
+ [ResGroup(in_ch=embedding_dim,
42
+ num_experts=num_experts,
43
+ use_shuffle=use_shuffle,
44
+ topk=topk,
45
+ lr_space=lr_space,
46
+ recursive=recursive,
47
+ global_kernel_size=global_kernel_size) for i in range(num_layers)]
48
+ )
49
+
50
+ # -- UPSCALE --
51
+ self.norm = LayerNorm(embedding_dim, data_format='channels_first')
52
+ self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1)
53
+ self.upsampler = nn.Sequential(
54
+ nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1),
55
+ nn.PixelShuffle(scale)
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ self.mean = self.mean.type_as(x)
60
+ x = (x - self.mean) * self.img_range
61
+
62
+ # -- SHALLOW FEATURES --
63
+ x = self.conv_1(x)
64
+ res = x
65
+
66
+ # -- DEEP FEATURES --
67
+ for idx, layer in enumerate(self.body):
68
+ x = layer(x)
69
+
70
+ x = self.norm(x)
71
+
72
+ # -- HR IMAGE RECONSTRUCTION --
73
+ x = self.conv_2(x) + res
74
+ x = self.upsampler(x)
75
+
76
+ x = x / self.img_range + self.mean
77
+ return x
78
+
79
+
80
+
81
+ #############################
82
+ # Components
83
+ #############################
84
+ class ResGroup(nn.Module):
85
+ def __init__(self,
86
+ in_ch: int,
87
+ num_experts: int,
88
+ global_kernel_size: int = 11,
89
+ lr_space: int = 1,
90
+ topk: int = 2,
91
+ recursive: int = 2,
92
+ use_shuffle: bool = False):
93
+ super().__init__()
94
+
95
+ self.local_block = RME(in_ch=in_ch,
96
+ num_experts=num_experts,
97
+ use_shuffle=use_shuffle,
98
+ lr_space=lr_space,
99
+ topk=topk,
100
+ recursive=recursive)
101
+ self.global_block = SME(in_ch=in_ch,
102
+ kernel_size=global_kernel_size)
103
+
104
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ x = self.local_block(x)
106
+ x = self.global_block(x)
107
+ return x
108
+
109
+
110
+
111
+ #############################
112
+ # Global Block
113
+ #############################
114
+ class SME(nn.Module):
115
+ def __init__(self,
116
+ in_ch: int,
117
+ kernel_size: int = 11):
118
+ super().__init__()
119
+
120
+ self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
121
+ self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size)
122
+
123
+ self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
124
+ self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ x = self.block(self.norm_1(x)) + x
128
+ x = self.ffn(self.norm_2(x)) + x
129
+ return x
130
+
131
+
132
+
133
+
134
+ class StripedConvFormer(nn.Module):
135
+ def __init__(self,
136
+ in_ch: int,
137
+ kernel_size: int):
138
+ super().__init__()
139
+ self.in_ch = in_ch
140
+ self.kernel_size = kernel_size
141
+ self.padding = kernel_size // 2
142
+
143
+ self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
144
+ self.to_qv = nn.Sequential(
145
+ nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0),
146
+ nn.GELU(),
147
+ )
148
+
149
+ self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True)
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ q, v = self.to_qv(x).chunk(2, dim=1)
153
+ q = self.attn(q)
154
+ x = self.proj(q * v)
155
+ return x
156
+
157
+
158
+
159
+ #############################
160
+ # Local Blocks
161
+ #############################
162
+ class RME(nn.Module):
163
+ def __init__(self,
164
+ in_ch: int,
165
+ num_experts: int,
166
+ topk: int,
167
+ lr_space: int = 1,
168
+ recursive: int = 2,
169
+ use_shuffle: bool = False,):
170
+ super().__init__()
171
+
172
+ self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
173
+ self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,)
174
+
175
+ self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
176
+ self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
177
+
178
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
179
+ x = self.block(self.norm_1(x)) + x
180
+ x = self.ffn(self.norm_2(x)) + x
181
+ return x
182
+
183
+
184
+
185
+ #################
186
+ # MoE Layer
187
+ #################
188
+ class MoEBlock(nn.Module):
189
+ def __init__(self,
190
+ in_ch: int,
191
+ num_experts: int,
192
+ topk: int,
193
+ use_shuffle: bool = False,
194
+ lr_space: str = "linear",
195
+ recursive: int = 2):
196
+ super().__init__()
197
+ self.use_shuffle = use_shuffle
198
+ self.recursive = recursive
199
+
200
+ self.conv_1 = nn.Sequential(
201
+ nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1),
202
+ nn.GELU(),
203
+ nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0)
204
+ )
205
+
206
+ self.agg_conv = nn.Sequential(
207
+ nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch),
208
+ nn.GELU())
209
+
210
+ self.conv = nn.Sequential(
211
+ nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch),
212
+ nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
213
+ )
214
+
215
+ self.conv_2 = nn.Sequential(
216
+ StripedConv2d(in_ch, kernel_size=3, depthwise=True),
217
+ nn.GELU())
218
+
219
+ if lr_space == "linear":
220
+ grow_func = lambda i: i+2
221
+ elif lr_space == "exp":
222
+ grow_func = lambda i: 2**(i+1)
223
+ elif lr_space == "double":
224
+ grow_func = lambda i: 2*i+2
225
+ else:
226
+ raise NotImplementedError(f"lr_space {lr_space} not implemented")
227
+
228
+ self.moe_layer = MoELayer(
229
+ experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim
230
+ gate=Router(in_ch=in_ch, num_experts=num_experts),
231
+ num_expert=topk,
232
+ )
233
+
234
+ self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
235
+
236
+ def calibrate(self, x: torch.Tensor) -> torch.Tensor:
237
+ b, c, h, w = x.shape
238
+ res = x
239
+
240
+ for _ in range(self.recursive):
241
+ x = self.agg_conv(x)
242
+ x = self.conv(x)
243
+ x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
244
+ return res + x
245
+
246
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
247
+ x = self.conv_1(x)
248
+
249
+ if self.use_shuffle:
250
+ x = channel_shuffle(x, groups=2)
251
+ x, k = torch.chunk(x, chunks=2, dim=1)
252
+
253
+ x = self.conv_2(x)
254
+ k = self.calibrate(k)
255
+
256
+ x = self.moe_layer(x, k)
257
+ x = self.proj(x)
258
+ return x
259
+
260
+
261
+ class MoELayer(nn.Module):
262
+ def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1):
263
+ super().__init__()
264
+ assert len(experts) > 0
265
+ self.experts = nn.ModuleList(experts)
266
+ self.gate = gate
267
+ self.num_expert = num_expert
268
+
269
+ def forward(self, inputs: torch.Tensor, k: torch.Tensor):
270
+ out = self.gate(inputs)
271
+ weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype)
272
+ topk_weights, topk_experts = torch.topk(weights, self.num_expert)
273
+ out = inputs.clone()
274
+
275
+ if self.training:
276
+ exp_weights = torch.zeros_like(weights)
277
+ exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts))
278
+ for i, expert in enumerate(self.experts):
279
+ out += expert(inputs, k) * exp_weights[:, i:i+1, None, None]
280
+ else:
281
+ selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)]
282
+ for i, expert in enumerate(selected_experts):
283
+ out += expert(inputs, k) * topk_weights[:, i:i+1, None, None]
284
+
285
+ return out
286
+
287
+
288
+
289
+ class Expert(nn.Module):
290
+ def __init__(self,
291
+ in_ch: int,
292
+ low_dim: int,):
293
+ super().__init__()
294
+ self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
295
+ self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
296
+ self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0)
297
+
298
+ def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
299
+ x = self.conv_1(x)
300
+ x = self.conv_2(k) * x # here no more sigmoid
301
+ x = self.conv_3(x)
302
+ return x
303
+
304
+
305
+ class Router(nn.Module):
306
+ def __init__(self,
307
+ in_ch: int,
308
+ num_experts: int):
309
+ super().__init__()
310
+
311
+ self.body = nn.Sequential(
312
+ nn.AdaptiveAvgPool2d(1),
313
+ Rearrange('b c 1 1 -> b c'),
314
+ nn.Linear(in_ch, num_experts, bias=False),
315
+ )
316
+
317
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
318
+ return self.body(x)
319
+
320
+
321
+
322
+ #################
323
+ # Utilities
324
+ #################
325
+ class StripedConv2d(nn.Module):
326
+ def __init__(self,
327
+ in_ch: int,
328
+ kernel_size: int,
329
+ depthwise: bool = False):
330
+ super().__init__()
331
+ self.in_ch = in_ch
332
+ self.kernel_size = kernel_size
333
+ self.padding = kernel_size // 2
334
+
335
+ self.conv = nn.Sequential(
336
+ nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1),
337
+ nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1),
338
+ )
339
+
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ return self.conv(x)
342
+
343
+
344
+
345
+ def channel_shuffle(x, groups=2):
346
+ bat_size, channels, w, h = x.shape
347
+ group_c = channels // groups
348
+ x = x.view(bat_size, groups, group_c, w, h)
349
+ x = torch.transpose(x, 1, 2).contiguous()
350
+ x = x.view(bat_size, -1, w, h)
351
+ return x
352
+
353
+
354
+ class GatedFFN(nn.Module):
355
+ def __init__(self,
356
+ in_ch,
357
+ mlp_ratio,
358
+ kernel_size,
359
+ act_layer,):
360
+ super().__init__()
361
+ mlp_ch = in_ch * mlp_ratio
362
+
363
+ self.fn_1 = nn.Sequential(
364
+ nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0),
365
+ act_layer,
366
+ )
367
+ self.fn_2 = nn.Sequential(
368
+ nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0),
369
+ act_layer,
370
+ )
371
+
372
+ self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2,
373
+ kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2)
374
+
375
+ def feat_decompose(self, x):
376
+ s = x - self.gate(x)
377
+ x = x + self.sigma * s
378
+ return x
379
+
380
+ def forward(self, x: torch.Tensor):
381
+ x = self.fn_1(x)
382
+ x, gate = torch.chunk(x, 2, dim=1)
383
+
384
+ gate = self.gate(gate)
385
+ x = x * gate
386
+
387
+ x = self.fn_2(x)
388
+ return x
389
+
390
+
391
+
392
+ class LayerNorm(nn.Module):
393
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
394
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
395
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
396
+ with shape (batch_size, channels, height, width).
397
+ """
398
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
399
+ super().__init__()
400
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
401
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
402
+ self.eps = eps
403
+ self.data_format = data_format
404
+ if self.data_format not in ["channels_last", "channels_first"]:
405
+ raise NotImplementedError
406
+ self.normalized_shape = (normalized_shape, )
407
+
408
+ def forward(self, x):
409
+ if self.data_format == "channels_last":
410
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
411
+ elif self.data_format == "channels_first":
412
+ u = x.mean(1, keepdim=True)
413
+ s = (x - u).pow(2).mean(1, keepdim=True)
414
+ x = (x - u) / torch.sqrt(s + self.eps)
415
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
416
+ return x
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ einops
4
+ PyYAML
5
+ Pillow>=6.2.2
6
+ gradio_imageslider==0.0.20