import pandas as pd import torch from PIL import Image import argparse from evaluate import load from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW import torchvision.transforms as transforms from dataset import TextlineDataset parser = argparse.ArgumentParser('arguments for the code') parser.add_argument('--root_path', type=str, default="", help='Root path to data files.') parser.add_argument('--tr_data_path', type=str, default="/path/to/train/data.csv", help='Path to .csv file containing the training data.') parser.add_argument('--val_data_path', type=str, default="/path/to/val/data.csv", help='Path to .csv file containing the validation data.') parser.add_argument('--output_path', type=str, default="/output/path/", help='Path for saving training results.') parser.add_argument('--batch_size', type=int, default=24, help='Batch size per device.') parser.add_argument('--epochs', type=int, default=13, help='Number of training epochs.') args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('Device: ', device) # Initialize processor and model processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") model.to(device) # Initialize metrics cer_metric = load("cer") wer_metric = load("wer") # Load train and validation data to dataframes train_df = pd.read_csv(args.tr_data_path) val_df = pd.read_csv(args.val_data_path) # Reset the indices to start from zero train_df.reset_index(drop=True, inplace=True) val_df.reset_index(drop=True, inplace=True) # Create train and validation datasets train_dataset = TextlineDataset(root_dir=args.root_path, df=train_df, processor=processor, augment=False) eval_dataset = TextlineDataset(root_dir=args.root_path, df=val_df, processor=processor, augment=False) print("Number of training examples:", len(train_dataset)) print("Number of validation examples:", len(eval_dataset)) # Define model configuration # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # make sure vocab size is set correctly model.config.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 # Set arguments for model training # For all arguments see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=50, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, load_best_model_at_end=True, metric_for_best_model='cer', greater_is_better=False, #fp16=True, num_train_epochs=args.epochs, save_total_limit=2, output_dir=args.output_path, optim="adamw_torch" ) # Function for computing CER and WER metrics for the prediction results def compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"cer": cer, "wer": wer} # Instantiate trainer # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer trainer = Seq2SeqTrainer( model=model, tokenizer=processor.feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, ) # Train the model trainer.train() #trainer.train(resume_from_checkpoint = True) model.save_pretrained(args.output_path) processor.save_pretrained(args.output_path + "/processor")