File size: 36,051 Bytes
4c65bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
# coding=utf-8
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 Cvt model."""


from __future__ import annotations

import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import tensorflow as tf

from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention
from ...modeling_tf_utils import (
    TFModelInputType,
    TFPreTrainedModel,
    TFSequenceClassificationLoss,
    get_initializer,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_cvt import CvtConfig


logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "CvtConfig"

TF_CVT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "microsoft/cvt-13",
    "microsoft/cvt-13-384",
    "microsoft/cvt-13-384-22k",
    "microsoft/cvt-21",
    "microsoft/cvt-21-384",
    "microsoft/cvt-21-384-22k",
    # See all Cvt models at https://huggingface.co/models?filter=cvt
]


@dataclass
class TFBaseModelOutputWithCLSToken(ModelOutput):
    """
    Base class for model's outputs.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`):
            Classification token at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
            the initial embedding outputs.
    """

    last_hidden_state: tf.Tensor = None
    cls_token_value: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None


class TFCvtDropPath(tf.keras.layers.Layer):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    References:
        (1) github.com:rwightman/pytorch-image-models
    """

    def __init__(self, drop_prob: float, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x: tf.Tensor, training=None):
        if self.drop_prob == 0.0 or not training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
        random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
        random_tensor = tf.floor(random_tensor)
        return (x / keep_prob) * random_tensor


class TFCvtEmbeddings(tf.keras.layers.Layer):
    """Construct the Convolutional Token Embeddings."""

    def __init__(
        self,
        config: CvtConfig,
        patch_size: int,
        embed_dim: int,
        stride: int,
        padding: int,
        dropout_rate: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.convolution_embeddings = TFCvtConvEmbeddings(
            config,
            patch_size=patch_size,
            embed_dim=embed_dim,
            stride=stride,
            padding=padding,
            name="convolution_embeddings",
        )
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
        hidden_state = self.convolution_embeddings(pixel_values)
        hidden_state = self.dropout(hidden_state, training=training)
        return hidden_state


class TFCvtConvEmbeddings(tf.keras.layers.Layer):
    """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts."""

    def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs):
        super().__init__(**kwargs)
        self.padding = tf.keras.layers.ZeroPadding2D(padding=padding)
        self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        self.projection = tf.keras.layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=stride,
            padding="valid",
            data_format="channels_last",
            kernel_initializer=get_initializer(config.initializer_range),
            name="projection",
        )
        # Using the same default epsilon as PyTorch
        self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")

    def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
        if isinstance(pixel_values, dict):
            pixel_values = pixel_values["pixel_values"]

        pixel_values = self.projection(self.padding(pixel_values))

        # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
        batch_size, height, width, num_channels = shape_list(pixel_values)
        hidden_size = height * width
        pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels))
        pixel_values = self.normalization(pixel_values)

        # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
        pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
        return pixel_values


class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer):
    """Convolutional projection layer."""

    def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs):
        super().__init__(**kwargs)
        self.padding = tf.keras.layers.ZeroPadding2D(padding=padding)
        self.convolution = tf.keras.layers.Conv2D(
            filters=embed_dim,
            kernel_size=kernel_size,
            kernel_initializer=get_initializer(config.initializer_range),
            padding="valid",
            strides=stride,
            use_bias=False,
            name="convolution",
            groups=embed_dim,
        )
        # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
        self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")

    def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
        hidden_state = self.convolution(self.padding(hidden_state))
        hidden_state = self.normalization(hidden_state, training=training)
        return hidden_state


class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer):
    """Linear projection layer used to flatten tokens into 1D."""

    def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
        # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
        batch_size, height, width, num_channels = shape_list(hidden_state)
        hidden_size = height * width
        hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
        return hidden_state


class TFCvtSelfAttentionProjection(tf.keras.layers.Layer):
    """Convolutional Projection for Attention."""

    def __init__(
        self,
        config: CvtConfig,
        embed_dim: int,
        kernel_size: int,
        stride: int,
        padding: int,
        projection_method: str = "dw_bn",
        **kwargs,
    ):
        super().__init__(**kwargs)
        if projection_method == "dw_bn":
            self.convolution_projection = TFCvtSelfAttentionConvProjection(
                config, embed_dim, kernel_size, stride, padding, name="convolution_projection"
            )
        self.linear_projection = TFCvtSelfAttentionLinearProjection()

    def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
        hidden_state = self.convolution_projection(hidden_state, training=training)
        hidden_state = self.linear_projection(hidden_state)
        return hidden_state


