Pedro Cavalcanti commited on
Commit
07050ac
1 Parent(s): a2f0be7

changed: file

Browse files
Files changed (1) hide show
  1. modeling_florence2.py +581 -280
modeling_florence2.py CHANGED
@@ -23,7 +23,7 @@ import torch.utils.checkpoint
23
  from torch import nn
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint as checkpoint
26
- from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
  from timm.models.layers import DropPath, trunc_normal_
@@ -39,7 +39,7 @@ from transformers.utils import (
39
  is_flash_attn_2_available,
40
  is_flash_attn_greater_or_equal_2_10,
41
  )
42
- from .configuration_florence2 import Florence2Config
43
  from .configuration_florence2 import Florence2LanguageConfig
44
  from .configuration_florence2 import Florence2VisionConfig
45
 
@@ -66,6 +66,7 @@ logger = logging.get_logger(__name__)
66
 
67
  _CONFIG_FOR_DOC = "Florence2Config"
68
 
 
69
  class LearnedAbsolutePositionEmbedding2D(nn.Module):
70
  """
71
  This module learns positional embeddings up to a fixed maximum size.
@@ -74,22 +75,30 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
74
  def __init__(self, embedding_dim=256, num_pos=50):
75
  super().__init__()
76
  self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
77
- self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))
 
 
78
 
79
  def forward(self, pixel_values):
80
  """
81
- pixel_values: (batch_size, height, width, num_channels)
82
  returns: (batch_size, height, width, embedding_dim * 2)
83
  """
84
  if len(pixel_values.shape) != 4:
85
- raise ValueError('pixel_values must be a 4D tensor')
86
  height, width = pixel_values.shape[1:3]
87
  width_values = torch.arange(width, device=pixel_values.device)
88
  height_values = torch.arange(height, device=pixel_values.device)
89
  x_emb = self.column_embeddings(width_values)
90
  y_emb = self.row_embeddings(height_values)
91
  # (height, width, embedding_dim * 2)
92
- pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
 
 
 
 
 
 
93
  # (embedding_dim * 2, height, width)
94
  pos = pos.permute(2, 0, 1)
95
  pos = pos.unsqueeze(0)
@@ -99,6 +108,7 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
99
  pos = pos.permute(0, 2, 3, 1)
100
  return pos
101
 
 
102
  class PositionalEmbeddingCosine1D(nn.Module):
103
  """
104
  This class implements a very simple positional encoding. It follows closely
@@ -110,22 +120,21 @@ class PositionalEmbeddingCosine1D(nn.Module):
110
  dropout_prob: The dropout probability.
111
  max_seq_len: The maximum length to precompute the positional encodings.
112
  """
113
- def __init__(
114
- self,
115
- embed_dim: int = 512,
116
- max_seq_len: int = 1024) -> None:
117
  super(PositionalEmbeddingCosine1D, self).__init__()
118
  self.embed_dim = embed_dim
119
  self.max_seq_len = max_seq_len
120
  # Generate the sinusoidal arrays.
121
  factor = math.log(10000)
122
  denominator = torch.exp(
123
- -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim)
 
124
  # Matrix where rows correspond to a positional embedding as a function
125
  # of the position index (i.e., the row index).
126
- frequencies = \
127
- torch.arange(0, self.max_seq_len) \
128
- .reshape(self.max_seq_len, 1) * denominator
129
  pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
130
  # Populate uneven entries.
131
  pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
@@ -149,11 +158,10 @@ class PositionalEmbeddingCosine1D(nn.Module):
149
  assert 2 <= shape_len <= 3
150
  len_seq = seq_embeds.size(-2)
151
  assert len_seq <= self.max_seq_len
152
- pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
153
  # Adapt pre-computed positional embeddings to the input.
154
  if shape_len == 3:
155
- pos_embeds = pos_embeds.view(
156
- (1, pos_embeds.size(0), pos_embeds.size(1)))
157
  return pos_embeds
158
 
159
 
@@ -165,10 +173,8 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
165
  embed_dim: The dimension of the embeddings.
166
  max_seq_len: The maximum length to precompute the positional encodings.
167
  """
168
- def __init__(
169
- self,
170
- embedding_dim: int = 512,
171
- num_pos: int = 1024) -> None:
172
  super(LearnedAbsolutePositionEmbedding1D, self).__init__()
173
  self.embeddings = nn.Embedding(num_pos, embedding_dim)
174
  self.num_pos = num_pos
@@ -193,12 +199,10 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
193
  pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
194
  # Adapt pre-computed positional embeddings to the input.
195
  if shape_len == 3:
196
- pos_embeds = pos_embeds.view(
197
- (1, pos_embeds.size(0), pos_embeds.size(1)))
198
  return pos_embeds
199
 
200
 
201
-
202
  class MySequential(nn.Sequential):
203
  def forward(self, *inputs):
204
  for module in self._modules.values():
@@ -242,11 +246,15 @@ class Mlp(nn.Module):
242
  super().__init__()
243
  out_features = out_features or in_features
244
  hidden_features = hidden_features or in_features
245
- self.net = nn.Sequential(OrderedDict([
246
- ("fc1", nn.Linear(in_features, hidden_features)),
247
- ("act", act_layer()),
248
- ("fc2", nn.Linear(hidden_features, out_features))
249
- ]))
 
 
 
 
250
 
251
  def forward(self, x, size):
252
  return self.net(x), size
@@ -263,12 +271,13 @@ class DepthWiseConv2d(nn.Module):
263
  ):
264
  super().__init__()
265
  self.dw = nn.Conv2d(
266
- dim_in, dim_in,
 
267
  kernel_size=kernel_size,
268
  padding=padding,
269
  groups=dim_in,
270
  stride=stride,
271
- bias=bias
272
  )
273
 
274
  def forward(self, x, size):
@@ -283,8 +292,7 @@ class DepthWiseConv2d(nn.Module):
283
 
284
 
285
  class ConvEmbed(nn.Module):
286
- """ Image to Patch Embedding
287
- """
288
 
289
  def __init__(
290
  self,
@@ -294,16 +302,13 @@ class ConvEmbed(nn.Module):
294
  stride=4,
295
  padding=2,
296
  norm_layer=None,
297
- pre_norm=True
298
  ):
299
  super().__init__()
300
  self.patch_size = patch_size
301
 
302
  self.proj = nn.Conv2d(
303
- in_chans, embed_dim,
304
- kernel_size=patch_size,
305
- stride=stride,
306
- padding=padding
307
  )
308
 
309
  dim_norm = in_chans if pre_norm else embed_dim
@@ -316,15 +321,12 @@ class ConvEmbed(nn.Module):
316
  if len(x.size()) == 3:
317
  if self.norm and self.pre_norm:
318
  x = self.norm(x)
319
- x = rearrange(
320
- x, 'b (h w) c -> b c h w',
321
- h=H, w=W
322
- )
323
 
324
  x = self.proj(x)
325
 
326
  _, _, H, W = x.shape
327
- x = rearrange(x, 'b c h w -> b (h w) c')
328
  if self.norm and not self.pre_norm:
329
  x = self.norm(x)
330
 
@@ -343,7 +345,11 @@ class ChannelAttention(nn.Module):
343
  def forward(self, x, size):
344
  B, N, C = x.shape
345
 
346
- qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
 
 
 
 
347
  q, k, v = qkv[0], qkv[1], qkv[2]
348
 
349
  q = q * (float(N) ** -0.5)
@@ -357,24 +363,41 @@ class ChannelAttention(nn.Module):
357
 
358
  class ChannelBlock(nn.Module):
359
 
360
- def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
361
- drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
362
- conv_at_attn=True, conv_at_ffn=True):
 
 
 
 
 
 
 
 
 
363
  super().__init__()
364
 
365
- drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
366
 
367
- self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
 
 
368
  self.channel_attn = PreNorm(
369
  norm_layer(dim),
370
  ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
371
- drop_path
 
 
 
372
  )
373
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
374
  self.ffn = PreNorm(
375
  norm_layer(dim),
376
- Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
377
- drop_path
 
 
 
 
378
  )
379
 
380
  def forward(self, x, size):
@@ -392,15 +415,19 @@ class ChannelBlock(nn.Module):
392
  def window_partition(x, window_size: int):
393
  B, H, W, C = x.shape
394
  x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
395
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
 
 
396
  return windows
397
 
398
 
399
  def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
400
- B = batch_size
401
  # this will cause onnx conversion failed for dynamic axis, because treated as constant
402
- # int(windows.shape[0] / (H * W / window_size / window_size))
403
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
 
 
404
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
405
  return x
406
 
@@ -441,20 +468,22 @@ class WindowAttention(nn.Module):
441
  # attn_windows = self.attn(x_windows)
442
 
443
  B_, N, C = x.shape
444
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
445
  q, k, v = qkv[0], qkv[1], qkv[2]
446
 
447
  q = q * self.scale
448
- attn = (q @ k.transpose(-2, -1))
449
  attn = self.softmax(attn)
450
 
451
  x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
452
  x = self.proj(x)
453
 
454
  # merge windows
455
- x = x.view(
456
- -1, self.window_size, self.window_size, C
457
- )
458
  x = window_reverse(x, B, self.window_size, Hp, Wp)
459
 
460
  if pad_r > 0 or pad_b > 0:
@@ -467,24 +496,42 @@ class WindowAttention(nn.Module):
467
 
468
  class SpatialBlock(nn.Module):
469
 
