diff --git a/README.md b/README.md index 84f0f78..5545d84 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Example: >>> bdecode(bencode([1, 2, b'a', {b'd': 3}])) [1, 2, b'a', {b'd': 3}] +The default ``bencode``/``bdecode`` functions just operate on +bytestrings. Use ``bencode_utf8`` / ``bdecode_utf8`` to +serialize/deserialize all plain strings as UTF-8 bytestrings. +Note that for performance reasons, all dictionary keys still have to be +bytestrings. + License ======= fastbencode is available under the GNU GPL, version 2 or later. diff --git a/fastbencode/_bencode_py.py b/fastbencode/_bencode_py.py index 279d545..6d33f2e 100644 --- a/fastbencode/_bencode_py.py +++ b/fastbencode/_bencode_py.py @@ -115,55 +115,72 @@ def __init__(self, s) -> None: self.bencoded = s -def encode_bencached(x, r): - r.append(x.bencoded) +class BEncoder: + def __init__(self, bytestring_encoding=None): + self.bytestring_encoding = bytestring_encoding + self.encode_func: Dict[Type, Callable[[object, List[bytes]], None]] = { + Bencached: self.encode_bencached, + int: self.encode_int, + bytes: self.encode_bytes, + list: self.encode_list, + tuple: self.encode_list, + dict: self.encode_dict, + bool: self.encode_bool, + str: self.encode_str, + } + + def encode_bencached(self, x, r): + r.append(x.bencoded) -def encode_bool(x, r): - encode_int(int(x), r) + def encode_bool(self, x, r): + self.encode_int(int(x), r) -def encode_int(x, r): - r.extend((b'i', int_to_bytes(x), b'e')) + def encode_int(self, x, r): + r.extend((b'i', int_to_bytes(x), b'e')) -def encode_bytes(x, r): - r.extend((int_to_bytes(len(x)), b':', x)) + def encode_bytes(self, x, r): + r.extend((int_to_bytes(len(x)), b':', x)) -def encode_list(x, r): - r.append(b'l') - for i in x: - encode_func[type(i)](i, r) - r.append(b'e') + def encode_list(self, x, r): + r.append(b'l') + for i in x: + self.encode(i, r) + r.append(b'e') -def encode_dict(x, r): - r.append(b'd') - ilist = sorted(x.items()) - for k, v in ilist: - r.extend((int_to_bytes(len(k)), b':', k)) - encode_func[type(v)](v, r) - r.append(b'e') + def encode_dict(self, x, r): + r.append(b'd') + ilist = sorted(x.items()) + for k, v in ilist: + r.extend((int_to_bytes(len(k)), b':', k)) + self.encode(v, r) + r.append(b'e') + def encode_str(self, x, r): + if self.bytestring_encoding is None: + raise TypeError("string found but no encoding specified. " + "Use bencode_utf8 rather bencode?") + return self.encode_bytes(x.encode(self.bytestring_encoding), r) -encode_func: Dict[Type, Callable[[object, List[bytes]], None]] = {} -encode_func[type(Bencached(0))] = encode_bencached -encode_func[int] = encode_int + def encode(self, x, r): + self.encode_func[type(x)](x, r) def int_to_bytes(n): return b'%d' % n - -encode_func[bytes] = encode_bytes -encode_func[list] = encode_list -encode_func[tuple] = encode_list -encode_func[dict] = encode_dict -encode_func[bool] = encode_bool - - def bencode(x): r = [] - encode_func[type(x)](x, r) + encoder = BEncoder() + encoder.encode(x, r) + return b''.join(r) + +def bencode_utf8(x): + r = [] + encoder = BEncoder(bytestring_encoding='utf-8') + encoder.encode(x, r) return b''.join(r) diff --git a/fastbencode/_bencode_pyx.pyx b/fastbencode/_bencode_pyx.pyx index 025bca4..32a9ef3 100644 --- a/fastbencode/_bencode_pyx.pyx +++ b/fastbencode/_bencode_pyx.pyx @@ -49,6 +49,7 @@ from cpython.mem cimport ( from cpython.unicode cimport ( PyUnicode_FromEncodedObject, PyUnicode_FromStringAndSize, + PyUnicode_Check, ) from cpython.tuple cimport ( PyTuple_CheckExact, @@ -282,8 +283,9 @@ cdef class Encoder: cdef readonly int size cdef readonly char *buffer cdef readonly int maxsize + cdef readonly object _bytestring_encoding - def __init__(self, int maxsize=INITSIZE): + def __init__(self, int maxsize=INITSIZE, str bytestring_encoding=None): """Initialize encoder engine @param maxsize: initial size of internal char buffer """ @@ -301,6 +303,8 @@ cdef class Encoder: self.maxsize = maxsize self.tail = p + self._bytestring_encoding = bytestring_encoding + def __dealloc__(self): PyMem_Free(self.buffer) self.buffer = NULL @@ -369,6 +373,12 @@ cdef class Encoder: E_UPDATE_TAIL(self, n + x_len) return 1 + cdef int _encode_string(self, x) except 0: + if self._bytestring_encoding is None: + raise TypeError("string found but no encoding specified. " + "Use bencode_utf8 rather bencode?") + return self._encode_bytes(x.encode(self._bytestring_encoding)) + cdef int _encode_list(self, x) except 0: self._ensure_buffer(1) self.tail[0] = c'l' @@ -413,6 +423,8 @@ cdef class Encoder: self._encode_dict(x) elif PyBool_Check(x): self._encode_int(int(x)) + elif PyUnicode_Check(x): + self._encode_string(x) elif isinstance(x, Bencached): self._append_string(x.bencoded) else: @@ -422,7 +434,17 @@ cdef class Encoder: def bencode(x): - """Encode Python object x to string""" + """Encode Python object x to bytestring""" encoder = Encoder() encoder.process(x) return encoder.to_bytes() + + +def bencode_utf8(x): + """Encode Python object x to bytestring. + + Encode any strings as UTF8 + """ + encoder = Encoder(bytestring_encoding='utf-8') + encoder.process(x) + return encoder.to_bytes() diff --git a/fastbencode/tests/test_bencode.py b/fastbencode/tests/test_bencode.py index 7073f42..ce897eb 100644 --- a/fastbencode/tests/test_bencode.py +++ b/fastbencode/tests/test_bencode.py @@ -454,3 +454,32 @@ def test_invalid_dict(self): def test_bool(self): self._check(b'i1e', True) self._check(b'i0e', False) + + +class TestBencodeEncodeUtf8(TestCase): + + module = None + + def _check(self, expected, source): + self.assertEqual(expected, self.module.bencode_utf8(source)) + + def test_string(self): + self._check(b'0:', '') + self._check(b'3:abc', 'abc') + self._check(b'10:1234567890', '1234567890') + + def test_list(self): + self._check(b'le', []) + self._check(b'li1ei2ei3ee', [1, 2, 3]) + self._check(b'll5:Alice3:Bobeli2ei3eee', [['Alice', 'Bob'], [2, 3]]) + + def test_list_as_tuple(self): + self._check(b'le', ()) + self._check(b'li1ei2ei3ee', (1, 2, 3)) + self._check(b'll5:Alice3:Bobeli2ei3eee', (('Alice', 'Bob'), (2, 3))) + + def test_dict(self): + self._check(b'de', {}) + self._check(b'd3:agei25e4:eyes4:bluee', {b'age': 25, b'eyes': 'blue'}) + self._check(b'd8:spam.mp3d6:author5:Alice6:lengthi100000eee', + {b'spam.mp3': {b'author': b'Alice', b'length': 100000}})