import json from typing import List, Tuple import pandas as pd from sftp import SpanPredictor def main(): # data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl" # data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl" data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl" models = [ ( "lome-en", "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", ), ( "lome-it-best", "/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", ), # ( # "lome-it-freeze", # "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz", # ), # ( # "lome-it-mono", # "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz", # ), ] for (model_name, model_path) in models: print("testing model: ", model_name) predictor = SpanPredictor.from_path(model_path) print("=== FD (run 1) ===") eval_frame_detection(data_file, predictor, model_name=model_name) for run in [1, 2]: print(f"=== BD (run {run}) ===") eval_boundary_detection(data_file, predictor, run=run) for run in [1, 2, 3]: print(f"=== AC (run {run}) ===") eval_argument_classification(data_file, predictor, run=run) def predict_frame( predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int] ): _, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span]) return labels[0] def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"): true_pos = 0 false_pos = 0 out = [] with open(data_file, encoding="utf-8") as f: for sent_id, sent in enumerate(f): sent_data = json.loads(sent) tokens = sent_data["tokens"] annotation = sent_data["annotations"][0] predicate_span = tuple(annotation["span"]) predicate = tokens[predicate_span[0] : predicate_span[1] + 1] frame_gold = annotation["label"] frame_pred = predict_frame(predictor, tokens, predicate_span) if frame_pred == frame_gold: true_pos += 1 else: false_pos += 1 out.append({ "sentence": " ".join(tokens), "predicate": predicate, "frame_gold": frame_gold, "frame_pred": frame_pred }) if verbose: print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") print(f"\tpredicate: {predicate}") print(f"\t gold: {frame_gold}") print(f"\tpredicted: {frame_pred}") print() acc_score = true_pos / (true_pos + false_pos) print("ACC =", acc_score) data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test" df_out = pd.DataFrame(out) df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv") def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame): boundaries, labels, _ = predictor.force_decode( tokens, parent_span=predicate_span, parent_label=frame ) out = [] for bnd, lab in zip(boundaries, labels): bnd = tuple(bnd) if bnd == predicate_span and lab == "Target": continue out.append(bnd) return out def get_gold_boundaries(annotation, predicate_span): return { tuple(c["span"]) for c in annotation["children"] if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") } def eval_boundary_detection(data_file, predictor, run=1, verbose=False): assert run in [1, 2] true_pos = 0 false_pos = 0 false_neg = 0 true_pos_tok = 0 false_pos_tok = 0 false_neg_tok = 0 with open(data_file, encoding="utf-8") as f: for sent_id, sent in enumerate(f): sent_data = json.loads(sent) tokens = sent_data["tokens"] annotation = sent_data["annotations"][0] predicate_span = tuple(annotation["span"]) predicate = tokens[predicate_span[0] : predicate_span[1] + 1] if run == 1: frame = predict_frame(predictor, tokens, predicate_span) else: frame = annotation["label"] boundaries_gold = get_gold_boundaries(annotation, predicate_span) boundaries_pred = set( predict_boundaries(predictor, tokens, predicate_span, frame) ) sent_true_pos = len(boundaries_gold & boundaries_pred) sent_false_pos = len(boundaries_pred - boundaries_gold) sent_false_neg = len(boundaries_gold - boundaries_pred) true_pos += sent_true_pos false_pos += sent_false_pos false_neg += sent_false_neg boundary_toks_gold = { tok_idx for (start, stop) in boundaries_gold for tok_idx in range(start, stop + 1) } boundary_toks_pred = { tok_idx for (start, stop) in boundaries_pred for tok_idx in range(start, stop + 1) } sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred) sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold) sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred) true_pos_tok += sent_tok_true_pos false_pos_tok += sent_tok_false_pos false_neg_tok += sent_tok_false_neg if verbose: print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") print(f"\tpredicate: {predicate}") print(f"\t frame: {frame}") print(f"\t gold: {boundaries_gold}") print(f"\tpredicted: {boundaries_pred}") print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}") print( f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}" ) print() prec = true_pos / (true_pos + false_pos) rec = true_pos / (true_pos + false_neg) f1_score = 2 * ((prec * rec) / (prec + rec)) print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") def predict_arguments( predictor: SpanPredictor, tokens, predicate_span, frame, boundaries ): boundaries = list(sorted(boundaries, key=lambda t: t[0])) _, labels, _ = predictor.force_decode( tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries ) out = [] for bnd, lab in zip(boundaries, labels): if bnd == predicate_span and lab == "Target": continue out.append((bnd, lab)) return out def eval_argument_classification(data_file, predictor, run=1, verbose=False): assert run in [1, 2, 3] true_pos = 0 false_pos = 0 false_neg = 0 true_pos_tok = 0 false_pos_tok = 0 false_neg_tok = 0 with open(data_file, encoding="utf-8") as f: for sent_id, sent in enumerate(f): sent_data = json.loads(sent) tokens = sent_data["tokens"] annotation = sent_data["annotations"][0] predicate_span = tuple(annotation["span"]) predicate = tokens[predicate_span[0] : predicate_span[1] + 1] # gold or predicted frames? if run == 1: frame = predict_frame(predictor, tokens, predicate_span) else: frame = annotation["label"] # gold or predicted argument boundaries? if run in [1, 2]: boundaries = set( predict_boundaries(predictor, tokens, predicate_span, frame) ) else: boundaries = get_gold_boundaries(annotation, predicate_span) pred_arguments = predict_arguments( predictor, tokens, predicate_span, frame, boundaries ) gold_arguments = { (tuple(c["span"]), c["label"]) for c in annotation["children"] if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") } if verbose: print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") print(f"\tpredicate: {predicate}") print(f"\t frame: {frame}") print(f"\t gold: {gold_arguments}") print(f"\tpredicted: {pred_arguments}") print() # -- full spans version for g_bnd, g_label in gold_arguments: # true positive: found the span and labeled it correctly if (g_bnd, g_label) in pred_arguments: true_pos += 1 # false negative: missed this argument else: false_neg += 1 for p_bnd, p_label in pred_arguments: # all predictions that are not true positives are false positives if (p_bnd, p_label) not in gold_arguments: false_pos += 1 # -- token based tok_gold_labels = { (token, label) for ((bnd_start, bnd_end), label) in gold_arguments for token in range(bnd_start, bnd_end + 1) } tok_pred_labels = { (token, label) for ((bnd_start, bnd_end), label) in pred_arguments for token in range(bnd_start, bnd_end + 1) } for g_tok, g_tok_label in tok_gold_labels: if (g_tok, g_tok_label) in tok_pred_labels: true_pos_tok += 1 else: false_neg_tok += 1 for p_tok, p_tok_label in tok_pred_labels: if (p_tok, p_tok_label) not in tok_gold_labels: false_pos_tok += 1 prec = true_pos / (true_pos + false_pos) rec = true_pos / (true_pos + false_neg) f1_score = 2 * ((prec * rec) / (prec + rec)) print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") if __name__ == "__main__": main()