470
- def __init__(self, dim, num_heads, window_size,
471
- mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
472
- norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
 
 
 
 
 
 
 
 
 
 
473
  super().__init__()
474
 
475
- drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
476
 
477
- self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
 
 
478
  self.window_attn = PreNorm(
479
  norm_layer(dim),
480
  WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
481
- drop_path
 
 
 
482
  )
483
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
484
  self.ffn = PreNorm(
485
  norm_layer(dim),
486
- Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
487
- drop_path
 
 
 
 
488
  )
489
 
490
  def forward(self, x, size):
@@ -499,7 +546,7 @@ class SpatialBlock(nn.Module):
499
 
500
 
501
  class DaViT(nn.Module):
502
- """ DaViT: Dual-Attention Transformer
503
 
504
  Args:
505
  in_chans (int): Number of input image channels. Default: 3.
@@ -534,14 +581,14 @@ class DaViT(nn.Module):
534
  num_heads=(3, 6, 12, 24),
535
  num_groups=(3, 6, 12, 24),
536
  window_size=7,
537
- mlp_ratio=4.,
538
  qkv_bias=True,
539
  drop_path_rate=0.1,
540
  norm_layer=nn.LayerNorm,
541
  enable_checkpoint=False,
542
  conv_at_attn=True,
543
  conv_at_ffn=True,
544
- ):
545
  super().__init__()
546
 
547
  self.num_classes = num_classes
@@ -553,7 +600,7 @@ class DaViT(nn.Module):
553
  assert self.num_stages == len(self.num_heads) == len(self.num_groups)
554
 
555
  num_stages = len(embed_dims)
556
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
557
 
558
  depth_offset = 0
559
  convs = []
@@ -566,48 +613,59 @@ class DaViT(nn.Module):
566
  in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
567
  embed_dim=self.embed_dims[i],
568
  norm_layer=norm_layer,
569
- pre_norm=patch_prenorm[i]
570
  )
571
  convs.append(conv_embed)
572
 
573
  block = MySequential(
574
  *[
575
- MySequential(OrderedDict([
576
- (
577
- 'spatial_block', SpatialBlock(
578
- embed_dims[i],
579
- num_heads[i],
580
- window_size,
581
- drop_path_rate=dpr[depth_offset+j*2],
582
- qkv_bias=qkv_bias,
583
- mlp_ratio=mlp_ratio,
584
- conv_at_attn=conv_at_attn,
585
- conv_at_ffn=conv_at_ffn,
586
- )
587
- ),
588
- (
589
- 'channel_block', ChannelBlock(
590
- embed_dims[i],
591
- num_groups[i],
592
- drop_path_rate=dpr[depth_offset+j*2+1],
593
- qkv_bias=qkv_bias,
594
- mlp_ratio=mlp_ratio,
595
- conv_at_attn=conv_at_attn,
596
- conv_at_ffn=conv_at_ffn,
597
- )
 
 
 
 
 
 
598
  )
599
- ])) for j in range(depths[i])
 
600
  ]
601
  )
602
  blocks.append(block)
603
- depth_offset += depths[i]*2
604
 
605
  self.convs = nn.ModuleList(convs)
606
  self.blocks = nn.ModuleList(blocks)
607
 
608
  self.norms = norm_layer(self.embed_dims[-1])
609
  self.avgpool = nn.AdaptiveAvgPool1d(1)
610
- self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
 
 
 
 
611
 
612
  self.apply(self._init_weights)
613
 
@@ -623,7 +681,7 @@ class DaViT(nn.Module):
623
  elif isinstance(m, nn.Conv2d):
624
  nn.init.normal_(m.weight, std=0.02)
625
  for name, _ in m.named_parameters():
626
- if name in ['bias']:
627
  nn.init.constant_(m.bias, 0)
628
  elif isinstance(m, nn.LayerNorm):
629
  nn.init.constant_(m.weight, 1.0)
@@ -634,7 +692,7 @@ class DaViT(nn.Module):
634
 
635
  def forward_features_unpool(self, x):
636
  """
637
- forward until avg pooling
638
  Args:
639
  x (_type_): input image tensor
640
  """
@@ -662,7 +720,7 @@ class DaViT(nn.Module):
662
  x = self.forward_features(x)
663
  x = self.head(x)
664
  return x
665
-
666
  @classmethod
667
  def from_config(cls, config):
668
  return cls(
@@ -679,12 +737,11 @@ class DaViT(nn.Module):
679
  )
680
 
681
 
682
-
683
-
684
  if is_flash_attn_2_available():
685
  from flash_attn import flash_attn_func, flash_attn_varlen_func
686
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
687
 
 
688
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
689
  def _get_unpad_data(attention_mask):
690
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -698,7 +755,9 @@ def _get_unpad_data(attention_mask):
698
  )
699
 
700
 
701
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
 
 
702
  """
703
  Shift input ids one token to the right.
704
  """
@@ -730,7 +789,10 @@ class Florence2LearnedPositionalEmbedding(nn.Embedding):
730
 
731
  bsz, seq_len = input_ids.shape[:2]
732
  positions = torch.arange(
733
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
 
 
 
734
  ).expand(bsz, -1)
735
 
736
  return super().forward(positions + self.offset)
@@ -741,7 +803,13 @@ class Florence2ScaledWordEmbedding(nn.Embedding):
741
  This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
742
  """
743
 
744
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
 
 
 
 
 
 
745
  super().__init__(num_embeddings, embedding_dim, padding_idx)
746
  self.embed_scale = embed_scale
747
 
@@ -784,7 +852,11 @@ class Florence2Attention(nn.Module):
784
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
785
 
786
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
787
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
788
 
789
  def forward(
790
  self,
@@ -861,7 +933,10 @@ class Florence2Attention(nn.Module):
861
  raise ValueError(
862
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
863
  )
864
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
 
 
 
865
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
866
 
867
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@@ -872,7 +947,9 @@ class Florence2Attention(nn.Module):
872
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
873
  f" {layer_head_mask.size()}"
874
  )
875
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
 
 
876
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
877
 
878
  if output_attentions:
@@ -880,12 +957,18 @@ class Florence2Attention(nn.Module):
880
  # make sure that attn_weights keeps its gradient.
881
  # In order to do so, attn_weights have to be reshaped
882
  # twice and have to be reused in the following
883
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
884
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 
 
 
 
885
  else:
886
  attn_weights_reshaped = None
887
 
888
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 
 
889
 
890
  attn_output = torch.bmm(attn_probs, value_states)
891
 
@@ -937,7 +1020,9 @@ class Florence2FlashAttention2(Florence2Attention):
937
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
938
  # Florence2FlashAttention2 attention does not support output_attentions
939
  if output_attentions:
940
- raise ValueError("Florence2FlashAttention2 attention does not support output_attentions")
 
 
941
 
942
  # if key_value_states are provided this layer is used as a cross-attention layer
943
  # for the decoder
@@ -967,8 +1052,12 @@ class Florence2FlashAttention2(Florence2Attention):
967
  # reuse k, v, self_attention
968
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
969
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
970
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
971
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
 
 
 
 
972
  else:
973
  # self_attention
974
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
@@ -1015,7 +1104,12 @@ class Florence2FlashAttention2(Florence2Attention):
1015
  value_states = value_states.to(target_dtype)
1016
 
1017
  attn_output = self._flash_attention_forward(
1018
- query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
 
 
 
 
 
1019
  )
1020
 
1021
  attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -1028,7 +1122,14 @@ class Florence2FlashAttention2(Florence2Attention):
1028
 
1029
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
1030
  def _flash_attention_forward(
1031
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
1032
  ):
1033
  """
1034
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -1058,7 +1159,14 @@ class Florence2FlashAttention2(Florence2Attention):
1058
  # Contains at least one padding token in the sequence
1059
  if attention_mask is not None:
1060
  batch_size = query_states.shape[0]
1061
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
 
 
 
 
 
 
 
1062
  query_states, key_states, value_states, attention_mask, query_length
1063
  )
1064
 
@@ -1078,28 +1186,40 @@ class Florence2FlashAttention2(Florence2Attention):
1078
  causal=causal,
1079
  )
1080
 
1081
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
1082
  else:
1083
  attn_output = flash_attn_func(
1084
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
 
 
 
 
1085
  )
1086
 
1087
  return attn_output
1088
 
1089
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
1090
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
1091
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1092
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1093
 
1094
  key_layer = index_first_axis(
1095
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
1096
  )
1097
  value_layer = index_first_axis(
1098
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
1099
  )
1100
  if query_length == kv_seq_len:
1101
  query_layer = index_first_axis(
1102
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
1103
  )
1104
  cu_seqlens_q = cu_seqlens_k
1105
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -1114,7 +1234,9 @@ class Florence2FlashAttention2(Florence2Attention):
1114
  else:
1115
  # The -q_len: slice assumes left padding.
1116
  attention_mask = attention_mask[:, -query_length:]
1117
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
1118
 
1119
  return (
1120
  query_layer,
@@ -1202,7 +1324,9 @@ class Florence2SdpaAttention(Florence2Attention):
1202
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1203
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1204
  # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
1205
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
 
 
1206
 
1207
  # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
1208
  # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
@@ -1283,15 +1407,21 @@ class Florence2EncoderLayer(nn.Module):
1283
  layer_head_mask=layer_head_mask,
1284
  output_attentions=output_attentions,
1285
  )
1286
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1287
  hidden_states = residual + hidden_states
1288
  hidden_states = self.self_attn_layer_norm(hidden_states)
1289
 
1290
  residual = hidden_states
1291
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1292
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
 
1293
  hidden_states = self.fc2(hidden_states)
1294
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1295
  hidden_states = residual + hidden_states
1296
  hidden_states = self.final_layer_norm(hidden_states)
1297
 
@@ -1299,7 +1429,9 @@ class Florence2EncoderLayer(nn.Module):
1299
  torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
1300
  ):
1301
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1302
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
 
 
1303
 
1304
  outputs = (hidden_states,)
1305
 
@@ -1350,7 +1482,9 @@ class Florence2DecoderLayer(nn.Module):
1350
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
1351
  output_attentions: Optional[bool] = False,
1352
  use_cache: Optional[bool] = True,
1353
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
1354
  """
1355
  Args:
1356
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -1373,7 +1507,9 @@ class Florence2DecoderLayer(nn.Module):
1373
 
1374
  # Self Attention
1375
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
1376
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
 
 
1377
  # add present self-attn cache to positions 1,2 of present_key_value tuple
1378
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1379
  hidden_states=hidden_states,
@@ -1382,7 +1518,9 @@ class Florence2DecoderLayer(nn.Module):
1382
  layer_head_mask=layer_head_mask,
1383
  output_attentions=output_attentions,
1384
  )
1385
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1386
  hidden_states = residual + hidden_states
1387
  hidden_states = self.self_attn_layer_norm(hidden_states)
1388
 
@@ -1393,16 +1531,22 @@ class Florence2DecoderLayer(nn.Module):
1393
  residual = hidden_states
1394
 
1395
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
1396
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
1397
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
1398
- hidden_states=hidden_states,
1399
- key_value_states=encoder_hidden_states,
1400
- attention_mask=encoder_attention_mask,
1401
- layer_head_mask=cross_attn_layer_head_mask,
1402
- past_key_value=cross_attn_past_key_value,
1403
- output_attentions=output_attentions,
 
 
 
 
 
 
 
1404
  )
1405
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1406
  hidden_states = residual + hidden_states
1407
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
1408
 
@@ -1412,9 +1556,13 @@ class Florence2DecoderLayer(nn.Module):
1412
  # Fully Connected
1413
  residual = hidden_states
1414
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1415
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
 
1416
  hidden_states = self.fc2(hidden_states)
1417
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1418
  hidden_states = residual + hidden_states
1419
  hidden_states = self.final_layer_norm(hidden_states)
1420
 
@@ -1429,7 +1577,6 @@ class Florence2DecoderLayer(nn.Module):
1429
  return outputs
1430
 
1431
 
1432
-
1433
  class Florence2LanguagePreTrainedModel(PreTrainedModel):
1434
  config_class = Florence2LanguageConfig
1435
  base_model_prefix = "model"
@@ -1454,7 +1601,9 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
1454
  @property
1455
  def dummy_inputs(self):
1456
  pad_token = self.config.pad_token_id
