fuxialexander commited on
Commit
690c5f2
1 Parent(s): 7af1806

regulatory demo

Browse files
Files changed (2) hide show
  1. app/main.py +198 -92
  2. modules/atac_rna_data_processing +1 -1
app/main.py CHANGED
@@ -1,13 +1,12 @@
1
- #%%
2
  import argparse
3
  import os
4
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
7
  import pkg_resources
8
- from proscope.data import get_genename_to_uniprot, get_lddt, get_seq
9
- import pandas as pd
10
  from dash_bio import Clustergram
 
11
 
12
  seq = get_seq()
13
  genename_to_uniprot = get_genename_to_uniprot()
@@ -23,38 +22,46 @@ from proscope.af2 import AFPairseg
23
  from proscope.protein import Protein
24
  from proscope.viewer import view_pdb_html
25
 
26
- #%%
27
  args = argparse.ArgumentParser()
28
  args.add_argument("-p", "--port", type=int, default=7860, help="Port number")
29
  args.add_argument("-s", "--share", action="store_true", help="Share on network")
30
  args.add_argument("-d", "--data", type=str, default="/data", help="Data directory")
31
- # args = args.parse_args()
32
  # set pseudo args
33
- args = args.parse_args(['-p', '7869', '-s', '-d', '/manitou/pmg/users/xf2217/demo_data'])
34
- #%%
35
  gene_pairs = glob(f"{args.data}/structures/causal/*")
36
  gene_pairs = [os.path.basename(pair) for pair in gene_pairs]
37
- GET_CONFIG = load_config('/manitou/pmg/users/xf2217/atac_rna_data_processing/atac_rna_data_processing/config/GET')
38
- GET_CONFIG.celltype.jacob=True
39
- GET_CONFIG.celltype.num_cls=2
40
- GET_CONFIG.celltype.input=True
41
- GET_CONFIG.celltype.embed=True
42
- GET_CONFIG.celltype.data_dir = '/manitou/pmg/users/xf2217/pretrain_human_bingren_shendure_apr2023/fetal_adult/'
43
- GET_CONFIG.celltype.interpret_dir='/manitou/pmg/users/xf2217/Interpretation_all_hg38_allembed_v4_natac/'
44
- GET_CONFIG.motif_dir = '/manitou/pmg/users/xf2217/interpret_natac/motif-clustering'
 
 
 
 
 
 
45
  motif = NrMotifV1.load_from_pickle(
46
  pkg_resources.resource_filename("atac_rna_data_processing", "data/NrMotifV1.pkl"),
47
- GET_CONFIG.motif_dir
48
  )
49
- cell_type_annot = pd.read_csv(GET_CONFIG.celltype.data_dir.split('fetal_adult')[0] + 'data/cell_type_pretrain_human_bingren_shendure_apr2023.txt')
50
- cell_type_id_to_name = dict(zip(cell_type_annot['id'], cell_type_annot['celltype']))
51
- cell_type_name_to_id = dict(zip(cell_type_annot['celltype'], cell_type_annot['id']))
52
- avaliable_celltypes = sorted([cell_type_id_to_name[f.split('/')[-1]] for f in glob(GET_CONFIG.celltype.interpret_dir+'*')])
53
- #%%
54
- # fill this in...
55
- # set plot ppi to 100
56
- plt.rcParams['figure.dpi'] = 100
57
-
 
 
 
 
58
 
59
 
60
  def visualize_AF2(tf_pair, a):
@@ -71,127 +78,226 @@ def visualize_AF2(tf_pair, a):
71
  fig4, ax4 = a.protein2.plot_plddt()
72
  fig5, ax5 = a.plot_score_heatmap()
73
  plt.tight_layout()
74
- new_dropdown = update_dropdown(list(a.pairs_data.keys()), 'Segment pair')
75
  return fig1, fig2, fig3, fig4, fig5, new_dropdown, a
76
 
 
77
  def view_pdb(seg_pair, a):
78
  pdb_path = a.pairs_data[seg_pair].pdb
79
  return view_pdb_html(pdb_path), a, pdb_path
80
 
81
 
