File size: 1,185 Bytes
911d980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys

from datasets import load_dataset
from transformers import TrainingArguments
from span_marker import SpanMarkerModel, Trainer


# Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
dataset = load_dataset("gwlms/germeval2014")
labels = dataset["train"].features["ner_tags"].feature.names

# Initialize a SpanMarker model using a pretrained BERT-style encoder
model_name = sys.argv[1]
model = SpanMarkerModel.from_pretrained(
    model_name,
    labels=labels,
    # SpanMarker hyperparameters:
    model_max_length=256,
    marker_max_length=128,
    entity_max_length=8,
)

args = TrainingArguments(
    output_dir="/tmp",
    per_device_eval_batch_size=64,
)

# Initialize the trainer using our model, training args & dataset, and train
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)


print("Evaluating on development set...")
dev_metrics = trainer.evaluate(dataset["validation"], metric_key_prefix="eval")
print(dev_metrics)

print("Evaluating on test set...")
test_metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test")
print(test_metrics)