evf-sam2 / model /evf_sam2_video.py
wondervictor's picture
update sam2
5e769e6
raw
history blame
12.3 kB
from typing import List
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
from .segment_anything_2.sam2.build_sam import build_sam2, build_sam2_video_predictor
from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
from .configuration_evf import EvfConfig
from .segment_anything_2.sam2.utils.misc import load_video_frames
from collections import OrderedDict
def dice_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
scale=1000, # 100000.0,
eps=1e-6,
):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1, 2)
targets = targets.flatten(1, 2)
numerator = 2 * (inputs / scale * targets).sum(-1)
denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
loss = 1 - (numerator + eps) / (denominator + eps)
loss = loss.sum() / (num_masks + 1e-8)
return loss
def sigmoid_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.binary_cross_entropy_with_logits(inputs,
targets,
reduction="none")
loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
return loss
class EvfSam2Model(PreTrainedModel):
config_class = EvfConfig
def __init__(self, config, **kwargs):
super(EvfSam2Model, self).__init__(config)
self.config = config
self.vision_pretrained = kwargs.get("vision_pretrained", None)
self.encoder_pretrained = kwargs.get("encoder_pretrained", None)
self.dice_loss_weight = kwargs.get("dice_loss_weight", None)
self.bce_loss_weight = kwargs.get("bce_loss_weight", None)
self.train_mask_decoder = kwargs.get("train_mask_decoder", False)
self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False)
self.initialize_evf_modules(config)
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
def initialize_evf_modules(self, config):
# SAM
if config.sam_scale == "large":
self.visual_model = build_sam2_video_predictor(
"sam2_hiera_l.yaml", self.vision_pretrained, device=None)
elif config.sam_scale == "tiny":
self.visual_model = build_sam2_video_predictor(
"sam2_hiera_t.yaml", self.vision_pretrained, device=None)
else:
raise NotImplementedError
for param in self.visual_model.parameters():
param.requires_grad = False
if self.train_mask_decoder:
self.visual_model.sam_mask_decoder.train()
for param in self.visual_model.sam_mask_decoder.parameters():
param.requires_grad = True
if self.train_prompt_encoder:
self.visual_model.sam_prompt_encoder.no_mask_embed.requires_grad_(
True)
# beit-3
if self.config.mm_extractor_scale == "base":
beit_config = _get_base_config()
elif self.config.mm_extractor_scale == "large":
beit_config = _get_large_config()
else:
raise AttributeError(
f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'."
)
self.mm_extractor = BEiT3Wrapper(beit_config)
if self.encoder_pretrained is not None:
beit_state_dict = torch.load(self.encoder_pretrained)["model"]
self.mm_extractor.load_state_dict(beit_state_dict, strict=False)
for param in self.mm_extractor.parameters():
param.requires_grad = True
# Projection layer
in_dim = config.hidden_size
assert in_dim==beit_config.encoder_embed_dim, \
f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}"
out_dim = config.out_dim
text_fc = [
nn.Linear(in_dim, in_dim),
nn.ReLU(),
nn.Linear(in_dim, out_dim)
]
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
self.text_hidden_fcs.train()
for param in self.text_hidden_fcs.parameters():
param.requires_grad = True
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
masks = masks.float()
masks = F.interpolate(masks,
orig_hw,
mode="bilinear",
align_corners=False)
return masks
# def forward(
# self,
# images: torch.FloatTensor,
# images_evf: torch.FloatTensor,
# input_ids: torch.LongTensor,
# attention_masks: torch.LongTensor,
# offset: torch.LongTensor,
# masks_list: List[torch.FloatTensor],
# label_list: List[torch.Tensor],
# resize_list: List[tuple],
# inference: bool = False,
# **kwargs,
# ):
# # image_embeddings = self.get_visual_embs(images)
# backbone_out = self.visual_model.forward_image(images)
# # dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
# _, image_embeddings, _, _ = self.visual_model._prepare_backbone_features(backbone_out)
# image_embeddings = [_.to(images.dtype) for _ in image_embeddings]
# batch_size = images.shape[0]
# if self.visual_model.directly_add_no_mem_embed:
# image_embeddings[-1] = image_embeddings[-1] + self.visual_model.no_mem_embed
# feats = [
# feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
# for feat, feat_size in zip(image_embeddings[::-1], self._bb_feat_sizes[::-1])
# ][::-1]
# _features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
# assert batch_size == len(offset) - 1
# images_evf_list = []
# for i in range(len(offset) - 1):
# start_i, end_i = offset[i], offset[i + 1]
# images_evf_i = (
# images_evf[i]
# .unsqueeze(0)
# .expand(end_i - start_i, -1, -1, -1)
# .contiguous()
# )
# images_evf_list.append(images_evf_i)
# images_evf = torch.cat(images_evf_list, dim=0)
# multimask_output = False
# output = self.mm_extractor.beit3(
# visual_tokens=images_evf,
# textual_tokens=input_ids,
# text_padding_position=~attention_masks
# )
# feat = output["encoder_out"][:, :1, ...]
# feat = self.text_hidden_fcs[0](feat)
# feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)])
# pred_masks = []
# for i in range(len(feat)):
# (
# sparse_embeddings,
# dense_embeddings,
# ) = self.visual_model.sam_prompt_encoder(
# points=None,
# boxes=None,
# masks=None,
# text_embeds=feat[i],
# )
# sparse_embeddings = sparse_embeddings.to(feat[i].dtype)
# high_res_features = [
# feat_level[i].unsqueeze(0)
# for feat_level in _features["high_res_feats"]
# ]
# low_res_masks, iou_predictions, _, _ = self.visual_model.sam_mask_decoder(
# image_embeddings=_features["image_embed"][i].unsqueeze(0),
# image_pe=self.visual_model.sam_prompt_encoder.get_dense_pe(),
# sparse_prompt_embeddings=sparse_embeddings,
# dense_prompt_embeddings=dense_embeddings,
# multimask_output=multimask_output,
# repeat_image = True,
# high_res_features=high_res_features,
# )
# if multimask_output:
# sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
# low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
# pred_mask = self.postprocess_masks(
# low_res_masks,
# orig_hw=label_list[i].shape,
# )
# pred_masks.append(pred_mask[:, 0])
# gt_masks = masks_list
# if inference:
# return {
# "pred_masks": pred_masks,
# "gt_masks": gt_masks,
# }
# mask_bce_loss = 0
# mask_dice_loss = 0
# num_masks = 0
# for batch_idx in range(len(pred_masks)):
# gt_mask = gt_masks[batch_idx]
# pred_mask = pred_masks[batch_idx]
# assert (
# gt_mask.shape[0] == pred_mask.shape[0]
# ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
# gt_mask.shape, pred_mask.shape
# )
# mask_bce_loss += (
# sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
# * gt_mask.shape[0]
# )
# mask_dice_loss += (
# dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
# * gt_mask.shape[0]
# )
# num_masks += gt_mask.shape[0]
# mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
# mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
# mask_loss = mask_bce_loss + mask_dice_loss
# loss = mask_loss
# return {
# "loss": loss,
# "mask_bce_loss": mask_bce_loss,
# "mask_dice_loss": mask_dice_loss,
# "mask_loss": mask_loss,
# }
def inference(
self,
video_path,
images_evf,
input_ids,
# original_size_list,
multimask_output=False,
):
predictor = self.visual_model
inference_state = predictor.init_state(video_path=video_path)
predictor.reset_state(inference_state)
multimask_output = multimask_output
output = self.mm_extractor.beit3(
visual_tokens=images_evf,
textual_tokens=input_ids,
text_padding_position=torch.zeros_like(input_ids))
feat = output["encoder_out"][:, :1, ...]
feat = self.text_hidden_fcs[0](feat)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
_, out_obj_ids, out_mask_logits = predictor.add_new_text(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
text=feat)
# run propagation throughout the video and collect the results in a dict
video_segments = {
} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
AutoConfig.register("evf", EvfConfig)
AutoModelForCausalLM.register(EvfConfig, EvfSam2Model)