Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +25 -0
- .gitignore +135 -0
- .gitmodules +7 -0
- CITATION.cff +8 -0
- Dockerfile +30 -0
- EfficientSAM/EdgeSAM/common.py +118 -0
- EfficientSAM/EdgeSAM/rep_vit.py +370 -0
- EfficientSAM/EdgeSAM/setup_edge_sam.py +90 -0
- EfficientSAM/FastSAM/tools.py +413 -0
- EfficientSAM/LightHQSAM/example_light_hqsam.png +3 -0
- EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg +0 -0
- EfficientSAM/LightHQSAM/setup_light_hqsam.py +45 -0
- EfficientSAM/LightHQSAM/tiny_vit_sam.py +724 -0
- EfficientSAM/MobileSAM/setup_mobile_sam.py +44 -0
- EfficientSAM/MobileSAM/tiny_vit_sam.py +716 -0
- EfficientSAM/README.md +194 -0
- EfficientSAM/RepViTSAM/repvit.py +364 -0
- EfficientSAM/RepViTSAM/setup_repvit_sam.py +53 -0
- EfficientSAM/grounded_edge_sam.py +107 -0
- EfficientSAM/grounded_efficient_sam.py +118 -0
- EfficientSAM/grounded_fast_sam.py +141 -0
- EfficientSAM/grounded_light_hqsam.py +109 -0
- EfficientSAM/grounded_mobile_sam.py +145 -0
- EfficientSAM/grounded_repvit_sam.py +107 -0
- GroundingDINO/.asset/COCO.png +0 -0
- GroundingDINO/.asset/GD_GLIGEN.png +3 -0
- GroundingDINO/.asset/GD_SD.png +3 -0
- GroundingDINO/.asset/ODinW.png +0 -0
- GroundingDINO/.asset/arch.png +0 -0
- GroundingDINO/.asset/cats.png +0 -0
- GroundingDINO/.asset/hero_figure.png +3 -0
- GroundingDINO/LICENSE +201 -0
- GroundingDINO/README.md +163 -0
- GroundingDINO/demo/gradio_app.py +125 -0
- GroundingDINO/demo/inference_on_a_image.py +172 -0
- GroundingDINO/groundingdino/__init__.py +0 -0
- GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py +43 -0
- GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
- GroundingDINO/groundingdino/datasets/__init__.py +0 -0
- GroundingDINO/groundingdino/datasets/transforms.py +311 -0
- GroundingDINO/groundingdino/models/GroundingDINO/__init__.py +15 -0
- GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py +1 -0
- GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py +221 -0
- GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py +186 -0
- GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +802 -0
- GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py +273 -0
- GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h +64 -0
- GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp +43 -0
- GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h +35 -0
- GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +156 -0
.gitattributes
CHANGED
@@ -33,3 +33,28 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
EfficientSAM/LightHQSAM/example_light_hqsam.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
GroundingDINO/.asset/GD_GLIGEN.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
GroundingDINO/.asset/GD_SD.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
GroundingDINO/.asset/hero_figure.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
VISAM/thirdparty/segment_anything/assets/masks1.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
VISAM/thirdparty/segment_anything/assets/notebook2.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
VISAM/visam.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/acoustics/gsam_whisper_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/acoustics/gsam_whisper_inpainting_pipeline.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/demo9.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/gradio_demo.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/grounded_sam_demo3_demo4.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/grounded_sam_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/grounded_sam_new_demo_image.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/mask_3dbox.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/osx/grounded_sam_osx_demo.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/osx/grouned_sam_osx_demo.gif filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/ram_grounded_sam_new.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
segment_anything/assets/masks1.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
segment_anything/assets/notebook2.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
voxelnext_3d_box/images/image_boxes1.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
voxelnext_3d_box/images/image_boxes2.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
voxelnext_3d_box/images/image_boxes3.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
voxelnext_3d_box/images/mask_box.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
voxelnext_3d_box/images/sam-voxelnext.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# checkpoint
|
132 |
+
*.pth
|
133 |
+
outputs/
|
134 |
+
|
135 |
+
.idea/
|
.gitmodules
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[submodule "grounded-sam-osx"]
|
3 |
+
path = grounded-sam-osx
|
4 |
+
url = https://github.com/linjing7/grounded-sam-osx.git
|
5 |
+
[submodule "VISAM"]
|
6 |
+
path = VISAM
|
7 |
+
url = https://github.com/BingfengYan/VISAM
|
CITATION.cff
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you use this software, please cite it as below."
|
3 |
+
authors:
|
4 |
+
- name: "Grounded-SAM Contributors"
|
5 |
+
title: "Grounded-Segment-Anything"
|
6 |
+
date-released: 2023-04-06
|
7 |
+
url: "https://github.com/IDEA-Research/Grounded-Segment-Anything"
|
8 |
+
license: Apache-2.0
|
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
|
2 |
+
|
3 |
+
# Arguments to build Docker Image using CUDA
|
4 |
+
ARG USE_CUDA=0
|
5 |
+
ARG TORCH_ARCH=
|
6 |
+
|
7 |
+
ENV AM_I_DOCKER True
|
8 |
+
ENV BUILD_WITH_CUDA "${USE_CUDA}"
|
9 |
+
ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}"
|
10 |
+
ENV CUDA_HOME /usr/local/cuda-11.6/
|
11 |
+
|
12 |
+
RUN mkdir -p /home/appuser/Grounded-Segment-Anything
|
13 |
+
COPY . /home/appuser/Grounded-Segment-Anything/
|
14 |
+
|
15 |
+
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
|
16 |
+
libsm6=2:* libxext6=2:* git=1:* nano=2.* \
|
17 |
+
vim=2:* -y \
|
18 |
+
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
|
19 |
+
|
20 |
+
WORKDIR /home/appuser/Grounded-Segment-Anything
|
21 |
+
RUN python -m pip install --no-cache-dir -e segment_anything
|
22 |
+
|
23 |
+
# When using build isolation, PyTorch with newer CUDA is installed and can't compile GroundingDINO
|
24 |
+
RUN python -m pip install --no-cache-dir wheel
|
25 |
+
RUN python -m pip install --no-cache-dir --no-build-isolation -e GroundingDINO
|
26 |
+
|
27 |
+
WORKDIR /home/appuser
|
28 |
+
RUN pip install --no-cache-dir diffusers[torch]==0.15.1 opencv-python==4.7.0.72 \
|
29 |
+
pycocotools==2.0.6 matplotlib==3.5.3 \
|
30 |
+
onnxruntime==1.14.1 onnx==1.13.1 ipykernel==6.16.2 scipy gradio openai
|
EfficientSAM/EdgeSAM/common.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from typing import Type
|
12 |
+
|
13 |
+
|
14 |
+
class MLPBlock(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
embedding_dim: int,
|
18 |
+
mlp_dim: int,
|
19 |
+
act: Type[nn.Module] = nn.GELU,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
23 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
24 |
+
self.act = act()
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
27 |
+
return self.lin2(self.act(self.lin1(x)))
|
28 |
+
|
29 |
+
|
30 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
31 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
32 |
+
class LayerNorm2d(nn.Module):
|
33 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
36 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
37 |
+
self.eps = eps
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
u = x.mean(1, keepdim=True)
|
41 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
42 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
43 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def val2list(x: list or tuple or any, repeat_time=1) -> list:
|
48 |
+
if isinstance(x, (list, tuple)):
|
49 |
+
return list(x)
|
50 |
+
return [x for _ in range(repeat_time)]
|
51 |
+
|
52 |
+
|
53 |
+
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
|
54 |
+
x = val2list(x)
|
55 |
+
|
56 |
+
# repeat elements if necessary
|
57 |
+
if len(x) > 0:
|
58 |
+
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
59 |
+
|
60 |
+
return tuple(x)
|
61 |
+
|
62 |
+
|
63 |
+
def list_sum(x: list) -> any:
|
64 |
+
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
|
65 |
+
|
66 |
+
|
67 |
+
def resize(
|
68 |
+
x: torch.Tensor,
|
69 |
+
size: any or None = None,
|
70 |
+
scale_factor=None,
|
71 |
+
mode: str = "bicubic",
|
72 |
+
align_corners: bool or None = False,
|
73 |
+
) -> torch.Tensor:
|
74 |
+
if mode in ["bilinear", "bicubic"]:
|
75 |
+
return F.interpolate(
|
76 |
+
x,
|
77 |
+
size=size,
|
78 |
+
scale_factor=scale_factor,
|
79 |
+
mode=mode,
|
80 |
+
align_corners=align_corners,
|
81 |
+
)
|
82 |
+
elif mode in ["nearest", "area"]:
|
83 |
+
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
|
84 |
+
else:
|
85 |
+
raise NotImplementedError(f"resize(mode={mode}) not implemented.")
|
86 |
+
|
87 |
+
|
88 |
+
class UpSampleLayer(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
mode="bicubic",
|
92 |
+
size=None,
|
93 |
+
factor=2,
|
94 |
+
align_corners=False,
|
95 |
+
):
|
96 |
+
super(UpSampleLayer, self).__init__()
|
97 |
+
self.mode = mode
|
98 |
+
self.size = val2list(size, 2) if size is not None else None
|
99 |
+
self.factor = None if self.size is not None else factor
|
100 |
+
self.align_corners = align_corners
|
101 |
+
|
102 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
103 |
+
return resize(x, self.size, self.factor, self.mode, self.align_corners)
|
104 |
+
|
105 |
+
|
106 |
+
class OpSequential(nn.Module):
|
107 |
+
def __init__(self, op_list):
|
108 |
+
super(OpSequential, self).__init__()
|
109 |
+
valid_op_list = []
|
110 |
+
for op in op_list:
|
111 |
+
if op is not None:
|
112 |
+
valid_op_list.append(op)
|
113 |
+
self.op_list = nn.ModuleList(valid_op_list)
|
114 |
+
|
115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
116 |
+
for op in self.op_list:
|
117 |
+
x = op(x)
|
118 |
+
return x
|
EfficientSAM/EdgeSAM/rep_vit.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from EdgeSAM.common import LayerNorm2d, UpSampleLayer, OpSequential
|
3 |
+
|
4 |
+
__all__ = ['rep_vit_m1', 'rep_vit_m2', 'rep_vit_m3', 'RepViT']
|
5 |
+
|
6 |
+
m1_cfgs = [
|
7 |
+
# k, t, c, SE, HS, s
|
8 |
+
[3, 2, 48, 1, 0, 1],
|
9 |
+
[3, 2, 48, 0, 0, 1],
|
10 |
+
[3, 2, 48, 0, 0, 1],
|
11 |
+
[3, 2, 96, 0, 0, 2],
|
12 |
+
[3, 2, 96, 1, 0, 1],
|
13 |
+
[3, 2, 96, 0, 0, 1],
|
14 |
+
[3, 2, 96, 0, 0, 1],
|
15 |
+
[3, 2, 192, 0, 1, 2],
|
16 |
+
[3, 2, 192, 1, 1, 1],
|
17 |
+
[3, 2, 192, 0, 1, 1],
|
18 |
+
[3, 2, 192, 1, 1, 1],
|
19 |
+
[3, 2, 192, 0, 1, 1],
|
20 |
+
[3, 2, 192, 1, 1, 1],
|
21 |
+
[3, 2, 192, 0, 1, 1],
|
22 |
+
[3, 2, 192, 1, 1, 1],
|
23 |
+
[3, 2, 192, 0, 1, 1],
|
24 |
+
[3, 2, 192, 1, 1, 1],
|
25 |
+
[3, 2, 192, 0, 1, 1],
|
26 |
+
[3, 2, 192, 1, 1, 1],
|
27 |
+
[3, 2, 192, 0, 1, 1],
|
28 |
+
[3, 2, 192, 1, 1, 1],
|
29 |
+
[3, 2, 192, 0, 1, 1],
|
30 |
+
[3, 2, 192, 0, 1, 1],
|
31 |
+
[3, 2, 384, 0, 1, 2],
|
32 |
+
[3, 2, 384, 1, 1, 1],
|
33 |
+
[3, 2, 384, 0, 1, 1]
|
34 |
+
]
|
35 |
+
|
36 |
+
m2_cfgs = [
|
37 |
+
# k, t, c, SE, HS, s
|
38 |
+
[3, 2, 64, 1, 0, 1],
|
39 |
+
[3, 2, 64, 0, 0, 1],
|
40 |
+
[3, 2, 64, 0, 0, 1],
|
41 |
+
[3, 2, 128, 0, 0, 2],
|
42 |
+
[3, 2, 128, 1, 0, 1],
|
43 |
+
[3, 2, 128, 0, 0, 1],
|
44 |
+
[3, 2, 128, 0, 0, 1],
|
45 |
+
[3, 2, 256, 0, 1, 2],
|
46 |
+
[3, 2, 256, 1, 1, 1],
|
47 |
+
[3, 2, 256, 0, 1, 1],
|
48 |
+
[3, 2, 256, 1, 1, 1],
|
49 |
+
[3, 2, 256, 0, 1, 1],
|
50 |
+
[3, 2, 256, 1, 1, 1],
|
51 |
+
[3, 2, 256, 0, 1, 1],
|
52 |
+
[3, 2, 256, 1, 1, 1],
|
53 |
+
[3, 2, 256, 0, 1, 1],
|
54 |
+
[3, 2, 256, 1, 1, 1],
|
55 |
+
[3, 2, 256, 0, 1, 1],
|
56 |
+
[3, 2, 256, 1, 1, 1],
|
57 |
+
[3, 2, 256, 0, 1, 1],
|
58 |
+
[3, 2, 256, 0, 1, 1],
|
59 |
+
[3, 2, 512, 0, 1, 2],
|
60 |
+
[3, 2, 512, 1, 1, 1],
|
61 |
+
[3, 2, 512, 0, 1, 1]
|
62 |
+
]
|
63 |
+
|
64 |
+
m3_cfgs = [
|
65 |
+
# k, t, c, SE, HS, s
|
66 |
+
[3, 2, 64, 1, 0, 1],
|
67 |
+
[3, 2, 64, 0, 0, 1],
|
68 |
+
[3, 2, 64, 1, 0, 1],
|
69 |
+
[3, 2, 64, 0, 0, 1],
|
70 |
+
[3, 2, 64, 0, 0, 1],
|
71 |
+
[3, 2, 128, 0, 0, 2],
|
72 |
+
[3, 2, 128, 1, 0, 1],
|
73 |
+
[3, 2, 128, 0, 0, 1],
|
74 |
+
[3, 2, 128, 1, 0, 1],
|
75 |
+
[3, 2, 128, 0, 0, 1],
|
76 |
+
[3, 2, 128, 0, 0, 1],
|
77 |
+
[3, 2, 256, 0, 1, 2],
|
78 |
+
[3, 2, 256, 1, 1, 1],
|
79 |
+
[3, 2, 256, 0, 1, 1],
|
80 |
+
[3, 2, 256, 1, 1, 1],
|
81 |
+
[3, 2, 256, 0, 1, 1],
|
82 |
+
[3, 2, 256, 1, 1, 1],
|
83 |
+
[3, 2, 256, 0, 1, 1],
|
84 |
+
[3, 2, 256, 1, 1, 1],
|
85 |
+
[3, 2, 256, 0, 1, 1],
|
86 |
+
[3, 2, 256, 1, 1, 1],
|
87 |
+
[3, 2, 256, 0, 1, 1],
|
88 |
+
[3, 2, 256, 1, 1, 1],
|
89 |
+
[3, 2, 256, 0, 1, 1],
|
90 |
+
[3, 2, 256, 1, 1, 1],
|
91 |
+
[3, 2, 256, 0, 1, 1],
|
92 |
+
[3, 2, 256, 1, 1, 1],
|
93 |
+
[3, 2, 256, 0, 1, 1],
|
94 |
+
[3, 2, 256, 1, 1, 1],
|
95 |
+
[3, 2, 256, 0, 1, 1],
|
96 |
+
[3, 2, 256, 0, 1, 1],
|
97 |
+
[3, 2, 512, 0, 1, 2],
|
98 |
+
[3, 2, 512, 1, 1, 1],
|
99 |
+
[3, 2, 512, 0, 1, 1]
|
100 |
+
]
|
101 |
+
|
102 |
+
|
103 |
+
def _make_divisible(v, divisor, min_value=None):
|
104 |
+
"""
|
105 |
+
This function is taken from the original tf repo.
|
106 |
+
It ensures that all layers have a channel number that is divisible by 8
|
107 |
+
It can be seen here:
|
108 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
109 |
+
:param v:
|
110 |
+
:param divisor:
|
111 |
+
:param min_value:
|
112 |
+
:return:
|
113 |
+
"""
|
114 |
+
if min_value is None:
|
115 |
+
min_value = divisor
|
116 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
117 |
+
# Make sure that round down does not go down by more than 10%.
|
118 |
+
if new_v < 0.9 * v:
|
119 |
+
new_v += divisor
|
120 |
+
return new_v
|
121 |
+
|
122 |
+
|
123 |
+
from timm.models.layers import SqueezeExcite
|
124 |
+
|
125 |
+
import torch
|
126 |
+
|
127 |
+
|
128 |
+
class Conv2d_BN(torch.nn.Sequential):
|
129 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
130 |
+
groups=1, bn_weight_init=1, resolution=-10000):
|
131 |
+
super().__init__()
|
132 |
+
self.add_module('c', torch.nn.Conv2d(
|
133 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
134 |
+
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
135 |
+
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
136 |
+
torch.nn.init.constant_(self.bn.bias, 0)
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def fuse(self):
|
140 |
+
c, bn = self._modules.values()
|
141 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
142 |
+
w = c.weight * w[:, None, None, None]
|
143 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
144 |
+
(bn.running_var + bn.eps) ** 0.5
|
145 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
146 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
|
147 |
+
groups=self.c.groups,
|
148 |
+
device=c.weight.device)
|
149 |
+
m.weight.data.copy_(w)
|
150 |
+
m.bias.data.copy_(b)
|
151 |
+
return m
|
152 |
+
|
153 |
+
|
154 |
+
class Residual(torch.nn.Module):
|
155 |
+
def __init__(self, m, drop=0.):
|
156 |
+
super().__init__()
|
157 |
+
self.m = m
|
158 |
+
self.drop = drop
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
if self.training and self.drop > 0:
|
162 |
+
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
|
163 |
+
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
164 |
+
else:
|
165 |
+
return x + self.m(x)
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def fuse(self):
|
169 |
+
if isinstance(self.m, Conv2d_BN):
|
170 |
+
m = self.m.fuse()
|
171 |
+
assert (m.groups == m.in_channels)
|
172 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
173 |
+
identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
|
174 |
+
m.weight += identity.to(m.weight.device)
|
175 |
+
return m
|
176 |
+
elif isinstance(self.m, torch.nn.Conv2d):
|
177 |
+
m = self.m
|
178 |
+
assert (m.groups != m.in_channels)
|
179 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
180 |
+
identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
|
181 |
+
m.weight += identity.to(m.weight.device)
|
182 |
+
return m
|
183 |
+
else:
|
184 |
+
return self
|
185 |
+
|
186 |
+
|
187 |
+
class RepVGGDW(torch.nn.Module):
|
188 |
+
def __init__(self, ed) -> None:
|
189 |
+
super().__init__()
|
190 |
+
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
|
191 |
+
self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
|
192 |
+
self.dim = ed
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
return self.conv(x) + self.conv1(x) + x
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def fuse(self):
|
199 |
+
conv = self.conv.fuse()
|
200 |
+
conv1 = self.conv1.fuse()
|
201 |
+
|
202 |
+
conv_w = conv.weight
|
203 |
+
conv_b = conv.bias
|
204 |
+
conv1_w = conv1.weight
|
205 |
+
conv1_b = conv1.bias
|
206 |
+
|
207 |
+
conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1])
|
208 |
+
|
209 |
+
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device),
|
210 |
+
[1, 1, 1, 1])
|
211 |
+
|
212 |
+
final_conv_w = conv_w + conv1_w + identity
|
213 |
+
final_conv_b = conv_b + conv1_b
|
214 |
+
|
215 |
+
conv.weight.data.copy_(final_conv_w)
|
216 |
+
conv.bias.data.copy_(final_conv_b)
|
217 |
+
return conv
|
218 |
+
|
219 |
+
|
220 |
+
class RepViTBlock(nn.Module):
|
221 |
+
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs, skip_downsample=False):
|
222 |
+
super(RepViTBlock, self).__init__()
|
223 |
+
assert stride in [1, 2]
|
224 |
+
|
225 |
+
self.identity = stride == 1 and inp == oup
|
226 |
+
assert (hidden_dim == 2 * inp)
|
227 |
+
|
228 |
+
if stride == 2:
|
229 |
+
if skip_downsample:
|
230 |
+
stride = 1
|
231 |
+
self.token_mixer = nn.Sequential(
|
232 |
+
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
|
233 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
234 |
+
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
|
235 |
+
)
|
236 |
+
self.channel_mixer = Residual(nn.Sequential(
|
237 |
+
# pw
|
238 |
+
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
|
239 |
+
nn.GELU() if use_hs else nn.GELU(),
|
240 |
+
# pw-linear
|
241 |
+
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
|
242 |
+
))
|
243 |
+
else:
|
244 |
+
assert (self.identity)
|
245 |
+
self.token_mixer = nn.Sequential(
|
246 |
+
RepVGGDW(inp),
|
247 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
248 |
+
)
|
249 |
+
self.channel_mixer = Residual(nn.Sequential(
|
250 |
+
# pw
|
251 |
+
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
|
252 |
+
nn.GELU() if use_hs else nn.GELU(),
|
253 |
+
# pw-linear
|
254 |
+
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
255 |
+
))
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
return self.channel_mixer(self.token_mixer(x))
|
259 |
+
|
260 |
+
|
261 |
+
from timm.models.vision_transformer import trunc_normal_
|
262 |
+
|
263 |
+
|
264 |
+
class BN_Linear(torch.nn.Sequential):
|
265 |
+
def __init__(self, a, b, bias=True, std=0.02):
|
266 |
+
super().__init__()
|
267 |
+
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
268 |
+
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
|
269 |
+
trunc_normal_(self.l.weight, std=std)
|
270 |
+
if bias:
|
271 |
+
torch.nn.init.constant_(self.l.bias, 0)
|
272 |
+
|
273 |
+
@torch.no_grad()
|
274 |
+
def fuse(self):
|
275 |
+
bn, l = self._modules.values()
|
276 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
277 |
+
b = bn.bias - self.bn.running_mean * \
|
278 |
+
self.bn.weight / (bn.running_var + bn.eps) ** 0.5
|
279 |
+
w = l.weight * w[None, :]
|
280 |
+
if l.bias is None:
|
281 |
+
b = b @ self.l.weight.T
|
282 |
+
else:
|
283 |
+
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
284 |
+
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
|
285 |
+
m.weight.data.copy_(w)
|
286 |
+
m.bias.data.copy_(b)
|
287 |
+
return m
|
288 |
+
|
289 |
+
|
290 |
+
class RepViT(nn.Module):
|
291 |
+
arch_settings = {
|
292 |
+
'm1': m1_cfgs,
|
293 |
+
'm2': m2_cfgs,
|
294 |
+
'm3': m3_cfgs
|
295 |
+
}
|
296 |
+
|
297 |
+
def __init__(self, arch, img_size=1024, upsample_mode='bicubic'):
|
298 |
+
super(RepViT, self).__init__()
|
299 |
+
# setting of inverted residual blocks
|
300 |
+
self.cfgs = self.arch_settings[arch]
|
301 |
+
self.img_size = img_size
|
302 |
+
|
303 |
+
# building first layer
|
304 |
+
input_channel = self.cfgs[0][2]
|
305 |
+
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
|
306 |
+
Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
|
307 |
+
layers = [patch_embed]
|
308 |
+
# building inverted residual blocks
|
309 |
+
block = RepViTBlock
|
310 |
+
self.stage_idx = []
|
311 |
+
prev_c = input_channel
|
312 |
+
for idx, (k, t, c, use_se, use_hs, s) in enumerate(self.cfgs):
|
313 |
+
output_channel = _make_divisible(c, 8)
|
314 |
+
exp_size = _make_divisible(input_channel * t, 8)
|
315 |
+
skip_downsample = False
|
316 |
+
if c != prev_c:
|
317 |
+
self.stage_idx.append(idx - 1)
|
318 |
+
prev_c = c
|
319 |
+
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs, skip_downsample))
|
320 |
+
input_channel = output_channel
|
321 |
+
self.stage_idx.append(idx)
|
322 |
+
self.features = nn.ModuleList(layers)
|
323 |
+
|
324 |
+
stage2_channels = _make_divisible(self.cfgs[self.stage_idx[2]][2], 8)
|
325 |
+
stage3_channels = _make_divisible(self.cfgs[self.stage_idx[3]][2], 8)
|
326 |
+
self.fuse_stage2 = nn.Conv2d(stage2_channels, 256, kernel_size=1, bias=False)
|
327 |
+
self.fuse_stage3 = OpSequential([
|
328 |
+
nn.Conv2d(stage3_channels, 256, kernel_size=1, bias=False),
|
329 |
+
UpSampleLayer(factor=2, mode=upsample_mode),
|
330 |
+
])
|
331 |
+
|
332 |
+
self.neck = nn.Sequential(
|
333 |
+
nn.Conv2d(256, 256, kernel_size=1, bias=False),
|
334 |
+
LayerNorm2d(256),
|
335 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
|
336 |
+
LayerNorm2d(256),
|
337 |
+
)
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
counter = 0
|
341 |
+
output_dict = dict()
|
342 |
+
# patch_embed
|
343 |
+
x = self.features[0](x)
|
344 |
+
output_dict['stem'] = x
|
345 |
+
# stages
|
346 |
+
for idx, f in enumerate(self.features[1:]):
|
347 |
+
x = f(x)
|
348 |
+
if idx in self.stage_idx:
|
349 |
+
output_dict[f'stage{counter}'] = x
|
350 |
+
counter += 1
|
351 |
+
|
352 |
+
x = self.fuse_stage2(output_dict['stage2']) + self.fuse_stage3(output_dict['stage3'])
|
353 |
+
|
354 |
+
x = self.neck(x)
|
355 |
+
# hack this place because we modified the predictor of SAM for HQ-SAM in
|
356 |
+
# segment_anything/segment_anything/predictor.py line 91 to return intern features of the backbone
|
357 |
+
# self.features, self.interm_features = self.model.image_encoder(input_image)
|
358 |
+
return x, None
|
359 |
+
|
360 |
+
|
361 |
+
def rep_vit_m1(img_size=1024, **kwargs):
|
362 |
+
return RepViT('m1', img_size, **kwargs)
|
363 |
+
|
364 |
+
|
365 |
+
def rep_vit_m2(img_size=1024, **kwargs):
|
366 |
+
return RepViT('m2', img_size, **kwargs)
|
367 |
+
|
368 |
+
|
369 |
+
def rep_vit_m3(img_size=1024, **kwargs):
|
370 |
+
return RepViT('m3', img_size, **kwargs)
|
EfficientSAM/EdgeSAM/setup_edge_sam.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
12 |
+
from EdgeSAM.rep_vit import RepViT
|
13 |
+
|
14 |
+
|
15 |
+
prompt_embed_dim = 256
|
16 |
+
image_size = 1024
|
17 |
+
vit_patch_size = 16
|
18 |
+
image_embedding_size = image_size // vit_patch_size
|
19 |
+
|
20 |
+
|
21 |
+
def build_edge_sam(checkpoint=None, upsample_mode="bicubic"):
|
22 |
+
image_encoder = RepViT(
|
23 |
+
arch="m1",
|
24 |
+
img_size=image_size,
|
25 |
+
upsample_mode=upsample_mode
|
26 |
+
)
|
27 |
+
return _build_sam(image_encoder, checkpoint)
|
28 |
+
|
29 |
+
|
30 |
+
sam_model_registry = {
|
31 |
+
"default": build_edge_sam,
|
32 |
+
"edge_sam": build_edge_sam,
|
33 |
+
}
|
34 |
+
|
35 |
+
def _build_sam_encoder(
|
36 |
+
encoder_embed_dim,
|
37 |
+
encoder_depth,
|
38 |
+
encoder_num_heads,
|
39 |
+
encoder_global_attn_indexes,
|
40 |
+
):
|
41 |
+
image_encoder = ImageEncoderViT(
|
42 |
+
depth=encoder_depth,
|
43 |
+
embed_dim=encoder_embed_dim,
|
44 |
+
img_size=image_size,
|
45 |
+
mlp_ratio=4,
|
46 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
47 |
+
num_heads=encoder_num_heads,
|
48 |
+
patch_size=vit_patch_size,
|
49 |
+
qkv_bias=True,
|
50 |
+
use_rel_pos=True,
|
51 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
52 |
+
window_size=14,
|
53 |
+
out_chans=prompt_embed_dim,
|
54 |
+
)
|
55 |
+
return image_encoder
|
56 |
+
|
57 |
+
|
58 |
+
def _build_sam(
|
59 |
+
image_encoder,
|
60 |
+
checkpoint=None,
|
61 |
+
):
|
62 |
+
sam = Sam(
|
63 |
+
image_encoder=image_encoder,
|
64 |
+
prompt_encoder=PromptEncoder(
|
65 |
+
embed_dim=prompt_embed_dim,
|
66 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
67 |
+
input_image_size=(image_size, image_size),
|
68 |
+
mask_in_chans=16,
|
69 |
+
),
|
70 |
+
mask_decoder=MaskDecoder(
|
71 |
+
num_multimask_outputs=3,
|
72 |
+
transformer=TwoWayTransformer(
|
73 |
+
depth=2,
|
74 |
+
embedding_dim=prompt_embed_dim,
|
75 |
+
mlp_dim=2048,
|
76 |
+
num_heads=8,
|
77 |
+
),
|
78 |
+
transformer_dim=prompt_embed_dim,
|
79 |
+
iou_head_depth=3,
|
80 |
+
iou_head_hidden_dim=256,
|
81 |
+
),
|
82 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
83 |
+
pixel_std=[58.395, 57.12, 57.375],
|
84 |
+
)
|
85 |
+
sam.eval()
|
86 |
+
if checkpoint is not None:
|
87 |
+
with open(checkpoint, "rb") as f:
|
88 |
+
state_dict = torch.load(f, map_location="cpu")
|
89 |
+
sam.load_state_dict(state_dict)
|
90 |
+
return sam
|
EfficientSAM/FastSAM/tools.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import clip
|
8 |
+
|
9 |
+
|
10 |
+
def convert_box_xywh_to_xyxy(box):
|
11 |
+
x1 = box[0]
|
12 |
+
y1 = box[1]
|
13 |
+
x2 = box[0] + box[2]
|
14 |
+
y2 = box[1] + box[3]
|
15 |
+
return [x1, y1, x2, y2]
|
16 |
+
|
17 |
+
|
18 |
+
def segment_image(image, bbox):
|
19 |
+
image_array = np.array(image)
|
20 |
+
segmented_image_array = np.zeros_like(image_array)
|
21 |
+
x1, y1, x2, y2 = bbox
|
22 |
+
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
23 |
+
segmented_image = Image.fromarray(segmented_image_array)
|
24 |
+
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
25 |
+
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
26 |
+
transparency_mask = np.zeros(
|
27 |
+
(image_array.shape[0], image_array.shape[1]), dtype=np.uint8
|
28 |
+
)
|
29 |
+
transparency_mask[y1:y2, x1:x2] = 255
|
30 |
+
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
31 |
+
black_image.paste(segmented_image, mask=transparency_mask_image)
|
32 |
+
return black_image
|
33 |
+
|
34 |
+
|
35 |
+
def format_results(result, filter=0):
|
36 |
+
annotations = []
|
37 |
+
n = len(result.masks.data)
|
38 |
+
for i in range(n):
|
39 |
+
annotation = {}
|
40 |
+
mask = result.masks.data[i] == 1.0
|
41 |
+
|
42 |
+
if torch.sum(mask) < filter:
|
43 |
+
continue
|
44 |
+
annotation["id"] = i
|
45 |
+
annotation["segmentation"] = mask.cpu().numpy()
|
46 |
+
annotation["bbox"] = result.boxes.data[i]
|
47 |
+
annotation["score"] = result.boxes.conf[i]
|
48 |
+
annotation["area"] = annotation["segmentation"].sum()
|
49 |
+
annotations.append(annotation)
|
50 |
+
return annotations
|
51 |
+
|
52 |
+
|
53 |
+
def filter_masks(annotations): # filte the overlap mask
|
54 |
+
annotations.sort(key=lambda x: x["area"], reverse=True)
|
55 |
+
to_remove = set()
|
56 |
+
for i in range(0, len(annotations)):
|
57 |
+
a = annotations[i]
|
58 |
+
for j in range(i + 1, len(annotations)):
|
59 |
+
b = annotations[j]
|
60 |
+
if i != j and j not in to_remove:
|
61 |
+
# check if
|
62 |
+
if b["area"] < a["area"]:
|
63 |
+
if (a["segmentation"] & b["segmentation"]).sum() / b[
|
64 |
+
"segmentation"
|
65 |
+
].sum() > 0.8:
|
66 |
+
to_remove.add(j)
|
67 |
+
|
68 |
+
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
69 |
+
|
70 |
+
|
71 |
+
def get_bbox_from_mask(mask):
|
72 |
+
mask = mask.astype(np.uint8)
|
73 |
+
contours, hierarchy = cv2.findContours(
|
74 |
+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
75 |
+
)
|
76 |
+
x1, y1, w, h = cv2.boundingRect(contours[0])
|
77 |
+
x2, y2 = x1 + w, y1 + h
|
78 |
+
if len(contours) > 1:
|
79 |
+
for b in contours:
|
80 |
+
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
81 |
+
# 将多个bbox合并成一个
|
82 |
+
x1 = min(x1, x_t)
|
83 |
+
y1 = min(y1, y_t)
|
84 |
+
x2 = max(x2, x_t + w_t)
|
85 |
+
y2 = max(y2, y_t + h_t)
|
86 |
+
h = y2 - y1
|
87 |
+
w = x2 - x1
|
88 |
+
return [x1, y1, x2, y2]
|
89 |
+
|
90 |
+
|
91 |
+
def fast_process(
|
92 |
+
annotations, args, mask_random_color, bbox=None, points=None, edges=False
|
93 |
+
):
|
94 |
+
if isinstance(annotations[0], dict):
|
95 |
+
annotations = [annotation["segmentation"] for annotation in annotations]
|
96 |
+
result_name = os.path.basename(args.img_path)
|
97 |
+
image = cv2.imread(args.img_path)
|
98 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
99 |
+
original_h = image.shape[0]
|
100 |
+
original_w = image.shape[1]
|
101 |
+
plt.figure(figsize=(original_w/100, original_h/100))
|
102 |
+
plt.imshow(image)
|
103 |
+
if args.better_quality == True:
|
104 |
+
if isinstance(annotations[0], torch.Tensor):
|
105 |
+
annotations = np.array(annotations.cpu())
|
106 |
+
for i, mask in enumerate(annotations):
|
107 |
+
mask = cv2.morphologyEx(
|
108 |
+
mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
|
109 |
+
)
|
110 |
+
annotations[i] = cv2.morphologyEx(
|
111 |
+
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
|
112 |
+
)
|
113 |
+
if args.device == "cpu":
|
114 |
+
annotations = np.array(annotations)
|
115 |
+
fast_show_mask(
|
116 |
+
annotations,
|
117 |
+
plt.gca(),
|
118 |
+
random_color=mask_random_color,
|
119 |
+
bbox=bbox,
|
120 |
+
points=points,
|
121 |
+
pointlabel=args.point_label,
|
122 |
+
retinamask=args.retina,
|
123 |
+
target_height=original_h,
|
124 |
+
target_width=original_w,
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
if isinstance(annotations[0], np.ndarray):
|
128 |
+
annotations = torch.from_numpy(annotations)
|
129 |
+
fast_show_mask_gpu(
|
130 |
+
annotations,
|
131 |
+
plt.gca(),
|
132 |
+
random_color=args.randomcolor,
|
133 |
+
bbox=bbox,
|
134 |
+
points=points,
|
135 |
+
pointlabel=args.point_label,
|
136 |
+
retinamask=args.retina,
|
137 |
+
target_height=original_h,
|
138 |
+
target_width=original_w,
|
139 |
+
)
|
140 |
+
if isinstance(annotations, torch.Tensor):
|
141 |
+
annotations = annotations.cpu().numpy()
|
142 |
+
if args.withContours == True:
|
143 |
+
contour_all = []
|
144 |
+
temp = np.zeros((original_h, original_w, 1))
|
145 |
+
for i, mask in enumerate(annotations):
|
146 |
+
if type(mask) == dict:
|
147 |
+
mask = mask["segmentation"]
|
148 |
+
annotation = mask.astype(np.uint8)
|
149 |
+
if args.retina == False:
|
150 |
+
annotation = cv2.resize(
|
151 |
+
annotation,
|
152 |
+
(original_w, original_h),
|
153 |
+
interpolation=cv2.INTER_NEAREST,
|
154 |
+
)
|
155 |
+
contours, hierarchy = cv2.findContours(
|
156 |
+
annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
157 |
+
)
|
158 |
+
for contour in contours:
|
159 |
+
contour_all.append(contour)
|
160 |
+
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
161 |
+
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
162 |
+
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
163 |
+
plt.imshow(contour_mask)
|
164 |
+
|
165 |
+
save_path = args.output
|
166 |
+
if not os.path.exists(save_path):
|
167 |
+
os.makedirs(save_path)
|
168 |
+
plt.axis("off")
|
169 |
+
fig = plt.gcf()
|
170 |
+
plt.draw()
|
171 |
+
buf = fig.canvas.tostring_rgb()
|
172 |
+
cols, rows = fig.canvas.get_width_height()
|
173 |
+
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
174 |
+
return img_array
|
175 |
+
# cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
# CPU post process
|
180 |
+
def fast_show_mask(
|
181 |
+
annotation,
|
182 |
+
ax,
|
183 |
+
random_color=False,
|
184 |
+
bbox=None,
|
185 |
+
points=None,
|
186 |
+
pointlabel=None,
|
187 |
+
retinamask=True,
|
188 |
+
target_height=960,
|
189 |
+
target_width=960,
|
190 |
+
):
|
191 |
+
msak_sum = annotation.shape[0]
|
192 |
+
height = annotation.shape[1]
|
193 |
+
weight = annotation.shape[2]
|
194 |
+
# 将annotation 按照面积 排序
|
195 |
+
areas = np.sum(annotation, axis=(1, 2))
|
196 |
+
sorted_indices = np.argsort(areas)
|
197 |
+
annotation = annotation[sorted_indices]
|
198 |
+
|
199 |
+
index = (annotation != 0).argmax(axis=0)
|
200 |
+
if random_color == True:
|
201 |
+
color = np.random.random((msak_sum, 1, 1, 3))
|
202 |
+
else:
|
203 |
+
color = np.ones((msak_sum, 1, 1, 3)) * np.array(
|
204 |
+
[30 / 255, 144 / 255, 255 / 255]
|
205 |
+
)
|
206 |
+
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
207 |
+
visual = np.concatenate([color, transparency], axis=-1)
|
208 |
+
mask_image = np.expand_dims(annotation, -1) * visual
|
209 |
+
|
210 |
+
show = np.zeros((height, weight, 4))
|
211 |
+
h_indices, w_indices = np.meshgrid(
|
212 |
+
np.arange(height), np.arange(weight), indexing="ij"
|
213 |
+
)
|
214 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
215 |
+
# 使用向量化索引更新show的值
|
216 |
+
show[h_indices, w_indices, :] = mask_image[indices]
|
217 |
+
if bbox is not None:
|
218 |
+
x1, y1, x2, y2 = bbox
|
219 |
+
ax.add_patch(
|
220 |
+
plt.Rectangle(
|
221 |
+
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
|
222 |
+
)
|
223 |
+
)
|
224 |
+
# draw point
|
225 |
+
if points is not None:
|
226 |
+
plt.scatter(
|
227 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
228 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
229 |
+
s=20,
|
230 |
+
c="y",
|
231 |
+
)
|
232 |
+
plt.scatter(
|
233 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
234 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
235 |
+
s=20,
|
236 |
+
c="m",
|
237 |
+
)
|
238 |
+
|
239 |
+
if retinamask == False:
|
240 |
+
show = cv2.resize(
|
241 |
+
show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
|
242 |
+
)
|
243 |
+
ax.imshow(show)
|
244 |
+
|
245 |
+
|
246 |
+
def fast_show_mask_gpu(
|
247 |
+
annotation,
|
248 |
+
ax,
|
249 |
+
random_color=False,
|
250 |
+
bbox=None,
|
251 |
+
points=None,
|
252 |
+
pointlabel=None,
|
253 |
+
retinamask=True,
|
254 |
+
target_height=960,
|
255 |
+
target_width=960,
|
256 |
+
):
|
257 |
+
msak_sum = annotation.shape[0]
|
258 |
+
height = annotation.shape[1]
|
259 |
+
weight = annotation.shape[2]
|
260 |
+
areas = torch.sum(annotation, dim=(1, 2))
|
261 |
+
sorted_indices = torch.argsort(areas, descending=False)
|
262 |
+
annotation = annotation[sorted_indices]
|
263 |
+
# 找每个位置第一个非零值下标
|
264 |
+
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
265 |
+
if random_color == True:
|
266 |
+
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
267 |
+
else:
|
268 |
+
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
|
269 |
+
[30 / 255, 144 / 255, 255 / 255]
|
270 |
+
).to(annotation.device)
|
271 |
+
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
272 |
+
visual = torch.cat([color, transparency], dim=-1)
|
273 |
+
mask_image = torch.unsqueeze(annotation, -1) * visual
|
274 |
+
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
|
275 |
+
show = torch.zeros((height, weight, 4)).to(annotation.device)
|
276 |
+
h_indices, w_indices = torch.meshgrid(
|
277 |
+
torch.arange(height), torch.arange(weight), indexing="ij"
|
278 |
+
)
|
279 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
280 |
+
# 使用向量化索引更新show的值
|
281 |
+
show[h_indices, w_indices, :] = mask_image[indices]
|
282 |
+
show_cpu = show.cpu().numpy()
|
283 |
+
if bbox is not None:
|
284 |
+
x1, y1, x2, y2 = bbox
|
285 |
+
ax.add_patch(
|
286 |
+
plt.Rectangle(
|
287 |
+
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
|
288 |
+
)
|
289 |
+
)
|
290 |
+
# draw point
|
291 |
+
if points is not None:
|
292 |
+
plt.scatter(
|
293 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
294 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
295 |
+
s=20,
|
296 |
+
c="y",
|
297 |
+
)
|
298 |
+
plt.scatter(
|
299 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
300 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
301 |
+
s=20,
|
302 |
+
c="m",
|
303 |
+
)
|
304 |
+
if retinamask == False:
|
305 |
+
show_cpu = cv2.resize(
|
306 |
+
show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
|
307 |
+
)
|
308 |
+
ax.imshow(show_cpu)
|
309 |
+
|
310 |
+
|
311 |
+
# clip
|
312 |
+
@torch.no_grad()
|
313 |
+
def retriev(
|
314 |
+
model, preprocess, elements, search_text: str, device
|
315 |
+
) -> int:
|
316 |
+
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
317 |
+
tokenized_text = clip.tokenize([search_text]).to(device)
|
318 |
+
stacked_images = torch.stack(preprocessed_images)
|
319 |
+
image_features = model.encode_image(stacked_images)
|
320 |
+
text_features = model.encode_text(tokenized_text)
|
321 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
322 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
323 |
+
probs = 100.0 * image_features @ text_features.T
|
324 |
+
return probs[:, 0].softmax(dim=0)
|
325 |
+
|
326 |
+
|
327 |
+
def crop_image(annotations, image_path):
|
328 |
+
image = Image.open(image_path)
|
329 |
+
ori_w, ori_h = image.size
|
330 |
+
mask_h, mask_w = annotations[0]["segmentation"].shape
|
331 |
+
if ori_w != mask_w or ori_h != mask_h:
|
332 |
+
image = image.resize((mask_w, mask_h))
|
333 |
+
cropped_boxes = []
|
334 |
+
cropped_images = []
|
335 |
+
not_crop = []
|
336 |
+
filter_id = []
|
337 |
+
# annotations, _ = filter_masks(annotations)
|
338 |
+
# filter_id = list(_)
|
339 |
+
for _, mask in enumerate(annotations):
|
340 |
+
if np.sum(mask["segmentation"]) <= 100:
|
341 |
+
filter_id.append(_)
|
342 |
+
continue
|
343 |
+
bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
|
344 |
+
cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
|
345 |
+
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
346 |
+
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
347 |
+
|
348 |
+
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
349 |
+
|
350 |
+
|
351 |
+
def box_prompt(masks, bbox, target_height, target_width):
|
352 |
+
h = masks.shape[1]
|
353 |
+
w = masks.shape[2]
|
354 |
+
if h != target_height or w != target_width:
|
355 |
+
bbox = [
|
356 |
+
int(bbox[0] * w / target_width),
|
357 |
+
int(bbox[1] * h / target_height),
|
358 |
+
int(bbox[2] * w / target_width),
|
359 |
+
int(bbox[3] * h / target_height),
|
360 |
+
]
|
361 |
+
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
362 |
+
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
363 |
+
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
364 |
+
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
365 |
+
|
366 |
+
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
367 |
+
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
368 |
+
|
369 |
+
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
370 |
+
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
371 |
+
|
372 |
+
union = bbox_area + orig_masks_area - masks_area
|
373 |
+
IoUs = masks_area / union
|
374 |
+
max_iou_index = torch.argmax(IoUs)
|
375 |
+
|
376 |
+
return masks[max_iou_index].cpu().numpy(), max_iou_index
|
377 |
+
|
378 |
+
|
379 |
+
def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理
|
380 |
+
h = masks[0]["segmentation"].shape[0]
|
381 |
+
w = masks[0]["segmentation"].shape[1]
|
382 |
+
if h != target_height or w != target_width:
|
383 |
+
points = [
|
384 |
+
[int(point[0] * w / target_width), int(point[1] * h / target_height)]
|
385 |
+
for point in points
|
386 |
+
]
|
387 |
+
onemask = np.zeros((h, w))
|
388 |
+
for i, annotation in enumerate(masks):
|
389 |
+
if type(annotation) == dict:
|
390 |
+
mask = annotation["segmentation"]
|
391 |
+
else:
|
392 |
+
mask = annotation
|
393 |
+
for i, point in enumerate(points):
|
394 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
395 |
+
onemask += mask
|
396 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
397 |
+
onemask -= mask
|
398 |
+
onemask = onemask >= 1
|
399 |
+
return onemask, 0
|
400 |
+
|
401 |
+
|
402 |
+
def text_prompt(annotations, args):
|
403 |
+
cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
|
404 |
+
annotations, args.img_path
|
405 |
+
)
|
406 |
+
clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
|
407 |
+
scores = retriev(
|
408 |
+
clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
|
409 |
+
)
|
410 |
+
max_idx = scores.argsort()
|
411 |
+
max_idx = max_idx[-1]
|
412 |
+
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
413 |
+
return annotaions[max_idx]["segmentation"], max_idx
|
EfficientSAM/LightHQSAM/example_light_hqsam.png
ADDED
Git LFS Details
|
EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg
ADDED
EfficientSAM/LightHQSAM/setup_light_hqsam.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from LightHQSAM.tiny_vit_sam import TinyViT
|
2 |
+
from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
3 |
+
|
4 |
+
def setup_model():
|
5 |
+
prompt_embed_dim = 256
|
6 |
+
image_size = 1024
|
7 |
+
vit_patch_size = 16
|
8 |
+
image_embedding_size = image_size // vit_patch_size
|
9 |
+
mobile_sam = Sam(
|
10 |
+
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
|
11 |
+
embed_dims=[64, 128, 160, 320],
|
12 |
+
depths=[2, 2, 6, 2],
|
13 |
+
num_heads=[2, 4, 5, 10],
|
14 |
+
window_sizes=[7, 7, 14, 7],
|
15 |
+
mlp_ratio=4.,
|
16 |
+
drop_rate=0.,
|
17 |
+
drop_path_rate=0.0,
|
18 |
+
use_checkpoint=False,
|
19 |
+
mbconv_expand_ratio=4.0,
|
20 |
+
local_conv_size=3,
|
21 |
+
layer_lr_decay=0.8
|
22 |
+
),
|
23 |
+
prompt_encoder=PromptEncoder(
|
24 |
+
embed_dim=prompt_embed_dim,
|
25 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
26 |
+
input_image_size=(image_size, image_size),
|
27 |
+
mask_in_chans=16,
|
28 |
+
),
|
29 |
+
mask_decoder=MaskDecoderHQ(
|
30 |
+
num_multimask_outputs=3,
|
31 |
+
transformer=TwoWayTransformer(
|
32 |
+
depth=2,
|
33 |
+
embedding_dim=prompt_embed_dim,
|
34 |
+
mlp_dim=2048,
|
35 |
+
num_heads=8,
|
36 |
+
),
|
37 |
+
transformer_dim=prompt_embed_dim,
|
38 |
+
iou_head_depth=3,
|
39 |
+
iou_head_hidden_dim=256,
|
40 |
+
vit_dim=160,
|
41 |
+
),
|
42 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
43 |
+
pixel_std=[58.395, 57.12, 57.375],
|
44 |
+
)
|
45 |
+
return mobile_sam
|
EfficientSAM/LightHQSAM/tiny_vit_sam.py
ADDED
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# TinyViT Model Architecture
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Adapted from LeViT and Swin Transformer
|
5 |
+
# LeViT: (https://github.com/facebookresearch/levit)
|
6 |
+
# Swin: (https://github.com/microsoft/swin-transformer)
|
7 |
+
# Build the TinyViT Model
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import itertools
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import DropPath as TimmDropPath,\
|
16 |
+
to_2tuple, trunc_normal_
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
|
21 |
+
class Conv2d_BN(torch.nn.Sequential):
|
22 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
23 |
+
groups=1, bn_weight_init=1):
|
24 |
+
super().__init__()
|
25 |
+
self.add_module('c', torch.nn.Conv2d(
|
26 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
27 |
+
bn = torch.nn.BatchNorm2d(b)
|
28 |
+
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
29 |
+
torch.nn.init.constant_(bn.bias, 0)
|
30 |
+
self.add_module('bn', bn)
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def fuse(self):
|
34 |
+
c, bn = self._modules.values()
|
35 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
36 |
+
w = c.weight * w[:, None, None, None]
|
37 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
38 |
+
(bn.running_var + bn.eps)**0.5
|
39 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
40 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
41 |
+
m.weight.data.copy_(w)
|
42 |
+
m.bias.data.copy_(b)
|
43 |
+
return m
|
44 |
+
|
45 |
+
|
46 |
+
class DropPath(TimmDropPath):
|
47 |
+
def __init__(self, drop_prob=None):
|
48 |
+
super().__init__(drop_prob=drop_prob)
|
49 |
+
self.drop_prob = drop_prob
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
msg = super().__repr__()
|
53 |
+
msg += f'(drop_prob={self.drop_prob})'
|
54 |
+
return msg
|
55 |
+
|
56 |
+
|
57 |
+
class PatchEmbed(nn.Module):
|
58 |
+
def __init__(self, in_chans, embed_dim, resolution, activation):
|
59 |
+
super().__init__()
|
60 |
+
img_size: Tuple[int, int] = to_2tuple(resolution)
|
61 |
+
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
62 |
+
self.num_patches = self.patches_resolution[0] * \
|
63 |
+
self.patches_resolution[1]
|
64 |
+
self.in_chans = in_chans
|
65 |
+
self.embed_dim = embed_dim
|
66 |
+
n = embed_dim
|
67 |
+
self.seq = nn.Sequential(
|
68 |
+
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
69 |
+
activation(),
|
70 |
+
Conv2d_BN(n // 2, n, 3, 2, 1),
|
71 |
+
)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
return self.seq(x)
|
75 |
+
|
76 |
+
|
77 |
+
class MBConv(nn.Module):
|
78 |
+
def __init__(self, in_chans, out_chans, expand_ratio,
|
79 |
+
activation, drop_path):
|
80 |
+
super().__init__()
|
81 |
+
self.in_chans = in_chans
|
82 |
+
self.hidden_chans = int(in_chans * expand_ratio)
|
83 |
+
self.out_chans = out_chans
|
84 |
+
|
85 |
+
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
86 |
+
self.act1 = activation()
|
87 |
+
|
88 |
+
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
|
89 |
+
ks=3, stride=1, pad=1, groups=self.hidden_chans)
|
90 |
+
self.act2 = activation()
|
91 |
+
|
92 |
+
self.conv3 = Conv2d_BN(
|
93 |
+
self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
94 |
+
self.act3 = activation()
|
95 |
+
|
96 |
+
self.drop_path = DropPath(
|
97 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
shortcut = x
|
101 |
+
|
102 |
+
x = self.conv1(x)
|
103 |
+
x = self.act1(x)
|
104 |
+
|
105 |
+
x = self.conv2(x)
|
106 |
+
x = self.act2(x)
|
107 |
+
|
108 |
+
x = self.conv3(x)
|
109 |
+
|
110 |
+
x = self.drop_path(x)
|
111 |
+
|
112 |
+
x += shortcut
|
113 |
+
x = self.act3(x)
|
114 |
+
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class PatchMerging(nn.Module):
|
119 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.input_resolution = input_resolution
|
123 |
+
self.dim = dim
|
124 |
+
self.out_dim = out_dim
|
125 |
+
self.act = activation()
|
126 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
127 |
+
stride_c=2
|
128 |
+
if(out_dim==320 or out_dim==448 or out_dim==576):
|
129 |
+
stride_c=1
|
130 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
131 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
if x.ndim == 3:
|
135 |
+
H, W = self.input_resolution
|
136 |
+
B = len(x)
|
137 |
+
# (B, C, H, W)
|
138 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
139 |
+
|
140 |
+
x = self.conv1(x)
|
141 |
+
x = self.act(x)
|
142 |
+
|
143 |
+
x = self.conv2(x)
|
144 |
+
x = self.act(x)
|
145 |
+
x = self.conv3(x)
|
146 |
+
x = x.flatten(2).transpose(1, 2)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class ConvLayer(nn.Module):
|
151 |
+
def __init__(self, dim, input_resolution, depth,
|
152 |
+
activation,
|
153 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
154 |
+
out_dim=None,
|
155 |
+
conv_expand_ratio=4.,
|
156 |
+
):
|
157 |
+
|
158 |
+
super().__init__()
|
159 |
+
self.dim = dim
|
160 |
+
self.input_resolution = input_resolution
|
161 |
+
self.depth = depth
|
162 |
+
self.use_checkpoint = use_checkpoint
|
163 |
+
|
164 |
+
# build blocks
|
165 |
+
self.blocks = nn.ModuleList([
|
166 |
+
MBConv(dim, dim, conv_expand_ratio, activation,
|
167 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
168 |
+
)
|
169 |
+
for i in range(depth)])
|
170 |
+
|
171 |
+
# patch merging layer
|
172 |
+
if downsample is not None:
|
173 |
+
self.downsample = downsample(
|
174 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
175 |
+
else:
|
176 |
+
self.downsample = None
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
for blk in self.blocks:
|
180 |
+
if self.use_checkpoint:
|
181 |
+
x = checkpoint.checkpoint(blk, x)
|
182 |
+
else:
|
183 |
+
x = blk(x)
|
184 |
+
if self.downsample is not None:
|
185 |
+
x = self.downsample(x)
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class Mlp(nn.Module):
|
190 |
+
def __init__(self, in_features, hidden_features=None,
|
191 |
+
out_features=None, act_layer=nn.GELU, drop=0.):
|
192 |
+
super().__init__()
|
193 |
+
out_features = out_features or in_features
|
194 |
+
hidden_features = hidden_features or in_features
|
195 |
+
self.norm = nn.LayerNorm(in_features)
|
196 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
197 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
198 |
+
self.act = act_layer()
|
199 |
+
self.drop = nn.Dropout(drop)
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
x = self.norm(x)
|
203 |
+
|
204 |
+
x = self.fc1(x)
|
205 |
+
x = self.act(x)
|
206 |
+
x = self.drop(x)
|
207 |
+
x = self.fc2(x)
|
208 |
+
x = self.drop(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class Attention(torch.nn.Module):
|
213 |
+
def __init__(self, dim, key_dim, num_heads=8,
|
214 |
+
attn_ratio=4,
|
215 |
+
resolution=(14, 14),
|
216 |
+
):
|
217 |
+
super().__init__()
|
218 |
+
# (h, w)
|
219 |
+
assert isinstance(resolution, tuple) and len(resolution) == 2
|
220 |
+
self.num_heads = num_heads
|
221 |
+
self.scale = key_dim ** -0.5
|
222 |
+
self.key_dim = key_dim
|
223 |
+
self.nh_kd = nh_kd = key_dim * num_heads
|
224 |
+
self.d = int(attn_ratio * key_dim)
|
225 |
+
self.dh = int(attn_ratio * key_dim) * num_heads
|
226 |
+
self.attn_ratio = attn_ratio
|
227 |
+
h = self.dh + nh_kd * 2
|
228 |
+
|
229 |
+
self.norm = nn.LayerNorm(dim)
|
230 |
+
self.qkv = nn.Linear(dim, h)
|
231 |
+
self.proj = nn.Linear(self.dh, dim)
|
232 |
+
|
233 |
+
points = list(itertools.product(
|
234 |
+
range(resolution[0]), range(resolution[1])))
|
235 |
+
N = len(points)
|
236 |
+
attention_offsets = {}
|
237 |
+
idxs = []
|
238 |
+
for p1 in points:
|
239 |
+
for p2 in points:
|
240 |
+
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
241 |
+
if offset not in attention_offsets:
|
242 |
+
attention_offsets[offset] = len(attention_offsets)
|
243 |
+
idxs.append(attention_offsets[offset])
|
244 |
+
self.attention_biases = torch.nn.Parameter(
|
245 |
+
torch.zeros(num_heads, len(attention_offsets)))
|
246 |
+
self.register_buffer('attention_bias_idxs',
|
247 |
+
torch.LongTensor(idxs).view(N, N),
|
248 |
+
persistent=False)
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def train(self, mode=True):
|
252 |
+
super().train(mode)
|
253 |
+
if mode and hasattr(self, 'ab'):
|
254 |
+
del self.ab
|
255 |
+
else:
|
256 |
+
self.register_buffer('ab',
|
257 |
+
self.attention_biases[:, self.attention_bias_idxs],
|
258 |
+
persistent=False)
|
259 |
+
|
260 |
+
def forward(self, x): # x (B,N,C)
|
261 |
+
B, N, _ = x.shape
|
262 |
+
|
263 |
+
# Normalization
|
264 |
+
x = self.norm(x)
|
265 |
+
|
266 |
+
qkv = self.qkv(x)
|
267 |
+
# (B, N, num_heads, d)
|
268 |
+
q, k, v = qkv.view(B, N, self.num_heads, -
|
269 |
+
1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
270 |
+
# (B, num_heads, N, d)
|
271 |
+
q = q.permute(0, 2, 1, 3)
|
272 |
+
k = k.permute(0, 2, 1, 3)
|
273 |
+
v = v.permute(0, 2, 1, 3)
|
274 |
+
|
275 |
+
attn = (
|
276 |
+
(q @ k.transpose(-2, -1)) * self.scale
|
277 |
+
+
|
278 |
+
(self.attention_biases[:, self.attention_bias_idxs]
|
279 |
+
if self.training else self.ab)
|
280 |
+
)
|
281 |
+
attn = attn.softmax(dim=-1)
|
282 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
283 |
+
x = self.proj(x)
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
class TinyViTBlock(nn.Module):
|
288 |
+
r""" TinyViT Block.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
dim (int): Number of input channels.
|
292 |
+
input_resolution (tuple[int, int]): Input resolution.
|
293 |
+
num_heads (int): Number of attention heads.
|
294 |
+
window_size (int): Window size.
|
295 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
296 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
297 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
298 |
+
local_conv_size (int): the kernel size of the convolution between
|
299 |
+
Attention and MLP. Default: 3
|
300 |
+
activation: the activation function. Default: nn.GELU
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7,
|
304 |
+
mlp_ratio=4., drop=0., drop_path=0.,
|
305 |
+
local_conv_size=3,
|
306 |
+
activation=nn.GELU,
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
self.dim = dim
|
310 |
+
self.input_resolution = input_resolution
|
311 |
+
self.num_heads = num_heads
|
312 |
+
assert window_size > 0, 'window_size must be greater than 0'
|
313 |
+
self.window_size = window_size
|
314 |
+
self.mlp_ratio = mlp_ratio
|
315 |
+
|
316 |
+
self.drop_path = DropPath(
|
317 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
318 |
+
|
319 |
+
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
320 |
+
head_dim = dim // num_heads
|
321 |
+
|
322 |
+
window_resolution = (window_size, window_size)
|
323 |
+
self.attn = Attention(dim, head_dim, num_heads,
|
324 |
+
attn_ratio=1, resolution=window_resolution)
|
325 |
+
|
326 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
327 |
+
mlp_activation = activation
|
328 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
329 |
+
act_layer=mlp_activation, drop=drop)
|
330 |
+
|
331 |
+
pad = local_conv_size // 2
|
332 |
+
self.local_conv = Conv2d_BN(
|
333 |
+
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
334 |
+
|
335 |
+
def forward(self, x):
|
336 |
+
H, W = self.input_resolution
|
337 |
+
B, L, C = x.shape
|
338 |
+
assert L == H * W, "input feature has wrong size"
|
339 |
+
res_x = x
|
340 |
+
if H == self.window_size and W == self.window_size:
|
341 |
+
x = self.attn(x)
|
342 |
+
else:
|
343 |
+
x = x.view(B, H, W, C)
|
344 |
+
pad_b = (self.window_size - H %
|
345 |
+
self.window_size) % self.window_size
|
346 |
+
pad_r = (self.window_size - W %
|
347 |
+
self.window_size) % self.window_size
|
348 |
+
padding = pad_b > 0 or pad_r > 0
|
349 |
+
|
350 |
+
if padding:
|
351 |
+
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
352 |
+
|
353 |
+
pH, pW = H + pad_b, W + pad_r
|
354 |
+
nH = pH // self.window_size
|
355 |
+
nW = pW // self.window_size
|
356 |
+
# window partition
|
357 |
+
x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
|
358 |
+
B * nH * nW, self.window_size * self.window_size, C)
|
359 |
+
x = self.attn(x)
|
360 |
+
# window reverse
|
361 |
+
x = x.view(B, nH, nW, self.window_size, self.window_size,
|
362 |
+
C).transpose(2, 3).reshape(B, pH, pW, C)
|
363 |
+
|
364 |
+
if padding:
|
365 |
+
x = x[:, :H, :W].contiguous()
|
366 |
+
|
367 |
+
x = x.view(B, L, C)
|
368 |
+
|
369 |
+
x = res_x + self.drop_path(x)
|
370 |
+
|
371 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
372 |
+
x = self.local_conv(x)
|
373 |
+
x = x.view(B, C, L).transpose(1, 2)
|
374 |
+
|
375 |
+
x = x + self.drop_path(self.mlp(x))
|
376 |
+
return x
|
377 |
+
|
378 |
+
def extra_repr(self) -> str:
|
379 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
380 |
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
381 |
+
|
382 |
+
|
383 |
+
class BasicLayer(nn.Module):
|
384 |
+
""" A basic TinyViT layer for one stage.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
dim (int): Number of input channels.
|
388 |
+
input_resolution (tuple[int]): Input resolution.
|
389 |
+
depth (int): Number of blocks.
|
390 |
+
num_heads (int): Number of attention heads.
|
391 |
+
window_size (int): Local window size.
|
392 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
393 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
394 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
395 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
396 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
397 |
+
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
398 |
+
activation: the activation function. Default: nn.GELU
|
399 |
+
out_dim: the output dimension of the layer. Default: dim
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
403 |
+
mlp_ratio=4., drop=0.,
|
404 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
405 |
+
local_conv_size=3,
|
406 |
+
activation=nn.GELU,
|
407 |
+
out_dim=None,
|
408 |
+
):
|
409 |
+
|
410 |
+
super().__init__()
|
411 |
+
self.dim = dim
|
412 |
+
self.input_resolution = input_resolution
|
413 |
+
self.depth = depth
|
414 |
+
self.use_checkpoint = use_checkpoint
|
415 |
+
|
416 |
+
# build blocks
|
417 |
+
self.blocks = nn.ModuleList([
|
418 |
+
TinyViTBlock(dim=dim, input_resolution=input_resolution,
|
419 |
+
num_heads=num_heads, window_size=window_size,
|
420 |
+
mlp_ratio=mlp_ratio,
|
421 |
+
drop=drop,
|
422 |
+
drop_path=drop_path[i] if isinstance(
|
423 |
+
drop_path, list) else drop_path,
|
424 |
+
local_conv_size=local_conv_size,
|
425 |
+
activation=activation,
|
426 |
+
)
|
427 |
+
for i in range(depth)])
|
428 |
+
|
429 |
+
# patch merging layer
|
430 |
+
if downsample is not None:
|
431 |
+
self.downsample = downsample(
|
432 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
433 |
+
else:
|
434 |
+
self.downsample = None
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
for blk in self.blocks:
|
438 |
+
if self.use_checkpoint:
|
439 |
+
x = checkpoint.checkpoint(blk, x)
|
440 |
+
else:
|
441 |
+
x = blk(x)
|
442 |
+
if self.downsample is not None:
|
443 |
+
x = self.downsample(x)
|
444 |
+
return x
|
445 |
+
|
446 |
+
def extra_repr(self) -> str:
|
447 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
448 |
+
|
449 |
+
class LayerNorm2d(nn.Module):
|
450 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
451 |
+
super().__init__()
|
452 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
453 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
454 |
+
self.eps = eps
|
455 |
+
|
456 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
457 |
+
u = x.mean(1, keepdim=True)
|
458 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
459 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
460 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
461 |
+
return x
|
462 |
+
class TinyViT(nn.Module):
|
463 |
+
def __init__(self, img_size=224, in_chans=3, num_classes=1000,
|
464 |
+
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
|
465 |
+
num_heads=[3, 6, 12, 24],
|
466 |
+
window_sizes=[7, 7, 14, 7],
|
467 |
+
mlp_ratio=4.,
|
468 |
+
drop_rate=0.,
|
469 |
+
drop_path_rate=0.1,
|
470 |
+
use_checkpoint=False,
|
471 |
+
mbconv_expand_ratio=4.0,
|
472 |
+
local_conv_size=3,
|
473 |
+
layer_lr_decay=1.0,
|
474 |
+
):
|
475 |
+
super().__init__()
|
476 |
+
self.img_size=img_size
|
477 |
+
self.num_classes = num_classes
|
478 |
+
self.depths = depths
|
479 |
+
self.num_layers = len(depths)
|
480 |
+
self.mlp_ratio = mlp_ratio
|
481 |
+
|
482 |
+
activation = nn.GELU
|
483 |
+
|
484 |
+
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
485 |
+
embed_dim=embed_dims[0],
|
486 |
+
resolution=img_size,
|
487 |
+
activation=activation)
|
488 |
+
|
489 |
+
patches_resolution = self.patch_embed.patches_resolution
|
490 |
+
self.patches_resolution = patches_resolution
|
491 |
+
|
492 |
+
# stochastic depth
|
493 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
494 |
+
sum(depths))] # stochastic depth decay rule
|
495 |
+
|
496 |
+
# build layers
|
497 |
+
self.layers = nn.ModuleList()
|
498 |
+
for i_layer in range(self.num_layers):
|
499 |
+
kwargs = dict(dim=embed_dims[i_layer],
|
500 |
+
input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
|
501 |
+
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
|
502 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
503 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
504 |
+
depth=depths[i_layer],
|
505 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
506 |
+
downsample=PatchMerging if (
|
507 |
+
i_layer < self.num_layers - 1) else None,
|
508 |
+
use_checkpoint=use_checkpoint,
|
509 |
+
out_dim=embed_dims[min(
|
510 |
+
i_layer + 1, len(embed_dims) - 1)],
|
511 |
+
activation=activation,
|
512 |
+
)
|
513 |
+
if i_layer == 0:
|
514 |
+
layer = ConvLayer(
|
515 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
516 |
+
**kwargs,
|
517 |
+
)
|
518 |
+
else:
|
519 |
+
layer = BasicLayer(
|
520 |
+
num_heads=num_heads[i_layer],
|
521 |
+
window_size=window_sizes[i_layer],
|
522 |
+
mlp_ratio=self.mlp_ratio,
|
523 |
+
drop=drop_rate,
|
524 |
+
local_conv_size=local_conv_size,
|
525 |
+
**kwargs)
|
526 |
+
self.layers.append(layer)
|
527 |
+
|
528 |
+
# Classifier head
|
529 |
+
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
530 |
+
self.head = nn.Linear(
|
531 |
+
embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
532 |
+
|
533 |
+
# init weights
|
534 |
+
self.apply(self._init_weights)
|
535 |
+
self.set_layer_lr_decay(layer_lr_decay)
|
536 |
+
self.neck = nn.Sequential(
|
537 |
+
nn.Conv2d(
|
538 |
+
embed_dims[-1],
|
539 |
+
256,
|
540 |
+
kernel_size=1,
|
541 |
+
bias=False,
|
542 |
+
),
|
543 |
+
LayerNorm2d(256),
|
544 |
+
nn.Conv2d(
|
545 |
+
256,
|
546 |
+
256,
|
547 |
+
kernel_size=3,
|
548 |
+
padding=1,
|
549 |
+
bias=False,
|
550 |
+
),
|
551 |
+
LayerNorm2d(256),
|
552 |
+
)
|
553 |
+
def set_layer_lr_decay(self, layer_lr_decay):
|
554 |
+
decay_rate = layer_lr_decay
|
555 |
+
|
556 |
+
# layers -> blocks (depth)
|
557 |
+
depth = sum(self.depths)
|
558 |
+
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
559 |
+
#print("LR SCALES:", lr_scales)
|
560 |
+
|
561 |
+
def _set_lr_scale(m, scale):
|
562 |
+
for p in m.parameters():
|
563 |
+
p.lr_scale = scale
|
564 |
+
|
565 |
+
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
566 |
+
i = 0
|
567 |
+
for layer in self.layers:
|
568 |
+
for block in layer.blocks:
|
569 |
+
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
570 |
+
i += 1
|
571 |
+
if layer.downsample is not None:
|
572 |
+
layer.downsample.apply(
|
573 |
+
lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
574 |
+
assert i == depth
|
575 |
+
for m in [self.norm_head, self.head]:
|
576 |
+
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
577 |
+
|
578 |
+
for k, p in self.named_parameters():
|
579 |
+
p.param_name = k
|
580 |
+
|
581 |
+
def _check_lr_scale(m):
|
582 |
+
for p in m.parameters():
|
583 |
+
assert hasattr(p, 'lr_scale'), p.param_name
|
584 |
+
|
585 |
+
self.apply(_check_lr_scale)
|
586 |
+
|
587 |
+
def _init_weights(self, m):
|
588 |
+
if isinstance(m, nn.Linear):
|
589 |
+
trunc_normal_(m.weight, std=.02)
|
590 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
591 |
+
nn.init.constant_(m.bias, 0)
|
592 |
+
elif isinstance(m, nn.LayerNorm):
|
593 |
+
nn.init.constant_(m.bias, 0)
|
594 |
+
nn.init.constant_(m.weight, 1.0)
|
595 |
+
|
596 |
+
@torch.jit.ignore
|
597 |
+
def no_weight_decay_keywords(self):
|
598 |
+
return {'attention_biases'}
|
599 |
+
|
600 |
+
def forward_features(self, x):
|
601 |
+
# x: (N, C, H, W)
|
602 |
+
x = self.patch_embed(x)
|
603 |
+
|
604 |
+
x = self.layers[0](x)
|
605 |
+
start_i = 1
|
606 |
+
|
607 |
+
interm_embeddings=[]
|
608 |
+
for i in range(start_i, len(self.layers)):
|
609 |
+
layer = self.layers[i]
|
610 |
+
x = layer(x)
|
611 |
+
# print('x shape:', x.shape, '---i:', i)
|
612 |
+
if i == 1:
|
613 |
+
interm_embeddings.append(x.view(x.shape[0], 64, 64, -1))
|
614 |
+
|
615 |
+
B,_,C=x.size()
|
616 |
+
x = x.view(B, 64, 64, C)
|
617 |
+
x=x.permute(0, 3, 1, 2)
|
618 |
+
x=self.neck(x)
|
619 |
+
return x, interm_embeddings
|
620 |
+
|
621 |
+
def forward(self, x):
|
622 |
+
x, interm_embeddings = self.forward_features(x)
|
623 |
+
#x = self.norm_head(x)
|
624 |
+
#x = self.head(x)
|
625 |
+
# print('come to here is correct'* 3)
|
626 |
+
return x, interm_embeddings
|
627 |
+
|
628 |
+
|
629 |
+
_checkpoint_url_format = \
|
630 |
+
'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth'
|
631 |
+
_provided_checkpoints = {
|
632 |
+
'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill',
|
633 |
+
'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill',
|
634 |
+
'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill',
|
635 |
+
'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill',
|
636 |
+
'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill',
|
637 |
+
}
|
638 |
+
|
639 |
+
|
640 |
+
def register_tiny_vit_model(fn):
|
641 |
+
'''Register a TinyViT model
|
642 |
+
It is a wrapper of `register_model` with loading the pretrained checkpoint.
|
643 |
+
'''
|
644 |
+
def fn_wrapper(pretrained=False, **kwargs):
|
645 |
+
model = fn()
|
646 |
+
if pretrained:
|
647 |
+
model_name = fn.__name__
|
648 |
+
assert model_name in _provided_checkpoints, \
|
649 |
+
f'Sorry that the checkpoint `{model_name}` is not provided yet.'
|
650 |
+
url = _checkpoint_url_format.format(
|
651 |
+
_provided_checkpoints[model_name])
|
652 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
653 |
+
url=url,
|
654 |
+
map_location='cpu', check_hash=False,
|
655 |
+
)
|
656 |
+
model.load_state_dict(checkpoint['model'])
|
657 |
+
|
658 |
+
return model
|
659 |
+
|
660 |
+
# rename the name of fn_wrapper
|
661 |
+
fn_wrapper.__name__ = fn.__name__
|
662 |
+
return register_model(fn_wrapper)
|
663 |
+
|
664 |
+
|
665 |
+
@register_tiny_vit_model
|
666 |
+
def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0):
|
667 |
+
return TinyViT(
|
668 |
+
num_classes=num_classes,
|
669 |
+
embed_dims=[64, 128, 160, 320],
|
670 |
+
depths=[2, 2, 6, 2],
|
671 |
+
num_heads=[2, 4, 5, 10],
|
672 |
+
window_sizes=[7, 7, 14, 7],
|
673 |
+
drop_path_rate=drop_path_rate,
|
674 |
+
)
|
675 |
+
|
676 |
+
|
677 |
+
@register_tiny_vit_model
|
678 |
+
def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
679 |
+
return TinyViT(
|
680 |
+
num_classes=num_classes,
|
681 |
+
embed_dims=[64, 128, 256, 448],
|
682 |
+
depths=[2, 2, 6, 2],
|
683 |
+
num_heads=[2, 4, 8, 14],
|
684 |
+
window_sizes=[7, 7, 14, 7],
|
685 |
+
drop_path_rate=drop_path_rate,
|
686 |
+
)
|
687 |
+
|
688 |
+
|
689 |
+
@register_tiny_vit_model
|
690 |
+
def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2):
|
691 |
+
return TinyViT(
|
692 |
+
num_classes=num_classes,
|
693 |
+
embed_dims=[96, 192, 384, 576],
|
694 |
+
depths=[2, 2, 6, 2],
|
695 |
+
num_heads=[3, 6, 12, 18],
|
696 |
+
window_sizes=[7, 7, 14, 7],
|
697 |
+
drop_path_rate=drop_path_rate,
|
698 |
+
)
|
699 |
+
|
700 |
+
|
701 |
+
@register_tiny_vit_model
|
702 |
+
def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
703 |
+
return TinyViT(
|
704 |
+
img_size=384,
|
705 |
+
num_classes=num_classes,
|
706 |
+
embed_dims=[96, 192, 384, 576],
|
707 |
+
depths=[2, 2, 6, 2],
|
708 |
+
num_heads=[3, 6, 12, 18],
|
709 |
+
window_sizes=[12, 12, 24, 12],
|
710 |
+
drop_path_rate=drop_path_rate,
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
@register_tiny_vit_model
|
715 |
+
def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
716 |
+
return TinyViT(
|
717 |
+
img_size=512,
|
718 |
+
num_classes=num_classes,
|
719 |
+
embed_dims=[96, 192, 384, 576],
|
720 |
+
depths=[2, 2, 6, 2],
|
721 |
+
num_heads=[3, 6, 12, 18],
|
722 |
+
window_sizes=[16, 16, 32, 16],
|
723 |
+
drop_path_rate=drop_path_rate,
|
724 |
+
)
|
EfficientSAM/MobileSAM/setup_mobile_sam.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from MobileSAM.tiny_vit_sam import TinyViT
|
2 |
+
from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
3 |
+
|
4 |
+
def setup_model():
|
5 |
+
prompt_embed_dim = 256
|
6 |
+
image_size = 1024
|
7 |
+
vit_patch_size = 16
|
8 |
+
image_embedding_size = image_size // vit_patch_size
|
9 |
+
mobile_sam = Sam(
|
10 |
+
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
|
11 |
+
embed_dims=[64, 128, 160, 320],
|
12 |
+
depths=[2, 2, 6, 2],
|
13 |
+
num_heads=[2, 4, 5, 10],
|
14 |
+
window_sizes=[7, 7, 14, 7],
|
15 |
+
mlp_ratio=4.,
|
16 |
+
drop_rate=0.,
|
17 |
+
drop_path_rate=0.0,
|
18 |
+
use_checkpoint=False,
|
19 |
+
mbconv_expand_ratio=4.0,
|
20 |
+
local_conv_size=3,
|
21 |
+
layer_lr_decay=0.8
|
22 |
+
),
|
23 |
+
prompt_encoder=PromptEncoder(
|
24 |
+
embed_dim=prompt_embed_dim,
|
25 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
26 |
+
input_image_size=(image_size, image_size),
|
27 |
+
mask_in_chans=16,
|
28 |
+
),
|
29 |
+
mask_decoder=MaskDecoder(
|
30 |
+
num_multimask_outputs=3,
|
31 |
+
transformer=TwoWayTransformer(
|
32 |
+
depth=2,
|
33 |
+
embedding_dim=prompt_embed_dim,
|
34 |
+
mlp_dim=2048,
|
35 |
+
num_heads=8,
|
36 |
+
),
|
37 |
+
transformer_dim=prompt_embed_dim,
|
38 |
+
iou_head_depth=3,
|
39 |
+
iou_head_hidden_dim=256,
|
40 |
+
),
|
41 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
42 |
+
pixel_std=[58.395, 57.12, 57.375],
|
43 |
+
)
|
44 |
+
return mobile_sam
|
EfficientSAM/MobileSAM/tiny_vit_sam.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# TinyViT Model Architecture
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Adapted from LeViT and Swin Transformer
|
5 |
+
# LeViT: (https://github.com/facebookresearch/levit)
|
6 |
+
# Swin: (https://github.com/microsoft/swin-transformer)
|
7 |
+
# Build the TinyViT Model
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import itertools
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import DropPath as TimmDropPath,\
|
16 |
+
to_2tuple, trunc_normal_
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
|
21 |
+
class Conv2d_BN(torch.nn.Sequential):
|
22 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
23 |
+
groups=1, bn_weight_init=1):
|
24 |
+
super().__init__()
|
25 |
+
self.add_module('c', torch.nn.Conv2d(
|
26 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
27 |
+
bn = torch.nn.BatchNorm2d(b)
|
28 |
+
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
29 |
+
torch.nn.init.constant_(bn.bias, 0)
|
30 |
+
self.add_module('bn', bn)
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def fuse(self):
|
34 |
+
c, bn = self._modules.values()
|
35 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
36 |
+
w = c.weight * w[:, None, None, None]
|
37 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
38 |
+
(bn.running_var + bn.eps)**0.5
|
39 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
40 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
41 |
+
m.weight.data.copy_(w)
|
42 |
+
m.bias.data.copy_(b)
|
43 |
+
return m
|
44 |
+
|
45 |
+
|
46 |
+
class DropPath(TimmDropPath):
|
47 |
+
def __init__(self, drop_prob=None):
|
48 |
+
super().__init__(drop_prob=drop_prob)
|
49 |
+
self.drop_prob = drop_prob
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
msg = super().__repr__()
|
53 |
+
msg += f'(drop_prob={self.drop_prob})'
|
54 |
+
return msg
|
55 |
+
|
56 |
+
|
57 |
+
class PatchEmbed(nn.Module):
|
58 |
+
def __init__(self, in_chans, embed_dim, resolution, activation):
|
59 |
+
super().__init__()
|
60 |
+
img_size: Tuple[int, int] = to_2tuple(resolution)
|
61 |
+
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
62 |
+
self.num_patches = self.patches_resolution[0] * \
|
63 |
+
self.patches_resolution[1]
|
64 |
+
self.in_chans = in_chans
|
65 |
+
self.embed_dim = embed_dim
|
66 |
+
n = embed_dim
|
67 |
+
self.seq = nn.Sequential(
|
68 |
+
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
69 |
+
activation(),
|
70 |
+
Conv2d_BN(n // 2, n, 3, 2, 1),
|
71 |
+
)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
return self.seq(x)
|
75 |
+
|
76 |
+
|
77 |
+
class MBConv(nn.Module):
|
78 |
+
def __init__(self, in_chans, out_chans, expand_ratio,
|
79 |
+
activation, drop_path):
|
80 |
+
super().__init__()
|
81 |
+
self.in_chans = in_chans
|
82 |
+
self.hidden_chans = int(in_chans * expand_ratio)
|
83 |
+
self.out_chans = out_chans
|
84 |
+
|
85 |
+
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
86 |
+
self.act1 = activation()
|
87 |
+
|
88 |
+
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
|
89 |
+
ks=3, stride=1, pad=1, groups=self.hidden_chans)
|
90 |
+
self.act2 = activation()
|
91 |
+
|
92 |
+
self.conv3 = Conv2d_BN(
|
93 |
+
self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
94 |
+
self.act3 = activation()
|
95 |
+
|
96 |
+
self.drop_path = DropPath(
|
97 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
shortcut = x
|
101 |
+
|
102 |
+
x = self.conv1(x)
|
103 |
+
x = self.act1(x)
|
104 |
+
|
105 |
+
x = self.conv2(x)
|
106 |
+
x = self.act2(x)
|
107 |
+
|
108 |
+
x = self.conv3(x)
|
109 |
+
|
110 |
+
x = self.drop_path(x)
|
111 |
+
|
112 |
+
x += shortcut
|
113 |
+
x = self.act3(x)
|
114 |
+
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class PatchMerging(nn.Module):
|
119 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.input_resolution = input_resolution
|
123 |
+
self.dim = dim
|
124 |
+
self.out_dim = out_dim
|
125 |
+
self.act = activation()
|
126 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
127 |
+
stride_c=2
|
128 |
+
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
|
129 |
+
stride_c=1
|
130 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
131 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
if x.ndim == 3:
|
135 |
+
H, W = self.input_resolution
|
136 |
+
B = len(x)
|
137 |
+
# (B, C, H, W)
|
138 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
139 |
+
|
140 |
+
x = self.conv1(x)
|
141 |
+
x = self.act(x)
|
142 |
+
|
143 |
+
x = self.conv2(x)
|
144 |
+
x = self.act(x)
|
145 |
+
x = self.conv3(x)
|
146 |
+
x = x.flatten(2).transpose(1, 2)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class ConvLayer(nn.Module):
|
151 |
+
def __init__(self, dim, input_resolution, depth,
|
152 |
+
activation,
|
153 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
154 |
+
out_dim=None,
|
155 |
+
conv_expand_ratio=4.,
|
156 |
+
):
|
157 |
+
|
158 |
+
super().__init__()
|
159 |
+
self.dim = dim
|
160 |
+
self.input_resolution = input_resolution
|
161 |
+
self.depth = depth
|
162 |
+
self.use_checkpoint = use_checkpoint
|
163 |
+
|
164 |
+
# build blocks
|
165 |
+
self.blocks = nn.ModuleList([
|
166 |
+
MBConv(dim, dim, conv_expand_ratio, activation,
|
167 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
168 |
+
)
|
169 |
+
for i in range(depth)])
|
170 |
+
|
171 |
+
# patch merging layer
|
172 |
+
if downsample is not None:
|
173 |
+
self.downsample = downsample(
|
174 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
175 |
+
else:
|
176 |
+
self.downsample = None
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
for blk in self.blocks:
|
180 |
+
if self.use_checkpoint:
|
181 |
+
x = checkpoint.checkpoint(blk, x)
|
182 |
+
else:
|
183 |
+
x = blk(x)
|
184 |
+
if self.downsample is not None:
|
185 |
+
x = self.downsample(x)
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class Mlp(nn.Module):
|
190 |
+
def __init__(self, in_features, hidden_features=None,
|
191 |
+
out_features=None, act_layer=nn.GELU, drop=0.):
|
192 |
+
super().__init__()
|
193 |
+
out_features = out_features or in_features
|
194 |
+
hidden_features = hidden_features or in_features
|
195 |
+
self.norm = nn.LayerNorm(in_features)
|
196 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
197 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
198 |
+
self.act = act_layer()
|
199 |
+
self.drop = nn.Dropout(drop)
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
x = self.norm(x)
|
203 |
+
|
204 |
+
x = self.fc1(x)
|
205 |
+
x = self.act(x)
|
206 |
+
x = self.drop(x)
|
207 |
+
x = self.fc2(x)
|
208 |
+
x = self.drop(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class Attention(torch.nn.Module):
|
213 |
+
def __init__(self, dim, key_dim, num_heads=8,
|
214 |
+
attn_ratio=4,
|
215 |
+
resolution=(14, 14),
|
216 |
+
):
|
217 |
+
super().__init__()
|
218 |
+
# (h, w)
|
219 |
+
assert isinstance(resolution, tuple) and len(resolution) == 2
|
220 |
+
self.num_heads = num_heads
|
221 |
+
self.scale = key_dim ** -0.5
|
222 |
+
self.key_dim = key_dim
|
223 |
+
self.nh_kd = nh_kd = key_dim * num_heads
|
224 |
+
self.d = int(attn_ratio * key_dim)
|
225 |
+
self.dh = int(attn_ratio * key_dim) * num_heads
|
226 |
+
self.attn_ratio = attn_ratio
|
227 |
+
h = self.dh + nh_kd * 2
|
228 |
+
|
229 |
+
self.norm = nn.LayerNorm(dim)
|
230 |
+
self.qkv = nn.Linear(dim, h)
|
231 |
+
self.proj = nn.Linear(self.dh, dim)
|
232 |
+
|
233 |
+
points = list(itertools.product(
|
234 |
+
range(resolution[0]), range(resolution[1])))
|
235 |
+
N = len(points)
|
236 |
+
attention_offsets = {}
|
237 |
+
idxs = []
|
238 |
+
for p1 in points:
|
239 |
+
for p2 in points:
|
240 |
+
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
241 |
+
if offset not in attention_offsets:
|
242 |
+
attention_offsets[offset] = len(attention_offsets)
|
243 |
+
idxs.append(attention_offsets[offset])
|
244 |
+
self.attention_biases = torch.nn.Parameter(
|
245 |
+
torch.zeros(num_heads, len(attention_offsets)))
|
246 |
+
self.register_buffer('attention_bias_idxs',
|
247 |
+
torch.LongTensor(idxs).view(N, N),
|
248 |
+
persistent=False)
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def train(self, mode=True):
|
252 |
+
super().train(mode)
|
253 |
+
if mode and hasattr(self, 'ab'):
|
254 |
+
del self.ab
|
255 |
+
else:
|
256 |
+
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
257 |
+
|
258 |
+
def forward(self, x): # x (B,N,C)
|
259 |
+
B, N, _ = x.shape
|
260 |
+
|
261 |
+
# Normalization
|
262 |
+
x = self.norm(x)
|
263 |
+
|
264 |
+
qkv = self.qkv(x)
|
265 |
+
# (B, N, num_heads, d)
|
266 |
+
q, k, v = qkv.view(B, N, self.num_heads, -
|
267 |
+
1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
268 |
+
# (B, num_heads, N, d)
|
269 |
+
q = q.permute(0, 2, 1, 3)
|
270 |
+
k = k.permute(0, 2, 1, 3)
|
271 |
+
v = v.permute(0, 2, 1, 3)
|
272 |
+
|
273 |
+
attn = (
|
274 |
+
(q @ k.transpose(-2, -1)) * self.scale
|
275 |
+
+
|
276 |
+
(self.attention_biases[:, self.attention_bias_idxs]
|
277 |
+
if self.training else self.ab)
|
278 |
+
)
|
279 |
+
attn = attn.softmax(dim=-1)
|
280 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
281 |
+
x = self.proj(x)
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class TinyViTBlock(nn.Module):
|
286 |
+
r""" TinyViT Block.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
dim (int): Number of input channels.
|
290 |
+
input_resolution (tuple[int, int]): Input resulotion.
|
291 |
+
num_heads (int): Number of attention heads.
|
292 |
+
window_size (int): Window size.
|
293 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
294 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
295 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
296 |
+
local_conv_size (int): the kernel size of the convolution between
|
297 |
+
Attention and MLP. Default: 3
|
298 |
+
activation: the activation function. Default: nn.GELU
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7,
|
302 |
+
mlp_ratio=4., drop=0., drop_path=0.,
|
303 |
+
local_conv_size=3,
|
304 |
+
activation=nn.GELU,
|
305 |
+
):
|
306 |
+
super().__init__()
|
307 |
+
self.dim = dim
|
308 |
+
self.input_resolution = input_resolution
|
309 |
+
self.num_heads = num_heads
|
310 |
+
assert window_size > 0, 'window_size must be greater than 0'
|
311 |
+
self.window_size = window_size
|
312 |
+
self.mlp_ratio = mlp_ratio
|
313 |
+
|
314 |
+
self.drop_path = DropPath(
|
315 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
316 |
+
|
317 |
+
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
318 |
+
head_dim = dim // num_heads
|
319 |
+
|
320 |
+
window_resolution = (window_size, window_size)
|
321 |
+
self.attn = Attention(dim, head_dim, num_heads,
|
322 |
+
attn_ratio=1, resolution=window_resolution)
|
323 |
+
|
324 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
325 |
+
mlp_activation = activation
|
326 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
327 |
+
act_layer=mlp_activation, drop=drop)
|
328 |
+
|
329 |
+
pad = local_conv_size // 2
|
330 |
+
self.local_conv = Conv2d_BN(
|
331 |
+
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
332 |
+
|
333 |
+
def forward(self, x):
|
334 |
+
H, W = self.input_resolution
|
335 |
+
B, L, C = x.shape
|
336 |
+
assert L == H * W, "input feature has wrong size"
|
337 |
+
res_x = x
|
338 |
+
if H == self.window_size and W == self.window_size:
|
339 |
+
x = self.attn(x)
|
340 |
+
else:
|
341 |
+
x = x.view(B, H, W, C)
|
342 |
+
pad_b = (self.window_size - H %
|
343 |
+
self.window_size) % self.window_size
|
344 |
+
pad_r = (self.window_size - W %
|
345 |
+
self.window_size) % self.window_size
|
346 |
+
padding = pad_b > 0 or pad_r > 0
|
347 |
+
|
348 |
+
if padding:
|
349 |
+
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
350 |
+
|
351 |
+
pH, pW = H + pad_b, W + pad_r
|
352 |
+
nH = pH // self.window_size
|
353 |
+
nW = pW // self.window_size
|
354 |
+
# window partition
|
355 |
+
x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
|
356 |
+
B * nH * nW, self.window_size * self.window_size, C)
|
357 |
+
x = self.attn(x)
|
358 |
+
# window reverse
|
359 |
+
x = x.view(B, nH, nW, self.window_size, self.window_size,
|
360 |
+
C).transpose(2, 3).reshape(B, pH, pW, C)
|
361 |
+
|
362 |
+
if padding:
|
363 |
+
x = x[:, :H, :W].contiguous()
|
364 |
+
|
365 |
+
x = x.view(B, L, C)
|
366 |
+
|
367 |
+
x = res_x + self.drop_path(x)
|
368 |
+
|
369 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
370 |
+
x = self.local_conv(x)
|
371 |
+
x = x.view(B, C, L).transpose(1, 2)
|
372 |
+
|
373 |
+
x = x + self.drop_path(self.mlp(x))
|
374 |
+
return x
|
375 |
+
|
376 |
+
def extra_repr(self) -> str:
|
377 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
378 |
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
379 |
+
|
380 |
+
|
381 |
+
class BasicLayer(nn.Module):
|
382 |
+
""" A basic TinyViT layer for one stage.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
dim (int): Number of input channels.
|
386 |
+
input_resolution (tuple[int]): Input resolution.
|
387 |
+
depth (int): Number of blocks.
|
388 |
+
num_heads (int): Number of attention heads.
|
389 |
+
window_size (int): Local window size.
|
390 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
391 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
392 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
393 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
394 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
395 |
+
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
396 |
+
activation: the activation function. Default: nn.GELU
|
397 |
+
out_dim: the output dimension of the layer. Default: dim
|
398 |
+
"""
|
399 |
+
|
400 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
401 |
+
mlp_ratio=4., drop=0.,
|
402 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
403 |
+
local_conv_size=3,
|
404 |
+
activation=nn.GELU,
|
405 |
+
out_dim=None,
|
406 |
+
):
|
407 |
+
|
408 |
+
super().__init__()
|
409 |
+
self.dim = dim
|
410 |
+
self.input_resolution = input_resolution
|
411 |
+
self.depth = depth
|
412 |
+
self.use_checkpoint = use_checkpoint
|
413 |
+
|
414 |
+
# build blocks
|
415 |
+
self.blocks = nn.ModuleList([
|
416 |
+
TinyViTBlock(dim=dim, input_resolution=input_resolution,
|
417 |
+
num_heads=num_heads, window_size=window_size,
|
418 |
+
mlp_ratio=mlp_ratio,
|
419 |
+
drop=drop,
|
420 |
+
drop_path=drop_path[i] if isinstance(
|
421 |
+
drop_path, list) else drop_path,
|
422 |
+
local_conv_size=local_conv_size,
|
423 |
+
activation=activation,
|
424 |
+
)
|
425 |
+
for i in range(depth)])
|
426 |
+
|
427 |
+
# patch merging layer
|
428 |
+
if downsample is not None:
|
429 |
+
self.downsample = downsample(
|
430 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
431 |
+
else:
|
432 |
+
self.downsample = None
|
433 |
+
|
434 |
+
def forward(self, x):
|
435 |
+
for blk in self.blocks:
|
436 |
+
if self.use_checkpoint:
|
437 |
+
x = checkpoint.checkpoint(blk, x)
|
438 |
+
else:
|
439 |
+
x = blk(x)
|
440 |
+
if self.downsample is not None:
|
441 |
+
x = self.downsample(x)
|
442 |
+
return x
|
443 |
+
|
444 |
+
def extra_repr(self) -> str:
|
445 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
446 |
+
|
447 |
+
class LayerNorm2d(nn.Module):
|
448 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
449 |
+
super().__init__()
|
450 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
451 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
452 |
+
self.eps = eps
|
453 |
+
|
454 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
455 |
+
u = x.mean(1, keepdim=True)
|
456 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
457 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
458 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
459 |
+
return x
|
460 |
+
class TinyViT(nn.Module):
|
461 |
+
def __init__(self, img_size=224, in_chans=3, num_classes=1000,
|
462 |
+
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
|
463 |
+
num_heads=[3, 6, 12, 24],
|
464 |
+
window_sizes=[7, 7, 14, 7],
|
465 |
+
mlp_ratio=4.,
|
466 |
+
drop_rate=0.,
|
467 |
+
drop_path_rate=0.1,
|
468 |
+
use_checkpoint=False,
|
469 |
+
mbconv_expand_ratio=4.0,
|
470 |
+
local_conv_size=3,
|
471 |
+
layer_lr_decay=1.0,
|
472 |
+
):
|
473 |
+
super().__init__()
|
474 |
+
self.img_size=img_size
|
475 |
+
self.num_classes = num_classes
|
476 |
+
self.depths = depths
|
477 |
+
self.num_layers = len(depths)
|
478 |
+
self.mlp_ratio = mlp_ratio
|
479 |
+
|
480 |
+
activation = nn.GELU
|
481 |
+
|
482 |
+
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
483 |
+
embed_dim=embed_dims[0],
|
484 |
+
resolution=img_size,
|
485 |
+
activation=activation)
|
486 |
+
|
487 |
+
patches_resolution = self.patch_embed.patches_resolution
|
488 |
+
self.patches_resolution = patches_resolution
|
489 |
+
|
490 |
+
# stochastic depth
|
491 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
492 |
+
sum(depths))] # stochastic depth decay rule
|
493 |
+
|
494 |
+
# build layers
|
495 |
+
self.layers = nn.ModuleList()
|
496 |
+
for i_layer in range(self.num_layers):
|
497 |
+
kwargs = dict(dim=embed_dims[i_layer],
|
498 |
+
input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
|
499 |
+
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
|
500 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
501 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
502 |
+
depth=depths[i_layer],
|
503 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
504 |
+
downsample=PatchMerging if (
|
505 |
+
i_layer < self.num_layers - 1) else None,
|
506 |
+
use_checkpoint=use_checkpoint,
|
507 |
+
out_dim=embed_dims[min(
|
508 |
+
i_layer + 1, len(embed_dims) - 1)],
|
509 |
+
activation=activation,
|
510 |
+
)
|
511 |
+
if i_layer == 0:
|
512 |
+
layer = ConvLayer(
|
513 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
514 |
+
**kwargs,
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
layer = BasicLayer(
|
518 |
+
num_heads=num_heads[i_layer],
|
519 |
+
window_size=window_sizes[i_layer],
|
520 |
+
mlp_ratio=self.mlp_ratio,
|
521 |
+
drop=drop_rate,
|
522 |
+
local_conv_size=local_conv_size,
|
523 |
+
**kwargs)
|
524 |
+
self.layers.append(layer)
|
525 |
+
|
526 |
+
# Classifier head
|
527 |
+
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
528 |
+
self.head = nn.Linear(
|
529 |
+
embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
530 |
+
|
531 |
+
# init weights
|
532 |
+
self.apply(self._init_weights)
|
533 |
+
self.set_layer_lr_decay(layer_lr_decay)
|
534 |
+
self.neck = nn.Sequential(
|
535 |
+
nn.Conv2d(
|
536 |
+
embed_dims[-1],#handongshen
|
537 |
+
256,
|
538 |
+
kernel_size=1,
|
539 |
+
bias=False,
|
540 |
+
),
|
541 |
+
LayerNorm2d(256),
|
542 |
+
nn.Conv2d(
|
543 |
+
256,
|
544 |
+
256,
|
545 |
+
kernel_size=3,
|
546 |
+
padding=1,
|
547 |
+
bias=False,
|
548 |
+
),
|
549 |
+
LayerNorm2d(256),
|
550 |
+
)
|
551 |
+
def set_layer_lr_decay(self, layer_lr_decay):
|
552 |
+
decay_rate = layer_lr_decay
|
553 |
+
|
554 |
+
# layers -> blocks (depth)
|
555 |
+
depth = sum(self.depths)
|
556 |
+
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
557 |
+
print("LR SCALES:", lr_scales)
|
558 |
+
|
559 |
+
def _set_lr_scale(m, scale):
|
560 |
+
for p in m.parameters():
|
561 |
+
p.lr_scale = scale
|
562 |
+
|
563 |
+
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
564 |
+
i = 0
|
565 |
+
for layer in self.layers:
|
566 |
+
for block in layer.blocks:
|
567 |
+
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
568 |
+
i += 1
|
569 |
+
if layer.downsample is not None:
|
570 |
+
layer.downsample.apply(
|
571 |
+
lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
572 |
+
assert i == depth
|
573 |
+
for m in [self.norm_head, self.head]:
|
574 |
+
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
575 |
+
|
576 |
+
for k, p in self.named_parameters():
|
577 |
+
p.param_name = k
|
578 |
+
|
579 |
+
def _check_lr_scale(m):
|
580 |
+
for p in m.parameters():
|
581 |
+
assert hasattr(p, 'lr_scale'), p.param_name
|
582 |
+
|
583 |
+
self.apply(_check_lr_scale)
|
584 |
+
|
585 |
+
def _init_weights(self, m):
|
586 |
+
if isinstance(m, nn.Linear):
|
587 |
+
trunc_normal_(m.weight, std=.02)
|
588 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
589 |
+
nn.init.constant_(m.bias, 0)
|
590 |
+
elif isinstance(m, nn.LayerNorm):
|
591 |
+
nn.init.constant_(m.bias, 0)
|
592 |
+
nn.init.constant_(m.weight, 1.0)
|
593 |
+
|
594 |
+
@torch.jit.ignore
|
595 |
+
def no_weight_decay_keywords(self):
|
596 |
+
return {'attention_biases'}
|
597 |
+
|
598 |
+
def forward_features(self, x):
|
599 |
+
# x: (N, C, H, W)
|
600 |
+
x = self.patch_embed(x)
|
601 |
+
|
602 |
+
x = self.layers[0](x)
|
603 |
+
start_i = 1
|
604 |
+
|
605 |
+
for i in range(start_i, len(self.layers)):
|
606 |
+
layer = self.layers[i]
|
607 |
+
x = layer(x)
|
608 |
+
B,_,C=x.size()
|
609 |
+
x = x.view(B, 64, 64, C)
|
610 |
+
x=x.permute(0, 3, 1, 2)
|
611 |
+
x=self.neck(x)
|
612 |
+
return x
|
613 |
+
|
614 |
+
def forward(self, x):
|
615 |
+
x = self.forward_features(x)
|
616 |
+
|
617 |
+
# We have made some hack changes here to make it compatible with SAM-HQ
|
618 |
+
return x, None
|
619 |
+
|
620 |
+
|
621 |
+
_checkpoint_url_format = \
|
622 |
+
'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth'
|
623 |
+
_provided_checkpoints = {
|
624 |
+
'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill',
|
625 |
+
'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill',
|
626 |
+
'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill',
|
627 |
+
'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill',
|
628 |
+
'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill',
|
629 |
+
}
|
630 |
+
|
631 |
+
|
632 |
+
def register_tiny_vit_model(fn):
|
633 |
+
'''Register a TinyViT model
|
634 |
+
It is a wrapper of `register_model` with loading the pretrained checkpoint.
|
635 |
+
'''
|
636 |
+
def fn_wrapper(pretrained=False, **kwargs):
|
637 |
+
model = fn()
|
638 |
+
if pretrained:
|
639 |
+
model_name = fn.__name__
|
640 |
+
assert model_name in _provided_checkpoints, \
|
641 |
+
f'Sorry that the checkpoint `{model_name}` is not provided yet.'
|
642 |
+
url = _checkpoint_url_format.format(
|
643 |
+
_provided_checkpoints[model_name])
|
644 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
645 |
+
url=url,
|
646 |
+
map_location='cpu', check_hash=False,
|
647 |
+
)
|
648 |
+
model.load_state_dict(checkpoint['model'])
|
649 |
+
|
650 |
+
return model
|
651 |
+
|
652 |
+
# rename the name of fn_wrapper
|
653 |
+
fn_wrapper.__name__ = fn.__name__
|
654 |
+
return register_model(fn_wrapper)
|
655 |
+
|
656 |
+
|
657 |
+
@register_tiny_vit_model
|
658 |
+
def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0):
|
659 |
+
return TinyViT(
|
660 |
+
num_classes=num_classes,
|
661 |
+
embed_dims=[64, 128, 160, 320],
|
662 |
+
depths=[2, 2, 6, 2],
|
663 |
+
num_heads=[2, 4, 5, 10],
|
664 |
+
window_sizes=[7, 7, 14, 7],
|
665 |
+
drop_path_rate=drop_path_rate,
|
666 |
+
)
|
667 |
+
|
668 |
+
|
669 |
+
@register_tiny_vit_model
|
670 |
+
def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
671 |
+
return TinyViT(
|
672 |
+
num_classes=num_classes,
|
673 |
+
embed_dims=[64, 128, 256, 448],
|
674 |
+
depths=[2, 2, 6, 2],
|
675 |
+
num_heads=[2, 4, 8, 14],
|
676 |
+
window_sizes=[7, 7, 14, 7],
|
677 |
+
drop_path_rate=drop_path_rate,
|
678 |
+
)
|
679 |
+
|
680 |
+
|
681 |
+
@register_tiny_vit_model
|
682 |
+
def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2):
|
683 |
+
return TinyViT(
|
684 |
+
num_classes=num_classes,
|
685 |
+
embed_dims=[96, 192, 384, 576],
|
686 |
+
depths=[2, 2, 6, 2],
|
687 |
+
num_heads=[3, 6, 12, 18],
|
688 |
+
window_sizes=[7, 7, 14, 7],
|
689 |
+
drop_path_rate=drop_path_rate,
|
690 |
+
)
|
691 |
+
|
692 |
+
|
693 |
+
@register_tiny_vit_model
|
694 |
+
def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
695 |
+
return TinyViT(
|
696 |
+
img_size=384,
|
697 |
+
num_classes=num_classes,
|
698 |
+
embed_dims=[96, 192, 384, 576],
|
699 |
+
depths=[2, 2, 6, 2],
|
700 |
+
num_heads=[3, 6, 12, 18],
|
701 |
+
window_sizes=[12, 12, 24, 12],
|
702 |
+
drop_path_rate=drop_path_rate,
|
703 |
+
)
|
704 |
+
|
705 |
+
|
706 |
+
@register_tiny_vit_model
|
707 |
+
def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1):
|
708 |
+
return TinyViT(
|
709 |
+
img_size=512,
|
710 |
+
num_classes=num_classes,
|
711 |
+
embed_dims=[96, 192, 384, 576],
|
712 |
+
depths=[2, 2, 6, 2],
|
713 |
+
num_heads=[3, 6, 12, 18],
|
714 |
+
window_sizes=[16, 16, 32, 16],
|
715 |
+
drop_path_rate=drop_path_rate,
|
716 |
+
)
|
EfficientSAM/README.md
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Efficient Grounded-SAM
|
2 |
+
|
3 |
+
We're going to combine [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) with efficient SAM variants for faster annotating.
|
4 |
+
|
5 |
+
<!-- Combining [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) and [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM) for faster zero-shot detect and segment anything. -->
|
6 |
+
|
7 |
+
|
8 |
+
### Table of Contents
|
9 |
+
- [Installation](#installation)
|
10 |
+
- [Efficient SAM Series](#efficient-sams)
|
11 |
+
- [Run Grounded-FastSAM Demo](#run-grounded-fastsam-demo)
|
12 |
+
- [Run Grounded-MobileSAM Demo](#run-grounded-mobilesam-demo)
|
13 |
+
- [Run Grounded-LightHQSAM Demo](#run-grounded-light-hqsam-demo)
|
14 |
+
- [Run Grounded-Efficient-SAM Demo](#run-grounded-efficient-sam-demo)
|
15 |
+
- [Run Grounded-Edge-SAM Demo](#run-grounded-edge-sam-demo)
|
16 |
+
- [Run Grounded-RepViT-SAM Demo](#run-grounded-repvit-sam-demo)
|
17 |
+
|
18 |
+
|
19 |
+
### Installation
|
20 |
+
|
21 |
+
- Install [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything#installation)
|
22 |
+
|
23 |
+
- Install [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM#installation)
|
24 |
+
|
25 |
+
- Note that we may use the sam image as the demo image in order to compare the inference results of different efficient-sam variants.
|
26 |
+
|
27 |
+
### Efficient SAMs
|
28 |
+
Here's the list of Efficient SAM variants:
|
29 |
+
|
30 |
+
<div align="center">
|
31 |
+
|
32 |
+
| Title | Intro | Description | Links |
|
33 |
+
|:----:|:----:|:----:|:----:|
|
34 |
+
| [FastSAM](https://arxiv.org/pdf/2306.12156.pdf) | ![](https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/assets/Overview.png) | The Fast Segment Anything Model(FastSAM) is a CNN Segment Anything Model trained by only 2% of the SA-1B dataset published by SAM authors. The FastSAM achieve a comparable performance with the SAM method at 50× higher run-time speed. | [[Github](https://github.com/CASIA-IVA-Lab/FastSAM)] [[Demo](https://huggingface.co/spaces/An-619/FastSAM)] |
|
35 |
+
| [MobileSAM](https://arxiv.org/pdf/2306.14289.pdf) | ![](https://github.com/ChaoningZhang/MobileSAM/blob/master/assets/model_diagram.jpg?raw=true) | MobileSAM performs on par with the original SAM (at least visually) and keeps exactly the same pipeline as the original SAM except for a change on the image encoder. Specifically, we replace the original heavyweight ViT-H encoder (632M) with a much smaller Tiny-ViT (5M). On a single GPU, MobileSAM runs around 12ms per image: 8ms on the image encoder and 4ms on the mask decoder. | [[Github](https://github.com/ChaoningZhang/MobileSAM)] |
|
36 |
+
| [Light-HQSAM](https://arxiv.org/pdf/2306.01567.pdf) | ![](https://github.com/SysCV/sam-hq/blob/main/figs/sam-hf-framework.png?raw=true) | Light HQ-SAM is based on the tiny vit image encoder provided by MobileSAM. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with ViT features for improved mask details. Refer to [Light HQ-SAM vs. MobileSAM](https://github.com/SysCV/sam-hq#light-hq-sam-vs-mobilesam-on-coco) for more details. | [[Github](https://github.com/SysCV/sam-hq)] |
|
37 |
+
| [Efficient-SAM](https://github.com/yformer/EfficientSAM) | ![](https://yformer.github.io/efficient-sam/EfficientSAM_files/overview.png) |Segment Anything Model (SAM) has emerged as a powerful tool for numerous vision applications. However, the huge computation cost of SAM model has limited its applications to wider real-world applications. To address this limitation, we propose EfficientSAMs, light-weight SAM models that exhibit decent performance with largely reduced complexity. Our idea is based on leveraging masked image pretraining, SAMI, which learns to reconstruct features from SAM image encoder for effective visual representation learning. Further, we take SAMI-pretrained light-weight image encoders and mask decoder to build EfficientSAMs, and finetune the models on SA-1B for segment anything task. Refer to [EfficientSAM arXiv](https://arxiv.org/pdf/2312.00863.pdf) for more details.| [[Github](https://github.com/yformer/EfficientSAM)] |
|
38 |
+
| [Edge-SAM](https://github.com/chongzhou96/EdgeSAM) | ![](https://www.mmlab-ntu.com/project/edgesam/img/arch.png) | EdgeSAM involves distilling the original ViT-based SAM image encoder into a purely CNN-based architecture, better suited for edge devices. We carefully benchmark various distillation strategies and demonstrate that task-agnostic encoder distillation fails to capture the full knowledge embodied in SAM. Refer to [Edge-SAM arXiv](https://arxiv.org/abs/2312.06660) for more details. | [[Github](https://github.com/chongzhou96/EdgeSAM)] |
|
39 |
+
| [RepViT-SAM](https://github.com/THU-MIG/RepViT/tree/main/sam) | ![](https://jameslahm.github.io/repvit-sam/static/images/edge.png) | Recently, RepViT achieves the state-of-the-art performance and latency trade-off on mobile devices by incorporating efficient architectural designs of ViTs into CNNs. Here, to achieve real-time segmenting anything on mobile devices, following MobileSAM, we replace the heavyweight image encoder in SAM with RepViT model, ending up with the RepViT-SAM model. Extensive experiments show that RepViT-SAM can enjoy significantly better zero-shot transfer capability than MobileSAM, along with nearly 10× faster inference speed. Refer to [RepViT-SAM arXiv](https://arxiv.org/pdf/2312.05760.pdf) for more details. | [[Github](https://github.com/THU-MIG/RepViT)] |
|
40 |
+
|
41 |
+
</div>
|
42 |
+
|
43 |
+
|
44 |
+
### Run Grounded-FastSAM Demo
|
45 |
+
|
46 |
+
- Firstly, download the pretrained Fast-SAM weight [here](https://github.com/CASIA-IVA-Lab/FastSAM#model-checkpoints)
|
47 |
+
|
48 |
+
- Run the demo with the following script:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
cd Grounded-Segment-Anything
|
52 |
+
|
53 |
+
python EfficientSAM/grounded_fast_sam.py --model_path "./FastSAM-x.pt" --img_path "assets/demo4.jpg" --text "the black dog." --output "./output/"
|
54 |
+
```
|
55 |
+
|
56 |
+
- And the results will be saved in `./output/` as:
|
57 |
+
|
58 |
+
<div style="text-align: center">
|
59 |
+
|
60 |
+
| Input | Text | Output |
|
61 |
+
|:---:|:---:|:---:|
|
62 |
+
|![](/assets/demo4.jpg) | "The black dog." | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/fast_sam/demo4_0_caption_the%20black%20dog.jpg?raw=true) |
|
63 |
+
|
64 |
+
</div>
|
65 |
+
|
66 |
+
|
67 |
+
**Note**: Due to the post process of FastSAM, only one box can be annotated at a time, if there're multiple box prompts, we simply save multiple annotate images to `./output` now, which will be modified in the future release.
|
68 |
+
|
69 |
+
|
70 |
+
### Run Grounded-MobileSAM Demo
|
71 |
+
|
72 |
+
- Firstly, download the pretrained MobileSAM weight [here](https://github.com/ChaoningZhang/MobileSAM/tree/master/weights)
|
73 |
+
|
74 |
+
- Run the demo with the following script:
|
75 |
+
|
76 |
+
```bash
|
77 |
+
cd Grounded-Segment-Anything
|
78 |
+
|
79 |
+
python EfficientSAM/grounded_mobile_sam.py --MOBILE_SAM_CHECKPOINT_PATH "./EfficientSAM/mobile_sam.pt" --SOURCE_IMAGE_PATH "./assets/demo2.jpg" --CAPTION "the running dog"
|
80 |
+
```
|
81 |
+
|
82 |
+
- And the result will be saved as `./gronded_mobile_sam_anontated_image.jpg` as:
|
83 |
+
|
84 |
+
<div style="text-align: center">
|
85 |
+
|
86 |
+
| Input | Text | Output |
|
87 |
+
|:---:|:---:|:---:|
|
88 |
+
|![](/assets/demo2.jpg) | "the running dog" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/mobile_sam/grounded_mobile_sam_annotated_image.jpg?raw=true) |
|
89 |
+
|
90 |
+
</div>
|
91 |
+
|
92 |
+
|
93 |
+
### Run Grounded-Light-HQSAM Demo
|
94 |
+
|
95 |
+
- Firstly, download the pretrained Light-HQSAM weight [here](https://github.com/SysCV/sam-hq#model-checkpoints)
|
96 |
+
|
97 |
+
- Run the demo with the following script:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
cd Grounded-Segment-Anything
|
101 |
+
|
102 |
+
python EfficientSAM/grounded_light_hqsam.py
|
103 |
+
```
|
104 |
+
|
105 |
+
- And the result will be saved as `./gronded_light_hqsam_anontated_image.jpg` as:
|
106 |
+
|
107 |
+
<div style="text-align: center">
|
108 |
+
|
109 |
+
| Input | Text | Output |
|
110 |
+
|:---:|:---:|:---:|
|
111 |
+
|![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg) |
|
112 |
+
|
113 |
+
</div>
|
114 |
+
|
115 |
+
|
116 |
+
### Run Grounded-Efficient-SAM Demo
|
117 |
+
|
118 |
+
- Download the pretrained EfficientSAM checkpoint from [here](https://github.com/yformer/EfficientSAM#model) and put it under `Grounded-Segment-Anything/EfficientSAM`
|
119 |
+
|
120 |
+
- Run the demo with the following script:
|
121 |
+
|
122 |
+
```bash
|
123 |
+
cd Grounded-Segment-Anything
|
124 |
+
|
125 |
+
python EfficientSAM/grounded_efficient_sam.py
|
126 |
+
```
|
127 |
+
|
128 |
+
- And the result will be saved as `./gronded_efficient_sam_anontated_image.jpg` as:
|
129 |
+
|
130 |
+
<div style="text-align: center">
|
131 |
+
|
132 |
+
| Input | Text | Output |
|
133 |
+
|:---:|:---:|:---:|
|
134 |
+
|![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/efficient_sam/grounded_efficient_sam_annotated_image.jpg?raw=true) |
|
135 |
+
|
136 |
+
</div>
|
137 |
+
|
138 |
+
|
139 |
+
### Run Grounded-Edge-SAM Demo
|
140 |
+
|
141 |
+
- Download the pretrained [Edge-SAM](https://github.com/chongzhou96/EdgeSAM) checkpoint follow the [official instruction](https://github.com/chongzhou96/EdgeSAM?tab=readme-ov-file#usage-) as:
|
142 |
+
|
143 |
+
```bash
|
144 |
+
cd Grounded-Segment-Anything
|
145 |
+
wget -P EfficientSAM/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam.pth
|
146 |
+
wget -P EfficientSAM/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x.pth
|
147 |
+
```
|
148 |
+
|
149 |
+
- Run the demo with the following script:
|
150 |
+
|
151 |
+
```bash
|
152 |
+
cd Grounded-Segment-Anything
|
153 |
+
|
154 |
+
python EfficientSAM/grounded_edge_sam.py
|
155 |
+
```
|
156 |
+
|
157 |
+
- And the result will be saved as `./gronded_edge_sam_anontated_image.jpg` as:
|
158 |
+
|
159 |
+
<div style="text-align: center">
|
160 |
+
|
161 |
+
| Input | Text | Output |
|
162 |
+
|:---:|:---:|:---:|
|
163 |
+
|![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/edge_sam/grounded_edge_sam_annotated_image.jpg?raw=true) |
|
164 |
+
|
165 |
+
</div>
|
166 |
+
|
167 |
+
### Run Grounded-RepViT-SAM Demo
|
168 |
+
|
169 |
+
- Download the pretrained [RepViT-SAM](https://github.com/THU-MIG/RepViT) checkpoint follow the [official instruction](https://github.com/THU-MIG/RepViT/tree/main/sam#installation) as:
|
170 |
+
|
171 |
+
```bash
|
172 |
+
cd Grounded-Segment-Anything
|
173 |
+
wget -P EfficientSAM/ https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_sam.pt
|
174 |
+
```
|
175 |
+
|
176 |
+
- Run the demo with the following script:
|
177 |
+
|
178 |
+
```bash
|
179 |
+
cd Grounded-Segment-Anything
|
180 |
+
|
181 |
+
python EfficientSAM/grounded_repvit_sam.py
|
182 |
+
```
|
183 |
+
|
184 |
+
- And the result will be saved as `./gronded_repvit_sam_anontated_image.jpg` as:
|
185 |
+
|
186 |
+
<div style="text-align: center">
|
187 |
+
|
188 |
+
| Input | Text | Output |
|
189 |
+
|:---:|:---:|:---:|
|
190 |
+
|![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/repvit_sam/grounded_repvit_sam_annotated_image.jpg?raw=true) |
|
191 |
+
|
192 |
+
</div>
|
193 |
+
|
194 |
+
|
EfficientSAM/RepViTSAM/repvit.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ['repvit_m1']
|
5 |
+
|
6 |
+
|
7 |
+
def _make_divisible(v, divisor, min_value=None):
|
8 |
+
"""
|
9 |
+
This function is taken from the original tf repo.
|
10 |
+
It ensures that all layers have a channel number that is divisible by 8
|
11 |
+
It can be seen here:
|
12 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
13 |
+
:param v:
|
14 |
+
:param divisor:
|
15 |
+
:param min_value:
|
16 |
+
:return:
|
17 |
+
"""
|
18 |
+
if min_value is None:
|
19 |
+
min_value = divisor
|
20 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
21 |
+
# Make sure that round down does not go down by more than 10%.
|
22 |
+
if new_v < 0.9 * v:
|
23 |
+
new_v += divisor
|
24 |
+
return new_v
|
25 |
+
|
26 |
+
from timm.models.layers import SqueezeExcite
|
27 |
+
|
28 |
+
import torch
|
29 |
+
|
30 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
31 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
32 |
+
class LayerNorm2d(nn.Module):
|
33 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
36 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
37 |
+
self.eps = eps
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
u = x.mean(1, keepdim=True)
|
41 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
42 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
43 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
44 |
+
return x
|
45 |
+
|
46 |
+
class Conv2d_BN(torch.nn.Sequential):
|
47 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
48 |
+
groups=1, bn_weight_init=1, resolution=-10000):
|
49 |
+
super().__init__()
|
50 |
+
self.add_module('c', torch.nn.Conv2d(
|
51 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
52 |
+
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
53 |
+
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
54 |
+
torch.nn.init.constant_(self.bn.bias, 0)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def fuse(self):
|
58 |
+
c, bn = self._modules.values()
|
59 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
60 |
+
w = c.weight * w[:, None, None, None]
|
61 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
62 |
+
(bn.running_var + bn.eps)**0.5
|
63 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
64 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
|
65 |
+
device=c.weight.device)
|
66 |
+
m.weight.data.copy_(w)
|
67 |
+
m.bias.data.copy_(b)
|
68 |
+
return m
|
69 |
+
|
70 |
+
class Residual(torch.nn.Module):
|
71 |
+
def __init__(self, m, drop=0.):
|
72 |
+
super().__init__()
|
73 |
+
self.m = m
|
74 |
+
self.drop = drop
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
if self.training and self.drop > 0:
|
78 |
+
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
|
79 |
+
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
80 |
+
else:
|
81 |
+
return x + self.m(x)
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def fuse(self):
|
85 |
+
if isinstance(self.m, Conv2d_BN):
|
86 |
+
m = self.m.fuse()
|
87 |
+
assert(m.groups == m.in_channels)
|
88 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
89 |
+
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
90 |
+
m.weight += identity.to(m.weight.device)
|
91 |
+
return m
|
92 |
+
elif isinstance(self.m, torch.nn.Conv2d):
|
93 |
+
m = self.m
|
94 |
+
assert(m.groups != m.in_channels)
|
95 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
96 |
+
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
97 |
+
m.weight += identity.to(m.weight.device)
|
98 |
+
return m
|
99 |
+
else:
|
100 |
+
return self
|
101 |
+
|
102 |
+
|
103 |
+
class RepVGGDW(torch.nn.Module):
|
104 |
+
def __init__(self, ed) -> None:
|
105 |
+
super().__init__()
|
106 |
+
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
|
107 |
+
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
|
108 |
+
self.dim = ed
|
109 |
+
self.bn = torch.nn.BatchNorm2d(ed)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return self.bn((self.conv(x) + self.conv1(x)) + x)
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def fuse(self):
|
116 |
+
conv = self.conv.fuse()
|
117 |
+
conv1 = self.conv1
|
118 |
+
|
119 |
+
conv_w = conv.weight
|
120 |
+
conv_b = conv.bias
|
121 |
+
conv1_w = conv1.weight
|
122 |
+
conv1_b = conv1.bias
|
123 |
+
|
124 |
+
conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
|
125 |
+
|
126 |
+
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
|
127 |
+
|
128 |
+
final_conv_w = conv_w + conv1_w + identity
|
129 |
+
final_conv_b = conv_b + conv1_b
|
130 |
+
|
131 |
+
conv.weight.data.copy_(final_conv_w)
|
132 |
+
conv.bias.data.copy_(final_conv_b)
|
133 |
+
|
134 |
+
bn = self.bn
|
135 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
136 |
+
w = conv.weight * w[:, None, None, None]
|
137 |
+
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
|
138 |
+
(bn.running_var + bn.eps)**0.5
|
139 |
+
conv.weight.data.copy_(w)
|
140 |
+
conv.bias.data.copy_(b)
|
141 |
+
return conv
|
142 |
+
|
143 |
+
|
144 |
+
class RepViTBlock(nn.Module):
|
145 |
+
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
146 |
+
super(RepViTBlock, self).__init__()
|
147 |
+
assert stride in [1, 2]
|
148 |
+
|
149 |
+
self.identity = stride == 1 and inp == oup
|
150 |
+
assert(hidden_dim == 2 * inp)
|
151 |
+
|
152 |
+
if stride == 2:
|
153 |
+
self.token_mixer = nn.Sequential(
|
154 |
+
Conv2d_BN(inp, inp, kernel_size, stride if inp != 320 else 1, (kernel_size - 1) // 2, groups=inp),
|
155 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
156 |
+
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
|
157 |
+
)
|
158 |
+
self.channel_mixer = Residual(nn.Sequential(
|
159 |
+
# pw
|
160 |
+
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
|
161 |
+
nn.GELU() if use_hs else nn.GELU(),
|
162 |
+
# pw-linear
|
163 |
+
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
|
164 |
+
))
|
165 |
+
else:
|
166 |
+
# assert(self.identity)
|
167 |
+
self.token_mixer = nn.Sequential(
|
168 |
+
RepVGGDW(inp),
|
169 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
170 |
+
)
|
171 |
+
if self.identity:
|
172 |
+
self.channel_mixer = Residual(nn.Sequential(
|
173 |
+
# pw
|
174 |
+
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
|
175 |
+
nn.GELU() if use_hs else nn.GELU(),
|
176 |
+
# pw-linear
|
177 |
+
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
178 |
+
))
|
179 |
+
else:
|
180 |
+
self.channel_mixer = nn.Sequential(
|
181 |
+
# pw
|
182 |
+
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
|
183 |
+
nn.GELU() if use_hs else nn.GELU(),
|
184 |
+
# pw-linear
|
185 |
+
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
186 |
+
)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return self.channel_mixer(self.token_mixer(x))
|
190 |
+
|
191 |
+
from timm.models.vision_transformer import trunc_normal_
|
192 |
+
class BN_Linear(torch.nn.Sequential):
|
193 |
+
def __init__(self, a, b, bias=True, std=0.02):
|
194 |
+
super().__init__()
|
195 |
+
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
196 |
+
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
|
197 |
+
trunc_normal_(self.l.weight, std=std)
|
198 |
+
if bias:
|
199 |
+
torch.nn.init.constant_(self.l.bias, 0)
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def fuse(self):
|
203 |
+
bn, l = self._modules.values()
|
204 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
205 |
+
b = bn.bias - self.bn.running_mean * \
|
206 |
+
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
207 |
+
w = l.weight * w[None, :]
|
208 |
+
if l.bias is None:
|
209 |
+
b = b @ self.l.weight.T
|
210 |
+
else:
|
211 |
+
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
212 |
+
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
|
213 |
+
m.weight.data.copy_(w)
|
214 |
+
m.bias.data.copy_(b)
|
215 |
+
return m
|
216 |
+
|
217 |
+
class Classfier(nn.Module):
|
218 |
+
def __init__(self, dim, num_classes, distillation=True):
|
219 |
+
super().__init__()
|
220 |
+
self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
|
221 |
+
self.distillation = distillation
|
222 |
+
if distillation:
|
223 |
+
self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
if self.distillation:
|
227 |
+
x = self.classifier(x), self.classifier_dist(x)
|
228 |
+
if not self.training:
|
229 |
+
x = (x[0] + x[1]) / 2
|
230 |
+
else:
|
231 |
+
x = self.classifier(x)
|
232 |
+
return x
|
233 |
+
|
234 |
+
@torch.no_grad()
|
235 |
+
def fuse(self):
|
236 |
+
classifier = self.classifier.fuse()
|
237 |
+
if self.distillation:
|
238 |
+
classifier_dist = self.classifier_dist.fuse()
|
239 |
+
classifier.weight += classifier_dist.weight
|
240 |
+
classifier.bias += classifier_dist.bias
|
241 |
+
classifier.weight /= 2
|
242 |
+
classifier.bias /= 2
|
243 |
+
return classifier
|
244 |
+
else:
|
245 |
+
return classifier
|
246 |
+
|
247 |
+
class RepViT(nn.Module):
|
248 |
+
def __init__(self, cfgs, num_classes=1000, distillation=False, img_size=1024):
|
249 |
+
super(RepViT, self).__init__()
|
250 |
+
# setting of inverted residual blocks
|
251 |
+
self.cfgs = cfgs
|
252 |
+
|
253 |
+
self.img_size = img_size
|
254 |
+
|
255 |
+
# building first layer
|
256 |
+
input_channel = self.cfgs[0][2]
|
257 |
+
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
|
258 |
+
Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
|
259 |
+
layers = [patch_embed]
|
260 |
+
# building inverted residual blocks
|
261 |
+
block = RepViTBlock
|
262 |
+
for k, t, c, use_se, use_hs, s in self.cfgs:
|
263 |
+
output_channel = _make_divisible(c, 8)
|
264 |
+
exp_size = _make_divisible(input_channel * t, 8)
|
265 |
+
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
|
266 |
+
input_channel = output_channel
|
267 |
+
self.features = nn.ModuleList(layers)
|
268 |
+
# self.classifier = Classfier(output_channel, num_classes, distillation)
|
269 |
+
|
270 |
+
self.neck = nn.Sequential(
|
271 |
+
nn.Conv2d(
|
272 |
+
output_channel,
|
273 |
+
256,
|
274 |
+
kernel_size=1,
|
275 |
+
bias=False,
|
276 |
+
),
|
277 |
+
LayerNorm2d(256),
|
278 |
+
nn.Conv2d(
|
279 |
+
256,
|
280 |
+
256,
|
281 |
+
kernel_size=3,
|
282 |
+
padding=1,
|
283 |
+
bias=False,
|
284 |
+
),
|
285 |
+
LayerNorm2d(256),
|
286 |
+
)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
# x = self.features(x)
|
290 |
+
for f in self.features:
|
291 |
+
x = f(x)
|
292 |
+
# x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
|
293 |
+
x = self.neck(x)
|
294 |
+
return x, None
|
295 |
+
|
296 |
+
from timm.models import register_model
|
297 |
+
|
298 |
+
@register_model
|
299 |
+
def repvit(pretrained=False, num_classes = 1000, distillation=False, **kwargs):
|
300 |
+
"""
|
301 |
+
Constructs a MobileNetV3-Large model
|
302 |
+
"""
|
303 |
+
cfgs = [
|
304 |
+
# k, t, c, SE, HS, s
|
305 |
+
[3, 2, 80, 1, 0, 1],
|
306 |
+
[3, 2, 80, 0, 0, 1],
|
307 |
+
[3, 2, 80, 1, 0, 1],
|
308 |
+
[3, 2, 80, 0, 0, 1],
|
309 |
+
[3, 2, 80, 1, 0, 1],
|
310 |
+
[3, 2, 80, 0, 0, 1],
|
311 |
+
[3, 2, 80, 0, 0, 1],
|
312 |
+
[3, 2, 160, 0, 0, 2],
|
313 |
+
[3, 2, 160, 1, 0, 1],
|
314 |
+
[3, 2, 160, 0, 0, 1],
|
315 |
+
[3, 2, 160, 1, 0, 1],
|
316 |
+
[3, 2, 160, 0, 0, 1],
|
317 |
+
[3, 2, 160, 1, 0, 1],
|
318 |
+
[3, 2, 160, 0, 0, 1],
|
319 |
+
[3, 2, 160, 0, 0, 1],
|
320 |
+
[3, 2, 320, 0, 1, 2],
|
321 |
+
[3, 2, 320, 1, 1, 1],
|
322 |
+
[3, 2, 320, 0, 1, 1],
|
323 |
+
[3, 2, 320, 1, 1, 1],
|
324 |
+
[3, 2, 320, 0, 1, 1],
|
325 |
+
[3, 2, 320, 1, 1, 1],
|
326 |
+
[3, 2, 320, 0, 1, 1],
|
327 |
+
[3, 2, 320, 1, 1, 1],
|
328 |
+
[3, 2, 320, 0, 1, 1],
|
329 |
+
[3, 2, 320, 1, 1, 1],
|
330 |
+
[3, 2, 320, 0, 1, 1],
|
331 |
+
[3, 2, 320, 1, 1, 1],
|
332 |
+
[3, 2, 320, 0, 1, 1],
|
333 |
+
[3, 2, 320, 1, 1, 1],
|
334 |
+
[3, 2, 320, 0, 1, 1],
|
335 |
+
[3, 2, 320, 1, 1, 1],
|
336 |
+
[3, 2, 320, 0, 1, 1],
|
337 |
+
[3, 2, 320, 1, 1, 1],
|
338 |
+
[3, 2, 320, 0, 1, 1],
|
339 |
+
[3, 2, 320, 1, 1, 1],
|
340 |
+
[3, 2, 320, 0, 1, 1],
|
341 |
+
[3, 2, 320, 1, 1, 1],
|
342 |
+
[3, 2, 320, 0, 1, 1],
|
343 |
+
[3, 2, 320, 1, 1, 1],
|
344 |
+
[3, 2, 320, 0, 1, 1],
|
345 |
+
[3, 2, 320, 1, 1, 1],
|
346 |
+
[3, 2, 320, 0, 1, 1],
|
347 |
+
[3, 2, 320, 1, 1, 1],
|
348 |
+
[3, 2, 320, 0, 1, 1],
|
349 |
+
[3, 2, 320, 1, 1, 1],
|
350 |
+
[3, 2, 320, 0, 1, 1],
|
351 |
+
[3, 2, 320, 1, 1, 1],
|
352 |
+
[3, 2, 320, 0, 1, 1],
|
353 |
+
[3, 2, 320, 1, 1, 1],
|
354 |
+
[3, 2, 320, 0, 1, 1],
|
355 |
+
# [3, 2, 320, 1, 1, 1],
|
356 |
+
# [3, 2, 320, 0, 1, 1],
|
357 |
+
[3, 2, 320, 0, 1, 1],
|
358 |
+
[3, 2, 640, 0, 1, 2],
|
359 |
+
[3, 2, 640, 1, 1, 1],
|
360 |
+
[3, 2, 640, 0, 1, 1],
|
361 |
+
# [3, 2, 640, 1, 1, 1],
|
362 |
+
# [3, 2, 640, 0, 1, 1]
|
363 |
+
]
|
364 |
+
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
EfficientSAM/RepViTSAM/setup_repvit_sam.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from functools import partial
|
9 |
+
from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
10 |
+
from RepViTSAM import repvit
|
11 |
+
from timm.models import create_model
|
12 |
+
|
13 |
+
def build_sam_repvit(checkpoint=None):
|
14 |
+
prompt_embed_dim = 256
|
15 |
+
image_size = 1024
|
16 |
+
vit_patch_size = 16
|
17 |
+
image_embedding_size = image_size // vit_patch_size
|
18 |
+
repvit_sam = Sam(
|
19 |
+
image_encoder=create_model('repvit'),
|
20 |
+
prompt_encoder=PromptEncoder(
|
21 |
+
embed_dim=prompt_embed_dim,
|
22 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
23 |
+
input_image_size=(image_size, image_size),
|
24 |
+
mask_in_chans=16,
|
25 |
+
),
|
26 |
+
mask_decoder=MaskDecoder(
|
27 |
+
num_multimask_outputs=3,
|
28 |
+
transformer=TwoWayTransformer(
|
29 |
+
depth=2,
|
30 |
+
embedding_dim=prompt_embed_dim,
|
31 |
+
mlp_dim=2048,
|
32 |
+
num_heads=8,
|
33 |
+
),
|
34 |
+
transformer_dim=prompt_embed_dim,
|
35 |
+
iou_head_depth=3,
|
36 |
+
iou_head_hidden_dim=256,
|
37 |
+
),
|
38 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
39 |
+
pixel_std=[58.395, 57.12, 57.375],
|
40 |
+
)
|
41 |
+
|
42 |
+
repvit_sam.eval()
|
43 |
+
if checkpoint is not None:
|
44 |
+
with open(checkpoint, "rb") as f:
|
45 |
+
state_dict = torch.load(f)
|
46 |
+
repvit_sam.load_state_dict(state_dict)
|
47 |
+
return repvit_sam
|
48 |
+
|
49 |
+
from functools import partial
|
50 |
+
|
51 |
+
sam_model_registry = {
|
52 |
+
"repvit": partial(build_sam_repvit),
|
53 |
+
}
|
EfficientSAM/grounded_edge_sam.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from groundingdino.util.inference import Model
|
9 |
+
from segment_anything import SamPredictor
|
10 |
+
from EdgeSAM.setup_edge_sam import build_edge_sam
|
11 |
+
|
12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
|
14 |
+
# GroundingDINO config and checkpoint
|
15 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
16 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
|
17 |
+
|
18 |
+
# Building GroundingDINO inference model
|
19 |
+
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
|
20 |
+
|
21 |
+
# Building MobileSAM predictor
|
22 |
+
EdgeSAM_CHECKPOINT_PATH = "./EfficientSAM/edge_sam_3x.pth"
|
23 |
+
edge_sam = build_edge_sam(checkpoint=EdgeSAM_CHECKPOINT_PATH)
|
24 |
+
edge_sam.to(device=DEVICE)
|
25 |
+
|
26 |
+
sam_predictor = SamPredictor(edge_sam)
|
27 |
+
|
28 |
+
|
29 |
+
# Predict classes and hyper-param for GroundingDINO
|
30 |
+
SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
|
31 |
+
CLASSES = ["bench"]
|
32 |
+
BOX_THRESHOLD = 0.25
|
33 |
+
TEXT_THRESHOLD = 0.25
|
34 |
+
NMS_THRESHOLD = 0.8
|
35 |
+
|
36 |
+
|
37 |
+
# load image
|
38 |
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
39 |
+
|
40 |
+
# detect objects
|
41 |
+
detections = grounding_dino_model.predict_with_classes(
|
42 |
+
image=image,
|
43 |
+
classes=CLASSES,
|
44 |
+
box_threshold=BOX_THRESHOLD,
|
45 |
+
text_threshold=TEXT_THRESHOLD
|
46 |
+
)
|
47 |
+
|
48 |
+
# annotate image with detections
|
49 |
+
box_annotator = sv.BoxAnnotator()
|
50 |
+
labels = [
|
51 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
52 |
+
for _, _, confidence, class_id, _, _
|
53 |
+
in detections]
|
54 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
|
55 |
+
|
56 |
+
# save the annotated grounding dino image
|
57 |
+
cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
|
58 |
+
|
59 |
+
|
60 |
+
# NMS post process
|
61 |
+
print(f"Before NMS: {len(detections.xyxy)} boxes")
|
62 |
+
nms_idx = torchvision.ops.nms(
|
63 |
+
torch.from_numpy(detections.xyxy),
|
64 |
+
torch.from_numpy(detections.confidence),
|
65 |
+
NMS_THRESHOLD
|
66 |
+
).numpy().tolist()
|
67 |
+
|
68 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
69 |
+
detections.confidence = detections.confidence[nms_idx]
|
70 |
+
detections.class_id = detections.class_id[nms_idx]
|
71 |
+
|
72 |
+
print(f"After NMS: {len(detections.xyxy)} boxes")
|
73 |
+
|
74 |
+
# Prompting SAM with detected boxes
|
75 |
+
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
|
76 |
+
sam_predictor.set_image(image)
|
77 |
+
result_masks = []
|
78 |
+
for box in xyxy:
|
79 |
+
masks, scores, logits = sam_predictor.predict(
|
80 |
+
box=box,
|
81 |
+
multimask_output=False,
|
82 |
+
hq_token_only=True,
|
83 |
+
)
|
84 |
+
index = np.argmax(scores)
|
85 |
+
result_masks.append(masks[index])
|
86 |
+
return np.array(result_masks)
|
87 |
+
|
88 |
+
|
89 |
+
# convert detections to masks
|
90 |
+
detections.mask = segment(
|
91 |
+
sam_predictor=sam_predictor,
|
92 |
+
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
93 |
+
xyxy=detections.xyxy
|
94 |
+
)
|
95 |
+
|
96 |
+
# annotate image with detections
|
97 |
+
box_annotator = sv.BoxAnnotator()
|
98 |
+
mask_annotator = sv.MaskAnnotator()
|
99 |
+
labels = [
|
100 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
101 |
+
for _, _, confidence, class_id, _, _
|
102 |
+
in detections]
|
103 |
+
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
|
104 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
105 |
+
|
106 |
+
# save the annotated grounded-sam image
|
107 |
+
cv2.imwrite("EfficientSAM/grounded_edge_sam_annotated_image.jpg", annotated_image)
|
EfficientSAM/grounded_efficient_sam.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torchvision.transforms import ToTensor
|
8 |
+
|
9 |
+
from groundingdino.util.inference import Model
|
10 |
+
|
11 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
|
13 |
+
# GroundingDINO config and checkpoint
|
14 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
15 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
|
16 |
+
|
17 |
+
# Building GroundingDINO inference model
|
18 |
+
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
|
19 |
+
|
20 |
+
# Building MobileSAM predictor
|
21 |
+
EFFICIENT_SAM_CHECHPOINT_PATH = "./EfficientSAM/efficientsam_s_gpu.jit"
|
22 |
+
efficientsam = torch.jit.load(EFFICIENT_SAM_CHECHPOINT_PATH)
|
23 |
+
|
24 |
+
|
25 |
+
# Predict classes and hyper-param for GroundingDINO
|
26 |
+
SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
|
27 |
+
CLASSES = ["bench"]
|
28 |
+
BOX_THRESHOLD = 0.25
|
29 |
+
TEXT_THRESHOLD = 0.25
|
30 |
+
NMS_THRESHOLD = 0.8
|
31 |
+
|
32 |
+
|
33 |
+
# load image
|
34 |
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
35 |
+
|
36 |
+
# detect objects
|
37 |
+
detections = grounding_dino_model.predict_with_classes(
|
38 |
+
image=image,
|
39 |
+
classes=CLASSES,
|
40 |
+
box_threshold=BOX_THRESHOLD,
|
41 |
+
text_threshold=TEXT_THRESHOLD
|
42 |
+
)
|
43 |
+
|
44 |
+
# annotate image with detections
|
45 |
+
box_annotator = sv.BoxAnnotator()
|
46 |
+
labels = [
|
47 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
48 |
+
for _, _, confidence, class_id, _, _
|
49 |
+
in detections]
|
50 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
|
51 |
+
|
52 |
+
# save the annotated grounding dino image
|
53 |
+
cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
|
54 |
+
|
55 |
+
|
56 |
+
# NMS post process
|
57 |
+
print(f"Before NMS: {len(detections.xyxy)} boxes")
|
58 |
+
nms_idx = torchvision.ops.nms(
|
59 |
+
torch.from_numpy(detections.xyxy),
|
60 |
+
torch.from_numpy(detections.confidence),
|
61 |
+
NMS_THRESHOLD
|
62 |
+
).numpy().tolist()
|
63 |
+
|
64 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
65 |
+
detections.confidence = detections.confidence[nms_idx]
|
66 |
+
detections.class_id = detections.class_id[nms_idx]
|
67 |
+
|
68 |
+
print(f"After NMS: {len(detections.xyxy)} boxes")
|
69 |
+
|
70 |
+
|
71 |
+
def efficient_sam_box_prompt_segment(image, pts_sampled, model):
|
72 |
+
bbox = torch.reshape(torch.tensor(pts_sampled), [1, 1, 2, 2])
|
73 |
+
bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
|
74 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
75 |
+
img_tensor = ToTensor()(image)
|
76 |
+
|
77 |
+
predicted_logits, predicted_iou = model(
|
78 |
+
img_tensor[None, ...].cuda(),
|
79 |
+
bbox.cuda(),
|
80 |
+
bbox_labels.cuda(),
|
81 |
+
)
|
82 |
+
predicted_logits = predicted_logits.cpu()
|
83 |
+
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
|
84 |
+
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
|
85 |
+
|
86 |
+
max_predicted_iou = -1
|
87 |
+
selected_mask_using_predicted_iou = None
|
88 |
+
for m in range(all_masks.shape[0]):
|
89 |
+
curr_predicted_iou = predicted_iou[m]
|
90 |
+
if (
|
91 |
+
curr_predicted_iou > max_predicted_iou
|
92 |
+
or selected_mask_using_predicted_iou is None
|
93 |
+
):
|
94 |
+
max_predicted_iou = curr_predicted_iou
|
95 |
+
selected_mask_using_predicted_iou = all_masks[m]
|
96 |
+
return selected_mask_using_predicted_iou
|
97 |
+
|
98 |
+
|
99 |
+
# collect segment results from EfficientSAM
|
100 |
+
result_masks = []
|
101 |
+
for box in detections.xyxy:
|
102 |
+
mask = efficient_sam_box_prompt_segment(image, box, efficientsam)
|
103 |
+
result_masks.append(mask)
|
104 |
+
|
105 |
+
detections.mask = np.array(result_masks)
|
106 |
+
|
107 |
+
# annotate image with detections
|
108 |
+
box_annotator = sv.BoxAnnotator()
|
109 |
+
mask_annotator = sv.MaskAnnotator()
|
110 |
+
labels = [
|
111 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
112 |
+
for _, _, confidence, class_id, _, _
|
113 |
+
in detections]
|
114 |
+
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
|
115 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
116 |
+
|
117 |
+
# save the annotated grounded-sam image
|
118 |
+
cv2.imwrite("EfficientSAM/gronded_efficient_sam_anontated_image.jpg", annotated_image)
|
EfficientSAM/grounded_fast_sam.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
from ultralytics import YOLO
|
4 |
+
from FastSAM.tools import *
|
5 |
+
from groundingdino.util.inference import load_model, load_image, predict, annotate, Model
|
6 |
+
from torchvision.ops import box_convert
|
7 |
+
import ast
|
8 |
+
|
9 |
+
def parse_args():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument(
|
12 |
+
"--model_path", type=str, default="./FastSAM/FastSAM-x.pt", help="model"
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--img_path", type=str, default="./images/dogs.jpg", help="path to image file"
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--text", type=str, default="the black dog.", help="text prompt for GroundingDINO"
|
19 |
+
)
|
20 |
+
parser.add_argument("--imgsz", type=int, default=1024, help="image size")
|
21 |
+
parser.add_argument(
|
22 |
+
"--iou",
|
23 |
+
type=float,
|
24 |
+
default=0.9,
|
25 |
+
help="iou threshold for filtering the annotations",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--conf", type=float, default=0.4, help="object confidence threshold"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--output", type=str, default="./output/", help="image save path"
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--randomcolor", type=bool, default=True, help="mask random color"
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--point_label",
|
41 |
+
type=str,
|
42 |
+
default="[0]",
|
43 |
+
help="[1,0] 0:background, 1:foreground",
|
44 |
+
)
|
45 |
+
parser.add_argument("--box_prompt", type=str, default="[0,0,0,0]", help="[x,y,w,h]")
|
46 |
+
parser.add_argument(
|
47 |
+
"--better_quality",
|
48 |
+
type=str,
|
49 |
+
default=False,
|
50 |
+
help="better quality using morphologyEx",
|
51 |
+
)
|
52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
+
parser.add_argument(
|
54 |
+
"--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--retina",
|
58 |
+
type=bool,
|
59 |
+
default=True,
|
60 |
+
help="draw high-resolution segmentation masks",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--withContours", type=bool, default=False, help="draw the edges of the masks"
|
64 |
+
)
|
65 |
+
return parser.parse_args()
|
66 |
+
|
67 |
+
|
68 |
+
def main(args):
|
69 |
+
|
70 |
+
# Image Path
|
71 |
+
img_path = args.img_path
|
72 |
+
text = args.text
|
73 |
+
|
74 |
+
# path to save img
|
75 |
+
save_path = args.output
|
76 |
+
if not os.path.exists(save_path):
|
77 |
+
os.makedirs(save_path)
|
78 |
+
basename = os.path.basename(args.img_path).split(".")[0]
|
79 |
+
|
80 |
+
# Build Fast-SAM Model
|
81 |
+
# ckpt_path = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/FastSAM/FastSAM-x.pt"
|
82 |
+
model = YOLO(args.model_path)
|
83 |
+
|
84 |
+
results = model(
|
85 |
+
args.img_path,
|
86 |
+
imgsz=args.imgsz,
|
87 |
+
device=args.device,
|
88 |
+
retina_masks=args.retina,
|
89 |
+
iou=args.iou,
|
90 |
+
conf=args.conf,
|
91 |
+
max_det=100,
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
# Build GroundingDINO Model
|
96 |
+
groundingdino_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
97 |
+
groundingdino_ckpt_path = "./groundingdino_swint_ogc.pth"
|
98 |
+
|
99 |
+
image_source, image = load_image(img_path)
|
100 |
+
model = load_model(groundingdino_config, groundingdino_ckpt_path)
|
101 |
+
|
102 |
+
boxes, logits, phrases = predict(
|
103 |
+
model=model,
|
104 |
+
image=image,
|
105 |
+
caption=text,
|
106 |
+
box_threshold=0.3,
|
107 |
+
text_threshold=0.25,
|
108 |
+
device=args.device,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
# Grounded-Fast-SAM
|
113 |
+
|
114 |
+
ori_img = cv2.imread(img_path)
|
115 |
+
ori_h = ori_img.shape[0]
|
116 |
+
ori_w = ori_img.shape[1]
|
117 |
+
|
118 |
+
# Save each frame due to the post process from FastSAM
|
119 |
+
boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h])
|
120 |
+
print(f"Detected Boxes: {len(boxes)}")
|
121 |
+
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist()
|
122 |
+
for box_idx in range(len(boxes)):
|
123 |
+
mask, _ = box_prompt(
|
124 |
+
results[0].masks.data,
|
125 |
+
boxes[box_idx],
|
126 |
+
ori_h,
|
127 |
+
ori_w,
|
128 |
+
)
|
129 |
+
annotations = np.array([mask])
|
130 |
+
img_array = fast_process(
|
131 |
+
annotations=annotations,
|
132 |
+
args=args,
|
133 |
+
mask_random_color=True,
|
134 |
+
bbox=boxes[box_idx],
|
135 |
+
)
|
136 |
+
cv2.imwrite(os.path.join(save_path, basename + f"_{str(box_idx)}_caption_{phrases[box_idx]}.jpg"), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
args = parse_args()
|
141 |
+
main(args)
|
EfficientSAM/grounded_light_hqsam.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from groundingdino.util.inference import Model
|
9 |
+
from segment_anything import SamPredictor
|
10 |
+
from LightHQSAM.setup_light_hqsam import setup_model
|
11 |
+
|
12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
|
14 |
+
# GroundingDINO config and checkpoint
|
15 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
16 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
|
17 |
+
|
18 |
+
# Building GroundingDINO inference model
|
19 |
+
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
|
20 |
+
|
21 |
+
# Building MobileSAM predictor
|
22 |
+
HQSAM_CHECKPOINT_PATH = "./EfficientSAM/sam_hq_vit_tiny.pth"
|
23 |
+
checkpoint = torch.load(HQSAM_CHECKPOINT_PATH)
|
24 |
+
light_hqsam = setup_model()
|
25 |
+
light_hqsam.load_state_dict(checkpoint, strict=True)
|
26 |
+
light_hqsam.to(device=DEVICE)
|
27 |
+
|
28 |
+
sam_predictor = SamPredictor(light_hqsam)
|
29 |
+
|
30 |
+
|
31 |
+
# Predict classes and hyper-param for GroundingDINO
|
32 |
+
SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
|
33 |
+
CLASSES = ["bench"]
|
34 |
+
BOX_THRESHOLD = 0.25
|
35 |
+
TEXT_THRESHOLD = 0.25
|
36 |
+
NMS_THRESHOLD = 0.8
|
37 |
+
|
38 |
+
|
39 |
+
# load image
|
40 |
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
41 |
+
|
42 |
+
# detect objects
|
43 |
+
detections = grounding_dino_model.predict_with_classes(
|
44 |
+
image=image,
|
45 |
+
classes=CLASSES,
|
46 |
+
box_threshold=BOX_THRESHOLD,
|
47 |
+
text_threshold=TEXT_THRESHOLD
|
48 |
+
)
|
49 |
+
|
50 |
+
# annotate image with detections
|
51 |
+
box_annotator = sv.BoxAnnotator()
|
52 |
+
labels = [
|
53 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
54 |
+
for _, _, confidence, class_id, _, _
|
55 |
+
in detections]
|
56 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
|
57 |
+
|
58 |
+
# save the annotated grounding dino image
|
59 |
+
cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
|
60 |
+
|
61 |
+
|
62 |
+
# NMS post process
|
63 |
+
print(f"Before NMS: {len(detections.xyxy)} boxes")
|
64 |
+
nms_idx = torchvision.ops.nms(
|
65 |
+
torch.from_numpy(detections.xyxy),
|
66 |
+
torch.from_numpy(detections.confidence),
|
67 |
+
NMS_THRESHOLD
|
68 |
+
).numpy().tolist()
|
69 |
+
|
70 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
71 |
+
detections.confidence = detections.confidence[nms_idx]
|
72 |
+
detections.class_id = detections.class_id[nms_idx]
|
73 |
+
|
74 |
+
print(f"After NMS: {len(detections.xyxy)} boxes")
|
75 |
+
|
76 |
+
# Prompting SAM with detected boxes
|
77 |
+
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
|
78 |
+
sam_predictor.set_image(image)
|
79 |
+
result_masks = []
|
80 |
+
for box in xyxy:
|
81 |
+
masks, scores, logits = sam_predictor.predict(
|
82 |
+
box=box,
|
83 |
+
multimask_output=False,
|
84 |
+
hq_token_only=True,
|
85 |
+
)
|
86 |
+
index = np.argmax(scores)
|
87 |
+
result_masks.append(masks[index])
|
88 |
+
return np.array(result_masks)
|
89 |
+
|
90 |
+
|
91 |
+
# convert detections to masks
|
92 |
+
detections.mask = segment(
|
93 |
+
sam_predictor=sam_predictor,
|
94 |
+
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
95 |
+
xyxy=detections.xyxy
|
96 |
+
)
|
97 |
+
|
98 |
+
# annotate image with detections
|
99 |
+
box_annotator = sv.BoxAnnotator()
|
100 |
+
mask_annotator = sv.MaskAnnotator()
|
101 |
+
labels = [
|
102 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
103 |
+
for _, _, confidence, class_id, _, _
|
104 |
+
in detections]
|
105 |
+
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
|
106 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
107 |
+
|
108 |
+
# save the annotated grounded-sam image
|
109 |
+
cv2.imwrite("EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg", annotated_image)
|
EfficientSAM/grounded_mobile_sam.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from groundingdino.util.inference import Model
|
9 |
+
from segment_anything import SamPredictor
|
10 |
+
from MobileSAM.setup_mobile_sam import setup_model
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument(
|
15 |
+
"--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model"
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file"
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO"
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename"
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename"
|
31 |
+
)
|
32 |
+
parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="")
|
33 |
+
parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="")
|
34 |
+
parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="")
|
35 |
+
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
parser.add_argument(
|
38 |
+
"--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
|
39 |
+
)
|
40 |
+
return parser.parse_args()
|
41 |
+
|
42 |
+
def main(args):
|
43 |
+
DEVICE = args.DEVICE
|
44 |
+
|
45 |
+
# GroundingDINO config and checkpoint
|
46 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
47 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
|
48 |
+
|
49 |
+
# Building GroundingDINO inference model
|
50 |
+
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
|
51 |
+
|
52 |
+
# Building MobileSAM predictor
|
53 |
+
MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH
|
54 |
+
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
|
55 |
+
mobile_sam = setup_model()
|
56 |
+
mobile_sam.load_state_dict(checkpoint, strict=True)
|
57 |
+
mobile_sam.to(device=DEVICE)
|
58 |
+
|
59 |
+
sam_predictor = SamPredictor(mobile_sam)
|
60 |
+
|
61 |
+
|
62 |
+
# Predict classes and hyper-param for GroundingDINO
|
63 |
+
SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH
|
64 |
+
CLASSES = [args.CAPTION]
|
65 |
+
BOX_THRESHOLD = args.BOX_THRESHOLD
|
66 |
+
TEXT_THRESHOLD = args.TEXT_THRESHOLD
|
67 |
+
NMS_THRESHOLD = args.NMS_THRESHOLD
|
68 |
+
|
69 |
+
|
70 |
+
# load image
|
71 |
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
72 |
+
|
73 |
+
# detect objects
|
74 |
+
detections = grounding_dino_model.predict_with_classes(
|
75 |
+
image=image,
|
76 |
+
classes=CLASSES,
|
77 |
+
box_threshold=BOX_THRESHOLD,
|
78 |
+
text_threshold=TEXT_THRESHOLD
|
79 |
+
)
|
80 |
+
|
81 |
+
# annotate image with detections
|
82 |
+
box_annotator = sv.BoxAnnotator()
|
83 |
+
labels = [
|
84 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
85 |
+
for _, _, confidence, class_id, _, _
|
86 |
+
in detections]
|
87 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
|
88 |
+
|
89 |
+
# save the annotated grounding dino image
|
90 |
+
cv2.imwrite(args.OUT_FILE_BOX, annotated_frame)
|
91 |
+
|
92 |
+
|
93 |
+
# NMS post process
|
94 |
+
print(f"Before NMS: {len(detections.xyxy)} boxes")
|
95 |
+
nms_idx = torchvision.ops.nms(
|
96 |
+
torch.from_numpy(detections.xyxy),
|
97 |
+
torch.from_numpy(detections.confidence),
|
98 |
+
NMS_THRESHOLD
|
99 |
+
).numpy().tolist()
|
100 |
+
|
101 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
102 |
+
detections.confidence = detections.confidence[nms_idx]
|
103 |
+
detections.class_id = detections.class_id[nms_idx]
|
104 |
+
|
105 |
+
print(f"After NMS: {len(detections.xyxy)} boxes")
|
106 |
+
|
107 |
+
# Prompting SAM with detected boxes
|
108 |
+
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
|
109 |
+
sam_predictor.set_image(image)
|
110 |
+
result_masks = []
|
111 |
+
for box in xyxy:
|
112 |
+
masks, scores, logits = sam_predictor.predict(
|
113 |
+
box=box,
|
114 |
+
multimask_output=True
|
115 |
+
)
|
116 |
+
index = np.argmax(scores)
|
117 |
+
result_masks.append(masks[index])
|
118 |
+
return np.array(result_masks)
|
119 |
+
|
120 |
+
|
121 |
+
# convert detections to masks
|
122 |
+
detections.mask = segment(
|
123 |
+
sam_predictor=sam_predictor,
|
124 |
+
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
125 |
+
xyxy=detections.xyxy
|
126 |
+
)
|
127 |
+
|
128 |
+
binary_mask = detections.mask[0].astype(np.uint8)*255
|
129 |
+
cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask)
|
130 |
+
|
131 |
+
# annotate image with detections
|
132 |
+
box_annotator = sv.BoxAnnotator()
|
133 |
+
mask_annotator = sv.MaskAnnotator()
|
134 |
+
labels = [
|
135 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
136 |
+
for _, _, confidence, class_id, _, _
|
137 |
+
in detections]
|
138 |
+
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
|
139 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
140 |
+
# save the annotated grounded-sam image
|
141 |
+
cv2.imwrite(args.OUT_FILE_SEG, annotated_image)
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
args = parse_args()
|
145 |
+
main(args)
|
EfficientSAM/grounded_repvit_sam.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from groundingdino.util.inference import Model
|
9 |
+
from segment_anything import SamPredictor
|
10 |
+
from RepViTSAM.setup_repvit_sam import build_sam_repvit
|
11 |
+
|
12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
|
14 |
+
# GroundingDINO config and checkpoint
|
15 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
16 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
|
17 |
+
|
18 |
+
# Building GroundingDINO inference model
|
19 |
+
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
|
20 |
+
|
21 |
+
# Building MobileSAM predictor
|
22 |
+
RepViTSAM_CHECKPOINT_PATH = "./EfficientSAM/repvit_sam.pt"
|
23 |
+
repvit_sam = build_sam_repvit(checkpoint=RepViTSAM_CHECKPOINT_PATH)
|
24 |
+
repvit_sam.to(device=DEVICE)
|
25 |
+
|
26 |
+
sam_predictor = SamPredictor(repvit_sam)
|
27 |
+
|
28 |
+
|
29 |
+
# Predict classes and hyper-param for GroundingDINO
|
30 |
+
SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
|
31 |
+
CLASSES = ["bench"]
|
32 |
+
BOX_THRESHOLD = 0.25
|
33 |
+
TEXT_THRESHOLD = 0.25
|
34 |
+
NMS_THRESHOLD = 0.8
|
35 |
+
|
36 |
+
|
37 |
+
# load image
|
38 |
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
39 |
+
|
40 |
+
# detect objects
|
41 |
+
detections = grounding_dino_model.predict_with_classes(
|
42 |
+
image=image,
|
43 |
+
classes=CLASSES,
|
44 |
+
box_threshold=BOX_THRESHOLD,
|
45 |
+
text_threshold=TEXT_THRESHOLD
|
46 |
+
)
|
47 |
+
|
48 |
+
# annotate image with detections
|
49 |
+
box_annotator = sv.BoxAnnotator()
|
50 |
+
labels = [
|
51 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
52 |
+
for _, _, confidence, class_id, _, _
|
53 |
+
in detections]
|
54 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
|
55 |
+
|
56 |
+
# save the annotated grounding dino image
|
57 |
+
cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
|
58 |
+
|
59 |
+
|
60 |
+
# NMS post process
|
61 |
+
print(f"Before NMS: {len(detections.xyxy)} boxes")
|
62 |
+
nms_idx = torchvision.ops.nms(
|
63 |
+
torch.from_numpy(detections.xyxy),
|
64 |
+
torch.from_numpy(detections.confidence),
|
65 |
+
NMS_THRESHOLD
|
66 |
+
).numpy().tolist()
|
67 |
+
|
68 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
69 |
+
detections.confidence = detections.confidence[nms_idx]
|
70 |
+
detections.class_id = detections.class_id[nms_idx]
|
71 |
+
|
72 |
+
print(f"After NMS: {len(detections.xyxy)} boxes")
|
73 |
+
|
74 |
+
# Prompting SAM with detected boxes
|
75 |
+
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
|
76 |
+
sam_predictor.set_image(image)
|
77 |
+
result_masks = []
|
78 |
+
for box in xyxy:
|
79 |
+
masks, scores, logits = sam_predictor.predict(
|
80 |
+
box=box,
|
81 |
+
multimask_output=False,
|
82 |
+
hq_token_only=True,
|
83 |
+
)
|
84 |
+
index = np.argmax(scores)
|
85 |
+
result_masks.append(masks[index])
|
86 |
+
return np.array(result_masks)
|
87 |
+
|
88 |
+
|
89 |
+
# convert detections to masks
|
90 |
+
detections.mask = segment(
|
91 |
+
sam_predictor=sam_predictor,
|
92 |
+
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
93 |
+
xyxy=detections.xyxy
|
94 |
+
)
|
95 |
+
|
96 |
+
# annotate image with detections
|
97 |
+
box_annotator = sv.BoxAnnotator()
|
98 |
+
mask_annotator = sv.MaskAnnotator()
|
99 |
+
labels = [
|
100 |
+
f"{CLASSES[class_id]} {confidence:0.2f}"
|
101 |
+
for _, _, confidence, class_id, _, _
|
102 |
+
in detections]
|
103 |
+
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
|
104 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
105 |
+
|
106 |
+
# save the annotated grounded-sam image
|
107 |
+
cv2.imwrite("EfficientSAM/grounded_repvit_sam_annotated_image.jpg", annotated_image)
|
GroundingDINO/.asset/COCO.png
ADDED
GroundingDINO/.asset/GD_GLIGEN.png
ADDED
Git LFS Details
|
GroundingDINO/.asset/GD_SD.png
ADDED
Git LFS Details
|
GroundingDINO/.asset/ODinW.png
ADDED
GroundingDINO/.asset/arch.png
ADDED
GroundingDINO/.asset/cats.png
ADDED
GroundingDINO/.asset/hero_figure.png
ADDED
Git LFS Details
|
GroundingDINO/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2020 - present, Facebook, Inc
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
GroundingDINO/README.md
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Grounding DINO
|
2 |
+
|
3 |
+
---
|
4 |
+
|
5 |
+
[![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
|
6 |
+
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8)
|
7 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
|
8 |
+
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
|
9 |
+
[![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
|
10 |
+
|
11 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
|
12 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
13 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
|
14 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
|
19 |
+
|
20 |
+
|
21 |
+
## Highlight
|
22 |
+
|
23 |
+
- **Open-Set Detection.** Detect **everything** with language!
|
24 |
+
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
25 |
+
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
26 |
+
|
27 |
+
## News
|
28 |
+
[2023/03/28] A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)] \
|
29 |
+
[2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space! \
|
30 |
+
[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
|
31 |
+
[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)] \
|
32 |
+
[2023/03/22] Code is available Now!
|
33 |
+
|
34 |
+
<details open>
|
35 |
+
<summary><font size="4">
|
36 |
+
Description
|
37 |
+
</font></summary>
|
38 |
+
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
39 |
+
</details>
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
## TODO
|
44 |
+
|
45 |
+
- [x] Release inference code and demo.
|
46 |
+
- [x] Release checkpoints.
|
47 |
+
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
48 |
+
- [ ] Release training codes.
|
49 |
+
|
50 |
+
## Install
|
51 |
+
|
52 |
+
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
|
53 |
+
|
54 |
+
```bash
|
55 |
+
pip install -e .
|
56 |
+
```
|
57 |
+
|
58 |
+
## Demo
|
59 |
+
|
60 |
+
```bash
|
61 |
+
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
62 |
+
-c /path/to/config \
|
63 |
+
-p /path/to/checkpoint \
|
64 |
+
-i .asset/cats.png \
|
65 |
+
-o "outputs/0" \
|
66 |
+
-t "cat ear." \
|
67 |
+
[--cpu-only] # open it for cpu mode
|
68 |
+
```
|
69 |
+
See the `demo/inference_on_a_image.py` for more details.
|
70 |
+
|
71 |
+
**Web UI**
|
72 |
+
|
73 |
+
We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
|
74 |
+
|
75 |
+
## Checkpoints
|
76 |
+
|
77 |
+
<!-- insert a table -->
|
78 |
+
<table>
|
79 |
+
<thead>
|
80 |
+
<tr style="text-align: right;">
|
81 |
+
<th></th>
|
82 |
+
<th>name</th>
|
83 |
+
<th>backbone</th>
|
84 |
+
<th>Data</th>
|
85 |
+
<th>box AP on COCO</th>
|
86 |
+
<th>Checkpoint</th>
|
87 |
+
<th>Config</th>
|
88 |
+
</tr>
|
89 |
+
</thead>
|
90 |
+
<tbody>
|
91 |
+
<tr>
|
92 |
+
<th>1</th>
|
93 |
+
<td>GroundingDINO-T</td>
|
94 |
+
<td>Swin-T</td>
|
95 |
+
<td>O365,GoldG,Cap4M</td>
|
96 |
+
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
|
97 |
+
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">Github link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
|
98 |
+
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
|
99 |
+
</tr>
|
100 |
+
</tbody>
|
101 |
+
</table>
|
102 |
+
|
103 |
+
## Results
|
104 |
+
|
105 |
+
<details open>
|
106 |
+
<summary><font size="4">
|
107 |
+
COCO Object Detection Results
|
108 |
+
</font></summary>
|
109 |
+
<img src=".asset/COCO.png" alt="COCO" width="100%">
|
110 |
+
</details>
|
111 |
+
|
112 |
+
<details open>
|
113 |
+
<summary><font size="4">
|
114 |
+
ODinW Object Detection Results
|
115 |
+
</font></summary>
|
116 |
+
<img src=".asset/ODinW.png" alt="ODinW" width="100%">
|
117 |
+
</details>
|
118 |
+
|
119 |
+
<details open>
|
120 |
+
<summary><font size="4">
|
121 |
+
Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
|
122 |
+
</font></summary>
|
123 |
+
<img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
|
124 |
+
</details>
|
125 |
+
|
126 |
+
<details open>
|
127 |
+
<summary><font size="4">
|
128 |
+
Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing
|
129 |
+
</font></summary>
|
130 |
+
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
|
131 |
+
</details>
|
132 |
+
|
133 |
+
## Model
|
134 |
+
|
135 |
+
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
|
136 |
+
|
137 |
+
![arch](.asset/arch.png)
|
138 |
+
|
139 |
+
|
140 |
+
## Acknowledgement
|
141 |
+
|
142 |
+
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
|
143 |
+
|
144 |
+
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
|
145 |
+
|
146 |
+
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
|
147 |
+
|
148 |
+
|
149 |
+
## Citation
|
150 |
+
|
151 |
+
If you find our work helpful for your research, please consider citing the following BibTeX entry.
|
152 |
+
|
153 |
+
```bibtex
|
154 |
+
@inproceedings{ShilongLiu2023GroundingDM,
|
155 |
+
title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
|
156 |
+
author={Shilong Liu and Zhaoyang Zeng and Tianhe Ren and Feng Li and Hao Zhang and Jie Yang and Chunyuan Li and Jianwei Yang and Hang Su and Jun Zhu and Lei Zhang},
|
157 |
+
year={2023}
|
158 |
+
}
|
159 |
+
```
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
GroundingDINO/demo/gradio_app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from functools import partial
|
3 |
+
import cv2
|
4 |
+
import requests
|
5 |
+
import os
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# prepare the environment
|
17 |
+
os.system("python setup.py build develop --user")
|
18 |
+
os.system("pip install packaging==21.3")
|
19 |
+
os.system("pip install gradio")
|
20 |
+
|
21 |
+
|
22 |
+
warnings.filterwarnings("ignore")
|
23 |
+
|
24 |
+
import gradio as gr
|
25 |
+
|
26 |
+
from groundingdino.models import build_model
|
27 |
+
from groundingdino.util.slconfig import SLConfig
|
28 |
+
from groundingdino.util.utils import clean_state_dict
|
29 |
+
from groundingdino.util.inference import annotate, load_image, predict
|
30 |
+
import groundingdino.datasets.transforms as T
|
31 |
+
|
32 |
+
from huggingface_hub import hf_hub_download
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
# Use this command for evaluate the GLIP-T model
|
37 |
+
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
38 |
+
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
39 |
+
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
40 |
+
|
41 |
+
|
42 |
+
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
43 |
+
args = SLConfig.fromfile(model_config_path)
|
44 |
+
model = build_model(args)
|
45 |
+
args.device = device
|
46 |
+
|
47 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
|
48 |
+
checkpoint = torch.load(cache_file, map_location='cpu')
|
49 |
+
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
50 |
+
print("Model loaded from {} \n => {}".format(cache_file, log))
|
51 |
+
_ = model.eval()
|
52 |
+
return model
|
53 |
+
|
54 |
+
def image_transform_grounding(init_image):
|
55 |
+
transform = T.Compose([
|
56 |
+
T.RandomResize([800], max_size=1333),
|
57 |
+
T.ToTensor(),
|
58 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
59 |
+
])
|
60 |
+
image, _ = transform(init_image, None) # 3, h, w
|
61 |
+
return init_image, image
|
62 |
+
|
63 |
+
def image_transform_grounding_for_vis(init_image):
|
64 |
+
transform = T.Compose([
|
65 |
+
T.RandomResize([800], max_size=1333),
|
66 |
+
])
|
67 |
+
image, _ = transform(init_image, None) # 3, h, w
|
68 |
+
return image
|
69 |
+
|
70 |
+
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
71 |
+
|
72 |
+
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
|
73 |
+
init_image = input_image.convert("RGB")
|
74 |
+
original_size = init_image.size
|
75 |
+
|
76 |
+
_, image_tensor = image_transform_grounding(init_image)
|
77 |
+
image_pil: Image = image_transform_grounding_for_vis(init_image)
|
78 |
+
|
79 |
+
# run grounidng
|
80 |
+
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
|
81 |
+
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
|
82 |
+
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
83 |
+
|
84 |
+
|
85 |
+
return image_with_box
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
|
89 |
+
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
|
90 |
+
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
91 |
+
parser.add_argument("--share", action="store_true", help="share the app")
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
block = gr.Blocks().queue()
|
95 |
+
with block:
|
96 |
+
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
|
97 |
+
gr.Markdown("### Open-World Detection with Grounding DINO")
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
input_image = gr.Image(source='upload', type="pil")
|
102 |
+
grounding_caption = gr.Textbox(label="Detection Prompt")
|
103 |
+
run_button = gr.Button(label="Run")
|
104 |
+
with gr.Accordion("Advanced options", open=False):
|
105 |
+
box_threshold = gr.Slider(
|
106 |
+
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
107 |
+
)
|
108 |
+
text_threshold = gr.Slider(
|
109 |
+
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
110 |
+
)
|
111 |
+
|
112 |
+
with gr.Column():
|
113 |
+
gallery = gr.outputs.Image(
|
114 |
+
type="pil",
|
115 |
+
# label="grounding results"
|
116 |
+
).style(full_width=True, full_height=True)
|
117 |
+
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
118 |
+
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
|
119 |
+
|
120 |
+
run_button.click(fn=run_grounding, inputs=[
|
121 |
+
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
|
122 |
+
|
123 |
+
|
124 |
+
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
|
125 |
+
|
GroundingDINO/demo/inference_on_a_image.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
|
9 |
+
import groundingdino.datasets.transforms as T
|
10 |
+
from groundingdino.models import build_model
|
11 |
+
from groundingdino.util import box_ops
|
12 |
+
from groundingdino.util.slconfig import SLConfig
|
13 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
14 |
+
|
15 |
+
|
16 |
+
def plot_boxes_to_image(image_pil, tgt):
|
17 |
+
H, W = tgt["size"]
|
18 |
+
boxes = tgt["boxes"]
|
19 |
+
labels = tgt["labels"]
|
20 |
+
assert len(boxes) == len(labels), "boxes and labels must have same length"
|
21 |
+
|
22 |
+
draw = ImageDraw.Draw(image_pil)
|
23 |
+
mask = Image.new("L", image_pil.size, 0)
|
24 |
+
mask_draw = ImageDraw.Draw(mask)
|
25 |
+
|
26 |
+
# draw boxes and masks
|
27 |
+
for box, label in zip(boxes, labels):
|
28 |
+
# from 0..1 to 0..W, 0..H
|
29 |
+
box = box * torch.Tensor([W, H, W, H])
|
30 |
+
# from xywh to xyxy
|
31 |
+
box[:2] -= box[2:] / 2
|
32 |
+
box[2:] += box[:2]
|
33 |
+
# random color
|
34 |
+
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
35 |
+
# draw
|
36 |
+
x0, y0, x1, y1 = box
|
37 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
38 |
+
|
39 |
+
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
40 |
+
# draw.text((x0, y0), str(label), fill=color)
|
41 |
+
|
42 |
+
font = ImageFont.load_default()
|
43 |
+
if hasattr(font, "getbbox"):
|
44 |
+
bbox = draw.textbbox((x0, y0), str(label), font)
|
45 |
+
else:
|
46 |
+
w, h = draw.textsize(str(label), font)
|
47 |
+
bbox = (x0, y0, w + x0, y0 + h)
|
48 |
+
# bbox = draw.textbbox((x0, y0), str(label))
|
49 |
+
draw.rectangle(bbox, fill=color)
|
50 |
+
draw.text((x0, y0), str(label), fill="white")
|
51 |
+
|
52 |
+
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
|
53 |
+
|
54 |
+
return image_pil, mask
|
55 |
+
|
56 |
+
|
57 |
+
def load_image(image_path):
|
58 |
+
# load image
|
59 |
+
image_pil = Image.open(image_path).convert("RGB") # load image
|
60 |
+
|
61 |
+
transform = T.Compose(
|
62 |
+
[
|
63 |
+
T.RandomResize([800], max_size=1333),
|
64 |
+
T.ToTensor(),
|
65 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
69 |
+
return image_pil, image
|
70 |
+
|
71 |
+
|
72 |
+
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
73 |
+
args = SLConfig.fromfile(model_config_path)
|
74 |
+
args.device = "cuda" if not cpu_only else "cpu"
|
75 |
+
model = build_model(args)
|
76 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
77 |
+
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
78 |
+
print(load_res)
|
79 |
+
_ = model.eval()
|
80 |
+
return model
|
81 |
+
|
82 |
+
|
83 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
|
84 |
+
caption = caption.lower()
|
85 |
+
caption = caption.strip()
|
86 |
+
if not caption.endswith("."):
|
87 |
+
caption = caption + "."
|
88 |
+
device = "cuda" if not cpu_only else "cpu"
|
89 |
+
model = model.to(device)
|
90 |
+
image = image.to(device)
|
91 |
+
with torch.no_grad():
|
92 |
+
outputs = model(image[None], captions=[caption])
|
93 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
94 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
95 |
+
logits.shape[0]
|
96 |
+
|
97 |
+
# filter output
|
98 |
+
logits_filt = logits.clone()
|
99 |
+
boxes_filt = boxes.clone()
|
100 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
101 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
102 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
103 |
+
logits_filt.shape[0]
|
104 |
+
|
105 |
+
# get phrase
|
106 |
+
tokenlizer = model.tokenizer
|
107 |
+
tokenized = tokenlizer(caption)
|
108 |
+
# build pred
|
109 |
+
pred_phrases = []
|
110 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
111 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
112 |
+
if with_logits:
|
113 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
114 |
+
else:
|
115 |
+
pred_phrases.append(pred_phrase)
|
116 |
+
|
117 |
+
return boxes_filt, pred_phrases
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
|
122 |
+
parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
|
123 |
+
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
|
124 |
+
parser.add_argument(
|
125 |
+
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
126 |
+
)
|
127 |
+
parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
|
128 |
+
parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
|
129 |
+
parser.add_argument(
|
130 |
+
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
134 |
+
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
135 |
+
|
136 |
+
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
# cfg
|
140 |
+
config_file = args.config_file # change the path of the model config file
|
141 |
+
checkpoint_path = args.checkpoint_path # change the path of the model
|
142 |
+
image_path = args.image_path
|
143 |
+
text_prompt = args.text_prompt
|
144 |
+
output_dir = args.output_dir
|
145 |
+
box_threshold = args.box_threshold
|
146 |
+
text_threshold = args.text_threshold
|
147 |
+
|
148 |
+
# make dir
|
149 |
+
os.makedirs(output_dir, exist_ok=True)
|
150 |
+
# load image
|
151 |
+
image_pil, image = load_image(image_path)
|
152 |
+
# load model
|
153 |
+
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
|
154 |
+
|
155 |
+
# visualize raw image
|
156 |
+
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
157 |
+
|
158 |
+
# run model
|
159 |
+
boxes_filt, pred_phrases = get_grounding_output(
|
160 |
+
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
|
161 |
+
)
|
162 |
+
|
163 |
+
# visualize pred
|
164 |
+
size = image_pil.size
|
165 |
+
pred_dict = {
|
166 |
+
"boxes": boxes_filt,
|
167 |
+
"size": [size[1], size[0]], # H,W
|
168 |
+
"labels": pred_phrases,
|
169 |
+
}
|
170 |
+
# import ipdb; ipdb.set_trace()
|
171 |
+
image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
|
172 |
+
image_with_box.save(os.path.join(output_dir, "pred.jpg"))
|
GroundingDINO/groundingdino/__init__.py
ADDED
File without changes
|
GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size = 1
|
2 |
+
modelname = "groundingdino"
|
3 |
+
backbone = "swin_B_384_22k"
|
4 |
+
position_embedding = "sine"
|
5 |
+
pe_temperatureH = 20
|
6 |
+
pe_temperatureW = 20
|
7 |
+
return_interm_indices = [1, 2, 3]
|
8 |
+
backbone_freeze_keywords = None
|
9 |
+
enc_layers = 6
|
10 |
+
dec_layers = 6
|
11 |
+
pre_norm = False
|
12 |
+
dim_feedforward = 2048
|
13 |
+
hidden_dim = 256
|
14 |
+
dropout = 0.0
|
15 |
+
nheads = 8
|
16 |
+
num_queries = 900
|
17 |
+
query_dim = 4
|
18 |
+
num_patterns = 0
|
19 |
+
num_feature_levels = 4
|
20 |
+
enc_n_points = 4
|
21 |
+
dec_n_points = 4
|
22 |
+
two_stage_type = "standard"
|
23 |
+
two_stage_bbox_embed_share = False
|
24 |
+
two_stage_class_embed_share = False
|
25 |
+
transformer_activation = "relu"
|
26 |
+
dec_pred_bbox_embed_share = True
|
27 |
+
dn_box_noise_scale = 1.0
|
28 |
+
dn_label_noise_ratio = 0.5
|
29 |
+
dn_label_coef = 1.0
|
30 |
+
dn_bbox_coef = 1.0
|
31 |
+
embed_init_tgt = True
|
32 |
+
dn_labelbook_size = 2000
|
33 |
+
max_text_len = 256
|
34 |
+
text_encoder_type = "bert-base-uncased"
|
35 |
+
use_text_enhancer = True
|
36 |
+
use_fusion_layer = True
|
37 |
+
use_checkpoint = True
|
38 |
+
use_transformer_ckpt = True
|
39 |
+
use_text_cross_attention = True
|
40 |
+
text_dropout = 0.0
|
41 |
+
fusion_dropout = 0.0
|
42 |
+
fusion_droppath = 0.1
|
43 |
+
sub_sentence_present = True
|
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size = 1
|
2 |
+
modelname = "groundingdino"
|
3 |
+
backbone = "swin_T_224_1k"
|
4 |
+
position_embedding = "sine"
|
5 |
+
pe_temperatureH = 20
|
6 |
+
pe_temperatureW = 20
|
7 |
+
return_interm_indices = [1, 2, 3]
|
8 |
+
backbone_freeze_keywords = None
|
9 |
+
enc_layers = 6
|
10 |
+
dec_layers = 6
|
11 |
+
pre_norm = False
|
12 |
+
dim_feedforward = 2048
|
13 |
+
hidden_dim = 256
|
14 |
+
dropout = 0.0
|
15 |
+
nheads = 8
|
16 |
+
num_queries = 900
|
17 |
+
query_dim = 4
|
18 |
+
num_patterns = 0
|
19 |
+
num_feature_levels = 4
|
20 |
+
enc_n_points = 4
|
21 |
+
dec_n_points = 4
|
22 |
+
two_stage_type = "standard"
|
23 |
+
two_stage_bbox_embed_share = False
|
24 |
+
two_stage_class_embed_share = False
|
25 |
+
transformer_activation = "relu"
|
26 |
+
dec_pred_bbox_embed_share = True
|
27 |
+
dn_box_noise_scale = 1.0
|
28 |
+
dn_label_noise_ratio = 0.5
|
29 |
+
dn_label_coef = 1.0
|
30 |
+
dn_bbox_coef = 1.0
|
31 |
+
embed_init_tgt = True
|
32 |
+
dn_labelbook_size = 2000
|
33 |
+
max_text_len = 256
|
34 |
+
text_encoder_type = "bert-base-uncased"
|
35 |
+
use_text_enhancer = True
|
36 |
+
use_fusion_layer = True
|
37 |
+
use_checkpoint = True
|
38 |
+
use_transformer_ckpt = True
|
39 |
+
use_text_cross_attention = True
|
40 |
+
text_dropout = 0.0
|
41 |
+
fusion_dropout = 0.0
|
42 |
+
fusion_droppath = 0.1
|
43 |
+
sub_sentence_present = True
|
GroundingDINO/groundingdino/datasets/__init__.py
ADDED
File without changes
|
GroundingDINO/groundingdino/datasets/transforms.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Transforms and data augmentation for both image + bbox.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as T
|
11 |
+
import torchvision.transforms.functional as F
|
12 |
+
|
13 |
+
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
14 |
+
from groundingdino.util.misc import interpolate
|
15 |
+
|
16 |
+
|
17 |
+
def crop(image, target, region):
|
18 |
+
cropped_image = F.crop(image, *region)
|
19 |
+
|
20 |
+
target = target.copy()
|
21 |
+
i, j, h, w = region
|
22 |
+
|
23 |
+
# should we do something wrt the original size?
|
24 |
+
target["size"] = torch.tensor([h, w])
|
25 |
+
|
26 |
+
fields = ["labels", "area", "iscrowd", "positive_map"]
|
27 |
+
|
28 |
+
if "boxes" in target:
|
29 |
+
boxes = target["boxes"]
|
30 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
31 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
32 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
33 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
34 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
35 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
36 |
+
target["area"] = area
|
37 |
+
fields.append("boxes")
|
38 |
+
|
39 |
+
if "masks" in target:
|
40 |
+
# FIXME should we update the area here if there are no boxes?
|
41 |
+
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
42 |
+
fields.append("masks")
|
43 |
+
|
44 |
+
# remove elements for which the boxes or masks that have zero area
|
45 |
+
if "boxes" in target or "masks" in target:
|
46 |
+
# favor boxes selection when defining which elements to keep
|
47 |
+
# this is compatible with previous implementation
|
48 |
+
if "boxes" in target:
|
49 |
+
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
50 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
51 |
+
else:
|
52 |
+
keep = target["masks"].flatten(1).any(1)
|
53 |
+
|
54 |
+
for field in fields:
|
55 |
+
if field in target:
|
56 |
+
target[field] = target[field][keep]
|
57 |
+
|
58 |
+
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
59 |
+
# for debug and visualization only.
|
60 |
+
if "strings_positive" in target:
|
61 |
+
target["strings_positive"] = [
|
62 |
+
_i for _i, _j in zip(target["strings_positive"], keep) if _j
|
63 |
+
]
|
64 |
+
|
65 |
+
return cropped_image, target
|
66 |
+
|
67 |
+
|
68 |
+
def hflip(image, target):
|
69 |
+
flipped_image = F.hflip(image)
|
70 |
+
|
71 |
+
w, h = image.size
|
72 |
+
|
73 |
+
target = target.copy()
|
74 |
+
if "boxes" in target:
|
75 |
+
boxes = target["boxes"]
|
76 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
|
77 |
+
[w, 0, w, 0]
|
78 |
+
)
|
79 |
+
target["boxes"] = boxes
|
80 |
+
|
81 |
+
if "masks" in target:
|
82 |
+
target["masks"] = target["masks"].flip(-1)
|
83 |
+
|
84 |
+
return flipped_image, target
|
85 |
+
|
86 |
+
|
87 |
+
def resize(image, target, size, max_size=None):
|
88 |
+
# size can be min_size (scalar) or (w, h) tuple
|
89 |
+
|
90 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
91 |
+
w, h = image_size
|
92 |
+
if max_size is not None:
|
93 |
+
min_original_size = float(min((w, h)))
|
94 |
+
max_original_size = float(max((w, h)))
|
95 |
+
if max_original_size / min_original_size * size > max_size:
|
96 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
97 |
+
|
98 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
99 |
+
return (h, w)
|
100 |
+
|
101 |
+
if w < h:
|
102 |
+
ow = size
|
103 |
+
oh = int(size * h / w)
|
104 |
+
else:
|
105 |
+
oh = size
|
106 |
+
ow = int(size * w / h)
|
107 |
+
|
108 |
+
return (oh, ow)
|
109 |
+
|
110 |
+
def get_size(image_size, size, max_size=None):
|
111 |
+
if isinstance(size, (list, tuple)):
|
112 |
+
return size[::-1]
|
113 |
+
else:
|
114 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
115 |
+
|
116 |
+
size = get_size(image.size, size, max_size)
|
117 |
+
rescaled_image = F.resize(image, size)
|
118 |
+
|
119 |
+
if target is None:
|
120 |
+
return rescaled_image, None
|
121 |
+
|
122 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
123 |
+
ratio_width, ratio_height = ratios
|
124 |
+
|
125 |
+
target = target.copy()
|
126 |
+
if "boxes" in target:
|
127 |
+
boxes = target["boxes"]
|
128 |
+
scaled_boxes = boxes * torch.as_tensor(
|
129 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
130 |
+
)
|
131 |
+
target["boxes"] = scaled_boxes
|
132 |
+
|
133 |
+
if "area" in target:
|
134 |
+
area = target["area"]
|
135 |
+
scaled_area = area * (ratio_width * ratio_height)
|
136 |
+
target["area"] = scaled_area
|
137 |
+
|
138 |
+
h, w = size
|
139 |
+
target["size"] = torch.tensor([h, w])
|
140 |
+
|
141 |
+
if "masks" in target:
|
142 |
+
target["masks"] = (
|
143 |
+
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
144 |
+
)
|
145 |
+
|
146 |
+
return rescaled_image, target
|
147 |
+
|
148 |
+
|
149 |
+
def pad(image, target, padding):
|
150 |
+
# assumes that we only pad on the bottom right corners
|
151 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
152 |
+
if target is None:
|
153 |
+
return padded_image, None
|
154 |
+
target = target.copy()
|
155 |
+
# should we do something wrt the original size?
|
156 |
+
target["size"] = torch.tensor(padded_image.size[::-1])
|
157 |
+
if "masks" in target:
|
158 |
+
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
|
159 |
+
return padded_image, target
|
160 |
+
|
161 |
+
|
162 |
+
class ResizeDebug(object):
|
163 |
+
def __init__(self, size):
|
164 |
+
self.size = size
|
165 |
+
|
166 |
+
def __call__(self, img, target):
|
167 |
+
return resize(img, target, self.size)
|
168 |
+
|
169 |
+
|
170 |
+
class RandomCrop(object):
|
171 |
+
def __init__(self, size):
|
172 |
+
self.size = size
|
173 |
+
|
174 |
+
def __call__(self, img, target):
|
175 |
+
region = T.RandomCrop.get_params(img, self.size)
|
176 |
+
return crop(img, target, region)
|
177 |
+
|
178 |
+
|
179 |
+
class RandomSizeCrop(object):
|
180 |
+
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
|
181 |
+
# respect_boxes: True to keep all boxes
|
182 |
+
# False to tolerence box filter
|
183 |
+
self.min_size = min_size
|
184 |
+
self.max_size = max_size
|
185 |
+
self.respect_boxes = respect_boxes
|
186 |
+
|
187 |
+
def __call__(self, img: PIL.Image.Image, target: dict):
|
188 |
+
init_boxes = len(target["boxes"])
|
189 |
+
max_patience = 10
|
190 |
+
for i in range(max_patience):
|
191 |
+
w = random.randint(self.min_size, min(img.width, self.max_size))
|
192 |
+
h = random.randint(self.min_size, min(img.height, self.max_size))
|
193 |
+
region = T.RandomCrop.get_params(img, [h, w])
|
194 |
+
result_img, result_target = crop(img, target, region)
|
195 |
+
if (
|
196 |
+
not self.respect_boxes
|
197 |
+
or len(result_target["boxes"]) == init_boxes
|
198 |
+
or i == max_patience - 1
|
199 |
+
):
|
200 |
+
return result_img, result_target
|
201 |
+
return result_img, result_target
|
202 |
+
|
203 |
+
|
204 |
+
class CenterCrop(object):
|
205 |
+
def __init__(self, size):
|
206 |
+
self.size = size
|
207 |
+
|
208 |
+
def __call__(self, img, target):
|
209 |
+
image_width, image_height = img.size
|
210 |
+
crop_height, crop_width = self.size
|
211 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
212 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
213 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
214 |
+
|
215 |
+
|
216 |
+
class RandomHorizontalFlip(object):
|
217 |
+
def __init__(self, p=0.5):
|
218 |
+
self.p = p
|
219 |
+
|
220 |
+
def __call__(self, img, target):
|
221 |
+
if random.random() < self.p:
|
222 |
+
return hflip(img, target)
|
223 |
+
return img, target
|
224 |
+
|
225 |
+
|
226 |
+
class RandomResize(object):
|
227 |
+
def __init__(self, sizes, max_size=None):
|
228 |
+
assert isinstance(sizes, (list, tuple))
|
229 |
+
self.sizes = sizes
|
230 |
+
self.max_size = max_size
|
231 |
+
|
232 |
+
def __call__(self, img, target=None):
|
233 |
+
size = random.choice(self.sizes)
|
234 |
+
return resize(img, target, size, self.max_size)
|
235 |
+
|
236 |
+
|
237 |
+
class RandomPad(object):
|
238 |
+
def __init__(self, max_pad):
|
239 |
+
self.max_pad = max_pad
|
240 |
+
|
241 |
+
def __call__(self, img, target):
|
242 |
+
pad_x = random.randint(0, self.max_pad)
|
243 |
+
pad_y = random.randint(0, self.max_pad)
|
244 |
+
return pad(img, target, (pad_x, pad_y))
|
245 |
+
|
246 |
+
|
247 |
+
class RandomSelect(object):
|
248 |
+
"""
|
249 |
+
Randomly selects between transforms1 and transforms2,
|
250 |
+
with probability p for transforms1 and (1 - p) for transforms2
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(self, transforms1, transforms2, p=0.5):
|
254 |
+
self.transforms1 = transforms1
|
255 |
+
self.transforms2 = transforms2
|
256 |
+
self.p = p
|
257 |
+
|
258 |
+
def __call__(self, img, target):
|
259 |
+
if random.random() < self.p:
|
260 |
+
return self.transforms1(img, target)
|
261 |
+
return self.transforms2(img, target)
|
262 |
+
|
263 |
+
|
264 |
+
class ToTensor(object):
|
265 |
+
def __call__(self, img, target):
|
266 |
+
return F.to_tensor(img), target
|
267 |
+
|
268 |
+
|
269 |
+
class RandomErasing(object):
|
270 |
+
def __init__(self, *args, **kwargs):
|
271 |
+
self.eraser = T.RandomErasing(*args, **kwargs)
|
272 |
+
|
273 |
+
def __call__(self, img, target):
|
274 |
+
return self.eraser(img), target
|
275 |
+
|
276 |
+
|
277 |
+
class Normalize(object):
|
278 |
+
def __init__(self, mean, std):
|
279 |
+
self.mean = mean
|
280 |
+
self.std = std
|
281 |
+
|
282 |
+
def __call__(self, image, target=None):
|
283 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
284 |
+
if target is None:
|
285 |
+
return image, None
|
286 |
+
target = target.copy()
|
287 |
+
h, w = image.shape[-2:]
|
288 |
+
if "boxes" in target:
|
289 |
+
boxes = target["boxes"]
|
290 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
291 |
+
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
292 |
+
target["boxes"] = boxes
|
293 |
+
return image, target
|
294 |
+
|
295 |
+
|
296 |
+
class Compose(object):
|
297 |
+
def __init__(self, transforms):
|
298 |
+
self.transforms = transforms
|
299 |
+
|
300 |
+
def __call__(self, image, target):
|
301 |
+
for t in self.transforms:
|
302 |
+
image, target = t(image, target)
|
303 |
+
return image, target
|
304 |
+
|
305 |
+
def __repr__(self):
|
306 |
+
format_string = self.__class__.__name__ + "("
|
307 |
+
for t in self.transforms:
|
308 |
+
format_string += "\n"
|
309 |
+
format_string += " {0}".format(t)
|
310 |
+
format_string += "\n)"
|
311 |
+
return format_string
|
GroundingDINO/groundingdino/models/GroundingDINO/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Grounding DINO
|
3 |
+
# url: https://github.com/IDEA-Research/GroundingDINO
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# Conditional DETR
|
8 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
12 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
13 |
+
# ------------------------------------------------------------------------
|
14 |
+
|
15 |
+
from .groundingdino import build_groundingdino
|
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .backbone import build_backbone
|
GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Grounding DINO
|
3 |
+
# url: https://github.com/IDEA-Research/GroundingDINO
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# Conditional DETR
|
8 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
12 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
13 |
+
# ------------------------------------------------------------------------
|
14 |
+
|
15 |
+
"""
|
16 |
+
Backbone modules.
|
17 |
+
"""
|
18 |
+
|
19 |
+
from typing import Dict, List
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torchvision
|
24 |
+
from torch import nn
|
25 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
26 |
+
|
27 |
+
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
|
28 |
+
|
29 |
+
from .position_encoding import build_position_encoding
|
30 |
+
from .swin_transformer import build_swin_transformer
|
31 |
+
|
32 |
+
|
33 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
34 |
+
"""
|
35 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
36 |
+
|
37 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
38 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
39 |
+
produce nans.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, n):
|
43 |
+
super(FrozenBatchNorm2d, self).__init__()
|
44 |
+
self.register_buffer("weight", torch.ones(n))
|
45 |
+
self.register_buffer("bias", torch.zeros(n))
|
46 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
47 |
+
self.register_buffer("running_var", torch.ones(n))
|
48 |
+
|
49 |
+
def _load_from_state_dict(
|
50 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
51 |
+
):
|
52 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
53 |
+
if num_batches_tracked_key in state_dict:
|
54 |
+
del state_dict[num_batches_tracked_key]
|
55 |
+
|
56 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
57 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
# move reshapes to the beginning
|
62 |
+
# to make it fuser-friendly
|
63 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
64 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
65 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
66 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
67 |
+
eps = 1e-5
|
68 |
+
scale = w * (rv + eps).rsqrt()
|
69 |
+
bias = b - rm * scale
|
70 |
+
return x * scale + bias
|
71 |
+
|
72 |
+
|
73 |
+
class BackboneBase(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
backbone: nn.Module,
|
77 |
+
train_backbone: bool,
|
78 |
+
num_channels: int,
|
79 |
+
return_interm_indices: list,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
for name, parameter in backbone.named_parameters():
|
83 |
+
if (
|
84 |
+
not train_backbone
|
85 |
+
or "layer2" not in name
|
86 |
+
and "layer3" not in name
|
87 |
+
and "layer4" not in name
|
88 |
+
):
|
89 |
+
parameter.requires_grad_(False)
|
90 |
+
|
91 |
+
return_layers = {}
|
92 |
+
for idx, layer_index in enumerate(return_interm_indices):
|
93 |
+
return_layers.update(
|
94 |
+
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
|
95 |
+
)
|
96 |
+
|
97 |
+
# if len:
|
98 |
+
# if use_stage1_feature:
|
99 |
+
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
100 |
+
# else:
|
101 |
+
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
102 |
+
# else:
|
103 |
+
# return_layers = {'layer4': "0"}
|
104 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
105 |
+
self.num_channels = num_channels
|
106 |
+
|
107 |
+
def forward(self, tensor_list: NestedTensor):
|
108 |
+
xs = self.body(tensor_list.tensors)
|
109 |
+
out: Dict[str, NestedTensor] = {}
|
110 |
+
for name, x in xs.items():
|
111 |
+
m = tensor_list.mask
|
112 |
+
assert m is not None
|
113 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
114 |
+
out[name] = NestedTensor(x, mask)
|
115 |
+
# import ipdb; ipdb.set_trace()
|
116 |
+
return out
|
117 |
+
|
118 |
+
|
119 |
+
class Backbone(BackboneBase):
|
120 |
+
"""ResNet backbone with frozen BatchNorm."""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
name: str,
|
125 |
+
train_backbone: bool,
|
126 |
+
dilation: bool,
|
127 |
+
return_interm_indices: list,
|
128 |
+
batch_norm=FrozenBatchNorm2d,
|
129 |
+
):
|
130 |
+
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
|
131 |
+
backbone = getattr(torchvision.models, name)(
|
132 |
+
replace_stride_with_dilation=[False, False, dilation],
|
133 |
+
pretrained=is_main_process(),
|
134 |
+
norm_layer=batch_norm,
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
raise NotImplementedError("Why you can get here with name {}".format(name))
|
138 |
+
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
139 |
+
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
|
140 |
+
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
141 |
+
num_channels_all = [256, 512, 1024, 2048]
|
142 |
+
num_channels = num_channels_all[4 - len(return_interm_indices) :]
|
143 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
|
144 |
+
|
145 |
+
|
146 |
+
class Joiner(nn.Sequential):
|
147 |
+
def __init__(self, backbone, position_embedding):
|
148 |
+
super().__init__(backbone, position_embedding)
|
149 |
+
|
150 |
+
def forward(self, tensor_list: NestedTensor):
|
151 |
+
xs = self[0](tensor_list)
|
152 |
+
out: List[NestedTensor] = []
|
153 |
+
pos = []
|
154 |
+
for name, x in xs.items():
|
155 |
+
out.append(x)
|
156 |
+
# position encoding
|
157 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
158 |
+
|
159 |
+
return out, pos
|
160 |
+
|
161 |
+
|
162 |
+
def build_backbone(args):
|
163 |
+
"""
|
164 |
+
Useful args:
|
165 |
+
- backbone: backbone name
|
166 |
+
- lr_backbone:
|
167 |
+
- dilation
|
168 |
+
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
|
169 |
+
- backbone_freeze_keywords:
|
170 |
+
- use_checkpoint: for swin only for now
|
171 |
+
|
172 |
+
"""
|
173 |
+
position_embedding = build_position_encoding(args)
|
174 |
+
train_backbone = True
|
175 |
+
if not train_backbone:
|
176 |
+
raise ValueError("Please set lr_backbone > 0")
|
177 |
+
return_interm_indices = args.return_interm_indices
|
178 |
+
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
179 |
+
args.backbone_freeze_keywords
|
180 |
+
use_checkpoint = getattr(args, "use_checkpoint", False)
|
181 |
+
|
182 |
+
if args.backbone in ["resnet50", "resnet101"]:
|
183 |
+
backbone = Backbone(
|
184 |
+
args.backbone,
|
185 |
+
train_backbone,
|
186 |
+
args.dilation,
|
187 |
+
return_interm_indices,
|
188 |
+
batch_norm=FrozenBatchNorm2d,
|
189 |
+
)
|
190 |
+
bb_num_channels = backbone.num_channels
|
191 |
+
elif args.backbone in [
|
192 |
+
"swin_T_224_1k",
|
193 |
+
"swin_B_224_22k",
|
194 |
+
"swin_B_384_22k",
|
195 |
+
"swin_L_224_22k",
|
196 |
+
"swin_L_384_22k",
|
197 |
+
]:
|
198 |
+
pretrain_img_size = int(args.backbone.split("_")[-2])
|
199 |
+
backbone = build_swin_transformer(
|
200 |
+
args.backbone,
|
201 |
+
pretrain_img_size=pretrain_img_size,
|
202 |
+
out_indices=tuple(return_interm_indices),
|
203 |
+
dilation=False,
|
204 |
+
use_checkpoint=use_checkpoint,
|
205 |
+
)
|
206 |
+
|
207 |
+
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
|
208 |
+
else:
|
209 |
+
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
|
210 |
+
|
211 |
+
assert len(bb_num_channels) == len(
|
212 |
+
return_interm_indices
|
213 |
+
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
214 |
+
|
215 |
+
model = Joiner(backbone, position_embedding)
|
216 |
+
model.num_channels = bb_num_channels
|
217 |
+
assert isinstance(
|
218 |
+
bb_num_channels, List
|
219 |
+
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
220 |
+
# import ipdb; ipdb.set_trace()
|
221 |
+
return model
|
GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Grounding DINO
|
3 |
+
# url: https://github.com/IDEA-Research/GroundingDINO
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# DINO
|
8 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# Conditional DETR
|
12 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
13 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
16 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
17 |
+
# ------------------------------------------------------------------------
|
18 |
+
|
19 |
+
"""
|
20 |
+
Various positional encodings for the transformer.
|
21 |
+
"""
|
22 |
+
import math
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
from groundingdino.util.misc import NestedTensor
|
28 |
+
|
29 |
+
|
30 |
+
class PositionEmbeddingSine(nn.Module):
|
31 |
+
"""
|
32 |
+
This is a more standard version of the position embedding, very similar to the one
|
33 |
+
used by the Attention is all you need paper, generalized to work on images.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
37 |
+
super().__init__()
|
38 |
+
self.num_pos_feats = num_pos_feats
|
39 |
+
self.temperature = temperature
|
40 |
+
self.normalize = normalize
|
41 |
+
if scale is not None and normalize is False:
|
42 |
+
raise ValueError("normalize should be True if scale is passed")
|
43 |
+
if scale is None:
|
44 |
+
scale = 2 * math.pi
|
45 |
+
self.scale = scale
|
46 |
+
|
47 |
+
def forward(self, tensor_list: NestedTensor):
|
48 |
+
x = tensor_list.tensors
|
49 |
+
mask = tensor_list.mask
|
50 |
+
assert mask is not None
|
51 |
+
not_mask = ~mask
|
52 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
53 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
54 |
+
if self.normalize:
|
55 |
+
eps = 1e-6
|
56 |
+
# if os.environ.get("SHILONG_AMP", None) == '1':
|
57 |
+
# eps = 1e-4
|
58 |
+
# else:
|
59 |
+
# eps = 1e-6
|
60 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
61 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
62 |
+
|
63 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
64 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
65 |
+
|
66 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
67 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
68 |
+
pos_x = torch.stack(
|
69 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
70 |
+
).flatten(3)
|
71 |
+
pos_y = torch.stack(
|
72 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
73 |
+
).flatten(3)
|
74 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
75 |
+
return pos
|
76 |
+
|
77 |
+
|
78 |
+
class PositionEmbeddingSineHW(nn.Module):
|
79 |
+
"""
|
80 |
+
This is a more standard version of the position embedding, very similar to the one
|
81 |
+
used by the Attention is all you need paper, generalized to work on images.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
self.num_pos_feats = num_pos_feats
|
89 |
+
self.temperatureH = temperatureH
|
90 |
+
self.temperatureW = temperatureW
|
91 |
+
self.normalize = normalize
|
92 |
+
if scale is not None and normalize is False:
|
93 |
+
raise ValueError("normalize should be True if scale is passed")
|
94 |
+
if scale is None:
|
95 |
+
scale = 2 * math.pi
|
96 |
+
self.scale = scale
|
97 |
+
|
98 |
+
def forward(self, tensor_list: NestedTensor):
|
99 |
+
x = tensor_list.tensors
|
100 |
+
mask = tensor_list.mask
|
101 |
+
assert mask is not None
|
102 |
+
not_mask = ~mask
|
103 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
104 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
105 |
+
|
106 |
+
# import ipdb; ipdb.set_trace()
|
107 |
+
|
108 |
+
if self.normalize:
|
109 |
+
eps = 1e-6
|
110 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
111 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
112 |
+
|
113 |
+
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
114 |
+
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
|
115 |
+
pos_x = x_embed[:, :, :, None] / dim_tx
|
116 |
+
|
117 |
+
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
118 |
+
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
|
119 |
+
pos_y = y_embed[:, :, :, None] / dim_ty
|
120 |
+
|
121 |
+
pos_x = torch.stack(
|
122 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
123 |
+
).flatten(3)
|
124 |
+
pos_y = torch.stack(
|
125 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
126 |
+
).flatten(3)
|
127 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
128 |
+
|
129 |
+
# import ipdb; ipdb.set_trace()
|
130 |
+
|
131 |
+
return pos
|
132 |
+
|
133 |
+
|
134 |
+
class PositionEmbeddingLearned(nn.Module):
|
135 |
+
"""
|
136 |
+
Absolute pos embedding, learned.
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, num_pos_feats=256):
|
140 |
+
super().__init__()
|
141 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
142 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
143 |
+
self.reset_parameters()
|
144 |
+
|
145 |
+
def reset_parameters(self):
|
146 |
+
nn.init.uniform_(self.row_embed.weight)
|
147 |
+
nn.init.uniform_(self.col_embed.weight)
|
148 |
+
|
149 |
+
def forward(self, tensor_list: NestedTensor):
|
150 |
+
x = tensor_list.tensors
|
151 |
+
h, w = x.shape[-2:]
|
152 |
+
i = torch.arange(w, device=x.device)
|
153 |
+
j = torch.arange(h, device=x.device)
|
154 |
+
x_emb = self.col_embed(i)
|
155 |
+
y_emb = self.row_embed(j)
|
156 |
+
pos = (
|
157 |
+
torch.cat(
|
158 |
+
[
|
159 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
160 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
161 |
+
],
|
162 |
+
dim=-1,
|
163 |
+
)
|
164 |
+
.permute(2, 0, 1)
|
165 |
+
.unsqueeze(0)
|
166 |
+
.repeat(x.shape[0], 1, 1, 1)
|
167 |
+
)
|
168 |
+
return pos
|
169 |
+
|
170 |
+
|
171 |
+
def build_position_encoding(args):
|
172 |
+
N_steps = args.hidden_dim // 2
|
173 |
+
if args.position_embedding in ("v2", "sine"):
|
174 |
+
# TODO find a better way of exposing other arguments
|
175 |
+
position_embedding = PositionEmbeddingSineHW(
|
176 |
+
N_steps,
|
177 |
+
temperatureH=args.pe_temperatureH,
|
178 |
+
temperatureW=args.pe_temperatureW,
|
179 |
+
normalize=True,
|
180 |
+
)
|
181 |
+
elif args.position_embedding in ("v3", "learned"):
|
182 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
183 |
+
else:
|
184 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
185 |
+
|
186 |
+
return position_embedding
|
GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py
ADDED
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Grounding DINO
|
3 |
+
# url: https://github.com/IDEA-Research/GroundingDINO
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# DINO
|
8 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# --------------------------------------------------------
|
11 |
+
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
|
12 |
+
# --------------------------------------------------------
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.utils.checkpoint as checkpoint
|
19 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
20 |
+
|
21 |
+
from groundingdino.util.misc import NestedTensor
|
22 |
+
|
23 |
+
|
24 |
+
class Mlp(nn.Module):
|
25 |
+
"""Multilayer perceptron."""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_features
|
32 |
+
hidden_features = hidden_features or in_features
|
33 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
34 |
+
self.act = act_layer()
|
35 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
36 |
+
self.drop = nn.Dropout(drop)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.fc1(x)
|
40 |
+
x = self.act(x)
|
41 |
+
x = self.drop(x)
|
42 |
+
x = self.fc2(x)
|
43 |
+
x = self.drop(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def window_partition(x, window_size):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
x: (B, H, W, C)
|
51 |
+
window_size (int): window size
|
52 |
+
Returns:
|
53 |
+
windows: (num_windows*B, window_size, window_size, C)
|
54 |
+
"""
|
55 |
+
B, H, W, C = x.shape
|
56 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
57 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
58 |
+
return windows
|
59 |
+
|
60 |
+
|
61 |
+
def window_reverse(windows, window_size, H, W):
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
windows: (num_windows*B, window_size, window_size, C)
|
65 |
+
window_size (int): Window size
|
66 |
+
H (int): Height of image
|
67 |
+
W (int): Width of image
|
68 |
+
Returns:
|
69 |
+
x: (B, H, W, C)
|
70 |
+
"""
|
71 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
72 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
73 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class WindowAttention(nn.Module):
|
78 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
79 |
+
It supports both of shifted and non-shifted window.
|
80 |
+
Args:
|
81 |
+
dim (int): Number of input channels.
|
82 |
+
window_size (tuple[int]): The height and width of the window.
|
83 |
+
num_heads (int): Number of attention heads.
|
84 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
85 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
86 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
87 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
dim,
|
93 |
+
window_size,
|
94 |
+
num_heads,
|
95 |
+
qkv_bias=True,
|
96 |
+
qk_scale=None,
|
97 |
+
attn_drop=0.0,
|
98 |
+
proj_drop=0.0,
|
99 |
+
):
|
100 |
+
|
101 |
+
super().__init__()
|
102 |
+
self.dim = dim
|
103 |
+
self.window_size = window_size # Wh, Ww
|
104 |
+
self.num_heads = num_heads
|
105 |
+
head_dim = dim // num_heads
|
106 |
+
self.scale = qk_scale or head_dim**-0.5
|
107 |
+
|
108 |
+
# define a parameter table of relative position bias
|
109 |
+
self.relative_position_bias_table = nn.Parameter(
|
110 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
111 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
112 |
+
|
113 |
+
# get pair-wise relative position index for each token inside the window
|
114 |
+
coords_h = torch.arange(self.window_size[0])
|
115 |
+
coords_w = torch.arange(self.window_size[1])
|
116 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
117 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
118 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
119 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
120 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
121 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
122 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
123 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
124 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
125 |
+
|
126 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
127 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
128 |
+
self.proj = nn.Linear(dim, dim)
|
129 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
130 |
+
|
131 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
132 |
+
self.softmax = nn.Softmax(dim=-1)
|
133 |
+
|
134 |
+
def forward(self, x, mask=None):
|
135 |
+
"""Forward function.
|
136 |
+
Args:
|
137 |
+
x: input features with shape of (num_windows*B, N, C)
|
138 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
139 |
+
"""
|
140 |
+
B_, N, C = x.shape
|
141 |
+
qkv = (
|
142 |
+
self.qkv(x)
|
143 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
144 |
+
.permute(2, 0, 3, 1, 4)
|
145 |
+
)
|
146 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
147 |
+
|
148 |
+
q = q * self.scale
|
149 |
+
attn = q @ k.transpose(-2, -1)
|
150 |
+
|
151 |
+
relative_position_bias = self.relative_position_bias_table[
|
152 |
+
self.relative_position_index.view(-1)
|
153 |
+
].view(
|
154 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
155 |
+
) # Wh*Ww,Wh*Ww,nH
|
156 |
+
relative_position_bias = relative_position_bias.permute(
|
157 |
+
2, 0, 1
|
158 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
+
|
161 |
+
if mask is not None:
|
162 |
+
nW = mask.shape[0]
|
163 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
else:
|
167 |
+
attn = self.softmax(attn)
|
168 |
+
|
169 |
+
attn = self.attn_drop(attn)
|
170 |
+
|
171 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
+
x = self.proj(x)
|
173 |
+
x = self.proj_drop(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class SwinTransformerBlock(nn.Module):
|
178 |
+
"""Swin Transformer Block.
|
179 |
+
Args:
|
180 |
+
dim (int): Number of input channels.
|
181 |
+
num_heads (int): Number of attention heads.
|
182 |
+
window_size (int): Window size.
|
183 |
+
shift_size (int): Shift size for SW-MSA.
|
184 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
185 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
186 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
187 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
188 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
189 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
190 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
191 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
dim,
|
197 |
+
num_heads,
|
198 |
+
window_size=7,
|
199 |
+
shift_size=0,
|
200 |
+
mlp_ratio=4.0,
|
201 |
+
qkv_bias=True,
|
202 |
+
qk_scale=None,
|
203 |
+
drop=0.0,
|
204 |
+
attn_drop=0.0,
|
205 |
+
drop_path=0.0,
|
206 |
+
act_layer=nn.GELU,
|
207 |
+
norm_layer=nn.LayerNorm,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
self.dim = dim
|
211 |
+
self.num_heads = num_heads
|
212 |
+
self.window_size = window_size
|
213 |
+
self.shift_size = shift_size
|
214 |
+
self.mlp_ratio = mlp_ratio
|
215 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
216 |
+
|
217 |
+
self.norm1 = norm_layer(dim)
|
218 |
+
self.attn = WindowAttention(
|
219 |
+
dim,
|
220 |
+
window_size=to_2tuple(self.window_size),
|
221 |
+
num_heads=num_heads,
|
222 |
+
qkv_bias=qkv_bias,
|
223 |
+
qk_scale=qk_scale,
|
224 |
+
attn_drop=attn_drop,
|
225 |
+
proj_drop=drop,
|
226 |
+
)
|
227 |
+
|
228 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
229 |
+
self.norm2 = norm_layer(dim)
|
230 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
231 |
+
self.mlp = Mlp(
|
232 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
233 |
+
)
|
234 |
+
|
235 |
+
self.H = None
|
236 |
+
self.W = None
|
237 |
+
|
238 |
+
def forward(self, x, mask_matrix):
|
239 |
+
"""Forward function.
|
240 |
+
Args:
|
241 |
+
x: Input feature, tensor size (B, H*W, C).
|
242 |
+
H, W: Spatial resolution of the input feature.
|
243 |
+
mask_matrix: Attention mask for cyclic shift.
|
244 |
+
"""
|
245 |
+
B, L, C = x.shape
|
246 |
+
H, W = self.H, self.W
|
247 |
+
assert L == H * W, "input feature has wrong size"
|
248 |
+
|
249 |
+
shortcut = x
|
250 |
+
x = self.norm1(x)
|
251 |
+
x = x.view(B, H, W, C)
|
252 |
+
|
253 |
+
# pad feature maps to multiples of window size
|
254 |
+
pad_l = pad_t = 0
|
255 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
256 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
257 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
258 |
+
_, Hp, Wp, _ = x.shape
|
259 |
+
|
260 |
+
# cyclic shift
|
261 |
+
if self.shift_size > 0:
|
262 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
263 |
+
attn_mask = mask_matrix
|
264 |
+
else:
|
265 |
+
shifted_x = x
|
266 |
+
attn_mask = None
|
267 |
+
|
268 |
+
# partition windows
|
269 |
+
x_windows = window_partition(
|
270 |
+
shifted_x, self.window_size
|
271 |
+
) # nW*B, window_size, window_size, C
|
272 |
+
x_windows = x_windows.view(
|
273 |
+
-1, self.window_size * self.window_size, C
|
274 |
+
) # nW*B, window_size*window_size, C
|
275 |
+
|
276 |
+
# W-MSA/SW-MSA
|
277 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
278 |
+
|
279 |
+
# merge windows
|
280 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
281 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
282 |
+
|
283 |
+
# reverse cyclic shift
|
284 |
+
if self.shift_size > 0:
|
285 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
286 |
+
else:
|
287 |
+
x = shifted_x
|
288 |
+
|
289 |
+
if pad_r > 0 or pad_b > 0:
|
290 |
+
x = x[:, :H, :W, :].contiguous()
|
291 |
+
|
292 |
+
x = x.view(B, H * W, C)
|
293 |
+
|
294 |
+
# FFN
|
295 |
+
x = shortcut + self.drop_path(x)
|
296 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
297 |
+
|
298 |
+
return x
|
299 |
+
|
300 |
+
|
301 |
+
class PatchMerging(nn.Module):
|
302 |
+
"""Patch Merging Layer
|
303 |
+
Args:
|
304 |
+
dim (int): Number of input channels.
|
305 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
309 |
+
super().__init__()
|
310 |
+
self.dim = dim
|
311 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
312 |
+
self.norm = norm_layer(4 * dim)
|
313 |
+
|
314 |
+
def forward(self, x, H, W):
|
315 |
+
"""Forward function.
|
316 |
+
Args:
|
317 |
+
x: Input feature, tensor size (B, H*W, C).
|
318 |
+
H, W: Spatial resolution of the input feature.
|
319 |
+
"""
|
320 |
+
B, L, C = x.shape
|
321 |
+
assert L == H * W, "input feature has wrong size"
|
322 |
+
|
323 |
+
x = x.view(B, H, W, C)
|
324 |
+
|
325 |
+
# padding
|
326 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
327 |
+
if pad_input:
|
328 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
329 |
+
|
330 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
331 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
332 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
333 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
334 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
335 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
336 |
+
|
337 |
+
x = self.norm(x)
|
338 |
+
x = self.reduction(x)
|
339 |
+
|
340 |
+
return x
|
341 |
+
|
342 |
+
|
343 |
+
class BasicLayer(nn.Module):
|
344 |
+
"""A basic Swin Transformer layer for one stage.
|
345 |
+
Args:
|
346 |
+
dim (int): Number of feature channels
|
347 |
+
depth (int): Depths of this stage.
|
348 |
+
num_heads (int): Number of attention head.
|
349 |
+
window_size (int): Local window size. Default: 7.
|
350 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
351 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
352 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
353 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
354 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
355 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
356 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
357 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
358 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
359 |
+
"""
|
360 |
+
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
dim,
|
364 |
+
depth,
|
365 |
+
num_heads,
|
366 |
+
window_size=7,
|
367 |
+
mlp_ratio=4.0,
|
368 |
+
qkv_bias=True,
|
369 |
+
qk_scale=None,
|
370 |
+
drop=0.0,
|
371 |
+
attn_drop=0.0,
|
372 |
+
drop_path=0.0,
|
373 |
+
norm_layer=nn.LayerNorm,
|
374 |
+
downsample=None,
|
375 |
+
use_checkpoint=False,
|
376 |
+
):
|
377 |
+
super().__init__()
|
378 |
+
self.window_size = window_size
|
379 |
+
self.shift_size = window_size // 2
|
380 |
+
self.depth = depth
|
381 |
+
self.use_checkpoint = use_checkpoint
|
382 |
+
|
383 |
+
# build blocks
|
384 |
+
self.blocks = nn.ModuleList(
|
385 |
+
[
|
386 |
+
SwinTransformerBlock(
|
387 |
+
dim=dim,
|
388 |
+
num_heads=num_heads,
|
389 |
+
window_size=window_size,
|
390 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
391 |
+
mlp_ratio=mlp_ratio,
|
392 |
+
qkv_bias=qkv_bias,
|
393 |
+
qk_scale=qk_scale,
|
394 |
+
drop=drop,
|
395 |
+
attn_drop=attn_drop,
|
396 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
397 |
+
norm_layer=norm_layer,
|
398 |
+
)
|
399 |
+
for i in range(depth)
|
400 |
+
]
|
401 |
+
)
|
402 |
+
|
403 |
+
# patch merging layer
|
404 |
+
if downsample is not None:
|
405 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
406 |
+
else:
|
407 |
+
self.downsample = None
|
408 |
+
|
409 |
+
def forward(self, x, H, W):
|
410 |
+
"""Forward function.
|
411 |
+
Args:
|
412 |
+
x: Input feature, tensor size (B, H*W, C).
|
413 |
+
H, W: Spatial resolution of the input feature.
|
414 |
+
"""
|
415 |
+
|
416 |
+
# calculate attention mask for SW-MSA
|
417 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
418 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
419 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device, dtype=x.dtype) # 1 Hp Wp 1
|
420 |
+
h_slices = (
|
421 |
+
slice(0, -self.window_size),
|
422 |
+
slice(-self.window_size, -self.shift_size),
|
423 |
+
slice(-self.shift_size, None),
|
424 |
+
)
|
425 |
+
w_slices = (
|
426 |
+
slice(0, -self.window_size),
|
427 |
+
slice(-self.window_size, -self.shift_size),
|
428 |
+
slice(-self.shift_size, None),
|
429 |
+
)
|
430 |
+
cnt = 0
|
431 |
+
for h in h_slices:
|
432 |
+
for w in w_slices:
|
433 |
+
img_mask[:, h, w, :] = cnt
|
434 |
+
cnt += 1
|
435 |
+
|
436 |
+
mask_windows = window_partition(
|
437 |
+
img_mask, self.window_size
|
438 |
+
) # nW, window_size, window_size, 1
|
439 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
440 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
441 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
442 |
+
attn_mask == 0, float(0.0)
|
443 |
+
)
|
444 |
+
|
445 |
+
for blk in self.blocks:
|
446 |
+
blk.H, blk.W = H, W
|
447 |
+
if self.use_checkpoint:
|
448 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
449 |
+
else:
|
450 |
+
x = blk(x, attn_mask)
|
451 |
+
if self.downsample is not None:
|
452 |
+
x_down = self.downsample(x, H, W)
|
453 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
454 |
+
return x, H, W, x_down, Wh, Ww
|
455 |
+
else:
|
456 |
+
return x, H, W, x, H, W
|
457 |
+
|
458 |
+
|
459 |
+
class PatchEmbed(nn.Module):
|
460 |
+
"""Image to Patch Embedding
|
461 |
+
Args:
|
462 |
+
patch_size (int): Patch token size. Default: 4.
|
463 |
+
in_chans (int): Number of input image channels. Default: 3.
|
464 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
465 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
469 |
+
super().__init__()
|
470 |
+
patch_size = to_2tuple(patch_size)
|
471 |
+
self.patch_size = patch_size
|
472 |
+
|
473 |
+
self.in_chans = in_chans
|
474 |
+
self.embed_dim = embed_dim
|
475 |
+
|
476 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
477 |
+
if norm_layer is not None:
|
478 |
+
self.norm = norm_layer(embed_dim)
|
479 |
+
else:
|
480 |
+
self.norm = None
|
481 |
+
|
482 |
+
def forward(self, x):
|
483 |
+
"""Forward function."""
|
484 |
+
# padding
|
485 |
+
_, _, H, W = x.size()
|
486 |
+
if W % self.patch_size[1] != 0:
|
487 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
488 |
+
if H % self.patch_size[0] != 0:
|
489 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
490 |
+
|
491 |
+
x = self.proj(x) # B C Wh Ww
|
492 |
+
if self.norm is not None:
|
493 |
+
Wh, Ww = x.size(2), x.size(3)
|
494 |
+
x = x.flatten(2).transpose(1, 2)
|
495 |
+
x = self.norm(x)
|
496 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
497 |
+
|
498 |
+
return x
|
499 |
+
|
500 |
+
|
501 |
+
class SwinTransformer(nn.Module):
|
502 |
+
"""Swin Transformer backbone.
|
503 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
504 |
+
https://arxiv.org/pdf/2103.14030
|
505 |
+
Args:
|
506 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
507 |
+
used in absolute postion embedding. Default 224.
|
508 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
509 |
+
in_chans (int): Number of input image channels. Default: 3.
|
510 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
511 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
512 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
513 |
+
window_size (int): Window size. Default: 7.
|
514 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
515 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
516 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
517 |
+
drop_rate (float): Dropout rate.
|
518 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
519 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
520 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
521 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
522 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
523 |
+
out_indices (Sequence[int]): Output from which stages.
|
524 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
525 |
+
-1 means not freezing any parameters.
|
526 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
527 |
+
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
|
528 |
+
"""
|
529 |
+
|
530 |
+
def __init__(
|
531 |
+
self,
|
532 |
+
pretrain_img_size=224,
|
533 |
+
patch_size=4,
|
534 |
+
in_chans=3,
|
535 |
+
embed_dim=96,
|
536 |
+
depths=[2, 2, 6, 2],
|
537 |
+
num_heads=[3, 6, 12, 24],
|
538 |
+
window_size=7,
|
539 |
+
mlp_ratio=4.0,
|
540 |
+
qkv_bias=True,
|
541 |
+
qk_scale=None,
|
542 |
+
drop_rate=0.0,
|
543 |
+
attn_drop_rate=0.0,
|
544 |
+
drop_path_rate=0.2,
|
545 |
+
norm_layer=nn.LayerNorm,
|
546 |
+
ape=False,
|
547 |
+
patch_norm=True,
|
548 |
+
out_indices=(0, 1, 2, 3),
|
549 |
+
frozen_stages=-1,
|
550 |
+
dilation=False,
|
551 |
+
use_checkpoint=False,
|
552 |
+
):
|
553 |
+
super().__init__()
|
554 |
+
|
555 |
+
self.pretrain_img_size = pretrain_img_size
|
556 |
+
self.num_layers = len(depths)
|
557 |
+
self.embed_dim = embed_dim
|
558 |
+
self.ape = ape
|
559 |
+
self.patch_norm = patch_norm
|
560 |
+
self.out_indices = out_indices
|
561 |
+
self.frozen_stages = frozen_stages
|
562 |
+
self.dilation = dilation
|
563 |
+
|
564 |
+
# if use_checkpoint:
|
565 |
+
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
|
566 |
+
|
567 |
+
# split image into non-overlapping patches
|
568 |
+
self.patch_embed = PatchEmbed(
|
569 |
+
patch_size=patch_size,
|
570 |
+
in_chans=in_chans,
|
571 |
+
embed_dim=embed_dim,
|
572 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
573 |
+
)
|
574 |
+
|
575 |
+
# absolute position embedding
|
576 |
+
if self.ape:
|
577 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
578 |
+
patch_size = to_2tuple(patch_size)
|
579 |
+
patches_resolution = [
|
580 |
+
pretrain_img_size[0] // patch_size[0],
|
581 |
+
pretrain_img_size[1] // patch_size[1],
|
582 |
+
]
|
583 |
+
|
584 |
+
self.absolute_pos_embed = nn.Parameter(
|
585 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
586 |
+
)
|
587 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
588 |
+
|
589 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
590 |
+
|
591 |
+
# stochastic depth
|
592 |
+
dpr = [
|
593 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
594 |
+
] # stochastic depth decay rule
|
595 |
+
|
596 |
+
# build layers
|
597 |
+
self.layers = nn.ModuleList()
|
598 |
+
# prepare downsample list
|
599 |
+
downsamplelist = [PatchMerging for i in range(self.num_layers)]
|
600 |
+
downsamplelist[-1] = None
|
601 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
602 |
+
if self.dilation:
|
603 |
+
downsamplelist[-2] = None
|
604 |
+
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
|
605 |
+
for i_layer in range(self.num_layers):
|
606 |
+
layer = BasicLayer(
|
607 |
+
# dim=int(embed_dim * 2 ** i_layer),
|
608 |
+
dim=num_features[i_layer],
|
609 |
+
depth=depths[i_layer],
|
610 |
+
num_heads=num_heads[i_layer],
|
611 |
+
window_size=window_size,
|
612 |
+
mlp_ratio=mlp_ratio,
|
613 |
+
qkv_bias=qkv_bias,
|
614 |
+
qk_scale=qk_scale,
|
615 |
+
drop=drop_rate,
|
616 |
+
attn_drop=attn_drop_rate,
|
617 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
618 |
+
norm_layer=norm_layer,
|
619 |
+
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
620 |
+
downsample=downsamplelist[i_layer],
|
621 |
+
use_checkpoint=use_checkpoint,
|
622 |
+
)
|
623 |
+
self.layers.append(layer)
|
624 |
+
|
625 |
+
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
626 |
+
self.num_features = num_features
|
627 |
+
|
628 |
+
# add a norm layer for each output
|
629 |
+
for i_layer in out_indices:
|
630 |
+
layer = norm_layer(num_features[i_layer])
|
631 |
+
layer_name = f"norm{i_layer}"
|
632 |
+
self.add_module(layer_name, layer)
|
633 |
+
|
634 |
+
self._freeze_stages()
|
635 |
+
|
636 |
+
def _freeze_stages(self):
|
637 |
+
if self.frozen_stages >= 0:
|
638 |
+
self.patch_embed.eval()
|
639 |
+
for param in self.patch_embed.parameters():
|
640 |
+
param.requires_grad = False
|
641 |
+
|
642 |
+
if self.frozen_stages >= 1 and self.ape:
|
643 |
+
self.absolute_pos_embed.requires_grad = False
|
644 |
+
|
645 |
+
if self.frozen_stages >= 2:
|
646 |
+
self.pos_drop.eval()
|
647 |
+
for i in range(0, self.frozen_stages - 1):
|
648 |
+
m = self.layers[i]
|
649 |
+
m.eval()
|
650 |
+
for param in m.parameters():
|
651 |
+
param.requires_grad = False
|
652 |
+
|
653 |
+
# def init_weights(self, pretrained=None):
|
654 |
+
# """Initialize the weights in backbone.
|
655 |
+
# Args:
|
656 |
+
# pretrained (str, optional): Path to pre-trained weights.
|
657 |
+
# Defaults to None.
|
658 |
+
# """
|
659 |
+
|
660 |
+
# def _init_weights(m):
|
661 |
+
# if isinstance(m, nn.Linear):
|
662 |
+
# trunc_normal_(m.weight, std=.02)
|
663 |
+
# if isinstance(m, nn.Linear) and m.bias is not None:
|
664 |
+
# nn.init.constant_(m.bias, 0)
|
665 |
+
# elif isinstance(m, nn.LayerNorm):
|
666 |
+
# nn.init.constant_(m.bias, 0)
|
667 |
+
# nn.init.constant_(m.weight, 1.0)
|
668 |
+
|
669 |
+
# if isinstance(pretrained, str):
|
670 |
+
# self.apply(_init_weights)
|
671 |
+
# logger = get_root_logger()
|
672 |
+
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
673 |
+
# elif pretrained is None:
|
674 |
+
# self.apply(_init_weights)
|
675 |
+
# else:
|
676 |
+
# raise TypeError('pretrained must be a str or None')
|
677 |
+
|
678 |
+
def forward_raw(self, x):
|
679 |
+
"""Forward function."""
|
680 |
+
x = self.patch_embed(x)
|
681 |
+
|
682 |
+
Wh, Ww = x.size(2), x.size(3)
|
683 |
+
if self.ape:
|
684 |
+
# interpolate the position embedding to the corresponding size
|
685 |
+
absolute_pos_embed = F.interpolate(
|
686 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
687 |
+
)
|
688 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
689 |
+
else:
|
690 |
+
x = x.flatten(2).transpose(1, 2)
|
691 |
+
x = self.pos_drop(x)
|
692 |
+
|
693 |
+
outs = []
|
694 |
+
for i in range(self.num_layers):
|
695 |
+
layer = self.layers[i]
|
696 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
697 |
+
# import ipdb; ipdb.set_trace()
|
698 |
+
|
699 |
+
if i in self.out_indices:
|
700 |
+
norm_layer = getattr(self, f"norm{i}")
|
701 |
+
x_out = norm_layer(x_out)
|
702 |
+
|
703 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
704 |
+
outs.append(out)
|
705 |
+
# in:
|
706 |
+
# torch.Size([2, 3, 1024, 1024])
|
707 |
+
# outs:
|
708 |
+
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
709 |
+
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
710 |
+
return tuple(outs)
|
711 |
+
|
712 |
+
def forward(self, tensor_list: NestedTensor):
|
713 |
+
x = tensor_list.tensors
|
714 |
+
|
715 |
+
"""Forward function."""
|
716 |
+
x = self.patch_embed(x)
|
717 |
+
|
718 |
+
Wh, Ww = x.size(2), x.size(3)
|
719 |
+
if self.ape:
|
720 |
+
# interpolate the position embedding to the corresponding size
|
721 |
+
absolute_pos_embed = F.interpolate(
|
722 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
723 |
+
)
|
724 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
725 |
+
else:
|
726 |
+
x = x.flatten(2).transpose(1, 2)
|
727 |
+
x = self.pos_drop(x)
|
728 |
+
|
729 |
+
outs = []
|
730 |
+
for i in range(self.num_layers):
|
731 |
+
layer = self.layers[i]
|
732 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
733 |
+
|
734 |
+
if i in self.out_indices:
|
735 |
+
norm_layer = getattr(self, f"norm{i}")
|
736 |
+
x_out = norm_layer(x_out)
|
737 |
+
|
738 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
739 |
+
outs.append(out)
|
740 |
+
# in:
|
741 |
+
# torch.Size([2, 3, 1024, 1024])
|
742 |
+
# out:
|
743 |
+
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
744 |
+
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
745 |
+
|
746 |
+
# collect for nesttensors
|
747 |
+
outs_dict = {}
|
748 |
+
for idx, out_i in enumerate(outs):
|
749 |
+
m = tensor_list.mask
|
750 |
+
assert m is not None
|
751 |
+
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
|
752 |
+
outs_dict[idx] = NestedTensor(out_i, mask)
|
753 |
+
|
754 |
+
return outs_dict
|
755 |
+
|
756 |
+
def train(self, mode=True):
|
757 |
+
"""Convert the model into training mode while keep layers freezed."""
|
758 |
+
super(SwinTransformer, self).train(mode)
|
759 |
+
self._freeze_stages()
|
760 |
+
|
761 |
+
|
762 |
+
def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
763 |
+
assert modelname in [
|
764 |
+
"swin_T_224_1k",
|
765 |
+
"swin_B_224_22k",
|
766 |
+
"swin_B_384_22k",
|
767 |
+
"swin_L_224_22k",
|
768 |
+
"swin_L_384_22k",
|
769 |
+
]
|
770 |
+
|
771 |
+
model_para_dict = {
|
772 |
+
"swin_T_224_1k": dict(
|
773 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
|
774 |
+
),
|
775 |
+
"swin_B_224_22k": dict(
|
776 |
+
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
|
777 |
+
),
|
778 |
+
"swin_B_384_22k": dict(
|
779 |
+
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
|
780 |
+
),
|
781 |
+
"swin_L_224_22k": dict(
|
782 |
+
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
|
783 |
+
),
|
784 |
+
"swin_L_384_22k": dict(
|
785 |
+
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
|
786 |
+
),
|
787 |
+
}
|
788 |
+
kw_cgf = model_para_dict[modelname]
|
789 |
+
kw_cgf.update(kw)
|
790 |
+
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
|
791 |
+
return model
|
792 |
+
|
793 |
+
|
794 |
+
if __name__ == "__main__":
|
795 |
+
model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
|
796 |
+
x = torch.rand(2, 3, 1024, 1024)
|
797 |
+
y = model.forward_raw(x)
|
798 |
+
import ipdb
|
799 |
+
|
800 |
+
ipdb.set_trace()
|
801 |
+
x = torch.rand(2, 3, 384, 384)
|
802 |
+
y = model.forward_raw(x)
|
GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Grounding DINO
|
3 |
+
# url: https://github.com/IDEA-Research/GroundingDINO
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from torchvision.ops.boxes import nms
|
13 |
+
from transformers import BertConfig, BertModel, BertPreTrainedModel
|
14 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
15 |
+
|
16 |
+
|
17 |
+
class BertModelWarper(nn.Module):
|
18 |
+
def __init__(self, bert_model):
|
19 |
+
super().__init__()
|
20 |
+
# self.bert = bert_modelc
|
21 |
+
|
22 |
+
self.config = bert_model.config
|
23 |
+
self.embeddings = bert_model.embeddings
|
24 |
+
self.encoder = bert_model.encoder
|
25 |
+
self.pooler = bert_model.pooler
|
26 |
+
|
27 |
+
self.get_extended_attention_mask = bert_model.get_extended_attention_mask
|
28 |
+
self.invert_attention_mask = bert_model.invert_attention_mask
|
29 |
+
self.get_head_mask = bert_model.get_head_mask
|
30 |
+
|
31 |
+
def forward(
|
32 |
+
self,
|
33 |
+
input_ids=None,
|
34 |
+
attention_mask=None,
|
35 |
+
token_type_ids=None,
|
36 |
+
position_ids=None,
|
37 |
+
head_mask=None,
|
38 |
+
inputs_embeds=None,
|
39 |
+
encoder_hidden_states=None,
|
40 |
+
encoder_attention_mask=None,
|
41 |
+
past_key_values=None,
|
42 |
+
use_cache=None,
|
43 |
+
output_attentions=None,
|
44 |
+
output_hidden_states=None,
|
45 |
+
return_dict=None,
|
46 |
+
):
|
47 |
+
r"""
|
48 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
49 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
50 |
+
the model is configured as a decoder.
|
51 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
52 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
53 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
54 |
+
|
55 |
+
- 1 for tokens that are **not masked**,
|
56 |
+
- 0 for tokens that are **masked**.
|
57 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
58 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
59 |
+
|
60 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
61 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
62 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
63 |
+
use_cache (:obj:`bool`, `optional`):
|
64 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
65 |
+
decoding (see :obj:`past_key_values`).
|
66 |
+
"""
|
67 |
+
output_attentions = (
|
68 |
+
output_attentions if output_attentions is not None else self.config.output_attentions
|
69 |
+
)
|
70 |
+
output_hidden_states = (
|
71 |
+
output_hidden_states
|
72 |
+
if output_hidden_states is not None
|
73 |
+
else self.config.output_hidden_states
|
74 |
+
)
|
75 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
76 |
+
|
77 |
+
if self.config.is_decoder:
|
78 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
79 |
+
else:
|
80 |
+
use_cache = False
|
81 |
+
|
82 |
+
if input_ids is not None and inputs_embeds is not None:
|
83 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
84 |
+
elif input_ids is not None:
|
85 |
+
input_shape = input_ids.size()
|
86 |
+
batch_size, seq_length = input_shape
|
87 |
+
elif inputs_embeds is not None:
|
88 |
+
input_shape = inputs_embeds.size()[:-1]
|
89 |
+
batch_size, seq_length = input_shape
|
90 |
+
else:
|
91 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
92 |
+
|
93 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
94 |
+
|
95 |
+
# past_key_values_length
|
96 |
+
past_key_values_length = (
|
97 |
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
98 |
+
)
|
99 |
+
|
100 |
+
if attention_mask is None:
|
101 |
+
attention_mask = torch.ones(
|
102 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
103 |
+
)
|
104 |
+
if token_type_ids is None:
|
105 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
106 |
+
|
107 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
108 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
109 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
110 |
+
attention_mask, input_shape, device
|
111 |
+
)
|
112 |
+
|
113 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
114 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
115 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
116 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
117 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
118 |
+
if encoder_attention_mask is None:
|
119 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
120 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
121 |
+
else:
|
122 |
+
encoder_extended_attention_mask = None
|
123 |
+
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
124 |
+
# import ipdb; ipdb.set_trace()
|
125 |
+
|
126 |
+
# Prepare head mask if needed
|
127 |
+
# 1.0 in head_mask indicate we keep the head
|
128 |
+
# attention_probs has shape bsz x n_heads x N x N
|
129 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
130 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
131 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
132 |
+
|
133 |
+
embedding_output = self.embeddings(
|
134 |
+
input_ids=input_ids,
|
135 |
+
position_ids=position_ids,
|
136 |
+
token_type_ids=token_type_ids,
|
137 |
+
inputs_embeds=inputs_embeds,
|
138 |
+
past_key_values_length=past_key_values_length,
|
139 |
+
)
|
140 |
+
|
141 |
+
encoder_outputs = self.encoder(
|
142 |
+
embedding_output,
|
143 |
+
attention_mask=extended_attention_mask,
|
144 |
+
head_mask=head_mask,
|
145 |
+
encoder_hidden_states=encoder_hidden_states,
|
146 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
147 |
+
past_key_values=past_key_values,
|
148 |
+
use_cache=use_cache,
|
149 |
+
output_attentions=output_attentions,
|
150 |
+
output_hidden_states=output_hidden_states,
|
151 |
+
return_dict=return_dict,
|
152 |
+
)
|
153 |
+
sequence_output = encoder_outputs[0]
|
154 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
155 |
+
|
156 |
+
if not return_dict:
|
157 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
158 |
+
|
159 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
160 |
+
last_hidden_state=sequence_output,
|
161 |
+
pooler_output=pooled_output,
|
162 |
+
past_key_values=encoder_outputs.past_key_values,
|
163 |
+
hidden_states=encoder_outputs.hidden_states,
|
164 |
+
attentions=encoder_outputs.attentions,
|
165 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class TextEncoderShell(nn.Module):
|
170 |
+
def __init__(self, text_encoder):
|
171 |
+
super().__init__()
|
172 |
+
self.text_encoder = text_encoder
|
173 |
+
self.config = self.text_encoder.config
|
174 |
+
|
175 |
+
def forward(self, **kw):
|
176 |
+
# feed into text encoder
|
177 |
+
return self.text_encoder(**kw)
|
178 |
+
|
179 |
+
|
180 |
+
def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
|
181 |
+
"""Generate attention mask between each pair of special tokens
|
182 |
+
Args:
|
183 |
+
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
184 |
+
special_tokens_mask (list): special tokens mask.
|
185 |
+
Returns:
|
186 |
+
torch.Tensor: attention mask between each special tokens.
|
187 |
+
"""
|
188 |
+
input_ids = tokenized["input_ids"]
|
189 |
+
bs, num_token = input_ids.shape
|
190 |
+
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
191 |
+
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
192 |
+
for special_token in special_tokens_list:
|
193 |
+
special_tokens_mask |= input_ids == special_token
|
194 |
+
|
195 |
+
# idxs: each row is a list of indices of special tokens
|
196 |
+
idxs = torch.nonzero(special_tokens_mask)
|
197 |
+
|
198 |
+
# generate attention mask and positional ids
|
199 |
+
attention_mask = (
|
200 |
+
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
201 |
+
)
|
202 |
+
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
203 |
+
previous_col = 0
|
204 |
+
for i in range(idxs.shape[0]):
|
205 |
+
row, col = idxs[i]
|
206 |
+
if (col == 0) or (col == num_token - 1):
|
207 |
+
attention_mask[row, col, col] = True
|
208 |
+
position_ids[row, col] = 0
|
209 |
+
else:
|
210 |
+
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
211 |
+
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
|
212 |
+
0, col - previous_col, device=input_ids.device
|
213 |
+
)
|
214 |
+
|
215 |
+
previous_col = col
|
216 |
+
|
217 |
+
# # padding mask
|
218 |
+
# padding_mask = tokenized['attention_mask']
|
219 |
+
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
220 |
+
|
221 |
+
return attention_mask, position_ids.to(torch.long)
|
222 |
+
|
223 |
+
|
224 |
+
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
|
225 |
+
"""Generate attention mask between each pair of special tokens
|
226 |
+
Args:
|
227 |
+
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
228 |
+
special_tokens_mask (list): special tokens mask.
|
229 |
+
Returns:
|
230 |
+
torch.Tensor: attention mask between each special tokens.
|
231 |
+
"""
|
232 |
+
input_ids = tokenized["input_ids"]
|
233 |
+
bs, num_token = input_ids.shape
|
234 |
+
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
235 |
+
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
236 |
+
for special_token in special_tokens_list:
|
237 |
+
special_tokens_mask |= input_ids == special_token
|
238 |
+
|
239 |
+
# idxs: each row is a list of indices of special tokens
|
240 |
+
idxs = torch.nonzero(special_tokens_mask)
|
241 |
+
|
242 |
+
# generate attention mask and positional ids
|
243 |
+
attention_mask = (
|
244 |
+
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
245 |
+
)
|
246 |
+
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
247 |
+
cate_to_token_mask_list = [[] for _ in range(bs)]
|
248 |
+
previous_col = 0
|
249 |
+
for i in range(idxs.shape[0]):
|
250 |
+
row, col = idxs[i]
|
251 |
+
if (col == 0) or (col == num_token - 1):
|
252 |
+
attention_mask[row, col, col] = True
|
253 |
+
position_ids[row, col] = 0
|
254 |
+
else:
|
255 |
+
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
256 |
+
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
|
257 |
+
0, col - previous_col, device=input_ids.device
|
258 |
+
)
|
259 |
+
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
|
260 |
+
c2t_maski[previous_col + 1 : col] = True
|
261 |
+
cate_to_token_mask_list[row].append(c2t_maski)
|
262 |
+
previous_col = col
|
263 |
+
|
264 |
+
cate_to_token_mask_list = [
|
265 |
+
torch.stack(cate_to_token_mask_listi, dim=0)
|
266 |
+
for cate_to_token_mask_listi in cate_to_token_mask_list
|
267 |
+
]
|
268 |
+
|
269 |
+
# # padding mask
|
270 |
+
# padding_mask = tokenized['attention_mask']
|
271 |
+
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
272 |
+
|
273 |
+
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
|
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#pragma once
|
12 |
+
|
13 |
+
#include "ms_deform_attn_cpu.h"
|
14 |
+
|
15 |
+
#ifdef WITH_CUDA
|
16 |
+
#include "ms_deform_attn_cuda.h"
|
17 |
+
#endif
|
18 |
+
|
19 |
+
namespace groundingdino {
|
20 |
+
|
21 |
+
at::Tensor
|
22 |
+
ms_deform_attn_forward(
|
23 |
+
const at::Tensor &value,
|
24 |
+
const at::Tensor &spatial_shapes,
|
25 |
+
const at::Tensor &level_start_index,
|
26 |
+
const at::Tensor &sampling_loc,
|
27 |
+
const at::Tensor &attn_weight,
|
28 |
+
const int im2col_step)
|
29 |
+
{
|
30 |
+
if (value.type().is_cuda())
|
31 |
+
{
|
32 |
+
#ifdef WITH_CUDA
|
33 |
+
return ms_deform_attn_cuda_forward(
|
34 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
35 |
+
#else
|
36 |
+
AT_ERROR("Not compiled with GPU support");
|
37 |
+
#endif
|
38 |
+
}
|
39 |
+
AT_ERROR("Not implemented on the CPU");
|
40 |
+
}
|
41 |
+
|
42 |
+
std::vector<at::Tensor>
|
43 |
+
ms_deform_attn_backward(
|
44 |
+
const at::Tensor &value,
|
45 |
+
const at::Tensor &spatial_shapes,
|
46 |
+
const at::Tensor &level_start_index,
|
47 |
+
const at::Tensor &sampling_loc,
|
48 |
+
const at::Tensor &attn_weight,
|
49 |
+
const at::Tensor &grad_output,
|
50 |
+
const int im2col_step)
|
51 |
+
{
|
52 |
+
if (value.type().is_cuda())
|
53 |
+
{
|
54 |
+
#ifdef WITH_CUDA
|
55 |
+
return ms_deform_attn_cuda_backward(
|
56 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
57 |
+
#else
|
58 |
+
AT_ERROR("Not compiled with GPU support");
|
59 |
+
#endif
|
60 |
+
}
|
61 |
+
AT_ERROR("Not implemented on the CPU");
|
62 |
+
}
|
63 |
+
|
64 |
+
} // namespace groundingdino
|
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
|
13 |
+
#include <ATen/ATen.h>
|
14 |
+
#include <ATen/cuda/CUDAContext.h>
|
15 |
+
|
16 |
+
namespace groundingdino {
|
17 |
+
|
18 |
+
at::Tensor
|
19 |
+
ms_deform_attn_cpu_forward(
|
20 |
+
const at::Tensor &value,
|
21 |
+
const at::Tensor &spatial_shapes,
|
22 |
+
const at::Tensor &level_start_index,
|
23 |
+
const at::Tensor &sampling_loc,
|
24 |
+
const at::Tensor &attn_weight,
|
25 |
+
const int im2col_step)
|
26 |
+
{
|
27 |
+
AT_ERROR("Not implement on cpu");
|
28 |
+
}
|
29 |
+
|
30 |
+
std::vector<at::Tensor>
|
31 |
+
ms_deform_attn_cpu_backward(
|
32 |
+
const at::Tensor &value,
|
33 |
+
const at::Tensor &spatial_shapes,
|
34 |
+
const at::Tensor &level_start_index,
|
35 |
+
const at::Tensor &sampling_loc,
|
36 |
+
const at::Tensor &attn_weight,
|
37 |
+
const at::Tensor &grad_output,
|
38 |
+
const int im2col_step)
|
39 |
+
{
|
40 |
+
AT_ERROR("Not implement on cpu");
|
41 |
+
}
|
42 |
+
|
43 |
+
} // namespace groundingdino
|
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#pragma once
|
12 |
+
#include <torch/extension.h>
|
13 |
+
|
14 |
+
namespace groundingdino {
|
15 |
+
|
16 |
+
at::Tensor
|
17 |
+
ms_deform_attn_cpu_forward(
|
18 |
+
const at::Tensor &value,
|
19 |
+
const at::Tensor &spatial_shapes,
|
20 |
+
const at::Tensor &level_start_index,
|
21 |
+
const at::Tensor &sampling_loc,
|
22 |
+
const at::Tensor &attn_weight,
|
23 |
+
const int im2col_step);
|
24 |
+
|
25 |
+
std::vector<at::Tensor>
|
26 |
+
ms_deform_attn_cpu_backward(
|
27 |
+
const at::Tensor &value,
|
28 |
+
const at::Tensor &spatial_shapes,
|
29 |
+
const at::Tensor &level_start_index,
|
30 |
+
const at::Tensor &sampling_loc,
|
31 |
+
const at::Tensor &attn_weight,
|
32 |
+
const at::Tensor &grad_output,
|
33 |
+
const int im2col_step);
|
34 |
+
|
35 |
+
} // namespace groundingdino
|
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
#include "ms_deform_im2col_cuda.cuh"
|
13 |
+
|
14 |
+
#include <ATen/ATen.h>
|
15 |
+
#include <ATen/cuda/CUDAContext.h>
|
16 |
+
#include <cuda.h>
|
17 |
+
#include <cuda_runtime.h>
|
18 |
+
|
19 |
+
namespace groundingdino {
|
20 |
+
|
21 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
22 |
+
const at::Tensor &value,
|
23 |
+
const at::Tensor &spatial_shapes,
|
24 |
+
const at::Tensor &level_start_index,
|
25 |
+
const at::Tensor &sampling_loc,
|
26 |
+
const at::Tensor &attn_weight,
|
27 |
+
const int im2col_step)
|
28 |
+
{
|
29 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
30 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
31 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
32 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
33 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
34 |
+
|
35 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
36 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
37 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
38 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
39 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
40 |
+
|
41 |
+
const int batch = value.size(0);
|
42 |
+
const int spatial_size = value.size(1);
|
43 |
+
const int num_heads = value.size(2);
|
44 |
+
const int channels = value.size(3);
|
45 |
+
|
46 |
+
const int num_levels = spatial_shapes.size(0);
|
47 |
+
|
48 |
+
const int num_query = sampling_loc.size(1);
|
49 |
+
const int num_point = sampling_loc.size(4);
|
50 |
+
|
51 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
52 |
+
|
53 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
54 |
+
|
55 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
56 |
+
|
57 |
+
const int batch_n = im2col_step_;
|
58 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
59 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
60 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
61 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
62 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
63 |
+
{
|
64 |
+
auto columns = output_n.select(0, n);
|
65 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
66 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
67 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
68 |
+
spatial_shapes.data<int64_t>(),
|
69 |
+
level_start_index.data<int64_t>(),
|
70 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
71 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
72 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
73 |
+
columns.data<scalar_t>());
|
74 |
+
|
75 |
+
}));
|
76 |
+
}
|
77 |
+
|
78 |
+
output = output.view({batch, num_query, num_heads*channels});
|
79 |
+
|
80 |
+
return output;
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
85 |
+
const at::Tensor &value,
|
86 |
+
const at::Tensor &spatial_shapes,
|
87 |
+
const at::Tensor &level_start_index,
|
88 |
+
const at::Tensor &sampling_loc,
|
89 |
+
const at::Tensor &attn_weight,
|
90 |
+
const at::Tensor &grad_output,
|
91 |
+
const int im2col_step)
|
92 |
+
{
|
93 |
+
|
94 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
95 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
96 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
97 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
98 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
99 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
100 |
+
|
101 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
102 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
103 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
104 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
105 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
106 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
107 |
+
|
108 |
+
const int batch = value.size(0);
|
109 |
+
const int spatial_size = value.size(1);
|
110 |
+
const int num_heads = value.size(2);
|
111 |
+
const int channels = value.size(3);
|
112 |
+
|
113 |
+
const int num_levels = spatial_shapes.size(0);
|
114 |
+
|
115 |
+
const int num_query = sampling_loc.size(1);
|
116 |
+
const int num_point = sampling_loc.size(4);
|
117 |
+
|
118 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
119 |
+
|
120 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
121 |
+
|
122 |
+
auto grad_value = at::zeros_like(value);
|
123 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
124 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
125 |
+
|
126 |
+
const int batch_n = im2col_step_;
|
127 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
128 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
129 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
130 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
131 |
+
|
132 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
133 |
+
{
|
134 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
135 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
136 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
137 |
+
grad_output_g.data<scalar_t>(),
|
138 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
139 |
+
spatial_shapes.data<int64_t>(),
|
140 |
+
level_start_index.data<int64_t>(),
|
141 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
142 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
143 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
144 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
145 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
146 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
147 |
+
|
148 |
+
}));
|
149 |
+
}
|
150 |
+
|
151 |
+
return {
|
152 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
153 |
+
};
|
154 |
+
}
|
155 |
+
|
156 |
+
} // namespace groundingdino
|