Spaces:
Build error
Build error
import json | |
import logging | |
import os | |
from collections import defaultdict, namedtuple | |
from typing import * | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.data.instance import Instance | |
from .span_reader import SpanReader | |
from ..utils import Span | |
# logging.basicConfig(level=logging.DEBUG) | |
# for v in logging.Logger.manager.loggerDict.values(): | |
# v.disabled = True | |
logger = logging.getLogger(__name__) | |
SpanTuple = namedtuple('Span', ['start', 'end']) | |
class BetterDatasetReader(SpanReader): | |
def __init__( | |
self, | |
eval_type, | |
consolidation_strategy='first', | |
span_set_type='single', | |
max_argument_ss_size=1, | |
use_ref_events=False, | |
**extra | |
): | |
super().__init__(**extra) | |
self.eval_type = eval_type | |
assert self.eval_type in ['abstract', 'basic'] | |
self.consolidation_strategy = consolidation_strategy | |
self.unitary_spans = span_set_type == 'single' | |
# event anchors are always singleton spans | |
self.max_arg_spans = max_argument_ss_size | |
self.use_ref_events = use_ref_events | |
self.n_overlap_arg = 0 | |
self.n_overlap_trigger = 0 | |
self.n_skip = 0 | |
self.n_too_long = 0 | |
def post_process_basic_span(predicted_span, basic_entry): | |
# Convert token offsets back to characters, also get the text spans as a sanity check | |
# !!!!! | |
# SF outputs inclusive idxs | |
# char offsets are inc-exc | |
# token offsets are inc-inc | |
# !!!!! | |
start_idx = predicted_span['start_idx'] # inc | |
end_idx = predicted_span['end_idx'] # inc | |
char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] # inc | |
char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 # exc | |
span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] # inc exc | |
span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] # inc exc | |
span = {'string': span_text, | |
'start': char_start_idx, | |
'end': char_end_idx, | |
'start-token': start_idx, | |
'end-token': end_idx, | |
'string-tok': span_text_tok, | |
'label': predicted_span['label'], | |
'predicted': True} | |
return span | |
def _get_shortest_span(spans): | |
# shortest_span_length = float('inf') | |
# shortest_span = None | |
# for span in spans: | |
# span_tokens = span['string-tok'] | |
# span_length = len(span_tokens) | |
# if span_length < shortest_span_length: | |
# shortest_span_length = span_length | |
# shortest_span = span | |
# return shortest_span | |
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])] | |
def _get_first_span(spans): | |
spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)] | |
try: | |
return [s[-1] for s in sorted(spans)] | |
except: | |
breakpoint() | |
def _get_longest_span(spans): | |
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)] | |
def _subfinder(text, pattern): | |
# https://stackoverflow.com/a/12576755 | |
matches = [] | |
pattern_length = len(pattern) | |
for i, token in enumerate(text): | |
try: | |
if token == pattern[0] and text[i:i + pattern_length] == pattern: | |
matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) # inclusive boundaries | |
except: | |
continue | |
return matches | |
def consolidate_span_set(self, spans): | |
if self.consolidation_strategy == 'first': | |
spans = BetterDatasetReader._get_first_span(spans) | |
elif self.consolidation_strategy == 'shortest': | |
spans = BetterDatasetReader._get_shortest_span(spans) | |
elif self.consolidation_strategy == 'longest': | |
spans = BetterDatasetReader._get_longest_span(spans) | |
else: | |
raise NotImplementedError(f"{self.consolidation_strategy} does not exist") | |
if self.unitary_spans: | |
spans = [spans[0]] | |
else: | |
spans = spans[:self.max_arg_spans] | |
# TODO add some sanity checks here | |
return spans | |
def get_mention_spans(self, text: List[str], span_sets: Dict): | |
mention_spans = defaultdict(list) | |
for span_set_id in span_sets.keys(): | |
spans = span_sets[span_set_id]['spans'] | |
# span = BetterDatasetReader._get_shortest_span(spans) | |
# span = BetterDatasetReader._get_earliest_span(spans) | |
consolidated_spans = self.consolidate_span_set(spans) | |
# if len(spans) > 1: | |
# logging.info(f"Truncated a spanset from {len(spans)} spans to 1") | |
if self.eval_type == 'abstract': | |
span = consolidated_spans[0] | |
span_tokens = span['string-tok'] | |
span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens) | |
if len(span_indices) > 1: | |
pass | |
if len(span_indices) == 0: | |
continue | |
mention_spans[span_set_id] = span_indices[0] | |
else: | |
# in basic, we already have token offsets in the right form | |
# if not span['string-tok'] == text[span['start-token']:span['end-token'] + 1]: | |
# print(span, text[span['start-token']:span['end-token'] + 1]) | |
# we should use these token offsets only! | |
for span in consolidated_spans: | |
mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token'])) | |
return mention_spans | |
def _read_single_file(self, file_path): | |
with open(file_path) as fp: | |
json_content = json.load(fp) | |
if 'entries' in json_content: | |
for doc_name, entry in json_content['entries'].items(): | |
instance = self.text_to_instance(entry, 'train' in file_path) | |
yield instance | |
else: # TODO why is this split in 2 cases? | |
for doc_name, entry in json_content.items(): | |
instance = self.text_to_instance(entry, True) | |
yield instance | |
logger.warning(f'{self.n_overlap_arg} overlapped args detected!') | |
logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!') | |
logger.warning(f'{self.n_skip} skipped detected!') | |
logger.warning(f'{self.n_too_long} were skipped because they are too long!') | |
self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0 | |
def _read(self, file_path: str) -> Iterable[Instance]: | |
if os.path.isdir(file_path): | |
for fn in os.listdir(file_path): | |
if not fn.endswith('.json'): | |
logger.info(f'Skipping {fn}') | |
continue | |
logger.info(f'Loading from {fn}') | |
yield from self._read_single_file(os.path.join(file_path, fn)) | |
else: | |
yield from self._read_single_file(file_path) | |
def text_to_instance(self, entry, is_training=False): | |
word_tokens = entry['segment-text-tok'] | |
# span sets have been trimmed to the earliest span mention | |
spans = self.get_mention_spans( | |
word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'] | |
) | |
# idx of every token that is a part of an event trigger/anchor span | |
all_trigger_idxs = set() | |
# actual inputs to the model | |
input_spans = [] | |
self._local_child_overlap = 0 | |
self._local_child_total = 0 | |
better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events'] | |
skipped_events = set() | |
# check for events that overlap other event's anchors, skip them later | |
for event_id, event in better_events.items(): | |
assert event['anchors'] in spans | |
# take the first consolidated span for anchors | |
anchor_start, anchor_end = spans[event['anchors']][0] | |
if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)): | |
logger.warning( | |
f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor") | |
self.n_overlap_trigger += 1 | |
skipped_events.add(event_id) | |
continue | |
all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) # record the trigger | |
for event_id, event in better_events.items(): | |
if event_id in skipped_events: | |
continue | |
# arguments for just this event | |
local_arg_idxs = set() | |
# take the first consolidated span for anchors | |
anchor_start, anchor_end = spans[event['anchors']][0] | |
event_span = Span(anchor_start, anchor_end, event['event-type'], True) | |
input_spans.append(event_span) | |
def add_a_child(span_id, label): | |
# TODO this is a bad way to do this | |
assert span_id in spans | |
for child_span in spans[span_id]: | |
self._local_child_total += 1 | |
arg_start, arg_end = child_span | |
if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)): | |
# logger.warn(f"Skipped argument {span_id}, overlaps a previously found argument") | |
# print(entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'][span_id]) | |
self.n_overlap_arg += 1 | |
self._local_child_overlap += 1 | |
continue | |
local_arg_idxs.update(range(arg_start, arg_end + 1)) | |
event_span.add_child(Span(arg_start, arg_end, label, False)) | |
for agent in event['agents']: | |
add_a_child(agent, 'agent') | |
for patient in event['patients']: | |
add_a_child(patient, 'patient') | |
if self.use_ref_events: | |
for ref_event in event['ref-events']: | |
if ref_event in skipped_events: | |
continue | |
ref_event_anchor_id = better_events[ref_event]['anchors'] | |
add_a_child(ref_event_anchor_id, 'ref-event') | |
# if len(event['ref-events']) > 0: | |
# breakpoint() | |
fields = self.prepare_inputs(word_tokens, spans=input_spans) | |
if self._local_child_overlap > 0: | |
logging.warning( | |
f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps") | |
return Instance(fields) | |