|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
from transformers import EsmModel |
|
import pandas as pd |
|
|
|
class ProteinDataset(Dataset): |
|
def __init__(self, data, esm_tokenizer, apta_tokenizer): |
|
self.data = data |
|
self.esm_tokenizer = esm_tokenizer |
|
self.apta_tokenizer = apta_tokenizer |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
protein_seq = self.data.iloc[idx, 0] |
|
protein_apta_tokens = self.apta_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True) |
|
protein_esm_tokens = self.esm_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True) |
|
return protein_apta_tokens, protein_esm_tokens |
|
|
|
class ContrastiveNetwork(nn.Module): |
|
def __init__(self, apta_encoder, esm_encoder, out_tokens=64, output_dim=128): |
|
super(ContrastiveNetwork, self).__init__() |
|
self.apta_encoder = apta_encoder |
|
self.esm_encoder = esm_encoder |
|
self.apta_proj = nn.Linear(128, out_tokens) |
|
self.esm_proj = nn.Linear(1280, out_tokens) |
|
self.final_proj = nn.Linear(1, output_dim) |
|
|
|
def forward(self, protein_apta_tokens, protein_esm_tokens): |
|
|
|
protein_apta_embedding = self.apta_encoder(protein_apta_tokens) |
|
protein_esm_embedding = self.esm_encoder(protein_esm_tokens, repr_layers=[33], return_contacts=False)['representations'][33] |
|
|
|
|
|
|
|
protein_apta_pooled = protein_apta_embedding.mean(dim=1) |
|
protein_esm_pooled = protein_esm_embedding.mean(dim=1) |
|
|
|
|
|
protein_apta_proj = self.apta_proj(protein_apta_pooled) |
|
protein_esm_proj = self.esm_proj(protein_esm_pooled) |
|
|
|
|
|
combined_embedding = torch.cat((protein_apta_proj, protein_esm_proj), dim=1) |
|
|
|
|
|
combined_embedding_reshape = torch.unsqueeze(combined_embedding, -1) |
|
|
|
|
|
output_embedding = self.final_proj(combined_embedding_reshape) |
|
|
|
|
|
return protein_apta_proj, protein_esm_proj, output_embedding |
|
|
|
def load_dataset(file_path): |
|
data = pd.read_csv(file_path) |
|
return data |
|
|
|
def contrastive_loss(embedding_a, embedding_b, margin=1.0): |
|
positive_distance = nn.functional.pairwise_distance(embedding_a, embedding_b) |
|
loss = positive_distance.mean() |
|
return loss |
|
|