|
from typing import Generator, List, Iterable |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from huggingface_hub import snapshot_download |
|
|
|
"""A wrapper class for running a frame interpolation based on the FILM model on TFHub |
|
|
|
Usage: |
|
interpolator = Interpolator() |
|
result_batch = interpolator(image_batch_0, image_batch_1, batch_dt) |
|
Where image_batch_1 and image_batch_2 are numpy tensors with TF standard |
|
(B,H,W,C) layout, batch_dt is the sub-frame time in range [0..1], (B,) layout. |
|
""" |
|
|
|
FILM_REPO_ID = "leonelhs/film" |
|
|
|
|
|
def _pad_to_align(x, align): |
|
"""Pads image batch x so width and height divide by align. |
|
|
|
Args: |
|
x: Image batch to align. |
|
align: Number to align to. |
|
|
|
Returns: |
|
1) An image padded so width % align == 0 and height % align == 0. |
|
2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box |
|
to undo the padding. |
|
""" |
|
|
|
assert np.ndim(x) == 4 |
|
assert align > 0, 'align must be a positive number.' |
|
|
|
height, width = x.shape[-3:-1] |
|
height_to_pad = (align - height % align) if height % align != 0 else 0 |
|
width_to_pad = (align - width % align) if width % align != 0 else 0 |
|
|
|
bbox_to_pad = { |
|
'offset_height': height_to_pad // 2, |
|
'offset_width': width_to_pad // 2, |
|
'target_height': height + height_to_pad, |
|
'target_width': width + width_to_pad |
|
} |
|
padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad) |
|
bbox_to_crop = { |
|
'offset_height': height_to_pad // 2, |
|
'offset_width': width_to_pad // 2, |
|
'target_height': height, |
|
'target_width': width |
|
} |
|
return padded_x, bbox_to_crop |
|
|
|
|
|
class Interpolator: |
|
"""A class for generating interpolated frames between two input frames. |
|
|
|
Uses the Film model from TFHub |
|
""" |
|
|
|
def __init__(self, times_to_interpolate=6, align: int = 64) -> None: |
|
"""Loads a saved model. |
|
|
|
Args: |
|
align: 'If >1, pad the input size so it divides with this before |
|
inference.' |
|
""" |
|
self.times_to_interpolate = times_to_interpolate |
|
model_path = snapshot_download(FILM_REPO_ID) |
|
self._model = tf.saved_model.load(model_path) |
|
|
|
self._align = align |
|
|
|
def __call__(self, x0: np.ndarray, x1: np.ndarray, |
|
dt: np.ndarray) -> np.ndarray: |
|
"""Generates an interpolated frame between given two batches of frames. |
|
|
|
All inputs should be np.float32 datatype. |
|
|
|
Args: |
|
x0: First image batch. Dimensions: (batch_size, height, width, channels) |
|
x1: Second image batch. Dimensions: (batch_size, height, width, channels) |
|
dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,) |
|
|
|
Returns: |
|
The result with dimensions (batch_size, height, width, channels). |
|
""" |
|
if self._align is not None: |
|
x0, bbox_to_crop = _pad_to_align(x0, self._align) |
|
x1, _ = _pad_to_align(x1, self._align) |
|
|
|
inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]} |
|
result = self._model(inputs, training=False) |
|
image = result['image'] |
|
|
|
if self._align is not None: |
|
image = tf.image.crop_to_bounding_box(image, **bbox_to_crop) |
|
return image.numpy() |
|
|
|
def preview_frames(self, frames: List[np.ndarray]): |
|
|
|
time = np.array([0.5], dtype=np.float32) |
|
|
|
media_input = { |
|
'time': np.expand_dims(time, axis=0), |
|
'x0': np.expand_dims(frames[0], axis=0), |
|
'x1': np.expand_dims(frames[1], axis=0) |
|
} |
|
mid = self._model(media_input) |
|
return [frames[0], mid['image'][0].numpy(), frames[1]] |
|
|
|
|
|
def _recursive_generator( |
|
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int, |
|
interpolator: Interpolator) -> Generator[np.ndarray, None, None]: |
|
"""Splits halfway to repeatedly generate more frames. |
|
|
|
Args: |
|
frame1: Input image 1. |
|
frame2: Input image 2. |
|
num_recursions: How many times to interpolate the consecutive image pairs. |
|
interpolator: The frame interpolator instance. |
|
|
|
Yields: |
|
The interpolated frames, including the first frame (frame1), but excluding |
|
the final frame2. |
|
""" |
|
if num_recursions == 0: |
|
yield frame1 |
|
else: |
|
|
|
|
|
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) |
|
mid_frame = interpolator(np.expand_dims(frame1, axis=0), np.expand_dims(frame2, axis=0), time)[0] |
|
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1, interpolator) |
|
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1, interpolator) |
|
|
|
|
|
def interpolate_recursively( |
|
frames: List[np.ndarray], interpolator: Interpolator) -> Iterable[np.ndarray]: |
|
"""Generates interpolated frames by repeatedly interpolating the midpoint. |
|
|
|
Args: |
|
frames: List of input frames. Expected shape (H, W, 3). The colors should be |
|
in the range[0, 1] and in gamma space. |
|
num_recursions: Number of times to do recursive midpoint |
|
interpolation. |
|
interpolator: The frame interpolation model to use. |
|
|
|
Yields: |
|
The interpolated frames (including the inputs). |
|
""" |
|
times_to_interpolate = interpolator.times_to_interpolate |
|
|
|
n = len(frames) |
|
for i in range(1, n): |
|
yield from _recursive_generator(frames[i - 1], frames[i], times_to_interpolate, interpolator) |
|
|
|
yield frames[-1] |
|
|