class TFCvtSelfAttention(tf.keras.layers.Layer):
    """
    Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for
    query, key, and value embeddings.
    """

    def __init__(
        self,
        config: CvtConfig,
        num_heads: int,
        embed_dim: int,
        kernel_size: int,
        stride_q: int,
        stride_kv: int,
        padding_q: int,
        padding_kv: int,
        qkv_projection_method: str,
        qkv_bias: bool,
        attention_drop_rate: float,
        with_cls_token: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.scale = embed_dim**-0.5
        self.with_cls_token = with_cls_token
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.convolution_projection_query = TFCvtSelfAttentionProjection(
            config,
            embed_dim,
            kernel_size,
            stride_q,
            padding_q,
            projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
            name="convolution_projection_query",
        )
        self.convolution_projection_key = TFCvtSelfAttentionProjection(
            config,
            embed_dim,
            kernel_size,
            stride_kv,
            padding_kv,
            projection_method=qkv_projection_method,
            name="convolution_projection_key",
        )
        self.convolution_projection_value = TFCvtSelfAttentionProjection(
            config,
            embed_dim,
            kernel_size,
            stride_kv,
            padding_kv,
            projection_method=qkv_projection_method,
            name="convolution_projection_value",
        )

        self.projection_query = tf.keras.layers.Dense(
            units=embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            use_bias=qkv_bias,
            bias_initializer="zeros",
            name="projection_query",
        )
        self.projection_key = tf.keras.layers.Dense(
            units=embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            use_bias=qkv_bias,
            bias_initializer="zeros",
            name="projection_key",
        )
        self.projection_value = tf.keras.layers.Dense(
            units=embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            use_bias=qkv_bias,
            bias_initializer="zeros",
            name="projection_value",
        )
        self.dropout = tf.keras.layers.Dropout(attention_drop_rate)

    def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor:
        batch_size, hidden_size, _ = shape_list(hidden_state)
        head_dim = self.embed_dim // self.num_heads
        hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim))
        hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3))
        return hidden_state

    def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
        if self.with_cls_token:
            cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)

        # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
        batch_size, hidden_size, num_channels = shape_list(hidden_state)
        hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))

        key = self.convolution_projection_key(hidden_state, training=training)
        query = self.convolution_projection_query(hidden_state, training=training)
        value = self.convolution_projection_value(hidden_state, training=training)

        if self.with_cls_token:
            query = tf.concat((cls_token, query), axis=1)
            key = tf.concat((cls_token, key), axis=1)
            value = tf.concat((cls_token, value), axis=1)

        head_dim = self.embed_dim // self.num_heads

        query = self.rearrange_for_multi_head_attention(self.projection_query(query))
        key = self.rearrange_for_multi_head_attention(self.projection_key(key))
        value = self.rearrange_for_multi_head_attention(self.projection_value(value))

        attention_score = tf.matmul(query, key, transpose_b=True) * self.scale
        attention_probs = stable_softmax(logits=attention_score, axis=-1)
        attention_probs = self.dropout(attention_probs, training=training)

        context = tf.matmul(attention_probs, value)
        # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)"
        _, _, hidden_size, _ = shape_list(context)
        context = tf.transpose(context, perm=(0, 2, 1, 3))
        context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
        return context


class TFCvtSelfOutput(tf.keras.layers.Layer):
    """Output of the Attention layer ."""

    def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.dropout = tf.keras.layers.Dropout(drop_rate)

    def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
        hidden_state = self.dense(inputs=hidden_state)
        hidden_state = self.dropout(inputs=hidden_state, training=training)
        return hidden_state


class TFCvtAttention(tf.keras.layers.Layer):
    """Attention layer. First chunk of the convolutional transformer block."""

    def __init__(
        self,
        config: CvtConfig,
        num_heads: int,
        embed_dim: int,
        kernel_size: int,
        stride_q: int,
        stride_kv: int,
        padding_q: int,
        padding_kv: int,
        qkv_projection_method: str,
        qkv_bias: bool,
        attention_drop_rate: float,
        drop_rate: float,
        with_cls_token: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.attention = TFCvtSelfAttention(
            config,
            num_heads,
            embed_dim,
            kernel_size,
            stride_q,
            stride_kv,
            padding_q,
            padding_kv,
            qkv_projection_method,
            qkv_bias,
            attention_drop_rate,
            with_cls_token,
            name="attention",
        )
        self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False):
        self_output = self.attention(hidden_state, height, width, training=training)
        attention_output = self.dense_output(self_output, training=training)
        return attention_output


class TFCvtIntermediate(tf.keras.layers.Layer):
    """Intermediate dense layer. Second chunk of the convolutional transformer block."""

    def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            units=int(embed_dim * mlp_ratio),
            kernel_initializer=get_initializer(config.initializer_range),
            activation="gelu",
            name="dense",
        )

    def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
        hidden_state = self.dense(hidden_state)
        return hidden_state