82
-
83
  def update_dropdown(x, label):
84
  return gr.Dropdown.update(choices=x, label=label)
85
 
 
86
  def load_and_plot_celltype(celltype_name, GET_CONFIG, cell):
87
  celltype_id = cell_type_name_to_id[celltype_name]
88
  cell = GETCellType(celltype_id, GET_CONFIG)
89
  cell.celltype_name = celltype_name
90
  gene_exp_fig = cell.plotly_gene_exp()
91
- gene_exp_table = cell.gene_annot.groupby('gene_name')[['pred', 'obs', 'accessibility']].mean().reset_index()
92
- return gene_exp_fig, gene_exp_table, cell
93
-
94
  def plot_gene_regions(cell, gene_name, plotly=True):
95
  return cell.plot_gene_regions(gene_name, plotly=plotly), cell
96
 
 
97
  def plot_gene_motifs(cell, gene_name, motif, overwrite=False):
98
  return cell.plot_gene_motifs(gene_name, motif, overwrite=overwrite)[0], cell
99
 
100
- def plot_motif_subnet(cell, motif_collection, m, type='neighbors', threshold=0.1):
101
- return cell.plotly_motif_subnet(motif_collection, m, type=type, threshold=threshold), cell
 
 
 
 
 
102
 
103
  def plot_gene_exp(cell, plotly=True):
104
  return cell.plotly_gene_exp(plotly=plotly), cell
105
 
 
106
  def plot_motif_corr(cell):
107
- fig = Clustergram(data=cell.gene_by_motif.corr,
108
- column_labels=list(cell.gene_by_motif.corr.columns.values),
109
- row_labels=list(cell.gene_by_motif.corr.index),
110
- hidden_labels=['row', 'col'],
111
- link_method='average',
112
- display_ratio=0.1,
113
- width=600,
114
- height=400,
115
- color_map='rdbu_r',
116
- )
 
117
  return fig, cell
118
 
119
- #%%
120
- # fill this in...
121
 
122
- # main
123
- if __name__ == '__main__':
124
- with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
125
-
126
- seg_pairs = gr.State([''])
127
  af = gr.State(None)
128
  cell = gr.State(None)
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with gr.Row() as row:
131
  # Left column: Plot gene expression and gene regions
132
  with gr.Column():
 
 
 
 
 
 
133
  with gr.Row() as row:
134
- celltype_name = gr.Dropdown(label='Cell Type', choices=avaliable_celltypes)
135
- celltype_btn = gr.Button(value='Load & Plot Gene Expression')
136
- gene_exp_plot = gr.Plot(label='Gene Expression Pred vs Obs')
137
- gene_exp_table = gr.DataFrame(label='Gene Expression Table', max_rows=10)
138
-
 
139
  # Right column: Plot gene motifs
140
  with gr.Column():
141
- gene_name_for_region = gr.Textbox(label='Get important regions or motifs for gene:')
 
 
 
 
 
 
 
 
142
  with gr.Row() as row:
143
- region_plot_btn = gr.Button(value='Regions')
144
- motif_plot_btn = gr.Button(value='Motifs')
145
 
146
- region_plot = gr.Plot(label='Gene Regions')
147
- motif_plot = gr.Plot(label='Gene Motifs')
148
 
149
-
 
 
 
 
 
 
150
  with gr.Row() as row:
151
  with gr.Column():
152
- clustergram_btn = gr.Button(value='Plot Motif Correlation Heatmap')
153
- clustergram_plot = gr.Plot(label='Motif Correlation')
154
 
155
-
156
  # Right column: Motif subnet plot
157
  with gr.Column():
158
  with gr.Row() as row:
159
- motif_for_subnet = gr.Dropdown(label='Motif Causal Subnetwork', choices=motif.cluster_names)
160
- subnet_type = gr.Dropdown(label='Type', choices=['neighbors', 'parents', 'children'], default='neighbors')
 
 
 
 
 
 
161
  # slider for threshold 0.01-0.2
