import torch import torch.nn as nn from torch.utils.data import Dataset from transformers import EsmModel import pandas as pd class ProteinDataset(Dataset): def __init__(self, data, esm_tokenizer, apta_tokenizer): self.data = data self.esm_tokenizer = esm_tokenizer self.apta_tokenizer = apta_tokenizer def __len__(self): return len(self.data) def __getitem__(self, idx): protein_seq = self.data.iloc[idx, 0] protein_apta_tokens = self.apta_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True) protein_esm_tokens = self.esm_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True) return protein_apta_tokens, protein_esm_tokens class ContrastiveNetwork(nn.Module): def __init__(self, apta_encoder, esm_encoder, out_tokens=64, output_dim=128): super(ContrastiveNetwork, self).__init__() self.apta_encoder = apta_encoder self.esm_encoder = esm_encoder self.apta_proj = nn.Linear(128, out_tokens) self.esm_proj = nn.Linear(1280, out_tokens) self.final_proj = nn.Linear(1, output_dim) def forward(self, protein_apta_tokens, protein_esm_tokens): # Input sizes: protein_apta_tokens: (BS x #aptaTokens), protein_esm_tokens: (BS x #esmTokens) protein_apta_embedding = self.apta_encoder(protein_apta_tokens) protein_esm_embedding = self.esm_encoder(protein_esm_tokens, repr_layers=[33], return_contacts=False)['representations'][33] # Sizes: protein_apta_embedding: (BS x #aptaTokens x apta_embed_dim), protein_esm_embedding: (BS x #esmTokens x esm_embed_dim) # Apply mean pooling protein_apta_pooled = protein_apta_embedding.mean(dim=1) protein_esm_pooled = protein_esm_embedding.mean(dim=1) # Project to a common space protein_apta_proj = self.apta_proj(protein_apta_pooled) protein_esm_proj = self.esm_proj(protein_esm_pooled) # Concatenate the projected embeddings combined_embedding = torch.cat((protein_apta_proj, protein_esm_proj), dim=1) # Unsqueeze to get an embedding combined_embedding_reshape = torch.unsqueeze(combined_embedding, -1) # Final projection to match the original AptaTrans embedding size output_embedding = self.final_proj(combined_embedding_reshape) # Output embedding has shape (BS x embedding_dim x output_dim) - can construct interaction map w/ this return protein_apta_proj, protein_esm_proj, output_embedding def load_dataset(file_path): data = pd.read_csv(file_path) return data def contrastive_loss(embedding_a, embedding_b, margin=1.0): positive_distance = nn.functional.pairwise_distance(embedding_a, embedding_b) loss = positive_distance.mean() return loss