class TFCvtOutput(tf.keras.layers.Layer):
    """
    Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection.
    """

    def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: int, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.dropout = tf.keras.layers.Dropout(drop_rate)

    def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        hidden_state = self.dense(inputs=hidden_state)
        hidden_state = self.dropout(inputs=hidden_state, training=training)
        hidden_state = hidden_state + input_tensor
        return hidden_state


class TFCvtLayer(tf.keras.layers.Layer):
    """
    Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It
    consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the
    `Block` class in the original implementation.
    """

    def __init__(
        self,
        config: CvtConfig,
        num_heads: int,
        embed_dim: int,
        kernel_size: int,
        stride_q: int,
        stride_kv: int,
        padding_q: int,
        padding_kv: int,
        qkv_projection_method: str,
        qkv_bias: bool,
        attention_drop_rate: float,
        drop_rate: float,
        mlp_ratio: float,
        drop_path_rate: float,
        with_cls_token: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.attention = TFCvtAttention(
            config,
            num_heads,
            embed_dim,
            kernel_size,
            stride_q,
            stride_kv,
            padding_q,
            padding_kv,
            qkv_projection_method,
            qkv_bias,
            attention_drop_rate,
            drop_rate,
            with_cls_token,
            name="attention",
        )
        self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate")
        self.dense_output = TFCvtOutput(config, embed_dim, drop_rate, name="output")
        # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour.
        self.drop_path = (
            TFCvtDropPath(drop_path_rate, name="drop_path")
            if drop_path_rate > 0.0
            else tf.keras.layers.Activation("linear", name="drop_path")
        )
        # Using the same default epsilon as PyTorch
        self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before")
        self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after")

    def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
        # in Cvt, layernorm is applied before self-attention
        attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training)
        attention_output = self.drop_path(attention_output, training=training)

        # first residual connection
        hidden_state = attention_output + hidden_state

        # in Cvt, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(hidden_state)
        layer_output = self.intermediate(layer_output)

        # second residual connection is done here
        layer_output = self.dense_output(layer_output, hidden_state)
        layer_output = self.drop_path(layer_output, training=training)
        return layer_output


class TFCvtStage(tf.keras.layers.Layer):
    """
    Cvt stage (encoder block). Each stage has 2 parts :
    - (1) A Convolutional Token Embedding layer
    - (2) A Convolutional Transformer Block (layer).
    The classification token is added only in the last stage.

    Args:
        config ([`CvtConfig`]): Model configuration class.
        stage (`int`): Stage number.
    """

    def __init__(self, config: CvtConfig, stage: int, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.stage = stage
        if self.config.cls_token[self.stage]:
            self.cls_token = self.add_weight(
                shape=(1, 1, self.config.embed_dim[-1]),
                initializer=get_initializer(self.config.initializer_range),
                trainable=True,
                name="cvt.encoder.stages.2.cls_token",
            )

        self.embedding = TFCvtEmbeddings(
            self.config,
            patch_size=config.patch_sizes[self.stage],
            stride=config.patch_stride[self.stage],
            embed_dim=config.embed_dim[self.stage],
            padding=config.patch_padding[self.stage],
            dropout_rate=config.drop_rate[self.stage],
            name="embedding",
        )

        drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage])
        drop_path_rates = [x.numpy().item() for x in drop_path_rates]
        self.layers = [
            TFCvtLayer(
                config,
                num_heads=config.num_heads[self.stage],
                embed_dim=config.embed_dim[self.stage],
                kernel_size=config.kernel_qkv[self.stage],
                stride_q=config.stride_q[self.stage],
                stride_kv=config.stride_kv[self.stage],
                padding_q=config.padding_q[self.stage],
                padding_kv=config.padding_kv[self.stage],
                qkv_projection_method=config.qkv_projection_method[self.stage],
                qkv_bias=config.qkv_bias[self.stage],
                attention_drop_rate=config.attention_drop_rate[self.stage],
                drop_rate=config.drop_rate[self.stage],
                mlp_ratio=config.mlp_ratio[self.stage],
                drop_path_rate=drop_path_rates[self.stage],
                with_cls_token=config.cls_token[self.stage],
                name=f"layers.{j}",
            )
            for j in range(config.depth[self.stage])
        ]

    def call(self, hidden_state: tf.Tensor, training: bool = False):
        cls_token = None
        hidden_state = self.embedding(hidden_state, training)

        # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
        batch_size, height, width, num_channels = shape_list(hidden_state)
        hidden_size = height * width
        hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))

        if self.config.cls_token[self.stage]:
            cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
            hidden_state = tf.concat((cls_token, hidden_state), axis=1)

        for layer in self.layers:
            layer_outputs = layer(hidden_state, height, width, training=training)
            hidden_state = layer_outputs

        if self.config.cls_token[self.stage]:
            cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)

        # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
        hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
        return hidden_state, cls_token


