diff --git a/src/atmst/blockstore/car_file.py b/src/atmst/blockstore/car_file.py index 6f33bc9..7b36d35 100644 --- a/src/atmst/blockstore/car_file.py +++ b/src/atmst/blockstore/car_file.py @@ -6,20 +6,24 @@ from . import BlockStore # should be equivalent to multiformats.varint.decode(), but not extremely slow for no reason. -def parse_varint(stream: BinaryIO): +def decode_varint(stream: BinaryIO): n = 0 - shift = 0 - while True: + for shift in range(0, 63, 7): val = stream.read(1) if not val: - raise ValueError("eof") # match varint.decode() + raise ValueError("unexpected end of varint input") val = val[0] n |= (val & 0x7f) << shift if not val & 0x80: + if shift and not val: + raise ValueError("varint not minimally encoded") return n shift += 7 + raise ValueError("varint too long") def encode_varint(n: int) -> bytes: + if not 0 <= n < 2**63: + raise ValueError("integer out of encodable varint range") res = [] while n > 0x7f: res.append(0x80 | (n & 0x7f)) @@ -47,7 +51,7 @@ def __init__(self, file: BinaryIO, validate_hashes: bool=True) -> None: file.seek(0) # parse out CAR header - header_len = parse_varint(file) + header_len = decode_varint(file) header = file.read(header_len) if len(header) != header_len: raise EOFError("not enough CAR header bytes") @@ -62,7 +66,7 @@ def __init__(self, file: BinaryIO, validate_hashes: bool=True) -> None: self.block_offsets = {} while True: try: - length = parse_varint(file) + length = decode_varint(file) except ValueError: break # EOF start = file.tell() diff --git a/tests/test_varint.py b/tests/test_varint.py new file mode 100644 index 0000000..effe9d2 --- /dev/null +++ b/tests/test_varint.py @@ -0,0 +1,28 @@ +import unittest +import io + +from atmst.blockstore.car_file import decode_varint, encode_varint + +class MSTDiffTestCase(unittest.TestCase): + def test_varint_encode(self): + self.assertEqual(encode_varint(0), b"\x00") + self.assertEqual(encode_varint(1), b"\x01") + self.assertEqual(encode_varint(127), b"\x7f") + self.assertEqual(encode_varint(128), b"\x80\x01") + self.assertEqual(encode_varint(2**63-1), b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f') + self.assertRaises(ValueError, encode_varint, 2**63) + self.assertRaises(ValueError, encode_varint, -1) + + def test_varint_decode(self): + self.assertEqual(decode_varint(io.BytesIO(b"\x00")), 0) + self.assertEqual(decode_varint(io.BytesIO(b"\x01")), 1) + self.assertEqual(decode_varint(io.BytesIO(b"\x7f")), 127) + self.assertEqual(decode_varint(io.BytesIO(b"\x80\x01")), 128) + self.assertEqual(decode_varint(io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f')), 2**63-1) + self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7f')) # too big + self.assertRaises(ValueError, decode_varint, io.BytesIO(b"")) # too short + self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff')) # truncated + self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # too minimally encoded + +if __name__ == '__main__': + unittest.main(module="tests.test_varint")