mhnfs / src /app /layout.py
Tschoui's picture
Changed default molecules in text boxes
ce05748 verified
"""
This file defines the layout of the app including the header, sidebar, and tabs in the
main content area.
"""
#---------------------------------------------------------------------------------------
# Imports
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import pandas as pd
import yaml
from src.data_preprocessing.create_descriptors import handle_inputs
from src.app.constants import (summary_text,
mhnfs_text,
citation_text,
few_shot_learning_text,
under_the_hood_text,
usage_text,
data_text,
trust_text,
example_trustworthy_text,
example_nottrustworthy_text)
#---------------------------------------------------------------------------------------
# Global variables
MAX_INPUT_LENGTH = 20
#---------------------------------------------------------------------------------------
# Functions
class LayoutMaker():
"""
This class includes all the design choices regarding the layout of the app. This
class can be used in the main file to define header, sidebar, and main content area.
"""
def __init__(self):
# Initialize the inputs dictionary
self.inputs = dict() # this will be the storage for query and support set inputs
self.inputs_lists = dict()
# Initialize prediction storage
self.predictions = None
# Buttons
self.buttons = dict() # this will be the storage for buttons
# content
self.summary_text = summary_text
self.mhnfs_text = mhnfs_text
self.citation_text = citation_text
self.few_shot_learning_text = few_shot_learning_text
self.under_the_hood_text = under_the_hood_text
self.usage_text = usage_text
self.data_text = data_text
self.trust_text = trust_text
self.example_trustworthy_text = example_trustworthy_text
self.example_nottrustworthy_text = example_nottrustworthy_text
self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/"
"trustworthy_example.csv")
self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/"
"nottrustworthy_example.csv")
self.max_input_length = MAX_INPUT_LENGTH
def make_sidebar(self):
"""
This function defines the sidebar of the app. It includes the logo, query box,
support set boxes, and predict buttons.
It returns the stored inputs (for query and support set) and the buttons which
allow for user interactions.
"""
with st.sidebar:
# Logo
logo = Image.open("./assets/logo.png")
st.image(logo)
st.divider()
# Query box
self._make_query_box()
st.divider()
# Support set actives box
self._make_active_support_set_box()
st.divider()
# Support set inactives box
self._make_inactive_support_set_box()
st.divider()
# Predict buttons
self.buttons["predict"] = st.button("Predict...")
self.buttons["reset"] = st.button("Reset")
return self.inputs, self.buttons
def make_header(self):
"""
This function defines the header of the app. It consists only of a png image
in which the title and an overview is given.
"""
header_container = st.container()
with header_container:
header = Image.open("./assets/header.png")
st.image(header)
def make_main_content_area(self,
predictor,
inputs,
buttons,
create_prediction_df: callable,
create_molecule_grid_plot: callable):
tab1, tab2, tab3, tab4 = st.tabs(["Predictions",
"Paper / Cite",
"Additional Information",
"Examples"])
# Results tab
with tab1:
self._fill_tab_with_results_content(predictor,
inputs,
buttons,
create_prediction_df,
create_molecule_grid_plot)
# Paper tab
with tab2:
self._fill_paper_and_citation_tab()
# More explanations tab
with tab3:
self._fill_more_explanations_tab()
with tab4:
self._fill_examples_tab()
def _make_query_box(self):
"""
This function
a) defines the query box and
b) stores the query input in the inputs dictionary
"""
st.info(":blue[Molecules to predict:]", icon="❓")
query_container = st.container()
with query_container:
input_choice = st.radio(
"Input your data in SMILES notation via:", ["Text box", "CSV upload"]
)
if input_choice == "Text box":
query_input = st.text_area(
label="SMILES input for query molecules",
label_visibility="hidden",
key="query_textbox",
value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
"N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, "
"Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
"CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, "
"Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, "
"COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, "
"Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, "
"CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, "
"N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, "
"COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F"
)
elif input_choice == "CSV upload":
query_file = st.file_uploader(key="query_csv",
label = "CSV upload for query mols",
label_visibility="hidden")
if query_file is not None:
query_input = pd.read_csv(query_file)
else: query_input = None
# Update storage
self.inputs["query"] = query_input
def _make_active_support_set_box(self):
"""
This function
a) defines the active support set box and
b) stores the active support set input in the inputs dictionary
"""
st.info(":blue[Known active molecules:]", icon="✨")
active_container = st.container()
with active_container:
active_input_choice = st.radio(
"Input your data in SMILES notation via:",
["Text box", "CSV upload"],
key="active_input_choice",
)
if active_input_choice == "Text box":
support_active_input = st.text_area(
label="SMILES input for active support set molecules",
label_visibility="hidden",
key="active_textbox",
value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, "
"Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O"
)
elif active_input_choice == "CSV upload":
support_active_file = st.file_uploader(
key="support_active_csv",
label = "CSV upload for active support set molecules",
label_visibility="hidden"
)
if support_active_file is not None:
support_active_input = pd.read_csv(support_active_file)
else: support_active_input = None
# Update storage
self.inputs["support_active"] = support_active_input
def _make_inactive_support_set_box(self):
st.info(":blue[Known inactive molecules:]", icon="✨")
inactive_container = st.container()
with inactive_container:
inactive_input_choice = st.radio(
"Input your data in SMILES notation via:",
["Text box", "CSV upload"],
key="inactive_input_choice",
)
if inactive_input_choice == "Text box":
support_inactive_input = st.text_area(
label="SMILES input for inactive support set molecules",
label_visibility="hidden",
key="inactive_textbox",
value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, "
"CSc1nc(C)nc(OCC(=O)O)c1C#N"
)
elif inactive_input_choice == "CSV upload":
support_inactive_file = st.file_uploader(
key="support_inactive_csv",
label = "CSV upload for inactive support set molecules",
label_visibility="hidden"
)
if support_inactive_file is not None:
support_inactive_input = pd.read_csv(
support_inactive_file
)
else: support_inactive_input = None
# Update storage
self.inputs["support_inactive"] = support_inactive_input
def _fill_tab_with_results_content(self, predictor, inputs, buttons,
create_prediction_df, create_molecule_grid_plot):
tab_container = st.container()
with tab_container:
# Info
st.info(":blue[Summary:]", icon="πŸš€")
st.markdown(self.summary_text)
# Results
st.info(":blue[Results:]",icon="πŸ‘¨β€πŸ’»")
if buttons['predict']:
# Check 1: Are all inputs provided?
if (inputs['query'] is None or
inputs['support_active'] is None or
inputs['support_inactive'] is None):
st.error("You didn't provide all necessary inputs.\n\n"
"Please provide all three necessary inputs via the "
"sidebar and hit the predict button again.")
else:
# Check 2: Less than max allowed molecules provided?
max_input_length = 0
for key, input in inputs.items():
input_list = handle_inputs(input)
self.inputs_lists[key] = input_list
max_input_length = max(max_input_length, len(input_list))
if max_input_length > self.max_input_length:
st.error("You provided too many molecules. The number of "
"molecules for each input is restricted to "
f"{self.max_input_length}.\n\n"
"For larger screenings, we suggest to clone the repo "
"and to run the model locally.")
else:
# Progress bar
progress_bar_text = ("I'm predicting activities. This might "
"need some minutes. Please wait...")
progress_bar = st.progress(50, text=progress_bar_text)
# Results table
df = self._predict_and_create_results_table(predictor,
inputs,
create_prediction_df)
progress_bar_text = ("Done. Here are the results:")
progress_bar = progress_bar.progress(100, text=progress_bar_text)
st.dataframe(df, use_container_width=True)
col1, col2, col3, col4 = st.columns([1,1,1,1])
# Provide download button for predictions
with col2:
self.buttons["download_results"] = st.download_button(
"Download predictions as CSV",
self._convert_df_to_binary(df),
file_name="predictions.csv",
)
# Provide download button for inputs
with col3:
with open("inputs.yml", 'w') as fl:
self.buttons["download_inputs"] = st.download_button(
"Download inputs as YML",
self._convert_to_yml(self.inputs_lists),
file_name="inputs.yml",
)
st.divider()
# Results grid
st.info(":blue[Grid plot of the predicted molecules:]",
icon="πŸ“Š")
mol_html_grid = create_molecule_grid_plot(df)
components.html(mol_html_grid, height=1000, scrolling=True)
elif buttons['reset']:
self._reset()
def _fill_paper_and_citation_tab(self):
st.info(":blue[**Paper: Context-enriched molecule representations improve "
"few-shot drug discovery**]", icon="πŸ“„")
st.markdown(self.mhnfs_text, unsafe_allow_html=True)
st.image("./assets/mhnfs_overview.png")
st.write("")
st.write("")
st.write("")
st.info(":blue[**Cite us / BibTex**]", icon="πŸ“š")
st.markdown(self.citation_text)
def _fill_more_explanations_tab(self):
st.info(":blue[**Under the hood**]", icon="βš™οΈ")
st.markdown(self.under_the_hood_text, unsafe_allow_html=True)
st.write("")
st.write("")
st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="🎯")
st.markdown(self.few_shot_learning_text, unsafe_allow_html=True)
st.write("")
st.write("")
st.info(":blue[**Usage**]", icon="πŸŽ›οΈ")
st.markdown(self.usage_text, unsafe_allow_html=True)
st.write("")
st.write("")
st.info(":blue[**How to provide the data**]", icon="πŸ“€")
st.markdown(self.data_text, unsafe_allow_html=True)
st.write("")
st.write("")
st.info(":blue[**When to trust the predictions**]", icon="πŸ”")
st.markdown(self.trust_text, unsafe_allow_html=True)
def _fill_examples_tab(self):
st.info(":blue[**Example for trustworthy predictions**]", icon="βœ…")
st.markdown(self.example_trustworthy_text, unsafe_allow_html=True)
st.dataframe(self.df_trustworthy, use_container_width=True)
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC="
"0.96**)")
prediction_plot_tw = Image.open("./assets/example_csv/predictions/"
"trustworthy_example.png")
st.image(prediction_plot_tw)
st.write("")
st.write("")
st.info(":blue[**Example for not trustworthy predictions**]", icon="⛔️")
st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True)
st.dataframe(self.df_nottrustworthy, use_container_width=True)
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC="
"0.42**)")
prediction_plot_ntw = Image.open("./assets/example_csv/predictions/"
"nottrustworthy_example.png")
st.image(prediction_plot_ntw)
def _predict_and_create_results_table(self,
predictor,
inputs,
create_prediction_df: callable):
df = create_prediction_df(predictor,
inputs['query'],
inputs['support_active'],
inputs['support_inactive'])
return df
def _reset(self):
keys = list(st.session_state.keys())
for key in keys:
st.session_state.pop(key)
def _convert_df_to_binary(_self, df):
return df.to_csv(index=False).encode('utf-8')
def _convert_to_yml(_self, inputs):
return yaml.dump(inputs)
content = """
# Usage
As soon as you have a few active and inactive molecules for your task, you can
provide them here and make predictions for new molecules.
## About few-shot learning and the model MHNfs
**Few-shot learning** is a machine learning sub-field which aims to provide
predictive models for scenarios in which only little data is known/available.
**MHNfs** is a few-shot learning model which is specifically designed for drug
discovery applications. It is built to use the input prompts in a way such that
the provided available knowledge - i.e. the known active and inactive molecules -
functions as context to predict the activity of the new requested molecules.
Precisely, the provided active and inactive molecules are associated with a
large set of general molecules - called context molecules - to enrich the
provided information and to remove spurious correlations arising from the
decoration of molecules. This is analogous to a Large Language Model which would
not only use the provided information in the current prompt as context but would
also have access to way more information, e.g. a prompting history.
## How to provide the data
* Molecules have to be provided in SMILES format.
* You can provide the molecules via the text boxes or via CSV upload.
- Text box: Replace the pseudo input by directly typing your molecules into
the text box. Please separate the molecules by comma.
- CSV upload: Upload a CSV file with the molecules.
* The CSV file should include a smiles column (both upper and lower
case "SMILES" are accepted).
* All other columns will be ignored.
## When to trust the predictions
Just like all other machine learning models, the performance of MHNfs varies
and, generally, the model works well if the task is somehow close to tasks which
were used to train the model. The model performance for very different tasks is
unclear and might be poor.
MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly
5000 tasks were used for training, rest for evaluation). The training tasks are
listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets.
"""
return content