|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Auto Model class.""" |
|
|
|
import warnings |
|
from collections import OrderedDict |
|
|
|
from ...utils import logging |
|
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update |
|
from .configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
MODEL_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertModel"), |
|
("align", "AlignModel"), |
|
("altclip", "AltCLIPModel"), |
|
("audio-spectrogram-transformer", "ASTModel"), |
|
("autoformer", "AutoformerModel"), |
|
("bark", "BarkModel"), |
|
("bart", "BartModel"), |
|
("beit", "BeitModel"), |
|
("bert", "BertModel"), |
|
("bert-generation", "BertGenerationEncoder"), |
|
("big_bird", "BigBirdModel"), |
|
("bigbird_pegasus", "BigBirdPegasusModel"), |
|
("biogpt", "BioGptModel"), |
|
("bit", "BitModel"), |
|
("blenderbot", "BlenderbotModel"), |
|
("blenderbot-small", "BlenderbotSmallModel"), |
|
("blip", "BlipModel"), |
|
("blip-2", "Blip2Model"), |
|
("bloom", "BloomModel"), |
|
("bridgetower", "BridgeTowerModel"), |
|
("bros", "BrosModel"), |
|
("camembert", "CamembertModel"), |
|
("canine", "CanineModel"), |
|
("chinese_clip", "ChineseCLIPModel"), |
|
("clap", "ClapModel"), |
|
("clip", "CLIPModel"), |
|
("clipseg", "CLIPSegModel"), |
|
("code_llama", "LlamaModel"), |
|
("codegen", "CodeGenModel"), |
|
("conditional_detr", "ConditionalDetrModel"), |
|
("convbert", "ConvBertModel"), |
|
("convnext", "ConvNextModel"), |
|
("convnextv2", "ConvNextV2Model"), |
|
("cpmant", "CpmAntModel"), |
|
("ctrl", "CTRLModel"), |
|
("cvt", "CvtModel"), |
|
("data2vec-audio", "Data2VecAudioModel"), |
|
("data2vec-text", "Data2VecTextModel"), |
|
("data2vec-vision", "Data2VecVisionModel"), |
|
("deberta", "DebertaModel"), |
|
("deberta-v2", "DebertaV2Model"), |
|
("decision_transformer", "DecisionTransformerModel"), |
|
("deformable_detr", "DeformableDetrModel"), |
|
("deit", "DeiTModel"), |
|
("deta", "DetaModel"), |
|
("detr", "DetrModel"), |
|
("dinat", "DinatModel"), |
|
("dinov2", "Dinov2Model"), |
|
("distilbert", "DistilBertModel"), |
|
("donut-swin", "DonutSwinModel"), |
|
("dpr", "DPRQuestionEncoder"), |
|
("dpt", "DPTModel"), |
|
("efficientformer", "EfficientFormerModel"), |
|
("efficientnet", "EfficientNetModel"), |
|
("electra", "ElectraModel"), |
|
("encodec", "EncodecModel"), |
|
("ernie", "ErnieModel"), |
|
("ernie_m", "ErnieMModel"), |
|
("esm", "EsmModel"), |
|
("falcon", "FalconModel"), |
|
("flaubert", "FlaubertModel"), |
|
("flava", "FlavaModel"), |
|
("fnet", "FNetModel"), |
|
("focalnet", "FocalNetModel"), |
|
("fsmt", "FSMTModel"), |
|
("funnel", ("FunnelModel", "FunnelBaseModel")), |
|
("git", "GitModel"), |
|
("glpn", "GLPNModel"), |
|
("gpt-sw3", "GPT2Model"), |
|
("gpt2", "GPT2Model"), |
|
("gpt_bigcode", "GPTBigCodeModel"), |
|
("gpt_neo", "GPTNeoModel"), |
|
("gpt_neox", "GPTNeoXModel"), |
|
("gpt_neox_japanese", "GPTNeoXJapaneseModel"), |
|
("gptj", "GPTJModel"), |
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
|
("graphormer", "GraphormerModel"), |
|
("groupvit", "GroupViTModel"), |
|
("hubert", "HubertModel"), |
|
("ibert", "IBertModel"), |
|
("idefics", "IdeficsModel"), |
|
("imagegpt", "ImageGPTModel"), |
|
("informer", "InformerModel"), |
|
("jukebox", "JukeboxModel"), |
|
("layoutlm", "LayoutLMModel"), |
|
("layoutlmv2", "LayoutLMv2Model"), |
|
("layoutlmv3", "LayoutLMv3Model"), |
|
("led", "LEDModel"), |
|
("levit", "LevitModel"), |
|
("lilt", "LiltModel"), |
|
("llama", "LlamaModel"), |
|
("longformer", "LongformerModel"), |
|
("longt5", "LongT5Model"), |
|
("luke", "LukeModel"), |
|
("lxmert", "LxmertModel"), |
|
("m2m_100", "M2M100Model"), |
|
("marian", "MarianModel"), |
|
("markuplm", "MarkupLMModel"), |
|
("mask2former", "Mask2FormerModel"), |
|
("maskformer", "MaskFormerModel"), |
|
("maskformer-swin", "MaskFormerSwinModel"), |
|
("mbart", "MBartModel"), |
|
("mctct", "MCTCTModel"), |
|
("mega", "MegaModel"), |
|
("megatron-bert", "MegatronBertModel"), |
|
("mgp-str", "MgpstrForSceneTextRecognition"), |
|
("mistral", "MistralModel"), |
|
("mobilebert", "MobileBertModel"), |
|
("mobilenet_v1", "MobileNetV1Model"), |
|
("mobilenet_v2", "MobileNetV2Model"), |
|
("mobilevit", "MobileViTModel"), |
|
("mobilevitv2", "MobileViTV2Model"), |
|
("mpnet", "MPNetModel"), |
|
("mpt", "MptModel"), |
|
("mra", "MraModel"), |
|
("mt5", "MT5Model"), |
|
("mvp", "MvpModel"), |
|
("nat", "NatModel"), |
|
("nezha", "NezhaModel"), |
|
("nllb-moe", "NllbMoeModel"), |
|
("nystromformer", "NystromformerModel"), |
|
("oneformer", "OneFormerModel"), |
|
("open-llama", "OpenLlamaModel"), |
|
("openai-gpt", "OpenAIGPTModel"), |
|
("opt", "OPTModel"), |
|
("owlvit", "OwlViTModel"), |
|
("pegasus", "PegasusModel"), |
|
("pegasus_x", "PegasusXModel"), |
|
("perceiver", "PerceiverModel"), |
|
("persimmon", "PersimmonModel"), |
|
("plbart", "PLBartModel"), |
|
("poolformer", "PoolFormerModel"), |
|
("prophetnet", "ProphetNetModel"), |
|
("pvt", "PvtModel"), |
|
("qdqbert", "QDQBertModel"), |
|
("reformer", "ReformerModel"), |
|
("regnet", "RegNetModel"), |
|
("rembert", "RemBertModel"), |
|
("resnet", "ResNetModel"), |
|
("retribert", "RetriBertModel"), |
|
("roberta", "RobertaModel"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormModel"), |
|
("roc_bert", "RoCBertModel"), |
|
("roformer", "RoFormerModel"), |
|
("rwkv", "RwkvModel"), |
|
("sam", "SamModel"), |
|
("segformer", "SegformerModel"), |
|
("sew", "SEWModel"), |
|
("sew-d", "SEWDModel"), |
|
("speech_to_text", "Speech2TextModel"), |
|
("speecht5", "SpeechT5Model"), |
|
("splinter", "SplinterModel"), |
|
("squeezebert", "SqueezeBertModel"), |
|
("swiftformer", "SwiftFormerModel"), |
|
("swin", "SwinModel"), |
|
("swin2sr", "Swin2SRModel"), |
|
("swinv2", "Swinv2Model"), |
|
("switch_transformers", "SwitchTransformersModel"), |
|
("t5", "T5Model"), |
|
("table-transformer", "TableTransformerModel"), |
|
("tapas", "TapasModel"), |
|
("time_series_transformer", "TimeSeriesTransformerModel"), |
|
("timesformer", "TimesformerModel"), |
|
("timm_backbone", "TimmBackbone"), |
|
("trajectory_transformer", "TrajectoryTransformerModel"), |
|
("transfo-xl", "TransfoXLModel"), |
|
("tvlt", "TvltModel"), |
|
("umt5", "UMT5Model"), |
|
("unispeech", "UniSpeechModel"), |
|
("unispeech-sat", "UniSpeechSatModel"), |
|
("van", "VanModel"), |
|
("videomae", "VideoMAEModel"), |
|
("vilt", "ViltModel"), |
|
("vision-text-dual-encoder", "VisionTextDualEncoderModel"), |
|
("visual_bert", "VisualBertModel"), |
|
("vit", "ViTModel"), |
|
("vit_hybrid", "ViTHybridModel"), |
|
("vit_mae", "ViTMAEModel"), |
|
("vit_msn", "ViTMSNModel"), |
|
("vitdet", "VitDetModel"), |
|
("vits", "VitsModel"), |
|
("vivit", "VivitModel"), |
|
("wav2vec2", "Wav2Vec2Model"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerModel"), |
|
("wavlm", "WavLMModel"), |
|
("whisper", "WhisperModel"), |
|
("xclip", "XCLIPModel"), |
|
("xglm", "XGLMModel"), |
|
("xlm", "XLMModel"), |
|
("xlm-prophetnet", "XLMProphetNetModel"), |
|
("xlm-roberta", "XLMRobertaModel"), |
|
("xlm-roberta-xl", "XLMRobertaXLModel"), |
|
("xlnet", "XLNetModel"), |
|
("xmod", "XmodModel"), |
|
("yolos", "YolosModel"), |
|
("yoso", "YosoModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForPreTraining"), |
|
("bart", "BartForConditionalGeneration"), |
|
("bert", "BertForPreTraining"), |
|
("big_bird", "BigBirdForPreTraining"), |
|
("bloom", "BloomForCausalLM"), |
|
("camembert", "CamembertForMaskedLM"), |
|
("ctrl", "CTRLLMHeadModel"), |
|
("data2vec-text", "Data2VecTextForMaskedLM"), |
|
("deberta", "DebertaForMaskedLM"), |
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
("distilbert", "DistilBertForMaskedLM"), |
|
("electra", "ElectraForPreTraining"), |
|
("ernie", "ErnieForPreTraining"), |
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
("flava", "FlavaForPreTraining"), |
|
("fnet", "FNetForPreTraining"), |
|
("fsmt", "FSMTForConditionalGeneration"), |
|
("funnel", "FunnelForPreTraining"), |
|
("gpt-sw3", "GPT2LMHeadModel"), |
|
("gpt2", "GPT2LMHeadModel"), |
|
("gpt_bigcode", "GPTBigCodeForCausalLM"), |
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
|
("ibert", "IBertForMaskedLM"), |
|
("idefics", "IdeficsForVisionText2Text"), |
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
("longformer", "LongformerForMaskedLM"), |
|
("luke", "LukeForMaskedLM"), |
|
("lxmert", "LxmertForPreTraining"), |
|
("mega", "MegaForMaskedLM"), |
|
("megatron-bert", "MegatronBertForPreTraining"), |
|
("mobilebert", "MobileBertForPreTraining"), |
|
("mpnet", "MPNetForMaskedLM"), |
|
("mpt", "MptForCausalLM"), |
|
("mra", "MraForMaskedLM"), |
|
("mvp", "MvpForConditionalGeneration"), |
|
("nezha", "NezhaForPreTraining"), |
|
("nllb-moe", "NllbMoeForConditionalGeneration"), |
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
("retribert", "RetriBertModel"), |
|
("roberta", "RobertaForMaskedLM"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
|
("roc_bert", "RoCBertForPreTraining"), |
|
("rwkv", "RwkvForCausalLM"), |
|
("splinter", "SplinterForPreTraining"), |
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
|
("t5", "T5ForConditionalGeneration"), |
|
("tapas", "TapasForMaskedLM"), |
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
("tvlt", "TvltForPreTraining"), |
|
("unispeech", "UniSpeechForPreTraining"), |
|
("unispeech-sat", "UniSpeechSatForPreTraining"), |
|
("videomae", "VideoMAEForPreTraining"), |
|
("visual_bert", "VisualBertForPreTraining"), |
|
("vit_mae", "ViTMAEForPreTraining"), |
|
("wav2vec2", "Wav2Vec2ForPreTraining"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"), |
|
("xlm", "XLMWithLMHeadModel"), |
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
|
("xlnet", "XLNetLMHeadModel"), |
|
("xmod", "XmodForMaskedLM"), |
|
] |
|
) |
|
|
|
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForMaskedLM"), |
|
("bart", "BartForConditionalGeneration"), |
|
("bert", "BertForMaskedLM"), |
|
("big_bird", "BigBirdForMaskedLM"), |
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
|
("bloom", "BloomForCausalLM"), |
|
("camembert", "CamembertForMaskedLM"), |
|
("codegen", "CodeGenForCausalLM"), |
|
("convbert", "ConvBertForMaskedLM"), |
|
("cpmant", "CpmAntForCausalLM"), |
|
("ctrl", "CTRLLMHeadModel"), |
|
("data2vec-text", "Data2VecTextForMaskedLM"), |
|
("deberta", "DebertaForMaskedLM"), |
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
("distilbert", "DistilBertForMaskedLM"), |
|
("electra", "ElectraForMaskedLM"), |
|
("encoder-decoder", "EncoderDecoderModel"), |
|
("ernie", "ErnieForMaskedLM"), |
|
("esm", "EsmForMaskedLM"), |
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
("fnet", "FNetForMaskedLM"), |
|
("fsmt", "FSMTForConditionalGeneration"), |
|
("funnel", "FunnelForMaskedLM"), |
|
("git", "GitForCausalLM"), |
|
("gpt-sw3", "GPT2LMHeadModel"), |
|
("gpt2", "GPT2LMHeadModel"), |
|
("gpt_bigcode", "GPTBigCodeForCausalLM"), |
|
("gpt_neo", "GPTNeoForCausalLM"), |
|
("gpt_neox", "GPTNeoXForCausalLM"), |
|
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), |
|
("gptj", "GPTJForCausalLM"), |
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
|
("ibert", "IBertForMaskedLM"), |
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
("led", "LEDForConditionalGeneration"), |
|
("longformer", "LongformerForMaskedLM"), |
|
("longt5", "LongT5ForConditionalGeneration"), |
|
("luke", "LukeForMaskedLM"), |
|
("m2m_100", "M2M100ForConditionalGeneration"), |
|
("marian", "MarianMTModel"), |
|
("mega", "MegaForMaskedLM"), |
|
("megatron-bert", "MegatronBertForCausalLM"), |
|
("mobilebert", "MobileBertForMaskedLM"), |
|
("mpnet", "MPNetForMaskedLM"), |
|
("mpt", "MptForCausalLM"), |
|
("mra", "MraForMaskedLM"), |
|
("mvp", "MvpForConditionalGeneration"), |
|
("nezha", "NezhaForMaskedLM"), |
|
("nllb-moe", "NllbMoeForConditionalGeneration"), |
|
("nystromformer", "NystromformerForMaskedLM"), |
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
("pegasus_x", "PegasusXForConditionalGeneration"), |
|
("plbart", "PLBartForConditionalGeneration"), |
|
("pop2piano", "Pop2PianoForConditionalGeneration"), |
|
("qdqbert", "QDQBertForMaskedLM"), |
|
("reformer", "ReformerModelWithLMHead"), |
|
("rembert", "RemBertForMaskedLM"), |
|
("roberta", "RobertaForMaskedLM"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
|
("roc_bert", "RoCBertForMaskedLM"), |
|
("roformer", "RoFormerForMaskedLM"), |
|
("rwkv", "RwkvForCausalLM"), |
|
("speech_to_text", "Speech2TextForConditionalGeneration"), |
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
|
("t5", "T5ForConditionalGeneration"), |
|
("tapas", "TapasForMaskedLM"), |
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
("wav2vec2", "Wav2Vec2ForMaskedLM"), |
|
("whisper", "WhisperForConditionalGeneration"), |
|
("xlm", "XLMWithLMHeadModel"), |
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
|
("xlnet", "XLNetLMHeadModel"), |
|
("xmod", "XmodForMaskedLM"), |
|
("yoso", "YosoForMaskedLM"), |
|
] |
|
) |
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bart", "BartForCausalLM"), |
|
("bert", "BertLMHeadModel"), |
|
("bert-generation", "BertGenerationDecoder"), |
|
("big_bird", "BigBirdForCausalLM"), |
|
("bigbird_pegasus", "BigBirdPegasusForCausalLM"), |
|
("biogpt", "BioGptForCausalLM"), |
|
("blenderbot", "BlenderbotForCausalLM"), |
|
("blenderbot-small", "BlenderbotSmallForCausalLM"), |
|
("bloom", "BloomForCausalLM"), |
|
("camembert", "CamembertForCausalLM"), |
|
("code_llama", "LlamaForCausalLM"), |
|
("codegen", "CodeGenForCausalLM"), |
|
("cpmant", "CpmAntForCausalLM"), |
|
("ctrl", "CTRLLMHeadModel"), |
|
("data2vec-text", "Data2VecTextForCausalLM"), |
|
("electra", "ElectraForCausalLM"), |
|
("ernie", "ErnieForCausalLM"), |
|
("falcon", "FalconForCausalLM"), |
|
("git", "GitForCausalLM"), |
|
("gpt-sw3", "GPT2LMHeadModel"), |
|
("gpt2", "GPT2LMHeadModel"), |
|
("gpt_bigcode", "GPTBigCodeForCausalLM"), |
|
("gpt_neo", "GPTNeoForCausalLM"), |
|
("gpt_neox", "GPTNeoXForCausalLM"), |
|
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), |
|
("gptj", "GPTJForCausalLM"), |
|
("llama", "LlamaForCausalLM"), |
|
("marian", "MarianForCausalLM"), |
|
("mbart", "MBartForCausalLM"), |
|
("mega", "MegaForCausalLM"), |
|
("megatron-bert", "MegatronBertForCausalLM"), |
|
("mistral", "MistralForCausalLM"), |
|
("mpt", "MptForCausalLM"), |
|
("musicgen", "MusicgenForCausalLM"), |
|
("mvp", "MvpForCausalLM"), |
|
("open-llama", "OpenLlamaForCausalLM"), |
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
("opt", "OPTForCausalLM"), |
|
("pegasus", "PegasusForCausalLM"), |
|
("persimmon", "PersimmonForCausalLM"), |
|
("plbart", "PLBartForCausalLM"), |
|
("prophetnet", "ProphetNetForCausalLM"), |
|
("qdqbert", "QDQBertLMHeadModel"), |
|
("reformer", "ReformerModelWithLMHead"), |
|
("rembert", "RemBertForCausalLM"), |
|
("roberta", "RobertaForCausalLM"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"), |
|
("roc_bert", "RoCBertForCausalLM"), |
|
("roformer", "RoFormerForCausalLM"), |
|
("rwkv", "RwkvForCausalLM"), |
|
("speech_to_text_2", "Speech2Text2ForCausalLM"), |
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
("trocr", "TrOCRForCausalLM"), |
|
("xglm", "XGLMForCausalLM"), |
|
("xlm", "XLMWithLMHeadModel"), |
|
("xlm-prophetnet", "XLMProphetNetForCausalLM"), |
|
("xlm-roberta", "XLMRobertaForCausalLM"), |
|
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), |
|
("xlnet", "XLNetLMHeadModel"), |
|
("xmod", "XmodForCausalLM"), |
|
] |
|
) |
|
|
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("deit", "DeiTForMaskedImageModeling"), |
|
("focalnet", "FocalNetForMaskedImageModeling"), |
|
("swin", "SwinForMaskedImageModeling"), |
|
("swinv2", "Swinv2ForMaskedImageModeling"), |
|
("vit", "ViTForMaskedImageModeling"), |
|
] |
|
) |
|
|
|
|
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
("imagegpt", "ImageGPTForCausalImageModeling"), |
|
] |
|
) |
|
|
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("beit", "BeitForImageClassification"), |
|
("bit", "BitForImageClassification"), |
|
("convnext", "ConvNextForImageClassification"), |
|
("convnextv2", "ConvNextV2ForImageClassification"), |
|
("cvt", "CvtForImageClassification"), |
|
("data2vec-vision", "Data2VecVisionForImageClassification"), |
|
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), |
|
("dinat", "DinatForImageClassification"), |
|
("dinov2", "Dinov2ForImageClassification"), |
|
( |
|
"efficientformer", |
|
( |
|
"EfficientFormerForImageClassification", |
|
"EfficientFormerForImageClassificationWithTeacher", |
|
), |
|
), |
|
("efficientnet", "EfficientNetForImageClassification"), |
|
("focalnet", "FocalNetForImageClassification"), |
|
("imagegpt", "ImageGPTForImageClassification"), |
|
("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")), |
|
("mobilenet_v1", "MobileNetV1ForImageClassification"), |
|
("mobilenet_v2", "MobileNetV2ForImageClassification"), |
|
("mobilevit", "MobileViTForImageClassification"), |
|
("mobilevitv2", "MobileViTV2ForImageClassification"), |
|
("nat", "NatForImageClassification"), |
|
( |
|
"perceiver", |
|
( |
|
"PerceiverForImageClassificationLearned", |
|
"PerceiverForImageClassificationFourier", |
|
"PerceiverForImageClassificationConvProcessing", |
|
), |
|
), |
|
("poolformer", "PoolFormerForImageClassification"), |
|
("pvt", "PvtForImageClassification"), |
|
("regnet", "RegNetForImageClassification"), |
|
("resnet", "ResNetForImageClassification"), |
|
("segformer", "SegformerForImageClassification"), |
|
("swiftformer", "SwiftFormerForImageClassification"), |
|
("swin", "SwinForImageClassification"), |
|
("swinv2", "Swinv2ForImageClassification"), |
|
("van", "VanForImageClassification"), |
|
("vit", "ViTForImageClassification"), |
|
("vit_hybrid", "ViTHybridForImageClassification"), |
|
("vit_msn", "ViTMSNForImageClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
|
|
("detr", "DetrForSegmentation"), |
|
] |
|
) |
|
|
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("beit", "BeitForSemanticSegmentation"), |
|
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), |
|
("dpt", "DPTForSemanticSegmentation"), |
|
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"), |
|
("mobilevit", "MobileViTForSemanticSegmentation"), |
|
("mobilevitv2", "MobileViTV2ForSemanticSegmentation"), |
|
("segformer", "SegformerForSemanticSegmentation"), |
|
("upernet", "UperNetForSemanticSegmentation"), |
|
] |
|
) |
|
|
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
|
|
("maskformer", "MaskFormerForInstanceSegmentation"), |
|
] |
|
) |
|
|
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("detr", "DetrForSegmentation"), |
|
("mask2former", "Mask2FormerForUniversalSegmentation"), |
|
("maskformer", "MaskFormerForInstanceSegmentation"), |
|
("oneformer", "OneFormerForUniversalSegmentation"), |
|
] |
|
) |
|
|
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("timesformer", "TimesformerForVideoClassification"), |
|
("videomae", "VideoMAEForVideoClassification"), |
|
("vivit", "VivitForVideoClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("blip", "BlipForConditionalGeneration"), |
|
("blip-2", "Blip2ForConditionalGeneration"), |
|
("git", "GitForCausalLM"), |
|
("instructblip", "InstructBlipForConditionalGeneration"), |
|
("pix2struct", "Pix2StructForConditionalGeneration"), |
|
("vision-encoder-decoder", "VisionEncoderDecoderModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForMaskedLM"), |
|
("bart", "BartForConditionalGeneration"), |
|
("bert", "BertForMaskedLM"), |
|
("big_bird", "BigBirdForMaskedLM"), |
|
("camembert", "CamembertForMaskedLM"), |
|
("convbert", "ConvBertForMaskedLM"), |
|
("data2vec-text", "Data2VecTextForMaskedLM"), |
|
("deberta", "DebertaForMaskedLM"), |
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
("distilbert", "DistilBertForMaskedLM"), |
|
("electra", "ElectraForMaskedLM"), |
|
("ernie", "ErnieForMaskedLM"), |
|
("esm", "EsmForMaskedLM"), |
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
("fnet", "FNetForMaskedLM"), |
|
("funnel", "FunnelForMaskedLM"), |
|
("ibert", "IBertForMaskedLM"), |
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
("longformer", "LongformerForMaskedLM"), |
|
("luke", "LukeForMaskedLM"), |
|
("mbart", "MBartForConditionalGeneration"), |
|
("mega", "MegaForMaskedLM"), |
|
("megatron-bert", "MegatronBertForMaskedLM"), |
|
("mobilebert", "MobileBertForMaskedLM"), |
|
("mpnet", "MPNetForMaskedLM"), |
|
("mra", "MraForMaskedLM"), |
|
("mvp", "MvpForConditionalGeneration"), |
|
("nezha", "NezhaForMaskedLM"), |
|
("nystromformer", "NystromformerForMaskedLM"), |
|
("perceiver", "PerceiverForMaskedLM"), |
|
("qdqbert", "QDQBertForMaskedLM"), |
|
("reformer", "ReformerForMaskedLM"), |
|
("rembert", "RemBertForMaskedLM"), |
|
("roberta", "RobertaForMaskedLM"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
|
("roc_bert", "RoCBertForMaskedLM"), |
|
("roformer", "RoFormerForMaskedLM"), |
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
("tapas", "TapasForMaskedLM"), |
|
("wav2vec2", "Wav2Vec2ForMaskedLM"), |
|
("xlm", "XLMWithLMHeadModel"), |
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
|
("xmod", "XmodForMaskedLM"), |
|
("yoso", "YosoForMaskedLM"), |
|
] |
|
) |
|
|
|
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("conditional_detr", "ConditionalDetrForObjectDetection"), |
|
("deformable_detr", "DeformableDetrForObjectDetection"), |
|
("deta", "DetaForObjectDetection"), |
|
("detr", "DetrForObjectDetection"), |
|
("table-transformer", "TableTransformerForObjectDetection"), |
|
("yolos", "YolosForObjectDetection"), |
|
] |
|
) |
|
|
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("owlvit", "OwlViTForObjectDetection") |
|
] |
|
) |
|
|
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("dpt", "DPTForDepthEstimation"), |
|
("glpn", "GLPNForDepthEstimation"), |
|
] |
|
) |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bart", "BartForConditionalGeneration"), |
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
|
("blenderbot", "BlenderbotForConditionalGeneration"), |
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
|
("encoder-decoder", "EncoderDecoderModel"), |
|
("fsmt", "FSMTForConditionalGeneration"), |
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
|
("led", "LEDForConditionalGeneration"), |
|
("longt5", "LongT5ForConditionalGeneration"), |
|
("m2m_100", "M2M100ForConditionalGeneration"), |
|
("marian", "MarianMTModel"), |
|
("mbart", "MBartForConditionalGeneration"), |
|
("mt5", "MT5ForConditionalGeneration"), |
|
("mvp", "MvpForConditionalGeneration"), |
|
("nllb-moe", "NllbMoeForConditionalGeneration"), |
|
("pegasus", "PegasusForConditionalGeneration"), |
|
("pegasus_x", "PegasusXForConditionalGeneration"), |
|
("plbart", "PLBartForConditionalGeneration"), |
|
("prophetnet", "ProphetNetForConditionalGeneration"), |
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
|
("t5", "T5ForConditionalGeneration"), |
|
("umt5", "UMT5ForConditionalGeneration"), |
|
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), |
|
] |
|
) |
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("pop2piano", "Pop2PianoForConditionalGeneration"), |
|
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), |
|
("speech_to_text", "Speech2TextForConditionalGeneration"), |
|
("speecht5", "SpeechT5ForSpeechToText"), |
|
("whisper", "WhisperForConditionalGeneration"), |
|
] |
|
) |
|
|
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForSequenceClassification"), |
|
("bart", "BartForSequenceClassification"), |
|
("bert", "BertForSequenceClassification"), |
|
("big_bird", "BigBirdForSequenceClassification"), |
|
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), |
|
("biogpt", "BioGptForSequenceClassification"), |
|
("bloom", "BloomForSequenceClassification"), |
|
("camembert", "CamembertForSequenceClassification"), |
|
("canine", "CanineForSequenceClassification"), |
|
("code_llama", "LlamaForSequenceClassification"), |
|
("convbert", "ConvBertForSequenceClassification"), |
|
("ctrl", "CTRLForSequenceClassification"), |
|
("data2vec-text", "Data2VecTextForSequenceClassification"), |
|
("deberta", "DebertaForSequenceClassification"), |
|
("deberta-v2", "DebertaV2ForSequenceClassification"), |
|
("distilbert", "DistilBertForSequenceClassification"), |
|
("electra", "ElectraForSequenceClassification"), |
|
("ernie", "ErnieForSequenceClassification"), |
|
("ernie_m", "ErnieMForSequenceClassification"), |
|
("esm", "EsmForSequenceClassification"), |
|
("falcon", "FalconForSequenceClassification"), |
|
("flaubert", "FlaubertForSequenceClassification"), |
|
("fnet", "FNetForSequenceClassification"), |
|
("funnel", "FunnelForSequenceClassification"), |
|
("gpt-sw3", "GPT2ForSequenceClassification"), |
|
("gpt2", "GPT2ForSequenceClassification"), |
|
("gpt_bigcode", "GPTBigCodeForSequenceClassification"), |
|
("gpt_neo", "GPTNeoForSequenceClassification"), |
|
("gpt_neox", "GPTNeoXForSequenceClassification"), |
|
("gptj", "GPTJForSequenceClassification"), |
|
("ibert", "IBertForSequenceClassification"), |
|
("layoutlm", "LayoutLMForSequenceClassification"), |
|
("layoutlmv2", "LayoutLMv2ForSequenceClassification"), |
|
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), |
|
("led", "LEDForSequenceClassification"), |
|
("lilt", "LiltForSequenceClassification"), |
|
("llama", "LlamaForSequenceClassification"), |
|
("longformer", "LongformerForSequenceClassification"), |
|
("luke", "LukeForSequenceClassification"), |
|
("markuplm", "MarkupLMForSequenceClassification"), |
|
("mbart", "MBartForSequenceClassification"), |
|
("mega", "MegaForSequenceClassification"), |
|
("megatron-bert", "MegatronBertForSequenceClassification"), |
|
("mistral", "MistralForSequenceClassification"), |
|
("mobilebert", "MobileBertForSequenceClassification"), |
|
("mpnet", "MPNetForSequenceClassification"), |
|
("mpt", "MptForSequenceClassification"), |
|
("mra", "MraForSequenceClassification"), |
|
("mt5", "MT5ForSequenceClassification"), |
|
("mvp", "MvpForSequenceClassification"), |
|
("nezha", "NezhaForSequenceClassification"), |
|
("nystromformer", "NystromformerForSequenceClassification"), |
|
("open-llama", "OpenLlamaForSequenceClassification"), |
|
("openai-gpt", "OpenAIGPTForSequenceClassification"), |
|
("opt", "OPTForSequenceClassification"), |
|
("perceiver", "PerceiverForSequenceClassification"), |
|
("persimmon", "PersimmonForSequenceClassification"), |
|
("plbart", "PLBartForSequenceClassification"), |
|
("qdqbert", "QDQBertForSequenceClassification"), |
|
("reformer", "ReformerForSequenceClassification"), |
|
("rembert", "RemBertForSequenceClassification"), |
|
("roberta", "RobertaForSequenceClassification"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), |
|
("roc_bert", "RoCBertForSequenceClassification"), |
|
("roformer", "RoFormerForSequenceClassification"), |
|
("squeezebert", "SqueezeBertForSequenceClassification"), |
|
("t5", "T5ForSequenceClassification"), |
|
("tapas", "TapasForSequenceClassification"), |
|
("transfo-xl", "TransfoXLForSequenceClassification"), |
|
("umt5", "UMT5ForSequenceClassification"), |
|
("xlm", "XLMForSequenceClassification"), |
|
("xlm-roberta", "XLMRobertaForSequenceClassification"), |
|
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), |
|
("xlnet", "XLNetForSequenceClassification"), |
|
("xmod", "XmodForSequenceClassification"), |
|
("yoso", "YosoForSequenceClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForQuestionAnswering"), |
|
("bart", "BartForQuestionAnswering"), |
|
("bert", "BertForQuestionAnswering"), |
|
("big_bird", "BigBirdForQuestionAnswering"), |
|
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), |
|
("bloom", "BloomForQuestionAnswering"), |
|
("camembert", "CamembertForQuestionAnswering"), |
|
("canine", "CanineForQuestionAnswering"), |
|
("convbert", "ConvBertForQuestionAnswering"), |
|
("data2vec-text", "Data2VecTextForQuestionAnswering"), |
|
("deberta", "DebertaForQuestionAnswering"), |
|
("deberta-v2", "DebertaV2ForQuestionAnswering"), |
|
("distilbert", "DistilBertForQuestionAnswering"), |
|
("electra", "ElectraForQuestionAnswering"), |
|
("ernie", "ErnieForQuestionAnswering"), |
|
("ernie_m", "ErnieMForQuestionAnswering"), |
|
("falcon", "FalconForQuestionAnswering"), |
|
("flaubert", "FlaubertForQuestionAnsweringSimple"), |
|
("fnet", "FNetForQuestionAnswering"), |
|
("funnel", "FunnelForQuestionAnswering"), |
|
("gpt2", "GPT2ForQuestionAnswering"), |
|
("gpt_neo", "GPTNeoForQuestionAnswering"), |
|
("gpt_neox", "GPTNeoXForQuestionAnswering"), |
|
("gptj", "GPTJForQuestionAnswering"), |
|
("ibert", "IBertForQuestionAnswering"), |
|
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), |
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), |
|
("led", "LEDForQuestionAnswering"), |
|
("lilt", "LiltForQuestionAnswering"), |
|
("longformer", "LongformerForQuestionAnswering"), |
|
("luke", "LukeForQuestionAnswering"), |
|
("lxmert", "LxmertForQuestionAnswering"), |
|
("markuplm", "MarkupLMForQuestionAnswering"), |
|
("mbart", "MBartForQuestionAnswering"), |
|
("mega", "MegaForQuestionAnswering"), |
|
("megatron-bert", "MegatronBertForQuestionAnswering"), |
|
("mobilebert", "MobileBertForQuestionAnswering"), |
|
("mpnet", "MPNetForQuestionAnswering"), |
|
("mpt", "MptForQuestionAnswering"), |
|
("mra", "MraForQuestionAnswering"), |
|
("mt5", "MT5ForQuestionAnswering"), |
|
("mvp", "MvpForQuestionAnswering"), |
|
("nezha", "NezhaForQuestionAnswering"), |
|
("nystromformer", "NystromformerForQuestionAnswering"), |
|
("opt", "OPTForQuestionAnswering"), |
|
("qdqbert", "QDQBertForQuestionAnswering"), |
|
("reformer", "ReformerForQuestionAnswering"), |
|
("rembert", "RemBertForQuestionAnswering"), |
|
("roberta", "RobertaForQuestionAnswering"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"), |
|
("roc_bert", "RoCBertForQuestionAnswering"), |
|
("roformer", "RoFormerForQuestionAnswering"), |
|
("splinter", "SplinterForQuestionAnswering"), |
|
("squeezebert", "SqueezeBertForQuestionAnswering"), |
|
("t5", "T5ForQuestionAnswering"), |
|
("umt5", "UMT5ForQuestionAnswering"), |
|
("xlm", "XLMForQuestionAnsweringSimple"), |
|
("xlm-roberta", "XLMRobertaForQuestionAnswering"), |
|
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), |
|
("xlnet", "XLNetForQuestionAnsweringSimple"), |
|
("xmod", "XmodForQuestionAnswering"), |
|
("yoso", "YosoForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("tapas", "TapasForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("blip-2", "Blip2ForConditionalGeneration"), |
|
("vilt", "ViltForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("layoutlm", "LayoutLMForQuestionAnswering"), |
|
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), |
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForTokenClassification"), |
|
("bert", "BertForTokenClassification"), |
|
("big_bird", "BigBirdForTokenClassification"), |
|
("biogpt", "BioGptForTokenClassification"), |
|
("bloom", "BloomForTokenClassification"), |
|
("bros", "BrosForTokenClassification"), |
|
("camembert", "CamembertForTokenClassification"), |
|
("canine", "CanineForTokenClassification"), |
|
("convbert", "ConvBertForTokenClassification"), |
|
("data2vec-text", "Data2VecTextForTokenClassification"), |
|
("deberta", "DebertaForTokenClassification"), |
|
("deberta-v2", "DebertaV2ForTokenClassification"), |
|
("distilbert", "DistilBertForTokenClassification"), |
|
("electra", "ElectraForTokenClassification"), |
|
("ernie", "ErnieForTokenClassification"), |
|
("ernie_m", "ErnieMForTokenClassification"), |
|
("esm", "EsmForTokenClassification"), |
|
("falcon", "FalconForTokenClassification"), |
|
("flaubert", "FlaubertForTokenClassification"), |
|
("fnet", "FNetForTokenClassification"), |
|
("funnel", "FunnelForTokenClassification"), |
|
("gpt-sw3", "GPT2ForTokenClassification"), |
|
("gpt2", "GPT2ForTokenClassification"), |
|
("gpt_bigcode", "GPTBigCodeForTokenClassification"), |
|
("gpt_neo", "GPTNeoForTokenClassification"), |
|
("gpt_neox", "GPTNeoXForTokenClassification"), |
|
("ibert", "IBertForTokenClassification"), |
|
("layoutlm", "LayoutLMForTokenClassification"), |
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"), |
|
("layoutlmv3", "LayoutLMv3ForTokenClassification"), |
|
("lilt", "LiltForTokenClassification"), |
|
("longformer", "LongformerForTokenClassification"), |
|
("luke", "LukeForTokenClassification"), |
|
("markuplm", "MarkupLMForTokenClassification"), |
|
("mega", "MegaForTokenClassification"), |
|
("megatron-bert", "MegatronBertForTokenClassification"), |
|
("mobilebert", "MobileBertForTokenClassification"), |
|
("mpnet", "MPNetForTokenClassification"), |
|
("mpt", "MptForTokenClassification"), |
|
("mra", "MraForTokenClassification"), |
|
("nezha", "NezhaForTokenClassification"), |
|
("nystromformer", "NystromformerForTokenClassification"), |
|
("qdqbert", "QDQBertForTokenClassification"), |
|
("rembert", "RemBertForTokenClassification"), |
|
("roberta", "RobertaForTokenClassification"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), |
|
("roc_bert", "RoCBertForTokenClassification"), |
|
("roformer", "RoFormerForTokenClassification"), |
|
("squeezebert", "SqueezeBertForTokenClassification"), |
|
("xlm", "XLMForTokenClassification"), |
|
("xlm-roberta", "XLMRobertaForTokenClassification"), |
|
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), |
|
("xlnet", "XLNetForTokenClassification"), |
|
("xmod", "XmodForTokenClassification"), |
|
("yoso", "YosoForTokenClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "AlbertForMultipleChoice"), |
|
("bert", "BertForMultipleChoice"), |
|
("big_bird", "BigBirdForMultipleChoice"), |
|
("camembert", "CamembertForMultipleChoice"), |
|
("canine", "CanineForMultipleChoice"), |
|
("convbert", "ConvBertForMultipleChoice"), |
|
("data2vec-text", "Data2VecTextForMultipleChoice"), |
|
("deberta-v2", "DebertaV2ForMultipleChoice"), |
|
("distilbert", "DistilBertForMultipleChoice"), |
|
("electra", "ElectraForMultipleChoice"), |
|
("ernie", "ErnieForMultipleChoice"), |
|
("ernie_m", "ErnieMForMultipleChoice"), |
|
("flaubert", "FlaubertForMultipleChoice"), |
|
("fnet", "FNetForMultipleChoice"), |
|
("funnel", "FunnelForMultipleChoice"), |
|
("ibert", "IBertForMultipleChoice"), |
|
("longformer", "LongformerForMultipleChoice"), |
|
("luke", "LukeForMultipleChoice"), |
|
("mega", "MegaForMultipleChoice"), |
|
("megatron-bert", "MegatronBertForMultipleChoice"), |
|
("mobilebert", "MobileBertForMultipleChoice"), |
|
("mpnet", "MPNetForMultipleChoice"), |
|
("mra", "MraForMultipleChoice"), |
|
("nezha", "NezhaForMultipleChoice"), |
|
("nystromformer", "NystromformerForMultipleChoice"), |
|
("qdqbert", "QDQBertForMultipleChoice"), |
|
("rembert", "RemBertForMultipleChoice"), |
|
("roberta", "RobertaForMultipleChoice"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"), |
|
("roc_bert", "RoCBertForMultipleChoice"), |
|
("roformer", "RoFormerForMultipleChoice"), |
|
("squeezebert", "SqueezeBertForMultipleChoice"), |
|
("xlm", "XLMForMultipleChoice"), |
|
("xlm-roberta", "XLMRobertaForMultipleChoice"), |
|
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), |
|
("xlnet", "XLNetForMultipleChoice"), |
|
("xmod", "XmodForMultipleChoice"), |
|
("yoso", "YosoForMultipleChoice"), |
|
] |
|
) |
|
|
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("bert", "BertForNextSentencePrediction"), |
|
("ernie", "ErnieForNextSentencePrediction"), |
|
("fnet", "FNetForNextSentencePrediction"), |
|
("megatron-bert", "MegatronBertForNextSentencePrediction"), |
|
("mobilebert", "MobileBertForNextSentencePrediction"), |
|
("nezha", "NezhaForNextSentencePrediction"), |
|
("qdqbert", "QDQBertForNextSentencePrediction"), |
|
] |
|
) |
|
|
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("audio-spectrogram-transformer", "ASTForAudioClassification"), |
|
("data2vec-audio", "Data2VecAudioForSequenceClassification"), |
|
("hubert", "HubertForSequenceClassification"), |
|
("sew", "SEWForSequenceClassification"), |
|
("sew-d", "SEWDForSequenceClassification"), |
|
("unispeech", "UniSpeechForSequenceClassification"), |
|
("unispeech-sat", "UniSpeechSatForSequenceClassification"), |
|
("wav2vec2", "Wav2Vec2ForSequenceClassification"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), |
|
("wavlm", "WavLMForSequenceClassification"), |
|
("whisper", "WhisperForAudioClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("data2vec-audio", "Data2VecAudioForCTC"), |
|
("hubert", "HubertForCTC"), |
|
("mctct", "MCTCTForCTC"), |
|
("sew", "SEWForCTC"), |
|
("sew-d", "SEWDForCTC"), |
|
("unispeech", "UniSpeechForCTC"), |
|
("unispeech-sat", "UniSpeechSatForCTC"), |
|
("wav2vec2", "Wav2Vec2ForCTC"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"), |
|
("wavlm", "WavLMForCTC"), |
|
] |
|
) |
|
|
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"), |
|
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), |
|
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"), |
|
("wavlm", "WavLMForAudioFrameClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("data2vec-audio", "Data2VecAudioForXVector"), |
|
("unispeech-sat", "UniSpeechSatForXVector"), |
|
("wav2vec2", "Wav2Vec2ForXVector"), |
|
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"), |
|
("wavlm", "WavLMForXVector"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("speecht5", "SpeechT5ForTextToSpeech"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bark", "BarkModel"), |
|
("musicgen", "MusicgenForConditionalGeneration"), |
|
("vits", "VitsModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("align", "AlignModel"), |
|
("altclip", "AltCLIPModel"), |
|
("blip", "BlipModel"), |
|
("chinese_clip", "ChineseCLIPModel"), |
|
("clip", "CLIPModel"), |
|
("clipseg", "CLIPSegModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bit", "BitBackbone"), |
|
("convnext", "ConvNextBackbone"), |
|
("convnextv2", "ConvNextV2Backbone"), |
|
("dinat", "DinatBackbone"), |
|
("dinov2", "Dinov2Backbone"), |
|
("focalnet", "FocalNetBackbone"), |
|
("maskformer-swin", "MaskFormerSwinBackbone"), |
|
("nat", "NatBackbone"), |
|
("resnet", "ResNetBackbone"), |
|
("swin", "SwinBackbone"), |
|
("timm_backbone", "TimmBackbone"), |
|
("vitdet", "VitDetBackbone"), |
|
] |
|
) |
|
|
|
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("sam", "SamModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("albert", "AlbertModel"), |
|
("bert", "BertModel"), |
|
("big_bird", "BigBirdModel"), |
|
("data2vec-text", "Data2VecTextModel"), |
|
("deberta", "DebertaModel"), |
|
("deberta-v2", "DebertaV2Model"), |
|
("distilbert", "DistilBertModel"), |
|
("electra", "ElectraModel"), |
|
("flaubert", "FlaubertModel"), |
|
("ibert", "IBertModel"), |
|
("longformer", "LongformerModel"), |
|
("mobilebert", "MobileBertModel"), |
|
("mt5", "MT5EncoderModel"), |
|
("nystromformer", "NystromformerModel"), |
|
("reformer", "ReformerModel"), |
|
("rembert", "RemBertModel"), |
|
("roberta", "RobertaModel"), |
|
("roberta-prelayernorm", "RobertaPreLayerNormModel"), |
|
("roc_bert", "RoCBertModel"), |
|
("roformer", "RoFormerModel"), |
|
("squeezebert", "SqueezeBertModel"), |
|
("t5", "T5EncoderModel"), |
|
("umt5", "UMT5EncoderModel"), |
|
("xlm", "XLMModel"), |
|
("xlm-roberta", "XLMRobertaModel"), |
|
("xlm-roberta-xl", "XLMRobertaXLModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("swin2sr", "Swin2SRForImageSuperResolution"), |
|
] |
|
) |
|
|
|
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) |
|
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
|
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) |
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) |
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) |
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) |
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) |
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
|
) |
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) |
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) |
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) |
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) |
|
|
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES |
|
) |
|
|
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) |
|
|
|
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) |
|
|
|
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) |
|
|
|
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) |
|
|
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) |
|
|
|
|
|
class AutoModelForMaskGeneration(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING |
|
|
|
|
|
class AutoModelForTextEncoding(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING |
|
|
|
|
|
class AutoModelForImageToImage(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING |
|
|
|
|
|
class AutoModel(_BaseAutoModelClass): |
|
_model_mapping = MODEL_MAPPING |
|
|
|
|
|
AutoModel = auto_class_update(AutoModel) |
|
|
|
|
|
class AutoModelForPreTraining(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING |
|
|
|
|
|
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") |
|
|
|
|
|
|
|
class _AutoModelWithLMHead(_BaseAutoModelClass): |
|
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING |
|
|
|
|
|
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") |
|
|
|
|
|
class AutoModelForCausalLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") |
|
|
|
|
|
class AutoModelForMaskedLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
|
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") |
|
|
|
|
|
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
|
|
|
|
AutoModelForSeq2SeqLM = auto_class_update( |
|
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" |
|
) |
|
|
|
|
|
class AutoModelForSequenceClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForSequenceClassification = auto_class_update( |
|
AutoModelForSequenceClassification, head_doc="sequence classification" |
|
) |
|
|
|
|
|
class AutoModelForQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") |
|
|
|
|
|
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForTableQuestionAnswering = auto_class_update( |
|
AutoModelForTableQuestionAnswering, |
|
head_doc="table question answering", |
|
checkpoint_for_example="google/tapas-base-finetuned-wtq", |
|
) |
|
|
|
|
|
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForVisualQuestionAnswering = auto_class_update( |
|
AutoModelForVisualQuestionAnswering, |
|
head_doc="visual question answering", |
|
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", |
|
) |
|
|
|
|
|
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForDocumentQuestionAnswering = auto_class_update( |
|
AutoModelForDocumentQuestionAnswering, |
|
head_doc="document question answering", |
|
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', |
|
) |
|
|
|
|
|
class AutoModelForTokenClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") |
|
|
|
|
|
class AutoModelForMultipleChoice(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING |
|
|
|
|
|
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") |
|
|
|
|
|
class AutoModelForNextSentencePrediction(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING |
|
|
|
|
|
AutoModelForNextSentencePrediction = auto_class_update( |
|
AutoModelForNextSentencePrediction, head_doc="next sentence prediction" |
|
) |
|
|
|
|
|
class AutoModelForImageClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") |
|
|
|
|
|
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForZeroShotImageClassification = auto_class_update( |
|
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" |
|
) |
|
|
|
|
|
class AutoModelForImageSegmentation(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING |
|
|
|
|
|
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") |
|
|
|
|
|
class AutoModelForSemanticSegmentation(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING |
|
|
|
|
|
AutoModelForSemanticSegmentation = auto_class_update( |
|
AutoModelForSemanticSegmentation, head_doc="semantic segmentation" |
|
) |
|
|
|
|
|
class AutoModelForUniversalSegmentation(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING |
|
|
|
|
|
AutoModelForUniversalSegmentation = auto_class_update( |
|
AutoModelForUniversalSegmentation, head_doc="universal image segmentation" |
|
) |
|
|
|
|
|
class AutoModelForInstanceSegmentation(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING |
|
|
|
|
|
AutoModelForInstanceSegmentation = auto_class_update( |
|
AutoModelForInstanceSegmentation, head_doc="instance segmentation" |
|
) |
|
|
|
|
|
class AutoModelForObjectDetection(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING |
|
|
|
|
|
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") |
|
|
|
|
|
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING |
|
|
|
|
|
AutoModelForZeroShotObjectDetection = auto_class_update( |
|
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" |
|
) |
|
|
|
|
|
class AutoModelForDepthEstimation(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING |
|
|
|
|
|
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") |
|
|
|
|
|
class AutoModelForVideoClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") |
|
|
|
|
|
class AutoModelForVision2Seq(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING |
|
|
|
|
|
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") |
|
|
|
|
|
class AutoModelForAudioClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") |
|
|
|
|
|
class AutoModelForCTC(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_CTC_MAPPING |
|
|
|
|
|
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") |
|
|
|
|
|
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING |
|
|
|
|
|
AutoModelForSpeechSeq2Seq = auto_class_update( |
|
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" |
|
) |
|
|
|
|
|
class AutoModelForAudioFrameClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForAudioFrameClassification = auto_class_update( |
|
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" |
|
) |
|
|
|
|
|
class AutoModelForAudioXVector(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING |
|
|
|
|
|
class AutoModelForTextToSpectrogram(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING |
|
|
|
|
|
class AutoModelForTextToWaveform(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING |
|
|
|
|
|
class AutoBackbone(_BaseAutoBackboneClass): |
|
_model_mapping = MODEL_FOR_BACKBONE_MAPPING |
|
|
|
|
|
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") |
|
|
|
|
|
class AutoModelForMaskedImageModeling(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING |
|
|
|
|
|
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") |
|
|
|
|
|
class AutoModelWithLMHead(_AutoModelWithLMHead): |
|
@classmethod |
|
def from_config(cls, config): |
|
warnings.warn( |
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_config(config) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
warnings.warn( |
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|