File size: 6,356 Bytes
84bfd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import sys
from typing import List
from tqdm import tqdm
import pandas as pd
import numpy as np
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import time
import requests
import joblib
# from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder
from Bio import SeqIO
import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import torch
from typing import *
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")

from xgboost import XGBClassifier, DMatrix

from model.barlow_twins import BarlowTwins

# sys.path.append("../utils/")
from utils.sequence import uniprot2sequence, encode_sequences



class DTIModel:
    def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
        self.bt_model = BarlowTwins()
        self.bt_model.load_model(bt_model_path)

        self.gbm_model = XGBClassifier()
        self.gbm_model.load_model(gbm_model_path)

        self.encoder = encoder

        self.smiles_cache = {}
        self.sequence_cache = {}

    def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False):
        if smiles is None:
            return None
        # Check if the SMILES is already in the cache
        if smiles in self.smiles_cache:
            return self.smiles_cache[smiles]
        else:
            # Encode the SMILES and store it in the cache
            try:
                mol = Chem.MolFromSmiles(smiles)
                morgan = AllChem.GetMorganFingerprintAsBitVect(
                    mol,
                    radius=radius,
                    nBits=bits,
                    useFeatures=features,
                )
                morgan = np.array(morgan)
                self.smiles_cache[smiles] = morgan
                return morgan
            except Exception as e:
                print(f"Failed to encode SMILES: {smiles}")
                print(e)
                return None

    def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
        morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
        return np.array(morgan)

    def _encode_sequence(self, sequence: str):
        # Clear torch cache
        torch.cuda.empty_cache()
        if sequence is None:
            return None
        # Check if the sequence is already in the cache
        if sequence in self.sequence_cache:
            return self.sequence_cache[sequence]
        else:
            # Encode the sequence and store it in the cache
            try:
                encoded_sequence = encode_sequences([sequence], encoder=self.encoder)
                self.sequence_cache[sequence] = encoded_sequence
                return encoded_sequence
            except Exception as e:
                print(f"Failed to encode sequence: {sequence}")
                print(e)
                return None

    def _encode_sequence_mult(self, sequences: List[str]):
        seq = [self._encode_sequence(sequence) for sequence in sequences]
        return np.array(seq)

    def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
        if drug_emb.shape[0] < target_emb.shape[0]:
            drug_emb = np.tile(drug_emb, (len(target_emb), 1))
        elif len(drug_emb) > len(target_emb):
            target_emb = np.tile(target_emb, (len(drug_emb), 1))
        emb = self.bt_model.zero_shot(drug_emb, target_emb)

        if pred_leaf:
            d_emb = DMatrix(emb)
            return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True)
        else:
            return self.gbm_model.predict_proba(emb)[:, 1]

    def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False):
        if isinstance(drug, str):
            drug_emb = self._encode_smiles(drug)
        else:
            drug_emb = self._encode_smiles_mult(drug)
        target_emb = self._encode_sequence(target)
        return self.__predict_pair(drug_emb, target_emb, pred_leaf)
    
    def get_leaf_weights(self):
        return self.gbm_model.get_booster().get_score(importance_type="weight")

    def _predict_fasta(self, drug: str, fasta_path: str):
        drug_emb = self._encode_smiles(drug)

        results = []
        # Extract targets from fasta
        for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"):
            target_emb = self._encode_sequence(str(target.seq))
            pred = self.__predict_pair(drug_emb, target_emb)
            results.append(
                {
                    "drug": drug,
                    "target": target.id,
                    "name": target.name,
                    "description": target.description,
                    "prediction": pred[0]
                }
            )
        return pd.DataFrame(results)

    def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120):
        def process_target(target, results):
            target_emb = self._encode_sequence(str(target.seq))
            pred = self.__predict_pair(drug_emb, target_emb)
            results.append({
                "drug": drug,
                "target": target.id,
                "name": target.name,
                "description": target.description,
                "prediction": pred[0]
            })

        drug_emb = self._encode_smiles(drug)
        results = []

        # First, count the total number of records for the progress bar
        total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta"))

        # Extract targets from fasta with a properly initialized tqdm progress bar
        for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"):
            thread_results = []
            thread = threading.Thread(target=process_target, args=(target, thread_results))
            thread.start()
            thread.join(timeout_seconds)
            if thread.is_alive():
                print(f"Skipping target {target.id} due to timeout")
                continue
            results.extend(thread_results)

        return pd.DataFrame(results)

    def predict_uniprot(self, drug: List[str] or str, uniprot_id: str):
        return self.predict(drug, uniprot2sequence(uniprot_id))