KerasHub
Divyasreepat commited on
Commit
d1e0e6a
1 Parent(s): 9bcb07c

Update README.md with new model card content

Browse files
Files changed (1) hide show
  1. README.md +92 -15
README.md CHANGED
@@ -1,18 +1,95 @@
1
  ---
2
  library_name: keras-hub
3
  ---
4
- This is a [`MiT` model](https://keras.io/api/keras_hub/models/mi_t) uploaded using the KerasHub library and can be used with JAX, TensorFlow, and PyTorch backends.
5
- Model config:
6
- * **name:** mi_t_backbone
7
- * **trainable:** True
8
- * **depths:** [3, 6, 40, 3]
9
- * **hidden_dims:** [64, 128, 320, 512]
10
- * **image_shape:** [224, 224, 3]
11
- * **num_layers:** 4
12
- * **blockwise_num_heads:** [1, 2, 5, 8]
13
- * **blockwise_sr_ratios:** [8, 4, 2, 1]
14
- * **max_drop_path_rate:** 0.1
15
- * **patch_sizes:** [7, 3, 3, 3]
16
- * **strides:** [4, 2, 2, 2]
17
-
18
- This model card has been generated automatically and should be completed by the model author. See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for more information.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: keras-hub
3
  ---
4
+ ### Model Overview
5
+ A Keras model implementing the MixTransformer architecture to be used as a backbone for the SegFormer architecture. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub.
6
+
7
+ References:
8
+ - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) # noqa: E501
9
+ - [Based on the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) # noqa: E501
10
+
11
+ ## Links
12
+ * [MiT Quickstart Notebook: coming soon]()
13
+ * [MiT API Documentation: coming soon]()
14
+
15
+ ## Installation
16
+
17
+ Keras and KerasHub can be installed with:
18
+
19
+ ```
20
+ pip install -U -q keras-Hub
21
+ pip install -U -q keras>=3
22
+ ```
23
+
24
+ Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the [Keras Getting Started](https://keras.io/getting_started/) page.
25
+
26
+ ## Presets
27
+
28
+ The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below.
29
+ Here's the table formatted similarly to the given pattern:
30
+
31
+ Here's the updated table with the input resolutions included in the descriptions:
32
+
33
+ | Preset name | Parameters | Description |
34
+ |--------------------------|------------|--------------------------------------------------------------------------------------------------|
35
+ | mit_b0_ade20k_512 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. |
36
+ | mit_b1_ade20k_512 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. |
37
+ | mit_b2_ade20k_512 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. |
38
+ | mit_b3_ade20k_512 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. |
39
+ | mit_b4_ade20k_512 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. |
40
+ | mit_b5_ade20k_640 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the ADE20K dataset with an input resolution of 640x640 pixels. |
41
+ | mit_b0_cityscapes_1024 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
42
+ | mit_b1_cityscapes_1024 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
43
+ | mit_b2_cityscapes_1024 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
44
+ | mit_b3_cityscapes_1024 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
45
+ | mit_b4_cityscapes_1024 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
46
+ | mit_b5_cityscapes_1024 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. |
47
+
48
+ ### Example Usage
49
+ Using the class with a `backbone`:
50
+
51
+ ```
52
+ import tensorflow as tf
53
+ import keras_cv
54
+ import numpy as np
55
+
56
+ images = np.ones(shape=(1, 96, 96, 3))
57
+ labels = np.zeros(shape=(1, 96, 96, 1))
58
+ backbone = keras_cv.models.MiTBackbone.from_preset("mit_b5_ade20k_640")
59
+
60
+ # Evaluate model
61
+ model(images)
62
+
63
+ # Train model
64
+ model.compile(
65
+ optimizer="adam",
66
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
67
+ metrics=["accuracy"],
68
+ )
69
+ model.fit(images, labels, epochs=3)
70
+ ```
71
+
72
+ ## Example Usage with Hugging Face URI
73
+
74
+ Using the class with a `backbone`:
75
+
76
+ ```
77
+ import tensorflow as tf
78
+ import keras_cv
79
+ import numpy as np
80
+
81
+ images = np.ones(shape=(1, 96, 96, 3))
82
+ labels = np.zeros(shape=(1, 96, 96, 1))
83
+ backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/mit_b5_ade20k_640")
84
+
85
+ # Evaluate model
86
+ model(images)
87
+
88
+ # Train model
89
+ model.compile(
90
+ optimizer="adam",
91
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
92
+ metrics=["accuracy"],
93
+ )
94
+ model.fit(images, labels, epochs=3)
95
+ ```