import os import torch from torch.utils.data import Dataset from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"]="0.0" class PromptDataset(Dataset): def __init__(self, file_path, tokenizer, block_size=256): self.input_examples = [] with open(file_path, 'r', encoding="utf-8", errors="replace") as f: text = f.read() lines = text.splitlines() for line in lines: if line.strip(): parts = line.split('[PAD]') if len(parts) >= 3: input_part = '[PAD]'.join(parts[:1]).strip() # Only keep the part up to the first [PAD] input_part += tokenizer.eos_token tokenized_input = tokenizer.encode(input_part, add_special_tokens=True, truncation=True) # Split sequences longer than the block size for input for i in range(0, len(tokenized_input), block_size): input_chunk = tokenized_input[i:i + block_size] self.input_examples.append(torch.tensor(input_chunk)) def __len__(self): return len(self.input_examples) def __getitem__(self, i): return self.input_examples[i] tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium") print(tokenizer.eos_token) tokenizer.pad_token = tokenizer.eos_token dataset = PromptDataset("batch_ds_v2.txt", tokenizer) print(f"Number of examples: {len(dataset)}") model = GPT2LMHeadModel.from_pretrained("gpt2-medium") device = torch.device("mps") model.to(device) training_args = TrainingArguments( lr_scheduler_type="cosine", run_name="medium-1of3_v3", output_dir="./v2/medium_2", overwrite_output_dir=True, num_train_epochs=15, #3 max_steps=500, save_steps=50, auto_find_batch_size=True, learning_rate=1e-4, max_grad_norm=1.0, logging_steps=1, ) def data_collator(features): input_ids = torch.nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=tokenizer.pad_token_id) labels = input_ids.clone() #labels[labels == tokenizer.pad_token_id] = -100 # Set labels to -100 where input is [PAD] to ignore in loss calculation return {"input_ids": input_ids, "labels": labels} trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=data_collator, ) trainer.train() model.save_pretrained("./v2/medium_2") tokenizer.save_pretrained("./v2/medium_2")