File size: 8,797 Bytes
8b79d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
"""
python inference.py \
--variant mobilenetv3 \
--checkpoint "CHECKPOINT" \
--device cuda \
--input-source "input.mp4" \
--output-type video \
--output-composition "composition.mp4" \
--output-alpha "alpha.mp4" \
--output-foreground "foreground.mp4" \
--output-video-mbps 4 \
--seq-chunk 1
"""
import torch
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from typing import Optional, Tuple
from tqdm.auto import tqdm
from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
def convert_video(model,
input_source: str,
input_resize: Optional[Tuple[int, int]] = None,
downsample_ratio: Optional[float] = None,
output_type: str = 'video',
output_composition: Optional[str] = None,
output_alpha: Optional[str] = None,
output_foreground: Optional[str] = None,
output_video_mbps: Optional[float] = None,
seq_chunk: int = 1,
num_workers: int = 0,
progress: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
"""
Args:
input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
input_resize: If provided, the input are first resized to (w, h).
downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
output_type: Options: ["video", "png_sequence"].
output_composition:
The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
If output_type == 'video', the composition has green screen background.
If output_type == 'png_sequence'. the composition is RGBA png images.
output_alpha: The alpha output from the model.
output_foreground: The foreground output from the model.
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
progress: Show progress bar.
device: Only need to manually provide if model is a TorchScript freezed model.
dtype: Only need to manually provide if model is a TorchScript freezed model.
"""
assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
assert seq_chunk >= 1, 'Sequence chunk must be >= 1'
assert num_workers >= 0, 'Number of workers must be >= 0'
# Initialize transform
if input_resize is not None:
transform = transforms.Compose([
transforms.Resize(input_resize[::-1]),
transforms.ToTensor()
])
else:
transform = transforms.ToTensor()
# Initialize reader
if os.path.isfile(input_source):
source = VideoReader(input_source, transform)
else:
source = ImageSequenceReader(input_source, transform)
reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
# Initialize writers
if output_type == 'video':
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
if output_composition is not None:
writer_com = VideoWriter(
path=output_composition,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = VideoWriter(
path=output_alpha,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = VideoWriter(
path=output_foreground,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
else:
if output_composition is not None:
writer_com = ImageSequenceWriter(output_composition, 'png')
if output_alpha is not None:
writer_pha = ImageSequenceWriter(output_alpha, 'png')
if output_foreground is not None:
writer_fgr = ImageSequenceWriter(output_foreground, 'png')
# Inference
model = model.eval()
if device is None or dtype is None:
param = next(model.parameters())
dtype = param.dtype
device = param.device
if (output_composition is not None) and (output_type == 'video'):
bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
try:
with torch.no_grad():
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
rec = [None] * 4
for src in reader:
if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:])
src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
fgr, pha, *rec = model(src, *rec, downsample_ratio)
if output_foreground is not None:
writer_fgr.write(fgr[0])
if output_alpha is not None:
writer_pha.write(pha[0])
if output_composition is not None:
if output_type == 'video':
com = fgr * pha + bgr * (1 - pha)
else:
fgr = fgr * pha.gt(0)
com = torch.cat([fgr, pha], dim=-3)
writer_com.write(com[0])
bar.update(src.size(1))
finally:
# Clean up
if output_composition is not None:
writer_com.close()
if output_alpha is not None:
writer_pha.close()
if output_foreground is not None:
writer_fgr.close()
def auto_downsample_ratio(h, w):
"""
Automatically find a downsample ratio so that the largest side of the resolution be 512px.
"""
return min(512 / max(h, w), 1)
class Converter:
def __init__(self, variant: str, checkpoint: str, device: str):
self.model = MattingNetwork(variant).eval().to(device)
self.model.load_state_dict(torch.load(checkpoint, map_location=device))
self.model = torch.jit.script(self.model)
self.model = torch.jit.freeze(self.model)
self.device = device
def convert(self, *args, **kwargs):
convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
if __name__ == '__main__':
import argparse
from model import MattingNetwork
parser = argparse.ArgumentParser()
parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--device', type=str, required=True)
parser.add_argument('--input-source', type=str, required=True)
parser.add_argument('--input-resize', type=int, default=None, nargs=2)
parser.add_argument('--downsample-ratio', type=float)
parser.add_argument('--output-composition', type=str)
parser.add_argument('--output-alpha', type=str)
parser.add_argument('--output-foreground', type=str)
parser.add_argument('--output-type', type=str, required=True, choices=['video', 'png_sequence'])
parser.add_argument('--output-video-mbps', type=int, default=1)
parser.add_argument('--seq-chunk', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--disable-progress', action='store_true')
args = parser.parse_args()
converter = Converter(args.variant, args.checkpoint, args.device)
converter.convert(
input_source=args.input_source,
input_resize=args.input_resize,
downsample_ratio=args.downsample_ratio,
output_type=args.output_type,
output_composition=args.output_composition,
output_alpha=args.output_alpha,
output_foreground=args.output_foreground,
output_video_mbps=args.output_video_mbps,
seq_chunk=args.seq_chunk,
num_workers=args.num_workers,
progress=not args.disable_progress
)
|