aptlm / encoders.py
abwer
Initial commit
29134bd
raw
history blame
No virus
12.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class Encoder(nn.Module):
def __init__(self, d_ff=512, d_model=128, n_heads=8, dropout=.3, max_len=512):
super(Encoder, self).__init__()
#hyperparameters
self.d_ff = d_ff
self.d_model= d_model
self.n_heads = n_heads
self.dropout_rate = dropout
self.max_len = max_len
#layers
self.MultiHeadAttention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
self.LayerNorm = nn.LayerNorm(normalized_shape=[max_len, d_model], eps=1e-6)
self.fc_ff = nn.Linear(d_model, d_ff)
self.relu = nn.ReLU()
self.fc_model = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, padding_mask):
residual = x
x = self.MultiHeadAttention(x,x,x, key_padding_mask=padding_mask, need_weights=False)[0] # why do they pass in x for q,k,v here?
x = self.dropout(x)
x = self.LayerNorm(x + residual)
residual = x
x = self.fc_ff(x)
x = self.relu(x)
x = self.fc_model(x)
x = self.dropout(x)
x = self.LayerNorm(x + residual)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
#hyperparameters
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor):
x = x + self.pe[:x.size(0)]
return x
class Encoders(nn.Module):
def __init__(self, n_vocabs, n_layers=6, d_ff=512, d_model=128, n_heads=8, dropout=.3, max_len=512):
super(Encoders, self).__init__()
#hyperparameters
self.n_vocabs = n_vocabs
self.n_layers = n_layers
self.d_ff = d_ff
self.d_model= d_model
self.n_heads = n_heads
self.dropout_rate = dropout
self.max_len = max_len
#layers
self.Embedding = nn.Embedding(num_embeddings=n_vocabs, embedding_dim=d_model, padding_idx=0)
self.PositionalEncoding = PositionalEncoding(d_model, max_len=max_len)
self.encoders = nn.ModuleList([Encoder(d_ff=d_ff, d_model=d_model, n_heads=n_heads, dropout=dropout, max_len=max_len) for _ in range(n_layers)])
def forward(self, x):
padding_mask = self.create_padding_mask(x) # (BS x max_len) -> (BS x max_len)
x = self.Embedding(x) # (BS x max_len) -> (BS x max_len x d_model)
x = self.PositionalEncoding(x)
x[padding_mask] = torch.zeros(self.d_model).to(x.device)
for i in range(self.n_layers):
x = self.encoders[i](x, padding_mask)
x[padding_mask] = torch.zeros(self.d_model).to(x.device)
return x
def create_padding_mask(self, x):
mask = torch.eq(x, torch.zeros(x.size(), device=x.device))
return mask
class Token_Pretrained_Model(nn.Module):
def __init__(self, n_vocabs, n_target_vocabs, d_ff, d_model, n_layers, n_heads, dropout, max_len):
super(Token_Pretrained_Model, self).__init__()
self.encoder = Encoders(n_vocabs=n_vocabs, d_ff=d_ff, d_model=d_model, n_layers=n_layers, n_heads=n_heads, dropout=dropout, max_len=max_len)
self.fc1_mlm = nn.Linear(d_model, d_model)
self.gelu_mlm = nn.GELU()
self.norm_mlm = nn.LayerNorm(normalized_shape=[max_len, d_model], eps=1e-6)
self.fc2_mlm = nn.Linear(d_model, n_vocabs)
self.fc1_ssp = nn.Linear(d_model, d_model)
self.gelu_ssp = nn.GELU()
self.norm_ssp = nn.LayerNorm(normalized_shape=[max_len, d_model], eps=1e-6)
self.fc2_ssp = nn.Linear(d_model, n_target_vocabs)
def forward(self, inputs_mlm, inputs_ssp):
enc_mlm = self.encoder(inputs_mlm)
output_mlm = self.fc1_mlm(enc_mlm)
output_mlm = self.gelu_mlm(output_mlm)
output_mlm = self.norm_mlm(output_mlm)
output_mlm = self.fc2_mlm(output_mlm)
output_mlm = F.log_softmax(output_mlm, dim=1)
enc_ssp = self.encoder(inputs_ssp)
output_ssp = self.fc1_ssp(enc_ssp)
output_ssp = self.gelu_ssp(output_ssp)
output_ssp = self.norm_ssp(output_ssp)
output_ssp = self.fc2_ssp(output_ssp)
output_ssp = F.log_softmax(output_ssp, dim=1)
return output_mlm, output_ssp
class Convolution_Block(nn.Module):
def __init__(self, kernel_size):
super(Convolution_Block, self).__init__()
self.conv1 = nn.Conv2d(kernel_size, kernel_size, (4, 4), padding='same')
self.batchnorm1 = nn.BatchNorm2d(kernel_size)
self.conv2 = nn.Conv2d(kernel_size, kernel_size, (4, 4), padding='same')
self.batchnorm2 = nn.BatchNorm2d(kernel_size)
self.gelu = nn.GELU()
def forward(self, inputs):
output = self.conv1(inputs)
output = self.batchnorm1(output)
output = self.gelu(output)
output = self.conv2(output)
output = self.batchnorm2(output)
output = self.gelu(output)
output = output + inputs
return output
class Downsized_Convolution_Block(nn.Module):
def __init__(self, input_kernel_size, output_kernel_size):
super(Downsized_Convolution_Block, self).__init__()
self.conv1 = nn.Conv2d(input_kernel_size, output_kernel_size, (4, 4), padding='same')
self.batchnorm1 = nn.BatchNorm2d(output_kernel_size)
self.conv2 = nn.Conv2d(output_kernel_size, output_kernel_size, (4, 4), padding='same')
self.batchnorm2 = nn.BatchNorm2d(output_kernel_size)
self.maxpool = nn. MaxPool2d((2, 2))
self.gelu = nn.GELU()
def forward(self, inputs):
output = self.maxpool(inputs)
output = self.conv1(output)
output = self.batchnorm1(output)
output = self.gelu(output)
output = self.conv2(output)
output = self.batchnorm2(output)
output = self.gelu(output)
return output
class AptaTrans(nn.Module):
def __init__(self, apta_encoder, prot_encoder, n_apta_vocabs, n_prot_vocabs, dropout, apta_max_len, prot_max_len):
super(AptaTrans, self).__init__()
#hyperparameters
self.n_apta_vocabs = n_apta_vocabs
self.n_prot_vocabs = n_prot_vocabs
self.apta_max_len = apta_max_len
self.prot_max_len = prot_max_len
self.dropout = dropout
self.apta_encoder = apta_encoder
self.prot_encoder = prot_encoder
self.c = 85
self.kernel_size = self.c
self.batchnorm_fm = nn.BatchNorm2d(1)
self.conv = nn.Conv2d(1, self.kernel_size, (4, 4))
self.batchnorm = nn.BatchNorm2d(self.kernel_size)
self.gelu = nn.GELU()
self.maxpool = nn.MaxPool2d((2,2))
self.conv64_1 = Convolution_Block(self.c)
self.conv64_2 = Convolution_Block(self.c)
self.conv64_3 = Convolution_Block(self.c)
self.dconv128 = Downsized_Convolution_Block(self.c, 2*self.c)
self.conv128_1 = Convolution_Block(2*self.c)
self.conv128_2 = Convolution_Block(2*self.c)
self.dconv256 = Downsized_Convolution_Block(2*self.c, 4*self.c)
self.conv256_1 = Convolution_Block(4*self.c)
self.conv256_2 = Convolution_Block(4*self.c)
# self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.flatten = nn.Flatten()
# self.fc = nn.Linear(256 * 34 * 108, 128) #128, 136, 432 #68 * 216 * 256 # 256 * 34 * 108 #256, 68, 255
self.fc = nn.Linear(2416040, 128) # TODO: clean this up | redefine due to combination of ESM2 and AptaTrans encoder outputs
self.fc_reshape = nn.Linear(1280, 128) # Needed to reshape ESM-2 encodings to match AptaTrans encodings
self.fc1 = nn.Linear(128, 1)
def forward(self, apta, esm_prot):
apta = self.apta_encoder(apta) # output: (BS X #apt_toks x apt_embed_dim), encoder outputs (BS x MLM & sec. structure feature embeddings)
prot = self.prot_encoder(esm_prot, repr_layers=[33], return_contacts=False)['representations'][33]
# FF layer that converts embedding dim to match Aptatrans embedding dim
prot = self.fc_reshape(prot)
prot = torch.transpose(prot, 1, 2) # (BS x embed_dim x #toks)
output = torch.bmm(apta, prot) # (BS x #apt_toks x apt_embed_dim) x (BS x apt_embed_dim x #prot_toks) = (BS x #apt_toks x #prot_toks)
output = torch.unsqueeze(output, 1)
output = self.batchnorm_fm(output)
output = self.conv(output)
output = self.batchnorm(output)
output = self.gelu(output)
output = self.conv64_1(output)
output = self.conv64_2(output)
output = self.conv64_3(output)
output = self.dconv128(output)
output = self.conv128_1(output)
output = self.conv128_2(output)
output = self.dconv256(output)
output = self.conv256_1(output)
output = self.conv256_2(output)
output = self.maxpool(output)
# print(output.shape)
output = self.flatten(output)
output = self.fc(output)
output = self.gelu(output)
output = self.fc1(output)
output = torch.sigmoid(output)
return output
def generate_interaction_map(self, apta, prot):
with torch.no_grad():
apta = self.apta_encoder(apta)
prot = self.prot_encoder(prot)
prot = torch.transpose(prot, 1, 2)
interaction_map = torch.bmm(apta, prot)
interaction_map = torch.unsqueeze(interaction_map, 1)
interaction_map = self.batchnorm_fm(interaction_map)
return interaction_map
def conv_block_proba(self, interaction_map):
with torch.no_grad():
output = torch.tensor(interaction_map).float().to('cuda:0')
output = torch.unsqueeze(output, 1)
output = torch.mean(output, 4)
output = self.conv(output)
output = self.batchnorm(output)
output = self.gelu(output)
output = self.conv64_1(output)
output = self.conv64_2(output)
output = self.conv64_3(output)
output = self.dconv128(output)
output = self.conv128_1(output)
output = self.conv128_2(output)
output = self.dconv256(output)
output = self.conv256_1(output)
output = self.conv256_2(output)
output = self.maxpool(output)
output = self.flatten(output)
output = self.fc(output)
output = self.gelu(output)
output = self.fc1(output)
output = torch.sigmoid(output)
output = np.array([[1 - o[0], o[0]]for o in output.clone().detach().cpu().numpy()])
return output
def find_opt_threshold(target, pred):
result = 0
best = 0
for i in range(0, 1000):
pred_threshold = np.where(pred > i/1000, 1, 0)
now = f1_score(target, pred_threshold)
if now > best:
result = i/1000
best = now
return result
def argument_seqset(seqset):
arg_seqset = []
for s, ss in seqset:
arg_seqset.append([s, ss])
arg_seqset.append([s[::-1], ss[::-1]])
return arg_seqset
def augment_apis(apta, prot, ys):
aug_apta = []
aug_prot = []
aug_y = []
for a, p, y in zip(apta, prot, ys):
aug_apta.append(a)
aug_prot.append(p)
aug_y.append(y)
aug_apta.append(a[::-1])
aug_prot.append(p)
aug_y.append(y)
aug_apta.append(a)
aug_prot.append(p[::-1])
aug_y.append(y)
aug_apta.append(a[::-1])
aug_prot.append(p[::-1])
aug_y.append(y)
return np.array(aug_apta), np.array(aug_prot), np.array(aug_y)