Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,964 Bytes
31f2f28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import warnings
import numpy as np
import tensorflow as tf
import torch
from interpolator import Interpolator
def translate_state_dict(var_dict, state_dict):
for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
print('Mapping', prev_name, '->', name)
weight = torch.from_numpy(weight)
if 'kernel' in prev_name:
# Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
weight = weight.permute(3, 2, 0, 1)
assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'
state_dict[name] = weight
def import_state_dict(interpolator: Interpolator, saved_model):
variables = saved_model.keras_api.variables
extract_dict = interpolator.extract.state_dict()
flow_dict = interpolator.predict_flow.state_dict()
fuse_dict = interpolator.fuse.state_dict()
extract_vars = {}
_flow_vars = {}
_fuse_vars = {}
for var in variables:
name = var.name
if name.startswith('feat_net'):
extract_vars[name[9:]] = var.numpy()
elif name.startswith('predict_flow'):
_flow_vars[name[13:]] = var.numpy()
elif name.startswith('fusion'):
_fuse_vars[name[7:]] = var.numpy()
# reverse order of modules to allow jit export
# TODO: improve this hack
flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))
assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'
for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
translate_state_dict(var_dict, state_dict)
interpolator.extract.load_state_dict(extract_dict)
interpolator.predict_flow.load_state_dict(flow_dict)
interpolator.fuse.load_state_dict(fuse_dict)
def verify_debug_outputs(pt_outputs, tf_outputs):
max_error = 0
for name, predicted in pt_outputs.items():
if name == 'image':
continue
pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
true_frfp = [f.numpy() for f in tf_outputs[name]]
for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
error = np.max(np.abs(pred - true))
max_error = max(max_error, error)
assert error < 1, f'{name} {i} max error: {error}'
print('Max intermediate error:', max_error)
def test_model(interpolator, model, half=False, gpu=False):
torch.manual_seed(0)
time = torch.full((1, 1), .5)
x0 = torch.rand(1, 3, 256, 256)
x1 = torch.rand(1, 3, 256, 256)
x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)
if half:
x0 = x0.half()
x1 = x1.half()
time = time.half()
if gpu and torch.cuda.is_available():
x0 = x0.cuda()
x1 = x1.cuda()
time = time.cuda()
with torch.no_grad():
pt_outputs = interpolator.debug_forward(x0, x1, time)
verify_debug_outputs(pt_outputs, tf_outputs)
with torch.no_grad():
prediction = interpolator(x0, x1, time)
output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
true_color = tf_outputs['image'].numpy()
error = np.abs(output_color - true_color).max()
print('Color max error:', error)
def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
f'using {"CG"[use_gpu]}PU')
model = tf.compat.v2.saved_model.load(model_path)
interpolator = Interpolator()
interpolator.eval()
import_state_dict(interpolator, model)
if use_gpu and torch.cuda.is_available():
interpolator = interpolator.cuda()
else:
use_gpu = False
if fp16:
interpolator = interpolator.half()
if export_to_torchscript:
interpolator = torch.jit.script(interpolator)
if export_to_torchscript:
interpolator.save(save_path)
else:
torch.save(interpolator.state_dict(), save_path)
if not skiptest:
if not use_gpu and fp16:
warnings.warn('Testing FP16 model on CPU is impossible, casting it back')
interpolator = interpolator.float()
fp16 = False
test_model(interpolator, model, fp16, use_gpu)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')
parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
parser.add_argument('--fp32', action='store_true', help='Save at full precision')
parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead')
parser.add_argument('--gpu', action='store_true', help='Use GPU')
args = parser.parse_args()
main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)
|