|
import pandas as pd |
|
from typing import Dict, Literal, Optional |
|
from datasets import Dataset, load_dataset |
|
from datasets.dataset_dict import DatasetDict |
|
from pydantic import BaseModel |
|
from protein_lm.modeling.getters.ptm_dataset import get_ptm_dataset |
|
from protein_lm.modeling.getters.uniref_dataset import get_uniref_dataset |
|
|
|
|
|
def get_dataset(config_dict: Dict, tokenizer) -> Dataset: |
|
if config_dict["dataset"] == "ptm": |
|
return get_ptm_dataset(config_dict, tokenizer) |
|
elif config_dict["dataset"] == "uniref50": |
|
return get_uniref_dataset(config_dict, tokenizer) |
|
else: |
|
raise ValueError(f"Invalid dataset {config_dict['dataset']}!") |
|
|