Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import glob | |
import sys | |
import argparse | |
import cv2 | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision import transforms as pth_transforms | |
import numpy as np | |
from PIL import Image | |
import utils | |
import vision_transformer as vits | |
FOURCC = { | |
"mp4": cv2.VideoWriter_fourcc(*"MP4V"), | |
"avi": cv2.VideoWriter_fourcc(*"XVID"), | |
} | |
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
class VideoGenerator: | |
def __init__(self, args): | |
self.args = args | |
# self.model = None | |
# Don't need to load model if you only want a video | |
if not self.args.video_only: | |
self.model = self.__load_model() | |
def run(self): | |
if self.args.input_path is None: | |
print(f"Provided input path {self.args.input_path} is non valid.") | |
sys.exit(1) | |
else: | |
if self.args.video_only: | |
self._generate_video_from_images( | |
self.args.input_path, self.args.output_path | |
) | |
else: | |
# If input path exists | |
if os.path.exists(self.args.input_path): | |
frames_folder = os.path.join(self.args.output_path, "frames") | |
os.makedirs(frames_folder, exist_ok=True) | |
# If input is a video file | |
if os.path.isfile(self.args.input_path): | |
attention_folder = os.path.join( | |
self.args.output_path, "attention" | |
) | |
os.makedirs(attention_folder, exist_ok=True) | |
self._extract_frames_from_video( | |
self.args.input_path, frames_folder | |
) | |
self._inference( | |
frames_folder, | |
attention_folder, | |
) | |
self._generate_video_from_images( | |
attention_folder, self.args.output_path | |
) | |
self._generate_video_from_images( | |
frames_folder, | |
self.args.output_path, | |
file_pattern="reshaped-*.jpg", | |
out_video_name="original-reshaped" | |
) | |
# If input is a folder of already extracted frames | |
if os.path.isdir(self.args.input_path): | |
attention_folder = os.path.join( | |
self.args.output_path, "attention" | |
) | |
os.makedirs(attention_folder, exist_ok=True) | |
self._inference(self.args.input_path, attention_folder) | |
self._generate_video_from_images( | |
attention_folder, self.args.output_path | |
) | |
self._generate_video_from_images( | |
frames_folder, | |
self.args.output_path, | |
file_pattern="reshaped-*.jpg", | |
out_video_name="original-reshaped" | |
) | |
# If input path doesn't exists | |
else: | |
print(f"Provided input path {self.args.input_path} doesn't exists.") | |
sys.exit(1) | |
def _extract_frames_from_video(self, inp: str, out: str): | |
vidcap = cv2.VideoCapture(inp) | |
self.args.fps = vidcap.get(cv2.CAP_PROP_FPS) | |
print(f"Video: {inp} ({self.args.fps} fps)") | |
print(f"Extracting frames to {out}") | |
success, image = vidcap.read() | |
count = 0 | |
while success: | |
cv2.imwrite( | |
os.path.join(out, f"frame-{count:04}.jpg"), | |
image, | |
) | |
success, image = vidcap.read() | |
count += 1 | |
def _generate_video_from_images(self, inp: str, out: str, file_pattern="attn-*.jpg", out_video_name="video"): | |
img_array = [] | |
attention_images_list = sorted(glob.glob(os.path.join(inp, file_pattern))) | |
# Get size of the first image | |
with open(attention_images_list[0], "rb") as f: | |
img = Image.open(f) | |
img = img.convert("RGB") | |
size = (img.width, img.height) | |
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
print(f"Generating video {size} to {out}") | |
for filename in tqdm(attention_images_list[1:]): | |
with open(filename, "rb") as f: | |
img = Image.open(f) | |
img = img.convert("RGB") | |
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
out = cv2.VideoWriter( | |
os.path.join(out, f"{out_video_name}." + self.args.video_format), | |
FOURCC[self.args.video_format], | |
self.args.fps, | |
size, | |
) | |
for i in range(len(img_array)): | |
out.write(img_array[i]) | |
out.release() | |
print("Done") | |
def _inference(self, inp: str, out: str): | |
print(f"Generating attention images to {out}") | |
for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))): | |
with open(img_path, "rb") as f: | |
img_in = Image.open(f) | |
img_in = img_in.convert("RGB") | |
if self.args.resize is not None: | |
transform = pth_transforms.Compose( | |
[ | |
pth_transforms.ToTensor(), | |
pth_transforms.Resize(self.args.resize), | |
pth_transforms.Normalize( | |
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225) | |
), | |
] | |
) | |
else: | |
transform = pth_transforms.Compose( | |
[ | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize( | |
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225) | |
), | |
] | |
) | |
img = transform(img_in) | |
# make the image divisible by the patch size | |
w, h = ( | |
img.shape[1] - img.shape[1] % self.args.patch_size, | |
img.shape[2] - img.shape[2] % self.args.patch_size, | |
) | |
img = img[:, :w, :h].unsqueeze(0) | |
w_featmap = img.shape[-2] // self.args.patch_size | |
h_featmap = img.shape[-1] // self.args.patch_size | |
attentions = self.model.get_last_selfattention(img.to(DEVICE)) | |
nh = attentions.shape[1] # number of head | |
# we keep only the output patch attention | |
attentions = attentions[0, :, 0, 1:].reshape(nh, -1) | |
# we keep only a certain percentage of the mass | |
val, idx = torch.sort(attentions) | |
val /= torch.sum(val, dim=1, keepdim=True) | |
cumval = torch.cumsum(val, dim=1) | |
th_attn = cumval > (1 - self.args.threshold) | |
idx2 = torch.argsort(idx) | |
for head in range(nh): | |
th_attn[head] = th_attn[head][idx2[head]] | |
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
# interpolate | |
th_attn = ( | |
nn.functional.interpolate( | |
th_attn.unsqueeze(0), | |
scale_factor=self.args.patch_size, | |
mode="nearest", | |
)[0] | |
.cpu() | |
.numpy() | |
) | |
attentions = attentions.reshape(nh, w_featmap, h_featmap) | |
attentions = ( | |
nn.functional.interpolate( | |
attentions.unsqueeze(0), | |
scale_factor=self.args.patch_size, | |
mode="nearest", | |
)[0] | |
.cpu() | |
.numpy() | |
) | |
# save attentions heatmaps | |
fname = os.path.join(out, "attn-" + os.path.basename(img_path)) | |
plt.imsave( | |
fname=fname, | |
arr=sum( | |
attentions[i] * 1 / attentions.shape[0] | |
for i in range(attentions.shape[0]) | |
), | |
cmap="inferno", | |
format="jpg", | |
) | |
fname = os.path.join(os.path.dirname(out), "frames/reshaped-" + os.path.basename(img_path)) | |
img_in = img_in.resize((attentions[0].shape[1], attentions[0].shape[0])) | |
img_in.save(fname) | |
def __load_model(self): | |
# build model | |
model = vits.__dict__[self.args.arch]( | |
patch_size=self.args.patch_size, num_classes=0 | |
) | |
for p in model.parameters(): | |
p.requires_grad = False | |
model.eval() | |
model.to(DEVICE) | |
if os.path.isfile(self.args.pretrained_weights): | |
state_dict = torch.load(self.args.pretrained_weights, map_location="cpu") | |
if ( | |
self.args.checkpoint_key is not None | |
and self.args.checkpoint_key in state_dict | |
): | |
print( | |
f"Take key {self.args.checkpoint_key} in provided checkpoint dict" | |
) | |
state_dict = state_dict[self.args.checkpoint_key] | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print( | |
"Pretrained weights found at {} and loaded with msg: {}".format( | |
self.args.pretrained_weights, msg | |
) | |
) | |
else: | |
print( | |
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." | |
) | |
url = None | |
if self.args.arch == "vit_small" and self.args.patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif self.args.arch == "vit_small" and self.args.patch_size == 8: | |
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper | |
elif self.args.arch == "vit_base" and self.args.patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif self.args.arch == "vit_base" and self.args.patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
if url is not None: | |
print( | |
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights." | |
) | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/" + url | |
) | |
model.load_state_dict(state_dict, strict=True) | |
else: | |
print( | |
"There is no reference weights available for this model => We use random weights." | |
) | |
return model | |
def parse_args(): | |
parser = argparse.ArgumentParser("Generation self-attention video") | |
parser.add_argument( | |
"--arch", | |
default="vit_small", | |
type=str, | |
choices=["vit_tiny", "vit_small", "vit_base"], | |
help="Architecture (support only ViT atm).", | |
) | |
parser.add_argument( | |
"--patch_size", default=8, type=int, help="Patch resolution of the self.model." | |
) | |
parser.add_argument( | |
"--pretrained_weights", | |
default="", | |
type=str, | |
help="Path to pretrained weights to load.", | |
) | |
parser.add_argument( | |
"--checkpoint_key", | |
default="teacher", | |
type=str, | |
help='Key to use in the checkpoint (example: "teacher")', | |
) | |
parser.add_argument( | |
"--input_path", | |
required=True, | |
type=str, | |
help="""Path to a video file if you want to extract frames | |
or to a folder of images already extracted by yourself. | |
or to a folder of attention images.""", | |
) | |
parser.add_argument( | |
"--output_path", | |
default="./", | |
type=str, | |
help="""Path to store a folder of frames and / or a folder of attention images. | |
and / or a final video. Default to current directory.""", | |
) | |
parser.add_argument( | |
"--threshold", | |
type=float, | |
default=0.6, | |
help="""We visualize masks | |
obtained by thresholding the self-attention maps to keep xx percent of the mass.""", | |
) | |
parser.add_argument( | |
"--resize", | |
default=None, | |
type=int, | |
nargs="+", | |
help="""Apply a resize transformation to input image(s). Use if OOM error. | |
Usage (single or W H): --resize 512, --resize 720 1280""", | |
) | |
parser.add_argument( | |
"--video_only", | |
action="store_true", | |
help="""Use this flag if you only want to generate a video and not all attention images. | |
If used, --input_path must be set to the folder of attention images. Ex: ./attention/""", | |
) | |
parser.add_argument( | |
"--fps", | |
default=30.0, | |
type=float, | |
help="FPS of input / output video. Automatically set if you extract frames from a video.", | |
) | |
parser.add_argument( | |
"--video_format", | |
default="mp4", | |
type=str, | |
choices=["mp4", "avi"], | |
help="Format of generated video (mp4 or avi).", | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_args() | |
vg = VideoGenerator(args) | |