import torch.nn as nn import torchvision.models as models from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig from cliport.models.core import fusion from cliport.models.rn50_bert_lingunet import RN50BertLingUNet class UntrainedRN50BertLingUNet(RN50BertLingUNet): """ Untrained ImageNet RN50 & Bert with U-Net skip connections """ def __init__(self, input_shape, output_dim, cfg, device, preprocess): super().__init__(input_shape, output_dim, cfg, device, preprocess) def _load_vision_fcn(self): resnet50 = models.resnet50(pretrained=False) modules = list(resnet50.children())[:-2] self.stem = nn.Sequential(*modules[:4]) self.layer1 = modules[4] self.layer2 = modules[5] self.layer3 = modules[6] self.layer4 = modules[7] def _load_lang_enc(self): self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') # only Tokenizer is pre-trained distilbert_config = DistilBertConfig() self.text_encoder = DistilBertModel(distilbert_config) self.text_fc = nn.Linear(768, 1024) self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2) self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4) self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8) self.proj_input_dim = 512 if 'word' in self.lang_fusion_type else 1024 self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024) self.lang_proj2 = nn.Linear(self.proj_input_dim, 512) self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)