File size: 7,942 Bytes
d8e916a ae09b33 d6202fc 3e10608 9b6af1b e4dff56 ae09b33 c3eef0f e4dff56 d8e9d34 e4dff56 d6202fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
---
license: llama3.1
base_model:
- THUDM/LongCite-llama3.1-8b
datasets:
- THUDM/LongCite-45k
pipeline_tag: text-generation
tags:
- imatrix
- importance matrix
- gguf
- llama.cpp
---
GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`
Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.
Not chat template provided as it requires python pre-processing (before being sent to LLM) and post-processing.
iMatrix generated using [this dataset](https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-9552049)
Example code
```python
from nltk.tokenize import PunktSentenceTokenizer
import re
class LongCiteModel:
@staticmethod
def text_split_by_punctuation(original_text, return_dict=False):
# text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space
text = original_text
custom_sent_tokenizer = PunktSentenceTokenizer()
punctuations = r"([。;!?])" # For Chinese support
separated = custom_sent_tokenizer.tokenize(text)
separated = sum([re.split(punctuations, s) for s in separated], [])
# Put the punctuations back to the sentence
for i in range(1, len(separated)):
if re.match(punctuations, separated[i]):
separated[i-1] += separated[i]
separated[i] = ''
separated = [s for s in separated if s != ""]
if len(separated) == 1:
separated = original_text.split('\n\n')
separated = [s.strip() for s in separated if s.strip() != ""]
if not return_dict:
return separated
else:
pos = 0
res = []
for i, sent in enumerate(separated):
st = original_text.find(sent, pos)
assert st != -1, sent
ed = st + len(sent)
res.append(
{
'c_idx': i,
'content': sent,
'start_idx': st,
'end_idx': ed,
}
)
pos = ed
return res
@staticmethod
def get_prompt(context, question):
sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True)
splited_context = ""
for i, s in enumerate(sents):
st, ed = s['start_idx'], s['end_idx']
assert s['content'] == context[st:ed], s
ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
sents[i] = {
'content': context[st:ed],
'start': st,
'end': ed,
'c_idx': s['c_idx'],
}
splited_context += f"<C{i}>"+context[st:ed]
prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
return prompt, sents, splited_context
@staticmethod
def get_citations(statement, sents):
c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
merged_citations = []
for i, s in enumerate(spans):
try:
st, ed = [int(x) for x in s.split('-')]
if st > len(sents) - 1 or ed < st:
continue
st, ed = max(0, st), min(ed, len(sents)-1)
assert st <= ed, str(c_texts) + '\t' + str(len(sents))
if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
merged_citations[-1].update({
"end_sentence_idx": ed,
'end_char_idx': sents[ed]['end'],
'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
})
else:
merged_citations.append({
"start_sentence_idx": st,
"end_sentence_idx": ed,
"start_char_idx": sents[st]['start'],
'end_char_idx': sents[ed]['end'],
'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
})
except:
print(c_texts, len(sents), statement)
raise
return statement, merged_citations[:3]
@staticmethod
def postprocess(answer, sents, splited_context):
res = []
pos = 0
new_answer = ""
while True:
st = answer.find("<statement>", pos)
if st == -1:
st = len(answer)
ed = answer.find("</statement>", st)
statement = answer[pos:st]
if len(statement.strip()) > 5:
res.append({
"statement": statement,
"citation": []
})
new_answer += f"<statement>{statement}<cite></cite></statement>"
else:
res.append({
"statement": statement,
"citation": None,
})
new_answer += statement
if ed == -1:
break
statement = answer[st+len("<statement>"):ed]
if len(statement.strip()) > 0:
statement, citations = LongCiteModel.get_citations(statement, sents)
res.append({
"statement": statement,
"citation": citations
})
c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
else:
res.append({
"statement": statement,
"citation": None,
})
new_answer += statement
pos = ed + len("</statement>")
return {
"answer": new_answer.strip(),
"statements_with_citations": [x for x in res if x['citation'] is not None],
"splited_context": splited_context.strip(),
"all_statements": res,
}
@staticmethod
def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
if max_input_length is None:
return prompt
else:
assert tokenizer is not None
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
if len(tokenized_prompt) > max_input_length:
half = int(max_input_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
return prompt
if __name__ == "__main__":
context = '''
your context
'''
query = "your user question here"
prompt, sents, splited_context = LongCiteModel.get_prompt(context, query)
print('Prompt:', prompt)
# add the Llama 3 tags to the prompt
max_input_length = 4096
output = "..." # what the llm returned
result = LongCiteModel.postprocess(output, sents, splited_context)
``` |