162
- subnet_threshold = gr.Slider(label='Threshold', minimum=0.01, maximum=0.25, step=0.01, value=0.1)
163
- subnet_btn = gr.Button(value='Plot Motif Causal Subnetwork')
164
- subnet_plot = gr.Plot(label='Motif Causal Subnetwork')
165
-
 
 
 
 
 
166
 
 
 
 
 
 
 
 
167
  with gr.Row() as row:
168
  with gr.Column():
169
  with gr.Row() as row:
170
- tf_pairs = gr.Dropdown(label='TF pair', choices=gene_pairs)
171
- tf_pairs_btn = gr.Button(value='Load & Plot')
172
- interact_plddt1 = gr.Plot(label='Interact pLDDT 1')
173
- interact_plddt2 = gr.Plot(label='Interact pLDDT 2')
174
- protein1_plddt = gr.Plot(label='Protein 1 pLDDT')
175
- protein2_plddt = gr.Plot(label='Protein 2 pLDDT')
176
-
177
- heatmap = gr.Plot(label='Heatmap')
178
-
179
  with gr.Column():
180
  with gr.Row() as row:
181
- segpair = gr.Dropdown(label='Seg pair', choices=seg_pairs.value)
182
- segpair_btn = gr.Button(value='Get PDB')
183
  pdb_html = gr.HTML(label="PDB HTML")
184
- pdb_file = gr.File(label='Download PDB')
185
 
186
- tf_pairs_btn.click(visualize_AF2, inputs = [tf_pairs, af], outputs = [ interact_plddt1, interact_plddt2, protein1_plddt, protein2_plddt, heatmap, segpair, af])
187
- segpair_btn.click(view_pdb, inputs=[segpair, af], outputs=[pdb_html, af, pdb_file])
188
- celltype_btn.click(load_and_plot_celltype, inputs=[celltype_name, gr.State(GET_CONFIG), cell], outputs=[gene_exp_plot, gene_exp_table, cell])
189
- region_plot_btn.click(plot_gene_regions, inputs=[cell, gene_name_for_region], outputs=[region_plot, cell])
190
- motif_plot_btn.click(plot_gene_motifs, inputs=[cell, gene_name_for_region, gr.State(motif)], outputs=[motif_plot, cell])
191
- clustergram_btn.click(plot_motif_corr, inputs=[cell], outputs=[clustergram_plot, cell])
192
- subnet_btn.click(plot_motif_subnet, inputs=[cell, gr.State(motif), motif_for_subnet, subnet_type, subnet_threshold], outputs=[subnet_plot, cell])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  demo.launch(share=args.share, server_port=args.port)
195
-
196
-
197
- # %%
 
 
1
  import argparse
2
  import os
3
 
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
+ import pandas as pd
7
  import pkg_resources
 
 
8
  from dash_bio import Clustergram
9
+ from proscope.data import get_genename_to_uniprot, get_lddt, get_seq
10
 
11
  seq = get_seq()
12
  genename_to_uniprot = get_genename_to_uniprot()
 
22
  from proscope.protein import Protein
23
  from proscope.viewer import view_pdb_html
24
 
 
25
  args = argparse.ArgumentParser()
26
  args.add_argument("-p", "--port", type=int, default=7860, help="Port number")
27
  args.add_argument("-s", "--share", action="store_true", help="Share on network")
28
  args.add_argument("-d", "--data", type=str, default="/data", help="Data directory")
29
+ args = args.parse_args()
30
  # set pseudo args
31
+ # args = args.parse_args(['-p', '7869', '-s', '-d', '/manitou/pmg/users/xf2217/demo_data'])
 
32
  gene_pairs = glob(f"{args.data}/structures/causal/*")
33
  gene_pairs = [os.path.basename(pair) for pair in gene_pairs]
