zsp / model.py
MassimoGregorioTotaro
fix ok, reformatting
b212cb1
raw
history blame
3.03 kB
from huggingface_hub import HfApi, ModelFilter
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
# fetch suitable ESM models from HuggingFace Hub
MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
if not any(MODELS):
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
class Model:
"""Wrapper for ESM models"""
def __init__(self, model_name:str=""):
"load selected model and tokenizer"
self.model_name = model_name
if model_name:
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
self.alphabet = self.batch_converter.get_vocab()
if torch.cuda.is_available():
self.model = self.model.cuda()
def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
"run model on batch of tokens"
return self.model(batch_tokens)["logits"]
def __lshift__(self, input:str) -> torch.Tensor:
"convert input string to batch of tokens"
return self.batch_converter(input, return_tensors="pt")["input_ids"]
def __getitem__(self, key:str) -> int:
"get token ID from character"
return self.alphabet[key]
def run_model(self, data):
"run model on data"
def label_row(row, token_probs):
"label row with score"
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
return score.item()
batch_tokens = self<<data.seq
# run model with selected scoring strategy (info thereof available in the original ESM paper)
if data.scoring_strategy.startswith("wt-marginals"):
with torch.no_grad():
token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
elif data.scoring_strategy.startswith("masked-marginals"):
all_token_probs = []
for i in range(batch_tokens.size()[1]):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = self['<mask>']
with torch.no_grad():
token_probs = torch.log_softmax(
self>>batch_tokens_masked, dim=-1
)
all_token_probs.append(token_probs[:, i])
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)