""" This file defines the layout of the app including the header, sidebar, and tabs in the main content area. """ #--------------------------------------------------------------------------------------- # Imports import streamlit as st import streamlit.components.v1 as components from PIL import Image import pandas as pd import yaml from src.data_preprocessing.create_descriptors import handle_inputs from src.app.constants import (summary_text, mhnfs_text, citation_text, few_shot_learning_text, under_the_hood_text, usage_text, data_text, trust_text, example_trustworthy_text, example_nottrustworthy_text) #--------------------------------------------------------------------------------------- # Global variables MAX_INPUT_LENGTH = 20 #--------------------------------------------------------------------------------------- # Functions class LayoutMaker(): """ This class includes all the design choices regarding the layout of the app. This class can be used in the main file to define header, sidebar, and main content area. """ def __init__(self): # Initialize the inputs dictionary self.inputs = dict() # this will be the storage for query and support set inputs self.inputs_lists = dict() # Initialize prediction storage self.predictions = None # Buttons self.buttons = dict() # this will be the storage for buttons # content self.summary_text = summary_text self.mhnfs_text = mhnfs_text self.citation_text = citation_text self.few_shot_learning_text = few_shot_learning_text self.under_the_hood_text = under_the_hood_text self.usage_text = usage_text self.data_text = data_text self.trust_text = trust_text self.example_trustworthy_text = example_trustworthy_text self.example_nottrustworthy_text = example_nottrustworthy_text self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/" "trustworthy_example.csv") self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/" "nottrustworthy_example.csv") self.max_input_length = MAX_INPUT_LENGTH def make_sidebar(self): """ This function defines the sidebar of the app. It includes the logo, query box, support set boxes, and predict buttons. It returns the stored inputs (for query and support set) and the buttons which allow for user interactions. """ with st.sidebar: # Logo logo = Image.open("./assets/logo.png") st.image(logo) st.divider() # Query box self._make_query_box() st.divider() # Support set actives box self._make_active_support_set_box() st.divider() # Support set inactives box self._make_inactive_support_set_box() st.divider() # Predict buttons self.buttons["predict"] = st.button("Predict...") self.buttons["reset"] = st.button("Reset") return self.inputs, self.buttons def make_header(self): """ This function defines the header of the app. It consists only of a png image in which the title and an overview is given. """ header_container = st.container() with header_container: header = Image.open("./assets/header.png") st.image(header) def make_main_content_area(self, predictor, inputs, buttons, create_prediction_df: callable, create_molecule_grid_plot: callable): tab1, tab2, tab3, tab4 = st.tabs(["Predictions", "Paper / Cite", "Additional Information", "Examples"]) # Results tab with tab1: self._fill_tab_with_results_content(predictor, inputs, buttons, create_prediction_df, create_molecule_grid_plot) # Paper tab with tab2: self._fill_paper_and_citation_tab() # More explanations tab with tab3: self._fill_more_explanations_tab() with tab4: self._fill_examples_tab() def _make_query_box(self): """ This function a) defines the query box and b) stores the query input in the inputs dictionary """ st.info(":blue[Molecules to predict:]", icon="❓") query_container = st.container() with query_container: input_choice = st.radio( "Input your data in SMILES notation via:", ["Text box", "CSV upload"] ) if input_choice == "Text box": query_input = st.text_area( label="SMILES input for query molecules", label_visibility="hidden", key="query_textbox", value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " "N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, " "Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " "CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, " "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, " "COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, " "Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, " "CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, " "N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, " "COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F" ) elif input_choice == "CSV upload": query_file = st.file_uploader(key="query_csv", label = "CSV upload for query mols", label_visibility="hidden") if query_file is not None: query_input = pd.read_csv(query_file) else: query_input = None # Update storage self.inputs["query"] = query_input def _make_active_support_set_box(self): """ This function a) defines the active support set box and b) stores the active support set input in the inputs dictionary """ st.info(":blue[Known active molecules:]", icon="✨") active_container = st.container() with active_container: active_input_choice = st.radio( "Input your data in SMILES notation via:", ["Text box", "CSV upload"], key="active_input_choice", ) if active_input_choice == "Text box": support_active_input = st.text_area( label="SMILES input for active support set molecules", label_visibility="hidden", key="active_textbox", value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, " "Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O" ) elif active_input_choice == "CSV upload": support_active_file = st.file_uploader( key="support_active_csv", label = "CSV upload for active support set molecules", label_visibility="hidden" ) if support_active_file is not None: support_active_input = pd.read_csv(support_active_file) else: support_active_input = None # Update storage self.inputs["support_active"] = support_active_input def _make_inactive_support_set_box(self): st.info(":blue[Known inactive molecules:]", icon="✨") inactive_container = st.container() with inactive_container: inactive_input_choice = st.radio( "Input your data in SMILES notation via:", ["Text box", "CSV upload"], key="inactive_input_choice", ) if inactive_input_choice == "Text box": support_inactive_input = st.text_area( label="SMILES input for inactive support set molecules", label_visibility="hidden", key="inactive_textbox", value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, " "CSc1nc(C)nc(OCC(=O)O)c1C#N" ) elif inactive_input_choice == "CSV upload": support_inactive_file = st.file_uploader( key="support_inactive_csv", label = "CSV upload for inactive support set molecules", label_visibility="hidden" ) if support_inactive_file is not None: support_inactive_input = pd.read_csv( support_inactive_file ) else: support_inactive_input = None # Update storage self.inputs["support_inactive"] = support_inactive_input def _fill_tab_with_results_content(self, predictor, inputs, buttons, create_prediction_df, create_molecule_grid_plot): tab_container = st.container() with tab_container: # Info st.info(":blue[Summary:]", icon="🚀") st.markdown(self.summary_text) # Results st.info(":blue[Results:]",icon="👨‍💻") if buttons['predict']: # Check 1: Are all inputs provided? if (inputs['query'] is None or inputs['support_active'] is None or inputs['support_inactive'] is None): st.error("You didn't provide all necessary inputs.\n\n" "Please provide all three necessary inputs via the " "sidebar and hit the predict button again.") else: # Check 2: Less than max allowed molecules provided? max_input_length = 0 for key, input in inputs.items(): input_list = handle_inputs(input) self.inputs_lists[key] = input_list max_input_length = max(max_input_length, len(input_list)) if max_input_length > self.max_input_length: st.error("You provided too many molecules. The number of " "molecules for each input is restricted to " f"{self.max_input_length}.\n\n" "For larger screenings, we suggest to clone the repo " "and to run the model locally.") else: # Progress bar progress_bar_text = ("I'm predicting activities. This might " "need some minutes. Please wait...") progress_bar = st.progress(50, text=progress_bar_text) # Results table df = self._predict_and_create_results_table(predictor, inputs, create_prediction_df) progress_bar_text = ("Done. Here are the results:") progress_bar = progress_bar.progress(100, text=progress_bar_text) st.dataframe(df, use_container_width=True) col1, col2, col3, col4 = st.columns([1,1,1,1]) # Provide download button for predictions with col2: self.buttons["download_results"] = st.download_button( "Download predictions as CSV", self._convert_df_to_binary(df), file_name="predictions.csv", ) # Provide download button for inputs with col3: with open("inputs.yml", 'w') as fl: self.buttons["download_inputs"] = st.download_button( "Download inputs as YML", self._convert_to_yml(self.inputs_lists), file_name="inputs.yml", ) st.divider() # Results grid st.info(":blue[Grid plot of the predicted molecules:]", icon="📊") mol_html_grid = create_molecule_grid_plot(df) components.html(mol_html_grid, height=1000, scrolling=True) elif buttons['reset']: self._reset() def _fill_paper_and_citation_tab(self): st.info(":blue[**Paper: Context-enriched molecule representations improve " "few-shot drug discovery**]", icon="📄") st.markdown(self.mhnfs_text, unsafe_allow_html=True) st.image("./assets/mhnfs_overview.png") st.write("") st.write("") st.write("") st.info(":blue[**Cite us / BibTex**]", icon="📚") st.markdown(self.citation_text) def _fill_more_explanations_tab(self): st.info(":blue[**Under the hood**]", icon="⚙️") st.markdown(self.under_the_hood_text, unsafe_allow_html=True) st.write("") st.write("") st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="🎯") st.markdown(self.few_shot_learning_text, unsafe_allow_html=True) st.write("") st.write("") st.info(":blue[**Usage**]", icon="🎛️") st.markdown(self.usage_text, unsafe_allow_html=True) st.write("") st.write("") st.info(":blue[**How to provide the data**]", icon="📀") st.markdown(self.data_text, unsafe_allow_html=True) st.write("") st.write("") st.info(":blue[**When to trust the predictions**]", icon="🔍") st.markdown(self.trust_text, unsafe_allow_html=True) def _fill_examples_tab(self): st.info(":blue[**Example for trustworthy predictions**]", icon="✅") st.markdown(self.example_trustworthy_text, unsafe_allow_html=True) st.dataframe(self.df_trustworthy, use_container_width=True) st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" "0.96**)") prediction_plot_tw = Image.open("./assets/example_csv/predictions/" "trustworthy_example.png") st.image(prediction_plot_tw) st.write("") st.write("") st.info(":blue[**Example for not trustworthy predictions**]", icon="⛔️") st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True) st.dataframe(self.df_nottrustworthy, use_container_width=True) st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" "0.42**)") prediction_plot_ntw = Image.open("./assets/example_csv/predictions/" "nottrustworthy_example.png") st.image(prediction_plot_ntw) def _predict_and_create_results_table(self, predictor, inputs, create_prediction_df: callable): df = create_prediction_df(predictor, inputs['query'], inputs['support_active'], inputs['support_inactive']) return df def _reset(self): keys = list(st.session_state.keys()) for key in keys: st.session_state.pop(key) def _convert_df_to_binary(_self, df): return df.to_csv(index=False).encode('utf-8') def _convert_to_yml(_self, inputs): return yaml.dump(inputs) content = """ # Usage As soon as you have a few active and inactive molecules for your task, you can provide them here and make predictions for new molecules. ## About few-shot learning and the model MHNfs **Few-shot learning** is a machine learning sub-field which aims to provide predictive models for scenarios in which only little data is known/available. **MHNfs** is a few-shot learning model which is specifically designed for drug discovery applications. It is built to use the input prompts in a way such that the provided available knowledge - i.e. the known active and inactive molecules - functions as context to predict the activity of the new requested molecules. Precisely, the provided active and inactive molecules are associated with a large set of general molecules - called context molecules - to enrich the provided information and to remove spurious correlations arising from the decoration of molecules. This is analogous to a Large Language Model which would not only use the provided information in the current prompt as context but would also have access to way more information, e.g. a prompting history. ## How to provide the data * Molecules have to be provided in SMILES format. * You can provide the molecules via the text boxes or via CSV upload. - Text box: Replace the pseudo input by directly typing your molecules into the text box. Please separate the molecules by comma. - CSV upload: Upload a CSV file with the molecules. * The CSV file should include a smiles column (both upper and lower case "SMILES" are accepted). * All other columns will be ignored. ## When to trust the predictions Just like all other machine learning models, the performance of MHNfs varies and, generally, the model works well if the task is somehow close to tasks which were used to train the model. The model performance for very different tasks is unclear and might be poor. MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly 5000 tasks were used for training, rest for evaluation). The training tasks are listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. """ return content