Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Support schema field type promotion #159

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pulsar/schema/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def validate_type(self, name, val):
if val is None and not self._required:
return self.default()

if type(val) != self.python_type():
if not isinstance(val, self.python_type()):
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
return val

Expand Down Expand Up @@ -309,7 +309,7 @@ def type(self):
return 'float'

def python_type(self):
return float
return float, int

def default(self):
if self._default is not None:
Expand All @@ -323,7 +323,7 @@ def type(self):
return 'double'

def python_type(self):
return float
return float, int

def default(self):
if self._default is not None:
Expand All @@ -337,30 +337,37 @@ def type(self):
return 'bytes'

def python_type(self):
return bytes
return bytes, str

def default(self):
if self._default is not None:
return self._default
else:
return None

def validate_type(self, name, val):
if isinstance(val, str):
return val.encode()
return val


class String(Field):
def type(self):
return 'string'

def python_type(self):
return str
return str, bytes

def validate_type(self, name, val):
t = type(val)

if val is None and not self._required:
return self.default()

if not (t is str or t.__name__ == 'unicode'):
if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'):
raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name))
if isinstance(val, bytes):
return val.decode()
return val

def default(self):
Expand Down Expand Up @@ -406,7 +413,7 @@ def validate_type(self, name, val):
else:
raise TypeError(
"Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys()))
elif type(val) != self.python_type():
elif not isinstance(val, self.python_type()):
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
else:
return val
Expand Down Expand Up @@ -450,7 +457,7 @@ def validate_type(self, name, val):
super(Array, self).validate_type(name, val)

for x in val:
if type(x) != self.array_type.python_type():
if not isinstance(x, self.array_type.python_type()):
raise TypeError('Array field ' + name + ' items should all be of type ' +
_string_representation(self.array_type.type()))
return val
Expand Down Expand Up @@ -493,7 +500,7 @@ def validate_type(self, name, val):
for k, v in val.items():
if type(k) != str and not is_unicode(k):
raise TypeError('Map keys for field ' + name + ' should all be strings')
if type(v) != self.value_type.python_type():
if not isinstance(v, self.value_type.python_type()):
raise TypeError('Map values for field ' + name + ' should all be of type '
+ _string_representation(self.value_type.python_type()))

Expand Down
2 changes: 2 additions & 0 deletions pulsar/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def __init__(self, record_cls):
def _get_serialized_value(self, o):
if isinstance(o, enum.Enum):
return o.value
elif isinstance(o, bytes):
return o.decode()
else:
data = o.__dict__.copy()
remove_reserved_key(data)
Expand Down
67 changes: 65 additions & 2 deletions tests/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
format='%(asctime)s %(levelname)-5s %(message)s')


class ExampleRecord(Record):
str_field = String()
int_field = Integer()
float_field = Float()
bytes_field = Bytes()

class SchemaTest(TestCase):

serviceUrl = 'pulsar://localhost:6650'
Expand Down Expand Up @@ -87,6 +93,31 @@ class Example(Record):
]
})

def test_type_promotion(self):
test_cases = [
(20, int, 20), # No promotion necessary: int => int
(20, float, 20.0), # Promotion: int => float
(20.0, float, 20.0), # No Promotion necessary: float => float
("Test text1", bytes, b"Test text1"), # Promotion: str => bytes
(b"Test text1", str, "Test text1"), # Promotion: bytes => str
]

for value_from, type_to, value_to in test_cases:
if type_to == int:
fieldType = Integer()
elif type_to == float:
fieldType = Double()
elif type_to == str:
fieldType = String()
elif type_to == bytes:
fieldType = Bytes()
else:
fieldType = String()

field_value = fieldType.validate_type("test_field", value_from)
self.assertEqual(value_to, field_value)


def test_complex(self):
class Color(Enum):
red = 1
Expand Down Expand Up @@ -229,7 +260,7 @@ class E3(Record):
a = Float()

E3(a=1.0) # Ok
self._expectTypeError(lambda: E3(a=1))
E3(a=1) # Ok Type promotion: int -> float

class E4(Record):
a = Null()
Expand Down Expand Up @@ -259,7 +290,7 @@ class E7(Record):
a = Double()

E7(a=1.0) # Ok
self._expectTypeError(lambda: E3(a=1))
E7(a=1) # Ok Type promotion: int -> double

class Color(Enum):
red = 1
Expand Down Expand Up @@ -1346,5 +1377,37 @@ def verify_messages(msgs: List[pulsar.Message]):

client.close()

def test_schema_type_promotion(self):
client = pulsar.Client(self.serviceUrl)

schemas = [("avro", AvroSchema(ExampleRecord)), ("json", JsonSchema(ExampleRecord))]

for schema_name, schema in schemas:
topic = f'test_schema_type_promotion_{schema_name}'

consumer = client.subscribe(
topic=topic,
subscription_name=f'my-sub-{schema_name}',
schema=schema
)
producer = client.create_producer(
topic=topic,
schema=schema
)
sendValue = ExampleRecord(str_field=b'test', int_field=1, float_field=3, bytes_field='str')

producer.send(sendValue)

msg = consumer.receive()
msg_value = msg.value()
self.assertEqual(msg_value.str_field, sendValue.str_field)
self.assertEqual(msg_value.int_field, sendValue.int_field)
self.assertEqual(msg_value.float_field, sendValue.float_field)
self.assertEqual(msg_value.bytes_field, sendValue.bytes_field)
consumer.acknowledge(msg)

client.close()


if __name__ == '__main__':
main()