|
import os |
|
import pandas as pd |
|
|
|
from Bio import SeqIO |
|
from typing import Dict, Literal, Optional |
|
from datasets import Dataset, load_dataset |
|
from datasets.dataset_dict import DatasetDict |
|
from typing import Dict, Literal, Optional |
|
from protein_lm.modeling.getters.ptm_dataset import DatasetConfig, train_val_test_split |
|
|
|
|
|
def read_fasta_file(fasta_file_path, subsample_size): |
|
ids = [] |
|
seqs = [] |
|
with open(fasta_file_path, "r") as fasta_file: |
|
for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")): |
|
if subsample_size and i >= subsample_size: |
|
break |
|
ids.append(record.id) |
|
seqs.append(str(record.seq)) |
|
|
|
return {"id": ids, "seq": seqs} |
|
|
|
|
|
def load_uniref_dataset(seq_dict, config) -> DatasetDict: |
|
ds = Dataset.from_dict(seq_dict) |
|
ds_dict = DatasetDict({"train": ds}) |
|
return train_val_test_split(ds_dict, config) |
|
|
|
|
|
def seq2token(batch, tokenizer, sequence_column_name, max_sequence_length): |
|
batch["input_ids"] = tokenizer( |
|
batch[sequence_column_name], |
|
add_special_tokens=True, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
return batch |
|
|
|
|
|
def get_uniref_dataset(config: Dict, tokenizer) -> Dataset: |
|
|
|
if config.cache_dir is not None and os.path.exists(config.cache_dir): |
|
split_dict = DatasetDict.load_from_disk(config.cache_dir) |
|
return split_dict |
|
seq_dict = read_fasta_file(config.dataset_loc, config.subsample_size) |
|
split_dict = load_uniref_dataset(seq_dict, config) |
|
split_dict = split_dict.map( |
|
lambda e: seq2token( |
|
batch=e, |
|
tokenizer=tokenizer, |
|
sequence_column_name="seq", |
|
max_sequence_length=config.max_sequence_length, |
|
), |
|
batched=True, |
|
) |
|
if config.cache_dir is not None: |
|
os.makedirs(config.cache_dir, exist_ok=True) |
|
split_dict.save_to_disk(config.cache_dir) |
|
return split_dict |
|
|