zsp / app.py
MassimoGregorioTotaro
general reorganisation
2dd6312
raw
history blame
8.76 kB
import gradio as gr
from huggingface_hub import HfApi, ModelFilter
import pandas as pd
from re import match
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
# fetch suitable ESM models from HuggingFace Hub
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
if not any(MODELS):
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
# scoring strategies
SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]
class Model:
"""Wrapper for ESM models"""
def __init__(self, model_name:str=""):
"load selected model and tokenizer"
self.model_name = model_name
if model_name:
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
self.alphabet = self.batch_converter.get_vocab()
if torch.cuda.is_available():
self.model = self.model.cuda()
def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
"run model on batch of tokens"
return self.model(batch_tokens)["logits"]
def __lshift__(self, input:str) -> torch.Tensor:
"convert input string to batch of tokens"
return self.batch_converter(input, return_tensors="pt")["input_ids"]
def __getitem__(self, key:str) -> int:
"get token ID from character"
return self.alphabet[key]
def run_model(self, data):
"run model on data"
def label_row(row, token_probs):
"label row with score"
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
return score.item()
batch_tokens = self<<data.seq
# run model with selected scoring strategy (info thereof available in the original ESM paper)
if data.scoring_strategy.startswith("wt-marginals"):
with torch.no_grad():
token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
elif data.scoring_strategy.startswith("masked-marginals"):
all_token_probs = []
for i in range(batch_tokens.size()[1]):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = self['<mask>']
with torch.no_grad():
token_probs = torch.log_softmax(
self>>batch_tokens_masked, dim=-1
)
all_token_probs.append(token_probs[:, i])
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
class Data:
"""Container for input and output data"""
# initialise empty model as static class member for efficiency
model = Model()
def parse_seq(self, src:str):
"parse input sequence"
self.seq = src.strip().upper()
if not all(x in self.model.alphabet for x in src):
raise RuntimeError("Unrecognised characters in sequence")
def parse_sub(self, trg:str):
"parse input substitutions"
self.mode = None
self.sub = list()
self.trg = trg.strip().upper()
# identify running mode
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode
self.mode = 'SVS'
for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
if src != trg:
self.sub.append(f"{src}{resi}{trg}")
else:
self.trg = self.trg.split()
if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode
self.mode = 'DMS'
for resi in map(int, self.trg):
src = self.seq[resi-1]
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
self.sub.append(f"{src}{resi}{trg}")
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode
self.mode = 'MUT'
self.sub = self.trg
else:
raise RuntimeError("Unrecognised running mode; wrong inputs?")
self.sub = pd.DataFrame(self.sub, columns=['0'])
def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
"initialise data"
# if model has changed, load new model
if self.model.model_name != model_name:
self.model_name = model_name
self.model = Model(model_name)
self.parse_seq(src)
self.parse_sub(trg)
self.scoring_strategy = scoring_strategy
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
self.out_buffer = out_file.name
def parse_output(self) -> str:
"format output data for visualisation"
if self.mode == 'MUT': # if single substitution mode, sort by score
self.out = self.out.sort_values(self.model_name, ascending=False)
elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score
self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence
.sort_values(['resi', self.model_name], ascending=[True,False])
.groupby(['resi'])
.head(19)
.drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
.reset_index(drop=True) for x in range(self.out.shape[0]//19)]
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
# save to temporary file to be downloaded
self.out.round(2).to_csv(self.out_buffer, index=False)
return (self.out.style
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
.hide(axis=0)
.hide(axis=1)
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
.to_html())
def calculate(self):
"run model and parse output"
self.model.run_model(self)
return self, self.parse_output()
def app(*argv):
"run app"
seq, trg, model_name, scoring_strategy, out_file, *_ = argv
data, html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
return html, gr.File.update(value=out_file.name, visible=True)
# df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE'))
# df.to_csv(out_file.name, index=False)
# return df.to_html(), gr.File.update(value=out_file.name, visible=True)
with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
gr.Markdown(md.read())
seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ')
trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140")
model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1])
scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1])
btn = gr.Button(value="Submit")
out = gr.HTML()
bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False)
btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto])
# demo.launch(share=True, server_name="0.0.0.0", server_port=7878)