|
|
|
import argparse |
|
import json |
|
import os |
|
|
|
import transformers |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("input_path", type=str, help="Input directory") |
|
parser.add_argument("output_path", type=str, help="Output directory") |
|
args = parser.parse_args() |
|
|
|
|
|
def fix_vocab(vocab): |
|
mask_id = 51960 |
|
unused = mask_id + 1 |
|
remapped = [] |
|
fixed_vocab = {} |
|
for key, value in vocab.items(): |
|
if value == 3 and key != "[UNK]": |
|
if key == "ĠĊ": |
|
fixed_vocab[key] = mask_id - 1 |
|
else: |
|
remapped.append((key, unused)) |
|
unused += 1 |
|
else: |
|
fixed_vocab[key] = value |
|
|
|
for key, value in remapped: |
|
fixed_vocab[key] = value |
|
|
|
return fixed_vocab |
|
|
|
with open(os.path.join(args.input_path, "vocab.json"), "r", encoding="utf-8") as vocab_file: |
|
vocab = json.load(vocab_file) |
|
|
|
fixed_vocab = fix_vocab(vocab) |
|
|
|
with open(os.path.join(args.output_path, "vocab.json"), "w", encoding="utf-8") as vocab_file: |
|
json.dump(fixed_vocab, vocab_file, ensure_ascii=False, indent=None) |
|
print(file=vocab_file) |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(args.output_path) |
|
tokenizer._tokenizer.save(os.path.join(args.output_path, "tokenizer.json")) |
|
|