File size: 3,032 Bytes
b212cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from huggingface_hub import HfApi, ModelFilter
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

    # fetch suitable ESM models from HuggingFace Hub
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
if not any(MODELS):
    raise RuntimeError("Error while retrieving models from HuggingFace Hub")

class Model:
    """Wrapper for ESM models"""
    def __init__(self, model_name:str=""):
        "load selected model and tokenizer"
        self.model_name = model_name
        if model_name:
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.batch_converter = AutoTokenizer.from_pretrained(model_name)
            self.alphabet = self.batch_converter.get_vocab()
            if torch.cuda.is_available():
                self.model = self.model.cuda()

    def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
        "run model on batch of tokens"
        return self.model(batch_tokens)["logits"]
    
    def __lshift__(self, input:str) -> torch.Tensor:
        "convert input string to batch of tokens"
        return self.batch_converter(input, return_tensors="pt")["input_ids"]

    def __getitem__(self, key:str) -> int:
        "get token ID from character"
        return self.alphabet[key]
    
    def run_model(self, data):
        "run model on data"
        def label_row(row, token_probs):
            "label row with score"
            wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
            score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
            return score.item()
        
        batch_tokens = self<<data.seq

            # run model with selected scoring strategy (info thereof available in the original ESM paper)
        if data.scoring_strategy.startswith("wt-marginals"):
            with torch.no_grad():
                token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
            data.out[self.model_name] = data.sub.apply(
                lambda row: label_row(
                    row['0'],
                    token_probs,
                ),
                axis=1,
            )
        elif data.scoring_strategy.startswith("masked-marginals"):
            all_token_probs = []
            for i in range(batch_tokens.size()[1]):
                batch_tokens_masked = batch_tokens.clone()
                batch_tokens_masked[0, i] = self['<mask>']
                with torch.no_grad():
                    token_probs = torch.log_softmax(
                        self>>batch_tokens_masked, dim=-1
                    )
                all_token_probs.append(token_probs[:, i])
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
            data.out[self.model_name] = data.sub.apply(
                lambda row: label_row(
                    row['0'],
                    token_probs,
                ),
                axis=1,
            )