34
+ GET_CONFIG = load_config(
35
+ "/manitou/pmg/users/xf2217/atac_rna_data_processing/atac_rna_data_processing/config/GET"
36
+ )
37
+ GET_CONFIG.celltype.jacob = True
38
+ GET_CONFIG.celltype.num_cls = 2
39
+ GET_CONFIG.celltype.input = True
40
+ GET_CONFIG.celltype.embed = True
41
+ GET_CONFIG.celltype.data_dir = (
42
+ "/manitou/pmg/users/xf2217/pretrain_human_bingren_shendure_apr2023/fetal_adult/"
43
+ )
44
+ GET_CONFIG.celltype.interpret_dir = (
45
+ "/manitou/pmg/users/xf2217/Interpretation_all_hg38_allembed_v4_natac/"
46
+ )
47
+ GET_CONFIG.motif_dir = "/manitou/pmg/users/xf2217/interpret_natac/motif-clustering"
48
  motif = NrMotifV1.load_from_pickle(
49
  pkg_resources.resource_filename("atac_rna_data_processing", "data/NrMotifV1.pkl"),
50
+ GET_CONFIG.motif_dir,
51
  )
52
+ cell_type_annot = pd.read_csv(
53
+ GET_CONFIG.celltype.data_dir.split("fetal_adult")[0]
54
+ + "data/cell_type_pretrain_human_bingren_shendure_apr2023.txt"
55
+ )
56
+ cell_type_id_to_name = dict(zip(cell_type_annot["id"], cell_type_annot["celltype"]))
57
+ cell_type_name_to_id = dict(zip(cell_type_annot["celltype"], cell_type_annot["id"]))
58
+ avaliable_celltypes = sorted(
59
+ [
60
+ cell_type_id_to_name[f.split("/")[-1]]
61
+ for f in glob(GET_CONFIG.celltype.interpret_dir + "*")
62
+ ]
63
+ )
64
+ plt.rcParams["figure.dpi"] = 100
65
 
66
 
67
  def visualize_AF2(tf_pair, a):
 
78
  fig4, ax4 = a.protein2.plot_plddt()
79
  fig5, ax5 = a.plot_score_heatmap()
80
  plt.tight_layout()
81
+ new_dropdown = update_dropdown(list(a.pairs_data.keys()), "Segment pair")
82
  return fig1, fig2, fig3, fig4, fig5, new_dropdown, a
83
 
84
+
85
  def view_pdb(seg_pair, a):
86
  pdb_path = a.pairs_data[seg_pair].pdb
87
  return view_pdb_html(pdb_path), a, pdb_path
88
 
89
 
 
90
  def update_dropdown(x, label):
91
  return gr.Dropdown.update(choices=x, label=label)
92
 
93
+
94
  def load_and_plot_celltype(celltype_name, GET_CONFIG, cell):
95
  celltype_id = cell_type_name_to_id[celltype_name]
96
  cell = GETCellType(celltype_id, GET_CONFIG)
97
  cell.celltype_name = celltype_name
98
  gene_exp_fig = cell.plotly_gene_exp()
99
+ return gene_exp_fig, cell
100
+
101
+
102
  def plot_gene_regions(cell, gene_name, plotly=True):
103
  return cell.plot_gene_regions(gene_name, plotly=plotly), cell
104
 
105
+
106
  def plot_gene_motifs(cell, gene_name, motif, overwrite=False):
107
  return cell.plot_gene_motifs(gene_name, motif, overwrite=overwrite)[0], cell
108
 
109
+
110
+ def plot_motif_subnet(cell, motif_collection, m, type="neighbors", threshold=0.1):
111
+ return (
112
+ cell.plotly_motif_subnet(motif_collection, m, type=type, threshold=threshold),
113
+ cell,
114
+ )
115
+
116
 
117
  def plot_gene_exp(cell, plotly=True):
118
  return cell.plotly_gene_exp(plotly=plotly), cell
119
 
120
+
121
  def plot_motif_corr(cell):
122
+ fig = Clustergram(
123
+ data=cell.gene_by_motif.corr,
124
+ column_labels=list(cell.gene_by_motif.corr.columns.values),
125
+ row_labels=list(cell.gene_by_motif.corr.index),
126
+ hidden_labels=["row", "col"],
127
+ link_method="average",
128
+ display_ratio=0.1,
129
+ width=600,
130
+ height=500,
131
+ color_map="rdbu_r",
132
+ )
133
  return fig, cell
134
 
 
 
135
 
