GenSim / cliport /models /untrained_rn50_bert_lingunet.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
No virus
1.69 kB
import torch.nn as nn
import torchvision.models as models
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
from cliport.models.core import fusion
from cliport.models.rn50_bert_lingunet import RN50BertLingUNet
class UntrainedRN50BertLingUNet(RN50BertLingUNet):
""" Untrained ImageNet RN50 & Bert with U-Net skip connections """
def __init__(self, input_shape, output_dim, cfg, device, preprocess):
super().__init__(input_shape, output_dim, cfg, device, preprocess)
def _load_vision_fcn(self):
resnet50 = models.resnet50(pretrained=False)
modules = list(resnet50.children())[:-2]
self.stem = nn.Sequential(*modules[:4])
self.layer1 = modules[4]
self.layer2 = modules[5]
self.layer3 = modules[6]
self.layer4 = modules[7]
def _load_lang_enc(self):
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') # only Tokenizer is pre-trained
distilbert_config = DistilBertConfig()
self.text_encoder = DistilBertModel(distilbert_config)
self.text_fc = nn.Linear(768, 1024)
self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2)
self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4)
self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8)
self.proj_input_dim = 512 if 'word' in self.lang_fusion_type else 1024
self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024)
self.lang_proj2 = nn.Linear(self.proj_input_dim, 512)
self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)