zhengchong's picture
chore: Update dependencies and code structure
6eb1d7d
raw
history blame contribute delete
No virus
4.84 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import logging
import numpy as np
import pickle
from enum import Enum
from typing import Optional
import torch
from torch import nn
from detectron2.config import CfgNode
from detectron2.utils.file_io import PathManager
from .vertex_direct_embedder import VertexDirectEmbedder
from .vertex_feature_embedder import VertexFeatureEmbedder
class EmbedderType(Enum):
"""
Embedder type which defines how vertices are mapped into the embedding space:
- "vertex_direct": direct vertex embedding
- "vertex_feature": embedding vertex features
"""
VERTEX_DIRECT = "vertex_direct"
VERTEX_FEATURE = "vertex_feature"
def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module:
"""
Create an embedder based on the provided configuration
Args:
embedder_spec (CfgNode): embedder configuration
embedder_dim (int): embedding space dimensionality
Return:
An embedder instance for the specified configuration
Raises ValueError, in case of unexpected embedder type
"""
embedder_type = EmbedderType(embedder_spec.TYPE)
if embedder_type == EmbedderType.VERTEX_DIRECT:
embedder = VertexDirectEmbedder(
num_vertices=embedder_spec.NUM_VERTICES,
embed_dim=embedder_dim,
)
if embedder_spec.INIT_FILE != "":
embedder.load(embedder_spec.INIT_FILE)
elif embedder_type == EmbedderType.VERTEX_FEATURE:
embedder = VertexFeatureEmbedder(
num_vertices=embedder_spec.NUM_VERTICES,
feature_dim=embedder_spec.FEATURE_DIM,
embed_dim=embedder_dim,
train_features=embedder_spec.FEATURES_TRAINABLE,
)
if embedder_spec.INIT_FILE != "":
embedder.load(embedder_spec.INIT_FILE)
else:
raise ValueError(f"Unexpected embedder type {embedder_type}")
if not embedder_spec.IS_TRAINABLE:
embedder.requires_grad_(False)
return embedder
class Embedder(nn.Module):
"""
Embedder module that serves as a container for embedders to use with different
meshes. Extends Module to automatically save / load state dict.
"""
DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder."
def __init__(self, cfg: CfgNode):
"""
Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule
"embedder_{i}".
Args:
cfg (CfgNode): configuration options
"""
super(Embedder, self).__init__()
self.mesh_names = set()
embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
logger = logging.getLogger(__name__)
for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
self.mesh_names.add(mesh_name)
if cfg.MODEL.WEIGHTS != "":
self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS)
def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
if prefix is None:
prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX
state_dict = None
if fpath.endswith(".pkl"):
with PathManager.open(fpath, "rb") as hFile:
state_dict = pickle.load(hFile, encoding="latin1")
else:
with PathManager.open(fpath, "rb") as hFile:
state_dict = torch.load(hFile, map_location=torch.device("cpu"))
if state_dict is not None and "model" in state_dict:
state_dict_local = {}
for key in state_dict["model"]:
if key.startswith(prefix):
v_key = state_dict["model"][key]
if isinstance(v_key, np.ndarray):
v_key = torch.from_numpy(v_key)
state_dict_local[key[len(prefix) :]] = v_key
# non-strict loading to finetune on different meshes
self.load_state_dict(state_dict_local, strict=False)
def forward(self, mesh_name: str) -> torch.Tensor:
"""
Produce vertex embeddings for the specific mesh; vertex embeddings are
a tensor of shape [N, D] where:
N = number of vertices
D = number of dimensions in the embedding space
Args:
mesh_name (str): name of a mesh for which to obtain vertex embeddings
Return:
Vertex embeddings, a tensor of shape [N, D]
"""
return getattr(self, f"embedder_{mesh_name}")()
def has_embeddings(self, mesh_name: str) -> bool:
return hasattr(self, f"embedder_{mesh_name}")