|
""" |
|
Geneformer tokenizer. |
|
|
|
Input data: |
|
Required format: raw counts scRNAseq data without feature selection as .loom file |
|
Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene |
|
Required col (cell) attribute: "n_counts"; total read counts in that cell |
|
Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria |
|
Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below |
|
|
|
Usage: |
|
from geneformer import TranscriptomeTokenizer |
|
tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) |
|
tk.tokenize_data("data_directory", "output_directory", "output_prefix") |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
import pickle |
|
import warnings |
|
from pathlib import Path |
|
from typing import Literal |
|
|
|
import anndata as ad |
|
import loompy as lp |
|
import numpy as np |
|
import scipy.sparse as sp |
|
from datasets import Dataset |
|
|
|
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") |
|
logger = logging.getLogger(__name__) |
|
|
|
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" |
|
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" |
|
|
|
|
|
def rank_genes(gene_vector, gene_tokens): |
|
""" |
|
Rank gene expression vector. |
|
""" |
|
|
|
sorted_indices = np.argsort(-gene_vector) |
|
return gene_tokens[sorted_indices] |
|
|
|
|
|
def tokenize_cell(gene_vector, gene_tokens): |
|
""" |
|
Convert normalized gene expression vector to tokenized rank value encoding. |
|
""" |
|
|
|
|
|
nonzero_mask = np.nonzero(gene_vector)[0] |
|
|
|
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask]) |
|
|
|
|
|
class TranscriptomeTokenizer: |
|
def __init__( |
|
self, |
|
custom_attr_name_dict=None, |
|
nproc=1, |
|
chunk_size=512, |
|
gene_median_file=GENE_MEDIAN_FILE, |
|
token_dictionary_file=TOKEN_DICTIONARY_FILE, |
|
): |
|
""" |
|
Initialize tokenizer. |
|
|
|
Parameters |
|
---------- |
|
custom_attr_name_dict : None, dict |
|
Dictionary of custom attributes to be added to the dataset. |
|
Keys are the names of the attributes in the loom file. |
|
Values are the names of the attributes in the dataset. |
|
nproc : int |
|
Number of processes to use for dataset mapping. |
|
chunk_size: int = 512 |
|
Chunk size for anndata tokenizer. |
|
gene_median_file : Path |
|
Path to pickle file containing dictionary of non-zero median |
|
gene expression values across Genecorpus-30M. |
|
token_dictionary_file : Path |
|
Path to pickle file containing token dictionary (Ensembl IDs:token). |
|
""" |
|
|
|
self.custom_attr_name_dict = custom_attr_name_dict |
|
|
|
|
|
self.nproc = nproc |
|
|
|
|
|
self.chunk_size = chunk_size |
|
|
|
|
|
|
|
with open(gene_median_file, "rb") as f: |
|
self.gene_median_dict = pickle.load(f) |
|
|
|
|
|
with open(token_dictionary_file, "rb") as f: |
|
self.gene_token_dict = pickle.load(f) |
|
|
|
|
|
self.gene_keys = list(self.gene_median_dict.keys()) |
|
|
|
|
|
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) |
|
|
|
def tokenize_data( |
|
self, |
|
data_directory: Path | str, |
|
output_directory: Path | str, |
|
output_prefix: str, |
|
file_format: Literal["loom", "h5ad"] = "loom", |
|
use_generator: bool = False, |
|
): |
|
""" |
|
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory. |
|
|
|
Parameters |
|
---------- |
|
data_directory : Path |
|
Path to directory containing loom files or anndata files |
|
output_directory : Path |
|
Path to directory where tokenized data will be saved as .dataset |
|
output_prefix : str |
|
Prefix for output .dataset |
|
file_format : str |
|
Format of input files. Can be "loom" or "h5ad". |
|
use_generator : bool |
|
Whether to use generator or dict for tokenization. |
|
""" |
|
tokenized_cells, cell_metadata = self.tokenize_files( |
|
Path(data_directory), file_format |
|
) |
|
tokenized_dataset = self.create_dataset( |
|
tokenized_cells, cell_metadata, use_generator=use_generator |
|
) |
|
|
|
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") |
|
tokenized_dataset.save_to_disk(output_path) |
|
|
|
def tokenize_files( |
|
self, data_directory, file_format: Literal["loom", "h5ad"] = "loom" |
|
): |
|
tokenized_cells = [] |
|
if self.custom_attr_name_dict is not None: |
|
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] |
|
cell_metadata = { |
|
attr_key: [] for attr_key in self.custom_attr_name_dict.values() |
|
} |
|
|
|
|
|
file_found = 0 |
|
|
|
tokenize_file_fn = ( |
|
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata |
|
) |
|
for file_path in data_directory.glob("*.{}".format(file_format)): |
|
file_found = 1 |
|
print(f"Tokenizing {file_path}") |
|
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path) |
|
tokenized_cells += file_tokenized_cells |
|
if self.custom_attr_name_dict is not None: |
|
for k in cell_attr: |
|
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[ |
|
k |
|
] |
|
else: |
|
cell_metadata = None |
|
|
|
if file_found == 0: |
|
logger.error( |
|
f"No .{file_format} files found in directory {data_directory}." |
|
) |
|
raise |
|
return tokenized_cells, cell_metadata |
|
|
|
def tokenize_anndata(self, adata_file_path, target_sum=10_000): |
|
adata = ad.read(adata_file_path, backed="r") |
|
|
|
if self.custom_attr_name_dict is not None: |
|
file_cell_metadata = { |
|
attr_key: [] for attr_key in self.custom_attr_name_dict.keys() |
|
} |
|
|
|
coding_miRNA_loc = np.where( |
|
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]] |
|
)[0] |
|
norm_factor_vector = np.array( |
|
[ |
|
self.gene_median_dict[i] |
|
for i in adata.var["ensembl_id"][coding_miRNA_loc] |
|
] |
|
) |
|
coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc] |
|
coding_miRNA_tokens = np.array( |
|
[self.gene_token_dict[i] for i in coding_miRNA_ids] |
|
) |
|
|
|
try: |
|
_ = adata.obs["filter_pass"] |
|
except KeyError: |
|
var_exists = False |
|
else: |
|
var_exists = True |
|
|
|
if var_exists: |
|
filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0] |
|
elif not var_exists: |
|
print( |
|
f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells." |
|
) |
|
filter_pass_loc = np.array([i for i in range(adata.shape[0])]) |
|
|
|
tokenized_cells = [] |
|
|
|
for i in range(0, len(filter_pass_loc), self.chunk_size): |
|
idx = filter_pass_loc[i : i + self.chunk_size] |
|
|
|
n_counts = adata[idx].obs["n_counts"].values[:, None] |
|
X_view = adata[idx, coding_miRNA_loc].X |
|
X_norm = X_view / n_counts * target_sum / norm_factor_vector |
|
X_norm = sp.csr_matrix(X_norm) |
|
|
|
tokenized_cells += [ |
|
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices]) |
|
for i in range(X_norm.shape[0]) |
|
] |
|
|
|
|
|
if self.custom_attr_name_dict is not None: |
|
for k in file_cell_metadata.keys(): |
|
file_cell_metadata[k] += adata[idx].obs[k].tolist() |
|
else: |
|
file_cell_metadata = None |
|
|
|
return tokenized_cells, file_cell_metadata |
|
|
|
def tokenize_loom(self, loom_file_path, target_sum=10_000): |
|
if self.custom_attr_name_dict is not None: |
|
file_cell_metadata = { |
|
attr_key: [] for attr_key in self.custom_attr_name_dict.keys() |
|
} |
|
|
|
with lp.connect(str(loom_file_path)) as data: |
|
|
|
coding_miRNA_loc = np.where( |
|
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] |
|
)[0] |
|
norm_factor_vector = np.array( |
|
[ |
|
self.gene_median_dict[i] |
|
for i in data.ra["ensembl_id"][coding_miRNA_loc] |
|
] |
|
) |
|
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] |
|
coding_miRNA_tokens = np.array( |
|
[self.gene_token_dict[i] for i in coding_miRNA_ids] |
|
) |
|
|
|
|
|
try: |
|
data.ca["filter_pass"] |
|
except AttributeError: |
|
var_exists = False |
|
else: |
|
var_exists = True |
|
|
|
if var_exists: |
|
filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0] |
|
elif not var_exists: |
|
print( |
|
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." |
|
) |
|
filter_pass_loc = np.array([i for i in range(data.shape[1])]) |
|
|
|
|
|
tokenized_cells = [] |
|
for _ix, _selection, view in data.scan(items=filter_pass_loc, axis=1): |
|
|
|
subview = view.view[coding_miRNA_loc, :] |
|
|
|
|
|
|
|
subview_norm_array = ( |
|
subview[:, :] |
|
/ subview.ca.n_counts |
|
* target_sum |
|
/ norm_factor_vector[:, None] |
|
) |
|
|
|
tokenized_cells += [ |
|
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) |
|
for i in range(subview_norm_array.shape[1]) |
|
] |
|
|
|
|
|
if self.custom_attr_name_dict is not None: |
|
for k in file_cell_metadata.keys(): |
|
file_cell_metadata[k] += subview.ca[k].tolist() |
|
else: |
|
file_cell_metadata = None |
|
|
|
return tokenized_cells, file_cell_metadata |
|
|
|
def create_dataset( |
|
self, |
|
tokenized_cells, |
|
cell_metadata, |
|
use_generator=False, |
|
keep_uncropped_input_ids=False, |
|
): |
|
print("Creating dataset.") |
|
|
|
dataset_dict = {"input_ids": tokenized_cells} |
|
if self.custom_attr_name_dict is not None: |
|
dataset_dict.update(cell_metadata) |
|
|
|
|
|
if use_generator: |
|
|
|
def dict_generator(): |
|
for i in range(len(tokenized_cells)): |
|
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()} |
|
|
|
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc) |
|
else: |
|
output_dataset = Dataset.from_dict(dataset_dict) |
|
|
|
def format_cell_features(example): |
|
|
|
if keep_uncropped_input_ids: |
|
example["input_ids_uncropped"] = example["input_ids"] |
|
example["length_uncropped"] = len(example["input_ids"]) |
|
|
|
|
|
example["input_ids"] = example["input_ids"][0:2048] |
|
example["length"] = len(example["input_ids"]) |
|
|
|
return example |
|
|
|
output_dataset_truncated = output_dataset.map( |
|
format_cell_features, num_proc=self.nproc |
|
) |
|
return output_dataset_truncated |
|
|