File size: 20,176 Bytes
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce05748
 
 
 
 
 
 
 
 
 
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce05748
 
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce05748
 
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""
This file defines the layout of the app including the header, sidebar, and tabs in the
main content area. 
"""

#---------------------------------------------------------------------------------------
# Imports
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import pandas as pd
import yaml

from src.data_preprocessing.create_descriptors import handle_inputs
from src.app.constants import (summary_text,
                               mhnfs_text,
                               citation_text,
                               few_shot_learning_text,
                               under_the_hood_text,
                               usage_text,
                               data_text,
                               trust_text,
                               example_trustworthy_text,
                               example_nottrustworthy_text)
#---------------------------------------------------------------------------------------
# Global variables
MAX_INPUT_LENGTH = 20

#---------------------------------------------------------------------------------------
# Functions

class LayoutMaker():
    """
    This class includes all the design choices regarding the layout of the app. This
    class can be used in the main file to define header, sidebar, and main content area.
    """
    
    def __init__(self):
        
        # Initialize the inputs dictionary
        self.inputs = dict() # this will be the storage for query and support set inputs
        self.inputs_lists = dict()
        
        # Initialize prediction storage
        self.predictions = None
        
        # Buttons
        self.buttons = dict() # this will be the storage for buttons
        
        # content
        self.summary_text = summary_text
        self.mhnfs_text = mhnfs_text
        self.citation_text = citation_text
        self.few_shot_learning_text = few_shot_learning_text
        self.under_the_hood_text = under_the_hood_text
        self.usage_text = usage_text
        self.data_text = data_text
        self.trust_text = trust_text
        self.example_trustworthy_text = example_trustworthy_text
        self.example_nottrustworthy_text = example_nottrustworthy_text
        
        self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/"
                                          "trustworthy_example.csv")
        self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/"
                                            "nottrustworthy_example.csv")
        
        self.max_input_length = MAX_INPUT_LENGTH
    
    def make_sidebar(self):
        """
        This function defines the sidebar of the app. It includes the logo, query box,
        support set boxes, and predict buttons.
        It returns the stored inputs (for query and support set) and the buttons which
        allow for user interactions.
        """
        with st.sidebar:
            # Logo
            logo = Image.open("./assets/logo.png")
            st.image(logo)
            st.divider()
            
            # Query box
            self._make_query_box()
            st.divider()
            
            # Support set actives box
            self._make_active_support_set_box()
            st.divider()
            
            # Support set inactives box
            self._make_inactive_support_set_box()
            st.divider()
            
            # Predict buttons
            self.buttons["predict"] = st.button("Predict...")
            self.buttons["reset"] = st.button("Reset")
            
        return self.inputs, self.buttons
    
    def make_header(self):
        """
        This function defines the header of the app. It consists only of a png image
        in which the title and an overview is given.
        """
        
        header_container = st.container()
        with header_container:
            header = Image.open("./assets/header.png")
            st.image(header)
   
    def make_main_content_area(self,
                               predictor,
                               inputs,
                               buttons,
                               create_prediction_df: callable,
                               create_molecule_grid_plot: callable):

        
        tab1, tab2, tab3, tab4 = st.tabs(["Predictions",
                                    "Paper / Cite",
                                    "Additional Information",
                                    "Examples"])
        
        # Results tab
        with tab1:
            self._fill_tab_with_results_content(predictor,
                                                inputs,
                                                buttons,
                                                create_prediction_df,
                                                create_molecule_grid_plot)

        # Paper tab
        with tab2:
            self._fill_paper_and_citation_tab()
        
        # More explanations tab
        with tab3:
            self._fill_more_explanations_tab()
        
        with tab4:
            self._fill_examples_tab()
                     
    def _make_query_box(self):
        """
        This function
        a) defines the query box and
        b) stores the query input in the inputs dictionary 
        """
        
        st.info(":blue[Molecules to predict:]", icon="❓")
        
        query_container = st.container()
        with query_container:
            input_choice = st.radio(
                "Input your data in SMILES notation via:", ["Text box", "CSV upload"]
            )
            if input_choice == "Text box":
                query_input = st.text_area(
                    label="SMILES input for query molecules",
                    label_visibility="hidden",
                    key="query_textbox",
                    value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
                           "N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, "
                           "Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
                           "CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, "
                           "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, "
                           "COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, "
                           "Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, "
                           "CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, "
                           "N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, "
                           "COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F"
                )
            elif input_choice == "CSV upload":
                query_file = st.file_uploader(key="query_csv",
                                              label = "CSV upload for query mols",
                                              label_visibility="hidden")
                if query_file is not None:
                    query_input = pd.read_csv(query_file)
                else: query_input = None
        
        # Update storage
        self.inputs["query"] = query_input
    
    def _make_active_support_set_box(self):
        """
        This function
        a) defines the active support set box and
        b) stores the active support set input in the inputs dictionary
        """
        
        st.info(":blue[Known active molecules:]", icon="✨")
        active_container = st.container()
        with active_container:
            active_input_choice = st.radio(
                "Input your data in SMILES notation via:",
                ["Text box", "CSV upload"],
                key="active_input_choice",
            )

            if active_input_choice == "Text box":
                support_active_input = st.text_area(
                    label="SMILES input for active support set molecules",
                    label_visibility="hidden",
                    key="active_textbox",
                    value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, "
                          "Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O"
                )
            elif active_input_choice == "CSV upload":
                support_active_file = st.file_uploader(
                    key="support_active_csv",
                    label = "CSV upload for active support set molecules",
                    label_visibility="hidden"
                    )
                if support_active_file is not None:
                    support_active_input  = pd.read_csv(support_active_file)
                else: support_active_input = None
        
        # Update storage
        self.inputs["support_active"] = support_active_input

    def _make_inactive_support_set_box(self):
        st.info(":blue[Known inactive molecules:]", icon="✨")
        inactive_container = st.container()
        with inactive_container:
            inactive_input_choice = st.radio(
                "Input your data in SMILES notation via:",
                ["Text box", "CSV upload"],
                key="inactive_input_choice",
            )
            if inactive_input_choice == "Text box":
                support_inactive_input  = st.text_area(
                    label="SMILES input for inactive support set molecules",
                    label_visibility="hidden",
                    key="inactive_textbox",
                    value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, "
                          "CSc1nc(C)nc(OCC(=O)O)c1C#N"
                )
            elif inactive_input_choice == "CSV upload":
                support_inactive_file  = st.file_uploader(
                    key="support_inactive_csv",
                    label = "CSV upload for inactive support set molecules",
                    label_visibility="hidden"
                    )
                if support_inactive_file is not None:
                    support_inactive_input  = pd.read_csv(
                        support_inactive_file
                        )
                else: support_inactive_input = None

        # Update storage
        self.inputs["support_inactive"] = support_inactive_input

    def _fill_tab_with_results_content(self, predictor, inputs, buttons,
                                       create_prediction_df, create_molecule_grid_plot):
        tab_container = st.container()
        with tab_container:
            # Info
            st.info(":blue[Summary:]", icon="πŸš€")
            st.markdown(self.summary_text)
            
            # Results
            st.info(":blue[Results:]",icon="πŸ‘¨β€πŸ’»")
            
            if buttons['predict']:
                
                # Check 1: Are all inputs provided?
                if (inputs['query'] is None or 
                    inputs['support_active'] is None or 
                    inputs['support_inactive'] is None):
                        st.error("You didn't provide all necessary inputs.\n\n"
                                 "Please provide all three necessary inputs via the "
                                 "sidebar and hit the predict button again.")        
                else:
                    # Check 2: Less than max allowed molecules provided?
                    max_input_length = 0
                    for key, input in inputs.items():
                            input_list = handle_inputs(input)
                            self.inputs_lists[key] = input_list
                            max_input_length = max(max_input_length, len(input_list))
                            
                    if max_input_length > self.max_input_length:
                        st.error("You provided too many molecules. The number of "
                                 "molecules for each input is restricted to "
                                f"{self.max_input_length}.\n\n"
                                "For larger screenings, we suggest to clone the repo "
                                "and to run the model locally.")
                    else:    
                        # Progress bar
                        progress_bar_text = ("I'm predicting activities. This might "
                                                "need some minutes. Please wait...")
                        progress_bar = st.progress(50, text=progress_bar_text)
                        
                        # Results table
                        df = self._predict_and_create_results_table(predictor,
                                                                    inputs,
                                                                    create_prediction_df)
                        
                        progress_bar_text = ("Done. Here are the results:")
                        progress_bar = progress_bar.progress(100, text=progress_bar_text)
                        st.dataframe(df, use_container_width=True)
                        
                        col1, col2, col3, col4 = st.columns([1,1,1,1])
                        # Provide download button for predictions
                        with col2:
                            self.buttons["download_results"] = st.download_button(
                                "Download predictions as CSV",
                                self._convert_df_to_binary(df),
                                file_name="predictions.csv",
                            )
                        
                        # Provide download button for inputs
                        with col3:
                            with open("inputs.yml", 'w') as fl:
                                self.buttons["download_inputs"] = st.download_button(
                                    "Download inputs as YML",
                                    self._convert_to_yml(self.inputs_lists),
                                    file_name="inputs.yml",
                                )
                        st.divider()
                                
                        # Results grid
                        st.info(":blue[Grid plot of the predicted molecules:]",
                                icon="πŸ“Š")
                        mol_html_grid = create_molecule_grid_plot(df)
                        components.html(mol_html_grid, height=1000, scrolling=True)
                
            elif buttons['reset']:
                self._reset()
    
    def _fill_paper_and_citation_tab(self):
        st.info(":blue[**Paper: Context-enriched molecule representations improve "
                "few-shot drug discovery**]", icon="πŸ“„")
        st.markdown(self.mhnfs_text, unsafe_allow_html=True)
        st.image("./assets/mhnfs_overview.png")
        st.write("")
        st.write("")
        st.write("")
        st.info(":blue[**Cite us / BibTex**]", icon="πŸ“š")
        st.markdown(self.citation_text)

    def _fill_more_explanations_tab(self):
        st.info(":blue[**Under the hood**]", icon="βš™οΈ")
        st.markdown(self.under_the_hood_text, unsafe_allow_html=True)
        st.write("")
        st.write("")
        
        st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="🎯")
        st.markdown(self.few_shot_learning_text, unsafe_allow_html=True)
        st.write("")
        st.write("")
        
        st.info(":blue[**Usage**]", icon="πŸŽ›οΈ")
        st.markdown(self.usage_text, unsafe_allow_html=True)
        st.write("")
        st.write("")
        
        st.info(":blue[**How to provide the data**]", icon="πŸ“€")
        st.markdown(self.data_text, unsafe_allow_html=True)
        st.write("")
        st.write("")
        
        st.info(":blue[**When to trust the predictions**]", icon="πŸ”")
        st.markdown(self.trust_text, unsafe_allow_html=True)
   
    def _fill_examples_tab(self):
        st.info(":blue[**Example for trustworthy predictions**]", icon="βœ…")
        st.markdown(self.example_trustworthy_text, unsafe_allow_html=True)
        st.dataframe(self.df_trustworthy, use_container_width=True)
        st.markdown("**Plot: Predictions for active and inactive molecules (model AUC="
                    "0.96**)")
        prediction_plot_tw = Image.open("./assets/example_csv/predictions/"
                                        "trustworthy_example.png")
        st.image(prediction_plot_tw)
        st.write("")
        st.write("")
        
        st.info(":blue[**Example for not trustworthy predictions**]", icon="⛔️")
        st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True)
        st.dataframe(self.df_nottrustworthy, use_container_width=True)
        st.markdown("**Plot: Predictions for active and inactive molecules (model AUC="
                    "0.42**)")
        prediction_plot_ntw = Image.open("./assets/example_csv/predictions/"
                                        "nottrustworthy_example.png")
        st.image(prediction_plot_ntw)
    
    def _predict_and_create_results_table(self,
                                          predictor,
                                          inputs,
                                          create_prediction_df: callable):

            df = create_prediction_df(predictor,
                                      inputs['query'],
                                      inputs['support_active'],
                                      inputs['support_inactive'])
            return df
    
    def _reset(self):
        keys = list(st.session_state.keys())
        for key in keys:
            st.session_state.pop(key)
    
    def _convert_df_to_binary(_self, df):
        return df.to_csv(index=False).encode('utf-8')
    
    def _convert_to_yml(_self, inputs):
        return yaml.dump(inputs)
        content = """
        # Usage
        As soon as you have a few active and inactive molecules for your task, you can 
        provide them here and make predictions for new molecules.
        
        ## About few-shot learning and the model MHNfs
        **Few-shot learning** is a machine learning sub-field which aims to provide 
        predictive models for scenarios in which only little data is known/available.
        
        **MHNfs** is a few-shot learning model which is specifically designed for drug
        discovery applications. It is built to use the input prompts in a way such that 
        the provided available knowledge - i.e. the known active and inactive molecules - 
        functions as context to predict the activity of the new requested molecules. 
        Precisely, the provided active and inactive molecules are associated with a
        large set of general molecules - called context molecules - to enrich the 
        provided information and to remove spurious correlations arising from the 
        decoration of molecules. This is analogous to a Large Language Model which would
        not only use the provided information in the current prompt as context but would
        also have access to way more information, e.g. a prompting history.
        
        ## How to provide the data
        * Molecules have to be provided in SMILES format.
        * You can provide the molecules via the text boxes or via CSV upload.
            - Text box: Replace the pseudo input by directly typing your molecules into 
            the text box. Please separate the molecules by comma.
            - CSV upload: Upload a CSV file with the molecules.
                * The CSV file should include a smiles column (both upper and lower 
                case "SMILES" are accepted). 
                * All other columns will be ignored.
        
        ## When to trust the predictions
        Just like all other machine learning models, the performance of MHNfs varies 
        and, generally, the model works well if the task is somehow close to tasks which 
        were used to train the model. The model performance for very different tasks is 
        unclear and might be poor.
        
        MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly 
        5000 tasks were used for training, rest for evaluation). The training tasks are 
        listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets.
        """
        return content