File size: 5,476 Bytes
8b76ef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CodeFuseCGESmallConfig(PretrainedConfig):
model_type = "phi3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=3072,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="silu",
max_position_embeddings=4096,
original_max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
bos_token_id=1,
eos_token_id=32000,
pad_token_id=32000,
sliding_window=None,
embedding_method="pma",
inf_seq_length=1024,
padding_side="right",
compress_dim=1024,
keep_max_layer=32,
pma_num_heads=32,
pma_ln=True,
pma_norm=False,
pma_norm_mode="post_normal",
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window
self.embedding_method = embedding_method
self.inf_seq_length = inf_seq_length
self.padding_side = padding_side
self.compress_dim = compress_dim
self.keep_max_layer = keep_max_layer
self.pma_num_heads = pma_num_heads
self.pma_ln = pma_ln
self.pma_norm = pma_norm
self.pma_norm_mode = pma_norm_mode
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return
rope_scaling_type = self.rope_scaling.get("type", None)
# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
|