136
+ if __name__ == "__main__":
137
+ with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
138
+ seg_pairs = gr.State([""])
 
 
139
  af = gr.State(None)
140
  cell = gr.State(None)
141
+
142
+ gr.Markdown(
143
+ """
144
+ # GET: A Foundation Model of Transcription Across Human Cell Types
145
+
146
+ _Transcriptional regulation, involving the complex interplay between regulatory sequences and proteins,
147
+ directs all biological processes. Computational models of transcriptions lack generalizability
148
+ to accurately extrapolate in unseen cell types and conditions. Here, we introduce GET,
149
+ an interpretable foundation model, designed to uncover deep regulatory patterns across 235 human fetal and adult cell types.
150
+ Relying exclusively on chromatin accessibility data and sequence information, GET achieves experimental-level accuracy
151
+ in predicting gene expression even in previously unseen cell types. GET showcases remarkable adaptability across new sequencing platforms and assays,
152
+ making it possible to infer regulatory activity across a broad range of cell types and conditions,
153
+ and to uncover universal and cell type specific transcription factor interaction networks.
154
+ We tested its performance on prediction of chromatin regulatory activity,
155
+ inference of regulatory elements and regulators of fetal hemoglobin,
156
+ and identification of known physical interactions between transcription factors.
157
+ In particular, we show GET outperforms current models in predicting lentivirus-based massive parallel reporter assay readout with reduced input data.
158
+ In fetal erythroblast, we are able to identify distant (>1Mbps) regulatory regions that were missed by previous models.
159
+ In sum, we provide a generalizable and predictive cell type specific model for transcription together with catalogs of gene regulation and transcription factor interactions.
160
+ Benefit from this catalog, we are able to provide mechanistic understanding of previously unknown significance germline coding variants in disordered regions of PAX5, a lymphoma associated transcription factor._
161
+ """
162
+ )
163
+
164
  with gr.Row() as row:
165
  # Left column: Plot gene expression and gene regions
166
  with gr.Column():
167
+ gr.Markdown(
168
+ """
169
+ ## Prediction performance
170
+ This section allows the selection of cell types and provides a plot depicting the observed versus predicted gene expression levels.
171
+ """
172
+ )
173
  with gr.Row() as row:
174
+ celltype_name = gr.Dropdown(
175
+ label="Cell Type", choices=avaliable_celltypes
176
+ )
177
+ celltype_btn = gr.Button(value="Load & Plot Gene Expression")
178
+ gene_exp_plot = gr.Plot(label="Gene Expression Pred vs Obs")
179
+
180
  # Right column: Plot gene motifs
181
  with gr.Column():
182
+ gr.Markdown(
183
+ """
184
+ ## Cell-type specific regulatory inference
185
+ This section allows the selection of a gene and provides plots of its cell-type specific regulatory regions and motifs.
186
+ """
187
+ )
188
+ gene_name_for_region = gr.Textbox(
189
+ label="Get important regions or motifs for gene:"
190
+ )
191
  with gr.Row() as row:
192
+ region_plot_btn = gr.Button(value="Regions")
193
+ motif_plot_btn = gr.Button(value="Motifs")
194
 
195
+ region_plot = gr.Plot(label="Gene Regions")
196
+ motif_plot = gr.Plot(label="Gene Motifs")
197
 
198
+ gr.Markdown(
199
+ """
200
+ ## Motif Correlation and Causal Subnetworks
201
+
202
+ Here, you can generate a heatmap to visualize motif correlations. Alternatively, you can explore the causal subnetworks related to specific motifs by selecting the motif and the type of subnetwork you are interested in, along with a effect size threshold.
203
+ """
204
+ )
205
  with gr.Row() as row:
206
  with gr.Column():
207
+ clustergram_btn = gr.Button(value="Plot Motif Correlation Heatmap")
208
+ clustergram_plot = gr.Plot(label="Motif Correlation")
209
 
 
210
  # Right column: Motif subnet plot
211
  with gr.Column():
212
  with gr.Row() as row:
