|
from torch import nn |
|
|
|
from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase |
|
|
|
|
|
class GPT2Block(GPT2BlockBase): |
|
def forward(self, x, layer_past=None, |
|
attention_mask=None, head_mask=None, use_cache=False, |
|
encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None): |
|
|
|
x = self.ln_1(x) |
|
output_attn = self.attn( |
|
x, layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache) |
|
|
|
a = output_attn[0] |
|
x = x + a |
|
|
|
m = self.mlp(self.ln_2(x)) |
|
x = x + m |
|
|
|
outputs = (x,) + output_attn[1:] |
|
return outputs |
|
|
|
|
|
class GPT2LMHeadModel(GPT2LMHeadModelBase): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)]) |