|
from transformers import PreTrainedModel |
|
from .configuration_ss4m import SimpleStories4MConfig |
|
from .nano_gpt_model import NanoGPT |
|
|
|
class SimpleStories4MModel(PreTrainedModel): |
|
config_class = SimpleStories4MConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
hyperparameters = { |
|
"vocab_size": config.vocab_size, |
|
"block_size": config.block_size, |
|
"n_embed": config.n_embed, |
|
"n_heads": config.n_heads, |
|
"n_layers": config.n_layers, |
|
"dropout": config.dropout, |
|
|
|
} |
|
self.model = NanoGPT(hyperparameters) |
|
|
|
def forward(self, tensor, targets=None): |
|
return self.model(tensor, targets) |