from typing import Any, Dict, List, Optional import dataclasses import glob import os import sys import json import spacy from spacy.language import Language from sftp import SpanPredictor @dataclasses.dataclass class FrameAnnotation: tokens: List[str] = dataclasses.field(default_factory=list) pos: List[str] = dataclasses.field(default_factory=list) @dataclasses.dataclass class MultiLabelAnnotation(FrameAnnotation): frame_list: List[List[str]] = dataclasses.field(default_factory=list) lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) def to_txt(self): for i, tok in enumerate(self.tokens): yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: labels = [[] for _ in sentence] for struct_id, struct in structures.items(): tgt_span = struct["target"] frame = struct["frame"] for i in range(tgt_span[0], tgt_span[1] + 1): labels[i].append(f"T:{frame}@{struct_id:02}") for role in struct["roles"]: role_span = role["boundary"] role_label = role["label"] for i in range(role_span[0], role_span[1] + 1): prefix = "B" if i == role_span[0] else "I" labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") return labels def predict_combined( spacy_model: Language, sentences: List[str], tgt_predictor: SpanPredictor, frm_predictor: SpanPredictor, bnd_predictor: SpanPredictor, arg_predictor: SpanPredictor, ) -> List[MultiLabelAnnotation]: annotations_out = [] for sent_idx, sent in enumerate(sentences): sent = sent.strip() print(f"Processing sent with idx={sent_idx}: {sent}") doc = spacy_model(sent) sent_tokens = [t.text for t in doc] tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens) frame_structures = {} for i, span in enumerate(tgt_spans): span = tuple(span) _, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span]) frame = fr_labels[0] if frame == "@@VIRTUAL_ROOT@@@": continue boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame) _, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries) frame_structures[i] = { "target": span, "frame": frame, "roles": [ {"boundary": bnd, "label": label} for bnd, label in zip(boundaries, arg_labels) if label != "Target" ] } annotations_out.append(MultiLabelAnnotation( tokens=sent_tokens, pos=[t.pos_ for t in doc], frame_list=convert_to_seq_labels(sent_tokens, frame_structures), lu_list=[None for _ in sent_tokens] )) return annotations_out def main(input_folder): print("Loading spaCy model ...") nlp = spacy.load("it_core_news_md") print("Loading predictors ...") zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0) ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0) print("Reading input files ...") for file in glob.glob(os.path.join(input_folder, "*.txt")): print(file) with open(file, encoding="utf-8") as f: sentences = list(f) annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor) out_name = os.path.splitext(os.path.basename(file))[0] with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out: for ann in annotations: for line in ann.to_txt(): f_out.write(line + os.linesep) f_out.write(os.linesep) with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out: json.dump([dataclasses.asdict(ann) for ann in annotations], f_out) if __name__ == "__main__": main(sys.argv[1])