Spaces:
Build error
Build error
import json | |
from typing import List, Tuple | |
import pandas as pd | |
from sftp import SpanPredictor | |
def main(): | |
# data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl" | |
# data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl" | |
data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl" | |
models = [ | |
( | |
"lome-en", | |
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", | |
), | |
( | |
"lome-it-best", | |
"/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", | |
), | |
# ( | |
# "lome-it-freeze", | |
# "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz", | |
# ), | |
# ( | |
# "lome-it-mono", | |
# "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz", | |
# ), | |
] | |
for (model_name, model_path) in models: | |
print("testing model: ", model_name) | |
predictor = SpanPredictor.from_path(model_path) | |
print("=== FD (run 1) ===") | |
eval_frame_detection(data_file, predictor, model_name=model_name) | |
for run in [1, 2]: | |
print(f"=== BD (run {run}) ===") | |
eval_boundary_detection(data_file, predictor, run=run) | |
for run in [1, 2, 3]: | |
print(f"=== AC (run {run}) ===") | |
eval_argument_classification(data_file, predictor, run=run) | |
def predict_frame( | |
predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int] | |
): | |
_, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span]) | |
return labels[0] | |
def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"): | |
true_pos = 0 | |
false_pos = 0 | |
out = [] | |
with open(data_file, encoding="utf-8") as f: | |
for sent_id, sent in enumerate(f): | |
sent_data = json.loads(sent) | |
tokens = sent_data["tokens"] | |
annotation = sent_data["annotations"][0] | |
predicate_span = tuple(annotation["span"]) | |
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
frame_gold = annotation["label"] | |
frame_pred = predict_frame(predictor, tokens, predicate_span) | |
if frame_pred == frame_gold: | |
true_pos += 1 | |
else: | |
false_pos += 1 | |
out.append({ | |
"sentence": " ".join(tokens), | |
"predicate": predicate, | |
"frame_gold": frame_gold, | |
"frame_pred": frame_pred | |
}) | |
if verbose: | |
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
print(f"\tpredicate: {predicate}") | |
print(f"\t gold: {frame_gold}") | |
print(f"\tpredicted: {frame_pred}") | |
print() | |
acc_score = true_pos / (true_pos + false_pos) | |
print("ACC =", acc_score) | |
data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test" | |
df_out = pd.DataFrame(out) | |
df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv") | |
def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame): | |
boundaries, labels, _ = predictor.force_decode( | |
tokens, parent_span=predicate_span, parent_label=frame | |
) | |
out = [] | |
for bnd, lab in zip(boundaries, labels): | |
bnd = tuple(bnd) | |
if bnd == predicate_span and lab == "Target": | |
continue | |
out.append(bnd) | |
return out | |
def get_gold_boundaries(annotation, predicate_span): | |
return { | |
tuple(c["span"]) | |
for c in annotation["children"] | |
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") | |
} | |
def eval_boundary_detection(data_file, predictor, run=1, verbose=False): | |
assert run in [1, 2] | |
true_pos = 0 | |
false_pos = 0 | |
false_neg = 0 | |
true_pos_tok = 0 | |
false_pos_tok = 0 | |
false_neg_tok = 0 | |
with open(data_file, encoding="utf-8") as f: | |
for sent_id, sent in enumerate(f): | |
sent_data = json.loads(sent) | |
tokens = sent_data["tokens"] | |
annotation = sent_data["annotations"][0] | |
predicate_span = tuple(annotation["span"]) | |
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
if run == 1: | |
frame = predict_frame(predictor, tokens, predicate_span) | |
else: | |
frame = annotation["label"] | |
boundaries_gold = get_gold_boundaries(annotation, predicate_span) | |
boundaries_pred = set( | |
predict_boundaries(predictor, tokens, predicate_span, frame) | |
) | |
sent_true_pos = len(boundaries_gold & boundaries_pred) | |
sent_false_pos = len(boundaries_pred - boundaries_gold) | |
sent_false_neg = len(boundaries_gold - boundaries_pred) | |
true_pos += sent_true_pos | |
false_pos += sent_false_pos | |
false_neg += sent_false_neg | |
boundary_toks_gold = { | |
tok_idx | |
for (start, stop) in boundaries_gold | |
for tok_idx in range(start, stop + 1) | |
} | |
boundary_toks_pred = { | |
tok_idx | |
for (start, stop) in boundaries_pred | |
for tok_idx in range(start, stop + 1) | |
} | |
sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred) | |
sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold) | |
sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred) | |
true_pos_tok += sent_tok_true_pos | |
false_pos_tok += sent_tok_false_pos | |
false_neg_tok += sent_tok_false_neg | |
if verbose: | |
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
print(f"\tpredicate: {predicate}") | |
print(f"\t frame: {frame}") | |
print(f"\t gold: {boundaries_gold}") | |
print(f"\tpredicted: {boundaries_pred}") | |
print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}") | |
print( | |
f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}" | |
) | |
print() | |
prec = true_pos / (true_pos + false_pos) | |
rec = true_pos / (true_pos + false_neg) | |
f1_score = 2 * ((prec * rec) / (prec + rec)) | |
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") | |
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) | |
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) | |
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) | |
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") | |
def predict_arguments( | |
predictor: SpanPredictor, tokens, predicate_span, frame, boundaries | |
): | |
boundaries = list(sorted(boundaries, key=lambda t: t[0])) | |
_, labels, _ = predictor.force_decode( | |
tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries | |
) | |
out = [] | |
for bnd, lab in zip(boundaries, labels): | |
if bnd == predicate_span and lab == "Target": | |
continue | |
out.append((bnd, lab)) | |
return out | |
def eval_argument_classification(data_file, predictor, run=1, verbose=False): | |
assert run in [1, 2, 3] | |
true_pos = 0 | |
false_pos = 0 | |
false_neg = 0 | |
true_pos_tok = 0 | |
false_pos_tok = 0 | |
false_neg_tok = 0 | |
with open(data_file, encoding="utf-8") as f: | |
for sent_id, sent in enumerate(f): | |
sent_data = json.loads(sent) | |
tokens = sent_data["tokens"] | |
annotation = sent_data["annotations"][0] | |
predicate_span = tuple(annotation["span"]) | |
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
# gold or predicted frames? | |
if run == 1: | |
frame = predict_frame(predictor, tokens, predicate_span) | |
else: | |
frame = annotation["label"] | |
# gold or predicted argument boundaries? | |
if run in [1, 2]: | |
boundaries = set( | |
predict_boundaries(predictor, tokens, predicate_span, frame) | |
) | |
else: | |
boundaries = get_gold_boundaries(annotation, predicate_span) | |
pred_arguments = predict_arguments( | |
predictor, tokens, predicate_span, frame, boundaries | |
) | |
gold_arguments = { | |
(tuple(c["span"]), c["label"]) | |
for c in annotation["children"] | |
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") | |
} | |
if verbose: | |
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
print(f"\tpredicate: {predicate}") | |
print(f"\t frame: {frame}") | |
print(f"\t gold: {gold_arguments}") | |
print(f"\tpredicted: {pred_arguments}") | |
print() | |
# -- full spans version | |
for g_bnd, g_label in gold_arguments: | |
# true positive: found the span and labeled it correctly | |
if (g_bnd, g_label) in pred_arguments: | |
true_pos += 1 | |
# false negative: missed this argument | |
else: | |
false_neg += 1 | |
for p_bnd, p_label in pred_arguments: | |
# all predictions that are not true positives are false positives | |
if (p_bnd, p_label) not in gold_arguments: | |
false_pos += 1 | |
# -- token based | |
tok_gold_labels = { | |
(token, label) | |
for ((bnd_start, bnd_end), label) in gold_arguments | |
for token in range(bnd_start, bnd_end + 1) | |
} | |
tok_pred_labels = { | |
(token, label) | |
for ((bnd_start, bnd_end), label) in pred_arguments | |
for token in range(bnd_start, bnd_end + 1) | |
} | |
for g_tok, g_tok_label in tok_gold_labels: | |
if (g_tok, g_tok_label) in tok_pred_labels: | |
true_pos_tok += 1 | |
else: | |
false_neg_tok += 1 | |
for p_tok, p_tok_label in tok_pred_labels: | |
if (p_tok, p_tok_label) not in tok_gold_labels: | |
false_pos_tok += 1 | |
prec = true_pos / (true_pos + false_pos) | |
rec = true_pos / (true_pos + false_neg) | |
f1_score = 2 * ((prec * rec) / (prec + rec)) | |
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") | |
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) | |
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) | |
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) | |
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") | |
if __name__ == "__main__": | |
main() | |