File size: 7,942 Bytes
d8e916a
 
 
 
ae09b33
 
 
d6202fc
 
 
 
 
3e10608
9b6af1b
 
e4dff56
 
ae09b33
 
c3eef0f
 
e4dff56
 
 
 
 
 
 
 
 
 
 
d8e9d34
e4dff56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6202fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
---
license: llama3.1
base_model:
- THUDM/LongCite-llama3.1-8b
datasets:
- THUDM/LongCite-45k
pipeline_tag: text-generation
tags:
- imatrix
- importance matrix
- gguf
- llama.cpp
---
GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`

Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.

Not chat template provided as it requires python pre-processing (before being sent to LLM) and post-processing.

iMatrix generated using [this dataset](https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-9552049)

Example code
```python

from nltk.tokenize import PunktSentenceTokenizer
import re

class LongCiteModel:
    @staticmethod
    def text_split_by_punctuation(original_text, return_dict=False):
        # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text)  # separate period without space
        text = original_text
        custom_sent_tokenizer = PunktSentenceTokenizer()
        punctuations = r"([。;!?])"  # For Chinese support

        separated = custom_sent_tokenizer.tokenize(text)
        separated = sum([re.split(punctuations, s) for s in separated], [])
        # Put the punctuations back to the sentence
        for i in range(1, len(separated)):
            if re.match(punctuations, separated[i]):
                separated[i-1] += separated[i]
                separated[i] = ''

        separated = [s for s in separated if s != ""]
        if len(separated) == 1:
            separated = original_text.split('\n\n')
        separated = [s.strip() for s in separated if s.strip() != ""]
        if not return_dict:
            return separated
        else:
            pos = 0
            res = []
            for i, sent in enumerate(separated):
                st = original_text.find(sent, pos)
                assert st != -1, sent
                ed = st + len(sent)
                res.append(
                    {
                        'c_idx': i,
                        'content': sent,
                        'start_idx': st,
                        'end_idx': ed,
                    }
                )
                pos = ed
            return res 

    @staticmethod
    def get_prompt(context, question):
        sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True)
        splited_context = ""
        for i, s in enumerate(sents):
            st, ed = s['start_idx'], s['end_idx']
            assert s['content'] == context[st:ed], s
            ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
            sents[i] = {
                'content': context[st:ed],
                'start': st,
                'end': ed,
                'c_idx': s['c_idx'],
            }
            splited_context += f"<C{i}>"+context[st:ed]
        prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
        return prompt, sents, splited_context
    
    @staticmethod
    def get_citations(statement, sents):
        c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
        spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
        statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
        merged_citations = []
        for i, s in enumerate(spans):
            try:
                st, ed = [int(x) for x in s.split('-')]
                if st > len(sents) - 1 or ed < st:
                    continue
                st, ed = max(0, st), min(ed, len(sents)-1)
                assert st <= ed, str(c_texts) + '\t' + str(len(sents))
                if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
                    merged_citations[-1].update({
                        "end_sentence_idx": ed,
                        'end_char_idx': sents[ed]['end'],
                        'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
                    })
                else:
                    merged_citations.append({
                        "start_sentence_idx": st,
                        "end_sentence_idx": ed,
                        "start_char_idx":  sents[st]['start'],
                        'end_char_idx': sents[ed]['end'],
                        'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
                    })
            except:
                print(c_texts, len(sents), statement)
                raise
        return statement, merged_citations[:3]
    
    @staticmethod
    def postprocess(answer, sents, splited_context):
        res = []
        pos = 0
        new_answer = ""
        while True:
            st = answer.find("<statement>", pos)
            if st == -1:
                st = len(answer)
            ed = answer.find("</statement>", st)
            statement = answer[pos:st]
            if len(statement.strip()) > 5:
                res.append({
                    "statement": statement,
                    "citation": []
                })
                new_answer += f"<statement>{statement}<cite></cite></statement>"
            else:
                res.append({
                    "statement": statement,
                    "citation": None,
                })
                new_answer += statement
            
            if ed == -1:
                break

            statement = answer[st+len("<statement>"):ed]
            if len(statement.strip()) > 0:
                statement, citations = LongCiteModel.get_citations(statement, sents)
                res.append({
                    "statement": statement,
                    "citation": citations
                })
                c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
                new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
            else:
                res.append({
                    "statement": statement,
                    "citation": None,
                })
                new_answer += statement
            pos = ed + len("</statement>")
        return {
            "answer": new_answer.strip(),
            "statements_with_citations": [x for x in res if x['citation'] is not None],
            "splited_context": splited_context.strip(),
            "all_statements": res,
        }

    @staticmethod
    def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
        if max_input_length is None:
            return prompt
        else:
            assert tokenizer is not None
            tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
            if len(tokenized_prompt) > max_input_length:
                half = int(max_input_length/2)
                prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
            return prompt
    



if __name__ == "__main__":

    context = '''
    your context
    '''
    query = "your user question here"
    prompt, sents, splited_context = LongCiteModel.get_prompt(context, query)
    print('Prompt:', prompt)
    # add the Llama 3 tags to the prompt
    max_input_length = 4096
    output = "..." # what the llm returned
    result = LongCiteModel.postprocess(output, sents, splited_context)




```