|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Auto Model class.""" |
|
|
|
|
|
import warnings |
|
from collections import OrderedDict |
|
|
|
from ...utils import logging |
|
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update |
|
from .configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
TF_MODEL_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertModel"), |
|
("bart", "TFBartModel"), |
|
("bert", "TFBertModel"), |
|
("blenderbot", "TFBlenderbotModel"), |
|
("blenderbot-small", "TFBlenderbotSmallModel"), |
|
("blip", "TFBlipModel"), |
|
("camembert", "TFCamembertModel"), |
|
("clip", "TFCLIPModel"), |
|
("convbert", "TFConvBertModel"), |
|
("convnext", "TFConvNextModel"), |
|
("ctrl", "TFCTRLModel"), |
|
("cvt", "TFCvtModel"), |
|
("data2vec-vision", "TFData2VecVisionModel"), |
|
("deberta", "TFDebertaModel"), |
|
("deberta-v2", "TFDebertaV2Model"), |
|
("deit", "TFDeiTModel"), |
|
("distilbert", "TFDistilBertModel"), |
|
("dpr", "TFDPRQuestionEncoder"), |
|
("efficientformer", "TFEfficientFormerModel"), |
|
("electra", "TFElectraModel"), |
|
("esm", "TFEsmModel"), |
|
("flaubert", "TFFlaubertModel"), |
|
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), |
|
("gpt-sw3", "TFGPT2Model"), |
|
("gpt2", "TFGPT2Model"), |
|
("gptj", "TFGPTJModel"), |
|
("groupvit", "TFGroupViTModel"), |
|
("hubert", "TFHubertModel"), |
|
("layoutlm", "TFLayoutLMModel"), |
|
("layoutlmv3", "TFLayoutLMv3Model"), |
|
("led", "TFLEDModel"), |
|
("longformer", "TFLongformerModel"), |
|
("lxmert", "TFLxmertModel"), |
|
("marian", "TFMarianModel"), |
|
("mbart", "TFMBartModel"), |
|
("mobilebert", "TFMobileBertModel"), |
|
("mobilevit", "TFMobileViTModel"), |
|
("mpnet", "TFMPNetModel"), |
|
("mt5", "TFMT5Model"), |
|
("openai-gpt", "TFOpenAIGPTModel"), |
|
("opt", "TFOPTModel"), |
|
("pegasus", "TFPegasusModel"), |
|
("regnet", "TFRegNetModel"), |
|
("rembert", "TFRemBertModel"), |
|
("resnet", "TFResNetModel"), |
|
("roberta", "TFRobertaModel"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), |
|
("roformer", "TFRoFormerModel"), |
|
("sam", "TFSamModel"), |
|
("segformer", "TFSegformerModel"), |
|
("speech_to_text", "TFSpeech2TextModel"), |
|
("swin", "TFSwinModel"), |
|
("t5", "TFT5Model"), |
|
("tapas", "TFTapasModel"), |
|
("transfo-xl", "TFTransfoXLModel"), |
|
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"), |
|
("vit", "TFViTModel"), |
|
("vit_mae", "TFViTMAEModel"), |
|
("wav2vec2", "TFWav2Vec2Model"), |
|
("whisper", "TFWhisperModel"), |
|
("xglm", "TFXGLMModel"), |
|
("xlm", "TFXLMModel"), |
|
("xlm-roberta", "TFXLMRobertaModel"), |
|
("xlnet", "TFXLNetModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForPreTraining"), |
|
("bart", "TFBartForConditionalGeneration"), |
|
("bert", "TFBertForPreTraining"), |
|
("camembert", "TFCamembertForMaskedLM"), |
|
("ctrl", "TFCTRLLMHeadModel"), |
|
("distilbert", "TFDistilBertForMaskedLM"), |
|
("electra", "TFElectraForPreTraining"), |
|
("flaubert", "TFFlaubertWithLMHeadModel"), |
|
("funnel", "TFFunnelForPreTraining"), |
|
("gpt-sw3", "TFGPT2LMHeadModel"), |
|
("gpt2", "TFGPT2LMHeadModel"), |
|
("layoutlm", "TFLayoutLMForMaskedLM"), |
|
("lxmert", "TFLxmertForPreTraining"), |
|
("mobilebert", "TFMobileBertForPreTraining"), |
|
("mpnet", "TFMPNetForMaskedLM"), |
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"), |
|
("roberta", "TFRobertaForMaskedLM"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), |
|
("t5", "TFT5ForConditionalGeneration"), |
|
("tapas", "TFTapasForMaskedLM"), |
|
("transfo-xl", "TFTransfoXLLMHeadModel"), |
|
("vit_mae", "TFViTMAEForPreTraining"), |
|
("xlm", "TFXLMWithLMHeadModel"), |
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"), |
|
("xlnet", "TFXLNetLMHeadModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForMaskedLM"), |
|
("bart", "TFBartForConditionalGeneration"), |
|
("bert", "TFBertForMaskedLM"), |
|
("camembert", "TFCamembertForMaskedLM"), |
|
("convbert", "TFConvBertForMaskedLM"), |
|
("ctrl", "TFCTRLLMHeadModel"), |
|
("distilbert", "TFDistilBertForMaskedLM"), |
|
("electra", "TFElectraForMaskedLM"), |
|
("esm", "TFEsmForMaskedLM"), |
|
("flaubert", "TFFlaubertWithLMHeadModel"), |
|
("funnel", "TFFunnelForMaskedLM"), |
|
("gpt-sw3", "TFGPT2LMHeadModel"), |
|
("gpt2", "TFGPT2LMHeadModel"), |
|
("gptj", "TFGPTJForCausalLM"), |
|
("layoutlm", "TFLayoutLMForMaskedLM"), |
|
("led", "TFLEDForConditionalGeneration"), |
|
("longformer", "TFLongformerForMaskedLM"), |
|
("marian", "TFMarianMTModel"), |
|
("mobilebert", "TFMobileBertForMaskedLM"), |
|
("mpnet", "TFMPNetForMaskedLM"), |
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"), |
|
("rembert", "TFRemBertForMaskedLM"), |
|
("roberta", "TFRobertaForMaskedLM"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), |
|
("roformer", "TFRoFormerForMaskedLM"), |
|
("speech_to_text", "TFSpeech2TextForConditionalGeneration"), |
|
("t5", "TFT5ForConditionalGeneration"), |
|
("tapas", "TFTapasForMaskedLM"), |
|
("transfo-xl", "TFTransfoXLLMHeadModel"), |
|
("whisper", "TFWhisperForConditionalGeneration"), |
|
("xlm", "TFXLMWithLMHeadModel"), |
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"), |
|
("xlnet", "TFXLNetLMHeadModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bert", "TFBertLMHeadModel"), |
|
("camembert", "TFCamembertForCausalLM"), |
|
("ctrl", "TFCTRLLMHeadModel"), |
|
("gpt-sw3", "TFGPT2LMHeadModel"), |
|
("gpt2", "TFGPT2LMHeadModel"), |
|
("gptj", "TFGPTJForCausalLM"), |
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"), |
|
("opt", "TFOPTForCausalLM"), |
|
("rembert", "TFRemBertForCausalLM"), |
|
("roberta", "TFRobertaForCausalLM"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"), |
|
("roformer", "TFRoFormerForCausalLM"), |
|
("transfo-xl", "TFTransfoXLLMHeadModel"), |
|
("xglm", "TFXGLMForCausalLM"), |
|
("xlm", "TFXLMWithLMHeadModel"), |
|
("xlm-roberta", "TFXLMRobertaForCausalLM"), |
|
("xlnet", "TFXLNetLMHeadModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("deit", "TFDeiTForMaskedImageModeling"), |
|
("swin", "TFSwinForMaskedImageModeling"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("convnext", "TFConvNextForImageClassification"), |
|
("cvt", "TFCvtForImageClassification"), |
|
("data2vec-vision", "TFData2VecVisionForImageClassification"), |
|
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), |
|
( |
|
"efficientformer", |
|
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"), |
|
), |
|
("mobilevit", "TFMobileViTForImageClassification"), |
|
("regnet", "TFRegNetForImageClassification"), |
|
("resnet", "TFResNetForImageClassification"), |
|
("segformer", "TFSegformerForImageClassification"), |
|
("swin", "TFSwinForImageClassification"), |
|
("vit", "TFViTForImageClassification"), |
|
] |
|
) |
|
|
|
|
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("blip", "TFBlipModel"), |
|
("clip", "TFCLIPModel"), |
|
] |
|
) |
|
|
|
|
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), |
|
("mobilevit", "TFMobileViTForSemanticSegmentation"), |
|
("segformer", "TFSegformerForSemanticSegmentation"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("blip", "TFBlipForConditionalGeneration"), |
|
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForMaskedLM"), |
|
("bert", "TFBertForMaskedLM"), |
|
("camembert", "TFCamembertForMaskedLM"), |
|
("convbert", "TFConvBertForMaskedLM"), |
|
("deberta", "TFDebertaForMaskedLM"), |
|
("deberta-v2", "TFDebertaV2ForMaskedLM"), |
|
("distilbert", "TFDistilBertForMaskedLM"), |
|
("electra", "TFElectraForMaskedLM"), |
|
("esm", "TFEsmForMaskedLM"), |
|
("flaubert", "TFFlaubertWithLMHeadModel"), |
|
("funnel", "TFFunnelForMaskedLM"), |
|
("layoutlm", "TFLayoutLMForMaskedLM"), |
|
("longformer", "TFLongformerForMaskedLM"), |
|
("mobilebert", "TFMobileBertForMaskedLM"), |
|
("mpnet", "TFMPNetForMaskedLM"), |
|
("rembert", "TFRemBertForMaskedLM"), |
|
("roberta", "TFRobertaForMaskedLM"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), |
|
("roformer", "TFRoFormerForMaskedLM"), |
|
("tapas", "TFTapasForMaskedLM"), |
|
("xlm", "TFXLMWithLMHeadModel"), |
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("bart", "TFBartForConditionalGeneration"), |
|
("blenderbot", "TFBlenderbotForConditionalGeneration"), |
|
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), |
|
("encoder-decoder", "TFEncoderDecoderModel"), |
|
("led", "TFLEDForConditionalGeneration"), |
|
("marian", "TFMarianMTModel"), |
|
("mbart", "TFMBartForConditionalGeneration"), |
|
("mt5", "TFMT5ForConditionalGeneration"), |
|
("pegasus", "TFPegasusForConditionalGeneration"), |
|
("t5", "TFT5ForConditionalGeneration"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("speech_to_text", "TFSpeech2TextForConditionalGeneration"), |
|
("whisper", "TFWhisperForConditionalGeneration"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForSequenceClassification"), |
|
("bart", "TFBartForSequenceClassification"), |
|
("bert", "TFBertForSequenceClassification"), |
|
("camembert", "TFCamembertForSequenceClassification"), |
|
("convbert", "TFConvBertForSequenceClassification"), |
|
("ctrl", "TFCTRLForSequenceClassification"), |
|
("deberta", "TFDebertaForSequenceClassification"), |
|
("deberta-v2", "TFDebertaV2ForSequenceClassification"), |
|
("distilbert", "TFDistilBertForSequenceClassification"), |
|
("electra", "TFElectraForSequenceClassification"), |
|
("esm", "TFEsmForSequenceClassification"), |
|
("flaubert", "TFFlaubertForSequenceClassification"), |
|
("funnel", "TFFunnelForSequenceClassification"), |
|
("gpt-sw3", "TFGPT2ForSequenceClassification"), |
|
("gpt2", "TFGPT2ForSequenceClassification"), |
|
("gptj", "TFGPTJForSequenceClassification"), |
|
("layoutlm", "TFLayoutLMForSequenceClassification"), |
|
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"), |
|
("longformer", "TFLongformerForSequenceClassification"), |
|
("mobilebert", "TFMobileBertForSequenceClassification"), |
|
("mpnet", "TFMPNetForSequenceClassification"), |
|
("openai-gpt", "TFOpenAIGPTForSequenceClassification"), |
|
("rembert", "TFRemBertForSequenceClassification"), |
|
("roberta", "TFRobertaForSequenceClassification"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"), |
|
("roformer", "TFRoFormerForSequenceClassification"), |
|
("tapas", "TFTapasForSequenceClassification"), |
|
("transfo-xl", "TFTransfoXLForSequenceClassification"), |
|
("xlm", "TFXLMForSequenceClassification"), |
|
("xlm-roberta", "TFXLMRobertaForSequenceClassification"), |
|
("xlnet", "TFXLNetForSequenceClassification"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForQuestionAnswering"), |
|
("bert", "TFBertForQuestionAnswering"), |
|
("camembert", "TFCamembertForQuestionAnswering"), |
|
("convbert", "TFConvBertForQuestionAnswering"), |
|
("deberta", "TFDebertaForQuestionAnswering"), |
|
("deberta-v2", "TFDebertaV2ForQuestionAnswering"), |
|
("distilbert", "TFDistilBertForQuestionAnswering"), |
|
("electra", "TFElectraForQuestionAnswering"), |
|
("flaubert", "TFFlaubertForQuestionAnsweringSimple"), |
|
("funnel", "TFFunnelForQuestionAnswering"), |
|
("gptj", "TFGPTJForQuestionAnswering"), |
|
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), |
|
("longformer", "TFLongformerForQuestionAnswering"), |
|
("mobilebert", "TFMobileBertForQuestionAnswering"), |
|
("mpnet", "TFMPNetForQuestionAnswering"), |
|
("rembert", "TFRemBertForQuestionAnswering"), |
|
("roberta", "TFRobertaForQuestionAnswering"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"), |
|
("roformer", "TFRoFormerForQuestionAnswering"), |
|
("xlm", "TFXLMForQuestionAnsweringSimple"), |
|
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), |
|
("xlnet", "TFXLNetForQuestionAnsweringSimple"), |
|
] |
|
) |
|
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")]) |
|
|
|
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("layoutlm", "TFLayoutLMForQuestionAnswering"), |
|
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), |
|
] |
|
) |
|
|
|
|
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("tapas", "TFTapasForQuestionAnswering"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForTokenClassification"), |
|
("bert", "TFBertForTokenClassification"), |
|
("camembert", "TFCamembertForTokenClassification"), |
|
("convbert", "TFConvBertForTokenClassification"), |
|
("deberta", "TFDebertaForTokenClassification"), |
|
("deberta-v2", "TFDebertaV2ForTokenClassification"), |
|
("distilbert", "TFDistilBertForTokenClassification"), |
|
("electra", "TFElectraForTokenClassification"), |
|
("esm", "TFEsmForTokenClassification"), |
|
("flaubert", "TFFlaubertForTokenClassification"), |
|
("funnel", "TFFunnelForTokenClassification"), |
|
("layoutlm", "TFLayoutLMForTokenClassification"), |
|
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"), |
|
("longformer", "TFLongformerForTokenClassification"), |
|
("mobilebert", "TFMobileBertForTokenClassification"), |
|
("mpnet", "TFMPNetForTokenClassification"), |
|
("rembert", "TFRemBertForTokenClassification"), |
|
("roberta", "TFRobertaForTokenClassification"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"), |
|
("roformer", "TFRoFormerForTokenClassification"), |
|
("xlm", "TFXLMForTokenClassification"), |
|
("xlm-roberta", "TFXLMRobertaForTokenClassification"), |
|
("xlnet", "TFXLNetForTokenClassification"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("albert", "TFAlbertForMultipleChoice"), |
|
("bert", "TFBertForMultipleChoice"), |
|
("camembert", "TFCamembertForMultipleChoice"), |
|
("convbert", "TFConvBertForMultipleChoice"), |
|
("deberta-v2", "TFDebertaV2ForMultipleChoice"), |
|
("distilbert", "TFDistilBertForMultipleChoice"), |
|
("electra", "TFElectraForMultipleChoice"), |
|
("flaubert", "TFFlaubertForMultipleChoice"), |
|
("funnel", "TFFunnelForMultipleChoice"), |
|
("longformer", "TFLongformerForMultipleChoice"), |
|
("mobilebert", "TFMobileBertForMultipleChoice"), |
|
("mpnet", "TFMPNetForMultipleChoice"), |
|
("rembert", "TFRemBertForMultipleChoice"), |
|
("roberta", "TFRobertaForMultipleChoice"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"), |
|
("roformer", "TFRoFormerForMultipleChoice"), |
|
("xlm", "TFXLMForMultipleChoice"), |
|
("xlm-roberta", "TFXLMRobertaForMultipleChoice"), |
|
("xlnet", "TFXLNetForMultipleChoice"), |
|
] |
|
) |
|
|
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("bert", "TFBertForNextSentencePrediction"), |
|
("mobilebert", "TFMobileBertForNextSentencePrediction"), |
|
] |
|
) |
|
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("sam", "TFSamModel"), |
|
] |
|
) |
|
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("albert", "TFAlbertModel"), |
|
("bert", "TFBertModel"), |
|
("convbert", "TFConvBertModel"), |
|
("deberta", "TFDebertaModel"), |
|
("deberta-v2", "TFDebertaV2Model"), |
|
("distilbert", "TFDistilBertModel"), |
|
("electra", "TFElectraModel"), |
|
("flaubert", "TFFlaubertModel"), |
|
("longformer", "TFLongformerModel"), |
|
("mobilebert", "TFMobileBertModel"), |
|
("mt5", "TFMT5EncoderModel"), |
|
("rembert", "TFRemBertModel"), |
|
("roberta", "TFRobertaModel"), |
|
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), |
|
("roformer", "TFRoFormerModel"), |
|
("t5", "TFT5EncoderModel"), |
|
("xlm", "TFXLMModel"), |
|
("xlm-roberta", "TFXLMRobertaModel"), |
|
] |
|
) |
|
|
|
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) |
|
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
|
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) |
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) |
|
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) |
|
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) |
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES |
|
) |
|
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
|
|
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES |
|
) |
|
|
|
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) |
|
|
|
|
|
class TFAutoModelForMaskGeneration(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING |
|
|
|
|
|
class TFAutoModelForTextEncoding(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING |
|
|
|
|
|
class TFAutoModel(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_MAPPING |
|
|
|
|
|
TFAutoModel = auto_class_update(TFAutoModel) |
|
|
|
|
|
class TFAutoModelForAudioClassification(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING |
|
|
|
|
|
TFAutoModelForAudioClassification = auto_class_update( |
|
TFAutoModelForAudioClassification, head_doc="audio classification" |
|
) |
|
|
|
|
|
class TFAutoModelForPreTraining(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING |
|
|
|
|
|
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining") |
|
|
|
|
|
|
|
class _TFAutoModelWithLMHead(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING |
|
|
|
|
|
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling") |
|
|
|
|
|
class TFAutoModelForCausalLM(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
|
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") |
|
|
|
|
|
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING |
|
|
|
|
|
TFAutoModelForMaskedImageModeling = auto_class_update( |
|
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling" |
|
) |
|
|
|
|
|
class TFAutoModelForImageClassification(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
|
TFAutoModelForImageClassification = auto_class_update( |
|
TFAutoModelForImageClassification, head_doc="image classification" |
|
) |
|
|
|
|
|
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
|
TFAutoModelForZeroShotImageClassification = auto_class_update( |
|
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" |
|
) |
|
|
|
|
|
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING |
|
|
|
|
|
TFAutoModelForSemanticSegmentation = auto_class_update( |
|
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" |
|
) |
|
|
|
|
|
class TFAutoModelForVision2Seq(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING |
|
|
|
|
|
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling") |
|
|
|
|
|
class TFAutoModelForMaskedLM(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
|
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling") |
|
|
|
|
|
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
|
|
|
|
TFAutoModelForSeq2SeqLM = auto_class_update( |
|
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" |
|
) |
|
|
|
|
|
class TFAutoModelForSequenceClassification(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING |
|
|
|
|
|
TFAutoModelForSequenceClassification = auto_class_update( |
|
TFAutoModelForSequenceClassification, head_doc="sequence classification" |
|
) |
|
|
|
|
|
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering") |
|
|
|
|
|
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
TFAutoModelForDocumentQuestionAnswering = auto_class_update( |
|
TFAutoModelForDocumentQuestionAnswering, |
|
head_doc="document question answering", |
|
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', |
|
) |
|
|
|
|
|
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
TFAutoModelForTableQuestionAnswering = auto_class_update( |
|
TFAutoModelForTableQuestionAnswering, |
|
head_doc="table question answering", |
|
checkpoint_for_example="google/tapas-base-finetuned-wtq", |
|
) |
|
|
|
|
|
class TFAutoModelForTokenClassification(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING |
|
|
|
|
|
TFAutoModelForTokenClassification = auto_class_update( |
|
TFAutoModelForTokenClassification, head_doc="token classification" |
|
) |
|
|
|
|
|
class TFAutoModelForMultipleChoice(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING |
|
|
|
|
|
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice") |
|
|
|
|
|
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING |
|
|
|
|
|
TFAutoModelForNextSentencePrediction = auto_class_update( |
|
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction" |
|
) |
|
|
|
|
|
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): |
|
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING |
|
|
|
|
|
TFAutoModelForSpeechSeq2Seq = auto_class_update( |
|
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" |
|
) |
|
|
|
|
|
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): |
|
@classmethod |
|
def from_config(cls, config): |
|
warnings.warn( |
|
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" |
|
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" |
|
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_config(config) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
warnings.warn( |
|
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" |
|
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" |
|
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|