213
+ motif_for_subnet = gr.Dropdown(
214
+ label="Motif Causal Subnetwork", choices=motif.cluster_names
215
+ )
216
+ subnet_type = gr.Dropdown(
217
+ label="Type",
218
+ choices=["neighbors", "parents", "children"],
219
+ default="neighbors",
220
+ )
221
  # slider for threshold 0.01-0.2
222
+ subnet_threshold = gr.Slider(
223
+ label="Threshold",
224
+ minimum=0.01,
225
+ maximum=0.25,
226
+ step=0.01,
227
+ value=0.1,
228
+ )
229
+ subnet_btn = gr.Button(value="Plot Motif Causal Subnetwork")
230
+ subnet_plot = gr.Plot(label="Motif Causal Subnetwork")
231
 
232
+ gr.Markdown(
233
+ """
234
+ ## Structural atlas of TF-TF and TF-EP300 interactions
235
+
236
+ This section allows you to explore transcription factor pairs. You can visualize various metrics such as Heatmaps and pLDDT (predicted Local Distance Difference Test) for both proteins in the interacting pair. You can also download the PDB file for specific segment pairs.
237
+ """
238
+ )
239
  with gr.Row() as row:
240
  with gr.Column():
241
  with gr.Row() as row:
242
+ tf_pairs = gr.Dropdown(label="TF pair", choices=gene_pairs)
243
+ tf_pairs_btn = gr.Button(value="Load & Plot")
244
+ heatmap = gr.Plot(label="Heatmap")
245
+ interact_plddt1 = gr.Plot(label="Interact pLDDT 1")
246
+ interact_plddt2 = gr.Plot(label="Interact pLDDT 2")
247
+ protein1_plddt = gr.Plot(label="Protein 1 pLDDT")
248
+ protein2_plddt = gr.Plot(label="Protein 2 pLDDT")
249
+
 
250
  with gr.Column():
251
  with gr.Row() as row:
252
+ segpair = gr.Dropdown(label="Seg pair", choices=seg_pairs.value)
253
+ segpair_btn = gr.Button(value="Get PDB")
254
  pdb_html = gr.HTML(label="PDB HTML")
255
+ pdb_file = gr.File(label="Download PDB")
256
 
257
+ tf_pairs_btn.click(
258
+ visualize_AF2,
259
+ inputs=[tf_pairs, af],
260
+ outputs=[
261
+ interact_plddt1,
262
+ interact_plddt2,
263
+ protein1_plddt,
264
+ protein2_plddt,
265
+ heatmap,
266
+ segpair,
267
+ af,
268
+ ],
269
+ )
270
+ segpair_btn.click(
271
+ view_pdb, inputs=[segpair, af], outputs=[pdb_html, af, pdb_file]
272
+ )
273
+ celltype_btn.click(
274
+ load_and_plot_celltype,
275
+ inputs=[celltype_name, gr.State(GET_CONFIG), cell],
276
+ outputs=[gene_exp_plot, cell],
277
+ )
278
+ region_plot_btn.click(
279
+ plot_gene_regions,
280
+ inputs=[cell, gene_name_for_region],
281
+ outputs=[region_plot, cell],
282
+ )
283
+ motif_plot_btn.click(
284
+ plot_gene_motifs,
285
+ inputs=[cell, gene_name_for_region, gr.State(motif)],
286
+ outputs=[motif_plot, cell],
287
+ )
288
+ clustergram_btn.click(
289
+ plot_motif_corr, inputs=[cell], outputs=[clustergram_plot, cell]
290
+ )
291
+ subnet_btn.click(
292
+ plot_motif_subnet,
293
+ inputs=[
294
+ cell,
295
+ gr.State(motif),
296
+ motif_for_subnet,
297
+ subnet_type,
298
+ subnet_threshold,
299
+ ],
300
+ outputs=[subnet_plot, cell],
301
+ )
302
 
303
  demo.launch(share=args.share, server_port=args.port)
 
 
 
modules/atac_rna_data_processing CHANGED
@@ -1 +1 @@
1
- Subproject commit fc337002918de1e5f1f864e7ba94864a743fd16c
 
1
+ Subproject commit aa7c0bfaac5719577e892e3226749a4d62c48848