import warnings import numpy as np import tensorflow as tf import torch from interpolator import Interpolator def translate_state_dict(var_dict, state_dict): for name, (prev_name, weight) in zip(state_dict, var_dict.items()): print('Mapping', prev_name, '->', name) weight = torch.from_numpy(weight) if 'kernel' in prev_name: # Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W) weight = weight.permute(3, 2, 0, 1) assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}' state_dict[name] = weight def import_state_dict(interpolator: Interpolator, saved_model): variables = saved_model.keras_api.variables extract_dict = interpolator.extract.state_dict() flow_dict = interpolator.predict_flow.state_dict() fuse_dict = interpolator.fuse.state_dict() extract_vars = {} _flow_vars = {} _fuse_vars = {} for var in variables: name = var.name if name.startswith('feat_net'): extract_vars[name[9:]] = var.numpy() elif name.startswith('predict_flow'): _flow_vars[name[13:]] = var.numpy() elif name.startswith('fusion'): _fuse_vars[name[7:]] = var.numpy() # reverse order of modules to allow jit export # TODO: improve this hack flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True)) fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True)) assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}' assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}' assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}' for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)): translate_state_dict(var_dict, state_dict) interpolator.extract.load_state_dict(extract_dict) interpolator.predict_flow.load_state_dict(flow_dict) interpolator.fuse.load_state_dict(fuse_dict) def verify_debug_outputs(pt_outputs, tf_outputs): max_error = 0 for name, predicted in pt_outputs.items(): if name == 'image': continue pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted] true_frfp = [f.numpy() for f in tf_outputs[name]] for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)): assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}' error = np.max(np.abs(pred - true)) max_error = max(max_error, error) assert error < 1, f'{name} {i} max error: {error}' print('Max intermediate error:', max_error) def test_model(interpolator, model, half=False, gpu=False): torch.manual_seed(0) time = torch.full((1, 1), .5) x0 = torch.rand(1, 3, 256, 256) x1 = torch.rand(1, 3, 256, 256) x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32) x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32) time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32) tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False) if half: x0 = x0.half() x1 = x1.half() time = time.half() if gpu and torch.cuda.is_available(): x0 = x0.cuda() x1 = x1.cuda() time = time.cuda() with torch.no_grad(): pt_outputs = interpolator.debug_forward(x0, x1, time) verify_debug_outputs(pt_outputs, tf_outputs) with torch.no_grad(): prediction = interpolator(x0, x1, time) output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy() true_color = tf_outputs['image'].numpy() error = np.abs(output_color - true_color).max() print('Color max error:', error) def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False): print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} ' f'using {"CG"[use_gpu]}PU') model = tf.compat.v2.saved_model.load(model_path) interpolator = Interpolator() interpolator.eval() import_state_dict(interpolator, model) if use_gpu and torch.cuda.is_available(): interpolator = interpolator.cuda() else: use_gpu = False if fp16: interpolator = interpolator.half() if export_to_torchscript: interpolator = torch.jit.script(interpolator) if export_to_torchscript: interpolator.save(save_path) else: torch.save(interpolator.state_dict(), save_path) if not skiptest: if not use_gpu and fp16: warnings.warn('Testing FP16 model on CPU is impossible, casting it back') interpolator = interpolator.float() fp16 = False test_model(interpolator, model, fp16, use_gpu) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict') parser.add_argument('model_path', type=str, help='Path to the TF SavedModel') parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict') parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript') parser.add_argument('--fp32', action='store_true', help='Save at full precision') parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead') parser.add_argument('--gpu', action='store_true', help='Use GPU') args = parser.parse_args() main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)