diff --git a/pulsar/schema/definition.py b/pulsar/schema/definition.py index a810a93..d2796d2 100644 --- a/pulsar/schema/definition.py +++ b/pulsar/schema/definition.py @@ -228,7 +228,7 @@ def validate_type(self, name, val): if val is None and not self._required: return self.default() - if type(val) != self.python_type(): + if not isinstance(val, self.python_type()): raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type()))) return val @@ -309,7 +309,7 @@ def type(self): return 'float' def python_type(self): - return float + return float, int def default(self): if self._default is not None: @@ -323,7 +323,7 @@ def type(self): return 'double' def python_type(self): - return float + return float, int def default(self): if self._default is not None: @@ -337,7 +337,7 @@ def type(self): return 'bytes' def python_type(self): - return bytes + return bytes, str def default(self): if self._default is not None: @@ -345,13 +345,18 @@ def default(self): else: return None + def validate_type(self, name, val): + if isinstance(val, str): + return val.encode() + return val + class String(Field): def type(self): return 'string' def python_type(self): - return str + return str, bytes def validate_type(self, name, val): t = type(val) @@ -359,8 +364,10 @@ def validate_type(self, name, val): if val is None and not self._required: return self.default() - if not (t is str or t.__name__ == 'unicode'): + if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'): raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name)) + if isinstance(val, bytes): + return val.decode() return val def default(self): @@ -406,7 +413,7 @@ def validate_type(self, name, val): else: raise TypeError( "Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys())) - elif type(val) != self.python_type(): + elif not isinstance(val, self.python_type()): raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type()))) else: return val @@ -450,7 +457,7 @@ def validate_type(self, name, val): super(Array, self).validate_type(name, val) for x in val: - if type(x) != self.array_type.python_type(): + if not isinstance(x, self.array_type.python_type()): raise TypeError('Array field ' + name + ' items should all be of type ' + _string_representation(self.array_type.type())) return val @@ -493,7 +500,7 @@ def validate_type(self, name, val): for k, v in val.items(): if type(k) != str and not is_unicode(k): raise TypeError('Map keys for field ' + name + ' should all be strings') - if type(v) != self.value_type.python_type(): + if not isinstance(v, self.value_type.python_type()): raise TypeError('Map values for field ' + name + ' should all be of type ' + _string_representation(self.value_type.python_type())) diff --git a/pulsar/schema/schema.py b/pulsar/schema/schema.py index b50a1fe..6ca73b1 100644 --- a/pulsar/schema/schema.py +++ b/pulsar/schema/schema.py @@ -101,6 +101,8 @@ def __init__(self, record_cls): def _get_serialized_value(self, o): if isinstance(o, enum.Enum): return o.value + elif isinstance(o, bytes): + return o.decode() else: data = o.__dict__.copy() remove_reserved_key(data) diff --git a/tests/schema_test.py b/tests/schema_test.py index 3e6e9c6..f86ad14 100755 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -35,6 +35,12 @@ format='%(asctime)s %(levelname)-5s %(message)s') +class ExampleRecord(Record): + str_field = String() + int_field = Integer() + float_field = Float() + bytes_field = Bytes() + class SchemaTest(TestCase): serviceUrl = 'pulsar://localhost:6650' @@ -87,6 +93,31 @@ class Example(Record): ] }) + def test_type_promotion(self): + test_cases = [ + (20, int, 20), # No promotion necessary: int => int + (20, float, 20.0), # Promotion: int => float + (20.0, float, 20.0), # No Promotion necessary: float => float + ("Test text1", bytes, b"Test text1"), # Promotion: str => bytes + (b"Test text1", str, "Test text1"), # Promotion: bytes => str + ] + + for value_from, type_to, value_to in test_cases: + if type_to == int: + fieldType = Integer() + elif type_to == float: + fieldType = Double() + elif type_to == str: + fieldType = String() + elif type_to == bytes: + fieldType = Bytes() + else: + fieldType = String() + + field_value = fieldType.validate_type("test_field", value_from) + self.assertEqual(value_to, field_value) + + def test_complex(self): class Color(Enum): red = 1 @@ -229,7 +260,7 @@ class E3(Record): a = Float() E3(a=1.0) # Ok - self._expectTypeError(lambda: E3(a=1)) + E3(a=1) # Ok Type promotion: int -> float class E4(Record): a = Null() @@ -259,7 +290,7 @@ class E7(Record): a = Double() E7(a=1.0) # Ok - self._expectTypeError(lambda: E3(a=1)) + E7(a=1) # Ok Type promotion: int -> double class Color(Enum): red = 1 @@ -1346,5 +1377,37 @@ def verify_messages(msgs: List[pulsar.Message]): client.close() + def test_schema_type_promotion(self): + client = pulsar.Client(self.serviceUrl) + + schemas = [("avro", AvroSchema(ExampleRecord)), ("json", JsonSchema(ExampleRecord))] + + for schema_name, schema in schemas: + topic = f'test_schema_type_promotion_{schema_name}' + + consumer = client.subscribe( + topic=topic, + subscription_name=f'my-sub-{schema_name}', + schema=schema + ) + producer = client.create_producer( + topic=topic, + schema=schema + ) + sendValue = ExampleRecord(str_field=b'test', int_field=1, float_field=3, bytes_field='str') + + producer.send(sendValue) + + msg = consumer.receive() + msg_value = msg.value() + self.assertEqual(msg_value.str_field, sendValue.str_field) + self.assertEqual(msg_value.int_field, sendValue.int_field) + self.assertEqual(msg_value.float_field, sendValue.float_field) + self.assertEqual(msg_value.bytes_field, sendValue.bytes_field) + consumer.acknowledge(msg) + + client.close() + + if __name__ == '__main__': main()