|
import os |
|
current_directory = os.path.dirname(os.path.abspath(__file__)) |
|
os.chdir(current_directory) |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
from tokenizer import Tokenizer |
|
tokenizer = Tokenizer() |
|
vocab_size = tokenizer.get_vocab() |
|
|
|
from model import Transformer |
|
model = Transformer(vocab_size) |
|
checkpoint_path = '/content/drive/MyDrive/base-500m.pth' |
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint) |
|
m = model.to(device) |
|
|
|
class Generate: |
|
def __init__(self): |
|
self.vocab_size = vocab_size |
|
self.block_size = m.block_size |
|
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0): |
|
""" |
|
generate new tokens using the trained model |
|
|
|
Args: |
|
- idx (Tensor): input tensor representing initial token indices |
|
- max_new_tokens (int): max no of new tokens to generate |
|
- temperature (float): softmax temperature for sampling |
|
- top_k (int): no of top tokens to consider in sampling |
|
|
|
Returns: |
|
- generated_tokens (list): list of generated token indices |
|
""" |
|
generated_tokens = [] |
|
|
|
for _ in range(max_new_tokens): |
|
idx_cond = idx[:, -m.block_size:] |
|
logits, _ = self(idx_cond) |
|
logits = logits[:, -1, :] |
|
|
|
scaled_logits = logits / temperature |
|
if top_k > 0: |
|
scaled_logits = self._top_k_filtering(scaled_logits, top_k) |
|
|
|
probs = F.softmax(scaled_logits, dim=-1) |
|
sampled_idx = torch.multinomial(probs, num_samples=1) |
|
generated_tokens.append(sampled_idx.item()) |
|
idx = torch.cat((idx, sampled_idx), dim=1) |
|
|
|
return generated_tokens |
|
|
|
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0): |
|
""" |
|
Generate predictions for masked tokens using the trained model. |
|
|
|
Args: |
|
- idx (Tensor): input tensor representing token indices |
|
- masked_indices (Tensor): tensor of indices indicating masked positions |
|
- temperature (float): softmax temperature for sampling |
|
- top_k (int): no of top tokens to consider in sampling |
|
|
|
Returns: |
|
- predicted_tokens (Tensor): tensor of predicted token indices |
|
""" |
|
B, T = idx.shape |
|
|
|
toked_model = m.toked_model(idx) |
|
pos_encod = m.pos_encod(torch.arange(T, device=device)) |
|
x = toked_model + pos_encod |
|
|
|
for layer in m.enc_layer: |
|
x_out = layer(x) |
|
|
|
for layer in m.dec_layer: |
|
x_final = layer(x, x_out) |
|
|
|
x_masked = x_final.clone() |
|
x_masked[masked_indices] = m.toked_model(torch.tensor([6], device=device)) |
|
|
|
x_masked = m.norm_final(x_masked) |
|
logits = m.linear_final(x_masked) |
|
|
|
masked_logits = logits[masked_indices].view(-1, logits.size(-1)) |
|
scaled_logits = masked_logits / temperature |
|
if top_k > 0: |
|
scaled_logits = self._top_k_filtering(scaled_logits, top_k) |
|
|
|
probs = F.softmax(scaled_logits, dim=-1) |
|
predicted_indices = torch.argmax(probs, dim=-1) |
|
|
|
return predicted_indices |
|
|
|
def _top_k_filtering(self, logits, top_k): |
|
""" |
|
filter logits to keep only the top-k tokens |
|
|
|
Args: |
|
- logits (Tensor): input tensor representing unscaled logits |
|
- top_k (int): no of top tokens to keep |
|
|
|
Returns: |
|
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining |
|
""" |
|
values, indices = torch.topk(logits, top_k, dim=-1) |
|
min_value = values[:, -1].unsqueeze(-1).expand_as(logits) |
|
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits) |
|
|
|
return filtered_logits |
|
|
|
generator = Generate() |
|
|
|
target_text = "I was in the market when" |
|
context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device) |
|
generated_output = tokenizer.decode(generator.generate(context, max_new_tokens=50)) |
|
print(target_text, generated_output) |