""" Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py Action format derived from VPT https://github.com/openai/Video-Pre-Training """ import math import torch from torch import nn from einops import rearrange, parse_shape from typing import Mapping, Sequence import torch from einops import rearrange def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): """ sigmoid schedule proposed in https://arxiv.org/abs/2212.11972 - Figure 8 better for images > 64x64, when used during training """ steps = timesteps + 1 t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps v_start = torch.tensor(start / tau).sigmoid() v_end = torch.tensor(end / tau).sigmoid() alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) ACTION_KEYS = [ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward", "back", "left", "right", "cameraX", "cameraY", "jump", "sneak", "sprint", "swapHands", "attack", "use", "pickItem", "drop", ] def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor: actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS)) for i, current_actions in enumerate(actions): for j, action_key in enumerate(ACTION_KEYS): if action_key.startswith("camera"): if action_key == "cameraX": value = current_actions["camera"][0] elif action_key == "cameraY": value = current_actions["camera"][1] else: raise ValueError(f"Unknown camera action key: {action_key}") # NOTE these numbers specific to the camera quantization used in # https://github.com/etched-ai/dreamcraft/blob/216e952f795bb3da598639a109bcdba4d2067b69/spark/preprocess_vpt_to_videos_actions.py#L312 # see method `compress_mouse` max_val = 20 bin_size = 0.5 num_buckets = int(max_val / bin_size) value = (value - num_buckets) / num_buckets assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}" else: value = current_actions[action_key] assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}" actions_one_hot[i, j] = value return actions_one_hot