tree3po's picture
Upload 21 files
12aae2e verified
raw
history blame
2.83 kB
"""
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
Action format derived from VPT https://github.com/openai/Video-Pre-Training
"""
import math
import torch
from torch import nn
from einops import rearrange, parse_shape
from typing import Mapping, Sequence
import torch
from einops import rearrange
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
"""
sigmoid schedule
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
better for images > 64x64, when used during training
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraX",
"cameraY",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
for i, current_actions in enumerate(actions):
for j, action_key in enumerate(ACTION_KEYS):
if action_key.startswith("camera"):
if action_key == "cameraX":
value = current_actions["camera"][0]
elif action_key == "cameraY":
value = current_actions["camera"][1]
else:
raise ValueError(f"Unknown camera action key: {action_key}")
# NOTE these numbers specific to the camera quantization used in
# https://github.com/etched-ai/dreamcraft/blob/216e952f795bb3da598639a109bcdba4d2067b69/spark/preprocess_vpt_to_videos_actions.py#L312
# see method `compress_mouse`
max_val = 20
bin_size = 0.5
num_buckets = int(max_val / bin_size)
value = (value - num_buckets) / num_buckets
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
else:
value = current_actions[action_key]
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
actions_one_hot[i, j] = value
return actions_one_hot