annt's picture
Add application file
daeb223
from dataclasses import dataclass, field
from .. import models
@dataclass
class RetroDataModelArguments:
pass
@dataclass
class DataArguments(RetroDataModelArguments):
max_seq_length: int = field(
default=512,
metadata={"help": ""},
)
max_answer_length: int = field(
default=30,
metadata={"help": ""},
)
doc_stride: int = field(
default=128,
metadata={"help": ""},
)
return_token_type_ids: bool = field(
default=True,
metadata={"help": ""},
)
pad_to_max_length: bool = field(
default=True,
metadata={"help": ""},
)
preprocessing_num_workers: int = field(
default=5,
metadata={"help": ""},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": ""},
)
version_2_with_negative: bool = field(
default=True,
metadata={"help": ""},
)
null_score_diff_threshold: float = field(
default=0.0,
metadata={"help": ""},
)
rear_threshold: float = field(
default=0.0,
metadata={"help": ""},
)
n_best_size: int = field(
default=20,
metadata={"help": ""},
)
use_choice_logits: bool = field(
default=False,
metadata={"help": ""},
)
start_n_top: int = field(
default=-1,
metadata={"help": ""},
)
end_n_top: int = field(
default=-1,
metadata={"help": ""},
)
beta1: int = field(
default=1,
metadata={"help": ""},
)
beta2: int = field(
default=1,
metadata={"help": ""},
)
best_cof: int = field(
default=1,
metadata={"help": ""},
)
@dataclass
class ModelArguments(RetroDataModelArguments):
use_auth_token: bool = field(
default=False,
metadata={"help": ""},
)
@dataclass
class SketchModelArguments(ModelArguments):
sketch_revision: str = field(
default="main",
metadata={"help": ""},
)
sketch_model_name: str = field(
default="monologg/koelectra-small-v3-discriminator",
metadata={"help": ""},
)
sketch_tokenizer_name: str = field(
default=None,
metadata={"help": ""},
)
sketch_architectures: str = field(
default="ElectraForSequenceClassification",
metadata={"help": ""},
)
@dataclass
class IntensiveModelArguments(ModelArguments):
intensive_revision: str = field(
default="main",
metadata={"help": ""},
)
intensive_model_name: str = field(
default="monologg/koelectra-small-v3-discriminator",
metadata={"help": ""},
)
intensive_tokenizer_name: str = field(
default=None,
metadata={"help": ""},
)
intensive_architectures: str = field(
default="ElectraForQuestionAnsweringAVPool",
metadata={"help": ""},
)
@dataclass
class RetroArguments(
DataArguments,
SketchModelArguments,
IntensiveModelArguments,
):
def __post_init__(self):
# Sketch
model_cls = getattr(models, self.sketch_architectures, None)
if model_cls is None:
raise AttributeError
self.sketch_model_cls = model_cls
self.sketch_model_type = model_cls.model_type
if self.sketch_tokenizer_name is None:
self.sketch_tokenizer_name = self.sketch_model_name
# Intensive
model_cls = getattr(models, self.intensive_architectures, None)
if model_cls is None:
raise AttributeError
self.intensive_model_cls = model_cls
self.intensive_model_type = model_cls.model_type
if self.intensive_tokenizer_name is None:
self.intensive_tokenizer_name = self.intensive_model_name