File size: 6,898 Bytes
2890e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from time import time
from typing import *
from collections import defaultdict

from concrete import (
    Token, TokenList, TextSpan, MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
    Communication, EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Sentence
)
from concrete.util import create_comm, AnalyticUUIDGeneratorFactory
from concrete.validate import validate_communication

from ..utils import Span


def _process_sentence(sent, comm_sent, aug, char_idx_offset: int):
    token_list = list()
    for tok_idx, (start_idx, end_idx) in enumerate(sent['tokenization']):
        token_list.append(Token(
            tokenIndex=tok_idx,
            text=sent['sentence'][start_idx:end_idx + 1],
            textSpan=TextSpan(
                start=start_idx + char_idx_offset,
                ending=end_idx + char_idx_offset + 1
            ),
        ))
    comm_sent.tokenization.tokenList = TokenList(tokenList=token_list)

    sm_list, em_dict, entity_list = list(), dict(), list()

    annotation = sent['annotations'] if isinstance(sent['annotations'], Span) else Span.from_json(sent['annotations'])
    for event in annotation:
        char_start_idx = sent['tokenization'][event.start_idx][0]
        char_end_idx = sent['tokenization'][event.end_idx][1]
        sm = SituationMention(
            uuid=next(aug),
            text=sent['sentence'][char_start_idx: char_end_idx + 1],
            situationType='EVENT',
            situationKind=event.label,
            argumentList=list(),
            tokens=TokenRefSequence(
                tokenIndexList=list(range(event.start_idx, event.end_idx + 1)),
                tokenizationId=comm_sent.tokenization.uuid
            ),
        )

        for arg in event:
            em = em_dict.get((arg.start_idx, arg.end_idx + 1))
            if em is None:
                char_start_idx = sent['tokenization'][arg.start_idx][0]
                char_end_idx = sent['tokenization'][arg.end_idx][1]
                em = EntityMention(next(aug), TokenRefSequence(
                    tokenIndexList=list(range(arg.start_idx, arg.end_idx + 1)),
                    tokenizationId=comm_sent.tokenization.uuid,
                ), text=sent['sentence'][char_start_idx: char_end_idx + 1])
                entity_list.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
                em_dict[(arg.start_idx, arg.end_idx + 1)] = em
            sm.argumentList.append(MentionArgument(
                role=arg.label,
                entityMentionId=em.uuid,
            ))

        sm_list.append(sm)

    return sm_list, list(em_dict.values()), entity_list


def concrete_doc(
        sentences: List[Dict[str, Any]],
        doc_name: str = 'document',
) -> Communication:
    """
    Data format: A list of sentences. Each sentence should be a dict of the following format:
    {
        "sentence": String.
        "tokenization": A list of Tuple[int, int] for start and end indices. Both inclusive.
        "annotations": A list of event dict, or Span object.
    }
    If it is dict, its format should be:

        Each event should be a dict of the following format:
        {
            "span": [start_idx, end_idx]: Integer. Both inclusive.
            "label": String.
            "children": A list of arguments.
        }
        Each argument should be a dict of the following format:
        {
            "span": [start_idx, end_idx]: Integer. Both inclusive.
            "label": String.
        }

    Note the "indices" above all refer to the indices of tokens, instead of characters.
    """
    comm = create_comm(
        doc_name,
        '\n'.join([sent['sentence'] for sent in sentences]),
    )
    aug = AnalyticUUIDGeneratorFactory(comm).create()
    situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
    comm.situationMentionSetList = [situation_mention_set]
    entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
    comm.entityMentionSetList = [entity_mention_set]
    entity_set = EntitySet(
        next(aug), AnnotationMetadata('O(0) Coref Paser.', time()), list(), None, entity_mention_set.uuid
    )
    comm.entitySetList = [entity_set]
    assert len(sentences) == len(comm.sectionList[0].sentenceList)

    char_idx_offset = 0
    for sent, comm_sent in zip(sentences, comm.sectionList[0].sentenceList):
        sm_list, em_list, entity_list = _process_sentence(sent, comm_sent, aug, char_idx_offset)
        entity_set.entityList.extend(entity_list)
        situation_mention_set.mentionList.extend(sm_list)
        entity_mention_set.mentionList.extend(em_list)
        char_idx_offset += len(sent['sentence']) + 1

    validate_communication(comm)
    return comm


def concrete_doc_tokenized(
        sentences: List[List[str]],
        spans: List[Span],
        doc_name: str = "document",
):
    """
    Similar to concrete_doc, but with tokenized words and spans.
    """
    inputs = list()
    for sent, vr in zip(sentences, spans):
        cur_start = 0
        tokenization = list()
        for token in sent:
            tokenization.append((cur_start, cur_start + len(token) - 1))
            cur_start += len(token) + 1
        inputs.append({
            "sentence": " ".join(sent),
            "tokenization": tokenization,
            "annotations": vr
        })
    return concrete_doc(inputs, doc_name)


def collect_concrete_srl(comm: Communication) -> List[Tuple[List[str], Span]]:
    # Mapping from <sentence uuid> to [<ConcreteSentence>, <Associated situation mentions>]
    sentences = defaultdict(lambda: [None, list()])
    for sec in comm.sectionList:
        for sen in sec.sentenceList:
            sentences[sen.uuid.uuidString][0] = sen
    # Assume there's only ONE situation mention set
    assert len(comm.situationMentionSetList) == 1
    # Assign each situation mention to the corresponding sentence
    for men in comm.situationMentionSetList[0].mentionList:
        if men.tokens is None: continue  # For ACE relations
        sentences[men.tokens.tokenization.sentence.uuid.uuidString][1].append(men)
    ret = list()
    for sen, mention_list in sentences.values():
        tokens = [t.text for t in sen.tokenization.tokenList.tokenList]
        spans = list()
        for mention in mention_list:
            mention_tokens = sorted(mention.tokens.tokenIndexList)
            event = Span(mention_tokens[0], mention_tokens[-1], mention.situationKind, True)
            for men_arg in mention.argumentList:
                arg_tokens = sorted(men_arg.entityMention.tokens.tokenIndexList)
                event.add_child(Span(arg_tokens[0], arg_tokens[-1], men_arg.role, False))
            spans.append(event)
        vr = Span.virtual_root(spans)
        ret.append((tokens, vr))
    return ret