|
import os |
|
import torch |
|
from torch.utils.data import Dataset |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments |
|
|
|
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"]="0.0" |
|
|
|
class StoryPromptDataset(Dataset): |
|
def __init__(self, file_path, tokenizer, block_size=450): |
|
self.examples = [] |
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: |
|
text = f.read() |
|
|
|
lines = text.splitlines() |
|
print1 = True |
|
for line in lines: |
|
if line.strip(): |
|
story_part = line.split('[PAD]') |
|
|
|
if len(story_part) >= 3: |
|
story = story_part[0].strip() |
|
abilities = story_part[3].strip() |
|
|
|
story_with_end = story + " <|endoftext|>" |
|
combined = story_with_end + " " + abilities |
|
if print1: |
|
print1 = False |
|
print("\n" + combined + "\n\n") |
|
|
|
tokenized_combined = tokenizer.encode(combined, add_special_tokens=True) |
|
self.examples.append(torch.tensor(tokenized_combined)) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, i): |
|
return self.examples[i] |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
dataset = StoryPromptDataset("batch_ds_v2.txt", tokenizer) |
|
print(f"Number of examples: {len(dataset)}") |
|
|
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") |
|
model.to(device) |
|
|
|
training_args = TrainingArguments( |
|
lr_scheduler_type="cosine", |
|
run_name="small-4of4", |
|
output_dir="./v2/midjourney/small", |
|
overwrite_output_dir=True, |
|
num_train_epochs=15, |
|
max_steps=5000, |
|
save_steps=1000, |
|
auto_find_batch_size=True, |
|
learning_rate=1e-4, |
|
max_grad_norm=1.0, |
|
logging_steps=1, |
|
) |
|
|
|
def data_collator(features): |
|
input_ids = torch.nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
labels = input_ids.clone() |
|
return {"input_ids": input_ids, "labels": labels} |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
data_collator=data_collator, |
|
) |
|
|
|
trainer.train() |
|
|
|
model.save_pretrained("./v2/midjourney/small") |
|
tokenizer.save_pretrained("./v2/midjourney/small") |
|
|