|
""" |
|
This file defines the layout of the app including the header, sidebar, and tabs in the |
|
main content area. |
|
""" |
|
|
|
|
|
|
|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
from PIL import Image |
|
import pandas as pd |
|
import yaml |
|
|
|
from src.data_preprocessing.create_descriptors import handle_inputs |
|
from src.app.constants import (summary_text, |
|
mhnfs_text, |
|
citation_text, |
|
few_shot_learning_text, |
|
under_the_hood_text, |
|
usage_text, |
|
data_text, |
|
trust_text, |
|
example_trustworthy_text, |
|
example_nottrustworthy_text) |
|
|
|
|
|
MAX_INPUT_LENGTH = 20 |
|
|
|
|
|
|
|
|
|
class LayoutMaker(): |
|
""" |
|
This class includes all the design choices regarding the layout of the app. This |
|
class can be used in the main file to define header, sidebar, and main content area. |
|
""" |
|
|
|
def __init__(self): |
|
|
|
|
|
self.inputs = dict() |
|
self.inputs_lists = dict() |
|
|
|
|
|
self.predictions = None |
|
|
|
|
|
self.buttons = dict() |
|
|
|
|
|
self.summary_text = summary_text |
|
self.mhnfs_text = mhnfs_text |
|
self.citation_text = citation_text |
|
self.few_shot_learning_text = few_shot_learning_text |
|
self.under_the_hood_text = under_the_hood_text |
|
self.usage_text = usage_text |
|
self.data_text = data_text |
|
self.trust_text = trust_text |
|
self.example_trustworthy_text = example_trustworthy_text |
|
self.example_nottrustworthy_text = example_nottrustworthy_text |
|
|
|
self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/" |
|
"trustworthy_example.csv") |
|
self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/" |
|
"nottrustworthy_example.csv") |
|
|
|
self.max_input_length = MAX_INPUT_LENGTH |
|
|
|
def make_sidebar(self): |
|
""" |
|
This function defines the sidebar of the app. It includes the logo, query box, |
|
support set boxes, and predict buttons. |
|
It returns the stored inputs (for query and support set) and the buttons which |
|
allow for user interactions. |
|
""" |
|
with st.sidebar: |
|
|
|
logo = Image.open("./assets/logo.png") |
|
st.image(logo) |
|
st.divider() |
|
|
|
|
|
self._make_query_box() |
|
st.divider() |
|
|
|
|
|
self._make_active_support_set_box() |
|
st.divider() |
|
|
|
|
|
self._make_inactive_support_set_box() |
|
st.divider() |
|
|
|
|
|
self.buttons["predict"] = st.button("Predict...") |
|
self.buttons["reset"] = st.button("Reset") |
|
|
|
return self.inputs, self.buttons |
|
|
|
def make_header(self): |
|
""" |
|
This function defines the header of the app. It consists only of a png image |
|
in which the title and an overview is given. |
|
""" |
|
|
|
header_container = st.container() |
|
with header_container: |
|
header = Image.open("./assets/header.png") |
|
st.image(header) |
|
|
|
def make_main_content_area(self, |
|
predictor, |
|
inputs, |
|
buttons, |
|
create_prediction_df: callable, |
|
create_molecule_grid_plot: callable): |
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["Predictions", |
|
"Paper / Cite", |
|
"Additional Information", |
|
"Examples"]) |
|
|
|
|
|
with tab1: |
|
self._fill_tab_with_results_content(predictor, |
|
inputs, |
|
buttons, |
|
create_prediction_df, |
|
create_molecule_grid_plot) |
|
|
|
|
|
with tab2: |
|
self._fill_paper_and_citation_tab() |
|
|
|
|
|
with tab3: |
|
self._fill_more_explanations_tab() |
|
|
|
with tab4: |
|
self._fill_examples_tab() |
|
|
|
def _make_query_box(self): |
|
""" |
|
This function |
|
a) defines the query box and |
|
b) stores the query input in the inputs dictionary |
|
""" |
|
|
|
st.info(":blue[Molecules to predict:]", icon="β") |
|
|
|
query_container = st.container() |
|
with query_container: |
|
input_choice = st.radio( |
|
"Input your data in SMILES notation via:", ["Text box", "CSV upload"] |
|
) |
|
if input_choice == "Text box": |
|
query_input = st.text_area( |
|
label="SMILES input for query molecules", |
|
label_visibility="hidden", |
|
key="query_textbox", |
|
value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " |
|
"N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, " |
|
"Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " |
|
"CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, " |
|
"Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, " |
|
"COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, " |
|
"Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, " |
|
"CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, " |
|
"N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, " |
|
"COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F" |
|
) |
|
elif input_choice == "CSV upload": |
|
query_file = st.file_uploader(key="query_csv", |
|
label = "CSV upload for query mols", |
|
label_visibility="hidden") |
|
if query_file is not None: |
|
query_input = pd.read_csv(query_file) |
|
else: query_input = None |
|
|
|
|
|
self.inputs["query"] = query_input |
|
|
|
def _make_active_support_set_box(self): |
|
""" |
|
This function |
|
a) defines the active support set box and |
|
b) stores the active support set input in the inputs dictionary |
|
""" |
|
|
|
st.info(":blue[Known active molecules:]", icon="β¨") |
|
active_container = st.container() |
|
with active_container: |
|
active_input_choice = st.radio( |
|
"Input your data in SMILES notation via:", |
|
["Text box", "CSV upload"], |
|
key="active_input_choice", |
|
) |
|
|
|
if active_input_choice == "Text box": |
|
support_active_input = st.text_area( |
|
label="SMILES input for active support set molecules", |
|
label_visibility="hidden", |
|
key="active_textbox", |
|
value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, " |
|
"Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O" |
|
) |
|
elif active_input_choice == "CSV upload": |
|
support_active_file = st.file_uploader( |
|
key="support_active_csv", |
|
label = "CSV upload for active support set molecules", |
|
label_visibility="hidden" |
|
) |
|
if support_active_file is not None: |
|
support_active_input = pd.read_csv(support_active_file) |
|
else: support_active_input = None |
|
|
|
|
|
self.inputs["support_active"] = support_active_input |
|
|
|
def _make_inactive_support_set_box(self): |
|
st.info(":blue[Known inactive molecules:]", icon="β¨") |
|
inactive_container = st.container() |
|
with inactive_container: |
|
inactive_input_choice = st.radio( |
|
"Input your data in SMILES notation via:", |
|
["Text box", "CSV upload"], |
|
key="inactive_input_choice", |
|
) |
|
if inactive_input_choice == "Text box": |
|
support_inactive_input = st.text_area( |
|
label="SMILES input for inactive support set molecules", |
|
label_visibility="hidden", |
|
key="inactive_textbox", |
|
value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, " |
|
"CSc1nc(C)nc(OCC(=O)O)c1C#N" |
|
) |
|
elif inactive_input_choice == "CSV upload": |
|
support_inactive_file = st.file_uploader( |
|
key="support_inactive_csv", |
|
label = "CSV upload for inactive support set molecules", |
|
label_visibility="hidden" |
|
) |
|
if support_inactive_file is not None: |
|
support_inactive_input = pd.read_csv( |
|
support_inactive_file |
|
) |
|
else: support_inactive_input = None |
|
|
|
|
|
self.inputs["support_inactive"] = support_inactive_input |
|
|
|
def _fill_tab_with_results_content(self, predictor, inputs, buttons, |
|
create_prediction_df, create_molecule_grid_plot): |
|
tab_container = st.container() |
|
with tab_container: |
|
|
|
st.info(":blue[Summary:]", icon="π") |
|
st.markdown(self.summary_text) |
|
|
|
|
|
st.info(":blue[Results:]",icon="π¨βπ»") |
|
|
|
if buttons['predict']: |
|
|
|
|
|
if (inputs['query'] is None or |
|
inputs['support_active'] is None or |
|
inputs['support_inactive'] is None): |
|
st.error("You didn't provide all necessary inputs.\n\n" |
|
"Please provide all three necessary inputs via the " |
|
"sidebar and hit the predict button again.") |
|
else: |
|
|
|
max_input_length = 0 |
|
for key, input in inputs.items(): |
|
input_list = handle_inputs(input) |
|
self.inputs_lists[key] = input_list |
|
max_input_length = max(max_input_length, len(input_list)) |
|
|
|
if max_input_length > self.max_input_length: |
|
st.error("You provided too many molecules. The number of " |
|
"molecules for each input is restricted to " |
|
f"{self.max_input_length}.\n\n" |
|
"For larger screenings, we suggest to clone the repo " |
|
"and to run the model locally.") |
|
else: |
|
|
|
progress_bar_text = ("I'm predicting activities. This might " |
|
"need some minutes. Please wait...") |
|
progress_bar = st.progress(50, text=progress_bar_text) |
|
|
|
|
|
df = self._predict_and_create_results_table(predictor, |
|
inputs, |
|
create_prediction_df) |
|
|
|
progress_bar_text = ("Done. Here are the results:") |
|
progress_bar = progress_bar.progress(100, text=progress_bar_text) |
|
st.dataframe(df, use_container_width=True) |
|
|
|
col1, col2, col3, col4 = st.columns([1,1,1,1]) |
|
|
|
with col2: |
|
self.buttons["download_results"] = st.download_button( |
|
"Download predictions as CSV", |
|
self._convert_df_to_binary(df), |
|
file_name="predictions.csv", |
|
) |
|
|
|
|
|
with col3: |
|
with open("inputs.yml", 'w') as fl: |
|
self.buttons["download_inputs"] = st.download_button( |
|
"Download inputs as YML", |
|
self._convert_to_yml(self.inputs_lists), |
|
file_name="inputs.yml", |
|
) |
|
st.divider() |
|
|
|
|
|
st.info(":blue[Grid plot of the predicted molecules:]", |
|
icon="π") |
|
mol_html_grid = create_molecule_grid_plot(df) |
|
components.html(mol_html_grid, height=1000, scrolling=True) |
|
|
|
elif buttons['reset']: |
|
self._reset() |
|
|
|
def _fill_paper_and_citation_tab(self): |
|
st.info(":blue[**Paper: Context-enriched molecule representations improve " |
|
"few-shot drug discovery**]", icon="π") |
|
st.markdown(self.mhnfs_text, unsafe_allow_html=True) |
|
st.image("./assets/mhnfs_overview.png") |
|
st.write("") |
|
st.write("") |
|
st.write("") |
|
st.info(":blue[**Cite us / BibTex**]", icon="π") |
|
st.markdown(self.citation_text) |
|
|
|
def _fill_more_explanations_tab(self): |
|
st.info(":blue[**Under the hood**]", icon="βοΈ") |
|
st.markdown(self.under_the_hood_text, unsafe_allow_html=True) |
|
st.write("") |
|
st.write("") |
|
|
|
st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="π―") |
|
st.markdown(self.few_shot_learning_text, unsafe_allow_html=True) |
|
st.write("") |
|
st.write("") |
|
|
|
st.info(":blue[**Usage**]", icon="ποΈ") |
|
st.markdown(self.usage_text, unsafe_allow_html=True) |
|
st.write("") |
|
st.write("") |
|
|
|
st.info(":blue[**How to provide the data**]", icon="π") |
|
st.markdown(self.data_text, unsafe_allow_html=True) |
|
st.write("") |
|
st.write("") |
|
|
|
st.info(":blue[**When to trust the predictions**]", icon="π") |
|
st.markdown(self.trust_text, unsafe_allow_html=True) |
|
|
|
def _fill_examples_tab(self): |
|
st.info(":blue[**Example for trustworthy predictions**]", icon="β
") |
|
st.markdown(self.example_trustworthy_text, unsafe_allow_html=True) |
|
st.dataframe(self.df_trustworthy, use_container_width=True) |
|
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" |
|
"0.96**)") |
|
prediction_plot_tw = Image.open("./assets/example_csv/predictions/" |
|
"trustworthy_example.png") |
|
st.image(prediction_plot_tw) |
|
st.write("") |
|
st.write("") |
|
|
|
st.info(":blue[**Example for not trustworthy predictions**]", icon="βοΈ") |
|
st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True) |
|
st.dataframe(self.df_nottrustworthy, use_container_width=True) |
|
st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" |
|
"0.42**)") |
|
prediction_plot_ntw = Image.open("./assets/example_csv/predictions/" |
|
"nottrustworthy_example.png") |
|
st.image(prediction_plot_ntw) |
|
|
|
def _predict_and_create_results_table(self, |
|
predictor, |
|
inputs, |
|
create_prediction_df: callable): |
|
|
|
df = create_prediction_df(predictor, |
|
inputs['query'], |
|
inputs['support_active'], |
|
inputs['support_inactive']) |
|
return df |
|
|
|
def _reset(self): |
|
keys = list(st.session_state.keys()) |
|
for key in keys: |
|
st.session_state.pop(key) |
|
|
|
def _convert_df_to_binary(_self, df): |
|
return df.to_csv(index=False).encode('utf-8') |
|
|
|
def _convert_to_yml(_self, inputs): |
|
return yaml.dump(inputs) |
|
content = """ |
|
# Usage |
|
As soon as you have a few active and inactive molecules for your task, you can |
|
provide them here and make predictions for new molecules. |
|
|
|
## About few-shot learning and the model MHNfs |
|
**Few-shot learning** is a machine learning sub-field which aims to provide |
|
predictive models for scenarios in which only little data is known/available. |
|
|
|
**MHNfs** is a few-shot learning model which is specifically designed for drug |
|
discovery applications. It is built to use the input prompts in a way such that |
|
the provided available knowledge - i.e. the known active and inactive molecules - |
|
functions as context to predict the activity of the new requested molecules. |
|
Precisely, the provided active and inactive molecules are associated with a |
|
large set of general molecules - called context molecules - to enrich the |
|
provided information and to remove spurious correlations arising from the |
|
decoration of molecules. This is analogous to a Large Language Model which would |
|
not only use the provided information in the current prompt as context but would |
|
also have access to way more information, e.g. a prompting history. |
|
|
|
## How to provide the data |
|
* Molecules have to be provided in SMILES format. |
|
* You can provide the molecules via the text boxes or via CSV upload. |
|
- Text box: Replace the pseudo input by directly typing your molecules into |
|
the text box. Please separate the molecules by comma. |
|
- CSV upload: Upload a CSV file with the molecules. |
|
* The CSV file should include a smiles column (both upper and lower |
|
case "SMILES" are accepted). |
|
* All other columns will be ignored. |
|
|
|
## When to trust the predictions |
|
Just like all other machine learning models, the performance of MHNfs varies |
|
and, generally, the model works well if the task is somehow close to tasks which |
|
were used to train the model. The model performance for very different tasks is |
|
unclear and might be poor. |
|
|
|
MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly |
|
5000 tasks were used for training, rest for evaluation). The training tasks are |
|
listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. |
|
""" |
|
return content |