zsp / data.py
Massimo G. Totaro
update fix
fba8f5e
raw
history blame
8.8 kB
from math import ceil
from re import match
import seaborn as sns
from model import Model
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from model import Model
class Data:
"""Container for input and output data"""
# Initialise empty model as static class member for efficiency
model = Model()
def parse_seq(self, src: str):
"""Parse input sequence"""
self.seq = src.strip().upper().replace('\n', '')
if not all(x in self.model.alphabet for x in self.seq):
raise RuntimeError("Unrecognised characters in sequence")
def parse_sub(self, trg: str):
"""Parse input substitutions"""
self.mode = None
self.sub = list()
self.trg = trg.strip().upper()
self.resi = list()
# Identify running mode
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg):
# If single string of same length as sequence, seq vs seq mode
self.mode = 'MUT'
for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
if src != trg:
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
else:
self.trg = self.trg.split()
if all(match(r'\d+', x) for x in self.trg):
# If all strings are numbers, deep mutational scanning mode
self.mode = 'DMS'
for resi in map(int, self.trg):
src = self.seq[resi-1]
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
# If all strings are of the form X#Y, single substitution mode
self.mode = 'MUT'
self.sub = self.trg
self.resi = [int(x[1:-1]) for x in self.trg]
for s, *resi, _ in self.trg:
if self.seq[int(''.join(resi))-1] != s:
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
else:
self.mode = 'TMS'
for resi, src in enumerate(self.seq, 1):
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
self.sub = pd.DataFrame(self.sub, columns=['0'])
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None):
"initialise data"
# if model has changed, load new model
if self.model.model_name != model_name:
self.model_name = model_name
self.model = Model(model_name)
self.parse_seq(src)
self.offset = 0
self.parse_sub(trg)
self.scoring_strategy = scoring_strategy
self.token_probs = None
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
self.out_str = None
self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file
def parse_output(self) -> None:
"format output data for visualisation"
if self.mode == 'TMS':
self.process_tms_mode()
else:
if self.mode == 'DMS':
self.sort_by_residue_and_score()
elif self.mode == 'MUT':
self.sort_by_score()
else:
raise RuntimeError(f"Unrecognised mode {self.mode}")
if self.out_buffer:
self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
self.out_str = (self.out.style
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
.hide(axis=0)
.hide(axis=1)
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
.to_html(justify='center'))
def sort_by_score(self):
self.out = self.out.sort_values(self.model_name, ascending=False)
def sort_by_residue_and_score(self):
self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.sort_values(['resi', self.model_name], ascending=[True,False])
.groupby(['resi'])
.head(19)
.drop(['resi'], axis=1))
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
def process_tms_mode(self):
self.out = self.assign_resi_and_group()
self.out = self.concat_and_set_axis()
self.out /= self.out.abs().max().max()
divs = self.calculate_divs()
ncols = min(divs, key=lambda x: abs(x-60))
nrows = ceil(self.out.shape[1]/ncols)
ncols = self.adjust_ncols(ncols, nrows)
self.plot_heatmap(ncols, nrows)
def assign_resi_and_group(self):
return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.groupby(['resi'])
.head(19))
def concat_and_set_axis(self):
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
.pipe(self.create_dataframe)
.sort_values(['0'], ascending=[True])
.drop(['resi', '0'], axis=1)
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
.astype(float)
) for x in range(self.out.shape[0]//19)]
, axis=1)
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))
def create_dataframe(self, df):
return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)
def calculate_divs(self):
return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]
def adjust_ncols(self, ncols, nrows):
while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
ncols -= 1
return ncols + 1
def plot_heatmap(self, ncols, nrows):
if nrows < 2:
self.plot_single_heatmap()
else:
self.plot_multiple_heatmaps(ncols, nrows)
if self.out_buffer:
plt.savefig(self.out_buffer, format='svg')
with open(self.out_buffer, 'r', encoding='utf-8') as f:
self.out_str = f.read()
def plot_single_heatmap(self):
fig = plt.figure(figsize=(12, 6))
sns.heatmap(self.out
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=self.out.map(lambda x: ' ' if x != 0 else '·')
, fmt='s'
, annot_kws={'size': 'xx-large'})
fig.tight_layout()
def plot_multiple_heatmaps(self, ncols, nrows):
fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
for i in range(nrows):
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
label = tmp.map(lambda x: ' ' if x != 0 else '·')
sns.heatmap(tmp
, ax=ax[i]
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=label
, fmt='s'
, annot_kws={'size': 'xx-large'})
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
fig.tight_layout()
def calculate(self):
"run model and parse output"
self.model.run_model(self)
self.parse_output()
return self
def __str__(self):
"return output data in DataFrame format"
return str(self.out)
def __repr__(self):
"return output data in html format"
return self.out_str