gpt2-story-tinetuned / story /train-small.py
k050506koch's picture
Upload 109 files
ece0628 verified
import os
import torch
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"]="0.0"
class PromptDataset(Dataset):
def __init__(self, file_path, tokenizer, block_size=256):
self.input_examples = []
with open(file_path, 'r', encoding="utf-8", errors="replace") as f:
text = f.read()
lines = text.splitlines()
for line in lines:
if line.strip():
parts = line.split('[PAD]')
print1 = True
if len(parts) >= 3:
input_part = '[PAD]'.join(parts[:1]).strip() # Only keep the part up to the first [PAD]
input_part += tokenizer.eos_token
if print1:
print(input_part)
print1 = False
tokenized_input = tokenizer.encode(input_part, add_special_tokens=True, truncation=True)
# Split sequences longer than the block size for input
for i in range(0, len(tokenized_input), block_size):
input_chunk = tokenized_input[i:i + block_size]
self.input_examples.append(torch.tensor(input_chunk))
def __len__(self):
return len(self.input_examples)
def __getitem__(self, i):
return self.input_examples[i]
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print(tokenizer.eos_token)
tokenizer.pad_token = tokenizer.eos_token
dataset = PromptDataset("batch_ds_v2.txt", tokenizer)
print(f"Number of examples: {len(dataset)}")
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device("mps")
model.to(device)
training_args = TrainingArguments(
lr_scheduler_type="cosine",
run_name="small-1of3_v3",
output_dir="./small",
overwrite_output_dir=True,
num_train_epochs=15, #3
max_steps=5000, #500
save_steps=1000,
auto_find_batch_size=True,
learning_rate=1e-4,
max_grad_norm=1.0,
logging_steps=1,
)
def data_collator(features):
input_ids = torch.nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=tokenizer.pad_token_id)
labels = input_ids.clone()
#labels[labels == tokenizer.pad_token_id] = -100 # Set labels to -100 where input is [PAD] to ignore in loss calculation
return {"input_ids": input_ids, "labels": labels}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
)
trainer.train()
model.save_pretrained("./v2/small_2")
tokenizer.save_pretrained("./v2/small_2")