File size: 1,970 Bytes
a476bbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import pandas as pd
from Bio import SeqIO
from typing import Dict, Literal, Optional
from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from typing import Dict, Literal, Optional
from protein_lm.modeling.getters.ptm_dataset import DatasetConfig, train_val_test_split
def read_fasta_file(fasta_file_path, subsample_size):
ids = []
seqs = []
with open(fasta_file_path, "r") as fasta_file:
for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")):
if subsample_size and i >= subsample_size:
break
ids.append(record.id)
seqs.append(str(record.seq))
return {"id": ids, "seq": seqs}
def load_uniref_dataset(seq_dict, config) -> DatasetDict:
ds = Dataset.from_dict(seq_dict)
ds_dict = DatasetDict({"train": ds})
return train_val_test_split(ds_dict, config)
def seq2token(batch, tokenizer, sequence_column_name, max_sequence_length):
batch["input_ids"] = tokenizer(
batch[sequence_column_name],
add_special_tokens=True,
max_sequence_length=max_sequence_length,
)
return batch
def get_uniref_dataset(config: Dict, tokenizer) -> Dataset:
# config = DatasetConfig(**config_dict)
if config.cache_dir is not None and os.path.exists(config.cache_dir):
split_dict = DatasetDict.load_from_disk(config.cache_dir)
return split_dict
seq_dict = read_fasta_file(config.dataset_loc, config.subsample_size)
split_dict = load_uniref_dataset(seq_dict, config)
split_dict = split_dict.map(
lambda e: seq2token(
batch=e,
tokenizer=tokenizer,
sequence_column_name="seq",
max_sequence_length=config.max_sequence_length,
),
batched=True,
)
if config.cache_dir is not None:
os.makedirs(config.cache_dir, exist_ok=True)
split_dict.save_to_disk(config.cache_dir)
return split_dict
|