|
""" |
|
Loading model |
|
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") |
|
model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50") |
|
|
|
Converter API |
|
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") |
|
""" |
|
|
|
|
|
dependencies = ['torch', 'torchvision'] |
|
|
|
import torch |
|
from model import MattingNetwork |
|
|
|
|
|
def mobilenetv3(pretrained: bool = True, progress: bool = True): |
|
model = MattingNetwork('mobilenetv3') |
|
if pretrained: |
|
url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth' |
|
model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) |
|
return model |
|
|
|
|
|
def resnet50(pretrained: bool = True, progress: bool = True): |
|
model = MattingNetwork('resnet50') |
|
if pretrained: |
|
url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth' |
|
model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) |
|
return model |
|
|
|
|
|
def converter(): |
|
try: |
|
from inference import convert_video |
|
return convert_video |
|
except ModuleNotFoundError as error: |
|
print(error) |
|
print('Please run "pip install av tqdm pims"') |
|
|