File size: 8,758 Bytes
2dd6312
634752b
 
2dd6312
 
634752b
2dd6312
634752b
2dd6312
634752b
2dd6312
 
634752b
2dd6312
 
634752b
2dd6312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634752b
2dd6312
 
 
634752b
2dd6312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634752b
2dd6312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634752b
2dd6312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634752b
 
2dd6312
 
 
 
 
 
 
 
 
 
 
 
634752b
 
2dd6312
 
 
 
634752b
2dd6312
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
from huggingface_hub import HfApi, ModelFilter
import pandas as pd
from re import match
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

    # fetch suitable ESM models from HuggingFace Hub
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
if not any(MODELS):
    raise RuntimeError("Error while retrieving models from HuggingFace Hub")

    # scoring strategies
SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]

class Model:
    """Wrapper for ESM models"""
    def __init__(self, model_name:str=""):
        "load selected model and tokenizer"
        self.model_name = model_name
        if model_name:
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.batch_converter = AutoTokenizer.from_pretrained(model_name)
            self.alphabet = self.batch_converter.get_vocab()
            if torch.cuda.is_available():
                self.model = self.model.cuda()

    def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
        "run model on batch of tokens"
        return self.model(batch_tokens)["logits"]
    
    def __lshift__(self, input:str) -> torch.Tensor:
        "convert input string to batch of tokens"
        return self.batch_converter(input, return_tensors="pt")["input_ids"]

    def __getitem__(self, key:str) -> int:
        "get token ID from character"
        return self.alphabet[key]
    
    def run_model(self, data):
        "run model on data"
        def label_row(row, token_probs):
            "label row with score"
            wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
            score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
            return score.item()
        
        batch_tokens = self<<data.seq

            # run model with selected scoring strategy (info thereof available in the original ESM paper)
        if data.scoring_strategy.startswith("wt-marginals"):
            with torch.no_grad():
                token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
            data.out[self.model_name] = data.sub.apply(
                lambda row: label_row(
                    row['0'],
                    token_probs,
                ),
                axis=1,
            )
        elif data.scoring_strategy.startswith("masked-marginals"):
            all_token_probs = []
            for i in range(batch_tokens.size()[1]):
                batch_tokens_masked = batch_tokens.clone()
                batch_tokens_masked[0, i] = self['<mask>']
                with torch.no_grad():
                    token_probs = torch.log_softmax(
                        self>>batch_tokens_masked, dim=-1
                    )
                all_token_probs.append(token_probs[:, i])
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
            data.out[self.model_name] = data.sub.apply(
                lambda row: label_row(
                    row['0'],
                    token_probs,
                ),
                axis=1,
            )
    
class Data:
    """Container for input and output data"""
        # initialise empty model as static class member for efficiency
    model = Model()

    def parse_seq(self, src:str):
        "parse input sequence"
        self.seq = src.strip().upper()
        if not all(x in self.model.alphabet for x in src):
            raise RuntimeError("Unrecognised characters in sequence")

    def parse_sub(self, trg:str):
        "parse input substitutions"
        self.mode = None
        self.sub = list()
        self.trg = trg.strip().upper()

            # identify running mode
        if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq):    # if single string of same length as sequence, seq vs seq mode
            self.mode = 'SVS'
            for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
                if src != trg:
                    self.sub.append(f"{src}{resi}{trg}")
        else:
            self.trg = self.trg.split()
            if all(match(r'\d+', x) for x in self.trg):                                 # if all strings are numbers, deep mutational scanning mode
                self.mode = 'DMS'
                for resi in map(int, self.trg):
                    src = self.seq[resi-1]
                    for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
                        self.sub.append(f"{src}{resi}{trg}")
            elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):                     # if all strings are of the form X#Y, single substitution mode
                self.mode = 'MUT'
                self.sub = self.trg
            else:
                raise RuntimeError("Unrecognised running mode; wrong inputs?")
        
        self.sub = pd.DataFrame(self.sub, columns=['0'])

    def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
        "initialise data"
            # if model has changed, load new model
        if self.model.model_name != model_name:
            self.model_name = model_name
            self.model = Model(model_name)
        self.parse_seq(src)
        self.parse_sub(trg)
        self.scoring_strategy = scoring_strategy
        self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
        self.out_buffer = out_file.name

    def parse_output(self) -> str:
        "format output data for visualisation"
        if self.mode == 'MUT':      # if single substitution mode, sort by score
            self.out = self.out.sort_values(self.model_name, ascending=False)
        elif self.mode == 'DMS':    # if deep mutational scanning mode, sort by residue and score
            self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))      # FIX: this doesn't work if there's jolly characters in the input sequence
                                    .sort_values(['resi', self.model_name], ascending=[True,False])
                                    .groupby(['resi'])
                                    .head(19)
                                    .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
                                                            .reset_index(drop=True) for x in range(self.out.shape[0]//19)]
                                , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
            # save to temporary file to be downloaded
        self.out.round(2).to_csv(self.out_buffer, index=False)
        return (self.out.style
                        .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
                        .hide(axis=0)
                        .hide(axis=1)
                        .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
                        .to_html())
    
    def calculate(self):
        "run model and parse output"
        self.model.run_model(self)
        return self, self.parse_output()

def app(*argv):
    "run app"
    seq, trg, model_name, scoring_strategy, out_file, *_ = argv
    data, html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
    return html, gr.File.update(value=out_file.name, visible=True)
    # df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE'))
    # df.to_csv(out_file.name, index=False)
    # return df.to_html(), gr.File.update(value=out_file.name, visible=True)

with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
    gr.Markdown(md.read())
    seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ')
    trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140")
    model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1])
    scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1])
    btn = gr.Button(value="Submit")
    out = gr.HTML()
    bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False)
    btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto]) 

# demo.launch(share=True, server_name="0.0.0.0", server_port=7878)