import esm import numpy as np from utils import tokenize_sequences, rna2vec import torch import pickle import random class SimAnnealer: def __init__(self, temperature, model, steps, target, length): self.temp = temperature self.device = 'cpu' self.model = model.to(self.device) self.length = length self.initial_state = self.random_state_generator() self.alpha = 0.95 self.steps = steps self.target = target self.n_prot_vocabs = 1 + 713 + 1 # pad + voca + msk self.n_prot_target_vocabs = 1 + 584 # pad + voca self.prot_max_len = 867 with open('./data/protein_word_freq_3.pickle', 'rb') as fr: words = pickle.load(fr) words = words[words["freq"]>words.freq.mean()].seq.values self.prot_words = {word:i+1 for i, word in enumerate(words)} def simulation(self): state = self.initial_state tokenized_target = self.prot_tokenizer().to(self.device) self.model.eval() for i in range(self.steps): if i % 10 == 0: print(f'Running simulation at step {i}') tokenized_state = self.apta_tokenizer(state).to(self.device) self.temp = self.temp_scheduler(i) neighbor = self.mutate(state) tokenized_neighbor = self.apta_tokenizer(neighbor).to(self.device) with torch.no_grad(): d1 = self.model(tokenized_state, tokenized_target) d2 = self.model(tokenized_neighbor, tokenized_target) if d2 > d1: state = neighbor out_p = d2 else: p = torch.exp(-(d2 - d1)/self.temp) if torch.rand(size=(1,)) < torch.squeeze(p): state = neighbor out_p = d1 out_p = d2 return state, out_p def prot_tokenizer(self): _, esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() # ESM-2 Encoder bc = esm_alphabet.get_batch_converter() _, _, prot_tokens = bc([(1, self.target)]) prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64) # adjusting for max protein sequence length during model training prot_ex = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*esm_alphabet.padding_idx prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized return prot_ex.to(self.device) def apta_tokenizer(self, aptamer): return torch.tensor(rna2vec(np.array([aptamer])), dtype=torch.int64) def temp_scheduler(self, t): return self.temp*self.alpha**t def mutate(self, state): base_ind = np.random.choice(len(self.initial_state)) base = state[base_ind] cands = ['U', 'A', 'C', 'G'] choice = cands[0] while choice == base: choice = cands[np.random.choice(len(cands))] state = list(state) state[base_ind] = choice return "".join(state) def random_state_generator(self): cands = ['U', 'A', 'C', 'G'] return random.choices(cands, k=self.length)