Gosse Minnema
Re-enable LOME
2890e34
import json
import logging
import os
from collections import defaultdict, namedtuple
from typing import *
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from .span_reader import SpanReader
from ..utils import Span
# logging.basicConfig(level=logging.DEBUG)
# for v in logging.Logger.manager.loggerDict.values():
# v.disabled = True
logger = logging.getLogger(__name__)
SpanTuple = namedtuple('Span', ['start', 'end'])
@DatasetReader.register('better')
class BetterDatasetReader(SpanReader):
def __init__(
self,
eval_type,
consolidation_strategy='first',
span_set_type='single',
max_argument_ss_size=1,
use_ref_events=False,
**extra
):
super().__init__(**extra)
self.eval_type = eval_type
assert self.eval_type in ['abstract', 'basic']
self.consolidation_strategy = consolidation_strategy
self.unitary_spans = span_set_type == 'single'
# event anchors are always singleton spans
self.max_arg_spans = max_argument_ss_size
self.use_ref_events = use_ref_events
self.n_overlap_arg = 0
self.n_overlap_trigger = 0
self.n_skip = 0
self.n_too_long = 0
@staticmethod
def post_process_basic_span(predicted_span, basic_entry):
# Convert token offsets back to characters, also get the text spans as a sanity check
# !!!!!
# SF outputs inclusive idxs
# char offsets are inc-exc
# token offsets are inc-inc
# !!!!!
start_idx = predicted_span['start_idx'] # inc
end_idx = predicted_span['end_idx'] # inc
char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] # inc
char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 # exc
span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] # inc exc
span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] # inc exc
span = {'string': span_text,
'start': char_start_idx,
'end': char_end_idx,
'start-token': start_idx,
'end-token': end_idx,
'string-tok': span_text_tok,
'label': predicted_span['label'],
'predicted': True}
return span
@staticmethod
def _get_shortest_span(spans):
# shortest_span_length = float('inf')
# shortest_span = None
# for span in spans:
# span_tokens = span['string-tok']
# span_length = len(span_tokens)
# if span_length < shortest_span_length:
# shortest_span_length = span_length
# shortest_span = span
# return shortest_span
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])]
@staticmethod
def _get_first_span(spans):
spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)]
try:
return [s[-1] for s in sorted(spans)]
except:
breakpoint()
@staticmethod
def _get_longest_span(spans):
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)]
@staticmethod
def _subfinder(text, pattern):
# https://stackoverflow.com/a/12576755
matches = []
pattern_length = len(pattern)
for i, token in enumerate(text):
try:
if token == pattern[0] and text[i:i + pattern_length] == pattern:
matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) # inclusive boundaries
except:
continue
return matches
def consolidate_span_set(self, spans):
if self.consolidation_strategy == 'first':
spans = BetterDatasetReader._get_first_span(spans)
elif self.consolidation_strategy == 'shortest':
spans = BetterDatasetReader._get_shortest_span(spans)
elif self.consolidation_strategy == 'longest':
spans = BetterDatasetReader._get_longest_span(spans)
else:
raise NotImplementedError(f"{self.consolidation_strategy} does not exist")
if self.unitary_spans:
spans = [spans[0]]
else:
spans = spans[:self.max_arg_spans]
# TODO add some sanity checks here
return spans
def get_mention_spans(self, text: List[str], span_sets: Dict):
mention_spans = defaultdict(list)
for span_set_id in span_sets.keys():
spans = span_sets[span_set_id]['spans']
# span = BetterDatasetReader._get_shortest_span(spans)
# span = BetterDatasetReader._get_earliest_span(spans)
consolidated_spans = self.consolidate_span_set(spans)
# if len(spans) > 1:
# logging.info(f"Truncated a spanset from {len(spans)} spans to 1")
if self.eval_type == 'abstract':
span = consolidated_spans[0]
span_tokens = span['string-tok']
span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens)
if len(span_indices) > 1:
pass
if len(span_indices) == 0:
continue
mention_spans[span_set_id] = span_indices[0]
else:
# in basic, we already have token offsets in the right form
# if not span['string-tok'] == text[span['start-token']:span['end-token'] + 1]:
# print(span, text[span['start-token']:span['end-token'] + 1])
# we should use these token offsets only!
for span in consolidated_spans:
mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token']))
return mention_spans
def _read_single_file(self, file_path):
with open(file_path) as fp:
json_content = json.load(fp)
if 'entries' in json_content:
for doc_name, entry in json_content['entries'].items():
instance = self.text_to_instance(entry, 'train' in file_path)
yield instance
else: # TODO why is this split in 2 cases?
for doc_name, entry in json_content.items():
instance = self.text_to_instance(entry, True)
yield instance
logger.warning(f'{self.n_overlap_arg} overlapped args detected!')
logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!')
logger.warning(f'{self.n_skip} skipped detected!')
logger.warning(f'{self.n_too_long} were skipped because they are too long!')
self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0
def _read(self, file_path: str) -> Iterable[Instance]:
if os.path.isdir(file_path):
for fn in os.listdir(file_path):
if not fn.endswith('.json'):
logger.info(f'Skipping {fn}')
continue
logger.info(f'Loading from {fn}')
yield from self._read_single_file(os.path.join(file_path, fn))
else:
yield from self._read_single_file(file_path)
def text_to_instance(self, entry, is_training=False):
word_tokens = entry['segment-text-tok']
# span sets have been trimmed to the earliest span mention
spans = self.get_mention_spans(
word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets']
)
# idx of every token that is a part of an event trigger/anchor span
all_trigger_idxs = set()
# actual inputs to the model
input_spans = []
self._local_child_overlap = 0
self._local_child_total = 0
better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events']
skipped_events = set()
# check for events that overlap other event's anchors, skip them later
for event_id, event in better_events.items():
assert event['anchors'] in spans
# take the first consolidated span for anchors
anchor_start, anchor_end = spans[event['anchors']][0]
if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)):
logger.warning(
f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor")
self.n_overlap_trigger += 1
skipped_events.add(event_id)
continue
all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) # record the trigger
for event_id, event in better_events.items():
if event_id in skipped_events:
continue
# arguments for just this event
local_arg_idxs = set()
# take the first consolidated span for anchors
anchor_start, anchor_end = spans[event['anchors']][0]
event_span = Span(anchor_start, anchor_end, event['event-type'], True)
input_spans.append(event_span)
def add_a_child(span_id, label):
# TODO this is a bad way to do this
assert span_id in spans
for child_span in spans[span_id]:
self._local_child_total += 1
arg_start, arg_end = child_span
if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)):
# logger.warn(f"Skipped argument {span_id}, overlaps a previously found argument")
# print(entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'][span_id])
self.n_overlap_arg += 1
self._local_child_overlap += 1
continue
local_arg_idxs.update(range(arg_start, arg_end + 1))
event_span.add_child(Span(arg_start, arg_end, label, False))
for agent in event['agents']:
add_a_child(agent, 'agent')
for patient in event['patients']:
add_a_child(patient, 'patient')
if self.use_ref_events:
for ref_event in event['ref-events']:
if ref_event in skipped_events:
continue
ref_event_anchor_id = better_events[ref_event]['anchors']
add_a_child(ref_event_anchor_id, 'ref-event')
# if len(event['ref-events']) > 0:
# breakpoint()
fields = self.prepare_inputs(word_tokens, spans=input_spans)
if self._local_child_overlap > 0:
logging.warning(
f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps")
return Instance(fields)