from huggingface_hub import HfApi, ModelFilter import torch from transformers import AutoTokenizer, AutoModelForMaskedLM # fetch suitable ESM models from HuggingFace Hub MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)] if not any(MODELS): raise RuntimeError("Error while retrieving models from HuggingFace Hub") class Model: """Wrapper for ESM models""" def __init__(self, model_name:str=""): "load selected model and tokenizer" self.model_name = model_name if model_name: self.model = AutoModelForMaskedLM.from_pretrained(model_name) self.batch_converter = AutoTokenizer.from_pretrained(model_name) self.alphabet = self.batch_converter.get_vocab() if torch.cuda.is_available(): self.model = self.model.cuda() def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor: "run model on batch of tokens" return self.model(batch_tokens)["logits"] def __lshift__(self, input:str) -> torch.Tensor: "convert input string to batch of tokens" return self.batch_converter(input, return_tensors="pt")["input_ids"] def __getitem__(self, key:str) -> int: "get token ID from character" return self.alphabet[key] def run_model(self, data): "run model on data" def label_row(row, token_probs): "label row with score" wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] return score.item() batch_tokens = self<>batch_tokens, dim=-1) data.out[self.model_name] = data.sub.apply( lambda row: label_row( row['0'], token_probs, ), axis=1, ) elif data.scoring_strategy.startswith("masked-marginals"): all_token_probs = [] for i in range(batch_tokens.size()[1]): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = self[''] with torch.no_grad(): token_probs = torch.log_softmax( self>>batch_tokens_masked, dim=-1 ) all_token_probs.append(token_probs[:, i]) token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) data.out[self.model_name] = data.sub.apply( lambda row: label_row( row['0'], token_probs, ), axis=1, )