Spaces:
Build error
Build error
PascalNotin
commited on
Commit
•
7275554
1
Parent(s):
b66f09d
Adjusted heatmap size and device printing
Browse files
app.py
CHANGED
@@ -39,12 +39,12 @@ def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=No
|
|
39 |
return all_single_mutants
|
40 |
|
41 |
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
|
|
|
|
|
42 |
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
|
43 |
-
fig, ax = plt.subplots(figsize=(50,len(sequence)
|
44 |
scores_dict = {}
|
45 |
valid_mutant_set=set(scores.mutant)
|
46 |
-
if mutation_range_start is None: mutation_range_start=1
|
47 |
-
if mutation_range_end is None: mutation_range_end=len(sequence)
|
48 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
49 |
ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
|
50 |
if annotate:
|
@@ -63,7 +63,6 @@ def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_rang
|
|
63 |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
|
64 |
heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
|
65 |
heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
|
66 |
-
#heat.figure.axes[-1].yaxis.set_ticklabels(fontsize=fontsize)
|
67 |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
|
68 |
heat.set_ylabel("Sequence position", fontsize = fontsize*2)
|
69 |
heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
|
@@ -87,7 +86,6 @@ def suggest_mutations(scores):
|
|
87 |
positive_scores = scores[scores.avg_score > 0]
|
88 |
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
|
89 |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
|
90 |
-
print(top_positions)
|
91 |
position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
|
92 |
return intro_message+mutant_recos+position_recos
|
93 |
|
@@ -115,6 +113,11 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
|
|
115 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
|
116 |
elif model_type=="Large":
|
117 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
|
|
|
|
|
|
|
|
|
|
|
118 |
model.config.tokenizer = tokenizer
|
119 |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
120 |
scores = model.score_mutants(DMS_data=all_single_mutants,
|
@@ -205,10 +208,10 @@ with tranception_design:
|
|
205 |
)
|
206 |
gr.Markdown("<br>")
|
207 |
gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
|
208 |
-
|
209 |
#output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
|
210 |
#output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
|
211 |
-
output_image = gr.Gallery(label="Fitness predictions
|
212 |
|
213 |
output_recommendations = gr.Textbox(label="Mutation recommendations")
|
214 |
|
|
|
39 |
return all_single_mutants
|
40 |
|
41 |
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
|
42 |
+
if mutation_range_start is None: mutation_range_start=1
|
43 |
+
if mutation_range_end is None: mutation_range_end=len(sequence)
|
44 |
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
|
45 |
+
fig, ax = plt.subplots(figsize=(min(len(sequence),50),len(sequence)))
|
46 |
scores_dict = {}
|
47 |
valid_mutant_set=set(scores.mutant)
|
|
|
|
|
48 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
49 |
ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
|
50 |
if annotate:
|
|
|
63 |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
|
64 |
heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
|
65 |
heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
|
|
|
66 |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
|
67 |
heat.set_ylabel("Sequence position", fontsize = fontsize*2)
|
68 |
heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
|
|
|
86 |
positive_scores = scores[scores.avg_score > 0]
|
87 |
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
|
88 |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
|
|
|
89 |
position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
|
90 |
return intro_message+mutant_recos+position_recos
|
91 |
|
|
|
113 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
|
114 |
elif model_type=="Large":
|
115 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
model.cuda()
|
118 |
+
print("Inference will take place on GPU")
|
119 |
+
else:
|
120 |
+
print("Inference will take place on CPU")
|
121 |
model.config.tokenizer = tokenizer
|
122 |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
123 |
scores = model.score_mutants(DMS_data=all_single_mutants,
|
|
|
208 |
)
|
209 |
gr.Markdown("<br>")
|
210 |
gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
|
211 |
+
gr.Markdown("Inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones")
|
212 |
#output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
|
213 |
#output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
|
214 |
+
output_image = gr.Gallery(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath") #Using Gallery to be able to scroll large matrix images
|
215 |
|
216 |
output_recommendations = gr.Textbox(label="Mutation recommendations")
|
217 |
|