from bmfm_sm.api.smmv_api import SmallMoleculeMultiViewModel from bmfm_sm.core.data_modules.namespace import LateFusionStrategy from bmfm_sm.api.dataset_registry import DatasetRegistry import gradio as gr examples = [ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "BACE"], ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "BBBP"], ["[N+](=O)([O-])[O-]", "CLINTOX"], ["OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O", "ESOL"], ["CN(C)C(=O)c1ccc(cc1)OC", "FREESOLV"], ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "HIV"], ["Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14", "LIPOPHILICITY"], ["Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C", "MUV"], ["C([H])([H])([H])[H]", "QM7"], ["C(CNCCNCCNCCN)N", "SIDER"], ["CCOc1ccc2nc(S(N)(=O)=O)sc2c1", "TOX21"], ["CSc1nc(N)nc(-c2cccc(-c3ccc4[nH]ccc4c3)c2)n1", "Pretrained"], ] examples_new = [ ["O=C1CCCN1", "ESOL"], ["CC1=CC(=O)[C@@H](CC1)C(C)C", "FREESOLV"], ["Clc1ccc(CN2CCNCC2)cc1C(=O)NCC34CC5CC(CC(C5)C3)C4", "LIPOPHILICITY"], ["Clc1ccc(nc1)C(=O)Nc1cc([C@]2([NH+]=C(N)[C@@H]3[C@H](C2)C3)C)c(F)cc1", "BACE"], ["OC(C1CCCCN1)c2cc(nc3c2cccc3C(F)(F)F)C(F)(F)F", "BBBP"], ["C1CN(CCN1C(=O)CCBr)C(=O)CCBr", "CLINTOX"], ["COc1cc2c(c(OC3OC(CO)C(O)C(O)C3O)c1)C(=O)CC(c1ccc(O)cc1)O2", "HIV"], ["[H]C1=C([H])C2([H])OC2([H])C([H])([H])C1([H])[H]", "QM7"], ["CCCC1=CC(=O)NC(=S)N1", "SIDER"], ["CCCC(=O)O[C@]1(C(=O)CCl)[C@@H](C)C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(F)C(=O)C[C@@]21C", "TOX21"], ["O=C(Nc1cccc2c1N=S=N2)C1CC(=O)N(c2ccccc2)C1", "MUV"], ["CSc1nc(N)nc(-c2cccc(-c3ccc4[nH]ccc4c3)c2)n1", "Pretrained"], ] base_huggingface_path = 'ibm/biomed.sm.mv-te-84m' finetuned_huggingface_path = "-MoleculeNet-ligand_scaffold-" available_datasets = { "BACE": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-BACE-101", "BBBP": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-BBBP-101", "CLINTOX": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-CLINTOX-101", "ESOL": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-ESOL-101", "FREESOLV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-FREESOLV-101", "HIV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-HIV-101", "LIPOPHILICITY": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-LIPOPHILICITY-101", "MUV": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-MUV-101", "QM7": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-QM7-101", "SIDER": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-SIDER-101", "TOX21": "ibm/biomed.sm.mv-te-84m-MoleculeNet-ligand_scaffold-TOX21-101", } class PretrainedSMMVPipeline: def __init__(self, pretrained_model_name_or_path: str): self.model = SmallMoleculeMultiViewModel.from_pretrained( LateFusionStrategy.ATTENTIONAL, model_path=pretrained_model_name_or_path, huggingface=True ) def __call__(self, smiles: str) -> float: emb = SmallMoleculeMultiViewModel.get_embeddings( smiles=smiles, pretrained_model=self.model ) return str(emb.tolist()) class FinetunedSMMVPipeline: def __init__(self, dataset:str, pretrained_model_name_or_path: str): dataset_registry = DatasetRegistry() self.ds = dataset_registry.get_dataset_info(dataset) self.model = SmallMoleculeMultiViewModel.from_finetuned( self.ds, model_path=pretrained_model_name_or_path, inference_mode=True, huggingface=True ) def __call__(self, smiles: str) -> float: prediction = SmallMoleculeMultiViewModel.get_predictions( smiles, self.ds, finetuned_model=self.model ) return str(prediction.tolist()) def deploy(): print(f"Loading checkpoint: Pretrained from {base_huggingface_path}") pipeline_pretrained = PretrainedSMMVPipeline(base_huggingface_path) pipelines_finetuned = {} pipelines_finetuned["Pretrained"] = pipeline_pretrained for dataset, huggingface_path in available_datasets.items(): print(f"Loading checkpoint: {dataset} from {huggingface_path}") pipelines_finetuned[dataset] = FinetunedSMMVPipeline( dataset=dataset, pretrained_model_name_or_path=huggingface_path ) def pipeline( smiles: str, dataset: str ): return pipelines_finetuned[dataset](smiles) smiles_input = gr.Textbox(placeholder="SMILES", label="SMILES") datasets_input = gr.Dropdown( choices=list(pipelines_finetuned.keys()), label="Checkpoint", ) text_output = gr.Textbox( max_lines=10, label="Prediction", ) gradio_app = gr.Interface( pipeline, inputs=[smiles_input, datasets_input], outputs=text_output, examples=examples_new, cache_mode="lazy", examples_per_page=20, title="ibm/biomed.sm.mv-te-84m Property Prediction Tasks", description="Predictions for Pretrained show embedding vector of base model. Predictions for datasets show output of model finetuned on that task", theme="Zarkel/IBM_Carbon_Theme" ) gradio_app.launch() if __name__ == "__main__": deploy()