LPN64's picture
Update README.md
d8e9d34 verified
metadata
license: llama3.1
base_model:
  - THUDM/LongCite-llama3.1-8b
datasets:
  - THUDM/LongCite-45k
pipeline_tag: text-generation
tags:
  - imatrix
  - importance matrix
  - gguf
  - llama.cpp

GGUF version of longcite, you need to add the following tokens as stop tokens : [128000, 128007, 128009] or ["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]

Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.

Not chat template provided as it requires python pre-processing (before being sent to LLM) and post-processing.

iMatrix generated using this dataset

Example code


from nltk.tokenize import PunktSentenceTokenizer
import re

class LongCiteModel:
    @staticmethod
    def text_split_by_punctuation(original_text, return_dict=False):
        # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text)  # separate period without space
        text = original_text
        custom_sent_tokenizer = PunktSentenceTokenizer()
        punctuations = r"([。;!?])"  # For Chinese support

        separated = custom_sent_tokenizer.tokenize(text)
        separated = sum([re.split(punctuations, s) for s in separated], [])
        # Put the punctuations back to the sentence
        for i in range(1, len(separated)):
            if re.match(punctuations, separated[i]):
                separated[i-1] += separated[i]
                separated[i] = ''

        separated = [s for s in separated if s != ""]
        if len(separated) == 1:
            separated = original_text.split('\n\n')
        separated = [s.strip() for s in separated if s.strip() != ""]
        if not return_dict:
            return separated
        else:
            pos = 0
            res = []
            for i, sent in enumerate(separated):
                st = original_text.find(sent, pos)
                assert st != -1, sent
                ed = st + len(sent)
                res.append(
                    {
                        'c_idx': i,
                        'content': sent,
                        'start_idx': st,
                        'end_idx': ed,
                    }
                )
                pos = ed
            return res 

    @staticmethod
    def get_prompt(context, question):
        sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True)
        splited_context = ""
        for i, s in enumerate(sents):
            st, ed = s['start_idx'], s['end_idx']
            assert s['content'] == context[st:ed], s
            ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
            sents[i] = {
                'content': context[st:ed],
                'start': st,
                'end': ed,
                'c_idx': s['c_idx'],
            }
            splited_context += f"<C{i}>"+context[st:ed]
        prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
        return prompt, sents, splited_context
    
    @staticmethod
    def get_citations(statement, sents):
        c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
        spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
        statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
        merged_citations = []
        for i, s in enumerate(spans):
            try:
                st, ed = [int(x) for x in s.split('-')]
                if st > len(sents) - 1 or ed < st:
                    continue
                st, ed = max(0, st), min(ed, len(sents)-1)
                assert st <= ed, str(c_texts) + '\t' + str(len(sents))
                if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
                    merged_citations[-1].update({
                        "end_sentence_idx": ed,
                        'end_char_idx': sents[ed]['end'],
                        'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
                    })
                else:
                    merged_citations.append({
                        "start_sentence_idx": st,
                        "end_sentence_idx": ed,
                        "start_char_idx":  sents[st]['start'],
                        'end_char_idx': sents[ed]['end'],
                        'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
                    })
            except:
                print(c_texts, len(sents), statement)
                raise
        return statement, merged_citations[:3]
    
    @staticmethod
    def postprocess(answer, sents, splited_context):
        res = []
        pos = 0
        new_answer = ""
        while True:
            st = answer.find("<statement>", pos)
            if st == -1:
                st = len(answer)
            ed = answer.find("</statement>", st)
            statement = answer[pos:st]
            if len(statement.strip()) > 5:
                res.append({
                    "statement": statement,
                    "citation": []
                })
                new_answer += f"<statement>{statement}<cite></cite></statement>"
            else:
                res.append({
                    "statement": statement,
                    "citation": None,
                })
                new_answer += statement
            
            if ed == -1:
                break

            statement = answer[st+len("<statement>"):ed]
            if len(statement.strip()) > 0:
                statement, citations = LongCiteModel.get_citations(statement, sents)
                res.append({
                    "statement": statement,
                    "citation": citations
                })
                c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
                new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
            else:
                res.append({
                    "statement": statement,
                    "citation": None,
                })
                new_answer += statement
            pos = ed + len("</statement>")
        return {
            "answer": new_answer.strip(),
            "statements_with_citations": [x for x in res if x['citation'] is not None],
            "splited_context": splited_context.strip(),
            "all_statements": res,
        }

    @staticmethod
    def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
        if max_input_length is None:
            return prompt
        else:
            assert tokenizer is not None
            tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
            if len(tokenized_prompt) > max_input_length:
                half = int(max_input_length/2)
                prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
            return prompt
    



if __name__ == "__main__":

    context = '''
    your context
    '''
    query = "your user question here"
    prompt, sents, splited_context = LongCiteModel.get_prompt(context, query)
    print('Prompt:', prompt)
    # add the Llama 3 tags to the prompt
    max_input_length = 4096
    output = "..." # what the llm returned
    result = LongCiteModel.postprocess(output, sents, splited_context)