#!/usr/bin/env python import getopt import math import numpy import PIL import PIL.Image import sys import torch import typing import softsplat # the custom softmax splatting layer try: from .correlation import correlation # the custom cost volume layer except: sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python # end ########################################################## torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance ########################################################## args_strModel = 'lf' args_strOne = './images/one.png' args_strTwo = './images/two.png' args_strVideo = './videos/car-turn.mp4' args_strOut = './out.png' for strOption, strArg in getopt.getopt(sys.argv[1:], '', [ 'model=', 'one=', 'two=', 'video=', 'out=', ])[0]: if strOption == '--model' and strArg != '': args_strModel = strArg # which model to use if strOption == '--one' and strArg != '': args_strOne = strArg # path to the first frame if strOption == '--two' and strArg != '': args_strTwo = strArg # path to the second frame if strOption == '--video' and strArg != '': args_strVideo = strArg # path to a video if strOption == '--out' and strArg != '': args_strOut = strArg # path to where the output should be stored # end ########################################################## def read_flo(strFile): with open(strFile, 'rb') as objFile: strFlow = objFile.read() # end assert(numpy.frombuffer(buffer=strFlow, dtype=numpy.float32, count=1, offset=0) == 202021.25) intWidth = numpy.frombuffer(buffer=strFlow, dtype=numpy.int32, count=1, offset=4)[0] intHeight = numpy.frombuffer(buffer=strFlow, dtype=numpy.int32, count=1, offset=8)[0] return numpy.frombuffer(buffer=strFlow, dtype=numpy.float32, count=intHeight * intWidth * 2, offset=12).reshape(intHeight, intWidth, 2) # end ########################################################## backwarp_tenGrid = {} def backwarp(tenIn, tenFlow): if str(tenFlow.shape) not in backwarp_tenGrid: tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() # end tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1) return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) # end ########################################################## class Flow(torch.nn.Module): def __init__(self): super().__init__() class Extractor(torch.nn.Module): def __init__(self): super().__init__() self.netFirst = torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSecond = torch.nn.Sequential( torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThird = torch.nn.Sequential( torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFourth = torch.nn.Sequential( torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFifth = torch.nn.Sequential( torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSixth = torch.nn.Sequential( torch.nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) # end def forward(self, tenInput): tenFirst = self.netFirst(tenInput) tenSecond = self.netSecond(tenFirst) tenThird = self.netThird(tenSecond) tenFourth = self.netFourth(tenThird) tenFifth = self.netFifth(tenFourth) tenSixth = self.netSixth(tenFifth) return [tenFirst, tenSecond, tenThird, tenFourth, tenFifth, tenSixth] # end # end class Decoder(torch.nn.Module): def __init__(self, intChannels): super().__init__() self.netMain = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intChannels, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) ) # end def forward(self, tenOne, tenTwo, objPrevious): intWidth = tenOne.shape[3] and tenTwo.shape[3] intHeight = tenOne.shape[2] and tenTwo.shape[2] tenMain = None if objPrevious is None: tenVolume = correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=tenTwo) tenMain = torch.cat([tenOne, tenVolume], 1) elif objPrevious is not None: tenForward = torch.nn.functional.interpolate(input=objPrevious['tenForward'], size=(intHeight, intWidth), mode='bilinear', align_corners=False) / float(objPrevious['tenForward'].shape[3]) * float(intWidth) tenVolume = correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=backwarp(tenTwo, tenForward)) tenMain = torch.cat([tenOne, tenVolume, tenForward], 1) # end return { 'tenForward': self.netMain(tenMain) } # end # end self.netExtractor = Extractor() self.netFirst = Decoder(16 + 81 + 2) self.netSecond = Decoder(32 + 81 + 2) self.netThird = Decoder(64 + 81 + 2) self.netFourth = Decoder(96 + 81 + 2) self.netFifth = Decoder(128 + 81 + 2) self.netSixth = Decoder(192 + 81) # end def forward(self, tenOne, tenTwo): intWidth = tenOne.shape[3] and tenTwo.shape[3] intHeight = tenOne.shape[2] and tenTwo.shape[2] tenOne = self.netExtractor(tenOne) tenTwo = self.netExtractor(tenTwo) objForward = None objBackward = None objForward = self.netSixth(tenOne[-1], tenTwo[-1], objForward) objBackward = self.netSixth(tenTwo[-1], tenOne[-1], objBackward) objForward = self.netFifth(tenOne[-2], tenTwo[-2], objForward) objBackward = self.netFifth(tenTwo[-2], tenOne[-2], objBackward) objForward = self.netFourth(tenOne[-3], tenTwo[-3], objForward) objBackward = self.netFourth(tenTwo[-3], tenOne[-3], objBackward) objForward = self.netThird(tenOne[-4], tenTwo[-4], objForward) objBackward = self.netThird(tenTwo[-4], tenOne[-4], objBackward) objForward = self.netSecond(tenOne[-5], tenTwo[-5], objForward) objBackward = self.netSecond(tenTwo[-5], tenOne[-5], objBackward) objForward = self.netFirst(tenOne[-6], tenTwo[-6], objForward) objBackward = self.netFirst(tenTwo[-6], tenOne[-6], objBackward) return { 'tenForward': torch.nn.functional.interpolate(input=objForward['tenForward'], size=(intHeight, intWidth), mode='bilinear', align_corners=False) * (float(intWidth) / float(objForward['tenForward'].shape[3])), 'tenBackward': torch.nn.functional.interpolate(input=objBackward['tenForward'], size=(intHeight, intWidth), mode='bilinear', align_corners=False) * (float(intWidth) / float(objBackward['tenForward'].shape[3])) } # end # end ########################################################## class Synthesis(torch.nn.Module): def __init__(self): super().__init__() class Basic(torch.nn.Module): def __init__(self, strType, intChannels, boolSkip): super().__init__() if strType == 'relu-conv-relu-conv': self.netMain = torch.nn.Sequential( torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), torch.nn.Conv2d(in_channels=intChannels[0], out_channels=intChannels[1], kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), torch.nn.Conv2d(in_channels=intChannels[1], out_channels=intChannels[2], kernel_size=3, stride=1, padding=1, bias=False) ) elif strType == 'conv-relu-conv': self.netMain = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intChannels[0], out_channels=intChannels[1], kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), torch.nn.Conv2d(in_channels=intChannels[1], out_channels=intChannels[2], kernel_size=3, stride=1, padding=1, bias=False) ) # end self.boolSkip = boolSkip if boolSkip == True: if intChannels[0] == intChannels[2]: self.netShortcut = None elif intChannels[0] != intChannels[2]: self.netShortcut = torch.nn.Conv2d(in_channels=intChannels[0], out_channels=intChannels[2], kernel_size=1, stride=1, padding=0, bias=False) # end # end # end def forward(self, tenInput): if self.boolSkip == False: return self.netMain(tenInput) # end if self.netShortcut is None: return self.netMain(tenInput) + tenInput elif self.netShortcut is not None: return self.netMain(tenInput) + self.netShortcut(tenInput) # end # end # end class Downsample(torch.nn.Module): def __init__(self, intChannels): super().__init__() self.netMain = torch.nn.Sequential( torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), torch.nn.Conv2d(in_channels=intChannels[0], out_channels=intChannels[1], kernel_size=3, stride=2, padding=1, bias=False), torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), torch.nn.Conv2d(in_channels=intChannels[1], out_channels=intChannels[2], kernel_size=3, stride=1, padding=1, bias=False) ) # end def forward(self, tenInput): return self.netMain(tenInput) # end # end class Upsample(torch.nn.Module): def __init__(self, intChannels): super().__init__() self.netMain = torch.nn.Sequential( torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), torch.nn.Conv2d(in_channels=intChannels[0], out_channels=intChannels[1], kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), torch.nn.Conv2d(in_channels=intChannels[1], out_channels=intChannels[2], kernel_size=3, stride=1, padding=1, bias=False) ) # end def forward(self, tenInput): return self.netMain(tenInput) # end # end class Encode(torch.nn.Module): def __init__(self): super().__init__() self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=32, init=0.25), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=32, init=0.25) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False), torch.nn.PReLU(num_parameters=64, init=0.25), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=64, init=0.25) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1, bias=False), torch.nn.PReLU(num_parameters=96, init=0.25), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1, bias=False), torch.nn.PReLU(num_parameters=96, init=0.25) ) # end def forward(self, tenInput): tenOutput = [] tenOutput.append(self.netOne(tenInput)) tenOutput.append(self.netTwo(tenOutput[-1])) tenOutput.append(self.netThr(tenOutput[-1])) return [torch.cat([tenInput, tenOutput[0]], 1)] + tenOutput[1:] # end # end class Softmetric(torch.nn.Module): def __init__(self): super().__init__() self.netInput = torch.nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1, bias=False) self.netError = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1, bias=False) for intRow, intFeatures in [(0, 16), (1, 32), (2, 64), (3, 96)]: self.add_module(str(intRow) + 'x0' + ' - ' + str(intRow) + 'x1', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) # end for intCol in [0]: self.add_module('0x' + str(intCol) + ' - ' + '1x' + str(intCol), Downsample([16, 32, 32])) self.add_module('1x' + str(intCol) + ' - ' + '2x' + str(intCol), Downsample([32, 64, 64])) self.add_module('2x' + str(intCol) + ' - ' + '3x' + str(intCol), Downsample([64, 96, 96])) # end for intCol in [1]: self.add_module('3x' + str(intCol) + ' - ' + '2x' + str(intCol), Upsample([96, 64, 64])) self.add_module('2x' + str(intCol) + ' - ' + '1x' + str(intCol), Upsample([64, 32, 32])) self.add_module('1x' + str(intCol) + ' - ' + '0x' + str(intCol), Upsample([32, 16, 16])) # end self.netOutput = Basic('conv-relu-conv', [16, 16, 1], True) # end def forward(self, tenEncone, tenEnctwo, tenFlow): tenColumn = [None, None, None, None] tenColumn[0] = torch.cat([self.netInput(tenEncone[0][:, 0:3, :, :]), self.netError(torch.nn.functional.l1_loss(input=tenEncone[0], target=backwarp(tenEnctwo[0], tenFlow), reduction='none').mean([1], True))], 1) tenColumn[1] = self._modules['0x0 - 1x0'](tenColumn[0]) tenColumn[2] = self._modules['1x0 - 2x0'](tenColumn[1]) tenColumn[3] = self._modules['2x0 - 3x0'](tenColumn[2]) intColumn = 1 for intRow in range(len(tenColumn) -1, -1, -1): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != len(tenColumn) - 1: tenUp = self._modules[str(intRow + 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow + 1]) if tenUp.shape[2] != tenColumn[intRow].shape[2]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, 0, 0, -1], mode='constant', value=0.0) if tenUp.shape[3] != tenColumn[intRow].shape[3]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, -1, 0, 0], mode='constant', value=0.0) tenColumn[intRow] = tenColumn[intRow] + tenUp # end # end return self.netOutput(tenColumn[0]) # end # end class Warp(torch.nn.Module): def __init__(self): super().__init__() self.netOne = Basic('conv-relu-conv', [3 + 3 + 32 + 32 + 1 + 1, 32, 32], True) self.netTwo = Basic('conv-relu-conv', [0 + 0 + 64 + 64 + 1 + 1, 64, 64], True) self.netThr = Basic('conv-relu-conv', [0 + 0 + 96 + 96 + 1 + 1, 96, 96], True) # end def forward(self, tenEncone, tenEnctwo, tenMetricone, tenMetrictwo, tenForward, tenBackward): tenOutput = [] for intLevel in range(3): if intLevel != 0: tenMetricone = torch.nn.functional.interpolate(input=tenMetricone, size=(tenEncone[intLevel].shape[2], tenEncone[intLevel].shape[3]), mode='bilinear', align_corners=False) tenMetrictwo = torch.nn.functional.interpolate(input=tenMetrictwo, size=(tenEnctwo[intLevel].shape[2], tenEnctwo[intLevel].shape[3]), mode='bilinear', align_corners=False) tenForward = torch.nn.functional.interpolate(input=tenForward, size=(tenEncone[intLevel].shape[2], tenEncone[intLevel].shape[3]), mode='bilinear', align_corners=False) * (float(tenEncone[intLevel].shape[3]) / float(tenForward.shape[3])) tenBackward = torch.nn.functional.interpolate(input=tenBackward, size=(tenEnctwo[intLevel].shape[2], tenEnctwo[intLevel].shape[3]), mode='bilinear', align_corners=False) * (float(tenEnctwo[intLevel].shape[3]) / float(tenBackward.shape[3])) # end tenOutput.append([self.netOne, self.netTwo, self.netThr][intLevel](torch.cat([ softsplat.softsplat(tenIn=torch.cat([tenEncone[intLevel], tenMetricone], 1), tenFlow=tenForward, tenMetric=tenMetricone.neg().clip(-20.0, 20.0), strMode='soft'), softsplat.softsplat(tenIn=torch.cat([tenEnctwo[intLevel], tenMetrictwo], 1), tenFlow=tenBackward, tenMetric=tenMetrictwo.neg().clip(-20.0, 20.0), strMode='soft') ], 1))) # end return tenOutput # end # end self.netEncode = Encode() self.netSoftmetric = Softmetric() self.netWarp = Warp() for intRow, intFeatures in [(0, 32), (1, 64), (2, 96)]: self.add_module(str(intRow) + 'x0' + ' - ' + str(intRow) + 'x1', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) self.add_module(str(intRow) + 'x1' + ' - ' + str(intRow) + 'x2', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) self.add_module(str(intRow) + 'x2' + ' - ' + str(intRow) + 'x3', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) self.add_module(str(intRow) + 'x3' + ' - ' + str(intRow) + 'x4', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) self.add_module(str(intRow) + 'x4' + ' - ' + str(intRow) + 'x5', Basic('relu-conv-relu-conv', [intFeatures, intFeatures, intFeatures], True)) # end for intCol in [0, 1, 2]: self.add_module('0x' + str(intCol) + ' - ' + '1x' + str(intCol), Downsample([32, 64, 64])) self.add_module('1x' + str(intCol) + ' - ' + '2x' + str(intCol), Downsample([64, 96, 96])) # end for intCol in [3, 4, 5]: self.add_module('2x' + str(intCol) + ' - ' + '1x' + str(intCol), Upsample([96, 64, 64])) self.add_module('1x' + str(intCol) + ' - ' + '0x' + str(intCol), Upsample([64, 32, 32])) # end self.netOutput = Basic('conv-relu-conv', [32, 32, 3], True) # end def forward(self, tenOne, tenTwo, tenForward, tenBackward, fltTime): tenEncone = self.netEncode(tenOne) tenEnctwo = self.netEncode(tenTwo) tenMetricone = self.netSoftmetric(tenEncone, tenEnctwo, tenForward) * 2.0 * fltTime tenMetrictwo = self.netSoftmetric(tenEnctwo, tenEncone, tenBackward) * 2.0 * (1.0 - fltTime) tenForward = tenForward * fltTime tenBackward = tenBackward * (1.0 - fltTime) tenWarp = self.netWarp(tenEncone, tenEnctwo, tenMetricone, tenMetrictwo, tenForward, tenBackward) tenColumn = [None, None, None] tenColumn[0] = tenWarp[0] tenColumn[1] = tenWarp[1] + self._modules['0x0 - 1x0'](tenColumn[0]) tenColumn[2] = tenWarp[2] + self._modules['1x0 - 2x0'](tenColumn[1]) intColumn = 1 for intRow in range(len(tenColumn)): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != 0: tenColumn[intRow] = tenColumn[intRow] + self._modules[str(intRow - 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow - 1]) # end # end intColumn = 2 for intRow in range(len(tenColumn)): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != 0: tenColumn[intRow] = tenColumn[intRow] + self._modules[str(intRow - 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow - 1]) # end # end intColumn = 3 for intRow in range(len(tenColumn) -1, -1, -1): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != len(tenColumn) - 1: tenUp = self._modules[str(intRow + 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow + 1]) if tenUp.shape[2] != tenColumn[intRow].shape[2]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, 0, 0, -1], mode='constant', value=0.0) if tenUp.shape[3] != tenColumn[intRow].shape[3]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, -1, 0, 0], mode='constant', value=0.0) tenColumn[intRow] = tenColumn[intRow] + tenUp # end # end intColumn = 4 for intRow in range(len(tenColumn) -1, -1, -1): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != len(tenColumn) - 1: tenUp = self._modules[str(intRow + 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow + 1]) if tenUp.shape[2] != tenColumn[intRow].shape[2]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, 0, 0, -1], mode='constant', value=0.0) if tenUp.shape[3] != tenColumn[intRow].shape[3]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, -1, 0, 0], mode='constant', value=0.0) tenColumn[intRow] = tenColumn[intRow] + tenUp # end # end intColumn = 5 for intRow in range(len(tenColumn) -1, -1, -1): tenColumn[intRow] = self._modules[str(intRow) + 'x' + str(intColumn - 1) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow]) if intRow != len(tenColumn) - 1: tenUp = self._modules[str(intRow + 1) + 'x' + str(intColumn) + ' - ' + str(intRow) + 'x' + str(intColumn)](tenColumn[intRow + 1]) if tenUp.shape[2] != tenColumn[intRow].shape[2]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, 0, 0, -1], mode='constant', value=0.0) if tenUp.shape[3] != tenColumn[intRow].shape[3]: tenUp = torch.nn.functional.pad(input=tenUp, pad=[0, -1, 0, 0], mode='constant', value=0.0) tenColumn[intRow] = tenColumn[intRow] + tenUp # end # end return self.netOutput(tenColumn[0]) # end # end ########################################################## class Network(torch.nn.Module): def __init__(self): super().__init__() self.netFlow = Flow() self.netSynthesis = Synthesis() self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/softsplat/network-' + args_strModel + '.pytorch', file_name='softsplat-' + args_strModel).items()}) # end def forward(self, tenOne, tenTwo, fltTimes): with torch.set_grad_enabled(False): tenStats = [tenOne, tenTwo] tenMean = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) tenStd = (sum([tenIn.std([1, 2, 3], False, True).square() + (tenMean - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() tenOne = ((tenOne - tenMean) / (tenStd + 0.0000001)).detach() tenTwo = ((tenTwo - tenMean) / (tenStd + 0.0000001)).detach() # end objFlow = self.netFlow(tenOne, tenTwo) tenImages = [self.netSynthesis(tenOne, tenTwo, objFlow['tenForward'], objFlow['tenBackward'], fltTime) for fltTime in fltTimes] return [(tenImage * tenStd) + tenMean for tenImage in tenImages] # end # end netNetwork = None ########################################################## def estimate(tenOne, tenTwo, fltTimes): global netNetwork if netNetwork is None: netNetwork = Network().cuda().eval() # end assert(tenOne.shape[1] == tenTwo.shape[1]) assert(tenOne.shape[2] == tenTwo.shape[2]) intWidth = tenOne.shape[2] intHeight = tenOne.shape[1] tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth) tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth) intPadr = (2 - (intWidth % 2)) % 2 intPadb = (2 - (intHeight % 2)) % 2 tenPreprocessedOne = torch.nn.functional.pad(input=tenPreprocessedOne, pad=[0, intPadr, 0, intPadb], mode='replicate') tenPreprocessedTwo = torch.nn.functional.pad(input=tenPreprocessedTwo, pad=[0, intPadr, 0, intPadb], mode='replicate') return [tenImage[0, :, :intHeight, :intWidth].cpu() for tenImage in netNetwork(tenPreprocessedOne, tenPreprocessedTwo, fltTimes)] # end ########################################################## if __name__ == '__main__': if args_strOut.split('.')[-1] in ['bmp', 'jpg', 'jpeg', 'png']: tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(args_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(args_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) tenOutput = estimate(tenOne, tenTwo, [0.5])[0] PIL.Image.fromarray((tenOutput.clip(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0).astype(numpy.uint8)).save(args_strOut) elif args_strOut.split('.')[-1] in ['avi', 'mp4', 'webm', 'wmv']: import moviepy import moviepy.editor import moviepy.video.io.ffmpeg_writer objVideoreader = moviepy.editor.VideoFileClip(filename=args_strVideo) intWidth = objVideoreader.w intHeight = objVideoreader.h tenFrames = [None, None, None, None, None] with moviepy.video.io.ffmpeg_writer.FFMPEG_VideoWriter(filename=args_strOut, size=(intWidth, intHeight), fps=objVideoreader.fps) as objVideowriter: for npyFrame in objVideoreader.iter_frames(): tenFrames[4] = torch.FloatTensor(numpy.ascontiguousarray(npyFrame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) if tenFrames[0] is not None: tenFrames[1:4] = estimate(tenFrames[0], tenFrames[4], [0.25, 0.5, 0.75]) objVideowriter.write_frame((tenFrames[0].clip(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0).astype(numpy.uint8)) objVideowriter.write_frame((tenFrames[1].clip(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0).astype(numpy.uint8)) objVideowriter.write_frame((tenFrames[2].clip(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0).astype(numpy.uint8)) objVideowriter.write_frame((tenFrames[3].clip(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0).astype(numpy.uint8)) # end tenFrames[0] = torch.FloatTensor(numpy.ascontiguousarray(npyFrame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) # end # end # end # end