aptlm / aptaESM_head.py
abwer
Initial commit
29134bd
raw
history blame
No virus
2.83 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import EsmModel
import pandas as pd
class ProteinDataset(Dataset):
def __init__(self, data, esm_tokenizer, apta_tokenizer):
self.data = data
self.esm_tokenizer = esm_tokenizer
self.apta_tokenizer = apta_tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
protein_seq = self.data.iloc[idx, 0]
protein_apta_tokens = self.apta_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True)
protein_esm_tokens = self.esm_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True)
return protein_apta_tokens, protein_esm_tokens
class ContrastiveNetwork(nn.Module):
def __init__(self, apta_encoder, esm_encoder, out_tokens=64, output_dim=128):
super(ContrastiveNetwork, self).__init__()
self.apta_encoder = apta_encoder
self.esm_encoder = esm_encoder
self.apta_proj = nn.Linear(128, out_tokens)
self.esm_proj = nn.Linear(1280, out_tokens)
self.final_proj = nn.Linear(1, output_dim)
def forward(self, protein_apta_tokens, protein_esm_tokens):
# Input sizes: protein_apta_tokens: (BS x #aptaTokens), protein_esm_tokens: (BS x #esmTokens)
protein_apta_embedding = self.apta_encoder(protein_apta_tokens)
protein_esm_embedding = self.esm_encoder(protein_esm_tokens, repr_layers=[33], return_contacts=False)['representations'][33]
# Sizes: protein_apta_embedding: (BS x #aptaTokens x apta_embed_dim), protein_esm_embedding: (BS x #esmTokens x esm_embed_dim)
# Apply mean pooling
protein_apta_pooled = protein_apta_embedding.mean(dim=1)
protein_esm_pooled = protein_esm_embedding.mean(dim=1)
# Project to a common space
protein_apta_proj = self.apta_proj(protein_apta_pooled)
protein_esm_proj = self.esm_proj(protein_esm_pooled)
# Concatenate the projected embeddings
combined_embedding = torch.cat((protein_apta_proj, protein_esm_proj), dim=1)
# Unsqueeze to get an embedding
combined_embedding_reshape = torch.unsqueeze(combined_embedding, -1)
# Final projection to match the original AptaTrans embedding size
output_embedding = self.final_proj(combined_embedding_reshape)
# Output embedding has shape (BS x embedding_dim x output_dim) - can construct interaction map w/ this
return protein_apta_proj, protein_esm_proj, output_embedding
def load_dataset(file_path):
data = pd.read_csv(file_path)
return data
def contrastive_loss(embedding_a, embedding_b, margin=1.0):
positive_distance = nn.functional.pairwise_distance(embedding_a, embedding_b)
loss = positive_distance.mean()
return loss