|
from torch.utils.data import Dataset |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments |
|
from PIL import Image |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
class HandwrittenMathDataset(Dataset): |
|
""" |
|
Initialize the class with the provided annotations file, image directory, and processor. |
|
|
|
Parameters: |
|
annotations_file (str): The file path to the annotations file. |
|
img_dir (str): The directory path to the images. |
|
processor: The processor object to be used for image processing. |
|
""" |
|
def __init__(self, annotations_file, img_dir, processor, subset="train"): |
|
self.img_labels = pd.read_csv(annotations_file) |
|
self.train_data, self.test_data = train_test_split(self.img_labels, test_size=0.2, random_state=42) |
|
self.data = self.train_data if subset == "train" else self.test_data |
|
self.img_dir = img_dir |
|
self.processor = processor |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.data.iloc[idx, 0] |
|
image = Image.open(img_path).convert("RGB") |
|
|
|
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values |
|
label = self.data.iloc[idx, 1] |
|
|
|
labels = self.processor.tokenizer(label, padding="max_length", max_length=128, truncation=True, |
|
return_tensors="pt").input_ids |
|
|
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100 |
|
|
|
return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()} |
|
|
|
|
|
def main(): |
|
""" |
|
A function to train a model for handwritten text recognition using TrOCRProcessor and VisionEncoderDecoderModel. |
|
""" |
|
annotations_file = './dataset/annotations.csv' |
|
img_dir = './dataset/images/' |
|
model_id = 'microsoft/trocr-base-handwritten' |
|
|
|
processor = TrOCRProcessor.from_pretrained(model_id) |
|
model = VisionEncoderDecoderModel.from_pretrained(model_id).to("cuda") |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
train_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor, |
|
subset="train") |
|
test_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor, |
|
subset="test") |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./model', |
|
per_device_train_batch_size=2, |
|
num_train_epochs=20, |
|
logging_dir='./training_logs', |
|
logging_steps=10, |
|
save_strategy="epoch", |
|
save_total_limit=1, |
|
weight_decay=0.1, |
|
learning_rate=1e-4, |
|
gradient_checkpointing=True, |
|
gradient_accumulation_steps=2, |
|
evaluation_strategy="epoch" |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|