File size: 6,325 Bytes
b78b52f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: 
"""
import os

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from transformers import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
import sentencepiece as spm
import argparse


def is_chinese(uchar):
    """判断一个unicode是否是汉字"""
    return '\u4e00' <= uchar <= '\u9fa5'


def is_chinese_string(string):
    """判断是否全为汉字"""
    return all(is_chinese(c) for c in string)


def load_baichuan_vocab(vocab_file):
    words = set()
    with open(vocab_file, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                words.add(line.strip().split()[0])
    return words


def load_jieba_vocab(jieba_vocab_file):
    # Read jieba vocab and sort by freq
    with open(jieba_vocab_file, "r", encoding="utf-8") as f:
        lines = f.readlines()
        word_freqs = [line.strip().split() for line in lines]
        word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
    return word_freqs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
    parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
    parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
    parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
    parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
    parser.add_argument('--jieba_word_size', default=20000, type=int)

    args = parser.parse_args()
    print(args)

    # load
    llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
    chinese_sp_model = spm.SentencePieceProcessor()
    chinese_sp_model.Load(args.domain_sp_model_file)

    llama_spm = sp_pb2_model.ModelProto()
    llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
    chinese_spm = sp_pb2_model.ModelProto()
    chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())

    # print number of tokens
    print(len(llama_tokenizer), len(chinese_sp_model))
    print(llama_tokenizer.all_special_tokens)
    print(llama_tokenizer.all_special_ids)
    print(llama_tokenizer.special_tokens_map)

    # Add Chinese tokens to LLaMA tokenizer
    llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)

    print(len(llama_spm_tokens_set))
    print(f"Before:{len(llama_spm_tokens_set)}")
    added_set = set()
    for p in chinese_spm.pieces:
        piece = p.piece
        if piece not in llama_spm_tokens_set:
            # print('picec', piece)
            new_p = sp_pb2_model.ModelProto().SentencePiece()
            new_p.piece = piece
            new_p.score = 0
            llama_spm.pieces.append(new_p)
            added_set.add(piece)
    print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")

    vocab = load_baichuan_vocab(args.baichuan_vocab_file)
    print('baichuan vocab len:', len(vocab))
    baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
    print('baichuan chinese vocab size:', len(baichuan_vocab_set))
    print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
    for p in baichuan_vocab_set:
        piece = p
        if piece not in llama_spm_tokens_set and piece not in added_set:
            # print('baichuan picec', piece)
            new_p = sp_pb2_model.ModelProto().SentencePiece()
            new_p.piece = piece
            new_p.score = 0
            llama_spm.pieces.append(new_p)
            added_set.add(piece)
    print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")

    if args.add_jieba:
        word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
        top_words = word_freqs[:args.jieba_word_size]
        print('jieba top10 freq words:', top_words[:10])
        jieba_vocab_set = set([i[0] for i in top_words if i])
        print('jieba_vocab_set size:', len(jieba_vocab_set))
        print('jieba_vocab head:', list(jieba_vocab_set)[:3])
        for p in jieba_vocab_set:
            piece = p
            if piece not in llama_spm_tokens_set and piece not in added_set:
                # print('jieba picec', piece)
                new_p = sp_pb2_model.ModelProto().SentencePiece()
                new_p.piece = piece
                new_p.score = 0
                llama_spm.pieces.append(new_p)
        print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")

    # Save
    output_sp_dir = 'merged_tokenizer_sp'
    output_hf_dir = 'merged_tokenizer_hf'  # the path to save Chinese-LLaMA tokenizer
    os.makedirs(output_sp_dir, exist_ok=True)
    with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
        f.write(llama_spm.SerializeToString())
    tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')

    tokenizer.save_pretrained(output_hf_dir)
    print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")

    # Test
    llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
    chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
    print(chinese_llama_tokenizer.all_special_tokens)
    print(chinese_llama_tokenizer.all_special_ids)
    print(chinese_llama_tokenizer.special_tokens_map)
    print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
    text = '''this is a test, hello world. thisisatesthelloworld, 
慕容复来到河边,姑苏慕容氏在外面丢了人。
1号店一周岁了,我们一古脑儿买了10斤零食。
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
    print("Test text:\n", text)
    print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
    print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")


if __name__ == '__main__':
    main()