File size: 6,439 Bytes
f1d3ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
from datasets import load_dataset
from transformers import TrainingArguments
from span_marker import SpanMarkerModel, Trainer
def main() -> None:
# Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
dataset = "Babelscape/multinerd"
train_dataset = load_dataset(dataset, split="train")
eval_dataset = load_dataset(dataset, split="validation").shuffle().select(range(3000))
labels = [
"O",
"B-PER",
"I-PER",
"B-ORG",
"I-ORG",
"B-LOC",
"I-LOC",
"B-ANIM",
"I-ANIM",
"B-BIO",
"I-BIO",
"B-CEL",
"I-CEL",
"B-DIS",
"I-DIS",
"B-EVE",
"I-EVE",
"B-FOOD",
"I-FOOD",
"B-INST",
"I-INST",
"B-MEDIA",
"I-MEDIA",
"B-MYTH",
"I-MYTH",
"B-PLANT",
"I-PLANT",
"B-TIME",
"I-TIME",
"B-VEHI",
"I-VEHI",
]
# Initialize a SpanMarker model using a pretrained BERT-style encoder
model_name = "xlm-roberta-base"
model = SpanMarkerModel.from_pretrained(
model_name,
labels=labels,
# SpanMarker hyperparameters:
model_max_length=256,
marker_max_length=128,
entity_max_length=6,
)
# Prepare the 🤗 transformers training arguments
args = TrainingArguments(
output_dir="models/span_marker_xlm_roberta_base_multinerd",
# Training Hyperparameters:
learning_rate=1e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
# gradient_accumulation_steps=2,
num_train_epochs=1,
weight_decay=0.01,
warmup_ratio=0.1,
bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
# Other Training parameters
logging_first_step=True,
logging_steps=50,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=1000,
save_total_limit=2,
dataloader_num_workers=2,
)
# Initialize the trainer using our model, training args & dataset, and train
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model("models/span_marker_xlm_roberta_base_multinerd/checkpoint-final")
test_dataset = load_dataset(dataset, split="test")
# Compute & save the metrics on the test set
metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
trainer.save_metrics("test", metrics)
if __name__ == "__main__":
main()
"""
This SpanMarker model will ignore 2.239322% of all annotated entities in the train dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words and the maximum model input length of 256 tokens.
These are the frequencies of the missed entities due to maximum entity length out of 4111958 total entities:
- 35814 missed entities with 7 words (0.870972%)
- 21246 missed entities with 8 words (0.516688%)
- 12680 missed entities with 9 words (0.308369%)
- 7308 missed entities with 10 words (0.177726%)
- 4414 missed entities with 11 words (0.107345%)
- 2474 missed entities with 12 words (0.060166%)
- 1894 missed entities with 13 words (0.046061%)
- 1130 missed entities with 14 words (0.027481%)
- 744 missed entities with 15 words (0.018094%)
- 582 missed entities with 16 words (0.014154%)
- 344 missed entities with 17 words (0.008366%)
- 226 missed entities with 18 words (0.005496%)
- 84 missed entities with 19 words (0.002043%)
- 46 missed entities with 20 words (0.001119%)
- 20 missed entities with 21 words (0.000486%)
- 20 missed entities with 22 words (0.000486%)
- 12 missed entities with 23 words (0.000292%)
- 18 missed entities with 24 words (0.000438%)
- 2 missed entities with 25 words (0.000049%)
- 4 missed entities with 26 words (0.000097%)
- 4 missed entities with 27 words (0.000097%)
- 2 missed entities with 31 words (0.000049%)
- 8 missed entities with 32 words (0.000195%)
- 6 missed entities with 33 words (0.000146%)
- 2 missed entities with 34 words (0.000049%)
- 4 missed entities with 36 words (0.000097%)
- 8 missed entities with 37 words (0.000195%)
- 2 missed entities with 38 words (0.000049%)
- 2 missed entities with 41 words (0.000049%)
- 2 missed entities with 72 words (0.000049%)
Additionally, a total of 2978 (0.072423%) entities were missed due to the maximum input length.
This SpanMarker model won't be able to predict 2.501087% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
These are the frequencies of the missed entities due to maximum entity length out of 4598 total entities:
- 45 missed entities with 7 words (0.978686%)
- 27 missed entities with 8 words (0.587212%)
- 21 missed entities with 9 words (0.456720%)
- 9 missed entities with 10 words (0.195737%)
- 3 missed entities with 12 words (0.065246%)
- 4 missed entities with 13 words (0.086994%)
- 3 missed entities with 14 words (0.065246%)
- 1 missed entities with 15 words (0.021749%)
- 1 missed entities with 16 words (0.021749%)
- 1 missed entities with 20 words (0.021749%)
"""
"""
wandb: Run summary:
wandb: eval/loss 0.00594
wandb: eval/overall_accuracy 0.98181
wandb: eval/overall_f1 0.90333
wandb: eval/overall_precision 0.91259
wandb: eval/overall_recall 0.89427
wandb: eval/runtime 21.4308
wandb: eval/samples_per_second 154.171
wandb: eval/steps_per_second 4.853
wandb: test/loss 0.00559
wandb: test/overall_accuracy 0.98247
wandb: test/overall_f1 0.91314
wandb: test/overall_precision 0.91994
wandb: test/overall_recall 0.90643
wandb: test/runtime 2202.6894
wandb: test/samples_per_second 169.652
wandb: test/steps_per_second 5.302
wandb: train/epoch 1.0
wandb: train/global_step 93223
wandb: train/learning_rate 0.0
wandb: train/loss 0.0049
wandb: train/total_flos 7.851073325660897e+17
wandb: train/train_loss 0.01782
wandb: train/train_runtime 41756.9748
wandb: train/train_samples_per_second 71.44
wandb: train/train_steps_per_second 2.233
""" |