|
""" |
|
This module provides a simple predict function for the MHNfs model. |
|
It loads the model from the provided checkpoint, creates necessary helper inputs |
|
and makes predictions for a list of molecules |
|
""" |
|
|
|
|
|
|
|
import pandas as pd |
|
import pytorch_lightning as pl |
|
import streamlit as st |
|
|
|
from src.data_preprocessing.create_model_inputs import (create_query_input, |
|
create_support_set_input) |
|
from src.mhnfs.model import MHNfs |
|
|
|
|
|
|
|
|
|
class ActivityPredictor: |
|
|
|
def __init__(self): |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
pl.seed_everything(1234) |
|
current_loc = __file__.rsplit("/",2)[0] |
|
model = MHNfs.load_from_checkpoint(current_loc + |
|
"/assets/mhnfs_data/" |
|
"mhnfs_checkpoint.ckpt") |
|
model._update_context_set_embedding() |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
self.model = load_model() |
|
|
|
|
|
self.query_molecules = None |
|
|
|
def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles): |
|
|
|
|
|
|
|
self.query_molecules = query_smiles |
|
query_input = create_query_input(query_smiles) |
|
|
|
|
|
support_actives_input, support_actives_size = create_support_set_input( |
|
support_activces_smiles |
|
) |
|
|
|
|
|
support_inactives_input, support_inactives_size = create_support_set_input( |
|
support_inactives_smiles |
|
) |
|
|
|
|
|
predictions = self.model( |
|
query_input, |
|
support_actives_input, |
|
support_inactives_input, |
|
support_actives_size, |
|
support_inactives_size, |
|
) |
|
|
|
preds_numpy = predictions.detach().numpy().flatten() |
|
|
|
|
|
return preds_numpy |
|
|
|
def _return_query_mols_as_list(self): |
|
if isinstance(self.query_molecules, list): |
|
return self.query_molecules |
|
elif isinstance(self.query_molecules, str): |
|
smiles_list = self.query_molecules.split(",") |
|
smiles_list_cleaned = [smiles.strip() for smiles in smiles_list] |
|
return smiles_list_cleaned |
|
elif isinstance(self.query_molecules, pd.DataFrame): |
|
return self.query_molecules.smiles.tolist() |
|
elif isinstance(self.query_molecules, type(None)): |
|
raise ValueError("No query molecules have been stored yet." |
|
"Run predict-function first.") |
|
else: |
|
raise TypeError("Type of query molecules not recognized." |
|
"Please check input type.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
predictor = ActivityPredictor() |
|
|
|
|
|
query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] |
|
support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] |
|
support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] |
|
|
|
|
|
predictions = predictor.predict(query_smiles, |
|
support_actives_smiles, |
|
support_inactives_smiles) |
|
|
|
print(predictions) |