Gosse Minnema
Re-enable LOME
2890e34
import logging
from abc import ABC
from typing import *
import numpy as np
from allennlp.common.util import END_SYMBOL
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
from allennlp.data.fields import *
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Token
from ..utils import Span, BIOSmoothing, apply_bio_smoothing
logger = logging.getLogger(__name__)
@DatasetReader.register('span')
class SpanReader(DatasetReader, ABC):
def __init__(
self,
pretrained_model: str,
max_length: int = 512,
ignore_label: bool = False,
debug: bool = False,
**extras
) -> None:
"""
:param pretrained_model: The name of the pretrained model. E.g. xlm-roberta-large
:param max_length: Sequences longer than this limit will be truncated.
:param ignore_label: If True, label on spans will be anonymized.
:param debug: True to turn on debugging mode.
:param span_proposals: Needed for "enumeration" scheme, but not needed for "BIO".
If True, it will try to enumerate candidate spans in the sentence, which will then be fed into
a binary classifier (EnumSpanFinder).
Note: It might take time to propose spans. And better to use SpacyTokenizer if you want to call
constituency parser or dependency parser.
:param maximum_negative_spans: Necessary for EnumSpanFinder.
:param extras: Args to DatasetReader.
"""
super().__init__(**extras)
self.word_indexer = {
'pieces': PretrainedTransformerIndexer(pretrained_model, namespace='pieces')
}
self._pretrained_model_name = pretrained_model
self.debug = debug
self.ignore_label = ignore_label
self._pretrained_tokenizer = PretrainedTransformerTokenizer(pretrained_model)
self.max_length = max_length
self.n_span_removed = 0
def retokenize(
self, sentence: List[str], truncate: bool = True
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
pieces, offsets = self._pretrained_tokenizer.intra_word_tokenize(sentence)
pieces = list(map(str, pieces))
if truncate:
pieces = pieces[:self.max_length]
pieces[-1] = END_SYMBOL
return pieces, offsets
def prepare_inputs(
self,
sentence: List[str],
spans: Optional[Union[List[Span], Span]] = None,
truncate: bool = True,
label_type: str = 'string',
) -> Dict[str, Field]:
"""
Prepare inputs and auxiliary variables for span model.
:param sentence: A list of tokens. Do not pass in any special tokens, like BOS or EOS.
Necessary for both training and testing.
:param spans: Optional. For training, spans passed in will be considered as positive examples; the spans
that are automatically proposed and not in the positive set will be considered as negative examples.
Necessary for training.
:param truncate: If True, sequence will be truncated if it's longer than `self.max_training_length`
:param label_type: One of [string, list].
:return: Dict of AllenNLP fields. For detailed of explanation of every field, refer to the comments
below. For the shape of every field, check the module doc.
Fields list:
- words
- span_labels
- span_boundary
- parent_indices
- parent_mask
- bio_seqs
- raw_sentence
- raw_spans
- proposed_spans
"""
fields = dict()
pieces, offsets = self.retokenize(sentence, truncate)
fields['tokens'] = TextField(list(map(Token, pieces)), self.word_indexer)
raw_inputs = {'sentence': sentence, "pieces": pieces, 'offsets': offsets}
fields['raw_inputs'] = MetadataField(raw_inputs)
if spans is None:
return fields
vr = spans if isinstance(spans, Span) else Span.virtual_root(spans)
self.n_span_removed = vr.remove_overlapping()
raw_inputs['spans'] = vr
vr = vr.re_index(offsets)
if truncate:
vr.truncate(self.max_length)
if self.ignore_label:
vr.ignore_labels()
# (start_idx, end_idx) pairs. Left and right inclusive.
# The first span is the Virtual Root node. Shape [span, 2]
span_boundary = list()
# label on span. Shape [span]
span_labels = list()
# parent idx (span indexing space). Shape [span]
span_parent_indices = list()
# True for parents. Shape [span]
parent_mask = [False] * vr.n_nodes
# Key: parent idx (span indexing space). Value: child span idx
flatten_spans = list(vr.bfs())
for span_idx, span in enumerate(vr.bfs()):
if span.is_parent:
parent_mask[span_idx] = True
# 0 is the virtual root
parent_idx = flatten_spans.index(span.parent) if span.parent else 0
span_parent_indices.append(parent_idx)
span_boundary.append(span.boundary)
span_labels.append(span.label)
bio_tag_list: List[List[str]] = list()
bio_configs: List[List[BIOSmoothing]] = list()
# Shape: [#parent, #token, 3]
bio_seqs: List[np.ndarray] = list()
# Parent index for every BIO seq
for parent_idx, parent in filter(lambda node: node[1].is_parent, enumerate(flatten_spans)):
bio_tags = ['O'] * len(pieces)
bio_tag_list.append(bio_tags)
bio_smooth: List[BIOSmoothing] = [parent.child_smooth.clone() for _ in pieces]
bio_configs.append(bio_smooth)
for child in parent:
assert all(bio_tags[bio_idx] == 'O' for bio_idx in range(child.start_idx, child.end_idx + 1))
if child.smooth_weight is not None:
for i in range(child.start_idx, child.end_idx+1):
bio_smooth[i].weight = child.smooth_weight
bio_tags[child.start_idx] = 'B'
for word_idx in range(child.start_idx + 1, child.end_idx + 1):
bio_tags[word_idx] = 'I'
bio_seqs.append(apply_bio_smoothing(bio_smooth, bio_tags))
fields['span_boundary'] = ArrayField(
np.array(span_boundary), padding_value=0, dtype=np.int
)
fields['parent_indices'] = ArrayField(np.array(span_parent_indices), 0, np.int)
if label_type == 'string':
fields['span_labels'] = ListField([LabelField(label, 'span_label') for label in span_labels])
elif label_type == 'list':
fields['span_labels'] = ArrayField(np.array(span_labels))
else:
raise NotImplementedError
fields['parent_mask'] = ArrayField(np.array(parent_mask), False, np.bool)
fields['bio_seqs'] = ArrayField(np.stack(bio_seqs))
self._sanity_check(
flatten_spans, pieces, bio_tag_list, parent_mask, span_boundary, span_labels, span_parent_indices
)
return fields
@staticmethod
def _sanity_check(
flatten_spans, words, bio_tag_list, parent_mask, span_boundary, span_labels, parent_indices, verbose=False
):
# For debugging use.
assert len(parent_mask) == len(span_boundary) == len(span_labels) == len(parent_indices)
for (parent_idx, parent_span), bio_tags in zip(
filter(lambda x: x[1].is_parent, enumerate(flatten_spans)), bio_tag_list
):
assert parent_mask[parent_idx]
parent_s, parent_e = span_boundary[parent_idx]
if verbose:
print('Parent: ', span_labels[parent_idx], 'Text: ', ' '.join(words[parent_s:parent_e+1]))
print(f'It contains {len(parent_span)} children.')
for child in parent_span:
child_idx = flatten_spans.index(child)
assert parent_indices[child_idx] == flatten_spans.index(parent_span)
if verbose:
child_s, child_e = span_boundary[child_idx]
print(' ', span_labels[child_idx], 'Text', words[child_s:child_e+1])
if verbose:
print(f'Child derived from BIO tags:')
for _, (start, end) in bio_tags_to_spans(bio_tags):
print(words[start:end+1])