gpt2-small-theseus-bg / modeling_gpt2.py
rmihaylov's picture
add model
1812276
raw
history blame
1.02 kB
from torch import nn
from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase
from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase
class GPT2Block(GPT2BlockBase):
def forward(self, x, layer_past=None,
attention_mask=None, head_mask=None, use_cache=False,
encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None):
x = self.ln_1(x)
output_attn = self.attn(
x, layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache)
a = output_attn[0]
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
outputs = (x,) + output_attn[1:]
return outputs
class GPT2LMHeadModel(GPT2LMHeadModelBase):
def __init__(self, config):
super().__init__(config)
self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)])