1457
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
 
 
1458
  dummy_inputs = {
1459
  "attention_mask": input_ids.ne(pad_token),
1460
  "input_ids": input_ids,
@@ -1472,7 +1621,11 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1472
  embed_tokens (nn.Embedding): output embedding
1473
  """
1474
 
1475
- def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
 
 
 
 
1476
  super().__init__(config)
1477
 
1478
  self.dropout = config.dropout
@@ -1494,7 +1647,9 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1494
  config.max_position_embeddings,
1495
  embed_dim,
1496
  )
1497
- self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)])
 
 
1498
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1499
  self._use_sdpa = config._attn_implementation == "sdpa"
1500
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -1555,15 +1710,25 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1555
  return_dict (`bool`, *optional*):
1556
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1557
  """
1558
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1559
  output_hidden_states = (
1560
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1561
  )
1562
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1563
 
1564
  # retrieve input_ids and inputs_embeds
1565
  if input_ids is not None and inputs_embeds is not None:
1566
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
1567
  elif input_ids is not None:
1568
  input = input_ids
1569
  input_ids = input_ids.view(-1, input_ids.shape[-1])
@@ -1580,7 +1745,9 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1580
 
1581
  hidden_states = inputs_embeds + embed_pos
1582
  hidden_states = self.layernorm_embedding(hidden_states)
1583
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1584
 
1585
  # expand attention_mask
1586
  if attention_mask is not None:
@@ -1590,10 +1757,14 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1590
  # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1591
  # the manual implementation that requires a 4D causal mask in all cases.
1592
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1593
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
 
 
1594
  else:
1595
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1596
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
 
 
1597
 
1598
  encoder_states = () if output_hidden_states else None
1599
  all_attentions = () if output_attentions else None
@@ -1631,7 +1802,9 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1631
  layer_outputs = encoder_layer(
1632
  hidden_states,
1633
  attention_mask,
1634
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
 
 
1635
  output_attentions=output_attentions,
1636
  )
1637
 
@@ -1644,9 +1817,15 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1644
  encoder_states = encoder_states + (hidden_states,)
1645
 
1646
  if not return_dict:
1647
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
 
 
 
 
1648
  return BaseModelOutput(
1649
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
 
 
1650
  )
1651
 
1652
 
@@ -1659,7 +1838,11 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1659
  embed_tokens (nn.Embedding): output embedding
1660
  """
1661
 
1662
- def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
 
 
 
 
1663
  super().__init__(config)
1664
  self.dropout = config.dropout
1665
  self.layerdrop = config.decoder_layerdrop
@@ -1678,7 +1861,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1678
  config.max_position_embeddings,
1679
  config.d_model,
1680
  )
1681
- self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)])
 
 
1682
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1683
  self._use_sdpa = config._attn_implementation == "sdpa"
1684
 
@@ -1774,16 +1959,26 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1774
  return_dict (`bool`, *optional*):
1775
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1776
  """
1777
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1778
  output_hidden_states = (
1779
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
1780
  )
1781
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1782
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1783
 
1784
  # retrieve input_ids and inputs_embeds
1785
  if input_ids is not None and inputs_embeds is not None:
1786
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
 
1787
  elif input_ids is not None:
1788
  input = input_ids
1789
  input_shape = input.shape
@@ -1792,17 +1987,25 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1792
  input_shape = inputs_embeds.size()[:-1]
1793
  input = inputs_embeds[:, :, -1]
1794
  else:
1795
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
 
1796
 
1797
  # past_key_values_length
1798
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
 
 
1799
 
1800
  if inputs_embeds is None:
1801
  inputs_embeds = self.embed_tokens(input)
1802
 
1803
  if self._use_flash_attention_2:
1804
  # 2d mask is passed through the layers
1805
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
 
 
 
 
1806
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1807
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1808
  # the manual implementation that requires a 4D causal mask in all cases.
@@ -1821,8 +2024,14 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1821
  # expand encoder attention mask
1822
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1823
  if self._use_flash_attention_2:
1824
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1825
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
 
 
 
 
 
 
1826
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1827
  # the manual implementation that requires a 4D causal mask in all cases.
1828
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1844,7 +2053,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1844
  hidden_states = inputs_embeds + positions
1845
  hidden_states = self.layernorm_embedding(hidden_states)
1846
 
1847
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
1848
 
1849
  if self.gradient_checkpointing and self.training:
1850
  if use_cache:
@@ -1856,11 +2067,15 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1856
  # decoder layers
1857
  all_hidden_states = () if output_hidden_states else None
1858
  all_self_attns = () if output_attentions else None
1859
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
 
 
1860
  next_decoder_cache = () if use_cache else None
1861
 
1862
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1863
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
 
 
1864
  if attn_mask is not None:
1865
  if attn_mask.size()[0] != (len(self.layers)):
1866
  raise ValueError(
@@ -1877,7 +2092,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1877
  if dropout_probability < self.layerdrop:
1878
  continue
1879
 
1880
- past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
1881
 
1882
  if self.gradient_checkpointing and self.training:
1883
  layer_outputs = self._gradient_checkpointing_func(
@@ -1887,7 +2104,11 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1887
  encoder_hidden_states,
1888
  encoder_attention_mask,
1889
  head_mask[idx] if head_mask is not None else None,
1890
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
 
 
 
 
1891
  None,
1892
  output_attentions,
1893
  use_cache,
@@ -1900,7 +2121,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1900
  encoder_attention_mask=encoder_attention_mask,
1901
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1902
  cross_attn_layer_head_mask=(
1903
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
 
 
1904
  ),
1905
  past_key_value=past_key_value,
1906
  output_attentions=output_attentions,
@@ -1925,7 +2148,13 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1925
  if not return_dict:
1926
  return tuple(
1927
  v
1928
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
 
 
 
 
 
 
1929
  if v is not None
1930
  )
1931
  return BaseModelOutputWithPastAndCrossAttentions(
@@ -2003,12 +2232,20 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2003
  input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
2004
  )
2005
 
2006
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
2007
  output_hidden_states = (
2008
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
2009
  )
2010
  use_cache = use_cache if use_cache is not None else self.config.use_cache
2011
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
2012
 
2013
  if encoder_outputs is None:
2014
  encoder_outputs = self.encoder(
@@ -2061,14 +2298,22 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2061
 
2062
  class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel):
2063
  base_model_prefix = "model"
2064
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
 
 
 
 
2065
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
2066
 
2067
  def __init__(self, config: Florence2LanguageConfig):
2068
  super().__init__(config)
2069
  self.model = Florence2LanguageModel(config)
2070
- self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
2071
- self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
 
 
 
2072
 
2073
  # Initialize weights and apply final processing
2074
  self.post_init()
@@ -2079,8 +2324,12 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2079
  def get_decoder(self):
2080
  return self.model.get_decoder()
2081
 
2082
- def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
2083
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
 
 
 
 
2084
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2085
  return new_embeddings
2086
 
@@ -2089,7 +2338,10 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2089
  if new_num_tokens <= old_num_tokens:
2090
  new_bias = self.final_logits_bias[:, :new_num_tokens]
2091
  else:
2092
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
 
 
 
2093
  new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
2094
  self.register_buffer("final_logits_bias", new_bias)
2095
 
@@ -2126,11 +2378,15 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2126
 
2127
  Returns:
2128
  """
2129
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
2130
 
2131
  if labels is not None:
2132
  if use_cache:
2133
- logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
 
 
2134
  use_cache = False
2135
  if decoder_input_ids is None and decoder_inputs_embeds is None:
2136
  decoder_input_ids = shift_tokens_right(
@@ -2162,11 +2418,15 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2162
  if labels is not None:
2163
  labels = labels.to(lm_logits.device)
2164
  loss_fct = CrossEntropyLoss()
2165
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
 
 
2166
 
2167
  if not return_dict:
2168
  output = (lm_logits,) + outputs[1:]
2169
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
 
 
2170
 
2171
  return Seq2SeqLMOutput(
2172
  loss=masked_lm_loss,
@@ -2220,7 +2480,9 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2220
  }
2221
 
2222
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
2223
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
 
 
2224
 
2225
  @staticmethod
2226
  def _reorder_cache(past_key_values, beam_idx):
@@ -2228,11 +2490,15 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2228
  for layer_past in past_key_values:
2229
  # cached cross_attention states don't have to be reordered -> they are always the same
2230
  reordered_past += (
2231
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
 
 
 
2232
  + layer_past[2:],
2233
  )
2234
  return reordered_past
2235
 
 
2236
  @dataclass
2237
  class Florence2Seq2SeqLMOutput(ModelOutput):
2238
  """
@@ -2289,6 +2555,7 @@ class Florence2Seq2SeqLMOutput(ModelOutput):
2289
  image_hidden_states of the model produced by the vision encoder
2290
  """
2291
 
 
2292
  last_hidden_state: torch.FloatTensor = None
2293
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
2294
  decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@@ -2408,6 +2675,7 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
2408
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2409
  """
2410
 
 
2411
  @add_start_docstrings(
2412
  """The FLORENCE2 vision model without any head""",
2413
  FLORENCE2_START_DOCSTRING,
@@ -2415,16 +2683,16 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
2415
  class Florence2VisionModel(Florence2PreTrainedModel):
2416
  def __init__(self, config: Florence2VisionConfig):
2417
  super().__init__(config)
2418
- assert config.model_type == 'davit', 'only DaViT is supported for now'
2419
  self.vision_tower = DaViT.from_config(config=config)
2420
 
2421
  self.post_init()
2422
-
2423
  def forward(self, pixel_values):
2424
  if len(pixel_values.shape) == 4:
2425
  x = self.vision_tower.forward_features_unpool(pixel_values)
2426
  else:
2427
- raise ValueError(f'invalid image shape {pixel_values.shape}')
2428
  return x
2429
 
2430
 
@@ -2435,40 +2703,38 @@ class Florence2VisionModel(Florence2PreTrainedModel):
2435
  class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2436
  def __init__(self, config: Florence2VisionConfig):
2437
  super().__init__(config)
2438
- assert config.model_type == 'davit', 'only DaViT is supported for now'
2439
  self.vision_tower = DaViT.from_config(config=config)
2440
 
2441
  self._build_image_projection_layers(config)
2442
 
2443
  self.post_init()
2444
-
2445
  def _build_image_projection_layers(self, config):
2446
  image_dim_out = config.dim_embed[-1]
2447
  dim_projection = config.projection_dim
2448
- self.image_projection = nn.Parameter(
2449
- torch.empty(image_dim_out, dim_projection)
2450
- )
2451
  self.image_proj_norm = nn.LayerNorm(dim_projection)
2452
  image_pos_embed_config = config.image_pos_embed
2453
- if image_pos_embed_config['type'] == 'learned_abs_2d':
2454
  self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
2455
  embedding_dim=image_dim_out,
2456
- num_pos=image_pos_embed_config['max_pos_embeddings']
2457
  )
2458
  else:
2459
- raise NotImplementedError('Not implemented yet')
2460
 
2461
  self.image_feature_source = config.image_feature_source
2462
 
2463
  # temporal embedding
2464
  visual_temporal_embedding_config = config.visual_temporal_embedding
2465
- if visual_temporal_embedding_config['type'] == 'COSINE':
2466
  self.visual_temporal_embed = PositionalEmbeddingCosine1D(
2467
  embed_dim=image_dim_out,
2468
- max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
2469
  )
2470
  else:
2471
- raise NotImplementedError('Not implemented yet')
2472
 
2473
  def forward(self, pixel_values):
2474
  if len(pixel_values.shape) == 4:
@@ -2476,37 +2742,43 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2476
  T = 1
2477
  x = self.vision_tower.forward_features_unpool(pixel_values)
2478
  else:
2479
- raise ValueError(f'invalid image shape {pixel_values.shape}')
2480
-
2481
  if self.image_pos_embed is not None:
2482
  x = x.view(batch_size * T, -1, x.shape[-1])
2483
  num_tokens = x.shape[-2]
2484
- h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
2485
- assert h * w == num_tokens, 'only support square feature maps for now'
2486
  x = x.view(batch_size * T, h, w, x.shape[-1])
2487
  pos_embed = self.image_pos_embed(x)
2488
  x = x + pos_embed
2489
- x = x.view(batch_size, T * h*w, x.shape[-1])
2490
 
2491
  if self.visual_temporal_embed is not None:
2492
- visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2493
- x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
 
 
 
 
2494
 
2495
  x_feat_dict = {}
2496
 
2497
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2498
- x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2499
 
2500
  temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
2501
- x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2502
 
2503
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
2504
- x_feat_dict['last_frame'] = x
2505
 
2506
  new_x = []
2507
  for _image_feature_source in self.image_feature_source:
2508
  if _image_feature_source not in x_feat_dict:
2509
- raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
 
 
2510
  new_x.append(x_feat_dict[_image_feature_source])
2511
 
2512
  x = torch.cat(new_x, dim=1)
@@ -2514,11 +2786,9 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2514
  x = x @ self.image_projection
2515
  x = self.image_proj_norm(x)
2516
 
2517
-
2518
  return x
2519
 
2520
 
2521
-
2522
  @add_start_docstrings(
2523
  """The FLORENCE2 model which consists of a vision backbone and a language model.""",
2524
  FLORENCE2_START_DOCSTRING,
@@ -2526,10 +2796,12 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2526
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2527
  def __init__(self, config: Florence2Config):
2528
  super().__init__(config)
2529
- assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
 
 
2530
  del config.vision_config.model_type
2531
  self.vision_tower = DaViT.from_config(config=config.vision_config)
2532
- # remove unused layers
2533
  del self.vision_tower.head
2534
  del self.vision_tower.norms
2535
 
@@ -2537,42 +2809,48 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2537
  self._attn_implementation = config._attn_implementation
2538
  self._build_image_projection_layers(config)
2539
 
2540
- language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
 
 
2541
 
2542
  if language_model._tied_weights_keys is not None:
2543
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
 
 
2544
  self.language_model = language_model
2545
 
2546
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
 
2547
  self.post_init()
2548
-
2549
  def _build_image_projection_layers(self, config):
2550
  image_dim_out = config.vision_config.dim_embed[-1]
2551
  dim_projection = config.vision_config.projection_dim
2552
- self.image_projection = nn.Parameter(
2553
- torch.empty(image_dim_out, dim_projection)
2554
- )
2555
  self.image_proj_norm = nn.LayerNorm(dim_projection)
2556
  image_pos_embed_config = config.vision_config.image_pos_embed
2557
- if image_pos_embed_config['type'] == 'learned_abs_2d':
2558
  self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
2559
  embedding_dim=image_dim_out,
2560
- num_pos=image_pos_embed_config['max_pos_embeddings']
2561
  )
2562
  else:
2563
- raise NotImplementedError('Not implemented yet')
2564
 
2565
  self.image_feature_source = config.vision_config.image_feature_source
2566
 
2567
  # temporal embedding
2568
- visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
2569
- if visual_temporal_embedding_config['type'] == 'COSINE':
 
 
2570
  self.visual_temporal_embed = PositionalEmbeddingCosine1D(
2571
  embed_dim=image_dim_out,
2572
- max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
2573
  )
2574
  else:
2575
- raise NotImplementedError('Not implemented yet')
2576
 
2577
  def get_encoder(self):
2578
  return self.language_model.get_encoder()
@@ -2583,51 +2861,61 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2583
  def get_input_embeddings(self):
2584
  return self.language_model.get_input_embeddings()
2585
 
2586
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
2587
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
 
 
 
 
2588
  # update vocab size
2589
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2590
  self.config.vocab_size = model_embeds.num_embeddings
2591
  self.vocab_size = model_embeds.num_embeddings
2592
  return model_embeds
2593
-
2594
  def _encode_image(self, pixel_values):
2595
  if len(pixel_values.shape) == 4:
2596
  batch_size, C, H, W = pixel_values.shape
2597
  T = 1
2598
  x = self.vision_tower.forward_features_unpool(pixel_values)
2599
  else:
2600
- raise ValueError(f'invalid image shape {pixel_values.shape}')
2601
-
2602
  if self.image_pos_embed is not None:
2603
  x = x.view(batch_size * T, -1, x.shape[-1])
2604
  num_tokens = x.shape[-2]
2605
- h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
2606
- assert h * w == num_tokens, 'only support square feature maps for now'
2607
  x = x.view(batch_size * T, h, w, x.shape[-1])
2608
  pos_embed = self.image_pos_embed(x)
2609
  x = x + pos_embed
2610
- x = x.view(batch_size, T * h*w, x.shape[-1])
2611
 
2612
  if self.visual_temporal_embed is not None:
2613
- visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2614
- x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
 
 
 
 
2615
 
2616
  x_feat_dict = {}
2617
 
2618
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2619
- x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2620
 
2621
  temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
2622
- x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2623
 
2624
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
2625
- x_feat_dict['last_frame'] = x
2626
 
2627
  new_x = []
2628
  for _image_feature_source in self.image_feature_source:
2629
  if _image_feature_source not in x_feat_dict:
2630
- raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
 
 
2631
  new_x.append(x_feat_dict[_image_feature_source])
2632
 
2633
  x = torch.cat(new_x, dim=1)
@@ -2635,11 +2923,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2635
  x = x @ self.image_projection
2636
  x = self.image_proj_norm(x)
2637
 
2638
- return x
2639
 
2640
- def _merge_input_ids_with_image_features(
2641
- self, image_features, inputs_embeds
2642
- ):
2643
  batch_size, image_token_length = image_features.size()[:-1]
2644
  device = image_features.device
2645
  image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
@@ -2650,20 +2936,25 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2650
  return image_features, image_attention_mask
2651
 
2652
  task_prefix_embeds = inputs_embeds
2653
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
 
 
2654
 
2655
  if len(task_prefix_attention_mask.shape) == 3:
2656
  task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2657
 
2658
  # concat [image embeds, task prefix embeds]
2659
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
2660
- attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)
 
 
2661
 
2662
  return inputs_embeds, attention_mask
2663
 
2664
-
2665
  @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING)
2666
- @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
 
 
2667
  def forward(
2668
  self,
2669
  input_ids: torch.LongTensor = None,
@@ -2714,11 +3005,19 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2714
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2715
  "A green car parked in front of a yellow building."
2716
  ```"""
2717
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
2718
  output_hidden_states = (
2719
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
2720
  )
2721
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2722
 
2723
  image_features = None
2724
  if inputs_embeds is None:
@@ -2729,7 +3028,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2729
  if pixel_values is not None:
2730
  # (batch_size, num_image_tokens, hidden_size)
2731
  image_features = self._encode_image(pixel_values)
2732
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
 
 
 
 
2733
 
2734
  attention_mask = attention_mask.to(inputs_embeds.dtype)
2735
  outputs = self.language_model(
@@ -2757,6 +3060,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2757
  output = (logits,) + outputs[1:]
2758
  return (loss,) + output if loss is not None else output
2759
 
 
 
2760
  return Florence2Seq2SeqLMOutput(
2761
  loss=loss,
2762
  logits=logits,
@@ -2767,16 +3072,10 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2767
  encoder_last_hidden_state=outputs.encoder_last_hidden_state,
2768
  encoder_hidden_states=outputs.encoder_hidden_states,
2769
  encoder_attentions=outputs.encoder_attentions,
2770
- image_hidden_states=image_features
2771
  )
2772
 
2773
- def generate(
2774
- self,
2775
- input_ids,
2776
- inputs_embeds=None,
2777
- pixel_values=None,
2778
- **kwargs
2779
- ):
2780
 
2781
  if inputs_embeds is None:
2782
  # 1. Extra the input embeddings
@@ -2785,12 +3084,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2785
  # 2. Merge text and images
2786
  if pixel_values is not None:
2787
  image_features = self._encode_image(pixel_values)
2788
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2789
-
 
 
 
 
2790
  return self.language_model.generate(
2791
- input_ids=None,
2792
- inputs_embeds=inputs_embeds,
2793
- **kwargs
2794
  )
2795
 
2796
  def prepare_inputs_for_generation(
@@ -2819,7 +3120,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2819
  remove_prefix_length = decoder_input_ids.shape[1] - 1
2820
 
2821
  decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
2822
-
2823
  return {
2824
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
2825
  "encoder_outputs": encoder_outputs,
@@ -2833,9 +3134,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2833
  "cross_attn_head_mask": cross_attn_head_mask,
2834
  "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
2835
  }
2836
-
2837
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
2838
  return self.language_model.shift_tokens_right(labels)
2839
 
2840
  def _reorder_cache(self, *args, **kwargs):
2841
- return self.language_model._reorder_cache(*args, **kwargs)
 
23
  from torch import nn
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint as checkpoint
26
+ from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
  from timm.models.layers import DropPath, trunc_normal_
 
39
  is_flash_attn_2_available,
40
  is_flash_attn_greater_or_equal_2_10,
41
  )
42
+ from .configuration_florence2 import Florence2Config
43
  from .configuration_florence2 import Florence2LanguageConfig
44
  from .configuration_florence2 import Florence2VisionConfig
45
 
 
66
 
67
  _CONFIG_FOR_DOC = "Florence2Config"
68
 
69
+
70
  class LearnedAbsolutePositionEmbedding2D(nn.Module):
71
  """
72
  This module learns positional embeddings up to a fixed maximum size.
 
75
  def __init__(self, embedding_dim=256, num_pos=50):
76
  super().__init__()
77
  self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
78
+ self.column_embeddings = nn.Embedding(
79
+ num_pos, embedding_dim - (embedding_dim // 2)
80
+ )
81
 
82
  def forward(self, pixel_values):
83
  """
84
+ pixel_values: (batch_size, height, width, num_channels)
85
  returns: (batch_size, height, width, embedding_dim * 2)
86
  """
87
  if len(pixel_values.shape) != 4:
88
+ raise ValueError("pixel_values must be a 4D tensor")
89
  height, width = pixel_values.shape[1:3]
90
  width_values = torch.arange(width, device=pixel_values.device)
91
  height_values = torch.arange(height, device=pixel_values.device)
92
  x_emb = self.column_embeddings(width_values)
93
  y_emb = self.row_embeddings(height_values)
94
  # (height, width, embedding_dim * 2)
95
+ pos = torch.cat(
96
+ [
97
+ x_emb.unsqueeze(0).repeat(height, 1, 1),
98
+ y_emb.unsqueeze(1).repeat(1, width, 1),
99
+ ],
100
+ dim=-1,
101
+ )
102
  # (embedding_dim * 2, height, width)
103
  pos = pos.permute(2, 0, 1)
104
  pos = pos.unsqueeze(0)
 
108
  pos = pos.permute(0, 2, 3, 1)
109
  return pos
110
 
111
+
112
  class PositionalEmbeddingCosine1D(nn.Module):
113
  """
114
  This class implements a very simple positional encoding. It follows closely
 
120
  dropout_prob: The dropout probability.
121
  max_seq_len: The maximum length to precompute the positional encodings.
122
  """
123
+
124
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
 
 
125
  super(PositionalEmbeddingCosine1D, self).__init__()
126
  self.embed_dim = embed_dim
127
  self.max_seq_len = max_seq_len
128
  # Generate the sinusoidal arrays.
129
  factor = math.log(10000)
130
  denominator = torch.exp(
131
+ -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim
132
+ )
133
  # Matrix where rows correspond to a positional embedding as a function
134
  # of the position index (i.e., the row index).
135
+ frequencies = (
136
+ torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator
137
+ )
138
  pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
139
  # Populate uneven entries.
140
  pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
 
158
  assert 2 <= shape_len <= 3
159
  len_seq = seq_embeds.size(-2)
160
  assert len_seq <= self.max_seq_len
161
+ pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :]
162
  # Adapt pre-computed positional embeddings to the input.
163
  if shape_len == 3:
164
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
 
165
  return pos_embeds
166
 
167
 
 
173
  embed_dim: The dimension of the embeddings.
174
  max_seq_len: The maximum length to precompute the positional encodings.
175
  """
176
+
177
+ def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None:
 
 
178
  super(LearnedAbsolutePositionEmbedding1D, self).__init__()
179
  self.embeddings = nn.Embedding(num_pos, embedding_dim)
180
  self.num_pos = num_pos
 
199
  pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
200
  # Adapt pre-computed positional embeddings to the input.
201
  if shape_len == 3:
202
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
 
203
  return pos_embeds
204
 
205
 
 
206
  class MySequential(nn.Sequential):
207
  def forward(self, *inputs):
208
  for module in self._modules.values():
 
246
  super().__init__()
247
  out_features = out_features or in_features
248
  hidden_features = hidden_features or in_features
249
+ self.net = nn.Sequential(
250
+ OrderedDict(
251
+ [
252
+ ("fc1", nn.Linear(in_features, hidden_features)),
253
+ ("act", act_layer()),
254
+ ("fc2", nn.Linear(hidden_features, out_features)),
255
+ ]
256
+ )
257
+ )
258
 
259
  def forward(self, x, size):
260
  return self.net(x), size
 
271
  ):
272
  super().__init__()
273
  self.dw = nn.Conv2d(
274
+ dim_in,
275
+ dim_in,
276
  kernel_size=kernel_size,
277
  padding=padding,
278
  groups=dim_in,
279
  stride=stride,
280
+ bias=bias,
281
  )
282
 
283
  def forward(self, x, size):
 
292
 
293
 
294
  class ConvEmbed(nn.Module):
295
+ """Image to Patch Embedding"""
 
296
 
297
  def __init__(
298
  self,
 
302
  stride=4,
303
  padding=2,
304
  norm_layer=None,
305
+ pre_norm=True,
306
  ):
307
  super().__init__()
308
  self.patch_size = patch_size
309
 
310
  self.proj = nn.Conv2d(
311
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
 
 
 
312
  )
313
 
314
  dim_norm = in_chans if pre_norm else embed_dim
 
321
  if len(x.size()) == 3:
322
  if self.norm and self.pre_norm:
323
  x = self.norm(x)
324
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
 
 
 
325
 
326
  x = self.proj(x)
327
 
328
  _, _, H, W = x.shape
329
+ x = rearrange(x, "b c h w -> b (h w) c")
330
  if self.norm and not self.pre_norm:
331
  x = self.norm(x)
332
 
 
345
  def forward(self, x, size):
346
  B, N, C = x.shape
347
 
348
+ qkv = (
349
+ self.qkv(x)
350
+ .reshape(B, N, 3, self.groups, C // self.groups)
351
+ .permute(2, 0, 3, 1, 4)
352
+ )
353
  q, k, v = qkv[0], qkv[1], qkv[2]
354
 
355
  q = q * (float(N) ** -0.5)
 
363
 
364
  class ChannelBlock(nn.Module):
365
 
366
+ def __init__(
367
+ self,
368
+ dim,
369
+ groups,
370
+ mlp_ratio=4.0,
371
+ qkv_bias=True,
372
+ drop_path_rate=0.0,
373
+ act_layer=nn.GELU,
374
+ norm_layer=nn.LayerNorm,
375
+ conv_at_attn=True,
376
+ conv_at_ffn=True,
377
+ ):
378
  super().__init__()
379
 
380
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
381
 
382
+ self.conv1 = (
383
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
384
+ )
385
  self.channel_attn = PreNorm(
386
  norm_layer(dim),
387
  ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
388
+ drop_path,
389
+ )
390
+ self.conv2 = (
391
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
392
  )
 
393
  self.ffn = PreNorm(
394
  norm_layer(dim),
395
+ Mlp(
396
+ in_features=dim,
397
+ hidden_features=int(dim * mlp_ratio),
398
+ act_layer=act_layer,
399
+ ),
400
+ drop_path,
401
  )
402
 
403
  def forward(self, x, size):
 
415
  def window_partition(x, window_size: int):
416
  B, H, W, C = x.shape
417
  x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
418
+ windows = (
419
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
420
+ )
421
  return windows
422
 
423
 
424
  def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
425
+ B = batch_size
426
  # this will cause onnx conversion failed for dynamic axis, because treated as constant
427
+ # int(windows.shape[0] / (H * W / window_size / window_size))
428
+ x = windows.view(
429
+ B, H // window_size, W // window_size, window_size, window_size, -1
430
+ )
431
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
432
  return x
433
 
 
468
  # attn_windows = self.attn(x_windows)
469
 
470
  B_, N, C = x.shape
471
+ qkv = (
472
+ self.qkv(x)
473
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
474
+ .permute(2, 0, 3, 1, 4)
475
+ )
476
  q, k, v = qkv[0], qkv[1], qkv[2]
477
 
478
  q = q * self.scale
479
+ attn = q @ k.transpose(-2, -1)
480
  attn = self.softmax(attn)
481
 
482
  x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
483
  x = self.proj(x)
484
 
485
  # merge windows
486
+ x = x.view(-1, self.window_size, self.window_size, C)
 
 
487
  x = window_reverse(x, B, self.window_size, Hp, Wp)
488
 
489
  if pad_r > 0 or pad_b > 0:
 
496
 
497
  class SpatialBlock(nn.Module):
498
 
499
+ def __init__(
500
+ self,
501
+ dim,
502
+ num_heads,
503
+ window_size,
504
+ mlp_ratio=4.0,
505
+ qkv_bias=True,
506
+ drop_path_rate=0.0,
507
+ act_layer=nn.GELU,
508
+ norm_layer=nn.LayerNorm,
509
+ conv_at_attn=True,
510
+ conv_at_ffn=True,
511
+ ):
512
  super().__init__()
513
 
514
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
515
 
516
+ self.conv1 = (
517
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
518
+ )
519
  self.window_attn = PreNorm(
520
  norm_layer(dim),
521
  WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
522
+ drop_path,
523
+ )
524
+ self.conv2 = (
525
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
526
  )
 
527
  self.ffn = PreNorm(
528
  norm_layer(dim),
529
+ Mlp(
530
+ in_features=dim,
531
+ hidden_features=int(dim * mlp_ratio),
532
+ act_layer=act_layer,
533
+ ),
534
+ drop_path,
535
  )
536
 
537
  def forward(self, x, size):
 
546
 
547
 
548
  class DaViT(nn.Module):
549
+ """DaViT: Dual-Attention Transformer
550
 
551
  Args:
552
  in_chans (int): Number of input image channels. Default: 3.
 
581
  num_heads=(3, 6, 12, 24),
582
  num_groups=(3, 6, 12, 24),
583
  window_size=7,
584
+ mlp_ratio=4.0,
585
  qkv_bias=True,
586
  drop_path_rate=0.1,
587
  norm_layer=nn.LayerNorm,
588
  enable_checkpoint=False,
589
  conv_at_attn=True,
590
  conv_at_ffn=True,
591
+ ):
592
  super().__init__()
593
 
594
  self.num_classes = num_classes
 
600
  assert self.num_stages == len(self.num_heads) == len(self.num_groups)
601
 
602
  num_stages = len(embed_dims)
603
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
604
 
605
  depth_offset = 0
606
  convs = []
 
613
  in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
614
  embed_dim=self.embed_dims[i],
615
  norm_layer=norm_layer,
616
+ pre_norm=patch_prenorm[i],
617
  )
618
  convs.append(conv_embed)
619
 
620
  block = MySequential(
621
  *[
622
+ MySequential(
623
+ OrderedDict(
624
+ [
625
+ (
626
+ "spatial_block",
627
+ SpatialBlock(
628
+ embed_dims[i],
629
+ num_heads[i],
630
+ window_size,
631
+ drop_path_rate=dpr[depth_offset + j * 2],
632
+ qkv_bias=qkv_bias,
633
+ mlp_ratio=mlp_ratio,
634
+ conv_at_attn=conv_at_attn,
635
+ conv_at_ffn=conv_at_ffn,
636
+ ),
637
+ ),
638
+ (
639
+ "channel_block",
640
+ ChannelBlock(
641
+ embed_dims[i],
642
+ num_groups[i],
643
+ drop_path_rate=dpr[depth_offset + j * 2 + 1],
644
+ qkv_bias=qkv_bias,
645
+ mlp_ratio=mlp_ratio,
646
+ conv_at_attn=conv_at_attn,
647
+ conv_at_ffn=conv_at_ffn,
648
+ ),
649
+ ),
650
+ ]
651
  )
652
+ )
653
+ for j in range(depths[i])
654
  ]
655
  )
656
  blocks.append(block)
657
+ depth_offset += depths[i] * 2
658
 
659
  self.convs = nn.ModuleList(convs)
660
  self.blocks = nn.ModuleList(blocks)
661
 
662
  self.norms = norm_layer(self.embed_dims[-1])
663
  self.avgpool = nn.AdaptiveAvgPool1d(1)
664
+ self.head = (
665
+ nn.Linear(self.embed_dims[-1], num_classes)
666
+ if num_classes > 0
667
+ else nn.Identity()
668
+ )
669
 
670
  self.apply(self._init_weights)
671
 
 
681
  elif isinstance(m, nn.Conv2d):
682
  nn.init.normal_(m.weight, std=0.02)
683
  for name, _ in m.named_parameters():
684
+ if name in ["bias"]:
685
  nn.init.constant_(m.bias, 0)
686
  elif isinstance(m, nn.LayerNorm):
687
  nn.init.constant_(m.weight, 1.0)
 
692
 
693
  def forward_features_unpool(self, x):
694
  """
695
+ forward until avg pooling
696
  Args:
697
  x (_type_): input image tensor
698
  """
 
720
  x = self.forward_features(x)
721
  x = self.head(x)
722
  return x
723
+
724
  @classmethod
725
  def from_config(cls, config):
726
  return cls(
 
737
  )
738
 
739
 
 
 
740
  if is_flash_attn_2_available():
741
  from flash_attn import flash_attn_func, flash_attn_varlen_func
742
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
743
 
744
+
745
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
746
  def _get_unpad_data(attention_mask):
747
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
755
  )
756
 
757
 
758
+ def shift_tokens_right(
759
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
760
+ ):
761
  """
762
  Shift input ids one token to the right.
763
  """
 
789
 
790
  bsz, seq_len = input_ids.shape[:2]
791
  positions = torch.arange(
792
+ past_key_values_length,
793
+ past_key_values_length + seq_len,
794
+ dtype=torch.long,
795
+ device=self.weight.device,
796
  ).expand(bsz, -1)
797
 
798
  return super().forward(positions + self.offset)
 
803
  This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
804
  """
805
 
806
+ def __init__(
807
+ self,
808
+ num_embeddings: int,
809
+ embedding_dim: int,
810
+ padding_idx: int,
811
+ embed_scale: Optional[float] = 1.0,
812
+ ):
813
  super().__init__(num_embeddings, embedding_dim, padding_idx)
814
  self.embed_scale = embed_scale
815
 
 
852
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
853
 
854
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
855
+ return (
856
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
857
+ .transpose(1, 2)
858
+ .contiguous()
859
+ )
860
 
861
  def forward(
862
  self,
 
933
  raise ValueError(
934
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
935
  )
936
+ attn_weights = (
937
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
938
+ + attention_mask
939
+ )
940
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
941
 
942
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 
947
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
948
  f" {layer_head_mask.size()}"
949
  )
950
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
951
+ bsz, self.num_heads, tgt_len, src_len
952
+ )
953
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
954
 
955
  if output_attentions:
 
957
  # make sure that attn_weights keeps its gradient.
958
  # In order to do so, attn_weights have to be reshaped
959
  # twice and have to be reused in the following
960
+ attn_weights_reshaped = attn_weights.view(
961
+ bsz, self.num_heads, tgt_len, src_len
962
+ )
963
+ attn_weights = attn_weights_reshaped.view(
964
+ bsz * self.num_heads, tgt_len, src_len
965
+ )
966
  else:
967
  attn_weights_reshaped = None
968
 
969
+ attn_probs = nn.functional.dropout(
970
+ attn_weights, p=self.dropout, training=self.training
971
+ )
972
 
973
  attn_output = torch.bmm(attn_probs, value_states)
974
 
 
1020
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1021
  # Florence2FlashAttention2 attention does not support output_attentions
1022
  if output_attentions:
1023
+ raise ValueError(
1024
+ "Florence2FlashAttention2 attention does not support output_attentions"
1025
+ )
1026
 
1027
  # if key_value_states are provided this layer is used as a cross-attention layer
1028
  # for the decoder
 
1052
  # reuse k, v, self_attention
1053
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
1054
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
1055
+ key_states = torch.cat(
1056
+ [past_key_value[0].transpose(1, 2), key_states], dim=1
1057
+ )
1058
+ value_states = torch.cat(
1059
+ [past_key_value[1].transpose(1, 2), value_states], dim=1
1060
+ )
1061
  else:
1062
  # self_attention
1063
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
 
1104
  value_states = value_states.to(target_dtype)
1105
 
1106
  attn_output = self._flash_attention_forward(
1107
+ query_states,
1108
+ key_states,
1109
+ value_states,
1110
+ attention_mask,
1111
+ q_len,
1112
+ dropout=self.dropout,
1113
  )
1114
 
1115
  attn_output = attn_output.reshape(bsz, q_len, -1)
 
1122
 
1123
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
1124
  def _flash_attention_forward(
1125
+ self,
1126
+ query_states,
1127
+ key_states,
1128
+ value_states,
1129
+ attention_mask,
1130
+ query_length,
1131
+ dropout=0.0,
1132
+ softmax_scale=None,
1133
  ):
1134
  """
1135
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
1159
  # Contains at least one padding token in the sequence
1160
  if attention_mask is not None:
1161
  batch_size = query_states.shape[0]
1162
+ (
1163
+ query_states,
1164
+ key_states,
1165
+ value_states,
1166
+ indices_q,
1167
+ cu_seq_lens,
1168
+ max_seq_lens,
1169
+ ) = self._upad_input(
1170
  query_states, key_states, value_states, attention_mask, query_length
1171
  )
1172
 
 
1186
  causal=causal,
1187
  )
1188
 
1189
+ attn_output = pad_input(
1190
+ attn_output_unpad, indices_q, batch_size, query_length
1191
+ )
1192
  else:
1193
  attn_output = flash_attn_func(
1194
+ query_states,
1195
+ key_states,
1196
+ value_states,
1197
+ dropout,
1198
+ softmax_scale=softmax_scale,
1199
+ causal=causal,
1200
  )
1201
 
1202
  return attn_output
1203
 
1204
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
1205
+ def _upad_input(
1206
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
1207
+ ):
1208
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1209
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1210
 
1211
  key_layer = index_first_axis(
1212
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1213
+ indices_k,
1214
  )
1215
  value_layer = index_first_axis(
1216
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1217
+ indices_k,
1218
  )
1219
  if query_length == kv_seq_len:
1220
  query_layer = index_first_axis(
1221
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1222
+ indices_k,
1223
  )
1224
  cu_seqlens_q = cu_seqlens_k
1225
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
1234
  else:
1235
  # The -q_len: slice assumes left padding.
1236
  attention_mask = attention_mask[:, -query_length:]
1237
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1238
+ query_layer, attention_mask
1239
+ )
1240
 
1241
  return (
1242
  query_layer,
 
1324
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1325
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1326
  # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
1327
+ is_causal = (
1328
+ True if self.is_causal and attention_mask is None and tgt_len > 1 else False
1329
+ )
1330
 
1331
  # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
1332
  # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
 
1407
  layer_head_mask=layer_head_mask,
1408
  output_attentions=output_attentions,
1409
  )
1410
+ hidden_states = nn.functional.dropout(
1411
+ hidden_states, p=self.dropout, training=self.training
1412
+ )
1413
  hidden_states = residual + hidden_states
1414
  hidden_states = self.self_attn_layer_norm(hidden_states)
1415
 
1416
  residual = hidden_states
1417
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1418
+ hidden_states = nn.functional.dropout(
1419
+ hidden_states, p=self.activation_dropout, training=self.training
1420
+ )
1421
  hidden_states = self.fc2(hidden_states)
1422
+ hidden_states = nn.functional.dropout(
1423
+ hidden_states, p=self.dropout, training=self.training
1424
+ )
1425
  hidden_states = residual + hidden_states
1426
  hidden_states = self.final_layer_norm(hidden_states)
1427
 
 
1429
  torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
1430
  ):
1431
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1432
+ hidden_states = torch.clamp(
1433
+ hidden_states, min=-clamp_value, max=clamp_value
1434
+ )
1435
 
1436
  outputs = (hidden_states,)
1437
 
 
1482
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
1483
  output_attentions: Optional[bool] = False,
1484
  use_cache: Optional[bool] = True,
1485
+ ) -> Tuple[
1486
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1487
+ ]:
1488
  """
1489
  Args:
1490
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
1507
 
1508
  # Self Attention
1509
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
1510
+ self_attn_past_key_value = (
1511
+ past_key_value[:2] if past_key_value is not None else None
1512
+ )
1513
  # add present self-attn cache to positions 1,2 of present_key_value tuple
1514
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1515
  hidden_states=hidden_states,
 
1518
  layer_head_mask=layer_head_mask,
1519
  output_attentions=output_attentions,
1520
  )
1521
+ hidden_states = nn.functional.dropout(
1522
+ hidden_states, p=self.dropout, training=self.training
1523
+ )
1524
  hidden_states = residual + hidden_states
1525
  hidden_states = self.self_attn_layer_norm(hidden_states)
1526
 
 
1531
  residual = hidden_states
1532
 
1533
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
1534
+ cross_attn_past_key_value = (
1535
+ past_key_value[-2:] if past_key_value is not None else None
1536
+ )
1537
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = (
1538
+ self.encoder_attn(
1539
+ hidden_states=hidden_states,
1540
+ key_value_states=encoder_hidden_states,
1541
+ attention_mask=encoder_attention_mask,
1542
+ layer_head_mask=cross_attn_layer_head_mask,
1543
+ past_key_value=cross_attn_past_key_value,
1544
+ output_attentions=output_attentions,
1545
+ )
1546
+ )
1547
+ hidden_states = nn.functional.dropout(
1548
+ hidden_states, p=self.dropout, training=self.training
1549
  )
 
1550
  hidden_states = residual + hidden_states
1551
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
1552
 
 
1556
  # Fully Connected
1557
  residual = hidden_states
1558
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1559
+ hidden_states = nn.functional.dropout(
1560
+ hidden_states, p=self.activation_dropout, training=self.training
1561
+ )
1562
  hidden_states = self.fc2(hidden_states)
1563
+ hidden_states = nn.functional.dropout(
1564
+ hidden_states, p=self.dropout, training=self.training
1565
+ )
1566
  hidden_states = residual + hidden_states
1567
  hidden_states = self.final_layer_norm(hidden_states)
1568
 
 
1577
  return outputs
1578
 
1579
 
 
1580
  class Florence2LanguagePreTrainedModel(PreTrainedModel):
1581
  config_class = Florence2LanguageConfig
1582
  base_model_prefix = "model"
 
1601
  @property
1602
  def dummy_inputs(self):
1603
  pad_token = self.config.pad_token_id
1604
+ input_ids = torch.tensor(
1605
+ [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device
1606
+ )
1607
  dummy_inputs = {
1608
  "attention_mask": input_ids.ne(pad_token),
1609
  "input_ids": input_ids,
 
1621
  embed_tokens (nn.Embedding): output embedding
1622
  """
1623
 
1624
+ def __init__(
1625
+ self,
1626
+ config: Florence2LanguageConfig,
1627
+ embed_tokens: Optional[nn.Embedding] = None,
1628
+ ):
1629
  super().__init__(config)
1630
 
1631
  self.dropout = config.dropout
 
1647
  config.max_position_embeddings,
1648
  embed_dim,
1649
  )
1650
+ self.layers = nn.ModuleList(
1651
+ [Florence2EncoderLayer(config) for _ in range(config.encoder_layers)]
1652
+ )
1653
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1654
  self._use_sdpa = config._attn_implementation == "sdpa"
1655
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
 
1710
  return_dict (`bool`, *optional*):
1711
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1712
  """
1713
+ output_attentions = (
1714
+ output_attentions
1715
+ if output_attentions is not None
1716
+ else self.config.output_attentions
1717
+ )
1718
  output_hidden_states = (
1719
+ output_hidden_states
1720
+ if output_hidden_states is not None
1721
+ else self.config.output_hidden_states
1722
+ )
1723
+ return_dict = (
1724
+ return_dict if return_dict is not None else self.config.use_return_dict
1725
  )
 
1726
 
1727
  # retrieve input_ids and inputs_embeds
1728
  if input_ids is not None and inputs_embeds is not None:
1729
+ raise ValueError(
1730
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1731
+ )
1732
  elif input_ids is not None:
1733
  input = input_ids
1734
  input_ids = input_ids.view(-1, input_ids.shape[-1])
 
1745
 
1746
  hidden_states = inputs_embeds + embed_pos
1747
  hidden_states = self.layernorm_embedding(hidden_states)
1748
+ hidden_states = nn.functional.dropout(
1749
+ hidden_states, p=self.dropout, training=self.training
1750
+ )
1751
 
1752
  # expand attention_mask
1753
  if attention_mask is not None:
 
1757
  # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1758
  # the manual implementation that requires a 4D causal mask in all cases.
1759
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1760
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
1761
+ attention_mask, inputs_embeds.dtype
1762
+ )
1763
  else:
1764
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1765
+ attention_mask = _prepare_4d_attention_mask(
1766
+ attention_mask, inputs_embeds.dtype
1767
+ )
1768
 
1769
  encoder_states = () if output_hidden_states else None
1770
  all_attentions = () if output_attentions else None
 
1802
  layer_outputs = encoder_layer(
1803
  hidden_states,
1804
  attention_mask,
1805
+ layer_head_mask=(
1806
+ head_mask[idx] if head_mask is not None else None
1807
+ ),
1808
  output_attentions=output_attentions,
1809
  )
1810
 
 
1817
  encoder_states = encoder_states + (hidden_states,)
1818
 
1819
  if not return_dict:
1820
+ return tuple(
1821
+ v
1822
+ for v in [hidden_states, encoder_states, all_attentions]
1823
+ if v is not None
1824
+ )
1825
  return BaseModelOutput(
1826
+ last_hidden_state=hidden_states,
1827
+ hidden_states=encoder_states,
1828
+ attentions=all_attentions,
1829
  )
1830
 
1831
 
 
1838
  embed_tokens (nn.Embedding): output embedding
1839
  """
1840
 
1841
+ def __init__(
1842
+ self,
1843
+ config: Florence2LanguageConfig,
1844
+ embed_tokens: Optional[nn.Embedding] = None,
1845
+ ):
1846
  super().__init__(config)
1847
  self.dropout = config.dropout
1848
  self.layerdrop = config.decoder_layerdrop
 
1861
  config.max_position_embeddings,
1862
  config.d_model,
1863
  )
1864
+ self.layers = nn.ModuleList(
1865
+ [Florence2DecoderLayer(config) for _ in range(config.decoder_layers)]
1866
+ )
1867
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1868
  self._use_sdpa = config._attn_implementation == "sdpa"
1869
 
 
1959
  return_dict (`bool`, *optional*):
1960
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1961
  """
1962
+ output_attentions = (
1963
+ output_attentions
1964
+ if output_attentions is not None
1965
+ else self.config.output_attentions
1966
+ )
1967
  output_hidden_states = (
1968
+ output_hidden_states
1969
+ if output_hidden_states is not None
1970
+ else self.config.output_hidden_states
1971
  )
1972
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1973
+ return_dict = (
1974
+ return_dict if return_dict is not None else self.config.use_return_dict
1975
+ )
1976
 
1977
  # retrieve input_ids and inputs_embeds
1978
  if input_ids is not None and inputs_embeds is not None:
1979
+ raise ValueError(
1980
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1981
+ )
1982
  elif input_ids is not None:
1983
  input = input_ids
1984
  input_shape = input.shape
 
1987
  input_shape = inputs_embeds.size()[:-1]
1988
  input = inputs_embeds[:, :, -1]
1989
  else:
1990
+ raise ValueError(
1991
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1992
+ )
1993
 
1994
  # past_key_values_length
1995
+ past_key_values_length = (
1996
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1997
+ )
1998
 
1999
  if inputs_embeds is None:
2000
  inputs_embeds = self.embed_tokens(input)
2001
 
2002
  if self._use_flash_attention_2:
2003
  # 2d mask is passed through the layers
2004
+ attention_mask = (
2005
+ attention_mask
2006
+ if (attention_mask is not None and 0 in attention_mask)
2007
+ else None
2008
+ )
2009
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
2010
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
2011
  # the manual implementation that requires a 4D causal mask in all cases.
 
2024
  # expand encoder attention mask
2025
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
2026
  if self._use_flash_attention_2:
2027
+ encoder_attention_mask = (
2028
+ encoder_attention_mask if 0 in encoder_attention_mask else None
2029
+ )
2030
+ elif (
2031
+ self._use_sdpa
2032
+ and cross_attn_head_mask is None
2033
+ and not output_attentions
2034
+ ):
2035
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
2036
  # the manual implementation that requires a 4D causal mask in all cases.
2037
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
2053
  hidden_states = inputs_embeds + positions
2054
  hidden_states = self.layernorm_embedding(hidden_states)
2055
 
2056
+ hidden_states = nn.functional.dropout(
2057
+ hidden_states, p=self.dropout, training=self.training
2058
+ )
2059
 
2060
  if self.gradient_checkpointing and self.training:
2061
  if use_cache:
 
2067
  # decoder layers
2068
  all_hidden_states = () if output_hidden_states else None
2069
  all_self_attns = () if output_attentions else None
2070
+ all_cross_attentions = (
2071
+ () if (output_attentions and encoder_hidden_states is not None) else None
2072
+ )
2073
  next_decoder_cache = () if use_cache else None
2074
 
2075
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
2076
+ for attn_mask, mask_name in zip(
2077
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
2078
+ ):
2079
  if attn_mask is not None:
2080
  if attn_mask.size()[0] != (len(self.layers)):
2081
  raise ValueError(
 
2092
  if dropout_probability < self.layerdrop:
2093
  continue
2094
 
2095
+ past_key_value = (
2096
+ past_key_values[idx] if past_key_values is not None else None
2097
+ )
2098
 
2099
  if self.gradient_checkpointing and self.training:
2100
  layer_outputs = self._gradient_checkpointing_func(
 
2104
  encoder_hidden_states,
2105
  encoder_attention_mask,
2106
  head_mask[idx] if head_mask is not None else None,
2107
+ (
2108
+ cross_attn_head_mask[idx]
2109
+ if cross_attn_head_mask is not None
2110
+ else None
2111
+ ),
2112
  None,
2113
  output_attentions,
2114
  use_cache,
 
2121
  encoder_attention_mask=encoder_attention_mask,
2122
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
2123
  cross_attn_layer_head_mask=(
2124
+ cross_attn_head_mask[idx]
2125
+ if cross_attn_head_mask is not None
2126
+ else None
2127
  ),
2128
  past_key_value=past_key_value,
2129
  output_attentions=output_attentions,
 
2148
  if not return_dict:
2149
  return tuple(
2150
  v
2151
+ for v in [
2152
+ hidden_states,
2153
+ next_cache,
2154
+ all_hidden_states,
2155
+ all_self_attns,
2156
+ all_cross_attentions,
2157
+ ]
2158
  if v is not None
2159
  )
2160
  return BaseModelOutputWithPastAndCrossAttentions(
 
2232
  input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
2233
  )
2234
 
2235
+ output_attentions = (
2236
+ output_attentions
2237
+ if output_attentions is not None
2238
+ else self.config.output_attentions
2239
+ )
2240
  output_hidden_states = (
2241
+ output_hidden_states
2242
+ if output_hidden_states is not None
2243
+ else self.config.output_hidden_states
2244
  )
2245
  use_cache = use_cache if use_cache is not None else self.config.use_cache
2246
+ return_dict = (
2247
+ return_dict if return_dict is not None else self.config.use_return_dict
2248
+ )
2249
 
2250
  if encoder_outputs is None:
2251
  encoder_outputs = self.encoder(
 
2298
 
2299
  class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel):
2300
  base_model_prefix = "model"
2301
+ _tied_weights_keys = [
2302
+ "encoder.embed_tokens.weight",
2303
+ "decoder.embed_tokens.weight",
2304
+ "lm_head.weight",
2305
+ ]
2306
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
2307
 
2308
  def __init__(self, config: Florence2LanguageConfig):
2309
  super().__init__(config)
2310
  self.model = Florence2LanguageModel(config)
2311
+ self.register_buffer(
2312
+ "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
2313
+ )
2314
+ self.lm_head = nn.Linear(
2315
+ config.d_model, self.model.shared.num_embeddings, bias=False
2316
+ )
2317
 
2318
  # Initialize weights and apply final processing
2319
  self.post_init()
 
2324
  def get_decoder(self):
2325
  return self.model.get_decoder()
2326
 
2327
+ def resize_token_embeddings(
2328
+ self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None
2329
+ ) -> nn.Embedding:
2330
+ new_embeddings = super().resize_token_embeddings(
2331
+ new_num_tokens, pad_to_multiple_of
2332
+ )
2333
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2334
  return new_embeddings
2335
 
 
2338
  if new_num_tokens <= old_num_tokens:
2339
  new_bias = self.final_logits_bias[:, :new_num_tokens]
2340
  else:
2341
+ extra_bias = torch.zeros(
2342
+ (1, new_num_tokens - old_num_tokens),
2343
+ device=self.final_logits_bias.device,
2344
+ )
2345
  new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
2346
  self.register_buffer("final_logits_bias", new_bias)
2347
 
 
2378
 
2379
  Returns:
2380
  """
2381
+ return_dict = (
2382
+ return_dict if return_dict is not None else self.config.use_return_dict
2383
+ )
2384
 
2385
  if labels is not None:
2386
  if use_cache:
2387
+ logger.warning(
2388
+ "The `use_cache` argument is changed to `False` since `labels` is provided."
2389
+ )
2390
  use_cache = False
2391
  if decoder_input_ids is None and decoder_inputs_embeds is None:
2392
  decoder_input_ids = shift_tokens_right(
 
2418
  if labels is not None:
2419
  labels = labels.to(lm_logits.device)
2420
  loss_fct = CrossEntropyLoss()
2421
+ masked_lm_loss = loss_fct(
2422
+ lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
2423
+ )
2424
 
2425
  if not return_dict:
2426
  output = (lm_logits,) + outputs[1:]
2427
+ return (
2428
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
2429
+ )
2430
 
2431
  return Seq2SeqLMOutput(
2432
  loss=masked_lm_loss,
 
2480
  }
2481
 
2482
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
2483
+ return shift_tokens_right(
2484
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
2485
+ )
2486
 
2487
  @staticmethod
2488
  def _reorder_cache(past_key_values, beam_idx):
 
2490
  for layer_past in past_key_values:
2491
  # cached cross_attention states don't have to be reordered -> they are always the same
2492
  reordered_past += (
2493
+ tuple(
2494
+ past_state.index_select(0, beam_idx.to(past_state.device))
2495
+ for past_state in layer_past[:2]
2496
+ )
2497
  + layer_past[2:],
2498
  )
2499
  return reordered_past
2500
 
2501
+
2502
  @dataclass
2503
  class Florence2Seq2SeqLMOutput(ModelOutput):
2504
  """
 
2555
  image_hidden_states of the model produced by the vision encoder
2556
  """
2557
 
2558
+ loss: torch.FloatTensor = None
2559
  last_hidden_state: torch.FloatTensor = None
2560
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
2561
  decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
 
2675
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2676
  """
2677
 
2678
+
2679
  @add_start_docstrings(
2680
  """The FLORENCE2 vision model without any head""",
2681
  FLORENCE2_START_DOCSTRING,
 
2683
  class Florence2VisionModel(Florence2PreTrainedModel):
2684
  def __init__(self, config: Florence2VisionConfig):
2685
  super().__init__(config)
2686
+ assert config.model_type == "davit", "only DaViT is supported for now"
2687
  self.vision_tower = DaViT.from_config(config=config)
2688
 
2689
  self.post_init()
2690
+
2691
  def forward(self, pixel_values):
2692
  if len(pixel_values.shape) == 4:
2693
  x = self.vision_tower.forward_features_unpool(pixel_values)
2694
  else:
2695
+ raise ValueError(f"invalid image shape {pixel_values.shape}")
2696
  return x
2697
 
2698
 
 
2703
  class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2704
  def __init__(self, config: Florence2VisionConfig):
2705
  super().__init__(config)
2706
+ assert config.model_type == "davit", "only DaViT is supported for now"
2707
  self.vision_tower = DaViT.from_config(config=config)
2708
 
2709
  self._build_image_projection_layers(config)
2710
 
2711
  self.post_init()
2712
+
2713
  def _build_image_projection_layers(self, config):
2714
  image_dim_out = config.dim_embed[-1]
2715
  dim_projection = config.projection_dim
2716
+ self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
 
 
2717
  self.image_proj_norm = nn.LayerNorm(dim_projection)
2718
  image_pos_embed_config = config.image_pos_embed
2719
+ if image_pos_embed_config["type"] == "learned_abs_2d":
2720
  self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
2721
  embedding_dim=image_dim_out,
2722
+ num_pos=image_pos_embed_config["max_pos_embeddings"],
2723
  )
2724
  else:
2725
+ raise NotImplementedError("Not implemented yet")
2726
 
2727
  self.image_feature_source = config.image_feature_source
2728
 
2729
  # temporal embedding
2730
  visual_temporal_embedding_config = config.visual_temporal_embedding
2731
+ if visual_temporal_embedding_config["type"] == "COSINE":
2732
  self.visual_temporal_embed = PositionalEmbeddingCosine1D(
2733
  embed_dim=image_dim_out,
2734
+ max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
2735
  )
2736
  else:
2737
+ raise NotImplementedError("Not implemented yet")
2738
 
2739
  def forward(self, pixel_values):
2740
  if len(pixel_values.shape) == 4:
 
2742
  T = 1
2743
  x = self.vision_tower.forward_features_unpool(pixel_values)
2744
  else:
2745
+ raise ValueError(f"invalid image shape {pixel_values.shape}")
2746
+
2747
  if self.image_pos_embed is not None:
2748
  x = x.view(batch_size * T, -1, x.shape[-1])
2749
  num_tokens = x.shape[-2]
2750
+ h, w = int(num_tokens**0.5), int(num_tokens**0.5)
2751
+ assert h * w == num_tokens, "only support square feature maps for now"
2752
  x = x.view(batch_size * T, h, w, x.shape[-1])
2753
  pos_embed = self.image_pos_embed(x)
2754
  x = x + pos_embed
2755
+ x = x.view(batch_size, T * h * w, x.shape[-1])
2756
 
2757
  if self.visual_temporal_embed is not None:
2758
+ visual_temporal_embed = self.visual_temporal_embed(
2759
+ x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
2760
+ )
2761
+ x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(
2762
+ 1, T, 1, x.shape[-1]
2763
+ )
2764
 
2765
  x_feat_dict = {}
2766
 
2767
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2768
+ x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
2769
 
2770
  temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
2771
+ x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
2772
 
2773
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
2774
+ x_feat_dict["last_frame"] = x
2775
 
2776
  new_x = []
2777
  for _image_feature_source in self.image_feature_source:
2778
  if _image_feature_source not in x_feat_dict:
2779
+ raise ValueError(
2780
+ "invalid image feature source: {}".format(_image_feature_source)
2781
+ )
2782
  new_x.append(x_feat_dict[_image_feature_source])
2783
 
2784
  x = torch.cat(new_x, dim=1)
 
2786
  x = x @ self.image_projection
2787
  x = self.image_proj_norm(x)
2788
 
 
2789
  return x
2790
 
2791
 
 
2792
  @add_start_docstrings(
2793
  """The FLORENCE2 model which consists of a vision backbone and a language model.""",
2794
  FLORENCE2_START_DOCSTRING,
 
2796
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2797
  def __init__(self, config: Florence2Config):
2798
  super().__init__(config)
2799
+ assert (
2800
+ config.vision_config.model_type == "davit"
2801
+ ), "only DaViT is supported for now"
2802
  del config.vision_config.model_type
2803
  self.vision_tower = DaViT.from_config(config=config.vision_config)
2804
+ # remove unused layers
2805
  del self.vision_tower.head
2806
  del self.vision_tower.norms
2807
 
 
2809
  self._attn_implementation = config._attn_implementation
2810
  self._build_image_projection_layers(config)
2811
 
2812
+ language_model = Florence2LanguageForConditionalGeneration(
2813
+ config=config.text_config
2814
+ )
2815
 
2816
  if language_model._tied_weights_keys is not None:
2817
+ self._tied_weights_keys = [
2818
+ f"language_model.{k}" for k in language_model._tied_weights_keys
2819
+ ]
2820
  self.language_model = language_model
2821
 
2822
+ self.pad_token_id = (
2823
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
2824
+ )
2825
  self.post_init()
2826
+
2827
  def _build_image_projection_layers(self, config):
2828
  image_dim_out = config.vision_config.dim_embed[-1]
2829
  dim_projection = config.vision_config.projection_dim
2830
+ self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
 
 
2831
  self.image_proj_norm = nn.LayerNorm(dim_projection)
2832
  image_pos_embed_config = config.vision_config.image_pos_embed
2833
+ if image_pos_embed_config["type"] == "learned_abs_2d":
2834
  self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
2835
  embedding_dim=image_dim_out,
2836
+ num_pos=image_pos_embed_config["max_pos_embeddings"],
2837
  )
2838
  else:
2839
+ raise NotImplementedError("Not implemented yet")
2840
 
2841
  self.image_feature_source = config.vision_config.image_feature_source
2842
 
2843
  # temporal embedding
2844
+ visual_temporal_embedding_config = (
2845
+ config.vision_config.visual_temporal_embedding
2846
+ )
2847
+ if visual_temporal_embedding_config["type"] == "COSINE":
2848
  self.visual_temporal_embed = PositionalEmbeddingCosine1D(
2849
  embed_dim=image_dim_out,
2850
+ max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
2851
  )
2852
  else:
2853
+ raise NotImplementedError("Not implemented yet")
2854
 
2855
  def get_encoder(self):
2856
  return self.language_model.get_encoder()
 
2861
  def get_input_embeddings(self):
2862
  return self.language_model.get_input_embeddings()
2863
 
2864
+ def resize_token_embeddings(
2865
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
2866
+ ) -> nn.Embedding:
2867
+ model_embeds = self.language_model.resize_token_embeddings(
2868
+ new_num_tokens, pad_to_multiple_of
2869
+ )
2870
  # update vocab size
2871
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2872
  self.config.vocab_size = model_embeds.num_embeddings
2873
  self.vocab_size = model_embeds.num_embeddings
2874
  return model_embeds
2875
+
2876
  def _encode_image(self, pixel_values):
2877
  if len(pixel_values.shape) == 4:
2878
  batch_size, C, H, W = pixel_values.shape
2879
  T = 1
2880
  x = self.vision_tower.forward_features_unpool(pixel_values)
2881
  else:
2882
+ raise ValueError(f"invalid image shape {pixel_values.shape}")
2883
+
2884
  if self.image_pos_embed is not None:
2885
  x = x.view(batch_size * T, -1, x.shape[-1])
2886
  num_tokens = x.shape[-2]
2887
+ h, w = int(num_tokens**0.5), int(num_tokens**0.5)
2888
+ assert h * w == num_tokens, "only support square feature maps for now"
2889
  x = x.view(batch_size * T, h, w, x.shape[-1])
2890
  pos_embed = self.image_pos_embed(x)
2891
  x = x + pos_embed
2892
+ x = x.view(batch_size, T * h * w, x.shape[-1])
2893
 
2894
  if self.visual_temporal_embed is not None:
2895
+ visual_temporal_embed = self.visual_temporal_embed(
2896
+ x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
2897
+ )
2898
+ x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(
2899
+ 1, T, 1, x.shape[-1]
2900
+ )
2901
 
2902
  x_feat_dict = {}
2903
 
2904
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2905
+ x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
2906
 
2907
  temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
2908
+ x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
2909
 
2910
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
2911
+ x_feat_dict["last_frame"] = x
2912
 
2913
  new_x = []
2914
  for _image_feature_source in self.image_feature_source:
2915
  if _image_feature_source not in x_feat_dict:
2916
+ raise ValueError(
2917
+ "invalid image feature source: {}".format(_image_feature_source)
2918
+ )
2919
  new_x.append(x_feat_dict[_image_feature_source])
2920
 
2921
  x = torch.cat(new_x, dim=1)
 
2923
  x = x @ self.image_projection
2924
  x = self.image_proj_norm(x)
2925
 
2926
+ return x
2927
 
2928
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds):
 
 
2929
  batch_size, image_token_length = image_features.size()[:-1]
2930
  device = image_features.device
2931
  image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
 
2936
  return image_features, image_attention_mask
2937
 
2938
  task_prefix_embeds = inputs_embeds
2939
+ task_prefix_attention_mask = torch.ones(
2940
+ batch_size, task_prefix_embeds.size(1), device=device
2941
+ )
2942
 
2943
  if len(task_prefix_attention_mask.shape) == 3:
2944
  task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2945
 
2946
  # concat [image embeds, task prefix embeds]
2947
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
2948
+ attention_mask = torch.cat(
2949
+ [image_attention_mask, task_prefix_attention_mask], dim=1
2950
+ )
2951
 
2952
  return inputs_embeds, attention_mask
2953
 
 
2954
  @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING)
2955
+ @replace_return_docstrings(
2956
+ output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
2957
+ )
2958
  def forward(
2959
  self,
2960
  input_ids: torch.LongTensor = None,
 
3005
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
3006
  "A green car parked in front of a yellow building."
3007
  ```"""
3008
+ output_attentions = (
3009
+ output_attentions
3010
+ if output_attentions is not None
3011
+ else self.config.output_attentions
3012
+ )
3013
  output_hidden_states = (
3014
+ output_hidden_states
3015
+ if output_hidden_states is not None
3016
+ else self.config.output_hidden_states
3017
+ )
3018
+ return_dict = (
3019
+ return_dict if return_dict is not None else self.config.use_return_dict
3020
  )
 
3021
 
3022
  image_features = None
3023
  if inputs_embeds is None:
 
3028
  if pixel_values is not None:
3029
  # (batch_size, num_image_tokens, hidden_size)
3030
  image_features = self._encode_image(pixel_values)
3031
+ inputs_embeds, attention_mask = (
3032
+ self._merge_input_ids_with_image_features(
3033
+ image_features, inputs_embeds
3034
+ )
3035
+ )
3036
 
3037
  attention_mask = attention_mask.to(inputs_embeds.dtype)
3038
  outputs = self.language_model(
 
3060
  output = (logits,) + outputs[1:]
3061
  return (loss,) + output if loss is not None else output
3062
 
3063
+ print(loss)
3064
+
3065
  return Florence2Seq2SeqLMOutput(
3066
  loss=loss,
3067
  logits=logits,
 
3072
  encoder_last_hidden_state=outputs.encoder_last_hidden_state,
3073
  encoder_hidden_states=outputs.encoder_hidden_states,
3074
  encoder_attentions=outputs.encoder_attentions,
3075
+ image_hidden_states=image_features,
3076
  )
3077
 
3078
+ def generate(self, input_ids, inputs_embeds=None, pixel_values=None, **kwargs):
 
 
 
 
 
 
3079
 
3080
  if inputs_embeds is None:
3081
  # 1. Extra the input embeddings
 
3084
  # 2. Merge text and images
3085
  if pixel_values is not None:
3086
  image_features = self._encode_image(pixel_values)
3087
+ inputs_embeds, attention_mask = (
3088
+ self._merge_input_ids_with_image_features(
3089
+ image_features, inputs_embeds
3090
+ )
3091
+ )
3092
+
3093
  return self.language_model.generate(
3094
+ input_ids=None, inputs_embeds=inputs_embeds, **kwargs
 
 
3095
  )
3096
 
3097
  def prepare_inputs_for_generation(
 
3120
  remove_prefix_length = decoder_input_ids.shape[1] - 1
3121
 
3122
  decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
3123
+
3124
  return {
3125
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
3126
  "encoder_outputs": encoder_outputs,
 
3134
  "cross_attn_head_mask": cross_attn_head_mask,
3135
  "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
3136
  }
3137
+
3138
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3139
  return self.language_model.shift_tokens_right(labels)
3140
 
3141
  def _reorder_cache(self, *args, **kwargs):
3142
+ return self.language_model._reorder_cache(*args, **kwargs)