Gosse Minnema
Re-enable LOME
2890e34
raw
history blame
10.9 kB
import json
from typing import List, Tuple
import pandas as pd
from sftp import SpanPredictor
def main():
# data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl"
# data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl"
data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl"
models = [
(
"lome-en",
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz",
),
(
"lome-it-best",
"/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz",
),
# (
# "lome-it-freeze",
# "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz",
# ),
# (
# "lome-it-mono",
# "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz",
# ),
]
for (model_name, model_path) in models:
print("testing model: ", model_name)
predictor = SpanPredictor.from_path(model_path)
print("=== FD (run 1) ===")
eval_frame_detection(data_file, predictor, model_name=model_name)
for run in [1, 2]:
print(f"=== BD (run {run}) ===")
eval_boundary_detection(data_file, predictor, run=run)
for run in [1, 2, 3]:
print(f"=== AC (run {run}) ===")
eval_argument_classification(data_file, predictor, run=run)
def predict_frame(
predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int]
):
_, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span])
return labels[0]
def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"):
true_pos = 0
false_pos = 0
out = []
with open(data_file, encoding="utf-8") as f:
for sent_id, sent in enumerate(f):
sent_data = json.loads(sent)
tokens = sent_data["tokens"]
annotation = sent_data["annotations"][0]
predicate_span = tuple(annotation["span"])
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
frame_gold = annotation["label"]
frame_pred = predict_frame(predictor, tokens, predicate_span)
if frame_pred == frame_gold:
true_pos += 1
else:
false_pos += 1
out.append({
"sentence": " ".join(tokens),
"predicate": predicate,
"frame_gold": frame_gold,
"frame_pred": frame_pred
})
if verbose:
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
print(f"\tpredicate: {predicate}")
print(f"\t gold: {frame_gold}")
print(f"\tpredicted: {frame_pred}")
print()
acc_score = true_pos / (true_pos + false_pos)
print("ACC =", acc_score)
data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test"
df_out = pd.DataFrame(out)
df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv")
def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame):
boundaries, labels, _ = predictor.force_decode(
tokens, parent_span=predicate_span, parent_label=frame
)
out = []
for bnd, lab in zip(boundaries, labels):
bnd = tuple(bnd)
if bnd == predicate_span and lab == "Target":
continue
out.append(bnd)
return out
def get_gold_boundaries(annotation, predicate_span):
return {
tuple(c["span"])
for c in annotation["children"]
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
}
def eval_boundary_detection(data_file, predictor, run=1, verbose=False):
assert run in [1, 2]
true_pos = 0
false_pos = 0
false_neg = 0
true_pos_tok = 0
false_pos_tok = 0
false_neg_tok = 0
with open(data_file, encoding="utf-8") as f:
for sent_id, sent in enumerate(f):
sent_data = json.loads(sent)
tokens = sent_data["tokens"]
annotation = sent_data["annotations"][0]
predicate_span = tuple(annotation["span"])
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
if run == 1:
frame = predict_frame(predictor, tokens, predicate_span)
else:
frame = annotation["label"]
boundaries_gold = get_gold_boundaries(annotation, predicate_span)
boundaries_pred = set(
predict_boundaries(predictor, tokens, predicate_span, frame)
)
sent_true_pos = len(boundaries_gold & boundaries_pred)
sent_false_pos = len(boundaries_pred - boundaries_gold)
sent_false_neg = len(boundaries_gold - boundaries_pred)
true_pos += sent_true_pos
false_pos += sent_false_pos
false_neg += sent_false_neg
boundary_toks_gold = {
tok_idx
for (start, stop) in boundaries_gold
for tok_idx in range(start, stop + 1)
}
boundary_toks_pred = {
tok_idx
for (start, stop) in boundaries_pred
for tok_idx in range(start, stop + 1)
}
sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred)
sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold)
sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred)
true_pos_tok += sent_tok_true_pos
false_pos_tok += sent_tok_false_pos
false_neg_tok += sent_tok_false_neg
if verbose:
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
print(f"\tpredicate: {predicate}")
print(f"\t frame: {frame}")
print(f"\t gold: {boundaries_gold}")
print(f"\tpredicted: {boundaries_pred}")
print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}")
print(
f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}"
)
print()
prec = true_pos / (true_pos + false_pos)
rec = true_pos / (true_pos + false_neg)
f1_score = 2 * ((prec * rec) / (prec + rec))
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
def predict_arguments(
predictor: SpanPredictor, tokens, predicate_span, frame, boundaries
):
boundaries = list(sorted(boundaries, key=lambda t: t[0]))
_, labels, _ = predictor.force_decode(
tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries
)
out = []
for bnd, lab in zip(boundaries, labels):
if bnd == predicate_span and lab == "Target":
continue
out.append((bnd, lab))
return out
def eval_argument_classification(data_file, predictor, run=1, verbose=False):
assert run in [1, 2, 3]
true_pos = 0
false_pos = 0
false_neg = 0
true_pos_tok = 0
false_pos_tok = 0
false_neg_tok = 0
with open(data_file, encoding="utf-8") as f:
for sent_id, sent in enumerate(f):
sent_data = json.loads(sent)
tokens = sent_data["tokens"]
annotation = sent_data["annotations"][0]
predicate_span = tuple(annotation["span"])
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
# gold or predicted frames?
if run == 1:
frame = predict_frame(predictor, tokens, predicate_span)
else:
frame = annotation["label"]
# gold or predicted argument boundaries?
if run in [1, 2]:
boundaries = set(
predict_boundaries(predictor, tokens, predicate_span, frame)
)
else:
boundaries = get_gold_boundaries(annotation, predicate_span)
pred_arguments = predict_arguments(
predictor, tokens, predicate_span, frame, boundaries
)
gold_arguments = {
(tuple(c["span"]), c["label"])
for c in annotation["children"]
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
}
if verbose:
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
print(f"\tpredicate: {predicate}")
print(f"\t frame: {frame}")
print(f"\t gold: {gold_arguments}")
print(f"\tpredicted: {pred_arguments}")
print()
# -- full spans version
for g_bnd, g_label in gold_arguments:
# true positive: found the span and labeled it correctly
if (g_bnd, g_label) in pred_arguments:
true_pos += 1
# false negative: missed this argument
else:
false_neg += 1
for p_bnd, p_label in pred_arguments:
# all predictions that are not true positives are false positives
if (p_bnd, p_label) not in gold_arguments:
false_pos += 1
# -- token based
tok_gold_labels = {
(token, label)
for ((bnd_start, bnd_end), label) in gold_arguments
for token in range(bnd_start, bnd_end + 1)
}
tok_pred_labels = {
(token, label)
for ((bnd_start, bnd_end), label) in pred_arguments
for token in range(bnd_start, bnd_end + 1)
}
for g_tok, g_tok_label in tok_gold_labels:
if (g_tok, g_tok_label) in tok_pred_labels:
true_pos_tok += 1
else:
false_neg_tok += 1
for p_tok, p_tok_label in tok_pred_labels:
if (p_tok, p_tok_label) not in tok_gold_labels:
false_pos_tok += 1
prec = true_pos / (true_pos + false_pos)
rec = true_pos / (true_pos + false_neg)
f1_score = 2 * ((prec * rec) / (prec + rec))
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
if __name__ == "__main__":
main()