|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class ViTMixConfig(PretrainedConfig): |
|
model_type = "VitMix" |
|
|
|
def __init__( |
|
self, |
|
image_size = 28, |
|
patch_size = 14, |
|
num_classes = 10, |
|
dim = 1024, |
|
depth = 6, |
|
heads = 16, |
|
mlp_dim = 2048, |
|
num_experts = 12, |
|
**kwargs |
|
): |
|
if image_size % patch_size != 0: |
|
print(f"image size must be half patch size! img_size: {image_size} | patch_size{patch_size}") |
|
|
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.num_classes = num_classes |
|
self.dim = dim |
|
self.depth = depth |
|
self.heads = heads |
|
self.mlp_dim = mlp_dim |
|
self.num_experts = num_experts |
|
super().__init__(**kwargs) |