PascalNotin commited on
Commit
7275554
1 Parent(s): b66f09d

Adjusted heatmap size and device printing

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -39,12 +39,12 @@ def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=No
39
  return all_single_mutants
40
 
41
  def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
 
 
42
  piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
43
- fig, ax = plt.subplots(figsize=(50,len(sequence)*0.6))
44
  scores_dict = {}
45
  valid_mutant_set=set(scores.mutant)
46
- if mutation_range_start is None: mutation_range_start=1
47
- if mutation_range_end is None: mutation_range_end=len(sequence)
48
  ax.tick_params(bottom=True, top=True, left=True, right=True)
49
  ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
50
  if annotate:
@@ -63,7 +63,6 @@ def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_rang
63
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
64
  heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
65
  heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
66
- #heat.figure.axes[-1].yaxis.set_ticklabels(fontsize=fontsize)
67
  heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
68
  heat.set_ylabel("Sequence position", fontsize = fontsize*2)
69
  heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
@@ -87,7 +86,6 @@ def suggest_mutations(scores):
87
  positive_scores = scores[scores.avg_score > 0]
88
  positive_scores_position_avg = positive_scores.groupby(['position']).mean()
89
  top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
90
- print(top_positions)
91
  position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
92
  return intro_message+mutant_recos+position_recos
93
 
@@ -115,6 +113,11 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
115
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
116
  elif model_type=="Large":
117
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
 
 
 
 
 
118
  model.config.tokenizer = tokenizer
119
  all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
120
  scores = model.score_mutants(DMS_data=all_single_mutants,
@@ -205,10 +208,10 @@ with tranception_design:
205
  )
206
  gr.Markdown("<br>")
207
  gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
208
-
209
  #output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
210
  #output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
211
- output_image = gr.Gallery(label="Fitness predictions (inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones)",type="filepath") #Using Gallery to be able to scroll large matrix images
212
 
213
  output_recommendations = gr.Textbox(label="Mutation recommendations")
214
 
 
39
  return all_single_mutants
40
 
41
  def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
42
+ if mutation_range_start is None: mutation_range_start=1
43
+ if mutation_range_end is None: mutation_range_end=len(sequence)
44
  piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
45
+ fig, ax = plt.subplots(figsize=(min(len(sequence),50),len(sequence)))
46
  scores_dict = {}
47
  valid_mutant_set=set(scores.mutant)
 
 
48
  ax.tick_params(bottom=True, top=True, left=True, right=True)
49
  ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
50
  if annotate:
 
63
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
64
  heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
65
  heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
 
66
  heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
67
  heat.set_ylabel("Sequence position", fontsize = fontsize*2)
68
  heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
 
86
  positive_scores = scores[scores.avg_score > 0]
87
  positive_scores_position_avg = positive_scores.groupby(['position']).mean()
88
  top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
 
89
  position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
90
  return intro_message+mutant_recos+position_recos
91
 
 
113
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
114
  elif model_type=="Large":
115
  model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
116
+ if torch.cuda.is_available():
117
+ model.cuda()
118
+ print("Inference will take place on GPU")
119
+ else:
120
+ print("Inference will take place on CPU")
121
  model.config.tokenizer = tokenizer
122
  all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
123
  scores = model.score_mutants(DMS_data=all_single_mutants,
 
208
  )
209
  gr.Markdown("<br>")
210
  gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
211
+ gr.Markdown("Inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones")
212
  #output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
213
  #output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
214
+ output_image = gr.Gallery(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath") #Using Gallery to be able to scroll large matrix images
215
 
216
  output_recommendations = gr.Textbox(label="Mutation recommendations")
217