Spaces:
Sleeping
Sleeping
from huggingface_hub import HfApi, ModelFilter | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
# fetch suitable ESM models from HuggingFace Hub | |
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)] | |
if not any(MODELS): | |
raise RuntimeError("Error while retrieving models from HuggingFace Hub") | |
class Model: | |
"""Wrapper for ESM models""" | |
def __init__(self, model_name:str=""): | |
"load selected model and tokenizer" | |
self.model_name = model_name | |
if model_name: | |
self.model = AutoModelForMaskedLM.from_pretrained(model_name) | |
self.batch_converter = AutoTokenizer.from_pretrained(model_name) | |
self.alphabet = self.batch_converter.get_vocab() | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor: | |
"run model on batch of tokens" | |
return self.model(batch_tokens)["logits"] | |
def __lshift__(self, input:str) -> torch.Tensor: | |
"convert input string to batch of tokens" | |
return self.batch_converter(input, return_tensors="pt")["input_ids"] | |
def __getitem__(self, key:str) -> int: | |
"get token ID from character" | |
return self.alphabet[key] | |
def run_model(self, data): | |
"run model on data" | |
def label_row(row, token_probs): | |
"label row with score" | |
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] | |
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] | |
return score.item() | |
batch_tokens = self<<data.seq | |
# run model with selected scoring strategy (info thereof available in the original ESM paper) | |
if data.scoring_strategy.startswith("wt-marginals"): | |
with torch.no_grad(): | |
token_probs = torch.log_softmax(self>>batch_tokens, dim=-1) | |
data.out[self.model_name] = data.sub.apply( | |
lambda row: label_row( | |
row['0'], | |
token_probs, | |
), | |
axis=1, | |
) | |
elif data.scoring_strategy.startswith("masked-marginals"): | |
all_token_probs = [] | |
for i in range(batch_tokens.size()[1]): | |
batch_tokens_masked = batch_tokens.clone() | |
batch_tokens_masked[0, i] = self['<mask>'] | |
with torch.no_grad(): | |
token_probs = torch.log_softmax( | |
self>>batch_tokens_masked, dim=-1 | |
) | |
all_token_probs.append(token_probs[:, i]) | |
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) | |
data.out[self.model_name] = data.sub.apply( | |
lambda row: label_row( | |
row['0'], | |
token_probs, | |
), | |
axis=1, | |
) | |