from typing import Generator, List, Iterable import numpy as np import tensorflow as tf from huggingface_hub import snapshot_download """A wrapper class for running a frame interpolation based on the FILM model on TFHub Usage: interpolator = Interpolator() result_batch = interpolator(image_batch_0, image_batch_1, batch_dt) Where image_batch_1 and image_batch_2 are numpy tensors with TF standard (B,H,W,C) layout, batch_dt is the sub-frame time in range [0..1], (B,) layout. """ FILM_REPO_ID = "leonelhs/film" def _pad_to_align(x, align): """Pads image batch x so width and height divide by align. Args: x: Image batch to align. align: Number to align to. Returns: 1) An image padded so width % align == 0 and height % align == 0. 2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box to undo the padding. """ # Input checking. assert np.ndim(x) == 4 assert align > 0, 'align must be a positive number.' height, width = x.shape[-3:-1] height_to_pad = (align - height % align) if height % align != 0 else 0 width_to_pad = (align - width % align) if width % align != 0 else 0 bbox_to_pad = { 'offset_height': height_to_pad // 2, 'offset_width': width_to_pad // 2, 'target_height': height + height_to_pad, 'target_width': width + width_to_pad } padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad) bbox_to_crop = { 'offset_height': height_to_pad // 2, 'offset_width': width_to_pad // 2, 'target_height': height, 'target_width': width } return padded_x, bbox_to_crop class Interpolator: """A class for generating interpolated frames between two input frames. Uses the Film model from TFHub """ def __init__(self, times_to_interpolate=6, align: int = 64) -> None: """Loads a saved model. Args: align: 'If >1, pad the input size so it divides with this before inference.' """ self.times_to_interpolate = times_to_interpolate model_path = snapshot_download(FILM_REPO_ID) self._model = tf.saved_model.load(model_path) # self._model = hub.load("https://tfhub.dev/google/film/1") self._align = align def __call__(self, x0: np.ndarray, x1: np.ndarray, dt: np.ndarray) -> np.ndarray: """Generates an interpolated frame between given two batches of frames. All inputs should be np.float32 datatype. Args: x0: First image batch. Dimensions: (batch_size, height, width, channels) x1: Second image batch. Dimensions: (batch_size, height, width, channels) dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,) Returns: The result with dimensions (batch_size, height, width, channels). """ if self._align is not None: x0, bbox_to_crop = _pad_to_align(x0, self._align) x1, _ = _pad_to_align(x1, self._align) inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]} result = self._model(inputs, training=False) image = result['image'] if self._align is not None: image = tf.image.crop_to_bounding_box(image, **bbox_to_crop) return image.numpy() def preview_frames(self, frames: List[np.ndarray]): time = np.array([0.5], dtype=np.float32) media_input = { 'time': np.expand_dims(time, axis=0), # adding the batch dimension to the time 'x0': np.expand_dims(frames[0], axis=0), # adding the batch dimension to the image 'x1': np.expand_dims(frames[1], axis=0) # adding the batch dimension to the image } mid = self._model(media_input) return [frames[0], mid['image'][0].numpy(), frames[1]] def _recursive_generator( frame1: np.ndarray, frame2: np.ndarray, num_recursions: int, interpolator: Interpolator) -> Generator[np.ndarray, None, None]: """Splits halfway to repeatedly generate more frames. Args: frame1: Input image 1. frame2: Input image 2. num_recursions: How many times to interpolate the consecutive image pairs. interpolator: The frame interpolator instance. Yields: The interpolated frames, including the first frame (frame1), but excluding the final frame2. """ if num_recursions == 0: yield frame1 else: # Adds the batch dimension to all inputs before calling the interpolator, # and remove it afterwards. time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) mid_frame = interpolator(np.expand_dims(frame1, axis=0), np.expand_dims(frame2, axis=0), time)[0] yield from _recursive_generator(frame1, mid_frame, num_recursions - 1, interpolator) yield from _recursive_generator(mid_frame, frame2, num_recursions - 1, interpolator) def interpolate_recursively( frames: List[np.ndarray], interpolator: Interpolator) -> Iterable[np.ndarray]: """Generates interpolated frames by repeatedly interpolating the midpoint. Args: frames: List of input frames. Expected shape (H, W, 3). The colors should be in the range[0, 1] and in gamma space. num_recursions: Number of times to do recursive midpoint interpolation. interpolator: The frame interpolation model to use. Yields: The interpolated frames (including the inputs). """ times_to_interpolate = interpolator.times_to_interpolate n = len(frames) for i in range(1, n): yield from _recursive_generator(frames[i - 1], frames[i], times_to_interpolate, interpolator) # Separately yield the final frame. yield frames[-1]