|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
from collections import Counter |
|
from multiprocessing import Pool |
|
from typing import Iterable, List |
|
|
|
import torch |
|
|
|
|
|
def item(tensor): |
|
|
|
if torch.is_tensor(tensor) and tensor.device.type == "xla": |
|
return tensor.detach() |
|
if hasattr(tensor, "item"): |
|
return tensor.item() |
|
if hasattr(tensor, "__getitem__"): |
|
return tensor[0] |
|
return tensor |
|
|
|
|
|
def post_process(sentence: str, symbol: str): |
|
if symbol == "sentencepiece": |
|
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() |
|
elif symbol == "wordpiece": |
|
sentence = sentence.replace(" ", "").replace("_", " ").strip() |
|
elif symbol == "letter": |
|
sentence = sentence.replace(" ", "").replace("|", " ").strip() |
|
elif symbol == "silence": |
|
import re |
|
|
|
sentence = sentence.replace("<SIL>", "") |
|
sentence = re.sub(" +", " ", sentence).strip() |
|
elif symbol == "_EOW": |
|
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() |
|
elif symbol in {"subword_nmt", "@@ ", "@@"}: |
|
if symbol == "subword_nmt": |
|
symbol = "@@ " |
|
sentence = (sentence + " ").replace(symbol, "").rstrip() |
|
elif symbol == "none": |
|
pass |
|
elif symbol is not None: |
|
raise NotImplementedError(f"Unknown post_process option: {symbol}") |
|
return sentence |
|
|
|
|
|
SPACE_NORMALIZER = re.compile(r"\s+") |
|
|
|
|
|
def tokenize_line(line): |
|
line = SPACE_NORMALIZER.sub(" ", line) |
|
line = line.strip() |
|
return line.split() |
|
|
|
|
|
def _safe_readline(fd) -> str: |
|
pos = fd.tell() |
|
while True: |
|
try: |
|
return fd.readline() |
|
except UnicodeDecodeError: |
|
pos -= 1 |
|
fd.seek(pos) |
|
|
|
|
|
def find_offsets(filename: str, num_chunks: int) -> List[int]: |
|
""" |
|
given a file and a number of chuncks, find the offsets in the file |
|
to be able to chunk around full lines. |
|
""" |
|
with open(filename, "r", encoding="utf-8") as f: |
|
size = os.fstat(f.fileno()).st_size |
|
chunk_size = size // num_chunks |
|
offsets = [0 for _ in range(num_chunks + 1)] |
|
for i in range(1, num_chunks): |
|
f.seek(chunk_size * i) |
|
_safe_readline(f) |
|
offsets[i] = f.tell() |
|
offsets[-1] = size |
|
return offsets |
|
|
|
|
|
class ChunkLineIterator: |
|
""" |
|
Iterator to properly iterate over lines of a file chunck. |
|
""" |
|
|
|
def __init__(self, fd, start_offset: int, end_offset: int): |
|
self._fd = fd |
|
self._start_offset = start_offset |
|
self._end_offset = end_offset |
|
|
|
def __iter__(self) -> Iterable[str]: |
|
self._fd.seek(self._start_offset) |
|
|
|
line = _safe_readline(self._fd) |
|
while line: |
|
pos = self._fd.tell() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
self._end_offset > 0 |
|
and pos > self._end_offset |
|
and pos < self._end_offset + 2 ** 32 |
|
): |
|
break |
|
yield line |
|
line = self._fd.readline() |
|
|
|
|
|
class Chunker: |
|
""" |
|
contextmanager to read a chunck of a file line by line. |
|
""" |
|
|
|
def __init__(self, path: str, start_offset: int, end_offset: int): |
|
self.path = path |
|
self.start_offset = start_offset |
|
self.end_offset = end_offset |
|
|
|
def __enter__(self) -> ChunkLineIterator: |
|
self.fd = open(self.path, "r", encoding="utf-8") |
|
return ChunkLineIterator(self.fd, self.start_offset, self.end_offset) |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None: |
|
self.fd.close() |
|
|
|
|
|
class Dictionary: |
|
"""A mapping from symbols to consecutive integers""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
bos="<s>", |
|
pad="<pad>", |
|
eos="</s>", |
|
unk="<unk>", |
|
extra_special_symbols=None, |
|
): |
|
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos |
|
self.symbols = [] |
|
self.count = [] |
|
self.indices = {} |
|
self.bos_index = self.add_symbol(bos) |
|
self.pad_index = self.add_symbol(pad) |
|
self.eos_index = self.add_symbol(eos) |
|
self.unk_index = self.add_symbol(unk) |
|
if extra_special_symbols: |
|
for s in extra_special_symbols: |
|
self.add_symbol(s) |
|
self.nspecial = len(self.symbols) |
|
|
|
def __eq__(self, other): |
|
return self.indices == other.indices |
|
|
|
def __getitem__(self, idx): |
|
if idx < len(self.symbols): |
|
return self.symbols[idx] |
|
return self.unk_word |
|
|
|
def get_count(self, idx): |
|
return self.count[idx] |
|
|
|
def __len__(self): |
|
"""Returns the number of symbols in the dictionary""" |
|
return len(self.symbols) |
|
|
|
def __contains__(self, sym): |
|
return sym in self.indices |
|
|
|
def index(self, sym): |
|
"""Returns the index of the specified symbol""" |
|
assert isinstance(sym, str) |
|
if sym in self.indices: |
|
return self.indices[sym] |
|
return self.unk_index |
|
|
|
def string( |
|
self, |
|
tensor, |
|
bpe_symbol=None, |
|
escape_unk=False, |
|
extra_symbols_to_ignore=None, |
|
unk_string=None, |
|
include_eos=False, |
|
separator=" ", |
|
): |
|
"""Helper for converting a tensor of token indices to a string. |
|
|
|
Can optionally remove BPE symbols or escape <unk> words. |
|
""" |
|
if torch.is_tensor(tensor) and tensor.dim() == 2: |
|
return "\n".join( |
|
self.string( |
|
t, |
|
bpe_symbol, |
|
escape_unk, |
|
extra_symbols_to_ignore, |
|
include_eos=include_eos, |
|
) |
|
for t in tensor |
|
) |
|
|
|
extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) |
|
if not include_eos: |
|
extra_symbols_to_ignore.add(self.eos()) |
|
|
|
def token_string(i): |
|
if i == self.unk(): |
|
if unk_string is not None: |
|
return unk_string |
|
else: |
|
return self.unk_string(escape_unk) |
|
else: |
|
return self[i] |
|
|
|
if hasattr(self, "bos_index"): |
|
extra_symbols_to_ignore.add(self.bos()) |
|
|
|
sent = separator.join( |
|
token_string(i) |
|
for i in tensor |
|
if item(i) not in extra_symbols_to_ignore |
|
) |
|
|
|
return post_process(sent, bpe_symbol) |
|
|
|
def unk_string(self, escape=False): |
|
"""Return unknown string, optionally escaped as: <<unk>>""" |
|
if escape: |
|
return "<{}>".format(self.unk_word) |
|
else: |
|
return self.unk_word |
|
|
|
def add_symbol(self, word, n=1, overwrite=False): |
|
"""Adds a word to the dictionary""" |
|
if word in self.indices and not overwrite: |
|
idx = self.indices[word] |
|
self.count[idx] = self.count[idx] + n |
|
return idx |
|
else: |
|
idx = len(self.symbols) |
|
self.indices[word] = idx |
|
self.symbols.append(word) |
|
self.count.append(n) |
|
return idx |
|
|
|
def update(self, new_dict): |
|
"""Updates counts from new dictionary.""" |
|
for word in new_dict.symbols: |
|
idx2 = new_dict.indices[word] |
|
if word in self.indices: |
|
idx = self.indices[word] |
|
self.count[idx] = self.count[idx] + new_dict.count[idx2] |
|
else: |
|
idx = len(self.symbols) |
|
self.indices[word] = idx |
|
self.symbols.append(word) |
|
self.count.append(new_dict.count[idx2]) |
|
|
|
def finalize(self, threshold=-1, nwords=-1, padding_factor=8): |
|
"""Sort symbols by frequency in descending order, ignoring special ones. |
|
|
|
Args: |
|
- threshold defines the minimum word count |
|
- nwords defines the total number of words in the final dictionary, |
|
including special symbols |
|
- padding_factor can be used to pad the dictionary size to be a |
|
multiple of 8, which is important on some hardware (e.g., Nvidia |
|
Tensor Cores). |
|
""" |
|
if nwords <= 0: |
|
nwords = len(self) |
|
|
|
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) |
|
new_symbols = self.symbols[: self.nspecial] |
|
new_count = self.count[: self.nspecial] |
|
|
|
c = Counter( |
|
dict( |
|
sorted(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])) |
|
) |
|
) |
|
for symbol, count in c.most_common(nwords - self.nspecial): |
|
if count >= threshold: |
|
new_indices[symbol] = len(new_symbols) |
|
new_symbols.append(symbol) |
|
new_count.append(count) |
|
else: |
|
break |
|
|
|
assert len(new_symbols) == len(new_indices) |
|
|
|
self.count = list(new_count) |
|
self.symbols = list(new_symbols) |
|
self.indices = new_indices |
|
|
|
self.pad_to_multiple_(padding_factor) |
|
|
|
def pad_to_multiple_(self, padding_factor): |
|
"""Pad Dictionary size to be a multiple of *padding_factor*.""" |
|
if padding_factor > 1: |
|
i = 0 |
|
while len(self) % padding_factor != 0: |
|
symbol = "madeupword{:04d}".format(i) |
|
self.add_symbol(symbol, n=0) |
|
i += 1 |
|
|
|
def bos(self): |
|
"""Helper to get index of beginning-of-sentence symbol""" |
|
return self.bos_index |
|
|
|
def pad(self): |
|
"""Helper to get index of pad symbol""" |
|
return self.pad_index |
|
|
|
def eos(self): |
|
"""Helper to get index of end-of-sentence symbol""" |
|
return self.eos_index |
|
|
|
def unk(self): |
|
"""Helper to get index of unk symbol""" |
|
return self.unk_index |
|
|
|
@classmethod |
|
def load(cls, f): |
|
"""Loads the dictionary from a text file with the format: |
|
|
|
``` |
|
<symbol0> <count0> |
|
<symbol1> <count1> |
|
... |
|
``` |
|
""" |
|
d = cls() |
|
d.add_from_file(f) |
|
return d |
|
|
|
def add_from_file(self, f): |
|
""" |
|
Loads a pre-existing dictionary from a text file and adds its symbols |
|
to this instance. |
|
""" |
|
if isinstance(f, str): |
|
try: |
|
with open(f, "r", encoding="utf-8") as fd: |
|
self.add_from_file(fd) |
|
except FileNotFoundError as fnfe: |
|
raise fnfe |
|
except UnicodeError: |
|
raise Exception( |
|
"Incorrect encoding detected in {}, please " |
|
"rebuild the dataset".format(f) |
|
) |
|
return |
|
|
|
lines = f.readlines() |
|
indices_start_line = self._load_meta(lines) |
|
|
|
for line in lines[indices_start_line:]: |
|
try: |
|
line, field = line.rstrip().rsplit(" ", 1) |
|
if field == "#fairseq:overwrite": |
|
overwrite = True |
|
line, field = line.rsplit(" ", 1) |
|
else: |
|
overwrite = False |
|
count = int(field) |
|
word = line |
|
if word in self and not overwrite: |
|
raise RuntimeError( |
|
"Duplicate word found when loading Dictionary: '{}'. " |
|
"Duplicate words can overwrite earlier ones by adding the " |
|
"#fairseq:overwrite flag at the end of the corresponding row " |
|
"in the dictionary file. If using the Camembert model, please " |
|
"download an updated copy of the model file.".format(word) |
|
) |
|
self.add_symbol(word, n=count, overwrite=overwrite) |
|
except ValueError: |
|
raise ValueError( |
|
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\"" |
|
) |
|
|
|
def _save(self, f, kv_iterator): |
|
if isinstance(f, str): |
|
os.makedirs(os.path.dirname(f), exist_ok=True) |
|
with open(f, "w", encoding="utf-8") as fd: |
|
return self.save(fd) |
|
for k, v in kv_iterator: |
|
print("{} {}".format(k, v), file=f) |
|
|
|
def _get_meta(self): |
|
return [], [] |
|
|
|
def _load_meta(self, lines): |
|
return 0 |
|
|
|
def save(self, f): |
|
"""Stores dictionary into a text file""" |
|
ex_keys, ex_vals = self._get_meta() |
|
self._save( |
|
f, |
|
zip( |
|
ex_keys + self.symbols[self.nspecial:], |
|
ex_vals + self.count[self.nspecial:], |
|
), |
|
) |
|
|
|
def dummy_sentence(self, length): |
|
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() |
|
t[-1] = self.eos() |
|
return t |
|
|
|
def encode_line( |
|
self, |
|
line, |
|
line_tokenizer=tokenize_line, |
|
add_if_not_exist=True, |
|
consumer=None, |
|
append_eos=True, |
|
reverse_order=False, |
|
) -> torch.IntTensor: |
|
words = line_tokenizer(line) |
|
if reverse_order: |
|
words = list(reversed(words)) |
|
nwords = len(words) |
|
ids = torch.IntTensor(nwords + 1 if append_eos else nwords) |
|
|
|
for i, word in enumerate(words): |
|
if add_if_not_exist: |
|
idx = self.add_symbol(word) |
|
else: |
|
idx = self.index(word) |
|
if consumer is not None: |
|
consumer(word, idx) |
|
ids[i] = idx |
|
if append_eos: |
|
ids[nwords] = self.eos_index |
|
return ids |
|
|
|
@staticmethod |
|
def _add_file_to_dictionary_single_worker( |
|
filename, |
|
tokenize, |
|
eos_word, |
|
start_offset, |
|
end_offset, |
|
): |
|
counter = Counter() |
|
with Chunker(filename, start_offset, end_offset) as line_iterator: |
|
for line in line_iterator: |
|
for word in tokenize(line): |
|
counter.update([word]) |
|
counter.update([eos_word]) |
|
return counter |
|
|
|
@staticmethod |
|
def add_file_to_dictionary(filename, dict, tokenize, num_workers): |
|
def merge_result(counter): |
|
for w, c in sorted(counter.items()): |
|
dict.add_symbol(w, c) |
|
|
|
local_file = filename |
|
offsets = find_offsets(local_file, num_workers) |
|
if num_workers > 1: |
|
chunks = zip(offsets, offsets[1:]) |
|
pool = Pool(processes=num_workers) |
|
results = [] |
|
for (start_offset, end_offset) in chunks: |
|
results.append( |
|
pool.apply_async( |
|
Dictionary._add_file_to_dictionary_single_worker, |
|
( |
|
local_file, |
|
tokenize, |
|
dict.eos_word, |
|
start_offset, |
|
end_offset, |
|
), |
|
) |
|
) |
|
pool.close() |
|
pool.join() |
|
for r in results: |
|
merge_result(r.get()) |
|
else: |
|
merge_result( |
|
Dictionary._add_file_to_dictionary_single_worker( |
|
local_file, tokenize, dict.eos_word, offsets[0], offsets[1] |
|
) |
|
) |
|
|