Spaces:
Running
Running
from bmfm_sm.api.smmv_api import SmallMoleculeMultiViewModel | |
from bmfm_sm.core.data_modules.namespace import LateFusionStrategy | |
from bmfm_sm.api.dataset_registry import DatasetRegistry | |
import gradio as gr | |
examples = [ | |
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "BACE"], | |
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "BBBP"], | |
["[N+](=O)([O-])[O-]", "CLINTOX"], | |
["OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O", "ESOL"], | |
["CN(C)C(=O)c1ccc(cc1)OC", "FREESOLV"], | |
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "HIV"], | |
["Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14", "LIPOPHILICITY"], | |
["Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C", "MUV"], | |
["C([H])([H])([H])[H]", "QM7"], | |
["C(CNCCNCCNCCN)N", "SIDER"], | |
["CCOc1ccc2nc(S(N)(=O)=O)sc2c1", "TOX21"], | |
["CSc1nc(N)nc(-c2cccc(-c3ccc4[nH]ccc4c3)c2)n1", "Pretrained"], | |
] | |
examples_new = [ | |
["O=C1CCCN1", "ESOL"], | |
["CC1=CC(=O)[C@@H](CC1)C(C)C", "FREESOLV"], | |
["Clc1ccc(CN2CCNCC2)cc1C(=O)NCC34CC5CC(CC(C5)C3)C4", "LIPOPHILICITY"], | |
["Clc1ccc(nc1)C(=O)Nc1cc([C@]2([NH+]=C(N)[C@@H]3[C@H](C2)C3)C)c(F)cc1", "BACE"], | |
["OC(C1CCCCN1)c2cc(nc3c2cccc3C(F)(F)F)C(F)(F)F", "BBBP"], | |
["C1CN(CCN1C(=O)CCBr)C(=O)CCBr", "CLINTOX"], | |
["COc1cc2c(c(OC3OC(CO)C(O)C(O)C3O)c1)C(=O)CC(c1ccc(O)cc1)O2", "HIV"], | |
["[H]C1=C([H])C2([H])OC2([H])C([H])([H])C1([H])[H]", "QM7"], | |
["CCCC1=CC(=O)NC(=S)N1", "SIDER"], | |
["CCCC(=O)O[C@]1(C(=O)CCl)[C@@H](C)C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(F)C(=O)C[C@@]21C", "TOX21"], | |
["O=C(Nc1cccc2c1N=S=N2)C1CC(=O)N(c2ccccc2)C1", "MUV"], | |
["CSc1nc(N)nc(-c2cccc(-c3ccc4[nH]ccc4c3)c2)n1", "Pretrained"], | |
] | |
base_huggingface_path = 'ibm/biomed.sm.mv-te-84m' | |
finetuned_huggingface_path = "-MoleculeNet-ligand_scaffold-" | |
available_datasets = { | |
"BACE": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-BACE-101", | |
"BBBP": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-BBBP-101", | |
"CLINTOX": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-CLINTOX-101", | |
"ESOL": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-ESOL-101", | |
"FREESOLV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-FREESOLV-101", | |
"HIV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-HIV-101", | |
"LIPOPHILICITY": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-LIPOPHILICITY-101", | |
"MUV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-MUV-101", | |
"QM7": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-QM7-101", | |
"SIDER": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-SIDER-101", | |
"TOX21": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-TOX21-101", | |
} | |
class PretrainedSMMVPipeline: | |
def __init__(self, pretrained_model_name_or_path: str): | |
self.model = SmallMoleculeMultiViewModel.from_pretrained( | |
LateFusionStrategy.ATTENTIONAL, | |
model_path=pretrained_model_name_or_path, | |
huggingface=True | |
) | |
def __call__(self, smiles: str) -> float: | |
emb = SmallMoleculeMultiViewModel.get_embeddings( | |
smiles=smiles, | |
pretrained_model=self.model | |
) | |
return str(emb.tolist()) | |
class FinetunedSMMVPipeline: | |
def __init__(self, dataset:str, pretrained_model_name_or_path: str): | |
dataset_registry = DatasetRegistry() | |
self.ds = dataset_registry.get_dataset_info(dataset) | |
self.model = SmallMoleculeMultiViewModel.from_finetuned( | |
self.ds, | |
model_path=pretrained_model_name_or_path, | |
inference_mode=True, | |
huggingface=True | |
) | |
def __call__(self, smiles: str) -> float: | |
prediction = SmallMoleculeMultiViewModel.get_predictions( | |
smiles, | |
self.ds, | |
finetuned_model=self.model | |
) | |
return str(prediction.tolist()) | |
def deploy(): | |
print(f"Loading checkpoint: Pretrained from {base_huggingface_path}") | |
pipeline_pretrained = PretrainedSMMVPipeline(base_huggingface_path) | |
pipelines_finetuned = {} | |
pipelines_finetuned["Pretrained"] = pipeline_pretrained | |
for dataset, huggingface_path in available_datasets.items(): | |
print(f"Loading checkpoint: {dataset} from {huggingface_path}") | |
pipelines_finetuned[dataset] = FinetunedSMMVPipeline( | |
dataset=dataset, | |
pretrained_model_name_or_path=huggingface_path | |
) | |
def pipeline( | |
smiles: str, | |
dataset: str | |
): | |
return pipelines_finetuned[dataset](smiles) | |
smiles_input = gr.Textbox(placeholder="SMILES", label="SMILES") | |
datasets_input = gr.Dropdown( | |
choices=list(pipelines_finetuned.keys()), | |
label="Checkpoint", | |
) | |
text_output = gr.Textbox( | |
max_lines=10, | |
label="Prediction", | |
) | |
gradio_app = gr.Interface( | |
pipeline, | |
inputs=[smiles_input, datasets_input], | |
outputs=text_output, | |
examples=examples_new, | |
cache_mode="lazy", | |
examples_per_page=20, | |
title="ibm/biomed.sm.mv-te-84m Property Prediction Tasks", | |
description="Predictions for Pretrained show embedding vector of base model. Predictions for datasets show output of model finetuned on that task", | |
theme="Zarkel/IBM_Carbon_Theme" | |
) | |
gradio_app.launch() | |
if __name__ == "__main__": | |
deploy() | |