class TFCvtEncoder(tf.keras.layers.Layer):
    """
    Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers
    (depth) being 1, 2 and 10.

    Args:
        config ([`CvtConfig`]): Model configuration class.
    """

    config_class = CvtConfig

    def __init__(self, config: CvtConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.stages = [
            TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth))
        ]

    def call(
        self,
        pixel_values: TFModelInputType,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
        all_hidden_states = () if output_hidden_states else None
        hidden_state = pixel_values
        # When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
        # as input format. So change the input format to (batch_size, height, width, num_channels).
        hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))

        cls_token = None
        for _, (stage_module) in enumerate(self.stages):
            hidden_state, cls_token = stage_module(hidden_state, training=training)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

        # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules
        hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2))
        if output_hidden_states:
            all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states])

        if not return_dict:
            return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)

        return TFBaseModelOutputWithCLSToken(
            last_hidden_state=hidden_state,
            cls_token_value=cls_token,
            hidden_states=all_hidden_states,
        )


@keras_serializable
class TFCvtMainLayer(tf.keras.layers.Layer):
    """Construct the Cvt model."""

    config_class = CvtConfig

    def __init__(self, config: CvtConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.encoder = TFCvtEncoder(config, name="encoder")

    @unpack_inputs
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        encoder_outputs = self.encoder(
            pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        sequence_output = encoder_outputs[0]

        if not return_dict:
            return (sequence_output,) + encoder_outputs[1:]

        return TFBaseModelOutputWithCLSToken(
            last_hidden_state=sequence_output,
            cls_token_value=encoder_outputs.cls_token_value,
            hidden_states=encoder_outputs.hidden_states,
        )


class TFCvtPreTrainedModel(TFPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = CvtConfig
    base_model_prefix = "cvt"
    main_input_name = "pixel_values"


TFCVT_START_DOCSTRING = r"""

    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TF 2.0 models accepts two formats as inputs:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional arguments.

    This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
    tensors in the first argument of the model call function: `model(inputs)`.

    </Tip>

    Args:
        config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""

TFCVT_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
            for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
            used instead.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
            eager mode, in graph mode the value will always be set to True.
        training (`bool`, *optional*, defaults to `False``):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
"""


@add_start_docstrings(
    "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
    TFCVT_START_DOCSTRING,
)
class TFCvtModel(TFCvtPreTrainedModel):
    def __init__(self, config: CvtConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.cvt = TFCvtMainLayer(config, name="cvt")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        pixel_values: tf.Tensor | None = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, TFCvtModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
        >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13")

        >>> inputs = image_processor(images=image, return_tensors="tf")
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        outputs = self.cvt(
            pixel_values=pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        if not return_dict:
            return (outputs[0],) + outputs[1:]

        return TFBaseModelOutputWithCLSToken(
            last_hidden_state=outputs.last_hidden_state,
            cls_token_value=outputs.cls_token_value,
            hidden_states=outputs.hidden_states,
        )


@add_start_docstrings(
    """
    Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    """,
    TFCVT_START_DOCSTRING,
)
class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss):
    def __init__(self, config: CvtConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.num_labels = config.num_labels
        self.cvt = TFCvtMainLayer(config, name="cvt")
        # Using same default epsilon as in the original implementation.
        self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")

        # Classifier head
        self.classifier = tf.keras.layers.Dense(
            units=config.num_labels,
            kernel_initializer=get_initializer(config.initializer_range),
            use_bias=True,
            bias_initializer="zeros",
            name="classifier",
        )

    @unpack_inputs
    @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        pixel_values: tf.Tensor | None = None,
        labels: tf.Tensor | None = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, TFCvtForImageClassification
        >>> import tensorflow as tf
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
        >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13")

        >>> inputs = image_processor(images=image, return_tensors="tf")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the 1000 ImageNet classes
        >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
        >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
        ```"""

        outputs = self.cvt(
            pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        sequence_output = outputs[0]
        cls_token = outputs[1]
        if self.config.cls_token[-1]:
            sequence_output = self.layernorm(cls_token)
        else:
            # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels"
            batch_size, num_channels, height, width = shape_list(sequence_output)
            sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width))
            sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1))
            sequence_output = self.layernorm(sequence_output)

        sequence_output_mean = tf.reduce_mean(sequence_output, axis=1)
        logits = self.classifier(sequence_output_mean)
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)