BarlowDTI / model /model.py
mschuh's picture
Fix ZeroGPU
fa7f380 verified
import sys
from typing import List
from tqdm import tqdm
import pandas as pd
import numpy as np
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import time
import requests
import joblib
# from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder
from Bio import SeqIO
import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import torch
from typing import *
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")
from xgboost import XGBClassifier, DMatrix
from model.barlow_twins import BarlowTwins
# sys.path.append("../utils/")
from utils.sequence import uniprot2sequence, encode_sequences
class DTIModel:
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
self.bt_model = BarlowTwins()
self.bt_model.load_model(bt_model_path)
self.gbm_model = XGBClassifier()
self.gbm_model.load_model(gbm_model_path)
self.encoder = encoder
self.smiles_cache = {}
self.sequence_cache = {}
def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False):
if smiles is None:
return None
# Check if the SMILES is already in the cache
if smiles in self.smiles_cache:
return self.smiles_cache[smiles]
else:
# Encode the SMILES and store it in the cache
try:
mol = Chem.MolFromSmiles(smiles)
morgan = AllChem.GetMorganFingerprintAsBitVect(
mol,
radius=radius,
nBits=bits,
useFeatures=features,
)
morgan = np.array(morgan)
self.smiles_cache[smiles] = morgan
return morgan
except Exception as e:
print(f"Failed to encode SMILES: {smiles}")
print(e)
return None
def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
return np.array(morgan)
def _encode_sequence(self, sequence: str):
# Clear torch cache
torch.cuda.empty_cache()
if sequence is None:
return None
# Check if the sequence is already in the cache
if sequence in self.sequence_cache:
return self.sequence_cache[sequence]
else:
# Encode the sequence and store it in the cache
try:
encoded_sequence = encode_sequences([sequence], encoder=self.encoder)
self.sequence_cache[sequence] = encoded_sequence
return encoded_sequence
except Exception as e:
print(f"Failed to encode sequence: {sequence}")
print(e)
return None
def _encode_sequence_mult(self, sequences: List[str]):
seq = [self._encode_sequence(sequence) for sequence in sequences]
return np.array(seq)
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
if drug_emb.shape[0] < target_emb.shape[0]:
drug_emb = np.tile(drug_emb, (len(target_emb), 1))
elif len(drug_emb) > len(target_emb):
target_emb = np.tile(target_emb, (len(drug_emb), 1))
emb = self.bt_model.zero_shot(drug_emb, target_emb)
if pred_leaf:
d_emb = DMatrix(emb)
return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True)
else:
return self.gbm_model.predict_proba(emb)[:, 1]
def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False):
if isinstance(drug, str):
drug_emb = self._encode_smiles(drug)
else:
drug_emb = self._encode_smiles_mult(drug)
target_emb = self._encode_sequence(target)
return self.__predict_pair(drug_emb, target_emb, pred_leaf)
def get_leaf_weights(self):
return self.gbm_model.get_booster().get_score(importance_type="weight")
def _predict_fasta(self, drug: str, fasta_path: str):
drug_emb = self._encode_smiles(drug)
results = []
# Extract targets from fasta
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"):
target_emb = self._encode_sequence(str(target.seq))
pred = self.__predict_pair(drug_emb, target_emb)
results.append(
{
"drug": drug,
"target": target.id,
"name": target.name,
"description": target.description,
"prediction": pred[0]
}
)
return pd.DataFrame(results)
def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120):
def process_target(target, results):
target_emb = self._encode_sequence(str(target.seq))
pred = self.__predict_pair(drug_emb, target_emb)
results.append({
"drug": drug,
"target": target.id,
"name": target.name,
"description": target.description,
"prediction": pred[0]
})
drug_emb = self._encode_smiles(drug)
results = []
# First, count the total number of records for the progress bar
total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta"))
# Extract targets from fasta with a properly initialized tqdm progress bar
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"):
thread_results = []
thread = threading.Thread(target=process_target, args=(target, thread_results))
thread.start()
thread.join(timeout_seconds)
if thread.is_alive():
print(f"Skipping target {target.id} due to timeout")
continue
results.extend(thread_results)
return pd.DataFrame(results)
def predict_uniprot(self, drug: List[str] or str, uniprot_id: str):
return self.predict(drug, uniprot2sequence(uniprot_id))