import gradio as gr from huggingface_hub import HfApi, ModelFilter import pandas as pd from re import match from tempfile import NamedTemporaryFile import torch from transformers import AutoTokenizer, AutoModelForMaskedLM # fetch suitable ESM models from HuggingFace Hub MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)] if not any(MODELS): raise RuntimeError("Error while retrieving models from HuggingFace Hub") # scoring strategies SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"] class Model: """Wrapper for ESM models""" def __init__(self, model_name:str=""): "load selected model and tokenizer" self.model_name = model_name if model_name: self.model = AutoModelForMaskedLM.from_pretrained(model_name) self.batch_converter = AutoTokenizer.from_pretrained(model_name) self.alphabet = self.batch_converter.get_vocab() if torch.cuda.is_available(): self.model = self.model.cuda() def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor: "run model on batch of tokens" return self.model(batch_tokens)["logits"] def __lshift__(self, input:str) -> torch.Tensor: "convert input string to batch of tokens" return self.batch_converter(input, return_tensors="pt")["input_ids"] def __getitem__(self, key:str) -> int: "get token ID from character" return self.alphabet[key] def run_model(self, data): "run model on data" def label_row(row, token_probs): "label row with score" wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] return score.item() batch_tokens = self<>batch_tokens, dim=-1) data.out[self.model_name] = data.sub.apply( lambda row: label_row( row['0'], token_probs, ), axis=1, ) elif data.scoring_strategy.startswith("masked-marginals"): all_token_probs = [] for i in range(batch_tokens.size()[1]): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = self[''] with torch.no_grad(): token_probs = torch.log_softmax( self>>batch_tokens_masked, dim=-1 ) all_token_probs.append(token_probs[:, i]) token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) data.out[self.model_name] = data.sub.apply( lambda row: label_row( row['0'], token_probs, ), axis=1, ) class Data: """Container for input and output data""" # initialise empty model as static class member for efficiency model = Model() def parse_seq(self, src:str): "parse input sequence" self.seq = src.strip().upper() if not all(x in self.model.alphabet for x in src): raise RuntimeError("Unrecognised characters in sequence") def parse_sub(self, trg:str): "parse input substitutions" self.mode = None self.sub = list() self.trg = trg.strip().upper() # identify running mode if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode self.mode = 'SVS' for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1): if src != trg: self.sub.append(f"{src}{resi}{trg}") else: self.trg = self.trg.split() if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode self.mode = 'DMS' for resi in map(int, self.trg): src = self.seq[resi-1] for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''): self.sub.append(f"{src}{resi}{trg}") elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode self.mode = 'MUT' self.sub = self.trg else: raise RuntimeError("Unrecognised running mode; wrong inputs?") self.sub = pd.DataFrame(self.sub, columns=['0']) def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file): "initialise data" # if model has changed, load new model if self.model.model_name != model_name: self.model_name = model_name self.model = Model(model_name) self.parse_seq(src) self.parse_sub(trg) self.scoring_strategy = scoring_strategy self.out = pd.DataFrame(self.sub, columns=['0', self.model_name]) self.out_buffer = out_file.name def parse_output(self) -> str: "format output data for visualisation" if self.mode == 'MUT': # if single substitution mode, sort by score self.out = self.out.sort_values(self.model_name, ascending=False) elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence .sort_values(['resi', self.model_name], ascending=[True,False]) .groupby(['resi']) .head(19) .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)] .reset_index(drop=True) for x in range(self.out.shape[0]//19)] , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns') # save to temporary file to be downloaded self.out.round(2).to_csv(self.out_buffer, index=False) return (self.out.style .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x) .hide(axis=0) .hide(axis=1) .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8) .to_html()) def calculate(self): "run model and parse output" self.model.run_model(self) return self, self.parse_output() def app(*argv): "run app" seq, trg, model_name, scoring_strategy, out_file, *_ = argv data, html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate() return html, gr.File.update(value=out_file.name, visible=True) # df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE')) # df.to_csv(out_file.name, index=False) # return df.to_html(), gr.File.update(value=out_file.name, visible=True) with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md: gr.Markdown(md.read()) seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ') trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140") model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1]) scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1]) btn = gr.Button(value="Submit") out = gr.HTML() bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False) btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto]) # demo.launch(share=True, server_name="0.0.0.0", server_port=7878)