aptlm / annealer.py
abwer
Initial commit
29134bd
raw
history blame contribute delete
No virus
3.17 kB
import esm
import numpy as np
from utils import tokenize_sequences, rna2vec
import torch
import pickle
import random
class SimAnnealer:
def __init__(self, temperature, model, steps, target, length):
self.temp = temperature
self.device = 'cpu'
self.model = model.to(self.device)
self.length = length
self.initial_state = self.random_state_generator()
self.alpha = 0.95
self.steps = steps
self.target = target
self.n_prot_vocabs = 1 + 713 + 1 # pad + voca + msk
self.n_prot_target_vocabs = 1 + 584 # pad + voca
self.prot_max_len = 867
with open('./data/protein_word_freq_3.pickle', 'rb') as fr:
words = pickle.load(fr)
words = words[words["freq"]>words.freq.mean()].seq.values
self.prot_words = {word:i+1 for i, word in enumerate(words)}
def simulation(self):
state = self.initial_state
tokenized_target = self.prot_tokenizer().to(self.device)
self.model.eval()
for i in range(self.steps):
if i % 10 == 0:
print(f'Running simulation at step {i}')
tokenized_state = self.apta_tokenizer(state).to(self.device)
self.temp = self.temp_scheduler(i)
neighbor = self.mutate(state)
tokenized_neighbor = self.apta_tokenizer(neighbor).to(self.device)
with torch.no_grad():
d1 = self.model(tokenized_state, tokenized_target)
d2 = self.model(tokenized_neighbor, tokenized_target)
if d2 > d1:
state = neighbor
out_p = d2
else:
p = torch.exp(-(d2 - d1)/self.temp)
if torch.rand(size=(1,)) < torch.squeeze(p):
state = neighbor
out_p = d1
out_p = d2
return state, out_p
def prot_tokenizer(self):
_, esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() # ESM-2 Encoder
bc = esm_alphabet.get_batch_converter()
_, _, prot_tokens = bc([(1, self.target)])
prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64)
# adjusting for max protein sequence length during model training
prot_ex = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*esm_alphabet.padding_idx
prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized
return prot_ex.to(self.device)
def apta_tokenizer(self, aptamer):
return torch.tensor(rna2vec(np.array([aptamer])), dtype=torch.int64)
def temp_scheduler(self, t):
return self.temp*self.alpha**t
def mutate(self, state):
base_ind = np.random.choice(len(self.initial_state))
base = state[base_ind]
cands = ['U', 'A', 'C', 'G']
choice = cands[0]
while choice == base:
choice = cands[np.random.choice(len(cands))]
state = list(state)
state[base_ind] = choice
return "".join(state)
def random_state_generator(self):
cands = ['U', 'A', 'C', 'G']
return random.choices(cands, k=self.length)