import sys | |
from datasets import load_dataset | |
from transformers import TrainingArguments | |
from span_marker import SpanMarkerModel, Trainer | |
# Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels | |
dataset = load_dataset("gwlms/germeval2014") | |
labels = dataset["train"].features["ner_tags"].feature.names | |
# Initialize a SpanMarker model using a pretrained BERT-style encoder | |
model_name = sys.argv[1] | |
model = SpanMarkerModel.from_pretrained( | |
model_name, | |
labels=labels, | |
# SpanMarker hyperparameters: | |
model_max_length=256, | |
marker_max_length=128, | |
entity_max_length=8, | |
) | |
args = TrainingArguments( | |
output_dir="/tmp", | |
per_device_eval_batch_size=64, | |
) | |
# Initialize the trainer using our model, training args & dataset, and train | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["validation"], | |
) | |
print("Evaluating on development set...") | |
dev_metrics = trainer.evaluate(dataset["validation"], metric_key_prefix="eval") | |
print(dev_metrics) | |
print("Evaluating on test set...") | |
test_metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test") | |
print(test_metrics) | |