Spaces:
Sleeping
Sleeping
File size: 8,758 Bytes
2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 634752b 2dd6312 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
from huggingface_hub import HfApi, ModelFilter
import pandas as pd
from re import match
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
# fetch suitable ESM models from HuggingFace Hub
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
if not any(MODELS):
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
# scoring strategies
SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]
class Model:
"""Wrapper for ESM models"""
def __init__(self, model_name:str=""):
"load selected model and tokenizer"
self.model_name = model_name
if model_name:
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
self.alphabet = self.batch_converter.get_vocab()
if torch.cuda.is_available():
self.model = self.model.cuda()
def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
"run model on batch of tokens"
return self.model(batch_tokens)["logits"]
def __lshift__(self, input:str) -> torch.Tensor:
"convert input string to batch of tokens"
return self.batch_converter(input, return_tensors="pt")["input_ids"]
def __getitem__(self, key:str) -> int:
"get token ID from character"
return self.alphabet[key]
def run_model(self, data):
"run model on data"
def label_row(row, token_probs):
"label row with score"
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
return score.item()
batch_tokens = self<<data.seq
# run model with selected scoring strategy (info thereof available in the original ESM paper)
if data.scoring_strategy.startswith("wt-marginals"):
with torch.no_grad():
token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
elif data.scoring_strategy.startswith("masked-marginals"):
all_token_probs = []
for i in range(batch_tokens.size()[1]):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = self['<mask>']
with torch.no_grad():
token_probs = torch.log_softmax(
self>>batch_tokens_masked, dim=-1
)
all_token_probs.append(token_probs[:, i])
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
class Data:
"""Container for input and output data"""
# initialise empty model as static class member for efficiency
model = Model()
def parse_seq(self, src:str):
"parse input sequence"
self.seq = src.strip().upper()
if not all(x in self.model.alphabet for x in src):
raise RuntimeError("Unrecognised characters in sequence")
def parse_sub(self, trg:str):
"parse input substitutions"
self.mode = None
self.sub = list()
self.trg = trg.strip().upper()
# identify running mode
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode
self.mode = 'SVS'
for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
if src != trg:
self.sub.append(f"{src}{resi}{trg}")
else:
self.trg = self.trg.split()
if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode
self.mode = 'DMS'
for resi in map(int, self.trg):
src = self.seq[resi-1]
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
self.sub.append(f"{src}{resi}{trg}")
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode
self.mode = 'MUT'
self.sub = self.trg
else:
raise RuntimeError("Unrecognised running mode; wrong inputs?")
self.sub = pd.DataFrame(self.sub, columns=['0'])
def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
"initialise data"
# if model has changed, load new model
if self.model.model_name != model_name:
self.model_name = model_name
self.model = Model(model_name)
self.parse_seq(src)
self.parse_sub(trg)
self.scoring_strategy = scoring_strategy
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
self.out_buffer = out_file.name
def parse_output(self) -> str:
"format output data for visualisation"
if self.mode == 'MUT': # if single substitution mode, sort by score
self.out = self.out.sort_values(self.model_name, ascending=False)
elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score
self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence
.sort_values(['resi', self.model_name], ascending=[True,False])
.groupby(['resi'])
.head(19)
.drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
.reset_index(drop=True) for x in range(self.out.shape[0]//19)]
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
# save to temporary file to be downloaded
self.out.round(2).to_csv(self.out_buffer, index=False)
return (self.out.style
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
.hide(axis=0)
.hide(axis=1)
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
.to_html())
def calculate(self):
"run model and parse output"
self.model.run_model(self)
return self, self.parse_output()
def app(*argv):
"run app"
seq, trg, model_name, scoring_strategy, out_file, *_ = argv
data, html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
return html, gr.File.update(value=out_file.name, visible=True)
# df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE'))
# df.to_csv(out_file.name, index=False)
# return df.to_html(), gr.File.update(value=out_file.name, visible=True)
with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
gr.Markdown(md.read())
seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ')
trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140")
model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1])
scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1])
btn = gr.Button(value="Submit")
out = gr.HTML()
bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False)
btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto])
# demo.launch(share=True, server_name="0.0.0.0", server_port=7878) |