|
import argparse |
|
import logging |
|
from torch.utils.data import Dataset, IterableDataset |
|
import gzip |
|
import json |
|
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments |
|
import sys |
|
from datetime import datetime |
|
import torch |
|
import random |
|
from shutil import copyfile |
|
import os |
|
import wandb |
|
import random |
|
import re |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_name", default="google/t5-v1_1-base") |
|
parser.add_argument("--train_files", required=True, nargs='+', default=[]) |
|
parser.add_argument("--epochs", default=1, type=int) |
|
parser.add_argument("--batch_size", default=32, type=int) |
|
parser.add_argument("--max_source_length", default=320, type=int) |
|
parser.add_argument("--max_target_length", default=64, type=int) |
|
parser.add_argument("--name", required=True) |
|
parser.add_argument("--train_size", default=10*1000*1000, type=int) |
|
parser.add_argument("--eval_size", default=10000, type=int) |
|
parser.add_argument("--fp16", default=False, action='store_true') |
|
args = parser.parse_args() |
|
|
|
wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}") |
|
|
|
|
|
|
|
|
|
class PairDataset: |
|
def __init__(self, filepath): |
|
self.filepath = filepath |
|
self.examples = [] |
|
|
|
def __iter__(self): |
|
print("open", self.filepath) |
|
with gzip.open(self.filepath, 'rt') as fIn: |
|
for line in fIn: |
|
example = self.get_example(json.loads(line)) |
|
if example is not None: |
|
self.examples.append(example) |
|
yield example |
|
|
|
while True: |
|
random.shuffle(self.examples) |
|
for ex in self.examples: |
|
yield ex |
|
|
|
|
|
def get_example(self, raw_example): |
|
if isinstance(raw_example, dict): |
|
return [raw_example['query'], random.choice(raw_example['pos'])] |
|
else: |
|
return [raw_example[0], raw_example[1]] |
|
|
|
|
|
class RedditTitleDataset(PairDataset): |
|
def get_example(self, raw_example): |
|
return [self.clean_title(raw_example['title']), raw_example['body']] |
|
|
|
|
|
def clean_title(self, text): |
|
text = text.replace("&", "&").strip() |
|
if text.startswith("["): |
|
text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip() |
|
|
|
if text.endswith("]"): |
|
text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip() |
|
|
|
if text.startswith("/r"): |
|
text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip() |
|
|
|
return text |
|
|
|
|
|
class StackExchangeTitleBodyDataset(PairDataset): |
|
def get_example(self, raw_example): |
|
return raw_example['texts'] |
|
|
|
|
|
class MultiDataset(IterableDataset): |
|
def __init__(self, filepaths, num_samples): |
|
self.num_samples = num_samples |
|
self.datasets = [] |
|
self.data_iterators = [] |
|
|
|
for filepath in filepaths: |
|
if 'reddit_title_text' in filepath: |
|
dataset = RedditTitleDataset(filepath) |
|
elif 'stackexchange_archive/jsonl' in filepath: |
|
dataset = StackExchangeTitleBodyDataset(filepath) |
|
else: |
|
dataset = PairDataset(filepath) |
|
self.datasets.append(dataset) |
|
self.data_iterators.append(iter(dataset)) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def __iter__(self): |
|
while True: |
|
for dataset in self.data_iterators: |
|
yield next(dataset) |
|
|
|
random.shuffle(self.data_iterators) |
|
|
|
def delete_examples_cache(self): |
|
for dataset in self.datasets: |
|
dataset.examples = [] |
|
|
|
|
|
|
|
def main(): |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
|
save_steps = 1000 |
|
|
|
output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
print("Output dir:", output_dir) |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
train_script_path = os.path.join(output_dir, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir=output_dir, |
|
fp16=args.fp16, |
|
fp16_backend="amp", |
|
per_device_train_batch_size=args.batch_size, |
|
evaluation_strategy="steps", |
|
save_steps=save_steps, |
|
logging_steps=100, |
|
eval_steps=save_steps, |
|
warmup_steps=1000, |
|
save_total_limit=1, |
|
num_train_epochs=args.epochs, |
|
report_to="wandb", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = MultiDataset(args.train_files, args.train_size) |
|
train_dataset_iter = iter(train_dataset) |
|
eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)] |
|
train_dataset.delete_examples_cache() |
|
print("Target:", eval_dataset[0][0]) |
|
print("Input:", eval_dataset[0][1]) |
|
|
|
print("Train dataset len:", len(train_dataset)) |
|
|
|
|
|
def data_collator(examples): |
|
targets = [row[0] for row in examples] |
|
inputs = [row[1] for row in examples] |
|
label_pad_token_id = -100 |
|
|
|
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None) |
|
|
|
|
|
labels["input_ids"] = [ |
|
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"] |
|
] |
|
|
|
|
|
model_inputs["labels"] = torch.tensor(labels["input_ids"]) |
|
return model_inputs |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator |
|
) |
|
|
|
|
|
train_result = trainer.train() |
|
trainer.save_model() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|