|
from collections import namedtuple |
|
import torch |
|
import esm |
|
from typing import List, Union, Optional |
|
from protein_lm.modeling.scripts.train import compute_esm_embedding, load_ckpt, make_esm_input_ids |
|
from protein_lm.tokenizer.tokenizer import PTMTokenizer |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
Output = namedtuple("output", ["logits", "hidden_states"]) |
|
|
|
class PTMMamba: |
|
def __init__(self, ckpt_path, device='cuda',use_esm=True) -> None: |
|
self.use_esm = use_esm |
|
self._tokenizer = PTMTokenizer() |
|
self._model = load_ckpt(ckpt_path, self.tokenizer, device) |
|
self._device = device |
|
self._model.to(device) |
|
self._model.eval() |
|
self.esm_model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
self.batch_converter = self.alphabet.get_batch_converter() |
|
self.esm_model.eval() |
|
|
|
@property |
|
def model(self) -> torch.nn.Module: |
|
return self._model |
|
|
|
|
|
@property |
|
def tokenizer(self) -> PTMTokenizer: |
|
return self._tokenizer |
|
|
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return self._device |
|
|
|
|
|
|
|
def infer(self, seq: str) -> Output: |
|
input_id = self.tokenizer(seq) |
|
input_ids = torch.tensor(input_id,device=self.device).unsqueeze(0) |
|
outputs = self._infer(input_ids) |
|
return outputs |
|
|
|
@torch.no_grad() |
|
def _infer(self, input_ids): |
|
if self.use_esm: |
|
esm_input_ids = make_esm_input_ids(input_ids, self.tokenizer) |
|
embedding = compute_esm_embedding( |
|
self.tokenizer, self.esm_model, self.batch_converter, esm_input_ids |
|
) |
|
else: |
|
embedding = None |
|
outputs = self.model(input_ids, embedding=embedding) |
|
return outputs |
|
|
|
|
|
def infer_batch(self, seqs: list) -> Output: |
|
input_ids = self.tokenizer(seqs) |
|
input_ids = pad_sequence( |
|
[torch.tensor(x) for x in input_ids], |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id, |
|
) |
|
input_ids = torch.tensor(input_ids,device=self.device) |
|
outputs = self._infer(input_ids) |
|
return outputs |
|
|
|
def __call__(self, seq: Union[str, List]) -> Output: |
|
if isinstance(seq, str): |
|
return self.infer(seq) |
|
elif isinstance(seq, list): |
|
return self.infer_batch(seq) |
|
else: |
|
raise ValueError("Input must be a string or a list of strings, got {}".format(type(seq))) |
|
|
|
|
|
if __name__ == "__main__": |
|
ckpt_path = "ckpt/bi_mamba-esm-ptm_token_input/best.ckpt" |
|
mamba = PTMMamba(ckpt_path,device='cuda:0') |
|
seq = '<N-acetylmethionine>EAD<Phosphoserine>PAGPGAPEPLAEGAAAEFS<Phosphoserine>LLRRIKGKLFTWNILKTIALGQMLSLCICGTAITSQYLAERYKVNTPMLQSFINYCLLFLIYTVMLAFRSGSDNLLVILKRKWWKYILLGLADVEANYVIVRAYQYTTLTSVQLLDCFGIPVLMALSWFILHARYRVIHFIAVAVCLLGVGTMVGADILAGREDNSGSDVLIGDILVLLGASLYAISNVCEEYIVKKLSRQEFLGMVGLFGTIISGIQLLIVEYKDIASIHWDWKIALLFVAFALCMFCLYSFMPLVIKVTSATSVNLGILTADLYSLFVGLFLFGYKFSGLYILSFTVIMVGFILYCSTPTRTAEPAESSVPPVTSIGIDNLGLKLEENLQETH<Phosphoserine>AVL' |
|
output = mamba(seq) |
|
print(output.logits.shape) |
|
print(output.hidden_states.shape) |
|
|
|
|
|
|