ljy266987
add lfs
12bfd03
raw
history blame
5.37 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io
import json
import struct
import typing as tp
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct('!4sBI')
_ENCODEC_MAGIC = b'ECDC'
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
meta_dumped = json.dumps(metadata).encode('utf-8')
version = 0
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version,
len(meta_dumped))
fo.write(header)
fo.write(meta_dumped)
fo.flush()
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
buf = b""
while len(buf) < size:
new_buf = fo.read(size)
if not new_buf:
raise EOFError("Impossible to read enough data from the stream, "
f"{size} bytes remaining.")
buf += new_buf
size -= len(new_buf)
return buf
def read_ecdc_header(fo: tp.IO[bytes]):
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
if magic != _ENCODEC_MAGIC:
raise ValueError("File is not in ECDC format.")
if version != 0:
raise ValueError("Version not supported.")
meta_bytes = _read_exactly(fo, meta_size)
return json.loads(meta_bytes.decode('utf-8'))
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self._current_value += (value << self._current_bits)
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xff
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> tp.Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
return out
def test():
import torch
torch.manual_seed(1234)
for rep in range(4):
length: int = torch.randint(10, 2_000, (1, )).item()
bits: int = torch.randint(1, 16, (1, )).item()
tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist()
rebuilt: tp.List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt),
len(tokens), bits)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
if __name__ == '__main__':
test()