File size: 3,166 Bytes
29134bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import esm
import numpy as np
from utils import tokenize_sequences, rna2vec
import torch
import pickle
import random

class SimAnnealer:
    def __init__(self, temperature, model, steps, target, length):
        self.temp = temperature
        self.device = 'cpu'
        self.model = model.to(self.device)
        self.length = length
        self.initial_state = self.random_state_generator()
        self.alpha = 0.95
        self.steps = steps
        self.target = target
        self.n_prot_vocabs = 1 + 713 + 1 # pad + voca + msk
        self.n_prot_target_vocabs = 1 + 584 # pad + voca
        self.prot_max_len = 867
        with open('./data/protein_word_freq_3.pickle', 'rb') as fr:
            words = pickle.load(fr)
            words = words[words["freq"]>words.freq.mean()].seq.values
            self.prot_words = {word:i+1 for i, word in enumerate(words)}

    def simulation(self):
        state = self.initial_state
        tokenized_target = self.prot_tokenizer().to(self.device)
        self.model.eval()
        for i in range(self.steps):
            if i % 10 == 0:
                print(f'Running simulation at step {i}')
            tokenized_state = self.apta_tokenizer(state).to(self.device)
            self.temp = self.temp_scheduler(i)
            neighbor = self.mutate(state)
            tokenized_neighbor = self.apta_tokenizer(neighbor).to(self.device)
            with torch.no_grad():
                d1 = self.model(tokenized_state, tokenized_target)
                d2 = self.model(tokenized_neighbor, tokenized_target)
            if d2 > d1:
                state = neighbor
                out_p = d2
            else:
                p = torch.exp(-(d2 - d1)/self.temp) 
                if torch.rand(size=(1,)) < torch.squeeze(p):
                    state = neighbor
                    out_p = d1
                out_p = d2
        return state, out_p
    
    def prot_tokenizer(self):
        _, esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() # ESM-2 Encoder
        bc = esm_alphabet.get_batch_converter()
        _, _, prot_tokens = bc([(1, self.target)])
        prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64)
        # adjusting for max protein sequence length during model training
        prot_ex = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*esm_alphabet.padding_idx
        prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized
        return prot_ex.to(self.device)
    
    def apta_tokenizer(self, aptamer):
        return torch.tensor(rna2vec(np.array([aptamer])), dtype=torch.int64)

    def temp_scheduler(self, t):
        return self.temp*self.alpha**t

    def mutate(self, state):
        base_ind = np.random.choice(len(self.initial_state))
        base = state[base_ind]
        cands = ['U', 'A', 'C', 'G']
        choice = cands[0]
        while choice == base:
            choice = cands[np.random.choice(len(cands))]
        state = list(state)
        state[base_ind] = choice
        return "".join(state)
    
    def random_state_generator(self):
        cands = ['U', 'A', 'C', 'G']
        return random.choices(cands, k=self.length)