File size: 1,422 Bytes
8b26dd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
"""
Subclasses VisionTextDualEncoderModel to customize text pooler.
"""
from typing import Optional
import torch
from transformers import AutoModel, VisionTextDualEncoderModel
from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler
# @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING)
class CustomCLIPModel(VisionTextDualEncoderModel):
config_class = CustomCLIPConfig
DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler(
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR
)
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = (
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
def __init__(
self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs
):
config = config if config is None else CustomCLIPConfig.from_base(config)
super().__init__(
config, # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__?
*args,
**kwargs,
)
self.text_model.pooler = (
(self.DEFAULT_TEXT_MODEL_POOLER_TYPE)(
**self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
if config is None
else get_text_model_pooler(config.text_model_pooler)(
**config.text_model_pooler_kwargs
)
)
AutoModel.register(CustomCLIPConfig, CustomCLIPModel)
|