Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
•
4d1ebf3
1
Parent(s):
663e9a6
track-anything --version 1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- LICENSE +21 -0
- README.md +47 -13
- XMem-s012.pth +3 -0
- app.py +362 -0
- app_save.py +381 -0
- app_test.py +23 -0
- assets/demo_version_1.MP4 +3 -0
- assets/inpainting.gif +3 -0
- assets/poster_demo_version_1.png +0 -0
- assets/qingming.mp4 +3 -0
- demo.py +87 -0
- images/groceries.jpg +0 -0
- images/mask_painter.png +0 -0
- images/painter_input_image.jpg +0 -0
- images/painter_input_mask.jpg +0 -0
- images/painter_output_image.png +0 -0
- images/painter_output_image__.png +0 -0
- images/point_painter.png +0 -0
- images/point_painter_1.png +0 -0
- images/point_painter_2.png +0 -0
- images/truck.jpg +0 -0
- images/truck_both.jpg +0 -0
- images/truck_mask.jpg +0 -0
- images/truck_point.jpg +0 -0
- inpainter/.DS_Store +0 -0
- inpainter/base_inpainter.py +160 -0
- inpainter/config/config.yaml +4 -0
- inpainter/model/e2fgvi.py +350 -0
- inpainter/model/e2fgvi_hq.py +350 -0
- inpainter/model/modules/feat_prop.py +149 -0
- inpainter/model/modules/flow_comp.py +450 -0
- inpainter/model/modules/spectral_norm.py +288 -0
- inpainter/model/modules/tfocal_transformer.py +536 -0
- inpainter/model/modules/tfocal_transformer_hq.py +565 -0
- inpainter/util/__init__.py +0 -0
- inpainter/util/tensor_util.py +24 -0
- requirements.txt +17 -0
- sam_vit_h_4b8939.pth +3 -0
- template.html +27 -0
- templates/index.html +50 -0
- text_server.py +72 -0
- tools/__init__.py +0 -0
- tools/base_segmenter.py +129 -0
- tools/interact_tools.py +265 -0
- tools/mask_painter.py +288 -0
- tools/painter.py +215 -0
- track_anything.py +93 -0
- tracker/.DS_Store +0 -0
- tracker/base_tracker.py +233 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Mingqi Gao
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,47 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Track-Anything
|
2 |
+
|
3 |
+
***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
|
4 |
+
- Video object tracking and segmentation with shot changes.
|
5 |
+
- Data annnotation for video object tracking and segmentation.
|
6 |
+
- Object-centric downstream video tasks, such as video inpainting and editing.
|
7 |
+
|
8 |
+
## Demo
|
9 |
+
|
10 |
+
https://user-images.githubusercontent.com/28050374/232842703-8395af24-b13e-4b8e-aafb-e94b61e6c449.MP4
|
11 |
+
|
12 |
+
### Multiple Object Tracking and Segmentation (with [XMem](https://github.com/hkchengrex/XMem))
|
13 |
+
|
14 |
+
https://user-images.githubusercontent.com/39208339/233035206-0a151004-6461-4deb-b782-d1dbfe691493.mp4
|
15 |
+
|
16 |
+
### Video Object Tracking and Segmentation with Shot Changes (with [XMem](https://github.com/hkchengrex/XMem))
|
17 |
+
|
18 |
+
https://user-images.githubusercontent.com/30309970/232848349-f5e29e71-2ea4-4529-ac9a-94b9ca1e7055.mp4
|
19 |
+
|
20 |
+
### Video Inpainting (with [E2FGVI](https://github.com/MCG-NKU/E2FGVI))
|
21 |
+
|
22 |
+
https://user-images.githubusercontent.com/28050374/232959816-07f2826f-d267-4dda-8ae5-a5132173b8f4.mp4
|
23 |
+
|
24 |
+
## Get Started
|
25 |
+
#### Linux
|
26 |
+
```bash
|
27 |
+
# Clone the repository:
|
28 |
+
git clone https://github.com/gaomingqi/Track-Anything.git
|
29 |
+
cd Track-Anything
|
30 |
+
|
31 |
+
# Install dependencies:
|
32 |
+
pip install -r requirements.txt
|
33 |
+
|
34 |
+
# Install dependencies for inpainting:
|
35 |
+
pip install -U openmim
|
36 |
+
mim install mmcv
|
37 |
+
|
38 |
+
# Install dependencies for editing
|
39 |
+
pip install madgrad
|
40 |
+
|
41 |
+
# Run the Track-Anything gradio demo.
|
42 |
+
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
|
43 |
+
```
|
44 |
+
|
45 |
+
## Acknowledgements
|
46 |
+
|
47 |
+
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts.
|
XMem-s012.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16205ad04bfc55b442bd4d7af894382e09868b35e10721c5afc09a24ea8d72d9
|
3 |
+
size 249026057
|
app.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from demo import automask_image_app, automask_video_app, sahi_autoseg_app
|
3 |
+
import argparse
|
4 |
+
import cv2
|
5 |
+
import time
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
sys.path.append(sys.path[0]+"/tracker")
|
11 |
+
sys.path.append(sys.path[0]+"/tracker/model")
|
12 |
+
from track_anything import TrackingAnything
|
13 |
+
from track_anything import parse_augment
|
14 |
+
import requests
|
15 |
+
import json
|
16 |
+
import torchvision
|
17 |
+
import torch
|
18 |
+
import concurrent.futures
|
19 |
+
import queue
|
20 |
+
|
21 |
+
# download checkpoints
|
22 |
+
def download_checkpoint(url, folder, filename):
|
23 |
+
os.makedirs(folder, exist_ok=True)
|
24 |
+
filepath = os.path.join(folder, filename)
|
25 |
+
|
26 |
+
if not os.path.exists(filepath):
|
27 |
+
print("download checkpoints ......")
|
28 |
+
response = requests.get(url, stream=True)
|
29 |
+
with open(filepath, "wb") as f:
|
30 |
+
for chunk in response.iter_content(chunk_size=8192):
|
31 |
+
if chunk:
|
32 |
+
f.write(chunk)
|
33 |
+
|
34 |
+
print("download successfully!")
|
35 |
+
|
36 |
+
return filepath
|
37 |
+
|
38 |
+
# convert points input to prompt state
|
39 |
+
def get_prompt(click_state, click_input):
|
40 |
+
inputs = json.loads(click_input)
|
41 |
+
points = click_state[0]
|
42 |
+
labels = click_state[1]
|
43 |
+
for input in inputs:
|
44 |
+
points.append(input[:2])
|
45 |
+
labels.append(input[2])
|
46 |
+
click_state[0] = points
|
47 |
+
click_state[1] = labels
|
48 |
+
prompt = {
|
49 |
+
"prompt_type":["click"],
|
50 |
+
"input_point":click_state[0],
|
51 |
+
"input_label":click_state[1],
|
52 |
+
"multimask_output":"True",
|
53 |
+
}
|
54 |
+
return prompt
|
55 |
+
|
56 |
+
# extract frames from upload video
|
57 |
+
def get_frames_from_video(video_input, video_state):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
video_path:str
|
61 |
+
timestamp:float64
|
62 |
+
Return
|
63 |
+
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
64 |
+
"""
|
65 |
+
video_path = video_input
|
66 |
+
frames = []
|
67 |
+
try:
|
68 |
+
cap = cv2.VideoCapture(video_path)
|
69 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
70 |
+
while cap.isOpened():
|
71 |
+
ret, frame = cap.read()
|
72 |
+
if ret == True:
|
73 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
74 |
+
else:
|
75 |
+
break
|
76 |
+
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
77 |
+
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
78 |
+
|
79 |
+
# initialize video_state
|
80 |
+
video_state = {
|
81 |
+
"video_name": os.path.split(video_path)[-1],
|
82 |
+
"origin_images": frames,
|
83 |
+
"painted_images": frames.copy(),
|
84 |
+
"masks": [None]*len(frames),
|
85 |
+
"logits": [None]*len(frames),
|
86 |
+
"select_frame_number": 0,
|
87 |
+
"fps": 30
|
88 |
+
}
|
89 |
+
return video_state, gr.update(visible=True, maximum=len(frames), value=1)
|
90 |
+
|
91 |
+
# get the select frame from gradio slider
|
92 |
+
def select_template(image_selection_slider, video_state):
|
93 |
+
|
94 |
+
# images = video_state[1]
|
95 |
+
image_selection_slider -= 1
|
96 |
+
video_state["select_frame_number"] = image_selection_slider
|
97 |
+
|
98 |
+
# once select a new template frame, set the image in sam
|
99 |
+
|
100 |
+
model.samcontroler.sam_controler.reset_image()
|
101 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
102 |
+
|
103 |
+
|
104 |
+
return video_state["painted_images"][image_selection_slider], video_state
|
105 |
+
|
106 |
+
# use sam to get the mask
|
107 |
+
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
template_frame: PIL.Image
|
111 |
+
point_prompt: flag for positive or negative button click
|
112 |
+
click_state: [[points], [labels]]
|
113 |
+
"""
|
114 |
+
if point_prompt == "Positive":
|
115 |
+
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
116 |
+
interactive_state["positive_click_times"] += 1
|
117 |
+
else:
|
118 |
+
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
119 |
+
interactive_state["negative_click_times"] += 1
|
120 |
+
|
121 |
+
# prompt for sam model
|
122 |
+
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
123 |
+
|
124 |
+
mask, logit, painted_image = model.first_frame_click(
|
125 |
+
image=video_state["origin_images"][video_state["select_frame_number"]],
|
126 |
+
points=np.array(prompt["input_point"]),
|
127 |
+
labels=np.array(prompt["input_label"]),
|
128 |
+
multimask=prompt["multimask_output"],
|
129 |
+
)
|
130 |
+
video_state["masks"][video_state["select_frame_number"]] = mask
|
131 |
+
video_state["logits"][video_state["select_frame_number"]] = logit
|
132 |
+
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
133 |
+
|
134 |
+
return painted_image, video_state, interactive_state
|
135 |
+
|
136 |
+
# tracking vos
|
137 |
+
def vos_tracking_video(video_state, interactive_state):
|
138 |
+
model.xmem.clear_memory()
|
139 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
140 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
141 |
+
fps = video_state["fps"]
|
142 |
+
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
143 |
+
|
144 |
+
video_state["masks"][video_state["select_frame_number"]:] = masks
|
145 |
+
video_state["logits"][video_state["select_frame_number"]:] = logits
|
146 |
+
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
|
147 |
+
|
148 |
+
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
149 |
+
interactive_state["inference_times"] += 1
|
150 |
+
|
151 |
+
print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
|
152 |
+
interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
|
153 |
+
interactive_state["positive_click_times"],
|
154 |
+
interactive_state["negative_click_times"]))
|
155 |
+
|
156 |
+
#### shanggao code for mask save
|
157 |
+
if interactive_state["mask_save"]:
|
158 |
+
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
|
159 |
+
os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
|
160 |
+
i = 0
|
161 |
+
print("save mask")
|
162 |
+
for mask in video_state["masks"]:
|
163 |
+
np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
|
164 |
+
i+=1
|
165 |
+
# save_mask(video_state["masks"], video_state["video_name"])
|
166 |
+
#### shanggao code for mask save
|
167 |
+
return video_output, video_state, interactive_state
|
168 |
+
|
169 |
+
# generate video after vos inference
|
170 |
+
def generate_video_from_frames(frames, output_path, fps=30):
|
171 |
+
"""
|
172 |
+
Generates a video from a list of frames.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
frames (list of numpy arrays): The frames to include in the video.
|
176 |
+
output_path (str): The path to save the generated video.
|
177 |
+
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
178 |
+
"""
|
179 |
+
frames = torch.from_numpy(np.asarray(frames))
|
180 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
181 |
+
os.makedirs(os.path.dirname(output_path))
|
182 |
+
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
183 |
+
return output_path
|
184 |
+
|
185 |
+
# check and download checkpoints if needed
|
186 |
+
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
187 |
+
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
188 |
+
xmem_checkpoint = "XMem-s012.pth"
|
189 |
+
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
190 |
+
folder ="./checkpoints"
|
191 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
192 |
+
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
193 |
+
|
194 |
+
# args, defined in track_anything.py
|
195 |
+
args = parse_augment()
|
196 |
+
# args.port = 12212
|
197 |
+
# args.device = "cuda:4"
|
198 |
+
# args.mask_save = True
|
199 |
+
|
200 |
+
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
201 |
+
|
202 |
+
with gr.Blocks() as iface:
|
203 |
+
"""
|
204 |
+
state for
|
205 |
+
"""
|
206 |
+
click_state = gr.State([[],[]])
|
207 |
+
interactive_state = gr.State({
|
208 |
+
"inference_times": 0,
|
209 |
+
"negative_click_times" : 0,
|
210 |
+
"positive_click_times": 0,
|
211 |
+
"mask_save": args.mask_save
|
212 |
+
})
|
213 |
+
video_state = gr.State(
|
214 |
+
{
|
215 |
+
"video_name": "",
|
216 |
+
"origin_images": None,
|
217 |
+
"painted_images": None,
|
218 |
+
"masks": None,
|
219 |
+
"logits": None,
|
220 |
+
"select_frame_number": 0,
|
221 |
+
"fps": 30
|
222 |
+
}
|
223 |
+
)
|
224 |
+
|
225 |
+
with gr.Row():
|
226 |
+
|
227 |
+
# for user video input
|
228 |
+
with gr.Column(scale=1.0):
|
229 |
+
video_input = gr.Video().style(height=360)
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
with gr.Row(scale=1):
|
234 |
+
# put the template frame under the radio button
|
235 |
+
with gr.Column(scale=0.5):
|
236 |
+
# extract frames
|
237 |
+
with gr.Column():
|
238 |
+
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
239 |
+
|
240 |
+
# click points settins, negative or positive, mode continuous or single
|
241 |
+
with gr.Row():
|
242 |
+
with gr.Row(scale=0.5):
|
243 |
+
point_prompt = gr.Radio(
|
244 |
+
choices=["Positive", "Negative"],
|
245 |
+
value="Positive",
|
246 |
+
label="Point Prompt",
|
247 |
+
interactive=True)
|
248 |
+
click_mode = gr.Radio(
|
249 |
+
choices=["Continuous", "Single"],
|
250 |
+
value="Continuous",
|
251 |
+
label="Clicking Mode",
|
252 |
+
interactive=True)
|
253 |
+
with gr.Row(scale=0.5):
|
254 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
|
255 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
256 |
+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
|
257 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False)
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
with gr.Column(scale=0.5):
|
263 |
+
video_output = gr.Video().style(height=360)
|
264 |
+
tracking_video_predict_button = gr.Button(value="Tracking")
|
265 |
+
|
266 |
+
# first step: get the video information
|
267 |
+
extract_frames_button.click(
|
268 |
+
fn=get_frames_from_video,
|
269 |
+
inputs=[
|
270 |
+
video_input, video_state
|
271 |
+
],
|
272 |
+
outputs=[video_state, image_selection_slider],
|
273 |
+
)
|
274 |
+
|
275 |
+
# second step: select images from slider
|
276 |
+
image_selection_slider.release(fn=select_template,
|
277 |
+
inputs=[image_selection_slider, video_state],
|
278 |
+
outputs=[template_frame, video_state], api_name="select_image")
|
279 |
+
|
280 |
+
|
281 |
+
template_frame.select(
|
282 |
+
fn=sam_refine,
|
283 |
+
inputs=[video_state, point_prompt, click_state, interactive_state],
|
284 |
+
outputs=[template_frame, video_state, interactive_state]
|
285 |
+
)
|
286 |
+
|
287 |
+
tracking_video_predict_button.click(
|
288 |
+
fn=vos_tracking_video,
|
289 |
+
inputs=[video_state, interactive_state],
|
290 |
+
outputs=[video_output, video_state, interactive_state]
|
291 |
+
)
|
292 |
+
|
293 |
+
|
294 |
+
# clear input
|
295 |
+
video_input.clear(
|
296 |
+
lambda: (
|
297 |
+
{
|
298 |
+
"origin_images": None,
|
299 |
+
"painted_images": None,
|
300 |
+
"masks": None,
|
301 |
+
"logits": None,
|
302 |
+
"select_frame_number": 0,
|
303 |
+
"fps": 30
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"inference_times": 0,
|
307 |
+
"negative_click_times" : 0,
|
308 |
+
"positive_click_times": 0,
|
309 |
+
"mask_save": args.mask_save
|
310 |
+
},
|
311 |
+
[[],[]]
|
312 |
+
),
|
313 |
+
[],
|
314 |
+
[
|
315 |
+
video_state,
|
316 |
+
interactive_state,
|
317 |
+
click_state,
|
318 |
+
],
|
319 |
+
queue=False,
|
320 |
+
show_progress=False
|
321 |
+
)
|
322 |
+
clear_button_image.click(
|
323 |
+
lambda: (
|
324 |
+
{
|
325 |
+
"origin_images": None,
|
326 |
+
"painted_images": None,
|
327 |
+
"masks": None,
|
328 |
+
"logits": None,
|
329 |
+
"select_frame_number": 0,
|
330 |
+
"fps": 30
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"inference_times": 0,
|
334 |
+
"negative_click_times" : 0,
|
335 |
+
"positive_click_times": 0,
|
336 |
+
"mask_save": args.mask_save
|
337 |
+
},
|
338 |
+
[[],[]]
|
339 |
+
),
|
340 |
+
[],
|
341 |
+
[
|
342 |
+
video_state,
|
343 |
+
interactive_state,
|
344 |
+
click_state,
|
345 |
+
],
|
346 |
+
|
347 |
+
queue=False,
|
348 |
+
show_progress=False
|
349 |
+
|
350 |
+
)
|
351 |
+
clear_button_clike.click(
|
352 |
+
lambda: ([[],[]]),
|
353 |
+
[],
|
354 |
+
[click_state],
|
355 |
+
queue=False,
|
356 |
+
show_progress=False
|
357 |
+
)
|
358 |
+
iface.queue(concurrency_count=1)
|
359 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
360 |
+
|
361 |
+
|
362 |
+
|
app_save.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from demo import automask_image_app, automask_video_app, sahi_autoseg_app
|
3 |
+
import argparse
|
4 |
+
import cv2
|
5 |
+
import time
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
sys.path.append(sys.path[0]+"/tracker")
|
11 |
+
sys.path.append(sys.path[0]+"/tracker/model")
|
12 |
+
from track_anything import TrackingAnything
|
13 |
+
from track_anything import parse_augment
|
14 |
+
import requests
|
15 |
+
import json
|
16 |
+
import torchvision
|
17 |
+
import torch
|
18 |
+
import concurrent.futures
|
19 |
+
import queue
|
20 |
+
|
21 |
+
def download_checkpoint(url, folder, filename):
|
22 |
+
os.makedirs(folder, exist_ok=True)
|
23 |
+
filepath = os.path.join(folder, filename)
|
24 |
+
|
25 |
+
if not os.path.exists(filepath):
|
26 |
+
print("download checkpoints ......")
|
27 |
+
response = requests.get(url, stream=True)
|
28 |
+
with open(filepath, "wb") as f:
|
29 |
+
for chunk in response.iter_content(chunk_size=8192):
|
30 |
+
if chunk:
|
31 |
+
f.write(chunk)
|
32 |
+
|
33 |
+
print("download successfully!")
|
34 |
+
|
35 |
+
return filepath
|
36 |
+
|
37 |
+
def pause_video(play_state):
|
38 |
+
print("user pause_video")
|
39 |
+
play_state.append(time.time())
|
40 |
+
return play_state
|
41 |
+
|
42 |
+
def play_video(play_state):
|
43 |
+
print("user play_video")
|
44 |
+
play_state.append(time.time())
|
45 |
+
return play_state
|
46 |
+
|
47 |
+
# convert points input to prompt state
|
48 |
+
def get_prompt(click_state, click_input):
|
49 |
+
inputs = json.loads(click_input)
|
50 |
+
points = click_state[0]
|
51 |
+
labels = click_state[1]
|
52 |
+
for input in inputs:
|
53 |
+
points.append(input[:2])
|
54 |
+
labels.append(input[2])
|
55 |
+
click_state[0] = points
|
56 |
+
click_state[1] = labels
|
57 |
+
prompt = {
|
58 |
+
"prompt_type":["click"],
|
59 |
+
"input_point":click_state[0],
|
60 |
+
"input_label":click_state[1],
|
61 |
+
"multimask_output":"True",
|
62 |
+
}
|
63 |
+
return prompt
|
64 |
+
|
65 |
+
def get_frames_from_video(video_input, play_state):
|
66 |
+
"""
|
67 |
+
Args:
|
68 |
+
video_path:str
|
69 |
+
timestamp:float64
|
70 |
+
Return
|
71 |
+
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
72 |
+
"""
|
73 |
+
video_path = video_input
|
74 |
+
# video_name = video_path.split('/')[-1]
|
75 |
+
|
76 |
+
try:
|
77 |
+
timestamp = play_state[1] - play_state[0]
|
78 |
+
except:
|
79 |
+
timestamp = 0
|
80 |
+
frames = []
|
81 |
+
try:
|
82 |
+
cap = cv2.VideoCapture(video_path)
|
83 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
84 |
+
while cap.isOpened():
|
85 |
+
ret, frame = cap.read()
|
86 |
+
if ret == True:
|
87 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
88 |
+
else:
|
89 |
+
break
|
90 |
+
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
91 |
+
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
92 |
+
|
93 |
+
# for index, frame in enumerate(frames):
|
94 |
+
# frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
95 |
+
|
96 |
+
key_frame_index = int(timestamp * fps)
|
97 |
+
nearest_frame = frames[key_frame_index]
|
98 |
+
frames_split = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
99 |
+
# output_path='./seperate.mp4'
|
100 |
+
# torchvision.io.write_video(output_path, frames[1], fps=fps, video_codec="libx264")
|
101 |
+
|
102 |
+
# set image in sam when select the template frame
|
103 |
+
model.samcontroler.sam_controler.set_image(nearest_frame)
|
104 |
+
return frames_split, nearest_frame, nearest_frame, fps
|
105 |
+
|
106 |
+
def generate_video_from_frames(frames, output_path, fps=30):
|
107 |
+
"""
|
108 |
+
Generates a video from a list of frames.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
frames (list of numpy arrays): The frames to include in the video.
|
112 |
+
output_path (str): The path to save the generated video.
|
113 |
+
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
114 |
+
"""
|
115 |
+
# height, width, layers = frames[0].shape
|
116 |
+
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
117 |
+
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
118 |
+
|
119 |
+
# for frame in frames:
|
120 |
+
# video.write(frame)
|
121 |
+
|
122 |
+
# video.release()
|
123 |
+
frames = torch.from_numpy(np.asarray(frames))
|
124 |
+
output_path='./output.mp4'
|
125 |
+
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
126 |
+
return output_path
|
127 |
+
|
128 |
+
def model_reset():
|
129 |
+
model.xmem.clear_memory()
|
130 |
+
return None
|
131 |
+
|
132 |
+
def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
template_frame: PIL.Image
|
136 |
+
point_prompt: flag for positive or negative button click
|
137 |
+
click_state: [[points], [labels]]
|
138 |
+
"""
|
139 |
+
if point_prompt == "Positive":
|
140 |
+
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
141 |
+
else:
|
142 |
+
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
143 |
+
|
144 |
+
# prompt for sam model
|
145 |
+
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
146 |
+
|
147 |
+
# default value
|
148 |
+
# points = np.array([[evt.index[0],evt.index[1]]])
|
149 |
+
# labels= np.array([1])
|
150 |
+
if len(logit)==0:
|
151 |
+
logit = None
|
152 |
+
|
153 |
+
mask, logit, painted_image = model.first_frame_click(
|
154 |
+
image=origin_frame,
|
155 |
+
points=np.array(prompt["input_point"]),
|
156 |
+
labels=np.array(prompt["input_label"]),
|
157 |
+
multimask=prompt["multimask_output"],
|
158 |
+
)
|
159 |
+
return painted_image, click_state, logit, mask
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def vos_tracking_video(video_state, template_mask,fps,video_input):
|
164 |
+
|
165 |
+
masks, logits, painted_images = model.generator(images=video_state[1], template_mask=template_mask)
|
166 |
+
video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
|
167 |
+
# image_selection_slider = gr.Slider(minimum=1, maximum=len(video_state[1]), value=1, label="Image Selection", interactive=True)
|
168 |
+
video_name = video_input.split('/')[-1].split('.')[0]
|
169 |
+
result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
|
170 |
+
if not os.path.exists(result_path):
|
171 |
+
os.makedirs(result_path)
|
172 |
+
i=0
|
173 |
+
for mask in masks:
|
174 |
+
np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
|
175 |
+
i+=1
|
176 |
+
return video_output, painted_images, masks, logits
|
177 |
+
|
178 |
+
def vos_tracking_image(image_selection_slider, painted_images):
|
179 |
+
|
180 |
+
# images = video_state[1]
|
181 |
+
percentage = image_selection_slider / 100
|
182 |
+
select_frame_num = int(percentage * len(painted_images))
|
183 |
+
return painted_images[select_frame_num], select_frame_num
|
184 |
+
|
185 |
+
def interactive_correction(video_state, point_prompt, click_state, select_correction_frame, evt: gr.SelectData):
|
186 |
+
"""
|
187 |
+
Args:
|
188 |
+
template_frame: PIL.Image
|
189 |
+
point_prompt: flag for positive or negative button click
|
190 |
+
click_state: [[points], [labels]]
|
191 |
+
"""
|
192 |
+
refine_image = video_state[1][select_correction_frame]
|
193 |
+
if point_prompt == "Positive":
|
194 |
+
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
195 |
+
else:
|
196 |
+
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
197 |
+
|
198 |
+
# prompt for sam model
|
199 |
+
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
200 |
+
model.samcontroler.seg_again(refine_image)
|
201 |
+
corrected_mask, corrected_logit, corrected_painted_image = model.first_frame_click(
|
202 |
+
image=refine_image,
|
203 |
+
points=np.array(prompt["input_point"]),
|
204 |
+
labels=np.array(prompt["input_label"]),
|
205 |
+
multimask=prompt["multimask_output"],
|
206 |
+
)
|
207 |
+
return corrected_painted_image, [corrected_mask, corrected_logit, corrected_painted_image]
|
208 |
+
|
209 |
+
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps, video_input):
|
210 |
+
model.xmem.clear_memory()
|
211 |
+
# inference the following images
|
212 |
+
following_images = video_state[1][select_correction_frame:]
|
213 |
+
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, template_mask=corrected_state[0])
|
214 |
+
masks = masks[:select_correction_frame] + corrected_masks
|
215 |
+
logits = logits[:select_correction_frame] + corrected_logits
|
216 |
+
painted_images = painted_images[:select_correction_frame] + corrected_painted_images
|
217 |
+
video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
|
218 |
+
|
219 |
+
video_name = video_input.split('/')[-1].split('.')[0]
|
220 |
+
result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
|
221 |
+
if not os.path.exists(result_path):
|
222 |
+
os.makedirs(result_path)
|
223 |
+
i=0
|
224 |
+
for mask in masks:
|
225 |
+
np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
|
226 |
+
i+=1
|
227 |
+
return video_output, painted_images, logits, masks
|
228 |
+
|
229 |
+
# check and download checkpoints if needed
|
230 |
+
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
231 |
+
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
232 |
+
xmem_checkpoint = "XMem-s012.pth"
|
233 |
+
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
234 |
+
folder ="./checkpoints"
|
235 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
236 |
+
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
237 |
+
|
238 |
+
# args, defined in track_anything.py
|
239 |
+
args = parse_augment()
|
240 |
+
args.port = 12207
|
241 |
+
args.device = "cuda:5"
|
242 |
+
|
243 |
+
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
244 |
+
|
245 |
+
with gr.Blocks() as iface:
|
246 |
+
"""
|
247 |
+
state for
|
248 |
+
"""
|
249 |
+
state = gr.State([])
|
250 |
+
play_state = gr.State([])
|
251 |
+
video_state = gr.State([[],[],[]])
|
252 |
+
click_state = gr.State([[],[]])
|
253 |
+
logits = gr.State([])
|
254 |
+
masks = gr.State([])
|
255 |
+
painted_images = gr.State([])
|
256 |
+
origin_image = gr.State(None)
|
257 |
+
template_mask = gr.State(None)
|
258 |
+
select_correction_frame = gr.State(None)
|
259 |
+
corrected_state = gr.State([[],[],[]])
|
260 |
+
fps = gr.State([])
|
261 |
+
# video_name = gr.State([])
|
262 |
+
# queue value for image refresh, origin image, mask, logits, painted image
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
with gr.Row():
|
267 |
+
|
268 |
+
# for user video input
|
269 |
+
with gr.Column(scale=1.0):
|
270 |
+
video_input = gr.Video().style(height=720)
|
271 |
+
|
272 |
+
# listen to the user action for play and pause input video
|
273 |
+
video_input.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
|
274 |
+
video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
|
275 |
+
|
276 |
+
|
277 |
+
with gr.Row(scale=1):
|
278 |
+
# put the template frame under the radio button
|
279 |
+
with gr.Column(scale=0.5):
|
280 |
+
# click points settins, negative or positive, mode continuous or single
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Row(scale=0.5):
|
283 |
+
point_prompt = gr.Radio(
|
284 |
+
choices=["Positive", "Negative"],
|
285 |
+
value="Positive",
|
286 |
+
label="Point Prompt",
|
287 |
+
interactive=True)
|
288 |
+
click_mode = gr.Radio(
|
289 |
+
choices=["Continuous", "Single"],
|
290 |
+
value="Continuous",
|
291 |
+
label="Clicking Mode",
|
292 |
+
interactive=True)
|
293 |
+
with gr.Row(scale=0.5):
|
294 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
|
295 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
296 |
+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
|
297 |
+
with gr.Column():
|
298 |
+
template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
with gr.Column(scale=0.5):
|
303 |
+
|
304 |
+
|
305 |
+
# for intermedia result check and correction
|
306 |
+
# intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
|
307 |
+
video_output = gr.Video().style(height=360)
|
308 |
+
tracking_video_predict_button = gr.Button(value="Tracking")
|
309 |
+
|
310 |
+
image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
|
311 |
+
image_selection_slider = gr.Slider(minimum=0, maximum=100, step=0.1, value=0, label="Image Selection", interactive=True)
|
312 |
+
correct_track_button = gr.Button(value="Interactive Correction")
|
313 |
+
|
314 |
+
template_frame.select(
|
315 |
+
fn=sam_refine,
|
316 |
+
inputs=[
|
317 |
+
origin_image, point_prompt, click_state, logits
|
318 |
+
],
|
319 |
+
outputs=[
|
320 |
+
template_frame, click_state, logits, template_mask
|
321 |
+
]
|
322 |
+
)
|
323 |
+
|
324 |
+
template_select_button.click(
|
325 |
+
fn=get_frames_from_video,
|
326 |
+
inputs=[
|
327 |
+
video_input,
|
328 |
+
play_state
|
329 |
+
],
|
330 |
+
# outputs=[video_state, template_frame, origin_image, fps, video_name],
|
331 |
+
outputs=[video_state, template_frame, origin_image, fps],
|
332 |
+
)
|
333 |
+
|
334 |
+
tracking_video_predict_button.click(
|
335 |
+
fn=vos_tracking_video,
|
336 |
+
inputs=[video_state, template_mask, fps, video_input],
|
337 |
+
outputs=[video_output, painted_images, masks, logits]
|
338 |
+
)
|
339 |
+
image_selection_slider.release(fn=vos_tracking_image,
|
340 |
+
inputs=[image_selection_slider, painted_images], outputs=[image_output, select_correction_frame], api_name="select_image")
|
341 |
+
# correction
|
342 |
+
image_output.select(
|
343 |
+
fn=interactive_correction,
|
344 |
+
inputs=[video_state, point_prompt, click_state, select_correction_frame],
|
345 |
+
outputs=[image_output, corrected_state]
|
346 |
+
)
|
347 |
+
correct_track_button.click(
|
348 |
+
fn=correct_track,
|
349 |
+
inputs=[video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps,video_input],
|
350 |
+
outputs=[video_output, painted_images, logits, masks ]
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
# clear input
|
356 |
+
video_input.clear(
|
357 |
+
lambda: ([], [], [[], [], []],
|
358 |
+
None, "", "", "", "", "", "", "", [[],[]],
|
359 |
+
None),
|
360 |
+
[],
|
361 |
+
[ state, play_state, video_state,
|
362 |
+
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
|
363 |
+
select_correction_frame],
|
364 |
+
queue=False,
|
365 |
+
show_progress=False
|
366 |
+
)
|
367 |
+
clear_button_image.click(
|
368 |
+
fn=model_reset
|
369 |
+
)
|
370 |
+
clear_button_clike.click(
|
371 |
+
lambda: ([[],[]]),
|
372 |
+
[],
|
373 |
+
[click_state],
|
374 |
+
queue=False,
|
375 |
+
show_progress=False
|
376 |
+
)
|
377 |
+
iface.queue(concurrency_count=1)
|
378 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
379 |
+
|
380 |
+
|
381 |
+
|
app_test.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def update_iframe(slider_value):
|
4 |
+
return f'''
|
5 |
+
<script>
|
6 |
+
window.addEventListener('message', function(event) {{
|
7 |
+
if (event.data.sliderValue !== undefined) {{
|
8 |
+
var iframe = document.getElementById("text_iframe");
|
9 |
+
iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
|
10 |
+
}}
|
11 |
+
}}, false);
|
12 |
+
</script>
|
13 |
+
<iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
|
14 |
+
'''
|
15 |
+
|
16 |
+
iface = gr.Interface(
|
17 |
+
fn=update_iframe,
|
18 |
+
inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
|
19 |
+
outputs=gr.outputs.HTML(),
|
20 |
+
allow_flagging=False,
|
21 |
+
)
|
22 |
+
|
23 |
+
iface.launch(server_name='0.0.0.0', server_port=12212)
|
assets/demo_version_1.MP4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b61b54bc6eb0d0f7416f95aa3cd6a48d850ca7473022ec1aff48310911b0233
|
3 |
+
size 27053146
|
assets/inpainting.gif
ADDED
Git LFS Details
|
assets/poster_demo_version_1.png
ADDED
assets/qingming.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58b34bbce0bd0a18ab5fc5450d4046e1cfc6bd55c508046695545819d8fc46dc
|
3 |
+
size 4483842
|
demo.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
|
2 |
+
|
3 |
+
# For image
|
4 |
+
|
5 |
+
def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
|
6 |
+
SegAutoMaskPredictor().image_predict(
|
7 |
+
source=image_path,
|
8 |
+
model_type=model_type, # vit_l, vit_h, vit_b
|
9 |
+
points_per_side=points_per_side,
|
10 |
+
points_per_batch=points_per_batch,
|
11 |
+
min_area=min_area,
|
12 |
+
output_path="output.png",
|
13 |
+
show=False,
|
14 |
+
save=True,
|
15 |
+
)
|
16 |
+
return "output.png"
|
17 |
+
|
18 |
+
|
19 |
+
# For video
|
20 |
+
|
21 |
+
def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
|
22 |
+
SegAutoMaskPredictor().video_predict(
|
23 |
+
source=video_path,
|
24 |
+
model_type=model_type, # vit_l, vit_h, vit_b
|
25 |
+
points_per_side=points_per_side,
|
26 |
+
points_per_batch=points_per_batch,
|
27 |
+
min_area=min_area,
|
28 |
+
output_path="output.mp4",
|
29 |
+
)
|
30 |
+
return "output.mp4"
|
31 |
+
|
32 |
+
|
33 |
+
# For manuel box and point selection
|
34 |
+
|
35 |
+
def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
|
36 |
+
SegManualMaskPredictor().image_predict(
|
37 |
+
source=image_path,
|
38 |
+
model_type=model_type, # vit_l, vit_h, vit_b
|
39 |
+
input_point=input_point,
|
40 |
+
input_label=input_label,
|
41 |
+
input_box=input_box,
|
42 |
+
multimask_output=multimask_output,
|
43 |
+
random_color=random_color,
|
44 |
+
output_path="output.png",
|
45 |
+
show=False,
|
46 |
+
save=True,
|
47 |
+
)
|
48 |
+
return "output.png"
|
49 |
+
|
50 |
+
|
51 |
+
# For sahi sliced prediction
|
52 |
+
|
53 |
+
def sahi_autoseg_app(
|
54 |
+
image_path,
|
55 |
+
sam_model_type,
|
56 |
+
detection_model_type,
|
57 |
+
detection_model_path,
|
58 |
+
conf_th,
|
59 |
+
image_size,
|
60 |
+
slice_height,
|
61 |
+
slice_width,
|
62 |
+
overlap_height_ratio,
|
63 |
+
overlap_width_ratio,
|
64 |
+
):
|
65 |
+
boxes = sahi_sliced_predict(
|
66 |
+
image_path=image_path,
|
67 |
+
detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
|
68 |
+
detection_model_path=detection_model_path,
|
69 |
+
conf_th=conf_th,
|
70 |
+
image_size=image_size,
|
71 |
+
slice_height=slice_height,
|
72 |
+
slice_width=slice_width,
|
73 |
+
overlap_height_ratio=overlap_height_ratio,
|
74 |
+
overlap_width_ratio=overlap_width_ratio,
|
75 |
+
)
|
76 |
+
|
77 |
+
SahiAutoSegmentation().predict(
|
78 |
+
source=image_path,
|
79 |
+
model_type=sam_model_type,
|
80 |
+
input_box=boxes,
|
81 |
+
multimask_output=False,
|
82 |
+
random_color=False,
|
83 |
+
show=False,
|
84 |
+
save=True,
|
85 |
+
)
|
86 |
+
|
87 |
+
return "output.png"
|
images/groceries.jpg
ADDED
images/mask_painter.png
ADDED
images/painter_input_image.jpg
ADDED
images/painter_input_mask.jpg
ADDED
images/painter_output_image.png
ADDED
images/painter_output_image__.png
ADDED
images/point_painter.png
ADDED
images/point_painter_1.png
ADDED
images/point_painter_2.png
ADDED
images/truck.jpg
ADDED
images/truck_both.jpg
ADDED
images/truck_mask.jpg
ADDED
images/truck_point.jpg
ADDED
inpainter/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
inpainter/base_inpainter.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
import cv2
|
8 |
+
import importlib
|
9 |
+
import numpy as np
|
10 |
+
from util.tensor_util import resize_frames, resize_masks
|
11 |
+
|
12 |
+
|
13 |
+
class BaseInpainter:
|
14 |
+
def __init__(self, E2FGVI_checkpoint, device) -> None:
|
15 |
+
"""
|
16 |
+
E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
|
17 |
+
"""
|
18 |
+
net = importlib.import_module('model.e2fgvi_hq')
|
19 |
+
self.model = net.InpaintGenerator().to(device)
|
20 |
+
self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
|
21 |
+
self.model.eval()
|
22 |
+
self.device = device
|
23 |
+
# load configurations
|
24 |
+
with open("inpainter/config/config.yaml", 'r') as stream:
|
25 |
+
config = yaml.safe_load(stream)
|
26 |
+
self.neighbor_stride = config['neighbor_stride']
|
27 |
+
self.num_ref = config['num_ref']
|
28 |
+
self.step = config['step']
|
29 |
+
|
30 |
+
# sample reference frames from the whole video
|
31 |
+
def get_ref_index(self, f, neighbor_ids, length):
|
32 |
+
ref_index = []
|
33 |
+
if self.num_ref == -1:
|
34 |
+
for i in range(0, length, self.step):
|
35 |
+
if i not in neighbor_ids:
|
36 |
+
ref_index.append(i)
|
37 |
+
else:
|
38 |
+
start_idx = max(0, f - self.step * (self.num_ref // 2))
|
39 |
+
end_idx = min(length, f + self.step * (self.num_ref // 2))
|
40 |
+
for i in range(start_idx, end_idx + 1, self.step):
|
41 |
+
if i not in neighbor_ids:
|
42 |
+
if len(ref_index) > self.num_ref:
|
43 |
+
break
|
44 |
+
ref_index.append(i)
|
45 |
+
return ref_index
|
46 |
+
|
47 |
+
def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
|
48 |
+
"""
|
49 |
+
frames: numpy array, T, H, W, 3
|
50 |
+
masks: numpy array, T, H, W
|
51 |
+
dilate_radius: radius when applying dilation on masks
|
52 |
+
ratio: down-sample ratio
|
53 |
+
|
54 |
+
Output:
|
55 |
+
inpainted_frames: numpy array, T, H, W, 3
|
56 |
+
"""
|
57 |
+
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
58 |
+
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
59 |
+
masks = masks.copy()
|
60 |
+
masks = np.clip(masks, 0, 1)
|
61 |
+
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
|
62 |
+
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
63 |
+
|
64 |
+
T, H, W = masks.shape
|
65 |
+
# size: (w, h)
|
66 |
+
if ratio == 1:
|
67 |
+
size = None
|
68 |
+
else:
|
69 |
+
size = (int(W*ratio), int(H*ratio))
|
70 |
+
|
71 |
+
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
72 |
+
binary_masks = resize_masks(masks, size)
|
73 |
+
frames = resize_frames(frames, size) # T, H, W, 3
|
74 |
+
# frames and binary_masks are numpy arrays
|
75 |
+
|
76 |
+
h, w = frames.shape[1:3]
|
77 |
+
video_length = T
|
78 |
+
|
79 |
+
# convert to tensor
|
80 |
+
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
|
81 |
+
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
|
82 |
+
|
83 |
+
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
84 |
+
comp_frames = [None] * video_length
|
85 |
+
|
86 |
+
for f in range(0, video_length, self.neighbor_stride):
|
87 |
+
neighbor_ids = [
|
88 |
+
i for i in range(max(0, f - self.neighbor_stride),
|
89 |
+
min(video_length, f + self.neighbor_stride + 1))
|
90 |
+
]
|
91 |
+
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
|
92 |
+
selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
|
93 |
+
selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
|
94 |
+
with torch.no_grad():
|
95 |
+
masked_imgs = selected_imgs * (1 - selected_masks)
|
96 |
+
mod_size_h = 60
|
97 |
+
mod_size_w = 108
|
98 |
+
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
|
99 |
+
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
|
100 |
+
masked_imgs = torch.cat(
|
101 |
+
[masked_imgs, torch.flip(masked_imgs, [3])],
|
102 |
+
3)[:, :, :, :h + h_pad, :]
|
103 |
+
masked_imgs = torch.cat(
|
104 |
+
[masked_imgs, torch.flip(masked_imgs, [4])],
|
105 |
+
4)[:, :, :, :, :w + w_pad]
|
106 |
+
pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
|
107 |
+
pred_imgs = pred_imgs[:, :, :h, :w]
|
108 |
+
pred_imgs = (pred_imgs + 1) / 2
|
109 |
+
pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
|
110 |
+
for i in range(len(neighbor_ids)):
|
111 |
+
idx = neighbor_ids[i]
|
112 |
+
img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
|
113 |
+
1 - binary_masks[idx])
|
114 |
+
if comp_frames[idx] is None:
|
115 |
+
comp_frames[idx] = img
|
116 |
+
else:
|
117 |
+
comp_frames[idx] = comp_frames[idx].astype(
|
118 |
+
np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
119 |
+
|
120 |
+
inpainted_frames = np.stack(comp_frames, 0)
|
121 |
+
return inpainted_frames.astype(np.uint8)
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
|
125 |
+
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
|
126 |
+
frame_path.sort()
|
127 |
+
mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
|
128 |
+
mask_path.sort()
|
129 |
+
save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
|
130 |
+
|
131 |
+
if not os.path.exists(save_path):
|
132 |
+
os.mkdir(save_path)
|
133 |
+
|
134 |
+
frames = []
|
135 |
+
masks = []
|
136 |
+
for fid, mid in zip(frame_path, mask_path):
|
137 |
+
frames.append(Image.open(fid).convert('RGB'))
|
138 |
+
masks.append(Image.open(mid).convert('P'))
|
139 |
+
|
140 |
+
frames = np.stack(frames, 0)
|
141 |
+
masks = np.stack(masks, 0)
|
142 |
+
|
143 |
+
# ----------------------------------------------
|
144 |
+
# how to use
|
145 |
+
# ----------------------------------------------
|
146 |
+
# 1/3: set checkpoint and device
|
147 |
+
checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
|
148 |
+
device = 'cuda:6'
|
149 |
+
# 2/3: initialise inpainter
|
150 |
+
base_inpainter = BaseInpainter(checkpoint, device)
|
151 |
+
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
|
152 |
+
# ratio: (0, 1], ratio for down sample, default value is 1
|
153 |
+
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=1) # numpy array, T, H, W, 3
|
154 |
+
# ----------------------------------------------
|
155 |
+
# end
|
156 |
+
# ----------------------------------------------
|
157 |
+
# save
|
158 |
+
for ti, inpainted_frame in enumerate(inpainted_frames):
|
159 |
+
frame = Image.fromarray(inpainted_frame).convert('RGB')
|
160 |
+
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
|
inpainter/config/config.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config info for E2FGVI
|
2 |
+
neighbor_stride: 5
|
3 |
+
num_ref: -1
|
4 |
+
step: 10
|
inpainter/model/e2fgvi.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Towards An End-to-End Framework for Video Inpainting
|
2 |
+
'''
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from model.modules.flow_comp import SPyNet
|
9 |
+
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
|
10 |
+
from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
|
11 |
+
from model.modules.spectral_norm import spectral_norm as _spectral_norm
|
12 |
+
|
13 |
+
|
14 |
+
class BaseNetwork(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(BaseNetwork, self).__init__()
|
17 |
+
|
18 |
+
def print_network(self):
|
19 |
+
if isinstance(self, list):
|
20 |
+
self = self[0]
|
21 |
+
num_params = 0
|
22 |
+
for param in self.parameters():
|
23 |
+
num_params += param.numel()
|
24 |
+
print(
|
25 |
+
'Network [%s] was created. Total number of parameters: %.1f million. '
|
26 |
+
'To see the architecture, do print(network).' %
|
27 |
+
(type(self).__name__, num_params / 1000000))
|
28 |
+
|
29 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
30 |
+
'''
|
31 |
+
initialize network's weights
|
32 |
+
init_type: normal | xavier | kaiming | orthogonal
|
33 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
34 |
+
'''
|
35 |
+
def init_func(m):
|
36 |
+
classname = m.__class__.__name__
|
37 |
+
if classname.find('InstanceNorm2d') != -1:
|
38 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
39 |
+
nn.init.constant_(m.weight.data, 1.0)
|
40 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
41 |
+
nn.init.constant_(m.bias.data, 0.0)
|
42 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
43 |
+
or classname.find('Linear') != -1):
|
44 |
+
if init_type == 'normal':
|
45 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
46 |
+
elif init_type == 'xavier':
|
47 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
48 |
+
elif init_type == 'xavier_uniform':
|
49 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
50 |
+
elif init_type == 'kaiming':
|
51 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
52 |
+
elif init_type == 'orthogonal':
|
53 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
54 |
+
elif init_type == 'none': # uses pytorch's default init method
|
55 |
+
m.reset_parameters()
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(
|
58 |
+
'initialization method [%s] is not implemented' %
|
59 |
+
init_type)
|
60 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
61 |
+
nn.init.constant_(m.bias.data, 0.0)
|
62 |
+
|
63 |
+
self.apply(init_func)
|
64 |
+
|
65 |
+
# propagate to children
|
66 |
+
for m in self.children():
|
67 |
+
if hasattr(m, 'init_weights'):
|
68 |
+
m.init_weights(init_type, gain)
|
69 |
+
|
70 |
+
|
71 |
+
class Encoder(nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super(Encoder, self).__init__()
|
74 |
+
self.group = [1, 2, 4, 8, 1]
|
75 |
+
self.layers = nn.ModuleList([
|
76 |
+
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
77 |
+
nn.LeakyReLU(0.2, inplace=True),
|
78 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
79 |
+
nn.LeakyReLU(0.2, inplace=True),
|
80 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
81 |
+
nn.LeakyReLU(0.2, inplace=True),
|
82 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
83 |
+
nn.LeakyReLU(0.2, inplace=True),
|
84 |
+
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
85 |
+
nn.LeakyReLU(0.2, inplace=True),
|
86 |
+
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
87 |
+
nn.LeakyReLU(0.2, inplace=True),
|
88 |
+
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
89 |
+
nn.LeakyReLU(0.2, inplace=True),
|
90 |
+
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
91 |
+
nn.LeakyReLU(0.2, inplace=True),
|
92 |
+
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
93 |
+
nn.LeakyReLU(0.2, inplace=True)
|
94 |
+
])
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
bt, c, h, w = x.size()
|
98 |
+
h, w = h // 4, w // 4
|
99 |
+
out = x
|
100 |
+
for i, layer in enumerate(self.layers):
|
101 |
+
if i == 8:
|
102 |
+
x0 = out
|
103 |
+
if i > 8 and i % 2 == 0:
|
104 |
+
g = self.group[(i - 8) // 2]
|
105 |
+
x = x0.view(bt, g, -1, h, w)
|
106 |
+
o = out.view(bt, g, -1, h, w)
|
107 |
+
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
108 |
+
out = layer(out)
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
class deconv(nn.Module):
|
113 |
+
def __init__(self,
|
114 |
+
input_channel,
|
115 |
+
output_channel,
|
116 |
+
kernel_size=3,
|
117 |
+
padding=0):
|
118 |
+
super().__init__()
|
119 |
+
self.conv = nn.Conv2d(input_channel,
|
120 |
+
output_channel,
|
121 |
+
kernel_size=kernel_size,
|
122 |
+
stride=1,
|
123 |
+
padding=padding)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
x = F.interpolate(x,
|
127 |
+
scale_factor=2,
|
128 |
+
mode='bilinear',
|
129 |
+
align_corners=True)
|
130 |
+
return self.conv(x)
|
131 |
+
|
132 |
+
|
133 |
+
class InpaintGenerator(BaseNetwork):
|
134 |
+
def __init__(self, init_weights=True):
|
135 |
+
super(InpaintGenerator, self).__init__()
|
136 |
+
channel = 256
|
137 |
+
hidden = 512
|
138 |
+
|
139 |
+
# encoder
|
140 |
+
self.encoder = Encoder()
|
141 |
+
|
142 |
+
# decoder
|
143 |
+
self.decoder = nn.Sequential(
|
144 |
+
deconv(channel // 2, 128, kernel_size=3, padding=1),
|
145 |
+
nn.LeakyReLU(0.2, inplace=True),
|
146 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
147 |
+
nn.LeakyReLU(0.2, inplace=True),
|
148 |
+
deconv(64, 64, kernel_size=3, padding=1),
|
149 |
+
nn.LeakyReLU(0.2, inplace=True),
|
150 |
+
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
151 |
+
|
152 |
+
# feature propagation module
|
153 |
+
self.feat_prop_module = BidirectionalPropagation(channel // 2)
|
154 |
+
|
155 |
+
# soft split and soft composition
|
156 |
+
kernel_size = (7, 7)
|
157 |
+
padding = (3, 3)
|
158 |
+
stride = (3, 3)
|
159 |
+
output_size = (60, 108)
|
160 |
+
t2t_params = {
|
161 |
+
'kernel_size': kernel_size,
|
162 |
+
'stride': stride,
|
163 |
+
'padding': padding,
|
164 |
+
'output_size': output_size
|
165 |
+
}
|
166 |
+
self.ss = SoftSplit(channel // 2,
|
167 |
+
hidden,
|
168 |
+
kernel_size,
|
169 |
+
stride,
|
170 |
+
padding,
|
171 |
+
t2t_param=t2t_params)
|
172 |
+
self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
|
173 |
+
stride, padding)
|
174 |
+
|
175 |
+
n_vecs = 1
|
176 |
+
for i, d in enumerate(kernel_size):
|
177 |
+
n_vecs *= int((output_size[i] + 2 * padding[i] -
|
178 |
+
(d - 1) - 1) / stride[i] + 1)
|
179 |
+
|
180 |
+
blocks = []
|
181 |
+
depths = 8
|
182 |
+
num_heads = [4] * depths
|
183 |
+
window_size = [(5, 9)] * depths
|
184 |
+
focal_windows = [(5, 9)] * depths
|
185 |
+
focal_levels = [2] * depths
|
186 |
+
pool_method = "fc"
|
187 |
+
|
188 |
+
for i in range(depths):
|
189 |
+
blocks.append(
|
190 |
+
TemporalFocalTransformerBlock(dim=hidden,
|
191 |
+
num_heads=num_heads[i],
|
192 |
+
window_size=window_size[i],
|
193 |
+
focal_level=focal_levels[i],
|
194 |
+
focal_window=focal_windows[i],
|
195 |
+
n_vecs=n_vecs,
|
196 |
+
t2t_params=t2t_params,
|
197 |
+
pool_method=pool_method))
|
198 |
+
self.transformer = nn.Sequential(*blocks)
|
199 |
+
|
200 |
+
if init_weights:
|
201 |
+
self.init_weights()
|
202 |
+
# Need to initial the weights of MSDeformAttn specifically
|
203 |
+
for m in self.modules():
|
204 |
+
if isinstance(m, SecondOrderDeformableAlignment):
|
205 |
+
m.init_offset()
|
206 |
+
|
207 |
+
# flow completion network
|
208 |
+
self.update_spynet = SPyNet()
|
209 |
+
|
210 |
+
def forward_bidirect_flow(self, masked_local_frames):
|
211 |
+
b, l_t, c, h, w = masked_local_frames.size()
|
212 |
+
|
213 |
+
# compute forward and backward flows of masked frames
|
214 |
+
masked_local_frames = F.interpolate(masked_local_frames.view(
|
215 |
+
-1, c, h, w),
|
216 |
+
scale_factor=1 / 4,
|
217 |
+
mode='bilinear',
|
218 |
+
align_corners=True,
|
219 |
+
recompute_scale_factor=True)
|
220 |
+
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
|
221 |
+
w // 4)
|
222 |
+
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
|
223 |
+
-1, c, h // 4, w // 4)
|
224 |
+
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
|
225 |
+
-1, c, h // 4, w // 4)
|
226 |
+
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
|
227 |
+
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
|
228 |
+
|
229 |
+
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
|
230 |
+
w // 4)
|
231 |
+
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
|
232 |
+
w // 4)
|
233 |
+
|
234 |
+
return pred_flows_forward, pred_flows_backward
|
235 |
+
|
236 |
+
def forward(self, masked_frames, num_local_frames):
|
237 |
+
l_t = num_local_frames
|
238 |
+
b, t, ori_c, ori_h, ori_w = masked_frames.size()
|
239 |
+
|
240 |
+
# normalization before feeding into the flow completion module
|
241 |
+
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
|
242 |
+
pred_flows = self.forward_bidirect_flow(masked_local_frames)
|
243 |
+
|
244 |
+
# extracting features and performing the feature propagation on local features
|
245 |
+
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
|
246 |
+
_, c, h, w = enc_feat.size()
|
247 |
+
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
248 |
+
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
249 |
+
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
|
250 |
+
pred_flows[1])
|
251 |
+
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
252 |
+
|
253 |
+
# content hallucination through stacking multiple temporal focal transformer blocks
|
254 |
+
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
|
255 |
+
trans_feat = self.transformer(trans_feat)
|
256 |
+
trans_feat = self.sc(trans_feat, t)
|
257 |
+
trans_feat = trans_feat.view(b, t, -1, h, w)
|
258 |
+
enc_feat = enc_feat + trans_feat
|
259 |
+
|
260 |
+
# decode frames from features
|
261 |
+
output = self.decoder(enc_feat.view(b * t, c, h, w))
|
262 |
+
output = torch.tanh(output)
|
263 |
+
return output, pred_flows
|
264 |
+
|
265 |
+
|
266 |
+
# ######################################################################
|
267 |
+
# Discriminator for Temporal Patch GAN
|
268 |
+
# ######################################################################
|
269 |
+
|
270 |
+
|
271 |
+
class Discriminator(BaseNetwork):
|
272 |
+
def __init__(self,
|
273 |
+
in_channels=3,
|
274 |
+
use_sigmoid=False,
|
275 |
+
use_spectral_norm=True,
|
276 |
+
init_weights=True):
|
277 |
+
super(Discriminator, self).__init__()
|
278 |
+
self.use_sigmoid = use_sigmoid
|
279 |
+
nf = 32
|
280 |
+
|
281 |
+
self.conv = nn.Sequential(
|
282 |
+
spectral_norm(
|
283 |
+
nn.Conv3d(in_channels=in_channels,
|
284 |
+
out_channels=nf * 1,
|
285 |
+
kernel_size=(3, 5, 5),
|
286 |
+
stride=(1, 2, 2),
|
287 |
+
padding=1,
|
288 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
289 |
+
# nn.InstanceNorm2d(64, track_running_stats=False),
|
290 |
+
nn.LeakyReLU(0.2, inplace=True),
|
291 |
+
spectral_norm(
|
292 |
+
nn.Conv3d(nf * 1,
|
293 |
+
nf * 2,
|
294 |
+
kernel_size=(3, 5, 5),
|
295 |
+
stride=(1, 2, 2),
|
296 |
+
padding=(1, 2, 2),
|
297 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
298 |
+
# nn.InstanceNorm2d(128, track_running_stats=False),
|
299 |
+
nn.LeakyReLU(0.2, inplace=True),
|
300 |
+
spectral_norm(
|
301 |
+
nn.Conv3d(nf * 2,
|
302 |
+
nf * 4,
|
303 |
+
kernel_size=(3, 5, 5),
|
304 |
+
stride=(1, 2, 2),
|
305 |
+
padding=(1, 2, 2),
|
306 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
307 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
308 |
+
nn.LeakyReLU(0.2, inplace=True),
|
309 |
+
spectral_norm(
|
310 |
+
nn.Conv3d(nf * 4,
|
311 |
+
nf * 4,
|
312 |
+
kernel_size=(3, 5, 5),
|
313 |
+
stride=(1, 2, 2),
|
314 |
+
padding=(1, 2, 2),
|
315 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
316 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
317 |
+
nn.LeakyReLU(0.2, inplace=True),
|
318 |
+
spectral_norm(
|
319 |
+
nn.Conv3d(nf * 4,
|
320 |
+
nf * 4,
|
321 |
+
kernel_size=(3, 5, 5),
|
322 |
+
stride=(1, 2, 2),
|
323 |
+
padding=(1, 2, 2),
|
324 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
325 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
326 |
+
nn.LeakyReLU(0.2, inplace=True),
|
327 |
+
nn.Conv3d(nf * 4,
|
328 |
+
nf * 4,
|
329 |
+
kernel_size=(3, 5, 5),
|
330 |
+
stride=(1, 2, 2),
|
331 |
+
padding=(1, 2, 2)))
|
332 |
+
|
333 |
+
if init_weights:
|
334 |
+
self.init_weights()
|
335 |
+
|
336 |
+
def forward(self, xs):
|
337 |
+
# T, C, H, W = xs.shape (old)
|
338 |
+
# B, T, C, H, W (new)
|
339 |
+
xs_t = torch.transpose(xs, 1, 2)
|
340 |
+
feat = self.conv(xs_t)
|
341 |
+
if self.use_sigmoid:
|
342 |
+
feat = torch.sigmoid(feat)
|
343 |
+
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
344 |
+
return out
|
345 |
+
|
346 |
+
|
347 |
+
def spectral_norm(module, mode=True):
|
348 |
+
if mode:
|
349 |
+
return _spectral_norm(module)
|
350 |
+
return module
|
inpainter/model/e2fgvi_hq.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Towards An End-to-End Framework for Video Inpainting
|
2 |
+
'''
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from model.modules.flow_comp import SPyNet
|
9 |
+
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
|
10 |
+
from model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
|
11 |
+
from model.modules.spectral_norm import spectral_norm as _spectral_norm
|
12 |
+
|
13 |
+
|
14 |
+
class BaseNetwork(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(BaseNetwork, self).__init__()
|
17 |
+
|
18 |
+
def print_network(self):
|
19 |
+
if isinstance(self, list):
|
20 |
+
self = self[0]
|
21 |
+
num_params = 0
|
22 |
+
for param in self.parameters():
|
23 |
+
num_params += param.numel()
|
24 |
+
print(
|
25 |
+
'Network [%s] was created. Total number of parameters: %.1f million. '
|
26 |
+
'To see the architecture, do print(network).' %
|
27 |
+
(type(self).__name__, num_params / 1000000))
|
28 |
+
|
29 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
30 |
+
'''
|
31 |
+
initialize network's weights
|
32 |
+
init_type: normal | xavier | kaiming | orthogonal
|
33 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
34 |
+
'''
|
35 |
+
def init_func(m):
|
36 |
+
classname = m.__class__.__name__
|
37 |
+
if classname.find('InstanceNorm2d') != -1:
|
38 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
39 |
+
nn.init.constant_(m.weight.data, 1.0)
|
40 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
41 |
+
nn.init.constant_(m.bias.data, 0.0)
|
42 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
43 |
+
or classname.find('Linear') != -1):
|
44 |
+
if init_type == 'normal':
|
45 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
46 |
+
elif init_type == 'xavier':
|
47 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
48 |
+
elif init_type == 'xavier_uniform':
|
49 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
50 |
+
elif init_type == 'kaiming':
|
51 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
52 |
+
elif init_type == 'orthogonal':
|
53 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
54 |
+
elif init_type == 'none': # uses pytorch's default init method
|
55 |
+
m.reset_parameters()
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(
|
58 |
+
'initialization method [%s] is not implemented' %
|
59 |
+
init_type)
|
60 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
61 |
+
nn.init.constant_(m.bias.data, 0.0)
|
62 |
+
|
63 |
+
self.apply(init_func)
|
64 |
+
|
65 |
+
# propagate to children
|
66 |
+
for m in self.children():
|
67 |
+
if hasattr(m, 'init_weights'):
|
68 |
+
m.init_weights(init_type, gain)
|
69 |
+
|
70 |
+
|
71 |
+
class Encoder(nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super(Encoder, self).__init__()
|
74 |
+
self.group = [1, 2, 4, 8, 1]
|
75 |
+
self.layers = nn.ModuleList([
|
76 |
+
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
77 |
+
nn.LeakyReLU(0.2, inplace=True),
|
78 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
79 |
+
nn.LeakyReLU(0.2, inplace=True),
|
80 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
81 |
+
nn.LeakyReLU(0.2, inplace=True),
|
82 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
83 |
+
nn.LeakyReLU(0.2, inplace=True),
|
84 |
+
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
85 |
+
nn.LeakyReLU(0.2, inplace=True),
|
86 |
+
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
87 |
+
nn.LeakyReLU(0.2, inplace=True),
|
88 |
+
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
89 |
+
nn.LeakyReLU(0.2, inplace=True),
|
90 |
+
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
91 |
+
nn.LeakyReLU(0.2, inplace=True),
|
92 |
+
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
93 |
+
nn.LeakyReLU(0.2, inplace=True)
|
94 |
+
])
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
bt, c, _, _ = x.size()
|
98 |
+
# h, w = h//4, w//4
|
99 |
+
out = x
|
100 |
+
for i, layer in enumerate(self.layers):
|
101 |
+
if i == 8:
|
102 |
+
x0 = out
|
103 |
+
_, _, h, w = x0.size()
|
104 |
+
if i > 8 and i % 2 == 0:
|
105 |
+
g = self.group[(i - 8) // 2]
|
106 |
+
x = x0.view(bt, g, -1, h, w)
|
107 |
+
o = out.view(bt, g, -1, h, w)
|
108 |
+
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
109 |
+
out = layer(out)
|
110 |
+
return out
|
111 |
+
|
112 |
+
|
113 |
+
class deconv(nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
input_channel,
|
116 |
+
output_channel,
|
117 |
+
kernel_size=3,
|
118 |
+
padding=0):
|
119 |
+
super().__init__()
|
120 |
+
self.conv = nn.Conv2d(input_channel,
|
121 |
+
output_channel,
|
122 |
+
kernel_size=kernel_size,
|
123 |
+
stride=1,
|
124 |
+
padding=padding)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = F.interpolate(x,
|
128 |
+
scale_factor=2,
|
129 |
+
mode='bilinear',
|
130 |
+
align_corners=True)
|
131 |
+
return self.conv(x)
|
132 |
+
|
133 |
+
|
134 |
+
class InpaintGenerator(BaseNetwork):
|
135 |
+
def __init__(self, init_weights=True):
|
136 |
+
super(InpaintGenerator, self).__init__()
|
137 |
+
channel = 256
|
138 |
+
hidden = 512
|
139 |
+
|
140 |
+
# encoder
|
141 |
+
self.encoder = Encoder()
|
142 |
+
|
143 |
+
# decoder
|
144 |
+
self.decoder = nn.Sequential(
|
145 |
+
deconv(channel // 2, 128, kernel_size=3, padding=1),
|
146 |
+
nn.LeakyReLU(0.2, inplace=True),
|
147 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
148 |
+
nn.LeakyReLU(0.2, inplace=True),
|
149 |
+
deconv(64, 64, kernel_size=3, padding=1),
|
150 |
+
nn.LeakyReLU(0.2, inplace=True),
|
151 |
+
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
152 |
+
|
153 |
+
# feature propagation module
|
154 |
+
self.feat_prop_module = BidirectionalPropagation(channel // 2)
|
155 |
+
|
156 |
+
# soft split and soft composition
|
157 |
+
kernel_size = (7, 7)
|
158 |
+
padding = (3, 3)
|
159 |
+
stride = (3, 3)
|
160 |
+
output_size = (60, 108)
|
161 |
+
t2t_params = {
|
162 |
+
'kernel_size': kernel_size,
|
163 |
+
'stride': stride,
|
164 |
+
'padding': padding
|
165 |
+
}
|
166 |
+
self.ss = SoftSplit(channel // 2,
|
167 |
+
hidden,
|
168 |
+
kernel_size,
|
169 |
+
stride,
|
170 |
+
padding,
|
171 |
+
t2t_param=t2t_params)
|
172 |
+
self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
|
173 |
+
|
174 |
+
n_vecs = 1
|
175 |
+
for i, d in enumerate(kernel_size):
|
176 |
+
n_vecs *= int((output_size[i] + 2 * padding[i] -
|
177 |
+
(d - 1) - 1) / stride[i] + 1)
|
178 |
+
|
179 |
+
blocks = []
|
180 |
+
depths = 8
|
181 |
+
num_heads = [4] * depths
|
182 |
+
window_size = [(5, 9)] * depths
|
183 |
+
focal_windows = [(5, 9)] * depths
|
184 |
+
focal_levels = [2] * depths
|
185 |
+
pool_method = "fc"
|
186 |
+
|
187 |
+
for i in range(depths):
|
188 |
+
blocks.append(
|
189 |
+
TemporalFocalTransformerBlock(dim=hidden,
|
190 |
+
num_heads=num_heads[i],
|
191 |
+
window_size=window_size[i],
|
192 |
+
focal_level=focal_levels[i],
|
193 |
+
focal_window=focal_windows[i],
|
194 |
+
n_vecs=n_vecs,
|
195 |
+
t2t_params=t2t_params,
|
196 |
+
pool_method=pool_method))
|
197 |
+
self.transformer = nn.Sequential(*blocks)
|
198 |
+
|
199 |
+
if init_weights:
|
200 |
+
self.init_weights()
|
201 |
+
# Need to initial the weights of MSDeformAttn specifically
|
202 |
+
for m in self.modules():
|
203 |
+
if isinstance(m, SecondOrderDeformableAlignment):
|
204 |
+
m.init_offset()
|
205 |
+
|
206 |
+
# flow completion network
|
207 |
+
self.update_spynet = SPyNet()
|
208 |
+
|
209 |
+
def forward_bidirect_flow(self, masked_local_frames):
|
210 |
+
b, l_t, c, h, w = masked_local_frames.size()
|
211 |
+
|
212 |
+
# compute forward and backward flows of masked frames
|
213 |
+
masked_local_frames = F.interpolate(masked_local_frames.view(
|
214 |
+
-1, c, h, w),
|
215 |
+
scale_factor=1 / 4,
|
216 |
+
mode='bilinear',
|
217 |
+
align_corners=True,
|
218 |
+
recompute_scale_factor=True)
|
219 |
+
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
|
220 |
+
w // 4)
|
221 |
+
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
|
222 |
+
-1, c, h // 4, w // 4)
|
223 |
+
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
|
224 |
+
-1, c, h // 4, w // 4)
|
225 |
+
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
|
226 |
+
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
|
227 |
+
|
228 |
+
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
|
229 |
+
w // 4)
|
230 |
+
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
|
231 |
+
w // 4)
|
232 |
+
|
233 |
+
return pred_flows_forward, pred_flows_backward
|
234 |
+
|
235 |
+
def forward(self, masked_frames, num_local_frames):
|
236 |
+
l_t = num_local_frames
|
237 |
+
b, t, ori_c, ori_h, ori_w = masked_frames.size()
|
238 |
+
|
239 |
+
# normalization before feeding into the flow completion module
|
240 |
+
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
|
241 |
+
pred_flows = self.forward_bidirect_flow(masked_local_frames)
|
242 |
+
|
243 |
+
# extracting features and performing the feature propagation on local features
|
244 |
+
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
|
245 |
+
_, c, h, w = enc_feat.size()
|
246 |
+
fold_output_size = (h, w)
|
247 |
+
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
248 |
+
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
249 |
+
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
|
250 |
+
pred_flows[1])
|
251 |
+
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
252 |
+
|
253 |
+
# content hallucination through stacking multiple temporal focal transformer blocks
|
254 |
+
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
|
255 |
+
trans_feat = self.transformer([trans_feat, fold_output_size])
|
256 |
+
trans_feat = self.sc(trans_feat[0], t, fold_output_size)
|
257 |
+
trans_feat = trans_feat.view(b, t, -1, h, w)
|
258 |
+
enc_feat = enc_feat + trans_feat
|
259 |
+
|
260 |
+
# decode frames from features
|
261 |
+
output = self.decoder(enc_feat.view(b * t, c, h, w))
|
262 |
+
output = torch.tanh(output)
|
263 |
+
return output, pred_flows
|
264 |
+
|
265 |
+
|
266 |
+
# ######################################################################
|
267 |
+
# Discriminator for Temporal Patch GAN
|
268 |
+
# ######################################################################
|
269 |
+
|
270 |
+
|
271 |
+
class Discriminator(BaseNetwork):
|
272 |
+
def __init__(self,
|
273 |
+
in_channels=3,
|
274 |
+
use_sigmoid=False,
|
275 |
+
use_spectral_norm=True,
|
276 |
+
init_weights=True):
|
277 |
+
super(Discriminator, self).__init__()
|
278 |
+
self.use_sigmoid = use_sigmoid
|
279 |
+
nf = 32
|
280 |
+
|
281 |
+
self.conv = nn.Sequential(
|
282 |
+
spectral_norm(
|
283 |
+
nn.Conv3d(in_channels=in_channels,
|
284 |
+
out_channels=nf * 1,
|
285 |
+
kernel_size=(3, 5, 5),
|
286 |
+
stride=(1, 2, 2),
|
287 |
+
padding=1,
|
288 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
289 |
+
# nn.InstanceNorm2d(64, track_running_stats=False),
|
290 |
+
nn.LeakyReLU(0.2, inplace=True),
|
291 |
+
spectral_norm(
|
292 |
+
nn.Conv3d(nf * 1,
|
293 |
+
nf * 2,
|
294 |
+
kernel_size=(3, 5, 5),
|
295 |
+
stride=(1, 2, 2),
|
296 |
+
padding=(1, 2, 2),
|
297 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
298 |
+
# nn.InstanceNorm2d(128, track_running_stats=False),
|
299 |
+
nn.LeakyReLU(0.2, inplace=True),
|
300 |
+
spectral_norm(
|
301 |
+
nn.Conv3d(nf * 2,
|
302 |
+
nf * 4,
|
303 |
+
kernel_size=(3, 5, 5),
|
304 |
+
stride=(1, 2, 2),
|
305 |
+
padding=(1, 2, 2),
|
306 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
307 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
308 |
+
nn.LeakyReLU(0.2, inplace=True),
|
309 |
+
spectral_norm(
|
310 |
+
nn.Conv3d(nf * 4,
|
311 |
+
nf * 4,
|
312 |
+
kernel_size=(3, 5, 5),
|
313 |
+
stride=(1, 2, 2),
|
314 |
+
padding=(1, 2, 2),
|
315 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
316 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
317 |
+
nn.LeakyReLU(0.2, inplace=True),
|
318 |
+
spectral_norm(
|
319 |
+
nn.Conv3d(nf * 4,
|
320 |
+
nf * 4,
|
321 |
+
kernel_size=(3, 5, 5),
|
322 |
+
stride=(1, 2, 2),
|
323 |
+
padding=(1, 2, 2),
|
324 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
325 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
326 |
+
nn.LeakyReLU(0.2, inplace=True),
|
327 |
+
nn.Conv3d(nf * 4,
|
328 |
+
nf * 4,
|
329 |
+
kernel_size=(3, 5, 5),
|
330 |
+
stride=(1, 2, 2),
|
331 |
+
padding=(1, 2, 2)))
|
332 |
+
|
333 |
+
if init_weights:
|
334 |
+
self.init_weights()
|
335 |
+
|
336 |
+
def forward(self, xs):
|
337 |
+
# T, C, H, W = xs.shape (old)
|
338 |
+
# B, T, C, H, W (new)
|
339 |
+
xs_t = torch.transpose(xs, 1, 2)
|
340 |
+
feat = self.conv(xs_t)
|
341 |
+
if self.use_sigmoid:
|
342 |
+
feat = torch.sigmoid(feat)
|
343 |
+
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
344 |
+
return out
|
345 |
+
|
346 |
+
|
347 |
+
def spectral_norm(module, mode=True):
|
348 |
+
if mode:
|
349 |
+
return _spectral_norm(module)
|
350 |
+
return module
|
inpainter/model/modules/feat_prop.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
|
8 |
+
from mmengine.model import constant_init
|
9 |
+
|
10 |
+
from model.modules.flow_comp import flow_warp
|
11 |
+
|
12 |
+
|
13 |
+
class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
|
14 |
+
"""Second-order deformable alignment module."""
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
17 |
+
|
18 |
+
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
19 |
+
|
20 |
+
self.conv_offset = nn.Sequential(
|
21 |
+
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
22 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
23 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
24 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
25 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
26 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
27 |
+
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
28 |
+
)
|
29 |
+
|
30 |
+
self.init_offset()
|
31 |
+
|
32 |
+
def init_offset(self):
|
33 |
+
constant_init(self.conv_offset[-1], val=0, bias=0)
|
34 |
+
|
35 |
+
def forward(self, x, extra_feat, flow_1, flow_2):
|
36 |
+
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
37 |
+
out = self.conv_offset(extra_feat)
|
38 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
39 |
+
|
40 |
+
# offset
|
41 |
+
offset = self.max_residue_magnitude * torch.tanh(
|
42 |
+
torch.cat((o1, o2), dim=1))
|
43 |
+
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
44 |
+
offset_1 = offset_1 + flow_1.flip(1).repeat(1,
|
45 |
+
offset_1.size(1) // 2, 1,
|
46 |
+
1)
|
47 |
+
offset_2 = offset_2 + flow_2.flip(1).repeat(1,
|
48 |
+
offset_2.size(1) // 2, 1,
|
49 |
+
1)
|
50 |
+
offset = torch.cat([offset_1, offset_2], dim=1)
|
51 |
+
|
52 |
+
# mask
|
53 |
+
mask = torch.sigmoid(mask)
|
54 |
+
|
55 |
+
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
|
56 |
+
self.stride, self.padding,
|
57 |
+
self.dilation, self.groups,
|
58 |
+
self.deform_groups)
|
59 |
+
|
60 |
+
|
61 |
+
class BidirectionalPropagation(nn.Module):
|
62 |
+
def __init__(self, channel):
|
63 |
+
super(BidirectionalPropagation, self).__init__()
|
64 |
+
modules = ['backward_', 'forward_']
|
65 |
+
self.deform_align = nn.ModuleDict()
|
66 |
+
self.backbone = nn.ModuleDict()
|
67 |
+
self.channel = channel
|
68 |
+
|
69 |
+
for i, module in enumerate(modules):
|
70 |
+
self.deform_align[module] = SecondOrderDeformableAlignment(
|
71 |
+
2 * channel, channel, 3, padding=1, deform_groups=16)
|
72 |
+
|
73 |
+
self.backbone[module] = nn.Sequential(
|
74 |
+
nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
|
75 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
76 |
+
nn.Conv2d(channel, channel, 3, 1, 1),
|
77 |
+
)
|
78 |
+
|
79 |
+
self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
|
80 |
+
|
81 |
+
def forward(self, x, flows_backward, flows_forward):
|
82 |
+
"""
|
83 |
+
x shape : [b, t, c, h, w]
|
84 |
+
return [b, t, c, h, w]
|
85 |
+
"""
|
86 |
+
b, t, c, h, w = x.shape
|
87 |
+
feats = {}
|
88 |
+
feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
|
89 |
+
|
90 |
+
for module_name in ['backward_', 'forward_']:
|
91 |
+
|
92 |
+
feats[module_name] = []
|
93 |
+
|
94 |
+
frame_idx = range(0, t)
|
95 |
+
flow_idx = range(-1, t - 1)
|
96 |
+
mapping_idx = list(range(0, len(feats['spatial'])))
|
97 |
+
mapping_idx += mapping_idx[::-1]
|
98 |
+
|
99 |
+
if 'backward' in module_name:
|
100 |
+
frame_idx = frame_idx[::-1]
|
101 |
+
flows = flows_backward
|
102 |
+
else:
|
103 |
+
flows = flows_forward
|
104 |
+
|
105 |
+
feat_prop = x.new_zeros(b, self.channel, h, w)
|
106 |
+
for i, idx in enumerate(frame_idx):
|
107 |
+
feat_current = feats['spatial'][mapping_idx[idx]]
|
108 |
+
|
109 |
+
if i > 0:
|
110 |
+
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
111 |
+
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
112 |
+
|
113 |
+
# initialize second-order features
|
114 |
+
feat_n2 = torch.zeros_like(feat_prop)
|
115 |
+
flow_n2 = torch.zeros_like(flow_n1)
|
116 |
+
cond_n2 = torch.zeros_like(cond_n1)
|
117 |
+
if i > 1:
|
118 |
+
feat_n2 = feats[module_name][-2]
|
119 |
+
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
120 |
+
flow_n2 = flow_n1 + flow_warp(
|
121 |
+
flow_n2, flow_n1.permute(0, 2, 3, 1))
|
122 |
+
cond_n2 = flow_warp(feat_n2,
|
123 |
+
flow_n2.permute(0, 2, 3, 1))
|
124 |
+
|
125 |
+
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
126 |
+
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
127 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond,
|
128 |
+
flow_n1,
|
129 |
+
flow_n2)
|
130 |
+
|
131 |
+
feat = [feat_current] + [
|
132 |
+
feats[k][idx]
|
133 |
+
for k in feats if k not in ['spatial', module_name]
|
134 |
+
] + [feat_prop]
|
135 |
+
|
136 |
+
feat = torch.cat(feat, dim=1)
|
137 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
138 |
+
feats[module_name].append(feat_prop)
|
139 |
+
|
140 |
+
if 'backward' in module_name:
|
141 |
+
feats[module_name] = feats[module_name][::-1]
|
142 |
+
|
143 |
+
outputs = []
|
144 |
+
for i in range(0, t):
|
145 |
+
align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
|
146 |
+
align_feats = torch.cat(align_feats, dim=1)
|
147 |
+
outputs.append(self.fusion(align_feats))
|
148 |
+
|
149 |
+
return torch.stack(outputs, dim=1) + x
|
inpainter/model/modules/flow_comp.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from mmcv.cnn import ConvModule
|
8 |
+
from mmengine.runner import load_checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
class FlowCompletionLoss(nn.Module):
|
12 |
+
"""Flow completion loss"""
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
self.fix_spynet = SPyNet()
|
16 |
+
for p in self.fix_spynet.parameters():
|
17 |
+
p.requires_grad = False
|
18 |
+
|
19 |
+
self.l1_criterion = nn.L1Loss()
|
20 |
+
|
21 |
+
def forward(self, pred_flows, gt_local_frames):
|
22 |
+
b, l_t, c, h, w = gt_local_frames.size()
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
# compute gt forward and backward flows
|
26 |
+
gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
|
27 |
+
scale_factor=1 / 4,
|
28 |
+
mode='bilinear',
|
29 |
+
align_corners=True,
|
30 |
+
recompute_scale_factor=True)
|
31 |
+
gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
|
32 |
+
gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
|
33 |
+
-1, c, h // 4, w // 4)
|
34 |
+
gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
|
35 |
+
-1, c, h // 4, w // 4)
|
36 |
+
gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
|
37 |
+
gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
|
38 |
+
|
39 |
+
# calculate loss for flow completion
|
40 |
+
forward_flow_loss = self.l1_criterion(
|
41 |
+
pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
|
42 |
+
backward_flow_loss = self.l1_criterion(
|
43 |
+
pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
|
44 |
+
flow_loss = forward_flow_loss + backward_flow_loss
|
45 |
+
|
46 |
+
return flow_loss
|
47 |
+
|
48 |
+
|
49 |
+
class SPyNet(nn.Module):
|
50 |
+
"""SPyNet network structure.
|
51 |
+
The difference to the SPyNet in [tof.py] is that
|
52 |
+
1. more SPyNetBasicModule is used in this version, and
|
53 |
+
2. no batch normalization is used in this version.
|
54 |
+
Paper:
|
55 |
+
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
|
56 |
+
Args:
|
57 |
+
pretrained (str): path for pre-trained SPyNet. Default: None.
|
58 |
+
"""
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
use_pretrain=True,
|
62 |
+
pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.basic_module = nn.ModuleList(
|
67 |
+
[SPyNetBasicModule() for _ in range(6)])
|
68 |
+
|
69 |
+
if use_pretrain:
|
70 |
+
if isinstance(pretrained, str):
|
71 |
+
print("load pretrained SPyNet...")
|
72 |
+
load_checkpoint(self, pretrained, strict=True)
|
73 |
+
elif pretrained is not None:
|
74 |
+
raise TypeError('[pretrained] should be str or None, '
|
75 |
+
f'but got {type(pretrained)}.')
|
76 |
+
|
77 |
+
self.register_buffer(
|
78 |
+
'mean',
|
79 |
+
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
80 |
+
self.register_buffer(
|
81 |
+
'std',
|
82 |
+
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
83 |
+
|
84 |
+
def compute_flow(self, ref, supp):
|
85 |
+
"""Compute flow from ref to supp.
|
86 |
+
Note that in this function, the images are already resized to a
|
87 |
+
multiple of 32.
|
88 |
+
Args:
|
89 |
+
ref (Tensor): Reference image with shape of (n, 3, h, w).
|
90 |
+
supp (Tensor): Supporting image with shape of (n, 3, h, w).
|
91 |
+
Returns:
|
92 |
+
Tensor: Estimated optical flow: (n, 2, h, w).
|
93 |
+
"""
|
94 |
+
n, _, h, w = ref.size()
|
95 |
+
|
96 |
+
# normalize the input images
|
97 |
+
ref = [(ref - self.mean) / self.std]
|
98 |
+
supp = [(supp - self.mean) / self.std]
|
99 |
+
|
100 |
+
# generate downsampled frames
|
101 |
+
for level in range(5):
|
102 |
+
ref.append(
|
103 |
+
F.avg_pool2d(input=ref[-1],
|
104 |
+
kernel_size=2,
|
105 |
+
stride=2,
|
106 |
+
count_include_pad=False))
|
107 |
+
supp.append(
|
108 |
+
F.avg_pool2d(input=supp[-1],
|
109 |
+
kernel_size=2,
|
110 |
+
stride=2,
|
111 |
+
count_include_pad=False))
|
112 |
+
ref = ref[::-1]
|
113 |
+
supp = supp[::-1]
|
114 |
+
|
115 |
+
# flow computation
|
116 |
+
flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
|
117 |
+
for level in range(len(ref)):
|
118 |
+
if level == 0:
|
119 |
+
flow_up = flow
|
120 |
+
else:
|
121 |
+
flow_up = F.interpolate(input=flow,
|
122 |
+
scale_factor=2,
|
123 |
+
mode='bilinear',
|
124 |
+
align_corners=True) * 2.0
|
125 |
+
|
126 |
+
# add the residue to the upsampled flow
|
127 |
+
flow = flow_up + self.basic_module[level](torch.cat([
|
128 |
+
ref[level],
|
129 |
+
flow_warp(supp[level],
|
130 |
+
flow_up.permute(0, 2, 3, 1).contiguous(),
|
131 |
+
padding_mode='border'), flow_up
|
132 |
+
], 1))
|
133 |
+
|
134 |
+
return flow
|
135 |
+
|
136 |
+
def forward(self, ref, supp):
|
137 |
+
"""Forward function of SPyNet.
|
138 |
+
This function computes the optical flow from ref to supp.
|
139 |
+
Args:
|
140 |
+
ref (Tensor): Reference image with shape of (n, 3, h, w).
|
141 |
+
supp (Tensor): Supporting image with shape of (n, 3, h, w).
|
142 |
+
Returns:
|
143 |
+
Tensor: Estimated optical flow: (n, 2, h, w).
|
144 |
+
"""
|
145 |
+
|
146 |
+
# upsize to a multiple of 32
|
147 |
+
h, w = ref.shape[2:4]
|
148 |
+
w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
|
149 |
+
h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
|
150 |
+
ref = F.interpolate(input=ref,
|
151 |
+
size=(h_up, w_up),
|
152 |
+
mode='bilinear',
|
153 |
+
align_corners=False)
|
154 |
+
supp = F.interpolate(input=supp,
|
155 |
+
size=(h_up, w_up),
|
156 |
+
mode='bilinear',
|
157 |
+
align_corners=False)
|
158 |
+
|
159 |
+
# compute flow, and resize back to the original resolution
|
160 |
+
flow = F.interpolate(input=self.compute_flow(ref, supp),
|
161 |
+
size=(h, w),
|
162 |
+
mode='bilinear',
|
163 |
+
align_corners=False)
|
164 |
+
|
165 |
+
# adjust the flow values
|
166 |
+
flow[:, 0, :, :] *= float(w) / float(w_up)
|
167 |
+
flow[:, 1, :, :] *= float(h) / float(h_up)
|
168 |
+
|
169 |
+
return flow
|
170 |
+
|
171 |
+
|
172 |
+
class SPyNetBasicModule(nn.Module):
|
173 |
+
"""Basic Module for SPyNet.
|
174 |
+
Paper:
|
175 |
+
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
|
176 |
+
"""
|
177 |
+
def __init__(self):
|
178 |
+
super().__init__()
|
179 |
+
|
180 |
+
self.basic_module = nn.Sequential(
|
181 |
+
ConvModule(in_channels=8,
|
182 |
+
out_channels=32,
|
183 |
+
kernel_size=7,
|
184 |
+
stride=1,
|
185 |
+
padding=3,
|
186 |
+
norm_cfg=None,
|
187 |
+
act_cfg=dict(type='ReLU')),
|
188 |
+
ConvModule(in_channels=32,
|
189 |
+
out_channels=64,
|
190 |
+
kernel_size=7,
|
191 |
+
stride=1,
|
192 |
+
padding=3,
|
193 |
+
norm_cfg=None,
|
194 |
+
act_cfg=dict(type='ReLU')),
|
195 |
+
ConvModule(in_channels=64,
|
196 |
+
out_channels=32,
|
197 |
+
kernel_size=7,
|
198 |
+
stride=1,
|
199 |
+
padding=3,
|
200 |
+
norm_cfg=None,
|
201 |
+
act_cfg=dict(type='ReLU')),
|
202 |
+
ConvModule(in_channels=32,
|
203 |
+
out_channels=16,
|
204 |
+
kernel_size=7,
|
205 |
+
stride=1,
|
206 |
+
padding=3,
|
207 |
+
norm_cfg=None,
|
208 |
+
act_cfg=dict(type='ReLU')),
|
209 |
+
ConvModule(in_channels=16,
|
210 |
+
out_channels=2,
|
211 |
+
kernel_size=7,
|
212 |
+
stride=1,
|
213 |
+
padding=3,
|
214 |
+
norm_cfg=None,
|
215 |
+
act_cfg=None))
|
216 |
+
|
217 |
+
def forward(self, tensor_input):
|
218 |
+
"""
|
219 |
+
Args:
|
220 |
+
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
|
221 |
+
8 channels contain:
|
222 |
+
[reference image (3), neighbor image (3), initial flow (2)].
|
223 |
+
Returns:
|
224 |
+
Tensor: Refined flow with shape (b, 2, h, w)
|
225 |
+
"""
|
226 |
+
return self.basic_module(tensor_input)
|
227 |
+
|
228 |
+
|
229 |
+
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
230 |
+
def make_colorwheel():
|
231 |
+
"""
|
232 |
+
Generates a color wheel for optical flow visualization as presented in:
|
233 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
234 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
235 |
+
|
236 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
237 |
+
Code follows the the Matlab source code of Deqing Sun.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
np.ndarray: Color wheel
|
241 |
+
"""
|
242 |
+
|
243 |
+
RY = 15
|
244 |
+
YG = 6
|
245 |
+
GC = 4
|
246 |
+
CB = 11
|
247 |
+
BM = 13
|
248 |
+
MR = 6
|
249 |
+
|
250 |
+
ncols = RY + YG + GC + CB + BM + MR
|
251 |
+
colorwheel = np.zeros((ncols, 3))
|
252 |
+
col = 0
|
253 |
+
|
254 |
+
# RY
|
255 |
+
colorwheel[0:RY, 0] = 255
|
256 |
+
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
|
257 |
+
col = col + RY
|
258 |
+
# YG
|
259 |
+
colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
|
260 |
+
colorwheel[col:col + YG, 1] = 255
|
261 |
+
col = col + YG
|
262 |
+
# GC
|
263 |
+
colorwheel[col:col + GC, 1] = 255
|
264 |
+
colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
|
265 |
+
col = col + GC
|
266 |
+
# CB
|
267 |
+
colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
|
268 |
+
colorwheel[col:col + CB, 2] = 255
|
269 |
+
col = col + CB
|
270 |
+
# BM
|
271 |
+
colorwheel[col:col + BM, 2] = 255
|
272 |
+
colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
|
273 |
+
col = col + BM
|
274 |
+
# MR
|
275 |
+
colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
|
276 |
+
colorwheel[col:col + MR, 0] = 255
|
277 |
+
return colorwheel
|
278 |
+
|
279 |
+
|
280 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
281 |
+
"""
|
282 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
283 |
+
|
284 |
+
According to the C++ source code of Daniel Scharstein
|
285 |
+
According to the Matlab source code of Deqing Sun
|
286 |
+
|
287 |
+
Args:
|
288 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
289 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
290 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
294 |
+
"""
|
295 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
296 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
297 |
+
ncols = colorwheel.shape[0]
|
298 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
299 |
+
a = np.arctan2(-v, -u) / np.pi
|
300 |
+
fk = (a + 1) / 2 * (ncols - 1)
|
301 |
+
k0 = np.floor(fk).astype(np.int32)
|
302 |
+
k1 = k0 + 1
|
303 |
+
k1[k1 == ncols] = 0
|
304 |
+
f = fk - k0
|
305 |
+
for i in range(colorwheel.shape[1]):
|
306 |
+
tmp = colorwheel[:, i]
|
307 |
+
col0 = tmp[k0] / 255.0
|
308 |
+
col1 = tmp[k1] / 255.0
|
309 |
+
col = (1 - f) * col0 + f * col1
|
310 |
+
idx = (rad <= 1)
|
311 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
312 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
313 |
+
# Note the 2-i => BGR instead of RGB
|
314 |
+
ch_idx = 2 - i if convert_to_bgr else i
|
315 |
+
flow_image[:, :, ch_idx] = np.floor(255 * col)
|
316 |
+
return flow_image
|
317 |
+
|
318 |
+
|
319 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
320 |
+
"""
|
321 |
+
Expects a two dimensional flow image of shape.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
325 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
326 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
330 |
+
"""
|
331 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
332 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
333 |
+
if clip_flow is not None:
|
334 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
335 |
+
u = flow_uv[:, :, 0]
|
336 |
+
v = flow_uv[:, :, 1]
|
337 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
338 |
+
rad_max = np.max(rad)
|
339 |
+
epsilon = 1e-5
|
340 |
+
u = u / (rad_max + epsilon)
|
341 |
+
v = v / (rad_max + epsilon)
|
342 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
343 |
+
|
344 |
+
|
345 |
+
def flow_warp(x,
|
346 |
+
flow,
|
347 |
+
interpolation='bilinear',
|
348 |
+
padding_mode='zeros',
|
349 |
+
align_corners=True):
|
350 |
+
"""Warp an image or a feature map with optical flow.
|
351 |
+
Args:
|
352 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
353 |
+
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
|
354 |
+
a two-channel, denoting the width and height relative offsets.
|
355 |
+
Note that the values are not normalized to [-1, 1].
|
356 |
+
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
|
357 |
+
Default: 'bilinear'.
|
358 |
+
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
|
359 |
+
Default: 'zeros'.
|
360 |
+
align_corners (bool): Whether align corners. Default: True.
|
361 |
+
Returns:
|
362 |
+
Tensor: Warped image or feature map.
|
363 |
+
"""
|
364 |
+
if x.size()[-2:] != flow.size()[1:3]:
|
365 |
+
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
|
366 |
+
f'flow ({flow.size()[1:3]}) are not the same.')
|
367 |
+
_, _, h, w = x.size()
|
368 |
+
# create mesh grid
|
369 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
370 |
+
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
|
371 |
+
grid.requires_grad = False
|
372 |
+
|
373 |
+
grid_flow = grid + flow
|
374 |
+
# scale grid_flow to [-1,1]
|
375 |
+
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
|
376 |
+
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
|
377 |
+
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
|
378 |
+
output = F.grid_sample(x,
|
379 |
+
grid_flow,
|
380 |
+
mode=interpolation,
|
381 |
+
padding_mode=padding_mode,
|
382 |
+
align_corners=align_corners)
|
383 |
+
return output
|
384 |
+
|
385 |
+
|
386 |
+
def initial_mask_flow(mask):
|
387 |
+
"""
|
388 |
+
mask 1 indicates valid pixel 0 indicates unknown pixel
|
389 |
+
"""
|
390 |
+
B, T, C, H, W = mask.shape
|
391 |
+
|
392 |
+
# calculate relative position
|
393 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
394 |
+
|
395 |
+
grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
|
396 |
+
abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
|
397 |
+
relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
|
398 |
+
|
399 |
+
abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
|
400 |
+
relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
|
401 |
+
|
402 |
+
# calculate the nearest indices
|
403 |
+
pos_up = mask.unsqueeze(3).repeat(
|
404 |
+
1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
|
405 |
+
relative_pos_y <= H)[None, None, None]
|
406 |
+
nearest_indice_up = pos_up.max(dim=4)[1]
|
407 |
+
|
408 |
+
pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
|
409 |
+
None, None, None] * (relative_pos_y <= H)[None, None, None]
|
410 |
+
nearest_indice_down = (pos_down).max(dim=4)[1]
|
411 |
+
|
412 |
+
pos_left = mask.unsqueeze(4).repeat(
|
413 |
+
1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
|
414 |
+
relative_pos_x <= W)[None, None, None]
|
415 |
+
nearest_indice_left = (pos_left).max(dim=5)[1]
|
416 |
+
|
417 |
+
pos_right = mask.unsqueeze(4).repeat(
|
418 |
+
1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
|
419 |
+
relative_pos_x <= W)[None, None, None]
|
420 |
+
nearest_indice_right = (pos_right).max(dim=5)[1]
|
421 |
+
|
422 |
+
# NOTE: IMPORTANT !!! depending on how to use this offset
|
423 |
+
initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
|
424 |
+
initial_offset_down = nearest_indice_down - grid_y[None, None, None]
|
425 |
+
|
426 |
+
initial_offset_left = -(nearest_indice_left -
|
427 |
+
grid_x[None, None, None]).flip(4)
|
428 |
+
initial_offset_right = nearest_indice_right - grid_x[None, None, None]
|
429 |
+
|
430 |
+
# nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
|
431 |
+
# initial_offset_x = nearest_indice_x - grid_x
|
432 |
+
|
433 |
+
# handle the boundary cases
|
434 |
+
final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
|
435 |
+
initial_offset_down > 0) * initial_offset_down
|
436 |
+
final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
|
437 |
+
initial_offset_up < 0) * initial_offset_up
|
438 |
+
final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
|
439 |
+
initial_offset_right > 0) * initial_offset_right
|
440 |
+
final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
|
441 |
+
initial_offset_left < 0) * initial_offset_left
|
442 |
+
zero_offset = torch.zeros_like(final_offset_down)
|
443 |
+
# out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
|
444 |
+
out = torch.cat([
|
445 |
+
zero_offset, final_offset_left, zero_offset, final_offset_right,
|
446 |
+
final_offset_up, zero_offset, final_offset_down, zero_offset
|
447 |
+
],
|
448 |
+
dim=2)
|
449 |
+
|
450 |
+
return out
|
inpainter/model/modules/spectral_norm.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch.nn.functional import normalize
|
6 |
+
|
7 |
+
|
8 |
+
class SpectralNorm(object):
|
9 |
+
# Invariant before and after each forward call:
|
10 |
+
# u = normalize(W @ v)
|
11 |
+
# NB: At initialization, this invariant is not enforced
|
12 |
+
|
13 |
+
_version = 1
|
14 |
+
|
15 |
+
# At version 1:
|
16 |
+
# made `W` not a buffer,
|
17 |
+
# added `v` as a buffer, and
|
18 |
+
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
|
19 |
+
|
20 |
+
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
21 |
+
self.name = name
|
22 |
+
self.dim = dim
|
23 |
+
if n_power_iterations <= 0:
|
24 |
+
raise ValueError(
|
25 |
+
'Expected n_power_iterations to be positive, but '
|
26 |
+
'got n_power_iterations={}'.format(n_power_iterations))
|
27 |
+
self.n_power_iterations = n_power_iterations
|
28 |
+
self.eps = eps
|
29 |
+
|
30 |
+
def reshape_weight_to_matrix(self, weight):
|
31 |
+
weight_mat = weight
|
32 |
+
if self.dim != 0:
|
33 |
+
# permute dim to front
|
34 |
+
weight_mat = weight_mat.permute(
|
35 |
+
self.dim,
|
36 |
+
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
37 |
+
height = weight_mat.size(0)
|
38 |
+
return weight_mat.reshape(height, -1)
|
39 |
+
|
40 |
+
def compute_weight(self, module, do_power_iteration):
|
41 |
+
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
|
42 |
+
# updated in power iteration **in-place**. This is very important
|
43 |
+
# because in `DataParallel` forward, the vectors (being buffers) are
|
44 |
+
# broadcast from the parallelized module to each module replica,
|
45 |
+
# which is a new module object created on the fly. And each replica
|
46 |
+
# runs its own spectral norm power iteration. So simply assigning
|
47 |
+
# the updated vectors to the module this function runs on will cause
|
48 |
+
# the update to be lost forever. And the next time the parallelized
|
49 |
+
# module is replicated, the same randomly initialized vectors are
|
50 |
+
# broadcast and used!
|
51 |
+
#
|
52 |
+
# Therefore, to make the change propagate back, we rely on two
|
53 |
+
# important behaviors (also enforced via tests):
|
54 |
+
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
|
55 |
+
# is already on correct device; and it makes sure that the
|
56 |
+
# parallelized module is already on `device[0]`.
|
57 |
+
# 2. If the out tensor in `out=` kwarg has correct shape, it will
|
58 |
+
# just fill in the values.
|
59 |
+
# Therefore, since the same power iteration is performed on all
|
60 |
+
# devices, simply updating the tensors in-place will make sure that
|
61 |
+
# the module replica on `device[0]` will update the _u vector on the
|
62 |
+
# parallized module (by shared storage).
|
63 |
+
#
|
64 |
+
# However, after we update `u` and `v` in-place, we need to **clone**
|
65 |
+
# them before using them to normalize the weight. This is to support
|
66 |
+
# backproping through two forward passes, e.g., the common pattern in
|
67 |
+
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
|
68 |
+
# complain that variables needed to do backward for the first forward
|
69 |
+
# (i.e., the `u` and `v` vectors) are changed in the second forward.
|
70 |
+
weight = getattr(module, self.name + '_orig')
|
71 |
+
u = getattr(module, self.name + '_u')
|
72 |
+
v = getattr(module, self.name + '_v')
|
73 |
+
weight_mat = self.reshape_weight_to_matrix(weight)
|
74 |
+
|
75 |
+
if do_power_iteration:
|
76 |
+
with torch.no_grad():
|
77 |
+
for _ in range(self.n_power_iterations):
|
78 |
+
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
79 |
+
# are the first left and right singular vectors.
|
80 |
+
# This power iteration produces approximations of `u` and `v`.
|
81 |
+
v = normalize(torch.mv(weight_mat.t(), u),
|
82 |
+
dim=0,
|
83 |
+
eps=self.eps,
|
84 |
+
out=v)
|
85 |
+
u = normalize(torch.mv(weight_mat, v),
|
86 |
+
dim=0,
|
87 |
+
eps=self.eps,
|
88 |
+
out=u)
|
89 |
+
if self.n_power_iterations > 0:
|
90 |
+
# See above on why we need to clone
|
91 |
+
u = u.clone()
|
92 |
+
v = v.clone()
|
93 |
+
|
94 |
+
sigma = torch.dot(u, torch.mv(weight_mat, v))
|
95 |
+
weight = weight / sigma
|
96 |
+
return weight
|
97 |
+
|
98 |
+
def remove(self, module):
|
99 |
+
with torch.no_grad():
|
100 |
+
weight = self.compute_weight(module, do_power_iteration=False)
|
101 |
+
delattr(module, self.name)
|
102 |
+
delattr(module, self.name + '_u')
|
103 |
+
delattr(module, self.name + '_v')
|
104 |
+
delattr(module, self.name + '_orig')
|
105 |
+
module.register_parameter(self.name,
|
106 |
+
torch.nn.Parameter(weight.detach()))
|
107 |
+
|
108 |
+
def __call__(self, module, inputs):
|
109 |
+
setattr(
|
110 |
+
module, self.name,
|
111 |
+
self.compute_weight(module, do_power_iteration=module.training))
|
112 |
+
|
113 |
+
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
|
114 |
+
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
|
115 |
+
# (the invariant at top of this class) and `u @ W @ v = sigma`.
|
116 |
+
# This uses pinverse in case W^T W is not invertible.
|
117 |
+
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
|
118 |
+
weight_mat.t(), u.unsqueeze(1)).squeeze(1)
|
119 |
+
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def apply(module, name, n_power_iterations, dim, eps):
|
123 |
+
for k, hook in module._forward_pre_hooks.items():
|
124 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
125 |
+
raise RuntimeError(
|
126 |
+
"Cannot register two spectral_norm hooks on "
|
127 |
+
"the same parameter {}".format(name))
|
128 |
+
|
129 |
+
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
130 |
+
weight = module._parameters[name]
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
weight_mat = fn.reshape_weight_to_matrix(weight)
|
134 |
+
|
135 |
+
h, w = weight_mat.size()
|
136 |
+
# randomly initialize `u` and `v`
|
137 |
+
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
|
138 |
+
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
|
139 |
+
|
140 |
+
delattr(module, fn.name)
|
141 |
+
module.register_parameter(fn.name + "_orig", weight)
|
142 |
+
# We still need to assign weight back as fn.name because all sorts of
|
143 |
+
# things may assume that it exists, e.g., when initializing weights.
|
144 |
+
# However, we can't directly assign as it could be an nn.Parameter and
|
145 |
+
# gets added as a parameter. Instead, we register weight.data as a plain
|
146 |
+
# attribute.
|
147 |
+
setattr(module, fn.name, weight.data)
|
148 |
+
module.register_buffer(fn.name + "_u", u)
|
149 |
+
module.register_buffer(fn.name + "_v", v)
|
150 |
+
|
151 |
+
module.register_forward_pre_hook(fn)
|
152 |
+
|
153 |
+
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
|
154 |
+
module._register_load_state_dict_pre_hook(
|
155 |
+
SpectralNormLoadStateDictPreHook(fn))
|
156 |
+
return fn
|
157 |
+
|
158 |
+
|
159 |
+
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
160 |
+
# instancemethod.
|
161 |
+
class SpectralNormLoadStateDictPreHook(object):
|
162 |
+
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
163 |
+
def __init__(self, fn):
|
164 |
+
self.fn = fn
|
165 |
+
|
166 |
+
# For state_dict with version None, (assuming that it has gone through at
|
167 |
+
# least one training forward), we have
|
168 |
+
#
|
169 |
+
# u = normalize(W_orig @ v)
|
170 |
+
# W = W_orig / sigma, where sigma = u @ W_orig @ v
|
171 |
+
#
|
172 |
+
# To compute `v`, we solve `W_orig @ x = u`, and let
|
173 |
+
# v = x / (u @ W_orig @ x) * (W / W_orig).
|
174 |
+
def __call__(self, state_dict, prefix, local_metadata, strict,
|
175 |
+
missing_keys, unexpected_keys, error_msgs):
|
176 |
+
fn = self.fn
|
177 |
+
version = local_metadata.get('spectral_norm',
|
178 |
+
{}).get(fn.name + '.version', None)
|
179 |
+
if version is None or version < 1:
|
180 |
+
with torch.no_grad():
|
181 |
+
weight_orig = state_dict[prefix + fn.name + '_orig']
|
182 |
+
# weight = state_dict.pop(prefix + fn.name)
|
183 |
+
# sigma = (weight_orig / weight).mean()
|
184 |
+
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
|
185 |
+
u = state_dict[prefix + fn.name + '_u']
|
186 |
+
# v = fn._solve_v_and_rescale(weight_mat, u, sigma)
|
187 |
+
# state_dict[prefix + fn.name + '_v'] = v
|
188 |
+
|
189 |
+
|
190 |
+
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
191 |
+
# instancemethod.
|
192 |
+
class SpectralNormStateDictHook(object):
|
193 |
+
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
194 |
+
def __init__(self, fn):
|
195 |
+
self.fn = fn
|
196 |
+
|
197 |
+
def __call__(self, module, state_dict, prefix, local_metadata):
|
198 |
+
if 'spectral_norm' not in local_metadata:
|
199 |
+
local_metadata['spectral_norm'] = {}
|
200 |
+
key = self.fn.name + '.version'
|
201 |
+
if key in local_metadata['spectral_norm']:
|
202 |
+
raise RuntimeError(
|
203 |
+
"Unexpected key in metadata['spectral_norm']: {}".format(key))
|
204 |
+
local_metadata['spectral_norm'][key] = self.fn._version
|
205 |
+
|
206 |
+
|
207 |
+
def spectral_norm(module,
|
208 |
+
name='weight',
|
209 |
+
n_power_iterations=1,
|
210 |
+
eps=1e-12,
|
211 |
+
dim=None):
|
212 |
+
r"""Applies spectral normalization to a parameter in the given module.
|
213 |
+
|
214 |
+
.. math::
|
215 |
+
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
|
216 |
+
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
217 |
+
|
218 |
+
Spectral normalization stabilizes the training of discriminators (critics)
|
219 |
+
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
|
220 |
+
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
221 |
+
power iteration method. If the dimension of the weight tensor is greater
|
222 |
+
than 2, it is reshaped to 2D in power iteration method to get spectral
|
223 |
+
norm. This is implemented via a hook that calculates spectral norm and
|
224 |
+
rescales weight before every :meth:`~Module.forward` call.
|
225 |
+
|
226 |
+
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
227 |
+
|
228 |
+
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
229 |
+
|
230 |
+
Args:
|
231 |
+
module (nn.Module): containing module
|
232 |
+
name (str, optional): name of weight parameter
|
233 |
+
n_power_iterations (int, optional): number of power iterations to
|
234 |
+
calculate spectral norm
|
235 |
+
eps (float, optional): epsilon for numerical stability in
|
236 |
+
calculating norms
|
237 |
+
dim (int, optional): dimension corresponding to number of outputs,
|
238 |
+
the default is ``0``, except for modules that are instances of
|
239 |
+
ConvTranspose{1,2,3}d, when it is ``1``
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
The original module with the spectral norm hook
|
243 |
+
|
244 |
+
Example::
|
245 |
+
|
246 |
+
>>> m = spectral_norm(nn.Linear(20, 40))
|
247 |
+
>>> m
|
248 |
+
Linear(in_features=20, out_features=40, bias=True)
|
249 |
+
>>> m.weight_u.size()
|
250 |
+
torch.Size([40])
|
251 |
+
|
252 |
+
"""
|
253 |
+
if dim is None:
|
254 |
+
if isinstance(module,
|
255 |
+
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
256 |
+
torch.nn.ConvTranspose3d)):
|
257 |
+
dim = 1
|
258 |
+
else:
|
259 |
+
dim = 0
|
260 |
+
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
261 |
+
return module
|
262 |
+
|
263 |
+
|
264 |
+
def remove_spectral_norm(module, name='weight'):
|
265 |
+
r"""Removes the spectral normalization reparameterization from a module.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
module (Module): containing module
|
269 |
+
name (str, optional): name of weight parameter
|
270 |
+
|
271 |
+
Example:
|
272 |
+
>>> m = spectral_norm(nn.Linear(40, 10))
|
273 |
+
>>> remove_spectral_norm(m)
|
274 |
+
"""
|
275 |
+
for k, hook in module._forward_pre_hooks.items():
|
276 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
277 |
+
hook.remove(module)
|
278 |
+
del module._forward_pre_hooks[k]
|
279 |
+
return module
|
280 |
+
|
281 |
+
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
282 |
+
name, module))
|
283 |
+
|
284 |
+
|
285 |
+
def use_spectral_norm(module, use_sn=False):
|
286 |
+
if use_sn:
|
287 |
+
return spectral_norm(module)
|
288 |
+
return module
|
inpainter/model/modules/tfocal_transformer.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is based on:
|
3 |
+
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
|
4 |
+
https://github.com/ruiliu-ai/FuseFormer
|
5 |
+
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
|
6 |
+
https://github.com/yitu-opensource/T2T-ViT
|
7 |
+
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
|
8 |
+
https://github.com/microsoft/Focal-Transformer
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
from functools import reduce
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class SoftSplit(nn.Module):
|
20 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding,
|
21 |
+
t2t_param):
|
22 |
+
super(SoftSplit, self).__init__()
|
23 |
+
self.kernel_size = kernel_size
|
24 |
+
self.t2t = nn.Unfold(kernel_size=kernel_size,
|
25 |
+
stride=stride,
|
26 |
+
padding=padding)
|
27 |
+
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
|
28 |
+
self.embedding = nn.Linear(c_in, hidden)
|
29 |
+
|
30 |
+
self.f_h = int(
|
31 |
+
(t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
|
32 |
+
(t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
|
33 |
+
1)
|
34 |
+
self.f_w = int(
|
35 |
+
(t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
|
36 |
+
(t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
|
37 |
+
1)
|
38 |
+
|
39 |
+
def forward(self, x, b):
|
40 |
+
feat = self.t2t(x)
|
41 |
+
feat = feat.permute(0, 2, 1)
|
42 |
+
# feat shape [b*t, num_vec, ks*ks*c]
|
43 |
+
feat = self.embedding(feat)
|
44 |
+
# feat shape after embedding [b, t*num_vec, hidden]
|
45 |
+
feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
|
46 |
+
return feat
|
47 |
+
|
48 |
+
|
49 |
+
class SoftComp(nn.Module):
|
50 |
+
def __init__(self, channel, hidden, output_size, kernel_size, stride,
|
51 |
+
padding):
|
52 |
+
super(SoftComp, self).__init__()
|
53 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
54 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
55 |
+
self.embedding = nn.Linear(hidden, c_out)
|
56 |
+
self.t2t = torch.nn.Fold(output_size=output_size,
|
57 |
+
kernel_size=kernel_size,
|
58 |
+
stride=stride,
|
59 |
+
padding=padding)
|
60 |
+
h, w = output_size
|
61 |
+
self.bias = nn.Parameter(torch.zeros((channel, h, w),
|
62 |
+
dtype=torch.float32),
|
63 |
+
requires_grad=True)
|
64 |
+
|
65 |
+
def forward(self, x, t):
|
66 |
+
b_, _, _, _, c_ = x.shape
|
67 |
+
x = x.view(b_, -1, c_)
|
68 |
+
feat = self.embedding(x)
|
69 |
+
b, _, c = feat.size()
|
70 |
+
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
71 |
+
feat = self.t2t(feat) + self.bias[None]
|
72 |
+
return feat
|
73 |
+
|
74 |
+
|
75 |
+
class FusionFeedForward(nn.Module):
|
76 |
+
def __init__(self, d_model, n_vecs=None, t2t_params=None):
|
77 |
+
super(FusionFeedForward, self).__init__()
|
78 |
+
# We set d_ff as a default to 1960
|
79 |
+
hd = 1960
|
80 |
+
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
|
81 |
+
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
|
82 |
+
assert t2t_params is not None and n_vecs is not None
|
83 |
+
tp = t2t_params.copy()
|
84 |
+
self.fold = nn.Fold(**tp)
|
85 |
+
del tp['output_size']
|
86 |
+
self.unfold = nn.Unfold(**tp)
|
87 |
+
self.n_vecs = n_vecs
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.conv1(x)
|
91 |
+
b, n, c = x.size()
|
92 |
+
normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
|
93 |
+
49).permute(0, 2, 1)
|
94 |
+
x = self.unfold(
|
95 |
+
self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
|
96 |
+
self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
|
97 |
+
x = self.conv2(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
def window_partition(x, window_size):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
x: shape is (B, T, H, W, C)
|
105 |
+
window_size (tuple[int]): window size
|
106 |
+
Returns:
|
107 |
+
windows: (B*num_windows, T*window_size*window_size, C)
|
108 |
+
"""
|
109 |
+
B, T, H, W, C = x.shape
|
110 |
+
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
111 |
+
window_size[1], C)
|
112 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
113 |
+
-1, T * window_size[0] * window_size[1], C)
|
114 |
+
return windows
|
115 |
+
|
116 |
+
|
117 |
+
def window_partition_noreshape(x, window_size):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
x: shape is (B, T, H, W, C)
|
121 |
+
window_size (tuple[int]): window size
|
122 |
+
Returns:
|
123 |
+
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
|
124 |
+
"""
|
125 |
+
B, T, H, W, C = x.shape
|
126 |
+
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
127 |
+
window_size[1], C)
|
128 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
|
129 |
+
return windows
|
130 |
+
|
131 |
+
|
132 |
+
def window_reverse(windows, window_size, T, H, W):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
windows: shape is (num_windows*B, T, window_size, window_size, C)
|
136 |
+
window_size (tuple[int]): Window size
|
137 |
+
T (int): Temporal length of video
|
138 |
+
H (int): Height of image
|
139 |
+
W (int): Width of image
|
140 |
+
Returns:
|
141 |
+
x: (B, T, H, W, C)
|
142 |
+
"""
|
143 |
+
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
144 |
+
x = windows.view(B, H // window_size[0], W // window_size[1], T,
|
145 |
+
window_size[0], window_size[1], -1)
|
146 |
+
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class WindowAttention(nn.Module):
|
151 |
+
"""Temporal focal window attention
|
152 |
+
"""
|
153 |
+
def __init__(self, dim, expand_size, window_size, focal_window,
|
154 |
+
focal_level, num_heads, qkv_bias, pool_method):
|
155 |
+
|
156 |
+
super().__init__()
|
157 |
+
self.dim = dim
|
158 |
+
self.expand_size = expand_size
|
159 |
+
self.window_size = window_size # Wh, Ww
|
160 |
+
self.pool_method = pool_method
|
161 |
+
self.num_heads = num_heads
|
162 |
+
head_dim = dim // num_heads
|
163 |
+
self.scale = head_dim**-0.5
|
164 |
+
self.focal_level = focal_level
|
165 |
+
self.focal_window = focal_window
|
166 |
+
|
167 |
+
if any(i > 0 for i in self.expand_size) and focal_level > 0:
|
168 |
+
# get mask for rolled k and rolled v
|
169 |
+
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
|
170 |
+
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
|
171 |
+
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
|
172 |
+
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
|
173 |
+
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
|
174 |
+
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
|
175 |
+
mask_br = torch.ones(self.window_size[0], self.window_size[1])
|
176 |
+
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
|
177 |
+
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
|
178 |
+
0).flatten(0)
|
179 |
+
self.register_buffer("valid_ind_rolled",
|
180 |
+
mask_rolled.nonzero(as_tuple=False).view(-1))
|
181 |
+
|
182 |
+
if pool_method != "none" and focal_level > 1:
|
183 |
+
self.unfolds = nn.ModuleList()
|
184 |
+
|
185 |
+
# build relative position bias between local patch and pooled windows
|
186 |
+
for k in range(focal_level - 1):
|
187 |
+
stride = 2**k
|
188 |
+
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
|
189 |
+
for i in self.focal_window)
|
190 |
+
# define unfolding operations
|
191 |
+
self.unfolds += [
|
192 |
+
nn.Unfold(kernel_size=kernel_size,
|
193 |
+
stride=stride,
|
194 |
+
padding=tuple(i // 2 for i in kernel_size))
|
195 |
+
]
|
196 |
+
|
197 |
+
# define unfolding index for focal_level > 0
|
198 |
+
if k > 0:
|
199 |
+
mask = torch.zeros(kernel_size)
|
200 |
+
mask[(2**k) - 1:, (2**k) - 1:] = 1
|
201 |
+
self.register_buffer(
|
202 |
+
"valid_ind_unfold_{}".format(k),
|
203 |
+
mask.flatten(0).nonzero(as_tuple=False).view(-1))
|
204 |
+
|
205 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
206 |
+
self.proj = nn.Linear(dim, dim)
|
207 |
+
|
208 |
+
self.softmax = nn.Softmax(dim=-1)
|
209 |
+
|
210 |
+
def forward(self, x_all, mask_all=None):
|
211 |
+
"""
|
212 |
+
Args:
|
213 |
+
x: input features with shape of (B, T, Wh, Ww, C)
|
214 |
+
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
|
215 |
+
|
216 |
+
output: (nW*B, Wh*Ww, C)
|
217 |
+
"""
|
218 |
+
x = x_all[0]
|
219 |
+
|
220 |
+
B, T, nH, nW, C = x.shape
|
221 |
+
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
|
222 |
+
C).permute(4, 0, 1, 2, 3, 5).contiguous()
|
223 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
|
224 |
+
|
225 |
+
# partition q map
|
226 |
+
(q_windows, k_windows, v_windows) = map(
|
227 |
+
lambda t: window_partition(t, self.window_size).view(
|
228 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
229 |
+
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
|
230 |
+
contiguous().view(-1, self.num_heads, T * self.window_size[
|
231 |
+
0] * self.window_size[1], C // self.num_heads), (q, k, v))
|
232 |
+
# q(k/v)_windows shape : [16, 4, 225, 128]
|
233 |
+
|
234 |
+
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
|
235 |
+
(k_tl, v_tl) = map(
|
236 |
+
lambda t: torch.roll(t,
|
237 |
+
shifts=(-self.expand_size[0], -self.
|
238 |
+
expand_size[1]),
|
239 |
+
dims=(2, 3)), (k, v))
|
240 |
+
(k_tr, v_tr) = map(
|
241 |
+
lambda t: torch.roll(t,
|
242 |
+
shifts=(-self.expand_size[0], self.
|
243 |
+
expand_size[1]),
|
244 |
+
dims=(2, 3)), (k, v))
|
245 |
+
(k_bl, v_bl) = map(
|
246 |
+
lambda t: torch.roll(t,
|
247 |
+
shifts=(self.expand_size[0], -self.
|
248 |
+
expand_size[1]),
|
249 |
+
dims=(2, 3)), (k, v))
|
250 |
+
(k_br, v_br) = map(
|
251 |
+
lambda t: torch.roll(t,
|
252 |
+
shifts=(self.expand_size[0], self.
|
253 |
+
expand_size[1]),
|
254 |
+
dims=(2, 3)), (k, v))
|
255 |
+
|
256 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
257 |
+
lambda t: window_partition(t, self.window_size).view(
|
258 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
259 |
+
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
|
260 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
261 |
+
lambda t: window_partition(t, self.window_size).view(
|
262 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
263 |
+
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
|
264 |
+
k_rolled = torch.cat(
|
265 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
|
266 |
+
2).permute(0, 3, 1, 2, 4).contiguous()
|
267 |
+
v_rolled = torch.cat(
|
268 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
|
269 |
+
2).permute(0, 3, 1, 2, 4).contiguous()
|
270 |
+
|
271 |
+
# mask out tokens in current window
|
272 |
+
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
|
273 |
+
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
|
274 |
+
temp_N = k_rolled.shape[3]
|
275 |
+
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
|
276 |
+
C // self.num_heads)
|
277 |
+
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
|
278 |
+
C // self.num_heads)
|
279 |
+
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
280 |
+
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
281 |
+
else:
|
282 |
+
k_rolled = k_windows
|
283 |
+
v_rolled = v_windows
|
284 |
+
|
285 |
+
# q(k/v)_windows shape : [16, 4, 225, 128]
|
286 |
+
# k_rolled.shape : [16, 4, 5, 165, 128]
|
287 |
+
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
|
288 |
+
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
|
289 |
+
|
290 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
291 |
+
k_pooled = []
|
292 |
+
v_pooled = []
|
293 |
+
for k in range(self.focal_level - 1):
|
294 |
+
stride = 2**k
|
295 |
+
x_window_pooled = x_all[k + 1].permute(
|
296 |
+
0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
|
297 |
+
|
298 |
+
nWh, nWw = x_window_pooled.shape[2:4]
|
299 |
+
|
300 |
+
# generate mask for pooled windows
|
301 |
+
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
|
302 |
+
# unfold mask: [nWh*nWw//s//s, k*k, 1]
|
303 |
+
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
|
304 |
+
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
|
305 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
306 |
+
|
307 |
+
if k > 0:
|
308 |
+
valid_ind_unfold_k = getattr(
|
309 |
+
self, "valid_ind_unfold_{}".format(k))
|
310 |
+
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
311 |
+
|
312 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
313 |
+
x_window_masks = x_window_masks.masked_fill(
|
314 |
+
x_window_masks == 0,
|
315 |
+
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
316 |
+
mask_all[k + 1] = x_window_masks
|
317 |
+
|
318 |
+
# generate k and v for pooled windows
|
319 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(
|
320 |
+
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
|
321 |
+
3).view(3, -1, C, nWh,
|
322 |
+
nWw).contiguous()
|
323 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
|
324 |
+
2] # B*T, C, nWh, nWw
|
325 |
+
# k_pooled_k shape: [5, 512, 4, 4]
|
326 |
+
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
|
327 |
+
|
328 |
+
(k_pooled_k, v_pooled_k) = map(
|
329 |
+
lambda t: self.unfolds[k](t).view(
|
330 |
+
B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
|
331 |
+
view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
|
332 |
+
(k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
|
333 |
+
)
|
334 |
+
# k_pooled_k shape : [16, 4, 5, 45, 128]
|
335 |
+
|
336 |
+
# select valid unfolding index
|
337 |
+
if k > 0:
|
338 |
+
(k_pooled_k, v_pooled_k) = map(
|
339 |
+
lambda t: t[:, :, :, valid_ind_unfold_k],
|
340 |
+
(k_pooled_k, v_pooled_k))
|
341 |
+
|
342 |
+
k_pooled_k = k_pooled_k.view(
|
343 |
+
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
344 |
+
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
345 |
+
v_pooled_k = v_pooled_k.view(
|
346 |
+
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
347 |
+
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
348 |
+
|
349 |
+
k_pooled += [k_pooled_k]
|
350 |
+
v_pooled += [v_pooled_k]
|
351 |
+
|
352 |
+
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
|
353 |
+
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
354 |
+
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
355 |
+
else:
|
356 |
+
k_all = k_rolled
|
357 |
+
v_all = v_rolled
|
358 |
+
|
359 |
+
N = k_all.shape[-2]
|
360 |
+
q_windows = q_windows * self.scale
|
361 |
+
attn = (
|
362 |
+
q_windows @ k_all.transpose(-2, -1)
|
363 |
+
) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
|
364 |
+
# T * 45
|
365 |
+
window_area = T * self.window_size[0] * self.window_size[1]
|
366 |
+
# T * 165
|
367 |
+
window_area_rolled = k_rolled.shape[2]
|
368 |
+
|
369 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
370 |
+
offset = window_area_rolled
|
371 |
+
for k in range(self.focal_level - 1):
|
372 |
+
# add attentional mask
|
373 |
+
# mask_all[1] shape [1, 16, T * 45]
|
374 |
+
|
375 |
+
bias = tuple((i + 2**k - 1) for i in self.focal_window)
|
376 |
+
|
377 |
+
if mask_all[k + 1] is not None:
|
378 |
+
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
|
379 |
+
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
|
380 |
+
mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
381 |
+
|
382 |
+
offset += T * bias[0] * bias[1]
|
383 |
+
|
384 |
+
if mask_all[0] is not None:
|
385 |
+
nW = mask_all[0].shape[0]
|
386 |
+
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
|
387 |
+
window_area, N)
|
388 |
+
attn[:, :, :, :, :
|
389 |
+
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
|
390 |
+
None, :, None, :, :]
|
391 |
+
attn = attn.view(-1, self.num_heads, window_area, N)
|
392 |
+
attn = self.softmax(attn)
|
393 |
+
else:
|
394 |
+
attn = self.softmax(attn)
|
395 |
+
|
396 |
+
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
|
397 |
+
C)
|
398 |
+
x = self.proj(x)
|
399 |
+
return x
|
400 |
+
|
401 |
+
|
402 |
+
class TemporalFocalTransformerBlock(nn.Module):
|
403 |
+
r""" Temporal Focal Transformer Block.
|
404 |
+
Args:
|
405 |
+
dim (int): Number of input channels.
|
406 |
+
num_heads (int): Number of attention heads.
|
407 |
+
window_size (tuple[int]): Window size.
|
408 |
+
shift_size (int): Shift size for SW-MSA.
|
409 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
410 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
411 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
412 |
+
focal_level (int): The number level of focal window.
|
413 |
+
focal_window (int): Window size of each focal window.
|
414 |
+
n_vecs (int): Required for F3N.
|
415 |
+
t2t_params (int): T2T parameters for F3N.
|
416 |
+
"""
|
417 |
+
def __init__(self,
|
418 |
+
dim,
|
419 |
+
num_heads,
|
420 |
+
window_size=(5, 9),
|
421 |
+
mlp_ratio=4.,
|
422 |
+
qkv_bias=True,
|
423 |
+
pool_method="fc",
|
424 |
+
focal_level=2,
|
425 |
+
focal_window=(5, 9),
|
426 |
+
norm_layer=nn.LayerNorm,
|
427 |
+
n_vecs=None,
|
428 |
+
t2t_params=None):
|
429 |
+
super().__init__()
|
430 |
+
self.dim = dim
|
431 |
+
self.num_heads = num_heads
|
432 |
+
self.window_size = window_size
|
433 |
+
self.expand_size = tuple(i // 2 for i in window_size) # TODO
|
434 |
+
self.mlp_ratio = mlp_ratio
|
435 |
+
self.pool_method = pool_method
|
436 |
+
self.focal_level = focal_level
|
437 |
+
self.focal_window = focal_window
|
438 |
+
|
439 |
+
self.window_size_glo = self.window_size
|
440 |
+
|
441 |
+
self.pool_layers = nn.ModuleList()
|
442 |
+
if self.pool_method != "none":
|
443 |
+
for k in range(self.focal_level - 1):
|
444 |
+
window_size_glo = tuple(
|
445 |
+
math.floor(i / (2**k)) for i in self.window_size_glo)
|
446 |
+
self.pool_layers.append(
|
447 |
+
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
|
448 |
+
self.pool_layers[-1].weight.data.fill_(
|
449 |
+
1. / (window_size_glo[0] * window_size_glo[1]))
|
450 |
+
self.pool_layers[-1].bias.data.fill_(0)
|
451 |
+
|
452 |
+
self.norm1 = norm_layer(dim)
|
453 |
+
|
454 |
+
self.attn = WindowAttention(dim,
|
455 |
+
expand_size=self.expand_size,
|
456 |
+
window_size=self.window_size,
|
457 |
+
focal_window=focal_window,
|
458 |
+
focal_level=focal_level,
|
459 |
+
num_heads=num_heads,
|
460 |
+
qkv_bias=qkv_bias,
|
461 |
+
pool_method=pool_method)
|
462 |
+
|
463 |
+
self.norm2 = norm_layer(dim)
|
464 |
+
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
|
465 |
+
|
466 |
+
def forward(self, x):
|
467 |
+
B, T, H, W, C = x.shape
|
468 |
+
|
469 |
+
shortcut = x
|
470 |
+
x = self.norm1(x)
|
471 |
+
|
472 |
+
shifted_x = x
|
473 |
+
|
474 |
+
x_windows_all = [shifted_x]
|
475 |
+
x_window_masks_all = [None]
|
476 |
+
|
477 |
+
# partition windows tuple(i // 2 for i in window_size)
|
478 |
+
if self.focal_level > 1 and self.pool_method != "none":
|
479 |
+
# if we add coarser granularity and the pool method is not none
|
480 |
+
for k in range(self.focal_level - 1):
|
481 |
+
window_size_glo = tuple(
|
482 |
+
math.floor(i / (2**k)) for i in self.window_size_glo)
|
483 |
+
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
|
484 |
+
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
|
485 |
+
H_pool = pooled_h * window_size_glo[0]
|
486 |
+
W_pool = pooled_w * window_size_glo[1]
|
487 |
+
|
488 |
+
x_level_k = shifted_x
|
489 |
+
# trim or pad shifted_x depending on the required size
|
490 |
+
if H > H_pool:
|
491 |
+
trim_t = (H - H_pool) // 2
|
492 |
+
trim_b = H - H_pool - trim_t
|
493 |
+
x_level_k = x_level_k[:, :, trim_t:-trim_b]
|
494 |
+
elif H < H_pool:
|
495 |
+
pad_t = (H_pool - H) // 2
|
496 |
+
pad_b = H_pool - H - pad_t
|
497 |
+
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
|
498 |
+
|
499 |
+
if W > W_pool:
|
500 |
+
trim_l = (W - W_pool) // 2
|
501 |
+
trim_r = W - W_pool - trim_l
|
502 |
+
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
|
503 |
+
elif W < W_pool:
|
504 |
+
pad_l = (W_pool - W) // 2
|
505 |
+
pad_r = W_pool - W - pad_l
|
506 |
+
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
|
507 |
+
|
508 |
+
x_windows_noreshape = window_partition_noreshape(
|
509 |
+
x_level_k.contiguous(), window_size_glo
|
510 |
+
) # B, nw, nw, T, window_size, window_size, C
|
511 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
512 |
+
x_windows_noreshape = x_windows_noreshape.view(
|
513 |
+
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
|
514 |
+
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
|
515 |
+
x_windows_pooled = self.pool_layers[k](
|
516 |
+
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
|
517 |
+
|
518 |
+
x_windows_all += [x_windows_pooled]
|
519 |
+
x_window_masks_all += [None]
|
520 |
+
|
521 |
+
attn_windows = self.attn(
|
522 |
+
x_windows_all,
|
523 |
+
mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
|
524 |
+
|
525 |
+
# merge windows
|
526 |
+
attn_windows = attn_windows.view(-1, T, self.window_size[0],
|
527 |
+
self.window_size[1], C)
|
528 |
+
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
|
529 |
+
W) # B T H' W' C
|
530 |
+
|
531 |
+
# FFN
|
532 |
+
x = shortcut + shifted_x
|
533 |
+
y = self.norm2(x)
|
534 |
+
x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
|
535 |
+
|
536 |
+
return x
|
inpainter/model/modules/tfocal_transformer_hq.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is based on:
|
3 |
+
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
|
4 |
+
https://github.com/ruiliu-ai/FuseFormer
|
5 |
+
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
|
6 |
+
https://github.com/yitu-opensource/T2T-ViT
|
7 |
+
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
|
8 |
+
https://github.com/microsoft/Focal-Transformer
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
from functools import reduce
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class SoftSplit(nn.Module):
|
20 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding,
|
21 |
+
t2t_param):
|
22 |
+
super(SoftSplit, self).__init__()
|
23 |
+
self.kernel_size = kernel_size
|
24 |
+
self.t2t = nn.Unfold(kernel_size=kernel_size,
|
25 |
+
stride=stride,
|
26 |
+
padding=padding)
|
27 |
+
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
|
28 |
+
self.embedding = nn.Linear(c_in, hidden)
|
29 |
+
|
30 |
+
self.t2t_param = t2t_param
|
31 |
+
|
32 |
+
def forward(self, x, b, output_size):
|
33 |
+
f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
|
34 |
+
(self.t2t_param['kernel_size'][0] - 1) - 1) /
|
35 |
+
self.t2t_param['stride'][0] + 1)
|
36 |
+
f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
|
37 |
+
(self.t2t_param['kernel_size'][1] - 1) - 1) /
|
38 |
+
self.t2t_param['stride'][1] + 1)
|
39 |
+
|
40 |
+
feat = self.t2t(x)
|
41 |
+
feat = feat.permute(0, 2, 1)
|
42 |
+
# feat shape [b*t, num_vec, ks*ks*c]
|
43 |
+
feat = self.embedding(feat)
|
44 |
+
# feat shape after embedding [b, t*num_vec, hidden]
|
45 |
+
feat = feat.view(b, -1, f_h, f_w, feat.size(2))
|
46 |
+
return feat
|
47 |
+
|
48 |
+
|
49 |
+
class SoftComp(nn.Module):
|
50 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
51 |
+
super(SoftComp, self).__init__()
|
52 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
53 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
54 |
+
self.embedding = nn.Linear(hidden, c_out)
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.stride = stride
|
57 |
+
self.padding = padding
|
58 |
+
self.bias_conv = nn.Conv2d(channel,
|
59 |
+
channel,
|
60 |
+
kernel_size=3,
|
61 |
+
stride=1,
|
62 |
+
padding=1)
|
63 |
+
# TODO upsample conv
|
64 |
+
# self.bias_conv = nn.Conv2d()
|
65 |
+
# self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
|
66 |
+
|
67 |
+
def forward(self, x, t, output_size):
|
68 |
+
b_, _, _, _, c_ = x.shape
|
69 |
+
x = x.view(b_, -1, c_)
|
70 |
+
feat = self.embedding(x)
|
71 |
+
b, _, c = feat.size()
|
72 |
+
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
73 |
+
feat = F.fold(feat,
|
74 |
+
output_size=output_size,
|
75 |
+
kernel_size=self.kernel_size,
|
76 |
+
stride=self.stride,
|
77 |
+
padding=self.padding)
|
78 |
+
feat = self.bias_conv(feat)
|
79 |
+
return feat
|
80 |
+
|
81 |
+
|
82 |
+
class FusionFeedForward(nn.Module):
|
83 |
+
def __init__(self, d_model, n_vecs=None, t2t_params=None):
|
84 |
+
super(FusionFeedForward, self).__init__()
|
85 |
+
# We set d_ff as a default to 1960
|
86 |
+
hd = 1960
|
87 |
+
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
|
88 |
+
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
|
89 |
+
assert t2t_params is not None and n_vecs is not None
|
90 |
+
self.t2t_params = t2t_params
|
91 |
+
|
92 |
+
def forward(self, x, output_size):
|
93 |
+
n_vecs = 1
|
94 |
+
for i, d in enumerate(self.t2t_params['kernel_size']):
|
95 |
+
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
|
96 |
+
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
|
97 |
+
|
98 |
+
x = self.conv1(x)
|
99 |
+
b, n, c = x.size()
|
100 |
+
normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
|
101 |
+
normalizer = F.fold(normalizer,
|
102 |
+
output_size=output_size,
|
103 |
+
kernel_size=self.t2t_params['kernel_size'],
|
104 |
+
padding=self.t2t_params['padding'],
|
105 |
+
stride=self.t2t_params['stride'])
|
106 |
+
|
107 |
+
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
|
108 |
+
output_size=output_size,
|
109 |
+
kernel_size=self.t2t_params['kernel_size'],
|
110 |
+
padding=self.t2t_params['padding'],
|
111 |
+
stride=self.t2t_params['stride'])
|
112 |
+
|
113 |
+
x = F.unfold(x / normalizer,
|
114 |
+
kernel_size=self.t2t_params['kernel_size'],
|
115 |
+
padding=self.t2t_params['padding'],
|
116 |
+
stride=self.t2t_params['stride']).permute(
|
117 |
+
0, 2, 1).contiguous().view(b, n, c)
|
118 |
+
x = self.conv2(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
def window_partition(x, window_size):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
x: shape is (B, T, H, W, C)
|
126 |
+
window_size (tuple[int]): window size
|
127 |
+
Returns:
|
128 |
+
windows: (B*num_windows, T*window_size*window_size, C)
|
129 |
+
"""
|
130 |
+
B, T, H, W, C = x.shape
|
131 |
+
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
132 |
+
window_size[1], C)
|
133 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
134 |
+
-1, T * window_size[0] * window_size[1], C)
|
135 |
+
return windows
|
136 |
+
|
137 |
+
|
138 |
+
def window_partition_noreshape(x, window_size):
|
139 |
+
"""
|
140 |
+
Args:
|
141 |
+
x: shape is (B, T, H, W, C)
|
142 |
+
window_size (tuple[int]): window size
|
143 |
+
Returns:
|
144 |
+
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
|
145 |
+
"""
|
146 |
+
B, T, H, W, C = x.shape
|
147 |
+
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
148 |
+
window_size[1], C)
|
149 |
+
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
|
150 |
+
return windows
|
151 |
+
|
152 |
+
|
153 |
+
def window_reverse(windows, window_size, T, H, W):
|
154 |
+
"""
|
155 |
+
Args:
|
156 |
+
windows: shape is (num_windows*B, T, window_size, window_size, C)
|
157 |
+
window_size (tuple[int]): Window size
|
158 |
+
T (int): Temporal length of video
|
159 |
+
H (int): Height of image
|
160 |
+
W (int): Width of image
|
161 |
+
Returns:
|
162 |
+
x: (B, T, H, W, C)
|
163 |
+
"""
|
164 |
+
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
165 |
+
x = windows.view(B, H // window_size[0], W // window_size[1], T,
|
166 |
+
window_size[0], window_size[1], -1)
|
167 |
+
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class WindowAttention(nn.Module):
|
172 |
+
"""Temporal focal window attention
|
173 |
+
"""
|
174 |
+
def __init__(self, dim, expand_size, window_size, focal_window,
|
175 |
+
focal_level, num_heads, qkv_bias, pool_method):
|
176 |
+
|
177 |
+
super().__init__()
|
178 |
+
self.dim = dim
|
179 |
+
self.expand_size = expand_size
|
180 |
+
self.window_size = window_size # Wh, Ww
|
181 |
+
self.pool_method = pool_method
|
182 |
+
self.num_heads = num_heads
|
183 |
+
head_dim = dim // num_heads
|
184 |
+
self.scale = head_dim**-0.5
|
185 |
+
self.focal_level = focal_level
|
186 |
+
self.focal_window = focal_window
|
187 |
+
|
188 |
+
if any(i > 0 for i in self.expand_size) and focal_level > 0:
|
189 |
+
# get mask for rolled k and rolled v
|
190 |
+
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
|
191 |
+
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
|
192 |
+
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
|
193 |
+
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
|
194 |
+
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
|
195 |
+
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
|
196 |
+
mask_br = torch.ones(self.window_size[0], self.window_size[1])
|
197 |
+
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
|
198 |
+
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
|
199 |
+
0).flatten(0)
|
200 |
+
self.register_buffer("valid_ind_rolled",
|
201 |
+
mask_rolled.nonzero(as_tuple=False).view(-1))
|
202 |
+
|
203 |
+
if pool_method != "none" and focal_level > 1:
|
204 |
+
self.unfolds = nn.ModuleList()
|
205 |
+
|
206 |
+
# build relative position bias between local patch and pooled windows
|
207 |
+
for k in range(focal_level - 1):
|
208 |
+
stride = 2**k
|
209 |
+
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
|
210 |
+
for i in self.focal_window)
|
211 |
+
# define unfolding operations
|
212 |
+
self.unfolds += [
|
213 |
+
nn.Unfold(kernel_size=kernel_size,
|
214 |
+
stride=stride,
|
215 |
+
padding=tuple(i // 2 for i in kernel_size))
|
216 |
+
]
|
217 |
+
|
218 |
+
# define unfolding index for focal_level > 0
|
219 |
+
if k > 0:
|
220 |
+
mask = torch.zeros(kernel_size)
|
221 |
+
mask[(2**k) - 1:, (2**k) - 1:] = 1
|
222 |
+
self.register_buffer(
|
223 |
+
"valid_ind_unfold_{}".format(k),
|
224 |
+
mask.flatten(0).nonzero(as_tuple=False).view(-1))
|
225 |
+
|
226 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
227 |
+
self.proj = nn.Linear(dim, dim)
|
228 |
+
|
229 |
+
self.softmax = nn.Softmax(dim=-1)
|
230 |
+
|
231 |
+
def forward(self, x_all, mask_all=None):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
x: input features with shape of (B, T, Wh, Ww, C)
|
235 |
+
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
|
236 |
+
|
237 |
+
output: (nW*B, Wh*Ww, C)
|
238 |
+
"""
|
239 |
+
x = x_all[0]
|
240 |
+
|
241 |
+
B, T, nH, nW, C = x.shape
|
242 |
+
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
|
243 |
+
C).permute(4, 0, 1, 2, 3, 5).contiguous()
|
244 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
|
245 |
+
|
246 |
+
# partition q map
|
247 |
+
(q_windows, k_windows, v_windows) = map(
|
248 |
+
lambda t: window_partition(t, self.window_size).view(
|
249 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
250 |
+
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
|
251 |
+
contiguous().view(-1, self.num_heads, T * self.window_size[
|
252 |
+
0] * self.window_size[1], C // self.num_heads), (q, k, v))
|
253 |
+
# q(k/v)_windows shape : [16, 4, 225, 128]
|
254 |
+
|
255 |
+
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
|
256 |
+
(k_tl, v_tl) = map(
|
257 |
+
lambda t: torch.roll(t,
|
258 |
+
shifts=(-self.expand_size[0], -self.
|
259 |
+
expand_size[1]),
|
260 |
+
dims=(2, 3)), (k, v))
|
261 |
+
(k_tr, v_tr) = map(
|
262 |
+
lambda t: torch.roll(t,
|
263 |
+
shifts=(-self.expand_size[0], self.
|
264 |
+
expand_size[1]),
|
265 |
+
dims=(2, 3)), (k, v))
|
266 |
+
(k_bl, v_bl) = map(
|
267 |
+
lambda t: torch.roll(t,
|
268 |
+
shifts=(self.expand_size[0], -self.
|
269 |
+
expand_size[1]),
|
270 |
+
dims=(2, 3)), (k, v))
|
271 |
+
(k_br, v_br) = map(
|
272 |
+
lambda t: torch.roll(t,
|
273 |
+
shifts=(self.expand_size[0], self.
|
274 |
+
expand_size[1]),
|
275 |
+
dims=(2, 3)), (k, v))
|
276 |
+
|
277 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
278 |
+
lambda t: window_partition(t, self.window_size).view(
|
279 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
280 |
+
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
|
281 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
282 |
+
lambda t: window_partition(t, self.window_size).view(
|
283 |
+
-1, T, self.window_size[0] * self.window_size[1], self.
|
284 |
+
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
|
285 |
+
k_rolled = torch.cat(
|
286 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
|
287 |
+
2).permute(0, 3, 1, 2, 4).contiguous()
|
288 |
+
v_rolled = torch.cat(
|
289 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
|
290 |
+
2).permute(0, 3, 1, 2, 4).contiguous()
|
291 |
+
|
292 |
+
# mask out tokens in current window
|
293 |
+
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
|
294 |
+
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
|
295 |
+
temp_N = k_rolled.shape[3]
|
296 |
+
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
|
297 |
+
C // self.num_heads)
|
298 |
+
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
|
299 |
+
C // self.num_heads)
|
300 |
+
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
301 |
+
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
302 |
+
else:
|
303 |
+
k_rolled = k_windows
|
304 |
+
v_rolled = v_windows
|
305 |
+
|
306 |
+
# q(k/v)_windows shape : [16, 4, 225, 128]
|
307 |
+
# k_rolled.shape : [16, 4, 5, 165, 128]
|
308 |
+
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
|
309 |
+
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
|
310 |
+
|
311 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
312 |
+
k_pooled = []
|
313 |
+
v_pooled = []
|
314 |
+
for k in range(self.focal_level - 1):
|
315 |
+
stride = 2**k
|
316 |
+
# B, T, nWh, nWw, C
|
317 |
+
x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
|
318 |
+
4).contiguous()
|
319 |
+
|
320 |
+
nWh, nWw = x_window_pooled.shape[2:4]
|
321 |
+
|
322 |
+
# generate mask for pooled windows
|
323 |
+
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
|
324 |
+
# unfold mask: [nWh*nWw//s//s, k*k, 1]
|
325 |
+
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
|
326 |
+
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
|
327 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
328 |
+
|
329 |
+
if k > 0:
|
330 |
+
valid_ind_unfold_k = getattr(
|
331 |
+
self, "valid_ind_unfold_{}".format(k))
|
332 |
+
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
333 |
+
|
334 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
335 |
+
x_window_masks = x_window_masks.masked_fill(
|
336 |
+
x_window_masks == 0,
|
337 |
+
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
338 |
+
mask_all[k + 1] = x_window_masks
|
339 |
+
|
340 |
+
# generate k and v for pooled windows
|
341 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(
|
342 |
+
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
|
343 |
+
3).view(3, -1, C, nWh,
|
344 |
+
nWw).contiguous()
|
345 |
+
# B*T, C, nWh, nWw
|
346 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
|
347 |
+
# k_pooled_k shape: [5, 512, 4, 4]
|
348 |
+
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
|
349 |
+
|
350 |
+
(k_pooled_k, v_pooled_k) = map(
|
351 |
+
lambda t: self.unfolds[k]
|
352 |
+
(t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
|
353 |
+
unfolds[k].kernel_size[1], -1)
|
354 |
+
.permute(0, 5, 1, 3, 4, 2).contiguous().view(
|
355 |
+
-1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
|
356 |
+
k].kernel_size[1], self.num_heads, C // self.
|
357 |
+
num_heads).permute(0, 3, 1, 2, 4).contiguous(),
|
358 |
+
# (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
|
359 |
+
(k_pooled_k, v_pooled_k))
|
360 |
+
# k_pooled_k shape : [16, 4, 5, 45, 128]
|
361 |
+
|
362 |
+
# select valid unfolding index
|
363 |
+
if k > 0:
|
364 |
+
(k_pooled_k, v_pooled_k) = map(
|
365 |
+
lambda t: t[:, :, :, valid_ind_unfold_k],
|
366 |
+
(k_pooled_k, v_pooled_k))
|
367 |
+
|
368 |
+
k_pooled_k = k_pooled_k.view(
|
369 |
+
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
370 |
+
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
371 |
+
v_pooled_k = v_pooled_k.view(
|
372 |
+
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
373 |
+
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
374 |
+
|
375 |
+
k_pooled += [k_pooled_k]
|
376 |
+
v_pooled += [v_pooled_k]
|
377 |
+
|
378 |
+
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
|
379 |
+
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
380 |
+
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
381 |
+
else:
|
382 |
+
k_all = k_rolled
|
383 |
+
v_all = v_rolled
|
384 |
+
|
385 |
+
N = k_all.shape[-2]
|
386 |
+
q_windows = q_windows * self.scale
|
387 |
+
# B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
|
388 |
+
attn = (q_windows @ k_all.transpose(-2, -1))
|
389 |
+
# T * 45
|
390 |
+
window_area = T * self.window_size[0] * self.window_size[1]
|
391 |
+
# T * 165
|
392 |
+
window_area_rolled = k_rolled.shape[2]
|
393 |
+
|
394 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
395 |
+
offset = window_area_rolled
|
396 |
+
for k in range(self.focal_level - 1):
|
397 |
+
# add attentional mask
|
398 |
+
# mask_all[1] shape [1, 16, T * 45]
|
399 |
+
|
400 |
+
bias = tuple((i + 2**k - 1) for i in self.focal_window)
|
401 |
+
|
402 |
+
if mask_all[k + 1] is not None:
|
403 |
+
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
|
404 |
+
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
|
405 |
+
mask_all[k+1][:, :, None, None, :].repeat(
|
406 |
+
attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
407 |
+
|
408 |
+
offset += T * bias[0] * bias[1]
|
409 |
+
|
410 |
+
if mask_all[0] is not None:
|
411 |
+
nW = mask_all[0].shape[0]
|
412 |
+
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
|
413 |
+
window_area, N)
|
414 |
+
attn[:, :, :, :, :
|
415 |
+
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
|
416 |
+
None, :, None, :, :]
|
417 |
+
attn = attn.view(-1, self.num_heads, window_area, N)
|
418 |
+
attn = self.softmax(attn)
|
419 |
+
else:
|
420 |
+
attn = self.softmax(attn)
|
421 |
+
|
422 |
+
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
|
423 |
+
C)
|
424 |
+
x = self.proj(x)
|
425 |
+
return x
|
426 |
+
|
427 |
+
|
428 |
+
class TemporalFocalTransformerBlock(nn.Module):
|
429 |
+
r""" Temporal Focal Transformer Block.
|
430 |
+
Args:
|
431 |
+
dim (int): Number of input channels.
|
432 |
+
num_heads (int): Number of attention heads.
|
433 |
+
window_size (tuple[int]): Window size.
|
434 |
+
shift_size (int): Shift size for SW-MSA.
|
435 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
436 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
437 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
438 |
+
focal_level (int): The number level of focal window.
|
439 |
+
focal_window (int): Window size of each focal window.
|
440 |
+
n_vecs (int): Required for F3N.
|
441 |
+
t2t_params (int): T2T parameters for F3N.
|
442 |
+
"""
|
443 |
+
def __init__(self,
|
444 |
+
dim,
|
445 |
+
num_heads,
|
446 |
+
window_size=(5, 9),
|
447 |
+
mlp_ratio=4.,
|
448 |
+
qkv_bias=True,
|
449 |
+
pool_method="fc",
|
450 |
+
focal_level=2,
|
451 |
+
focal_window=(5, 9),
|
452 |
+
norm_layer=nn.LayerNorm,
|
453 |
+
n_vecs=None,
|
454 |
+
t2t_params=None):
|
455 |
+
super().__init__()
|
456 |
+
self.dim = dim
|
457 |
+
self.num_heads = num_heads
|
458 |
+
self.window_size = window_size
|
459 |
+
self.expand_size = tuple(i // 2 for i in window_size) # TODO
|
460 |
+
self.mlp_ratio = mlp_ratio
|
461 |
+
self.pool_method = pool_method
|
462 |
+
self.focal_level = focal_level
|
463 |
+
self.focal_window = focal_window
|
464 |
+
|
465 |
+
self.window_size_glo = self.window_size
|
466 |
+
|
467 |
+
self.pool_layers = nn.ModuleList()
|
468 |
+
if self.pool_method != "none":
|
469 |
+
for k in range(self.focal_level - 1):
|
470 |
+
window_size_glo = tuple(
|
471 |
+
math.floor(i / (2**k)) for i in self.window_size_glo)
|
472 |
+
self.pool_layers.append(
|
473 |
+
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
|
474 |
+
self.pool_layers[-1].weight.data.fill_(
|
475 |
+
1. / (window_size_glo[0] * window_size_glo[1]))
|
476 |
+
self.pool_layers[-1].bias.data.fill_(0)
|
477 |
+
|
478 |
+
self.norm1 = norm_layer(dim)
|
479 |
+
|
480 |
+
self.attn = WindowAttention(dim,
|
481 |
+
expand_size=self.expand_size,
|
482 |
+
window_size=self.window_size,
|
483 |
+
focal_window=focal_window,
|
484 |
+
focal_level=focal_level,
|
485 |
+
num_heads=num_heads,
|
486 |
+
qkv_bias=qkv_bias,
|
487 |
+
pool_method=pool_method)
|
488 |
+
|
489 |
+
self.norm2 = norm_layer(dim)
|
490 |
+
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
|
491 |
+
|
492 |
+
def forward(self, x):
|
493 |
+
output_size = x[1]
|
494 |
+
x = x[0]
|
495 |
+
|
496 |
+
B, T, H, W, C = x.shape
|
497 |
+
|
498 |
+
shortcut = x
|
499 |
+
x = self.norm1(x)
|
500 |
+
|
501 |
+
shifted_x = x
|
502 |
+
|
503 |
+
x_windows_all = [shifted_x]
|
504 |
+
x_window_masks_all = [None]
|
505 |
+
|
506 |
+
# partition windows tuple(i // 2 for i in window_size)
|
507 |
+
if self.focal_level > 1 and self.pool_method != "none":
|
508 |
+
# if we add coarser granularity and the pool method is not none
|
509 |
+
for k in range(self.focal_level - 1):
|
510 |
+
window_size_glo = tuple(
|
511 |
+
math.floor(i / (2**k)) for i in self.window_size_glo)
|
512 |
+
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
|
513 |
+
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
|
514 |
+
H_pool = pooled_h * window_size_glo[0]
|
515 |
+
W_pool = pooled_w * window_size_glo[1]
|
516 |
+
|
517 |
+
x_level_k = shifted_x
|
518 |
+
# trim or pad shifted_x depending on the required size
|
519 |
+
if H > H_pool:
|
520 |
+
trim_t = (H - H_pool) // 2
|
521 |
+
trim_b = H - H_pool - trim_t
|
522 |
+
x_level_k = x_level_k[:, :, trim_t:-trim_b]
|
523 |
+
elif H < H_pool:
|
524 |
+
pad_t = (H_pool - H) // 2
|
525 |
+
pad_b = H_pool - H - pad_t
|
526 |
+
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
|
527 |
+
|
528 |
+
if W > W_pool:
|
529 |
+
trim_l = (W - W_pool) // 2
|
530 |
+
trim_r = W - W_pool - trim_l
|
531 |
+
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
|
532 |
+
elif W < W_pool:
|
533 |
+
pad_l = (W_pool - W) // 2
|
534 |
+
pad_r = W_pool - W - pad_l
|
535 |
+
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
|
536 |
+
|
537 |
+
x_windows_noreshape = window_partition_noreshape(
|
538 |
+
x_level_k.contiguous(), window_size_glo
|
539 |
+
) # B, nw, nw, T, window_size, window_size, C
|
540 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
541 |
+
x_windows_noreshape = x_windows_noreshape.view(
|
542 |
+
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
|
543 |
+
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
|
544 |
+
x_windows_pooled = self.pool_layers[k](
|
545 |
+
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
|
546 |
+
|
547 |
+
x_windows_all += [x_windows_pooled]
|
548 |
+
x_window_masks_all += [None]
|
549 |
+
|
550 |
+
# nW*B, T*window_size*window_size, C
|
551 |
+
attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
|
552 |
+
|
553 |
+
# merge windows
|
554 |
+
attn_windows = attn_windows.view(-1, T, self.window_size[0],
|
555 |
+
self.window_size[1], C)
|
556 |
+
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
|
557 |
+
W) # B T H' W' C
|
558 |
+
|
559 |
+
# FFN
|
560 |
+
x = shortcut + shifted_x
|
561 |
+
y = self.norm2(x)
|
562 |
+
x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
|
563 |
+
B, T, H, W, C)
|
564 |
+
|
565 |
+
return x, output_size
|
inpainter/util/__init__.py
ADDED
File without changes
|
inpainter/util/tensor_util.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# resize frames
|
5 |
+
def resize_frames(frames, size=None):
|
6 |
+
"""
|
7 |
+
size: (w, h)
|
8 |
+
"""
|
9 |
+
if size is not None:
|
10 |
+
frames = [cv2.resize(f, size) for f in frames]
|
11 |
+
frames = np.stack(frames, 0)
|
12 |
+
|
13 |
+
return frames
|
14 |
+
|
15 |
+
# resize frames
|
16 |
+
def resize_masks(masks, size=None):
|
17 |
+
"""
|
18 |
+
size: (w, h)
|
19 |
+
"""
|
20 |
+
if size is not None:
|
21 |
+
masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
|
22 |
+
masks = np.stack(masks, 0)
|
23 |
+
|
24 |
+
return masks
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
progressbar2
|
2 |
+
gdown
|
3 |
+
gitpython
|
4 |
+
git+https://github.com/cheind/py-thin-plate-spline
|
5 |
+
hickle
|
6 |
+
tensorboard
|
7 |
+
numpy
|
8 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
9 |
+
gradio==3.25.0
|
10 |
+
opencv-python
|
11 |
+
pycocotools
|
12 |
+
matplotlib
|
13 |
+
onnxruntime
|
14 |
+
onnx
|
15 |
+
metaseg
|
16 |
+
pyyaml
|
17 |
+
av
|
sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
template.html
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- template.html -->
|
2 |
+
<!DOCTYPE html>
|
3 |
+
<html lang="en">
|
4 |
+
<head>
|
5 |
+
<meta charset="UTF-8">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<title>Gradio Video Pause Time</title>
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<video id="video" controls>
|
11 |
+
<source src="{{VIDEO_URL}}" type="video/mp4">
|
12 |
+
Your browser does not support the video tag.
|
13 |
+
</video>
|
14 |
+
<script>
|
15 |
+
const video = document.getElementById("video");
|
16 |
+
let pauseTime = null;
|
17 |
+
|
18 |
+
video.addEventListener("pause", () => {
|
19 |
+
pauseTime = video.currentTime;
|
20 |
+
});
|
21 |
+
|
22 |
+
function getPauseTime() {
|
23 |
+
return pauseTime;
|
24 |
+
}
|
25 |
+
</script>
|
26 |
+
</body>
|
27 |
+
</html>
|
templates/index.html
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<title>Video Object Segmentation</title>
|
8 |
+
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
|
9 |
+
</head>
|
10 |
+
<body>
|
11 |
+
<h1>Video Object Segmentation</h1>
|
12 |
+
|
13 |
+
<input type="file" id="video-input" accept="video/*">
|
14 |
+
<button id="upload-video">Upload Video</button>
|
15 |
+
<br>
|
16 |
+
<button id="template-select">Template Select</button>
|
17 |
+
<button id="sam-refine">SAM Refine</button>
|
18 |
+
<br>
|
19 |
+
<button id="track-video">Track Video</button>
|
20 |
+
<button id="track-image">Track Image</button>
|
21 |
+
<br>
|
22 |
+
<a href="/download_video" id="download-video" download>Download Video</a>
|
23 |
+
|
24 |
+
<script>
|
25 |
+
// JavaScript code for handling interactions with the server
|
26 |
+
$("#upload-video").click(function() {
|
27 |
+
var videoInput = document.getElementById("video-input");
|
28 |
+
var formData = new FormData();
|
29 |
+
formData.append("video", videoInput.files[0]);
|
30 |
+
|
31 |
+
$.ajax({
|
32 |
+
url: "/upload_video",
|
33 |
+
type: "POST",
|
34 |
+
data: formData,
|
35 |
+
processData: false,
|
36 |
+
contentType: false,
|
37 |
+
success: function(response) {
|
38 |
+
console.log(response);
|
39 |
+
// Process the response and update the UI accordingly
|
40 |
+
},
|
41 |
+
error: function(jqXHR, textStatus, errorThrown) {
|
42 |
+
console.log(textStatus, errorThrown);
|
43 |
+
}
|
44 |
+
});
|
45 |
+
});
|
46 |
+
|
47 |
+
</script>
|
48 |
+
</body>
|
49 |
+
</html>
|
50 |
+
|
text_server.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import queue
|
7 |
+
import numpy as np
|
8 |
+
import requests
|
9 |
+
import concurrent.futures
|
10 |
+
from PIL import Image
|
11 |
+
from flask import Flask, render_template, request, jsonify, send_file
|
12 |
+
import torchvision
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from demo import automask_image_app, automask_video_app, sahi_autoseg_app
|
16 |
+
sys.path.append(sys.path[0] + "/tracker")
|
17 |
+
sys.path.append(sys.path[0] + "/tracker/model")
|
18 |
+
from track_anything import TrackingAnything
|
19 |
+
from track_anything import parse_augment
|
20 |
+
|
21 |
+
# ... (all the functions defined in the original code except the Gradio part)
|
22 |
+
|
23 |
+
app = Flask(__name__)
|
24 |
+
app.config['UPLOAD_FOLDER'] = './uploaded_videos'
|
25 |
+
app.config['ALLOWED_EXTENSIONS'] = {'mp4', 'avi', 'mov', 'mkv'}
|
26 |
+
|
27 |
+
|
28 |
+
def allowed_file(filename):
|
29 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
|
30 |
+
|
31 |
+
@app.route("/")
|
32 |
+
def index():
|
33 |
+
return render_template("index.html")
|
34 |
+
|
35 |
+
@app.route("/upload_video", methods=["POST"])
|
36 |
+
def upload_video():
|
37 |
+
# ... (handle video upload and processing)
|
38 |
+
return jsonify(status="success", data=video_data)
|
39 |
+
|
40 |
+
@app.route("/template_select", methods=["POST"])
|
41 |
+
def template_select():
|
42 |
+
# ... (handle template selection and processing)
|
43 |
+
return jsonify(status="success", data=template_data)
|
44 |
+
|
45 |
+
@app.route("/sam_refine", methods=["POST"])
|
46 |
+
def sam_refine_request():
|
47 |
+
# ... (handle sam refine and processing)
|
48 |
+
return jsonify(status="success", data=sam_data)
|
49 |
+
|
50 |
+
@app.route("/track_video", methods=["POST"])
|
51 |
+
def track_video():
|
52 |
+
# ... (handle video tracking and processing)
|
53 |
+
return jsonify(status="success", data=tracking_data)
|
54 |
+
|
55 |
+
@app.route("/track_image", methods=["POST"])
|
56 |
+
def track_image():
|
57 |
+
# ... (handle image tracking and processing)
|
58 |
+
return jsonify(status="success", data=tracking_data)
|
59 |
+
|
60 |
+
@app.route("/download_video", methods=["GET"])
|
61 |
+
def download_video():
|
62 |
+
try:
|
63 |
+
return send_file("output.mp4", attachment_filename="output.mp4")
|
64 |
+
except Exception as e:
|
65 |
+
return str(e)
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
app.run(debug=True, host="0.0.0.0", port=args.port)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
app.run(host="0.0.0.0",port=12212, debug=True)
|
tools/__init__.py
ADDED
File without changes
|
tools/base_segmenter.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import PIL
|
10 |
+
from .mask_painter import mask_painter
|
11 |
+
|
12 |
+
|
13 |
+
class BaseSegmenter:
|
14 |
+
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
|
15 |
+
"""
|
16 |
+
device: model device
|
17 |
+
SAM_checkpoint: path of SAM checkpoint
|
18 |
+
model_type: vit_b, vit_l, vit_h
|
19 |
+
"""
|
20 |
+
print(f"Initializing BaseSegmenter to {device}")
|
21 |
+
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
|
22 |
+
|
23 |
+
self.device = device
|
24 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
25 |
+
self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
|
26 |
+
self.model.to(device=self.device)
|
27 |
+
self.predictor = SamPredictor(self.model)
|
28 |
+
self.embedded = False
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def set_image(self, image: np.ndarray):
|
32 |
+
# PIL.open(image_path) 3channel: RGB
|
33 |
+
# image embedding: avoid encode the same image multiple times
|
34 |
+
self.orignal_image = image
|
35 |
+
if self.embedded:
|
36 |
+
print('repeat embedding, please reset_image.')
|
37 |
+
return
|
38 |
+
self.predictor.set_image(image)
|
39 |
+
self.embedded = True
|
40 |
+
return
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def reset_image(self):
|
44 |
+
# reset image embeding
|
45 |
+
self.predictor.reset_image()
|
46 |
+
self.embedded = False
|
47 |
+
|
48 |
+
def predict(self, prompts, mode, multimask=True):
|
49 |
+
"""
|
50 |
+
image: numpy array, h, w, 3
|
51 |
+
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
|
52 |
+
prompts['point_coords']: numpy array [N,2]
|
53 |
+
prompts['point_labels']: numpy array [1,N]
|
54 |
+
prompts['mask_input']: numpy array [1,256,256]
|
55 |
+
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
|
56 |
+
mask_outputs: True (return 3 masks), False (return 1 mask only)
|
57 |
+
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
|
58 |
+
"""
|
59 |
+
assert self.embedded, 'prediction is called before set_image (feature embedding).'
|
60 |
+
assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
|
61 |
+
|
62 |
+
if mode == 'point':
|
63 |
+
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
64 |
+
point_labels=prompts['point_labels'],
|
65 |
+
multimask_output=multimask)
|
66 |
+
elif mode == 'mask':
|
67 |
+
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
|
68 |
+
multimask_output=multimask)
|
69 |
+
elif mode == 'both': # both
|
70 |
+
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
71 |
+
point_labels=prompts['point_labels'],
|
72 |
+
mask_input=prompts['mask_input'],
|
73 |
+
multimask_output=multimask)
|
74 |
+
else:
|
75 |
+
raise("Not implement now!")
|
76 |
+
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
77 |
+
return masks, scores, logits
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
# load and show an image
|
82 |
+
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
83 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
|
84 |
+
|
85 |
+
# initialise BaseSegmenter
|
86 |
+
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
87 |
+
model_type = 'vit_h'
|
88 |
+
device = "cuda:4"
|
89 |
+
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
|
90 |
+
|
91 |
+
# image embedding (once embedded, multiple prompts can be applied)
|
92 |
+
base_segmenter.set_image(image)
|
93 |
+
|
94 |
+
# examples
|
95 |
+
# point only ------------------------
|
96 |
+
mode = 'point'
|
97 |
+
prompts = {
|
98 |
+
'point_coords': np.array([[500, 375], [1125, 625]]),
|
99 |
+
'point_labels': np.array([1, 1]),
|
100 |
+
}
|
101 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
102 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
103 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
104 |
+
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
105 |
+
|
106 |
+
# both ------------------------
|
107 |
+
mode = 'both'
|
108 |
+
mask_input = logits[np.argmax(scores), :, :]
|
109 |
+
prompts = {'mask_input': mask_input [None, :, :]}
|
110 |
+
prompts = {
|
111 |
+
'point_coords': np.array([[500, 375], [1125, 625]]),
|
112 |
+
'point_labels': np.array([1, 0]),
|
113 |
+
'mask_input': mask_input[None, :, :]
|
114 |
+
}
|
115 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
116 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
117 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
118 |
+
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
119 |
+
|
120 |
+
# mask only ------------------------
|
121 |
+
mode = 'mask'
|
122 |
+
mask_input = logits[np.argmax(scores), :, :]
|
123 |
+
|
124 |
+
prompts = {'mask_input': mask_input[None, :, :]}
|
125 |
+
|
126 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
127 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
128 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
129 |
+
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
tools/interact_tools.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import PIL
|
10 |
+
from .mask_painter import mask_painter as mask_painter2
|
11 |
+
from .base_segmenter import BaseSegmenter
|
12 |
+
from .painter import mask_painter, point_painter
|
13 |
+
import os
|
14 |
+
import requests
|
15 |
+
import sys
|
16 |
+
|
17 |
+
|
18 |
+
mask_color = 3
|
19 |
+
mask_alpha = 0.7
|
20 |
+
contour_color = 1
|
21 |
+
contour_width = 5
|
22 |
+
point_color_ne = 8
|
23 |
+
point_color_ps = 50
|
24 |
+
point_alpha = 0.9
|
25 |
+
point_radius = 15
|
26 |
+
contour_color = 2
|
27 |
+
contour_width = 5
|
28 |
+
|
29 |
+
|
30 |
+
class SamControler():
|
31 |
+
def __init__(self, SAM_checkpoint, model_type, device):
|
32 |
+
'''
|
33 |
+
initialize sam controler
|
34 |
+
'''
|
35 |
+
|
36 |
+
|
37 |
+
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
38 |
+
|
39 |
+
|
40 |
+
def seg_again(self, image: np.ndarray):
|
41 |
+
'''
|
42 |
+
it is used when interact in video
|
43 |
+
'''
|
44 |
+
self.sam_controler.reset_image()
|
45 |
+
self.sam_controler.set_image(image)
|
46 |
+
return
|
47 |
+
|
48 |
+
|
49 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
50 |
+
'''
|
51 |
+
it is used in first frame in video
|
52 |
+
return: mask, logit, painted image(mask+point)
|
53 |
+
'''
|
54 |
+
# self.sam_controler.set_image(image)
|
55 |
+
origal_image = self.sam_controler.orignal_image
|
56 |
+
neg_flag = labels[-1]
|
57 |
+
if neg_flag==1:
|
58 |
+
#find neg
|
59 |
+
prompts = {
|
60 |
+
'point_coords': points,
|
61 |
+
'point_labels': labels,
|
62 |
+
}
|
63 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
64 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
65 |
+
prompts = {
|
66 |
+
'point_coords': points,
|
67 |
+
'point_labels': labels,
|
68 |
+
'mask_input': logit[None, :, :]
|
69 |
+
}
|
70 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
71 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
72 |
+
else:
|
73 |
+
#find positive
|
74 |
+
prompts = {
|
75 |
+
'point_coords': points,
|
76 |
+
'point_labels': labels,
|
77 |
+
}
|
78 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
79 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
80 |
+
|
81 |
+
|
82 |
+
assert len(points)==len(labels)
|
83 |
+
|
84 |
+
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
85 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
86 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
87 |
+
painted_image = Image.fromarray(painted_image)
|
88 |
+
|
89 |
+
return mask, logit, painted_image
|
90 |
+
|
91 |
+
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
92 |
+
origal_image = self.sam_controler.orignal_image
|
93 |
+
if same:
|
94 |
+
'''
|
95 |
+
true; loop in the same image
|
96 |
+
'''
|
97 |
+
prompts = {
|
98 |
+
'point_coords': points,
|
99 |
+
'point_labels': labels,
|
100 |
+
'mask_input': logits[None, :, :]
|
101 |
+
}
|
102 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
103 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
104 |
+
|
105 |
+
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
106 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
107 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
108 |
+
painted_image = Image.fromarray(painted_image)
|
109 |
+
|
110 |
+
return mask, logit, painted_image
|
111 |
+
else:
|
112 |
+
'''
|
113 |
+
loop in the different image, interact in the video
|
114 |
+
'''
|
115 |
+
if image is None:
|
116 |
+
raise('Image error')
|
117 |
+
else:
|
118 |
+
self.seg_again(image)
|
119 |
+
prompts = {
|
120 |
+
'point_coords': points,
|
121 |
+
'point_labels': labels,
|
122 |
+
}
|
123 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
124 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
125 |
+
|
126 |
+
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
127 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
128 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
129 |
+
painted_image = Image.fromarray(painted_image)
|
130 |
+
|
131 |
+
return mask, logit, painted_image
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
# def initialize():
|
139 |
+
# '''
|
140 |
+
# initialize sam controler
|
141 |
+
# '''
|
142 |
+
# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
143 |
+
# folder = "segmenter"
|
144 |
+
# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
145 |
+
# download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
146 |
+
|
147 |
+
|
148 |
+
# model_type = 'vit_h'
|
149 |
+
# device = "cuda:0"
|
150 |
+
# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
151 |
+
# return sam_controler
|
152 |
+
|
153 |
+
|
154 |
+
# def seg_again(sam_controler, image: np.ndarray):
|
155 |
+
# '''
|
156 |
+
# it is used when interact in video
|
157 |
+
# '''
|
158 |
+
# sam_controler.reset_image()
|
159 |
+
# sam_controler.set_image(image)
|
160 |
+
# return
|
161 |
+
|
162 |
+
|
163 |
+
# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
164 |
+
# '''
|
165 |
+
# it is used in first frame in video
|
166 |
+
# return: mask, logit, painted image(mask+point)
|
167 |
+
# '''
|
168 |
+
# sam_controler.set_image(image)
|
169 |
+
# prompts = {
|
170 |
+
# 'point_coords': points,
|
171 |
+
# 'point_labels': labels,
|
172 |
+
# }
|
173 |
+
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
174 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
175 |
+
|
176 |
+
# assert len(points)==len(labels)
|
177 |
+
|
178 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
179 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
180 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
181 |
+
# painted_image = Image.fromarray(painted_image)
|
182 |
+
|
183 |
+
# return mask, logit, painted_image
|
184 |
+
|
185 |
+
# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
186 |
+
# if same:
|
187 |
+
# '''
|
188 |
+
# true; loop in the same image
|
189 |
+
# '''
|
190 |
+
# prompts = {
|
191 |
+
# 'point_coords': points,
|
192 |
+
# 'point_labels': labels,
|
193 |
+
# 'mask_input': logits[None, :, :]
|
194 |
+
# }
|
195 |
+
# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
196 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
197 |
+
|
198 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
199 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
200 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
201 |
+
# painted_image = Image.fromarray(painted_image)
|
202 |
+
|
203 |
+
# return mask, logit, painted_image
|
204 |
+
# else:
|
205 |
+
# '''
|
206 |
+
# loop in the different image, interact in the video
|
207 |
+
# '''
|
208 |
+
# if image is None:
|
209 |
+
# raise('Image error')
|
210 |
+
# else:
|
211 |
+
# seg_again(sam_controler, image)
|
212 |
+
# prompts = {
|
213 |
+
# 'point_coords': points,
|
214 |
+
# 'point_labels': labels,
|
215 |
+
# }
|
216 |
+
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
217 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
218 |
+
|
219 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
220 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
221 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
222 |
+
# painted_image = Image.fromarray(painted_image)
|
223 |
+
|
224 |
+
# return mask, logit, painted_image
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
points = np.array([[500, 375], [1125, 625]])
|
231 |
+
labels = np.array([1, 1])
|
232 |
+
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
233 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
234 |
+
|
235 |
+
sam_controler = initialize()
|
236 |
+
mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
237 |
+
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
238 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
239 |
+
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
240 |
+
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
241 |
+
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
242 |
+
|
243 |
+
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
244 |
+
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
245 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
246 |
+
cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
247 |
+
painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
248 |
+
|
249 |
+
mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
250 |
+
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
251 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
252 |
+
cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
253 |
+
painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
|
tools/mask_painter.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import copy
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def colormap(rgb=True):
|
10 |
+
color_list = np.array(
|
11 |
+
[
|
12 |
+
0.000, 0.000, 0.000,
|
13 |
+
1.000, 1.000, 1.000,
|
14 |
+
1.000, 0.498, 0.313,
|
15 |
+
0.392, 0.581, 0.929,
|
16 |
+
0.000, 0.447, 0.741,
|
17 |
+
0.850, 0.325, 0.098,
|
18 |
+
0.929, 0.694, 0.125,
|
19 |
+
0.494, 0.184, 0.556,
|
20 |
+
0.466, 0.674, 0.188,
|
21 |
+
0.301, 0.745, 0.933,
|
22 |
+
0.635, 0.078, 0.184,
|
23 |
+
0.300, 0.300, 0.300,
|
24 |
+
0.600, 0.600, 0.600,
|
25 |
+
1.000, 0.000, 0.000,
|
26 |
+
1.000, 0.500, 0.000,
|
27 |
+
0.749, 0.749, 0.000,
|
28 |
+
0.000, 1.000, 0.000,
|
29 |
+
0.000, 0.000, 1.000,
|
30 |
+
0.667, 0.000, 1.000,
|
31 |
+
0.333, 0.333, 0.000,
|
32 |
+
0.333, 0.667, 0.000,
|
33 |
+
0.333, 1.000, 0.000,
|
34 |
+
0.667, 0.333, 0.000,
|
35 |
+
0.667, 0.667, 0.000,
|
36 |
+
0.667, 1.000, 0.000,
|
37 |
+
1.000, 0.333, 0.000,
|
38 |
+
1.000, 0.667, 0.000,
|
39 |
+
1.000, 1.000, 0.000,
|
40 |
+
0.000, 0.333, 0.500,
|
41 |
+
0.000, 0.667, 0.500,
|
42 |
+
0.000, 1.000, 0.500,
|
43 |
+
0.333, 0.000, 0.500,
|
44 |
+
0.333, 0.333, 0.500,
|
45 |
+
0.333, 0.667, 0.500,
|
46 |
+
0.333, 1.000, 0.500,
|
47 |
+
0.667, 0.000, 0.500,
|
48 |
+
0.667, 0.333, 0.500,
|
49 |
+
0.667, 0.667, 0.500,
|
50 |
+
0.667, 1.000, 0.500,
|
51 |
+
1.000, 0.000, 0.500,
|
52 |
+
1.000, 0.333, 0.500,
|
53 |
+
1.000, 0.667, 0.500,
|
54 |
+
1.000, 1.000, 0.500,
|
55 |
+
0.000, 0.333, 1.000,
|
56 |
+
0.000, 0.667, 1.000,
|
57 |
+
0.000, 1.000, 1.000,
|
58 |
+
0.333, 0.000, 1.000,
|
59 |
+
0.333, 0.333, 1.000,
|
60 |
+
0.333, 0.667, 1.000,
|
61 |
+
0.333, 1.000, 1.000,
|
62 |
+
0.667, 0.000, 1.000,
|
63 |
+
0.667, 0.333, 1.000,
|
64 |
+
0.667, 0.667, 1.000,
|
65 |
+
0.667, 1.000, 1.000,
|
66 |
+
1.000, 0.000, 1.000,
|
67 |
+
1.000, 0.333, 1.000,
|
68 |
+
1.000, 0.667, 1.000,
|
69 |
+
0.167, 0.000, 0.000,
|
70 |
+
0.333, 0.000, 0.000,
|
71 |
+
0.500, 0.000, 0.000,
|
72 |
+
0.667, 0.000, 0.000,
|
73 |
+
0.833, 0.000, 0.000,
|
74 |
+
1.000, 0.000, 0.000,
|
75 |
+
0.000, 0.167, 0.000,
|
76 |
+
0.000, 0.333, 0.000,
|
77 |
+
0.000, 0.500, 0.000,
|
78 |
+
0.000, 0.667, 0.000,
|
79 |
+
0.000, 0.833, 0.000,
|
80 |
+
0.000, 1.000, 0.000,
|
81 |
+
0.000, 0.000, 0.167,
|
82 |
+
0.000, 0.000, 0.333,
|
83 |
+
0.000, 0.000, 0.500,
|
84 |
+
0.000, 0.000, 0.667,
|
85 |
+
0.000, 0.000, 0.833,
|
86 |
+
0.000, 0.000, 1.000,
|
87 |
+
0.143, 0.143, 0.143,
|
88 |
+
0.286, 0.286, 0.286,
|
89 |
+
0.429, 0.429, 0.429,
|
90 |
+
0.571, 0.571, 0.571,
|
91 |
+
0.714, 0.714, 0.714,
|
92 |
+
0.857, 0.857, 0.857
|
93 |
+
]
|
94 |
+
).astype(np.float32)
|
95 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
96 |
+
if not rgb:
|
97 |
+
color_list = color_list[:, ::-1]
|
98 |
+
return color_list
|
99 |
+
|
100 |
+
|
101 |
+
color_list = colormap()
|
102 |
+
color_list = color_list.astype('uint8').tolist()
|
103 |
+
|
104 |
+
|
105 |
+
def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
|
106 |
+
background_color = np.array(background_color)
|
107 |
+
contour_color = np.array(contour_color)
|
108 |
+
|
109 |
+
# background_mask = 1 - background_mask
|
110 |
+
# contour_mask = 1 - contour_mask
|
111 |
+
|
112 |
+
for i in range(3):
|
113 |
+
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
114 |
+
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
115 |
+
|
116 |
+
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
117 |
+
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
118 |
+
|
119 |
+
return image.astype('uint8')
|
120 |
+
|
121 |
+
|
122 |
+
def mask_generator_00(mask, background_radius, contour_radius):
|
123 |
+
# no background width when '00'
|
124 |
+
# distance map
|
125 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
126 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
127 |
+
dist_map = dist_transform_fore - dist_transform_back
|
128 |
+
# ...:::!!!:::...
|
129 |
+
contour_radius += 2
|
130 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
131 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
132 |
+
contour_mask[contour_mask>0.5] = 1.
|
133 |
+
|
134 |
+
return mask, contour_mask
|
135 |
+
|
136 |
+
|
137 |
+
def mask_generator_01(mask, background_radius, contour_radius):
|
138 |
+
# no background width when '00'
|
139 |
+
# distance map
|
140 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
141 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
142 |
+
dist_map = dist_transform_fore - dist_transform_back
|
143 |
+
# ...:::!!!:::...
|
144 |
+
contour_radius += 2
|
145 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
146 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
147 |
+
return mask, contour_mask
|
148 |
+
|
149 |
+
|
150 |
+
def mask_generator_10(mask, background_radius, contour_radius):
|
151 |
+
# distance map
|
152 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
153 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
154 |
+
dist_map = dist_transform_fore - dist_transform_back
|
155 |
+
# .....:::::!!!!!
|
156 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
157 |
+
background_mask = (background_mask - np.min(background_mask))
|
158 |
+
background_mask = background_mask / np.max(background_mask)
|
159 |
+
# ...:::!!!:::...
|
160 |
+
contour_radius += 2
|
161 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
162 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
163 |
+
contour_mask[contour_mask>0.5] = 1.
|
164 |
+
return background_mask, contour_mask
|
165 |
+
|
166 |
+
|
167 |
+
def mask_generator_11(mask, background_radius, contour_radius):
|
168 |
+
# distance map
|
169 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
170 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
171 |
+
dist_map = dist_transform_fore - dist_transform_back
|
172 |
+
# .....:::::!!!!!
|
173 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
174 |
+
background_mask = (background_mask - np.min(background_mask))
|
175 |
+
background_mask = background_mask / np.max(background_mask)
|
176 |
+
# ...:::!!!:::...
|
177 |
+
contour_radius += 2
|
178 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
179 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
180 |
+
return background_mask, contour_mask
|
181 |
+
|
182 |
+
|
183 |
+
def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
|
184 |
+
"""
|
185 |
+
Input:
|
186 |
+
input_image: numpy array
|
187 |
+
input_mask: numpy array
|
188 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
189 |
+
background_blur_radius: radius of background blur, must be odd number
|
190 |
+
contour_width: width of mask contour, must be odd number
|
191 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
192 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
193 |
+
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
|
194 |
+
|
195 |
+
Output:
|
196 |
+
painted_image: numpy array
|
197 |
+
"""
|
198 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
199 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
200 |
+
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
|
201 |
+
|
202 |
+
# downsample input image and mask
|
203 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
204 |
+
res = 1024
|
205 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
206 |
+
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
207 |
+
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
208 |
+
|
209 |
+
# 0: background, 1: foreground
|
210 |
+
msk = np.clip(input_mask, 0, 1)
|
211 |
+
|
212 |
+
# generate masks for background and contour pixels
|
213 |
+
background_radius = (background_blur_radius - 1) // 2
|
214 |
+
contour_radius = (contour_width - 1) // 2
|
215 |
+
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
|
216 |
+
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
217 |
+
|
218 |
+
# paint
|
219 |
+
painted_image = vis_add_mask\
|
220 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
221 |
+
|
222 |
+
return painted_image
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == '__main__':
|
226 |
+
|
227 |
+
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
228 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
229 |
+
contour_width = 11 # contour width, must be odd number
|
230 |
+
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
231 |
+
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
232 |
+
|
233 |
+
# load input image and mask
|
234 |
+
input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
|
235 |
+
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
236 |
+
|
237 |
+
# paint
|
238 |
+
overall_time_1 = 0
|
239 |
+
overall_time_2 = 0
|
240 |
+
overall_time_3 = 0
|
241 |
+
overall_time_4 = 0
|
242 |
+
overall_time_5 = 0
|
243 |
+
|
244 |
+
for i in range(50):
|
245 |
+
t2 = time.time()
|
246 |
+
painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|
247 |
+
e2 = time.time()
|
248 |
+
|
249 |
+
t3 = time.time()
|
250 |
+
painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
|
251 |
+
e3 = time.time()
|
252 |
+
|
253 |
+
t1 = time.time()
|
254 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
|
255 |
+
e1 = time.time()
|
256 |
+
|
257 |
+
t4 = time.time()
|
258 |
+
painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
|
259 |
+
e4 = time.time()
|
260 |
+
|
261 |
+
t5 = time.time()
|
262 |
+
painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
|
263 |
+
e5 = time.time()
|
264 |
+
|
265 |
+
overall_time_1 += (e1 - t1)
|
266 |
+
overall_time_2 += (e2 - t2)
|
267 |
+
overall_time_3 += (e3 - t3)
|
268 |
+
overall_time_4 += (e4 - t4)
|
269 |
+
overall_time_5 += (e5 - t5)
|
270 |
+
|
271 |
+
print(f'average time w gaussian: {overall_time_1/50}')
|
272 |
+
print(f'average time w/o gaussian00: {overall_time_2/50}')
|
273 |
+
print(f'average time w/o gaussian10: {overall_time_3/50}')
|
274 |
+
print(f'average time w/o gaussian01: {overall_time_4/50}')
|
275 |
+
print(f'average time w/o gaussian11: {overall_time_5/50}')
|
276 |
+
|
277 |
+
# save
|
278 |
+
painted_image_00 = Image.fromarray(painted_image_00)
|
279 |
+
painted_image_00.save('./test_img/painter_output_image_00.png')
|
280 |
+
|
281 |
+
painted_image_10 = Image.fromarray(painted_image_10)
|
282 |
+
painted_image_10.save('./test_img/painter_output_image_10.png')
|
283 |
+
|
284 |
+
painted_image_01 = Image.fromarray(painted_image_01)
|
285 |
+
painted_image_01.save('./test_img/painter_output_image_01.png')
|
286 |
+
|
287 |
+
painted_image_11 = Image.fromarray(painted_image_11)
|
288 |
+
painted_image_11.save('./test_img/painter_output_image_11.png')
|
tools/painter.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# paint masks, contours, or points on images, with specified colors
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import copy
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
def colormap(rgb=True):
|
11 |
+
color_list = np.array(
|
12 |
+
[
|
13 |
+
0.000, 0.000, 0.000,
|
14 |
+
1.000, 1.000, 1.000,
|
15 |
+
1.000, 0.498, 0.313,
|
16 |
+
0.392, 0.581, 0.929,
|
17 |
+
0.000, 0.447, 0.741,
|
18 |
+
0.850, 0.325, 0.098,
|
19 |
+
0.929, 0.694, 0.125,
|
20 |
+
0.494, 0.184, 0.556,
|
21 |
+
0.466, 0.674, 0.188,
|
22 |
+
0.301, 0.745, 0.933,
|
23 |
+
0.635, 0.078, 0.184,
|
24 |
+
0.300, 0.300, 0.300,
|
25 |
+
0.600, 0.600, 0.600,
|
26 |
+
1.000, 0.000, 0.000,
|
27 |
+
1.000, 0.500, 0.000,
|
28 |
+
0.749, 0.749, 0.000,
|
29 |
+
0.000, 1.000, 0.000,
|
30 |
+
0.000, 0.000, 1.000,
|
31 |
+
0.667, 0.000, 1.000,
|
32 |
+
0.333, 0.333, 0.000,
|
33 |
+
0.333, 0.667, 0.000,
|
34 |
+
0.333, 1.000, 0.000,
|
35 |
+
0.667, 0.333, 0.000,
|
36 |
+
0.667, 0.667, 0.000,
|
37 |
+
0.667, 1.000, 0.000,
|
38 |
+
1.000, 0.333, 0.000,
|
39 |
+
1.000, 0.667, 0.000,
|
40 |
+
1.000, 1.000, 0.000,
|
41 |
+
0.000, 0.333, 0.500,
|
42 |
+
0.000, 0.667, 0.500,
|
43 |
+
0.000, 1.000, 0.500,
|
44 |
+
0.333, 0.000, 0.500,
|
45 |
+
0.333, 0.333, 0.500,
|
46 |
+
0.333, 0.667, 0.500,
|
47 |
+
0.333, 1.000, 0.500,
|
48 |
+
0.667, 0.000, 0.500,
|
49 |
+
0.667, 0.333, 0.500,
|
50 |
+
0.667, 0.667, 0.500,
|
51 |
+
0.667, 1.000, 0.500,
|
52 |
+
1.000, 0.000, 0.500,
|
53 |
+
1.000, 0.333, 0.500,
|
54 |
+
1.000, 0.667, 0.500,
|
55 |
+
1.000, 1.000, 0.500,
|
56 |
+
0.000, 0.333, 1.000,
|
57 |
+
0.000, 0.667, 1.000,
|
58 |
+
0.000, 1.000, 1.000,
|
59 |
+
0.333, 0.000, 1.000,
|
60 |
+
0.333, 0.333, 1.000,
|
61 |
+
0.333, 0.667, 1.000,
|
62 |
+
0.333, 1.000, 1.000,
|
63 |
+
0.667, 0.000, 1.000,
|
64 |
+
0.667, 0.333, 1.000,
|
65 |
+
0.667, 0.667, 1.000,
|
66 |
+
0.667, 1.000, 1.000,
|
67 |
+
1.000, 0.000, 1.000,
|
68 |
+
1.000, 0.333, 1.000,
|
69 |
+
1.000, 0.667, 1.000,
|
70 |
+
0.167, 0.000, 0.000,
|
71 |
+
0.333, 0.000, 0.000,
|
72 |
+
0.500, 0.000, 0.000,
|
73 |
+
0.667, 0.000, 0.000,
|
74 |
+
0.833, 0.000, 0.000,
|
75 |
+
1.000, 0.000, 0.000,
|
76 |
+
0.000, 0.167, 0.000,
|
77 |
+
0.000, 0.333, 0.000,
|
78 |
+
0.000, 0.500, 0.000,
|
79 |
+
0.000, 0.667, 0.000,
|
80 |
+
0.000, 0.833, 0.000,
|
81 |
+
0.000, 1.000, 0.000,
|
82 |
+
0.000, 0.000, 0.167,
|
83 |
+
0.000, 0.000, 0.333,
|
84 |
+
0.000, 0.000, 0.500,
|
85 |
+
0.000, 0.000, 0.667,
|
86 |
+
0.000, 0.000, 0.833,
|
87 |
+
0.000, 0.000, 1.000,
|
88 |
+
0.143, 0.143, 0.143,
|
89 |
+
0.286, 0.286, 0.286,
|
90 |
+
0.429, 0.429, 0.429,
|
91 |
+
0.571, 0.571, 0.571,
|
92 |
+
0.714, 0.714, 0.714,
|
93 |
+
0.857, 0.857, 0.857
|
94 |
+
]
|
95 |
+
).astype(np.float32)
|
96 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
97 |
+
if not rgb:
|
98 |
+
color_list = color_list[:, ::-1]
|
99 |
+
return color_list
|
100 |
+
|
101 |
+
|
102 |
+
color_list = colormap()
|
103 |
+
color_list = color_list.astype('uint8').tolist()
|
104 |
+
|
105 |
+
|
106 |
+
def vis_add_mask(image, mask, color, alpha):
|
107 |
+
color = np.array(color_list[color])
|
108 |
+
mask = mask > 0.5
|
109 |
+
image[mask] = image[mask] * (1-alpha) + color * alpha
|
110 |
+
return image.astype('uint8')
|
111 |
+
|
112 |
+
def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
|
113 |
+
h, w = input_image.shape[:2]
|
114 |
+
point_mask = np.zeros((h, w)).astype('uint8')
|
115 |
+
for point in input_points:
|
116 |
+
point_mask[point[1], point[0]] = 1
|
117 |
+
|
118 |
+
kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
|
119 |
+
point_mask = cv2.dilate(point_mask, kernel)
|
120 |
+
|
121 |
+
contour_radius = (contour_width - 1) // 2
|
122 |
+
dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
|
123 |
+
dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
|
124 |
+
dist_map = dist_transform_fore - dist_transform_back
|
125 |
+
# ...:::!!!:::...
|
126 |
+
contour_radius += 2
|
127 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
128 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
129 |
+
contour_mask[contour_mask>0.5] = 1.
|
130 |
+
|
131 |
+
# paint mask
|
132 |
+
painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
|
133 |
+
# paint contour
|
134 |
+
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
135 |
+
return painted_image
|
136 |
+
|
137 |
+
def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
|
138 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
|
139 |
+
# 0: background, 1: foreground
|
140 |
+
mask = np.clip(input_mask, 0, 1)
|
141 |
+
contour_radius = (contour_width - 1) // 2
|
142 |
+
|
143 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
144 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
145 |
+
dist_map = dist_transform_fore - dist_transform_back
|
146 |
+
# ...:::!!!:::...
|
147 |
+
contour_radius += 2
|
148 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
149 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
150 |
+
contour_mask[contour_mask>0.5] = 1.
|
151 |
+
|
152 |
+
# paint mask
|
153 |
+
painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
|
154 |
+
# paint contour
|
155 |
+
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
156 |
+
|
157 |
+
return painted_image
|
158 |
+
|
159 |
+
def background_remover(input_image, input_mask):
|
160 |
+
"""
|
161 |
+
input_image: H, W, 3, np.array
|
162 |
+
input_mask: H, W, np.array
|
163 |
+
|
164 |
+
image_wo_background: PIL.Image
|
165 |
+
"""
|
166 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
|
167 |
+
# 0: background, 1: foreground
|
168 |
+
mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
|
169 |
+
image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
|
170 |
+
image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
|
171 |
+
|
172 |
+
return image_wo_background
|
173 |
+
|
174 |
+
if __name__ == '__main__':
|
175 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
176 |
+
input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
|
177 |
+
|
178 |
+
# example of mask painter
|
179 |
+
mask_color = 3
|
180 |
+
mask_alpha = 0.7
|
181 |
+
contour_color = 1
|
182 |
+
contour_width = 5
|
183 |
+
|
184 |
+
# save
|
185 |
+
painted_image = Image.fromarray(input_image)
|
186 |
+
painted_image.save('images/original.png')
|
187 |
+
|
188 |
+
painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
|
189 |
+
# save
|
190 |
+
painted_image = Image.fromarray(input_image)
|
191 |
+
painted_image.save('images/original1.png')
|
192 |
+
|
193 |
+
# example of point painter
|
194 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
195 |
+
input_points = np.array([[500, 375], [70, 600]]) # x, y
|
196 |
+
point_color = 5
|
197 |
+
point_alpha = 0.9
|
198 |
+
point_radius = 15
|
199 |
+
contour_color = 2
|
200 |
+
contour_width = 5
|
201 |
+
painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
202 |
+
# save
|
203 |
+
painted_image = Image.fromarray(painted_image_1)
|
204 |
+
painted_image.save('images/point_painter_1.png')
|
205 |
+
|
206 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
207 |
+
painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
|
208 |
+
# save
|
209 |
+
painted_image = Image.fromarray(painted_image_2)
|
210 |
+
painted_image.save('images/point_painter_2.png')
|
211 |
+
|
212 |
+
# example of background remover
|
213 |
+
input_image = np.array(Image.open('images/original.png').convert('RGB'))
|
214 |
+
image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
|
215 |
+
image_wo_background.save('images/image_wo_background.png')
|
track_anything.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("/hhd3/gaoshang/Track-Anything/tracker")
|
3 |
+
import PIL
|
4 |
+
from tools.interact_tools import SamControler
|
5 |
+
from tracker.base_tracker import BaseTracker
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class TrackingAnything():
|
12 |
+
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
13 |
+
self.args = args
|
14 |
+
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
15 |
+
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
16 |
+
|
17 |
+
|
18 |
+
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
19 |
+
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
20 |
+
if first_flag:
|
21 |
+
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
22 |
+
return mask, logit, painted_image
|
23 |
+
|
24 |
+
if interact_flag:
|
25 |
+
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
26 |
+
return mask, logit, painted_image
|
27 |
+
|
28 |
+
mask, logit, painted_image = self.xmem.track(image, logit)
|
29 |
+
return mask, logit, painted_image
|
30 |
+
|
31 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
32 |
+
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
33 |
+
return mask, logit, painted_image
|
34 |
+
|
35 |
+
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
36 |
+
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
37 |
+
return mask, logit, painted_image
|
38 |
+
|
39 |
+
def generator(self, images: list, template_mask:np.ndarray):
|
40 |
+
|
41 |
+
masks = []
|
42 |
+
logits = []
|
43 |
+
painted_images = []
|
44 |
+
for i in range(len(images)):
|
45 |
+
if i ==0:
|
46 |
+
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
47 |
+
masks.append(mask)
|
48 |
+
logits.append(logit)
|
49 |
+
painted_images.append(painted_image)
|
50 |
+
|
51 |
+
else:
|
52 |
+
mask, logit, painted_image = self.xmem.track(images[i])
|
53 |
+
masks.append(mask)
|
54 |
+
logits.append(logit)
|
55 |
+
painted_images.append(painted_image)
|
56 |
+
return masks, logits, painted_images
|
57 |
+
|
58 |
+
|
59 |
+
def parse_augment():
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
62 |
+
parser.add_argument('--sam_model_type', type=str, default="vit_h")
|
63 |
+
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
|
64 |
+
parser.add_argument('--debug', action="store_true")
|
65 |
+
parser.add_argument('--mask_save', default=True)
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
if args.debug:
|
69 |
+
print(args)
|
70 |
+
return args
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
masks = None
|
75 |
+
logits = None
|
76 |
+
painted_images = None
|
77 |
+
images = []
|
78 |
+
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
|
79 |
+
args = parse_augment()
|
80 |
+
# images.append(np.ones((20,20,3)).astype('uint8'))
|
81 |
+
# images.append(np.ones((20,20,3)).astype('uint8'))
|
82 |
+
images.append(image)
|
83 |
+
images.append(image)
|
84 |
+
|
85 |
+
mask = np.zeros_like(image)[:,:,0]
|
86 |
+
mask[0,0]= 1
|
87 |
+
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
|
88 |
+
masks, logits ,painted_images= trackany.generator(images, mask)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
tracker/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
tracker/base_tracker.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import for debugging
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
# import for base_tracker
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from model.network import XMem
|
11 |
+
from inference.inference_core import InferenceCore
|
12 |
+
from util.mask_mapper import MaskMapper
|
13 |
+
from torchvision import transforms
|
14 |
+
from util.range_transform import im_normalization
|
15 |
+
import sys
|
16 |
+
sys.path.insert(0, sys.path[0]+"/../")
|
17 |
+
from tools.painter import mask_painter
|
18 |
+
from tools.base_segmenter import BaseSegmenter
|
19 |
+
from torchvision.transforms import Resize
|
20 |
+
|
21 |
+
|
22 |
+
class BaseTracker:
|
23 |
+
def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
|
24 |
+
"""
|
25 |
+
device: model device
|
26 |
+
xmem_checkpoint: checkpoint of XMem model
|
27 |
+
"""
|
28 |
+
# load configurations
|
29 |
+
with open("tracker/config/config.yaml", 'r') as stream:
|
30 |
+
config = yaml.safe_load(stream)
|
31 |
+
# initialise XMem
|
32 |
+
network = XMem(config, xmem_checkpoint).to(device).eval()
|
33 |
+
# initialise IncerenceCore
|
34 |
+
self.tracker = InferenceCore(network, config)
|
35 |
+
# data transformation
|
36 |
+
self.im_transform = transforms.Compose([
|
37 |
+
transforms.ToTensor(),
|
38 |
+
im_normalization,
|
39 |
+
])
|
40 |
+
self.device = device
|
41 |
+
|
42 |
+
# changable properties
|
43 |
+
self.mapper = MaskMapper()
|
44 |
+
self.initialised = False
|
45 |
+
|
46 |
+
# # SAM-based refinement
|
47 |
+
# self.sam_model = sam_model
|
48 |
+
# self.resizer = Resize([256, 256])
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def resize_mask(self, mask):
|
52 |
+
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
53 |
+
h, w = mask.shape[-2:]
|
54 |
+
min_hw = min(h, w)
|
55 |
+
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
56 |
+
mode='nearest')
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def track(self, frame, first_frame_annotation=None):
|
60 |
+
"""
|
61 |
+
Input:
|
62 |
+
frames: numpy arrays (H, W, 3)
|
63 |
+
logit: numpy array (H, W), logit
|
64 |
+
|
65 |
+
Output:
|
66 |
+
mask: numpy arrays (H, W)
|
67 |
+
logit: numpy arrays, probability map (H, W)
|
68 |
+
painted_image: numpy array (H, W, 3)
|
69 |
+
"""
|
70 |
+
if first_frame_annotation is not None: # first frame mask
|
71 |
+
# initialisation
|
72 |
+
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
73 |
+
mask = torch.Tensor(mask).to(self.device)
|
74 |
+
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
75 |
+
else:
|
76 |
+
mask = None
|
77 |
+
labels = None
|
78 |
+
# prepare inputs
|
79 |
+
frame_tensor = self.im_transform(frame).to(self.device)
|
80 |
+
# track one frame
|
81 |
+
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
82 |
+
# # refine
|
83 |
+
# if first_frame_annotation is None:
|
84 |
+
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
85 |
+
|
86 |
+
# convert to mask
|
87 |
+
out_mask = torch.argmax(probs, dim=0)
|
88 |
+
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
89 |
+
|
90 |
+
num_objs = out_mask.max()
|
91 |
+
painted_image = frame
|
92 |
+
for obj in range(1, num_objs+1):
|
93 |
+
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
94 |
+
|
95 |
+
return out_mask, out_mask, painted_image
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def sam_refinement(self, frame, logits, ti):
|
99 |
+
"""
|
100 |
+
refine segmentation results with mask prompt
|
101 |
+
"""
|
102 |
+
# convert to 1, 256, 256
|
103 |
+
self.sam_model.set_image(frame)
|
104 |
+
mode = 'mask'
|
105 |
+
logits = logits.unsqueeze(0)
|
106 |
+
logits = self.resizer(logits).cpu().numpy()
|
107 |
+
prompts = {'mask_input': logits} # 1 256 256
|
108 |
+
masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
109 |
+
painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8)
|
110 |
+
painted_image = Image.fromarray(painted_image)
|
111 |
+
painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png')
|
112 |
+
self.sam_model.reset_image()
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def clear_memory(self):
|
116 |
+
self.tracker.clear_memory()
|
117 |
+
self.mapper.clear_labels()
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
# video frames (multiple objects)
|
122 |
+
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
123 |
+
video_path_list.sort()
|
124 |
+
# first frame
|
125 |
+
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
126 |
+
# load frames
|
127 |
+
frames = []
|
128 |
+
for video_path in video_path_list:
|
129 |
+
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
130 |
+
frames = np.stack(frames, 0) # N, H, W, C
|
131 |
+
# load first frame annotation
|
132 |
+
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
133 |
+
|
134 |
+
# ----------------------------------------------------------
|
135 |
+
# initalise tracker
|
136 |
+
# ----------------------------------------------------------
|
137 |
+
device = 'cuda:4'
|
138 |
+
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
139 |
+
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
140 |
+
model_type = 'vit_h'
|
141 |
+
|
142 |
+
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
143 |
+
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
144 |
+
|
145 |
+
# test for storage efficiency
|
146 |
+
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
147 |
+
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
148 |
+
|
149 |
+
for ti, frame in enumerate(frames):
|
150 |
+
print(ti)
|
151 |
+
if ti > 200:
|
152 |
+
break
|
153 |
+
if ti == 0:
|
154 |
+
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
155 |
+
else:
|
156 |
+
mask, prob, painted_image = tracker.track(frame)
|
157 |
+
# save
|
158 |
+
painted_image = Image.fromarray(painted_image)
|
159 |
+
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
160 |
+
|
161 |
+
tracker.clear_memory()
|
162 |
+
for ti, frame in enumerate(frames):
|
163 |
+
print(ti)
|
164 |
+
# if ti > 200:
|
165 |
+
# break
|
166 |
+
if ti == 0:
|
167 |
+
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
168 |
+
else:
|
169 |
+
mask, prob, painted_image = tracker.track(frame)
|
170 |
+
# save
|
171 |
+
painted_image = Image.fromarray(painted_image)
|
172 |
+
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
173 |
+
|
174 |
+
# # track anything given in the first frame annotation
|
175 |
+
# for ti, frame in enumerate(frames):
|
176 |
+
# if ti == 0:
|
177 |
+
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
178 |
+
# else:
|
179 |
+
# mask, prob, painted_image = tracker.track(frame)
|
180 |
+
# # save
|
181 |
+
# painted_image = Image.fromarray(painted_image)
|
182 |
+
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
183 |
+
|
184 |
+
# # ----------------------------------------------------------
|
185 |
+
# # another video
|
186 |
+
# # ----------------------------------------------------------
|
187 |
+
# # video frames
|
188 |
+
# video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
|
189 |
+
# video_path_list.sort()
|
190 |
+
# # first frame
|
191 |
+
# first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
|
192 |
+
# # load frames
|
193 |
+
# frames = []
|
194 |
+
# for video_path in video_path_list:
|
195 |
+
# frames.append(np.array(Image.open(video_path).convert('RGB')))
|
196 |
+
# frames = np.stack(frames, 0) # N, H, W, C
|
197 |
+
# # load first frame annotation
|
198 |
+
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
199 |
+
|
200 |
+
# print('first video done. clear.')
|
201 |
+
|
202 |
+
# tracker.clear_memory()
|
203 |
+
# # track anything given in the first frame annotation
|
204 |
+
# for ti, frame in enumerate(frames):
|
205 |
+
# if ti == 0:
|
206 |
+
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
207 |
+
# else:
|
208 |
+
# mask, prob, painted_image = tracker.track(frame)
|
209 |
+
# # save
|
210 |
+
# painted_image = Image.fromarray(painted_image)
|
211 |
+
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
212 |
+
|
213 |
+
# # failure case test
|
214 |
+
# failure_path = '/ssd1/gaomingqi/failure'
|
215 |
+
# frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
216 |
+
# # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
217 |
+
# first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
218 |
+
# first_mask = np.clip(first_mask, 0, 1)
|
219 |
+
|
220 |
+
# for ti, frame in enumerate(frames):
|
221 |
+
# if ti == 0:
|
222 |
+
# mask, probs, painted_image = tracker.track(frame, first_mask)
|
223 |
+
# else:
|
224 |
+
# mask, probs, painted_image = tracker.track(frame)
|
225 |
+
# # save
|
226 |
+
# painted_image = Image.fromarray(painted_image)
|
227 |
+
# painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
228 |
+
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
229 |
+
|
230 |
+
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
231 |
+
|
232 |
+
|
233 |
+
|