Skip to content

Commit

Permalink
make varint encode/decode robust, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Feb 25, 2024
1 parent d7d7b8e commit ee7a5e3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/atmst/blockstore/car_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
from . import BlockStore

# should be equivalent to multiformats.varint.decode(), but not extremely slow for no reason.
def parse_varint(stream: BinaryIO):
def decode_varint(stream: BinaryIO):
n = 0
shift = 0
while True:
for shift in range(0, 63, 7):
val = stream.read(1)
if not val:
raise ValueError("eof") # match varint.decode()
raise ValueError("unexpected end of varint input")
val = val[0]
n |= (val & 0x7f) << shift
if not val & 0x80:
if shift and not val:
raise ValueError("varint not minimally encoded")
return n
shift += 7
raise ValueError("varint too long")

def encode_varint(n: int) -> bytes:
if not 0 <= n < 2**63:
raise ValueError("integer out of encodable varint range")
res = []
while n > 0x7f:
res.append(0x80 | (n & 0x7f))
Expand Down Expand Up @@ -47,7 +51,7 @@ def __init__(self, file: BinaryIO, validate_hashes: bool=True) -> None:
file.seek(0)

# parse out CAR header
header_len = parse_varint(file)
header_len = decode_varint(file)
header = file.read(header_len)
if len(header) != header_len:
raise EOFError("not enough CAR header bytes")
Expand All @@ -62,7 +66,7 @@ def __init__(self, file: BinaryIO, validate_hashes: bool=True) -> None:
self.block_offsets = {}
while True:
try:
length = parse_varint(file)
length = decode_varint(file)
except ValueError:
break # EOF
start = file.tell()
Expand Down
28 changes: 28 additions & 0 deletions tests/test_varint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
import io

from atmst.blockstore.car_file import decode_varint, encode_varint

class MSTDiffTestCase(unittest.TestCase):
def test_varint_encode(self):
self.assertEqual(encode_varint(0), b"\x00")
self.assertEqual(encode_varint(1), b"\x01")
self.assertEqual(encode_varint(127), b"\x7f")
self.assertEqual(encode_varint(128), b"\x80\x01")
self.assertEqual(encode_varint(2**63-1), b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f')
self.assertRaises(ValueError, encode_varint, 2**63)
self.assertRaises(ValueError, encode_varint, -1)

def test_varint_decode(self):
self.assertEqual(decode_varint(io.BytesIO(b"\x00")), 0)
self.assertEqual(decode_varint(io.BytesIO(b"\x01")), 1)
self.assertEqual(decode_varint(io.BytesIO(b"\x7f")), 127)
self.assertEqual(decode_varint(io.BytesIO(b"\x80\x01")), 128)
self.assertEqual(decode_varint(io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f')), 2**63-1)
self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7f')) # too big
self.assertRaises(ValueError, decode_varint, io.BytesIO(b"")) # too short
self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff')) # truncated
self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # too minimally encoded

if __name__ == '__main__':
unittest.main(module="tests.test_varint")

0 comments on commit ee7a5e3

Please sign in to comment.