MikkoLipsanen commited on
Commit
71042e0
1 Parent(s): 1838a16

Update train_trocr.py

Browse files
Files changed (1) hide show
  1. train_trocr.py +8 -22
train_trocr.py CHANGED
@@ -5,7 +5,6 @@ import argparse
5
  from evaluate import load
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
7
  import torchvision.transforms as transforms
8
- #import torch_optimizer as optim
9
 
10
  from dataset import TextlineDataset
11
 
@@ -13,14 +12,12 @@ parser = argparse.ArgumentParser('arguments for the code')
13
 
14
  parser.add_argument('--root_path', type=str, default="",
15
  help='Root path to data files.')
16
- parser.add_argument('--tr_data_path', type=str, default="/data/htr/trocr_data/trocr_tuomiokirjat/train/trocr/data.csv",
17
  help='Path to .csv file containing the training data.')
18
- parser.add_argument('--val_data_path', type=str, default="/data/htr/trocr_data/trocr_tuomiokirjat/val/trocr/data.csv",
19
  help='Path to .csv file containing the validation data.')
20
- parser.add_argument('--output_path', type=str, default="/koodit/htr/text_recognition/trocr/tuomiokirjat/models/22112023/",
21
  help='Path for saving training results.')
22
- parser.add_argument('--resume_path', type=str, default="/koodit/htr/text_recognition/trocr/tuomiokirjat/models/22112023",
23
- help='Path to the previous model')
24
  parser.add_argument('--batch_size', type=int, default=24,
25
  help='Batch size per device.')
26
  parser.add_argument('--epochs', type=int, default=13,
@@ -28,20 +25,12 @@ parser.add_argument('--epochs', type=int, default=13,
28
 
29
  args = parser.parse_args()
30
 
31
- # nohup python train_trocr.py > logs/tuomiokirjat_resume_23112023.txt 2>&1 &
32
- # echo $! > logs/save_pid.txt
33
-
34
- # run using 2 GPUs: torchrun --nproc_per_node=2 train_trocr.py > logs/tuomiokirjat_22112023.txt 2>&1 &
35
-
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  print('Device: ', device)
38
 
39
  # Initialize processor and model
40
- #processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
41
- #model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
42
- processor =TrOCRProcessor.from_pretrained(args.resume_path + "/processor")
43
- model = VisionEncoderDecoderModel.from_pretrained(args.resume_path + "/checkpoint-13094")
44
-
45
  model.to(device)
46
 
47
  # Initialize metrics
@@ -51,8 +40,6 @@ wer_metric = load("wer")
51
  # Load train and validation data to dataframes
52
  train_df = pd.read_csv(args.tr_data_path)
53
  val_df = pd.read_csv(args.val_data_path)
54
- #train_df = train_df.iloc[:50]
55
- #val_df = val_df.iloc[:10]
56
 
57
  # Reset the indices to start from zero
58
  train_df.reset_index(drop=True, inplace=True)
@@ -88,7 +75,7 @@ model.config.length_penalty = 2.0
88
  model.config.num_beams = 4
89
 
90
  # Set arguments for model training
91
- # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
92
  training_args = Seq2SeqTrainingArguments(
93
  predict_with_generate=True,
94
  evaluation_strategy="epoch",
@@ -122,7 +109,7 @@ def compute_metrics(pred):
122
  return {"cer": cer, "wer": wer}
123
 
124
 
125
- # instantiate trainer
126
  # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
127
  trainer = Seq2SeqTrainer(
128
  model=model,
@@ -138,5 +125,4 @@ trainer = Seq2SeqTrainer(
138
  trainer.train()
139
  #trainer.train(resume_from_checkpoint = True)
140
  model.save_pretrained(args.output_path)
141
- processor.save_pretrained(args.output_path + "/processor")
142
-
 
5
  from evaluate import load
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
7
  import torchvision.transforms as transforms
 
8
 
9
  from dataset import TextlineDataset
10
 
 
12
 
13
  parser.add_argument('--root_path', type=str, default="",
14
  help='Root path to data files.')
15
+ parser.add_argument('--tr_data_path', type=str, default="/path/to/train/data.csv",
16
  help='Path to .csv file containing the training data.')
17
+ parser.add_argument('--val_data_path', type=str, default="/path/to/val/data.csv",
18
  help='Path to .csv file containing the validation data.')
19
+ parser.add_argument('--output_path', type=str, default="/output/path/",
20
  help='Path for saving training results.')
 
 
21
  parser.add_argument('--batch_size', type=int, default=24,
22
  help='Batch size per device.')
23
  parser.add_argument('--epochs', type=int, default=13,
 
25
 
26
  args = parser.parse_args()
27
 
 
 
 
 
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  print('Device: ', device)
30
 
31
  # Initialize processor and model
32
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
33
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
 
 
 
34
  model.to(device)
35
 
36
  # Initialize metrics
 
40
  # Load train and validation data to dataframes
41
  train_df = pd.read_csv(args.tr_data_path)
42
  val_df = pd.read_csv(args.val_data_path)
 
 
43
 
44
  # Reset the indices to start from zero
45
  train_df.reset_index(drop=True, inplace=True)
 
75
  model.config.num_beams = 4
76
 
77
  # Set arguments for model training
78
+ # For all arguments see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
79
  training_args = Seq2SeqTrainingArguments(
80
  predict_with_generate=True,
81
  evaluation_strategy="epoch",
 
109
  return {"cer": cer, "wer": wer}
110
 
111
 
112
+ # Instantiate trainer
113
  # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
114
  trainer = Seq2SeqTrainer(
115
  model=model,
 
125
  trainer.train()
126
  #trainer.train(resume_from_checkpoint = True)
127
  model.save_pretrained(args.output_path)
128
+ processor.save_pretrained(args.output_path + "/processor")