From e4e8250b47eca460ddb0b1135e798908e66dccad Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Thu, 13 Jun 2024 10:18:09 -0500 Subject: [PATCH 01/25] Serialized all scalar data types for message_serializer with all tests passing. --- .../_internal/parameter/message_serializer.py | 78 +++++++++++++++++++ tests/unit/test_message_serializer.py | 39 ++++++++++ 2 files changed, 117 insertions(+) create mode 100644 ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py create mode 100644 tests/unit/test_message_serializer.py diff --git a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py new file mode 100644 index 000000000..17cb4b368 --- /dev/null +++ b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, Sequence +from uuid import uuid4 +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory + +from ni_measurement_plugin_sdk._internal.parameter.metadata import ParameterMetadata +from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter + +def test() -> None: + cur_values = [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + ] + + # Serialize parameter_values using ParameterMetaData + encoded_value_with_message = SerializeWithMessageInstance( + parameter_metadata_dict=currentParameter(cur_values), + parameter_values=cur_values) + + print() + print(f"Serialized value: {encoded_value_with_message}") + + +def SerializeWithMessageInstance( + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_values: Sequence[Any] +) -> bytes: + + # Creates a protobuf file to put descriptor stuff in + pool = descriptor_pool.Default() + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + original_guid = uuid4() + new_guid = 'msg' + ''.join(filter(str.isalnum, str(original_guid)))[:16] + file_descriptor_proto.name = str(new_guid) + file_descriptor_proto.package = str(new_guid) + + # Create a DescriptorProto for the message + message_proto = file_descriptor_proto.message_type.add() + message_proto.name = str(new_guid) + + # Creates a new field for each parameter_value and defines it + for i, value in enumerate(parameter_values, start=1): + field_descriptor = message_proto.field.add() + parameter_metadata = parameter_metadata_dict[i] + + field_descriptor.number = i + field_descriptor.name = f"field_{i}" + field_descriptor.type = parameter_metadata.type + field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + # TODO: Learn how nested messages encode + # TODO: Fix on encoding lists, uses varint and not length-delimited + + # Add fields to message and assign the message to a variable + pool.Add(file_descriptor_proto) + message_descriptor = pool.FindMessageTypeByName(str(new_guid) + '.' + str(new_guid)) + DynamicMessage = message_factory.GetMessageClass(message_descriptor) + message_instance = DynamicMessage() + + #set fields to values and then serialize them + for i, value in enumerate(parameter_values, start=1): + field_name = f"field_{i}" + setattr(message_instance, field_name, value) + + serialized_value = message_instance.SerializeToString() + return serialized_value + + +def main(**kwargs: Any) -> None: + test() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py new file mode 100644 index 000000000..54bc7df12 --- /dev/null +++ b/tests/unit/test_message_serializer.py @@ -0,0 +1,39 @@ +import pytest + +from ni_measurement_plugin_sdk._internal.parameter.message_serializer import SerializeWithMessageInstance as new_serializer +from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter +from ni_measurement_plugin_sdk._internal.parameter import serializer + +@pytest.mark.parametrize( + "test_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString" + ], + ] +) + +def test___serializer___serialize_parameter___successful_serialization(test_values): + default_values = test_values + parameter = currentParameter(default_values) + + new_serialize = new_serializer( + parameter_metadata_dict=parameter, + parameter_values=test_values) + current_serialize = serializer.serialize_parameters( + parameter_metadata_dict=parameter, + parameter_values=test_values) + + print() + print(f"Current Serializer: {current_serialize}") + print() + print(f"New Serializer: {new_serialize}") + + assert new_serialize == current_serialize \ No newline at end of file From 54bf7df23acf09040fc161a0e52203b5559e539e Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Fri, 14 Jun 2024 09:20:20 -0500 Subject: [PATCH 02/25] Added array data types to message_serializer --- .../_internal/parameter/message_serializer.py | 28 +++++++++++++++---- tests/unit/test_message_serializer.py | 20 +++++++++++-- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py index 17cb4b368..5f1e41cc4 100644 --- a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py @@ -15,6 +15,14 @@ def test() -> None: 2, True, "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], ] # Serialize parameter_values using ParameterMetaData @@ -23,12 +31,12 @@ def test() -> None: parameter_values=cur_values) print() - print(f"Serialized value: {encoded_value_with_message}") + print(f"New Serialized value: {encoded_value_with_message}") def SerializeWithMessageInstance( parameter_metadata_dict: Dict[int, ParameterMetadata], - parameter_values: Sequence[Any] + parameter_values: Sequence[Any], ) -> bytes: # Creates a protobuf file to put descriptor stuff in @@ -51,10 +59,15 @@ def SerializeWithMessageInstance( field_descriptor.number = i field_descriptor.name = f"field_{i}" field_descriptor.type = parameter_metadata.type - field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + # if a value is an array then it's labled as repeated and packed in the field + if parameter_metadata.repeated: + field_descriptor.options.packed = True + field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + else: + field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + # field_descriptor.options.packed = False # TODO: Learn how nested messages encode - # TODO: Fix on encoding lists, uses varint and not length-delimited # Add fields to message and assign the message to a variable pool.Add(file_descriptor_proto) @@ -65,12 +78,15 @@ def SerializeWithMessageInstance( #set fields to values and then serialize them for i, value in enumerate(parameter_values, start=1): field_name = f"field_{i}" - setattr(message_instance, field_name, value) + if isinstance(value, list): + repeated_field = getattr(message_instance, field_name) + repeated_field.extend(value) + else: + setattr(message_instance, field_name, value) serialized_value = message_instance.SerializeToString() return serialized_value - def main(**kwargs: Any) -> None: test() diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py index 54bc7df12..3ceba5f7c 100644 --- a/tests/unit/test_message_serializer.py +++ b/tests/unit/test_message_serializer.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( "test_values", [ - [ + [ 2.0, 19.2, 3, @@ -15,7 +15,15 @@ 2, 2, True, - "TestString" + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], ], ] ) @@ -36,4 +44,10 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu print() print(f"New Serializer: {new_serialize}") - assert new_serialize == current_serialize \ No newline at end of file + assert new_serialize == current_serialize + +def main() -> None: + test___serializer___serialize_parameter___successful_serialization(0) + +if __name__ == "__main__": + main() \ No newline at end of file From e55d6d0c425e84be3079d5aafc3e27a536957ab9 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 18 Jun 2024 09:17:46 -0500 Subject: [PATCH 03/25] Added enums and sub functions to message_serializer --- .../_internal/parameter/message_serializer.py | 123 ++++++++++++++---- tests/unit/test_message_serializer.py | 29 ++++- 2 files changed, 127 insertions(+), 25 deletions(-) diff --git a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py index 5f1e41cc4..d550719de 100644 --- a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py @@ -2,9 +2,14 @@ from uuid import uuid4 from google.protobuf import descriptor_pb2, descriptor_pool, message_factory +# metadata from ni_measurement_plugin_sdk._internal.parameter.metadata import ParameterMetadata from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter +# enums and default values +from enum import Enum, IntEnum +from ni_measurement_plugin_sdk._internal.parameter.serialization_strategy import get_type_default + def test() -> None: cur_values = [ 2.0, @@ -23,6 +28,10 @@ def test() -> None: [0, 1, 399], [True, False, True], ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], ] # Serialize parameter_values using ParameterMetaData @@ -36,7 +45,7 @@ def test() -> None: def SerializeWithMessageInstance( parameter_metadata_dict: Dict[int, ParameterMetadata], - parameter_values: Sequence[Any], + parameter_values: Sequence[Any] ) -> bytes: # Creates a protobuf file to put descriptor stuff in @@ -51,22 +60,29 @@ def SerializeWithMessageInstance( message_proto = file_descriptor_proto.message_type.add() message_proto.name = str(new_guid) - # Creates a new field for each parameter_value and defines it - for i, value in enumerate(parameter_values, start=1): - field_descriptor = message_proto.field.add() + # Define fields in message + for i, parameter in enumerate(parameter_values, start=1): parameter_metadata = parameter_metadata_dict[i] + is_enum = parameter_metadata.type == descriptor_pb2.FieldDescriptorProto.TYPE_ENUM - field_descriptor.number = i - field_descriptor.name = f"field_{i}" - field_descriptor.type = parameter_metadata.type - # if a value is an array then it's labled as repeated and packed in the field - if parameter_metadata.repeated: - field_descriptor.options.packed = True - field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED - else: - field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL - # field_descriptor.options.packed = False - + # if value is an enum and a list, set parameter to the 1st element in value + if is_enum and isinstance(parameter, list): + parameter = parameter[0] + # if value is not an enum and doesn't equal to it's default value, define fields + if not is_enum or parameter.value != get_type_default(parameter_metadata.type, parameter_metadata.repeated): + field_descriptor = message_proto.field.add() + define_fields( + field_descriptor=field_descriptor, + metadata=parameter_metadata_dict, + i=i + ) + if is_enum: + define_enums( + file_descriptor=file_descriptor_proto, + param=parameter, + field_descriptor=field_descriptor + ) + # TODO: Learn how nested messages encode # Add fields to message and assign the message to a variable @@ -76,17 +92,80 @@ def SerializeWithMessageInstance( message_instance = DynamicMessage() #set fields to values and then serialize them - for i, value in enumerate(parameter_values, start=1): - field_name = f"field_{i}" - if isinstance(value, list): - repeated_field = getattr(message_instance, field_name) - repeated_field.extend(value) - else: - setattr(message_instance, field_name, value) + for i, parameter in enumerate(parameter_values, start=1): + try: + field_name = f"field_{i}" + parameter = get_enum_values(param=parameter) + if isinstance(parameter, list): + repeated_field = getattr(message_instance, field_name) + repeated_field.extend(parameter) + else: + setattr(message_instance, field_name, parameter) + except: + # goes here if enum is equal to it's default value + i += 1 serialized_value = message_instance.SerializeToString() return serialized_value +def get_enum_values(param): + # if param is a list of enums, return values of them in a list + # or param is an enum, returns the value of it + # else it doesn nothing to param + if isinstance(param, list) and isinstance(param[0], Enum): + return [x.value for x in param] + elif isinstance(param, Enum): + return param.value + return param + +def define_enums(file_descriptor, param, field_descriptor): + # if there are no enums or param is a different enum from ones defined before, creates a new enum + if file_descriptor.enum_type == [] or param.__class__.__name__ not in [enum.name for enum in file_descriptor.enum_type]: + # Define a enum class + enum_descriptor = file_descriptor.enum_type.add() + enum_descriptor.name = param.__class__.__name__ + field_descriptor.type_name = enum_descriptor.name + + # Add constants to enum class + for name, number in param.__class__.__members__.items(): + enum_value_descriptor = enum_descriptor.value.add() + enum_value_descriptor.name = name + enum_value_descriptor.number = number.value + else: + field_descriptor.type_name = param.__class__.__name__ + +def define_fields(field_descriptor, metadata, i): + parameter_metadata = metadata[i] + + field_descriptor.number = i + field_descriptor.name = f"field_{i}" + field_descriptor.type = parameter_metadata.type + # if a value is an array then it's labled as repeated and packed + if parameter_metadata.repeated: + field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + + +class DifferentColor(Enum): + """Non-primary colors used for testing enum-typed config and output.""" + + PURPLE = 0 + ORANGE = 1 + TEAL = 2 + BROWN = 3 + + +class Countries(IntEnum): + """Countries enum used for testing enum-typed config and output.""" + + AMERICA = 0 + TAIWAN = 1 + AUSTRALIA = 2 + CANADA = 3 + def main(**kwargs: Any) -> None: test() diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py index 3ceba5f7c..10270528a 100644 --- a/tests/unit/test_message_serializer.py +++ b/tests/unit/test_message_serializer.py @@ -4,6 +4,25 @@ from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter from ni_measurement_plugin_sdk._internal.parameter import serializer +from enum import Enum, IntEnum + +class DifferentColor(Enum): + """Non-primary colors used for testing enum-typed config and output.""" + + PURPLE = 0 + ORANGE = 1 + TEAL = 2 + BROWN = 3 + + +class Countries(IntEnum): + """Countries enum used for testing enum-typed config and output.""" + + AMERICA = 0 + TAIWAN = 1 + AUSTRALIA = 2 + CANADA = 3 + @pytest.mark.parametrize( "test_values", [ @@ -24,17 +43,21 @@ [0, 1, 399], [True, False, True], ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], ], ] ) def test___serializer___serialize_parameter___successful_serialization(test_values): - default_values = test_values - parameter = currentParameter(default_values) + parameter = currentParameter(test_values) new_serialize = new_serializer( parameter_metadata_dict=parameter, - parameter_values=test_values) + parameter_values=test_values, + current_encoded_value=0) current_serialize = serializer.serialize_parameters( parameter_metadata_dict=parameter, parameter_values=test_values) From 3131ec3f4b54f43a34ca05e6c274d34ce94ee091 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Thu, 20 Jun 2024 14:13:26 -0500 Subject: [PATCH 04/25] Added messages and refactored message_serializer --- .../_internal/parameter/message_serializer.py | 151 +++++++++++------- tests/unit/test_message_serializer.py | 88 +++++++--- 2 files changed, 152 insertions(+), 87 deletions(-) diff --git a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py index d550719de..95fd06331 100644 --- a/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk/_internal/parameter/message_serializer.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Sequence from uuid import uuid4 from google.protobuf import descriptor_pb2, descriptor_pool, message_factory +from google.protobuf.descriptor_pb2 import FieldDescriptorProto as field # metadata from ni_measurement_plugin_sdk._internal.parameter.metadata import ParameterMetadata @@ -8,6 +9,7 @@ # enums and default values from enum import Enum, IntEnum +from ni_measurement_plugin_sdk._internal.stubs.ni.protobuf.types import xydata_pb2 from ni_measurement_plugin_sdk._internal.parameter.serialization_strategy import get_type_default def test() -> None: @@ -32,20 +34,22 @@ def test() -> None: [DifferentColor.TEAL, DifferentColor.BROWN], Countries.AUSTRALIA, [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array ] # Serialize parameter_values using ParameterMetaData - encoded_value_with_message = SerializeWithMessageInstance( + message_serializer = SerializeWithMessageInstance( parameter_metadata_dict=currentParameter(cur_values), parameter_values=cur_values) print() - print(f"New Serialized value: {encoded_value_with_message}") + print(f"Message Serialized value: {message_serializer}") def SerializeWithMessageInstance( parameter_metadata_dict: Dict[int, ParameterMetadata], - parameter_values: Sequence[Any] + parameter_values: Sequence[Any], ) -> bytes: # Creates a protobuf file to put descriptor stuff in @@ -60,55 +64,61 @@ def SerializeWithMessageInstance( message_proto = file_descriptor_proto.message_type.add() message_proto.name = str(new_guid) - # Define fields in message + # Initialize the message with fields defined for i, parameter in enumerate(parameter_values, start=1): parameter_metadata = parameter_metadata_dict[i] - is_enum = parameter_metadata.type == descriptor_pb2.FieldDescriptorProto.TYPE_ENUM - - # if value is an enum and a list, set parameter to the 1st element in value - if is_enum and isinstance(parameter, list): - parameter = parameter[0] - # if value is not an enum and doesn't equal to it's default value, define fields - if not is_enum or parameter.value != get_type_default(parameter_metadata.type, parameter_metadata.repeated): - field_descriptor = message_proto.field.add() - define_fields( - field_descriptor=field_descriptor, - metadata=parameter_metadata_dict, - i=i - ) - if is_enum: - define_enums( - file_descriptor=file_descriptor_proto, - param=parameter, - field_descriptor=field_descriptor - ) - - # TODO: Learn how nested messages encode - - # Add fields to message and assign the message to a variable + is_enum = parameter_metadata.type == field.TYPE_ENUM + # Define fields + field_descriptor = _define_fields( + message_proto=message_proto, + parameter_metadata=parameter_metadata, + i=i, + param=parameter, + is_enum=is_enum) + # define enums if it's an enum and there's a field + if is_enum and field_descriptor is not None: + _define_enums( + file_descriptor=file_descriptor_proto, + param=parameter, + field_descriptor=field_descriptor) + + # Get message and add fields to it pool.Add(file_descriptor_proto) message_descriptor = pool.FindMessageTypeByName(str(new_guid) + '.' + str(new_guid)) - DynamicMessage = message_factory.GetMessageClass(message_descriptor) - message_instance = DynamicMessage() + message_instance = message_factory.GetMessageClass(message_descriptor)() - #set fields to values and then serialize them + #assign values to fields for i, parameter in enumerate(parameter_values, start=1): - try: - field_name = f"field_{i}" - parameter = get_enum_values(param=parameter) - if isinstance(parameter, list): - repeated_field = getattr(message_instance, field_name) - repeated_field.extend(parameter) + field_name = f"field_{i}" + parameter_metadata = parameter_metadata_dict[i] + parameter = _get_enum_values(param=parameter) + try: + if parameter_metadata.repeated: + getattr(message_instance, field_name).extend(parameter) + elif parameter_metadata.type == field.TYPE_MESSAGE: + getattr(message_instance, field_name).CopyFrom(parameter) else: setattr(message_instance, field_name, parameter) except: - # goes here if enum is equal to it's default value - i += 1 - - serialized_value = message_instance.SerializeToString() - return serialized_value - -def get_enum_values(param): + i += 1 # no field: parameter is None or equal to default value + return message_instance.SerializeToString() + +def _equal_to_default_value(metadata, param, is_enum): + default_value = get_type_default( + metadata.type, + metadata.repeated) + # gets value of enum + if is_enum: + if metadata.repeated: + param = param[0].value + else: + param = param.value + # return true if param is None or eqaul to default value + if param == default_value or param == None: + return True + return False + +def _get_enum_values(param): # if param is a list of enums, return values of them in a list # or param is an enum, returns the value of it # else it doesn nothing to param @@ -118,7 +128,10 @@ def get_enum_values(param): return param.value return param -def define_enums(file_descriptor, param, field_descriptor): +def _define_enums(file_descriptor, param, field_descriptor): + # if param is a list, then it sets param to 1st element in list + if isinstance(param, list): + param = param[0] # if there are no enums or param is a different enum from ones defined before, creates a new enum if file_descriptor.enum_type == [] or param.__class__.__name__ not in [enum.name for enum in file_descriptor.enum_type]: # Define a enum class @@ -134,20 +147,27 @@ def define_enums(file_descriptor, param, field_descriptor): else: field_descriptor.type_name = param.__class__.__name__ -def define_fields(field_descriptor, metadata, i): - parameter_metadata = metadata[i] - - field_descriptor.number = i - field_descriptor.name = f"field_{i}" - field_descriptor.type = parameter_metadata.type - # if a value is an array then it's labled as repeated and packed - if parameter_metadata.repeated: - field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED - field_descriptor.options.packed = True - else: - field_descriptor.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL - - +def _define_fields(message_proto, parameter_metadata, i, param, is_enum): + # exits if param is None or eqaul to default value + if not _equal_to_default_value( + metadata=parameter_metadata, + param=param, + is_enum=is_enum): + field_descriptor = message_proto.field.add() + + field_descriptor.number = i + field_descriptor.name = f"field_{i}" + field_descriptor.type = parameter_metadata.type + # if a value is an array then it's labled as repeated and packed + if parameter_metadata.repeated: + field_descriptor.label = field.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = field.LABEL_OPTIONAL + # if a value is a message then assign type name to it's full name + if parameter_metadata.type == field.TYPE_MESSAGE: + field_descriptor.type_name = parameter_metadata.message_type + return field_descriptor class DifferentColor(Enum): """Non-primary colors used for testing enum-typed config and output.""" @@ -166,8 +186,17 @@ class Countries(IntEnum): AUSTRALIA = 2 CANADA = 3 -def main(**kwargs: Any) -> None: - test() +double_xy_data = xydata_pb2.DoubleXYData() +double_xy_data.x_data.append(4) +double_xy_data.y_data.append(6) + +double_xy_data2 = xydata_pb2.DoubleXYData() +double_xy_data2.x_data.append(8) +double_xy_data2.y_data.append(10) + +double_xy_data_array = [double_xy_data, double_xy_data2] +# This should match the number of fields in bigmessage.proto. +BIG_MESSAGE_SIZE = 100 if __name__ == "__main__": - main() \ No newline at end of file + test() \ No newline at end of file diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py index 10270528a..19ed66194 100644 --- a/tests/unit/test_message_serializer.py +++ b/tests/unit/test_message_serializer.py @@ -5,6 +5,7 @@ from ni_measurement_plugin_sdk._internal.parameter import serializer from enum import Enum, IntEnum +from ni_measurement_plugin_sdk._internal.stubs.ni.protobuf.types import xydata_pb2 class DifferentColor(Enum): """Non-primary colors used for testing enum-typed config and output.""" @@ -23,32 +24,68 @@ class Countries(IntEnum): AUSTRALIA = 2 CANADA = 3 +double_xy_data = xydata_pb2.DoubleXYData() +double_xy_data.x_data.append(4) +double_xy_data.y_data.append(6) + +double_xy_data2 = xydata_pb2.DoubleXYData() +double_xy_data2.x_data.append(8) +double_xy_data2.y_data.append(10) + +double_xy_data_array = [double_xy_data, double_xy_data2] + @pytest.mark.parametrize( "test_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - ], - ] + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array + ], + [ + -0.9999, + -0.9999, + -13, + 1, + 1000, + 2, + True, + "////", + [5.5, -13.3, 1, 0.0, -99.9999], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ] + ] ) def test___serializer___serialize_parameter___successful_serialization(test_values): @@ -56,8 +93,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu new_serialize = new_serializer( parameter_metadata_dict=parameter, - parameter_values=test_values, - current_encoded_value=0) + parameter_values=test_values,) current_serialize = serializer.serialize_parameters( parameter_metadata_dict=parameter, parameter_values=test_values) From 777038381a9f28c80179d17ad0cea715e60a718a Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Fri, 21 Jun 2024 09:59:57 -0500 Subject: [PATCH 05/25] Modified message_serializer to pass all the tests --- .../_internal/parameter/message_serializer.py | 42 +++++++++++-------- tests/unit/test_message_serializer.py | 16 +++---- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py index 95fd06331..a3f412b4b 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py @@ -4,13 +4,15 @@ from google.protobuf.descriptor_pb2 import FieldDescriptorProto as field # metadata -from ni_measurement_plugin_sdk._internal.parameter.metadata import ParameterMetadata +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter # enums and default values from enum import Enum, IntEnum -from ni_measurement_plugin_sdk._internal.stubs.ni.protobuf.types import xydata_pb2 -from ni_measurement_plugin_sdk._internal.parameter.serialization_strategy import get_type_default +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import get_type_default, _TYPE_DEFAULT_MAPPING + +from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor def test() -> None: cur_values = [ @@ -39,7 +41,7 @@ def test() -> None: ] # Serialize parameter_values using ParameterMetaData - message_serializer = SerializeWithMessageInstance( + message_serializer = serialize_parameters( parameter_metadata_dict=currentParameter(cur_values), parameter_values=cur_values) @@ -47,7 +49,7 @@ def test() -> None: print(f"Message Serialized value: {message_serializer}") -def SerializeWithMessageInstance( +def serialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_values: Sequence[Any], ) -> bytes: @@ -67,16 +69,16 @@ def SerializeWithMessageInstance( # Initialize the message with fields defined for i, parameter in enumerate(parameter_values, start=1): parameter_metadata = parameter_metadata_dict[i] - is_enum = parameter_metadata.type == field.TYPE_ENUM + is_python_enum = isinstance(parameter, Enum) # Define fields field_descriptor = _define_fields( message_proto=message_proto, parameter_metadata=parameter_metadata, i=i, param=parameter, - is_enum=is_enum) - # define enums if it's an enum and there's a field - if is_enum and field_descriptor is not None: + is_python_enum=is_python_enum) + # define enums if it's a regular or a protobuf enum and there's a field + if parameter_metadata.type == field.TYPE_ENUM and field_descriptor is not None: _define_enums( file_descriptor=file_descriptor_proto, param=parameter, @@ -103,12 +105,12 @@ def SerializeWithMessageInstance( i += 1 # no field: parameter is None or equal to default value return message_instance.SerializeToString() -def _equal_to_default_value(metadata, param, is_enum): +def _equal_to_default_value(metadata, param, is_python_enum): default_value = get_type_default( metadata.type, metadata.repeated) - # gets value of enum - if is_enum: + # gets value from a regular python enum + if is_python_enum: if metadata.repeated: param = param[0].value else: @@ -119,6 +121,8 @@ def _equal_to_default_value(metadata, param, is_enum): return False def _get_enum_values(param): + if param == []: + return param # if param is a list of enums, return values of them in a list # or param is an enum, returns the value of it # else it doesn nothing to param @@ -132,27 +136,29 @@ def _define_enums(file_descriptor, param, field_descriptor): # if param is a list, then it sets param to 1st element in list if isinstance(param, list): param = param[0] - # if there are no enums or param is a different enum from ones defined before, creates a new enum - if file_descriptor.enum_type == [] or param.__class__.__name__ not in [enum.name for enum in file_descriptor.enum_type]: + # if there are no enums/param is a different enum and is a python enum, defines a enum field + if param.__class__.__name__ not in [enum.name for enum in file_descriptor.enum_type] and isinstance(param, Enum): # Define a enum class enum_descriptor = file_descriptor.enum_type.add() enum_descriptor.name = param.__class__.__name__ - field_descriptor.type_name = enum_descriptor.name # Add constants to enum class for name, number in param.__class__.__members__.items(): enum_value_descriptor = enum_descriptor.value.add() enum_value_descriptor.name = name enum_value_descriptor.number = number.value - else: + # checks enum if it's protobuf or python + try: + field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name + except: field_descriptor.type_name = param.__class__.__name__ -def _define_fields(message_proto, parameter_metadata, i, param, is_enum): +def _define_fields(message_proto, parameter_metadata, i, param, is_python_enum): # exits if param is None or eqaul to default value if not _equal_to_default_value( metadata=parameter_metadata, param=param, - is_enum=is_enum): + is_python_enum=is_python_enum): field_descriptor = message_proto.field.add() field_descriptor.number = i diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py index 19ed66194..f111ed762 100644 --- a/tests/unit/test_message_serializer.py +++ b/tests/unit/test_message_serializer.py @@ -1,11 +1,11 @@ import pytest -from ni_measurement_plugin_sdk._internal.parameter.message_serializer import SerializeWithMessageInstance as new_serializer +from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import serialize_parameters as new_serializer from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter -from ni_measurement_plugin_sdk._internal.parameter import serializer +from ni_measurement_plugin_sdk_service._internal.parameter import serializer from enum import Enum, IntEnum -from ni_measurement_plugin_sdk._internal.stubs.ni.protobuf.types import xydata_pb2 +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 class DifferentColor(Enum): """Non-primary colors used for testing enum-typed config and output.""" @@ -101,12 +101,6 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu print() print(f"Current Serializer: {current_serialize}") print() - print(f"New Serializer: {new_serialize}") + print(f"Message Serializer: {new_serialize}") - assert new_serialize == current_serialize - -def main() -> None: - test___serializer___serialize_parameter___successful_serialization(0) - -if __name__ == "__main__": - main() \ No newline at end of file + assert new_serialize == current_serialize \ No newline at end of file From ea6934d34fc68597632ea08e3f25080a9f1aae82 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 09:48:46 -0500 Subject: [PATCH 06/25] Switched current serializer with message_serializer and removed sub functions calling it --- .../_internal/grpc_servicer.py | 19 +- .../_internal/parameter/message_serializer.py | 175 ++++++++---------- .../_internal/parameter/metadata.py | 15 +- .../parameter/serialization_strategy.py | 111 +---------- .../_internal/parameter/serializer.py | 79 +------- tests/unit/test_serialization_strategy.py | 33 +--- tests/unit/test_serializer.py | 8 +- 7 files changed, 112 insertions(+), 328 deletions(-) diff --git a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index ae29db78e..34dd34614 100644 --- a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -13,8 +13,13 @@ import grpc from google.protobuf import any_pb2 -from ni_measurement_plugin_sdk_service._internal.parameter import serializer -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata +from ni_measurement_plugin_sdk_service._internal.parameter import ( + message_serializer, + serializer, +) +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v1 import ( measurement_service_pb2 as v1_measurement_service_pb2, measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc, @@ -133,7 +138,7 @@ def _get_mapping_by_parameter_name( def _serialize_outputs(output_metadata: Dict[int, ParameterMetadata], outputs: Any) -> any_pb2.Any: if isinstance(outputs, collections.abc.Sequence): - return any_pb2.Any(value=serializer.serialize_parameters(output_metadata, outputs)) + return any_pb2.Any(value=message_serializer.serialize_parameters(output_metadata, outputs)) elif outputs is None: raise ValueError(f"Measurement function returned None") else: @@ -193,8 +198,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = serializer.serialize_default_values( - self._configuration_metadata + measurement_signature.configuration_defaults.value = ( + message_serializer.serialize_default_values(self._configuration_metadata) ) for field_number, output_metadata in self._output_metadata.items(): @@ -301,8 +306,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = serializer.serialize_default_values( - self._configuration_metadata + measurement_signature.configuration_defaults.value = ( + message_serializer.serialize_default_values(self._configuration_metadata) ) for field_number, output_metadata in self._output_metadata.items(): diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py index a3f412b4b..7036df5e2 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py @@ -1,64 +1,27 @@ +# enums and default values +from enum import Enum from typing import Any, Dict, Sequence from uuid import uuid4 -from google.protobuf import descriptor_pb2, descriptor_pool, message_factory -from google.protobuf.descriptor_pb2 import FieldDescriptorProto as field - -# metadata -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata -from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter -# enums and default values -from enum import Enum, IntEnum -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 -from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import get_type_default, _TYPE_DEFAULT_MAPPING +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory, type_pb2 +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +# metadata +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor -def test() -> None: - cur_values = [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array - ] - - # Serialize parameter_values using ParameterMetaData - message_serializer = serialize_parameters( - parameter_metadata_dict=currentParameter(cur_values), - parameter_values=cur_values) - - print() - print(f"Message Serialized value: {message_serializer}") - def serialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_values: Sequence[Any], ) -> bytes: - - # Creates a protobuf file to put descriptor stuff in + """Test doc string?""" pool = descriptor_pool.Default() file_descriptor_proto = descriptor_pb2.FileDescriptorProto() original_guid = uuid4() - new_guid = 'msg' + ''.join(filter(str.isalnum, str(original_guid)))[:16] + new_guid = "msg" + "".join(filter(str.isalnum, str(original_guid)))[:16] file_descriptor_proto.name = str(new_guid) file_descriptor_proto.package = str(new_guid) @@ -76,20 +39,25 @@ def serialize_parameters( parameter_metadata=parameter_metadata, i=i, param=parameter, - is_python_enum=is_python_enum) + is_python_enum=is_python_enum, + ) # define enums if it's a regular or a protobuf enum and there's a field - if parameter_metadata.type == field.TYPE_ENUM and field_descriptor is not None: + if ( + parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM + and field_descriptor is not None + ): _define_enums( file_descriptor=file_descriptor_proto, param=parameter, - field_descriptor=field_descriptor) + field_descriptor=field_descriptor, + ) # Get message and add fields to it pool.Add(file_descriptor_proto) - message_descriptor = pool.FindMessageTypeByName(str(new_guid) + '.' + str(new_guid)) + message_descriptor = pool.FindMessageTypeByName(str(new_guid) + "." + str(new_guid)) message_instance = message_factory.GetMessageClass(message_descriptor)() - #assign values to fields + # assign values to fields for i, parameter in enumerate(parameter_values, start=1): field_name = f"field_{i}" parameter_metadata = parameter_metadata_dict[i] @@ -97,18 +65,32 @@ def serialize_parameters( try: if parameter_metadata.repeated: getattr(message_instance, field_name).extend(parameter) - elif parameter_metadata.type == field.TYPE_MESSAGE: + elif parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE: getattr(message_instance, field_name).CopyFrom(parameter) else: setattr(message_instance, field_name, parameter) - except: - i += 1 # no field: parameter is None or equal to default value + except AttributeError: + i += 1 # no field: parameter is None or equal to default value return message_instance.SerializeToString() + +def serialize_default_values(parameter_metadata_dict: Dict[int, ParameterMetadata]) -> bytes: + """Serialize the Default values in the Metadata. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. + + Returns: + bytes: Serialized byte string containing default values. + """ + default_value_parameter_array = [ + parameter.default_value for parameter in parameter_metadata_dict.values() + ] + return serialize_parameters(parameter_metadata_dict, default_value_parameter_array) + + def _equal_to_default_value(metadata, param, is_python_enum): - default_value = get_type_default( - metadata.type, - metadata.repeated) + default_value = get_type_default(metadata.type, metadata.repeated) # gets value from a regular python enum if is_python_enum: if metadata.repeated: @@ -116,15 +98,16 @@ def _equal_to_default_value(metadata, param, is_python_enum): else: param = param.value # return true if param is None or eqaul to default value - if param == default_value or param == None: + if param == default_value or param is None: return True return False + def _get_enum_values(param): if param == []: return param - # if param is a list of enums, return values of them in a list - # or param is an enum, returns the value of it + # if param is a list of enums, return values of them in a list + # or param is an enum, returns the value of it # else it doesn nothing to param if isinstance(param, list) and isinstance(param[0], Enum): return [x.value for x in param] @@ -132,12 +115,15 @@ def _get_enum_values(param): return param.value return param + def _define_enums(file_descriptor, param, field_descriptor): # if param is a list, then it sets param to 1st element in list if isinstance(param, list): param = param[0] # if there are no enums/param is a different enum and is a python enum, defines a enum field - if param.__class__.__name__ not in [enum.name for enum in file_descriptor.enum_type] and isinstance(param, Enum): + if param.__class__.__name__ not in [ + enum.name for enum in file_descriptor.enum_type + ] and isinstance(param, Enum): # Define a enum class enum_descriptor = file_descriptor.enum_type.add() enum_descriptor.name = param.__class__.__name__ @@ -150,15 +136,16 @@ def _define_enums(file_descriptor, param, field_descriptor): # checks enum if it's protobuf or python try: field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name - except: - field_descriptor.type_name = param.__class__.__name__ + # TODO: Add error type to except thing + except TypeError: + field_descriptor.type_name = param.__class__.__name__ + def _define_fields(message_proto, parameter_metadata, i, param, is_python_enum): # exits if param is None or eqaul to default value if not _equal_to_default_value( - metadata=parameter_metadata, - param=param, - is_python_enum=is_python_enum): + metadata=parameter_metadata, param=param, is_python_enum=is_python_enum + ): field_descriptor = message_proto.field.add() field_descriptor.number = i @@ -166,43 +153,29 @@ def _define_fields(message_proto, parameter_metadata, i, param, is_python_enum): field_descriptor.type = parameter_metadata.type # if a value is an array then it's labled as repeated and packed if parameter_metadata.repeated: - field_descriptor.label = field.LABEL_REPEATED - field_descriptor.options.packed = True + field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True else: - field_descriptor.label = field.LABEL_OPTIONAL + field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL # if a value is a message then assign type name to it's full name - if parameter_metadata.type == field.TYPE_MESSAGE: + if parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE: field_descriptor.type_name = parameter_metadata.message_type return field_descriptor -class DifferentColor(Enum): - """Non-primary colors used for testing enum-typed config and output.""" - - PURPLE = 0 - ORANGE = 1 - TEAL = 2 - BROWN = 3 - - -class Countries(IntEnum): - """Countries enum used for testing enum-typed config and output.""" - - AMERICA = 0 - TAIWAN = 1 - AUSTRALIA = 2 - CANADA = 3 - -double_xy_data = xydata_pb2.DoubleXYData() -double_xy_data.x_data.append(4) -double_xy_data.y_data.append(6) - -double_xy_data2 = xydata_pb2.DoubleXYData() -double_xy_data2.x_data.append(8) -double_xy_data2.y_data.append(10) - -double_xy_data_array = [double_xy_data, double_xy_data2] -# This should match the number of fields in bigmessage.proto. -BIG_MESSAGE_SIZE = 100 -if __name__ == "__main__": - test() \ No newline at end of file +def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: + """Get the default value for the give type.""" + _type_default_mapping = { + type_pb2.Field.TYPE_FLOAT: float(), + type_pb2.Field.TYPE_DOUBLE: float(), + type_pb2.Field.TYPE_INT32: int(), + type_pb2.Field.TYPE_INT64: int(), + type_pb2.Field.TYPE_UINT32: int(), + type_pb2.Field.TYPE_UINT64: int(), + type_pb2.Field.TYPE_BOOL: bool(), + type_pb2.Field.TYPE_STRING: str(), + type_pb2.Field.TYPE_ENUM: int(), + } + if repeated: + return list() + return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index f6aff41c8..5884313b9 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -10,7 +10,6 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization @@ -50,22 +49,20 @@ def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: Raises: TypeError: If default value does not match the Datatype. """ + from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import ( + get_type_default, + ) + default_value = parameter_metadata.default_value if default_value is None: return None - expected_type = type( - serialization_strategy.get_type_default( - parameter_metadata.type, parameter_metadata.repeated - ) - ) + expected_type = type(get_type_default(parameter_metadata.type, parameter_metadata.repeated)) display_name = parameter_metadata.display_name enum_values_annotation = get_enum_values_annotation(parameter_metadata) if parameter_metadata.repeated: - expected_element_type = type( - serialization_strategy.get_type_default(parameter_metadata.type, False) - ) + expected_element_type = type(get_type_default(parameter_metadata.type, False)) _validate_default_value_type_for_repeated_type( default_value, expected_type, diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py index 88fc5cd84..d534c9109 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py @@ -5,55 +5,19 @@ from typing import Any, Optional, cast from google.protobuf import type_pb2 -from google.protobuf.internal import decoder, encoder +from google.protobuf.internal import decoder from google.protobuf.message import Message from ni_measurement_plugin_sdk_service._internal.parameter import _message from ni_measurement_plugin_sdk_service._internal.parameter._serializer_types import ( Decoder, DecoderConstructor, - Encoder, - EncoderConstructor, Key, PartialDecoderConstructor, - PartialEncoderConstructor, ) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 - - -def _scalar_encoder(encoder: EncoderConstructor) -> PartialEncoderConstructor: - """Constructs a scalar encoder constructor. - - Takes a field index and returns an Encoder. - - This class returns the Encoder with is_repeated set to False - and is_packed set to False. - """ - - def scalar_encoder(field_index: int) -> Encoder: - is_repeated = False - is_packed = False - return encoder(field_index, is_repeated, is_packed) - - return scalar_encoder - - -def _vector_encoder( - encoder: EncoderConstructor, is_packed: bool = True -) -> PartialEncoderConstructor: - """Constructs a vector (array) encoder constructor. - - Takes a field index and returns an Encoder. - - This class returns the Encoder with is_repeated set to True - and is_packed defaulting to True. - """ - - def vector_encoder(field_index: int) -> Encoder: - is_repeated = True - return encoder(field_index, is_repeated, is_packed) - - return vector_encoder +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, +) def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor: @@ -117,26 +81,6 @@ def message_decoder(field_index: int, key: Key) -> Decoder: return message_decoder -# Cast works around this issue in typeshed -# https://github.com/python/typeshed/issues/10695 -FloatEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.FloatEncoder)) -DoubleEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.DoubleEncoder)) -IntEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.Int32Encoder)) -UIntEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.UInt32Encoder)) -BoolEncoder = _scalar_encoder(encoder.BoolEncoder) -StringEncoder = _scalar_encoder(encoder.StringEncoder) -MessageEncoder = _scalar_encoder(cast(EncoderConstructor, _message._message_encoder_constructor)) - -FloatArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.FloatEncoder)) -DoubleArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.DoubleEncoder)) -IntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.Int32Encoder)) -UIntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.UInt32Encoder)) -BoolArrayEncoder = _vector_encoder(encoder.BoolEncoder) -StringArrayEncoder = _vector_encoder(encoder.StringEncoder, is_packed=False) -MessageArrayEncoder = _vector_encoder( - cast(EncoderConstructor, _message._message_encoder_constructor) -) - # Cast works around this issue in typeshed # https://github.com/python/typeshed/issues/10697 FloatDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.FloatDecoder)) @@ -163,20 +107,6 @@ def message_decoder(field_index: int, key: Key) -> Decoder: _message._message_decoder_constructor, is_repeated=True ) - -_FIELD_TYPE_TO_ENCODER_MAPPING = { - type_pb2.Field.TYPE_FLOAT: (FloatEncoder, FloatArrayEncoder), - type_pb2.Field.TYPE_DOUBLE: (DoubleEncoder, DoubleArrayEncoder), - type_pb2.Field.TYPE_INT32: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_INT64: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_UINT32: (UIntEncoder, UIntArrayEncoder), - type_pb2.Field.TYPE_UINT64: (UIntEncoder, UIntArrayEncoder), - type_pb2.Field.TYPE_BOOL: (BoolEncoder, BoolArrayEncoder), - type_pb2.Field.TYPE_STRING: (StringEncoder, StringArrayEncoder), - type_pb2.Field.TYPE_ENUM: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, MessageArrayEncoder), -} - _FIELD_TYPE_TO_DECODER_MAPPING = { type_pb2.Field.TYPE_FLOAT: (FloatDecoder, FloatArrayDecoder), type_pb2.Field.TYPE_DOUBLE: (DoubleDecoder, DoubleArrayDecoder), @@ -189,18 +119,6 @@ def message_decoder(field_index: int, key: Key) -> Decoder: type_pb2.Field.TYPE_ENUM: (Int32Decoder, Int32ArrayDecoder), } -_TYPE_DEFAULT_MAPPING = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), -} - _MESSAGE_TYPE_TO_DECODER = { xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder, } @@ -210,19 +128,6 @@ def message_decoder(field_index: int, key: Key) -> Decoder: } -def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialEncoderConstructor: - """Get the appropriate partial encoder constructor for the specified type. - - A scalar or vector constructor is returned based on the 'repeated' parameter. - """ - if type not in _FIELD_TYPE_TO_ENCODER_MAPPING: - raise ValueError(f"Error can not encode type '{type}'") - scalar, array = _FIELD_TYPE_TO_ENCODER_MAPPING[type] - if repeated: - return array - return scalar - - def get_decoder( type: type_pb2.Field.Kind.ValueType, repeated: bool, message_type: str = "" ) -> PartialDecoderConstructor: @@ -241,11 +146,3 @@ def get_decoder( return decoder else: raise ValueError(f"Error can not decode type '{type}'") - - -def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: - """Get the default value for the give type.""" - if repeated: - return list() - type_default_value = _TYPE_DEFAULT_MAPPING.get(type) - return type_default_value diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py index 71ef39568..87b537b06 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py @@ -1,8 +1,6 @@ """Parameter Serializer.""" -from enum import Enum -from io import BytesIO -from typing import Any, Dict, Sequence, cast +from typing import Any, Dict, cast from google.protobuf.descriptor import FieldDescriptor from google.protobuf.internal.decoder import ( # type: ignore[attr-defined] @@ -10,15 +8,14 @@ ) from google.protobuf.message import Message -from ni_measurement_plugin_sdk_service._annotations import ( - TYPE_SPECIALIZATION_KEY, -) from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy +from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import ( + get_type_default, +) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, get_enum_values_annotation, ) -from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization _GRPC_WIRE_TYPE_BIT_WIDTH = 3 @@ -52,60 +49,6 @@ def deserialize_parameters( return overlapping_parameter_by_id -def serialize_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], - parameter_values: Sequence[Any], -) -> bytes: - """Serialize the parameter values in same order based on the metadata_dict. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_value (Sequence[Any]): Parameter values to serialize. - - Returns: - bytes: Serialized byte string containing parameter values. - """ - serialize_buffer = BytesIO() # inner_encoder updates the serialize_buffer - for i, parameter in enumerate(parameter_values): - parameter_metadata = parameter_metadata_dict[i + 1] - encoder = serialization_strategy.get_encoder( - parameter_metadata.type, - parameter_metadata.repeated, - ) - type_default_value = serialization_strategy.get_type_default( - parameter_metadata.type, - parameter_metadata.repeated, - ) - # Convert enum parameters to their underlying value if necessary. - if ( - parameter_metadata.annotations.get(TYPE_SPECIALIZATION_KEY) - == TypeSpecialization.Enum.value - ): - parameter = _get_enum_value(parameter, parameter_metadata.repeated) - # Skipping serialization if the value is None or if its a type default value. - if parameter is not None and parameter != type_default_value: - inner_encoder = encoder(i + 1) - inner_encoder(serialize_buffer.write, parameter, False) - return serialize_buffer.getvalue() - - -def serialize_default_values(parameter_metadata_dict: Dict[int, ParameterMetadata]) -> bytes: - """Serialize the Default values in the Metadata. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. - - Returns: - bytes: Serialized byte string containing default values. - """ - default_value_parameter_array = list() - default_value_parameter_array = [ - parameter.default_value for parameter in parameter_metadata_dict.values() - ] - return serialize_parameters(parameter_metadata_dict, default_value_parameter_array) - - def _get_overlapping_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes ) -> Dict[int, Any]: @@ -169,9 +112,7 @@ def _get_missing_parameters( enum_type = _get_enum_type(value) missing_parameters[key] = enum_type(0) else: - missing_parameters[key] = serialization_strategy.get_type_default( - value.type, value.repeated - ) + missing_parameters[key] = get_type_default(value.type, value.repeated) return missing_parameters @@ -203,16 +144,6 @@ def _deserialize_enum_parameters( parameter_by_id[id] = enum_type(value) -def _get_enum_value(parameter: Any, repeated: bool) -> Any: - if repeated: - if len(parameter) > 0 and isinstance(parameter[0], Enum): - return [x.value for x in parameter] - else: - if isinstance(parameter, Enum): - return parameter.value - return parameter - - def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: return type(parameter_metadata.default_value[0]) diff --git a/tests/unit/test_serialization_strategy.py b/tests/unit/test_serialization_strategy.py index 41b2c4a81..e3cea86af 100644 --- a/tests/unit/test_serialization_strategy.py +++ b/tests/unit/test_serialization_strategy.py @@ -3,32 +3,13 @@ import pytest from google.protobuf import type_pb2 -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 - - -@pytest.mark.parametrize( - "type,is_repeated,expected_encoder", - [ - (type_pb2.Field.TYPE_FLOAT, False, serialization_strategy.FloatEncoder), - (type_pb2.Field.TYPE_DOUBLE, False, serialization_strategy.DoubleEncoder), - (type_pb2.Field.TYPE_INT32, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_INT64, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_UINT32, False, serialization_strategy.UIntEncoder), - (type_pb2.Field.TYPE_UINT64, False, serialization_strategy.UIntEncoder), - (type_pb2.Field.TYPE_BOOL, False, serialization_strategy.BoolEncoder), - (type_pb2.Field.TYPE_STRING, False, serialization_strategy.StringEncoder), - (type_pb2.Field.TYPE_ENUM, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_MESSAGE, False, serialization_strategy.MessageEncoder), - (type_pb2.Field.TYPE_MESSAGE, True, serialization_strategy.MessageArrayEncoder), - ], +from ni_measurement_plugin_sdk_service._internal.parameter import ( + serialization_strategy, + message_serializer, +) +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, ) -def test___serialization_strategy___get_encoder___returns_expected_encoder( - type, is_repeated, expected_encoder -): - encoder = serialization_strategy.get_encoder(type, is_repeated) - - assert encoder == expected_encoder @pytest.mark.parametrize( @@ -84,6 +65,6 @@ def test___serialization_strategy___get_decoder___returns_expected_decoder( def test___serialization_strategy___get_default_value___returns_type_defaults( type, is_repeated, expected_default_value ): - default_value = serialization_strategy.get_type_default(type, is_repeated) + default_value = message_serializer.get_type_default(type, is_repeated) assert default_value == expected_default_value diff --git a/tests/unit/test_serializer.py b/tests/unit/test_serializer.py index bb3655939..227fa47fa 100644 --- a/tests/unit/test_serializer.py +++ b/tests/unit/test_serializer.py @@ -10,7 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import serializer +from ni_measurement_plugin_sdk_service._internal.parameter import serializer, message_serializer from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, @@ -110,7 +110,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu parameter = _get_test_parameter_by_id(default_values) # Custom Serialization - custom_serialized_bytes = serializer.serialize_parameters(parameter, test_values) + custom_serialized_bytes = message_serializer.serialize_parameters(parameter, test_values) _validate_serialized_bytes(custom_serialized_bytes, test_values) @@ -172,7 +172,7 @@ def test___serializer___serialize_default_parameter___successful_serialization(d parameter = _get_test_parameter_by_id(default_values) # Custom Serialization - custom_serialized_bytes = serializer.serialize_default_values(parameter) + custom_serialized_bytes = message_serializer.serialize_default_values(parameter) _validate_serialized_bytes(custom_serialized_bytes, default_values) @@ -278,7 +278,7 @@ def test___big_message___serialize_parameters___returns_serialized_data() -> Non values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] expected_message = _get_big_message(values) - serialized_data = serializer.serialize_parameters(parameter_metadata_by_id, values) + serialized_data = message_serializer.serialize_parameters(parameter_metadata_by_id, values) message = BigMessage.FromString(serialized_data) assert message.ListFields() == pytest.approx(expected_message.ListFields()) From dd8e33a2717fd6b1bc711652b6936494c940c818 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 09:53:33 -0500 Subject: [PATCH 07/25] Replaced test_message_serialzier with test_serializer --- tests/unit/test_message_serializer.py | 106 -------------------------- 1 file changed, 106 deletions(-) delete mode 100644 tests/unit/test_message_serializer.py diff --git a/tests/unit/test_message_serializer.py b/tests/unit/test_message_serializer.py deleted file mode 100644 index f111ed762..000000000 --- a/tests/unit/test_message_serializer.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest - -from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import serialize_parameters as new_serializer -from tests.unit.test_serializer import _get_test_parameter_by_id as currentParameter -from ni_measurement_plugin_sdk_service._internal.parameter import serializer - -from enum import Enum, IntEnum -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 - -class DifferentColor(Enum): - """Non-primary colors used for testing enum-typed config and output.""" - - PURPLE = 0 - ORANGE = 1 - TEAL = 2 - BROWN = 3 - - -class Countries(IntEnum): - """Countries enum used for testing enum-typed config and output.""" - - AMERICA = 0 - TAIWAN = 1 - AUSTRALIA = 2 - CANADA = 3 - -double_xy_data = xydata_pb2.DoubleXYData() -double_xy_data.x_data.append(4) -double_xy_data.y_data.append(6) - -double_xy_data2 = xydata_pb2.DoubleXYData() -double_xy_data2.x_data.append(8) -double_xy_data2.y_data.append(10) - -double_xy_data_array = [double_xy_data, double_xy_data2] - -@pytest.mark.parametrize( - "test_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array - ], - [ - -0.9999, - -0.9999, - -13, - 1, - 1000, - 2, - True, - "////", - [5.5, -13.3, 1, 0.0, -99.9999], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ] - ] -) - -def test___serializer___serialize_parameter___successful_serialization(test_values): - parameter = currentParameter(test_values) - - new_serialize = new_serializer( - parameter_metadata_dict=parameter, - parameter_values=test_values,) - current_serialize = serializer.serialize_parameters( - parameter_metadata_dict=parameter, - parameter_values=test_values) - - print() - print(f"Current Serializer: {current_serialize}") - print() - print(f"Message Serializer: {new_serialize}") - - assert new_serialize == current_serialize \ No newline at end of file From a936584e5756c4aa010c0d8e0ccdb6b733ce0036 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 10:37:06 -0500 Subject: [PATCH 08/25] Fixed 'Mypy statis analysis' in message_serializer --- .../_internal/parameter/message_serializer.py | 93 +++++++++++++++---- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py index 7036df5e2..4e4c65045 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py @@ -17,7 +17,16 @@ def serialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_values: Sequence[Any], ) -> bytes: - """Test doc string?""" + """Serialize the parameter values in same order based on the metadata_dict. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + parameter_value (Sequence[Any]): Parameter values to serialize. + + Returns: + bytes: Serialized byte string containing parameter values. + """ pool = descriptor_pool.Default() file_descriptor_proto = descriptor_pb2.FileDescriptorProto() original_guid = uuid4() @@ -36,8 +45,8 @@ def serialize_parameters( # Define fields field_descriptor = _define_fields( message_proto=message_proto, - parameter_metadata=parameter_metadata, - i=i, + metadata=parameter_metadata, + index=i, param=parameter, is_python_enum=is_python_enum, ) @@ -89,7 +98,19 @@ def serialize_default_values(parameter_metadata_dict: Dict[int, ParameterMetadat return serialize_parameters(parameter_metadata_dict, default_value_parameter_array) -def _equal_to_default_value(metadata, param, is_python_enum): +def _equal_to_default_value(metadata: ParameterMetadata, param: Any, is_python_enum: bool) -> bool: + """Determine if 'param' is equal to it's default value. + + Args: + metadata (ParameterMetadata): Metadata of 'param'. + + param (Any): A value/parameter of parameter_values. + + is_python_enum (boolean): True if 'param' is a enum from python's libraries. + + Returns: + boolean: True if 'param' is equal it's default value or is None. + """ default_value = get_type_default(metadata.type, metadata.repeated) # gets value from a regular python enum if is_python_enum: @@ -97,13 +118,21 @@ def _equal_to_default_value(metadata, param, is_python_enum): param = param[0].value else: param = param.value - # return true if param is None or eqaul to default value if param == default_value or param is None: return True return False -def _get_enum_values(param): +def _get_enum_values(param: Any) -> Any: + """Get's value of an enum. + + Args: + param (Any): A value/parameter of parameter_values. + + Returns: + Any: An enum value or a list of enums or the 'param'. + + """ if param == []: return param # if param is a list of enums, return values of them in a list @@ -116,7 +145,20 @@ def _get_enum_values(param): return param -def _define_enums(file_descriptor, param, field_descriptor): +def _define_enums( + file_descriptor: descriptor_pb2.FileDescriptorProto, + param: Any, + field_descriptor: FieldDescriptorProto, +) -> None: + """Implement a enum class in 'file_descriptor'. + + Args: + file_descriptor (FileDescriptorProto): Descriptor of a proto file. + + param (Any): A value/parameter of parameter_values. + + field_descriptor (FieldDescriptorProto): Descriptor of a field. + """ # if param is a list, then it sets param to 1st element in list if isinstance(param, list): param = param[0] @@ -136,30 +178,45 @@ def _define_enums(file_descriptor, param, field_descriptor): # checks enum if it's protobuf or python try: field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name - # TODO: Add error type to except thing except TypeError: field_descriptor.type_name = param.__class__.__name__ -def _define_fields(message_proto, parameter_metadata, i, param, is_python_enum): +def _define_fields( + message_proto: Any, metadata: ParameterMetadata, index: int, param: Any, is_python_enum: bool +) -> Any: + """Implement a field in 'message_proto'. + + Args: + message_proto (message_type): A message instance in 'file_descriptor_proto'. + + metadata (ParameterMetadata): Metadata of 'param'. + + index (int): 'param' index in parameter_values + + param (Any): A value/parameter of parameter_values. + + is_python_enum (boolean): True if 'param' is a enum from python's libraries. + + Returns: + Any: field_descriptor of 'param' or None if 'param' is not equal_to_default_value. + """ # exits if param is None or eqaul to default value - if not _equal_to_default_value( - metadata=parameter_metadata, param=param, is_python_enum=is_python_enum - ): + if not _equal_to_default_value(metadata=metadata, param=param, is_python_enum=is_python_enum): field_descriptor = message_proto.field.add() - field_descriptor.number = i - field_descriptor.name = f"field_{i}" - field_descriptor.type = parameter_metadata.type + field_descriptor.number = index + field_descriptor.name = f"field_{index}" + field_descriptor.type = metadata.type # if a value is an array then it's labled as repeated and packed - if parameter_metadata.repeated: + if metadata.repeated: field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED field_descriptor.options.packed = True else: field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL # if a value is a message then assign type name to it's full name - if parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE: - field_descriptor.type_name = parameter_metadata.message_type + if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + field_descriptor.type_name = metadata.message_type return field_descriptor From e24da77554d61286f52eb3c0e3f8986215e0c8dc Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 15:39:14 -0500 Subject: [PATCH 09/25] Changed file names corresponding to it's functionality --- .../_internal/grpc_servicer.py | 19 +++++----- .../parameter/{serializer.py => decoder.py} | 11 +++--- ...zation_strategy.py => decoder_strategy.py} | 18 ++++++++++ .../{message_serializer.py => encoder.py} | 36 ++++++------------- .../_internal/parameter/metadata.py | 5 +-- ...n_strategy.py => test_decoder_strategy.py} | 29 ++++++++------- tests/unit/test_serializer.py | 14 ++++---- 7 files changed, 63 insertions(+), 69 deletions(-) rename ni_measurement_plugin_sdk_service/_internal/parameter/{serializer.py => decoder.py} (95%) rename ni_measurement_plugin_sdk_service/_internal/parameter/{serialization_strategy.py => decoder_strategy.py} (90%) rename ni_measurement_plugin_sdk_service/_internal/parameter/{message_serializer.py => encoder.py} (88%) rename tests/unit/{test_serialization_strategy.py => test_decoder_strategy.py} (60%) diff --git a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index 34dd34614..f84d9a361 100644 --- a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -13,10 +13,7 @@ import grpc from google.protobuf import any_pb2 -from ni_measurement_plugin_sdk_service._internal.parameter import ( - message_serializer, - serializer, -) +from ni_measurement_plugin_sdk_service._internal.parameter import decoder, encoder from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) @@ -138,7 +135,7 @@ def _get_mapping_by_parameter_name( def _serialize_outputs(output_metadata: Dict[int, ParameterMetadata], outputs: Any) -> any_pb2.Any: if isinstance(outputs, collections.abc.Sequence): - return any_pb2.Any(value=message_serializer.serialize_parameters(output_metadata, outputs)) + return any_pb2.Any(value=encoder.serialize_parameters(output_metadata, outputs)) elif outputs is None: raise ValueError(f"Measurement function returned None") else: @@ -198,8 +195,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = ( - message_serializer.serialize_default_values(self._configuration_metadata) + measurement_signature.configuration_defaults.value = encoder.serialize_default_values( + self._configuration_metadata ) for field_number, output_metadata in self._output_metadata.items(): @@ -229,7 +226,7 @@ def Measure( # noqa: N802 - function name should be lowercase self, request: v1_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext ) -> v1_measurement_service_pb2.MeasureResponse: """RPC API that executes the registered measurement method.""" - mapping_by_id = serializer.deserialize_parameters( + mapping_by_id = decoder.deserialize_parameters( self._configuration_metadata, request.configuration_parameters.value ) mapping_by_variable_name = _get_mapping_by_parameter_name( @@ -306,8 +303,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = ( - message_serializer.serialize_default_values(self._configuration_metadata) + measurement_signature.configuration_defaults.value = encoder.serialize_default_values( + self._configuration_metadata ) for field_number, output_metadata in self._output_metadata.items(): @@ -339,7 +336,7 @@ def Measure( # noqa: N802 - function name should be lowercase self, request: v2_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext ) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]: """RPC API that executes the registered measurement method.""" - mapping_by_id = serializer.deserialize_parameters( + mapping_by_id = decoder.deserialize_parameters( self._configuration_metadata, request.configuration_parameters.value ) mapping_by_variable_name = _get_mapping_by_parameter_name( diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py similarity index 95% rename from ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py rename to ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index 87b537b06..c38d4d7ca 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -8,10 +8,7 @@ ) from google.protobuf.message import Message -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy -from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import ( - get_type_default, -) +from ni_measurement_plugin_sdk_service._internal.parameter import decoder_strategy from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, get_enum_values_annotation, @@ -77,7 +74,7 @@ def _get_overlapping_parameters( f"Error occurred while reading the parameter - given protobuf index '{field_index}' is invalid." ) field_metadata = parameter_metadata_dict[field_index] - decoder = serialization_strategy.get_decoder( + decoder = decoder_strategy.get_decoder( field_metadata.type, field_metadata.repeated, field_metadata.message_type ) inner_decoder = decoder(field_index, cast(FieldDescriptor, field_index)) @@ -112,7 +109,9 @@ def _get_missing_parameters( enum_type = _get_enum_type(value) missing_parameters[key] = enum_type(0) else: - missing_parameters[key] = get_type_default(value.type, value.repeated) + missing_parameters[key] = decoder_strategy.get_type_default( + value.type, value.repeated + ) return missing_parameters diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py similarity index 90% rename from ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py rename to ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py index d534c9109..b4923d4bb 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py @@ -146,3 +146,21 @@ def get_decoder( return decoder else: raise ValueError(f"Error can not decode type '{type}'") + + +def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: + """Get the default value for the give type.""" + _type_default_mapping = { + type_pb2.Field.TYPE_FLOAT: float(), + type_pb2.Field.TYPE_DOUBLE: float(), + type_pb2.Field.TYPE_INT32: int(), + type_pb2.Field.TYPE_INT64: int(), + type_pb2.Field.TYPE_UINT32: int(), + type_pb2.Field.TYPE_UINT64: int(), + type_pb2.Field.TYPE_BOOL: bool(), + type_pb2.Field.TYPE_STRING: str(), + type_pb2.Field.TYPE_ENUM: int(), + } + if repeated: + return list() + return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py similarity index 88% rename from ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py rename to ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 4e4c65045..7ad676d52 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/message_serializer.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -3,10 +3,10 @@ from typing import Any, Dict, Sequence from uuid import uuid4 -from google.protobuf import descriptor_pb2, descriptor_pool, message_factory, type_pb2 +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto -# metadata +from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import get_type_default from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) @@ -29,14 +29,14 @@ def serialize_parameters( """ pool = descriptor_pool.Default() file_descriptor_proto = descriptor_pb2.FileDescriptorProto() - original_guid = uuid4() - new_guid = "msg" + "".join(filter(str.isalnum, str(original_guid)))[:16] - file_descriptor_proto.name = str(new_guid) - file_descriptor_proto.package = str(new_guid) + original_guid = str(uuid4()) + unique_descriptor_name = "msg" + "".join(filter(str.isalnum, original_guid))[:16] + file_descriptor_proto.name = str(unique_descriptor_name) + file_descriptor_proto.package = str(unique_descriptor_name) # Create a DescriptorProto for the message message_proto = file_descriptor_proto.message_type.add() - message_proto.name = str(new_guid) + message_proto.name = str(unique_descriptor_name) # Initialize the message with fields defined for i, parameter in enumerate(parameter_values, start=1): @@ -63,7 +63,9 @@ def serialize_parameters( # Get message and add fields to it pool.Add(file_descriptor_proto) - message_descriptor = pool.FindMessageTypeByName(str(new_guid) + "." + str(new_guid)) + message_descriptor = pool.FindMessageTypeByName( + f"{unique_descriptor_name}.{unique_descriptor_name}" + ) message_instance = message_factory.GetMessageClass(message_descriptor)() # assign values to fields @@ -218,21 +220,3 @@ def _define_fields( if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: field_descriptor.type_name = metadata.message_type return field_descriptor - - -def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: - """Get the default value for the give type.""" - _type_default_mapping = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), - } - if repeated: - return list() - return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 5884313b9..64ee5ebe4 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -10,6 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) +from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import get_type_default from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization @@ -49,10 +50,6 @@ def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: Raises: TypeError: If default value does not match the Datatype. """ - from ni_measurement_plugin_sdk_service._internal.parameter.message_serializer import ( - get_type_default, - ) - default_value = parameter_metadata.default_value if default_value is None: return None diff --git a/tests/unit/test_serialization_strategy.py b/tests/unit/test_decoder_strategy.py similarity index 60% rename from tests/unit/test_serialization_strategy.py rename to tests/unit/test_decoder_strategy.py index e3cea86af..f1d92a277 100644 --- a/tests/unit/test_serialization_strategy.py +++ b/tests/unit/test_decoder_strategy.py @@ -4,8 +4,7 @@ from google.protobuf import type_pb2 from ni_measurement_plugin_sdk_service._internal.parameter import ( - serialization_strategy, - message_serializer, + decoder_strategy, ) from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( xydata_pb2, @@ -15,33 +14,33 @@ @pytest.mark.parametrize( "type,is_repeated,message_type,expected_decoder", [ - (type_pb2.Field.TYPE_FLOAT, False, "", serialization_strategy.FloatDecoder), - (type_pb2.Field.TYPE_DOUBLE, False, "", serialization_strategy.DoubleDecoder), - (type_pb2.Field.TYPE_INT32, False, "", serialization_strategy.Int32Decoder), - (type_pb2.Field.TYPE_INT64, False, "", serialization_strategy.Int64Decoder), - (type_pb2.Field.TYPE_UINT32, False, "", serialization_strategy.UInt32Decoder), - (type_pb2.Field.TYPE_UINT64, False, "", serialization_strategy.UInt64Decoder), - (type_pb2.Field.TYPE_BOOL, False, "", serialization_strategy.BoolDecoder), - (type_pb2.Field.TYPE_STRING, False, "", serialization_strategy.StringDecoder), - (type_pb2.Field.TYPE_ENUM, False, "", serialization_strategy.Int32Decoder), + (type_pb2.Field.TYPE_FLOAT, False, "", decoder_strategy.FloatDecoder), + (type_pb2.Field.TYPE_DOUBLE, False, "", decoder_strategy.DoubleDecoder), + (type_pb2.Field.TYPE_INT32, False, "", decoder_strategy.Int32Decoder), + (type_pb2.Field.TYPE_INT64, False, "", decoder_strategy.Int64Decoder), + (type_pb2.Field.TYPE_UINT32, False, "", decoder_strategy.UInt32Decoder), + (type_pb2.Field.TYPE_UINT64, False, "", decoder_strategy.UInt64Decoder), + (type_pb2.Field.TYPE_BOOL, False, "", decoder_strategy.BoolDecoder), + (type_pb2.Field.TYPE_STRING, False, "", decoder_strategy.StringDecoder), + (type_pb2.Field.TYPE_ENUM, False, "", decoder_strategy.Int32Decoder), ( type_pb2.Field.TYPE_MESSAGE, False, xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - serialization_strategy.XYDataDecoder, + decoder_strategy.XYDataDecoder, ), ( type_pb2.Field.TYPE_MESSAGE, True, xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - serialization_strategy.XYDataArrayDecoder, + decoder_strategy.XYDataArrayDecoder, ), ], ) def test___serialization_strategy___get_decoder___returns_expected_decoder( type, is_repeated, message_type, expected_decoder ): - decoder = serialization_strategy.get_decoder(type, is_repeated, message_type) + decoder = decoder_strategy.get_decoder(type, is_repeated, message_type) assert decoder == expected_decoder @@ -65,6 +64,6 @@ def test___serialization_strategy___get_decoder___returns_expected_decoder( def test___serialization_strategy___get_default_value___returns_type_defaults( type, is_repeated, expected_default_value ): - default_value = message_serializer.get_type_default(type, is_repeated) + default_value = decoder_strategy.get_type_default(type, is_repeated) assert default_value == expected_default_value diff --git a/tests/unit/test_serializer.py b/tests/unit/test_serializer.py index 227fa47fa..46e49c7b6 100644 --- a/tests/unit/test_serializer.py +++ b/tests/unit/test_serializer.py @@ -10,7 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import serializer, message_serializer +from ni_measurement_plugin_sdk_service._internal.parameter import decoder, encoder from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, @@ -110,7 +110,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu parameter = _get_test_parameter_by_id(default_values) # Custom Serialization - custom_serialized_bytes = message_serializer.serialize_parameters(parameter, test_values) + custom_serialized_bytes = encoder.serialize_parameters(parameter, test_values) _validate_serialized_bytes(custom_serialized_bytes, test_values) @@ -172,7 +172,7 @@ def test___serializer___serialize_default_parameter___successful_serialization(d parameter = _get_test_parameter_by_id(default_values) # Custom Serialization - custom_serialized_bytes = message_serializer.serialize_default_values(parameter) + custom_serialized_bytes = encoder.serialize_default_values(parameter) _validate_serialized_bytes(custom_serialized_bytes, default_values) @@ -210,7 +210,7 @@ def test___serializer___deserialize_parameter___successful_deserialization(value parameter = _get_test_parameter_by_id(values) grpc_serialized_data = _get_grpc_serialized_data(values) - parameter_value_by_id = serializer.deserialize_parameters(parameter, grpc_serialized_data) + parameter_value_by_id = decoder.deserialize_parameters(parameter, grpc_serialized_data) assert list(parameter_value_by_id.values()) == values @@ -243,7 +243,7 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): double_xy_data_array, ] parameter = _get_test_parameter_by_id(nonzero_defaults) - parameter_value_by_id = serializer.deserialize_parameters(parameter, bytes()) + parameter_value_by_id = decoder.deserialize_parameters(parameter, bytes()) for key, value in parameter_value_by_id.items(): parameter_metadata = parameter[key] @@ -266,7 +266,7 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( serialized_data = message.SerializeToString() expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} - parameter_value_by_id = serializer.deserialize_parameters( + parameter_value_by_id = decoder.deserialize_parameters( parameter_metadata_by_id, serialized_data ) @@ -278,7 +278,7 @@ def test___big_message___serialize_parameters___returns_serialized_data() -> Non values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] expected_message = _get_big_message(values) - serialized_data = message_serializer.serialize_parameters(parameter_metadata_by_id, values) + serialized_data = encoder.serialize_parameters(parameter_metadata_by_id, values) message = BigMessage.FromString(serialized_data) assert message.ListFields() == pytest.approx(expected_message.ListFields()) From 371e90d75aba7e189974e566e797867e48e393a8 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 15:48:05 -0500 Subject: [PATCH 10/25] Fixed naming issue --- .../_internal/parameter/encoder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 7ad676d52..c7377abca 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -30,13 +30,13 @@ def serialize_parameters( pool = descriptor_pool.Default() file_descriptor_proto = descriptor_pb2.FileDescriptorProto() original_guid = str(uuid4()) - unique_descriptor_name = "msg" + "".join(filter(str.isalnum, original_guid))[:16] - file_descriptor_proto.name = str(unique_descriptor_name) - file_descriptor_proto.package = str(unique_descriptor_name) + unique_descriptor_name = str("msg" + "".join(filter(str.isalnum, original_guid))[:16]) + file_descriptor_proto.name = unique_descriptor_name + file_descriptor_proto.package = unique_descriptor_name # Create a DescriptorProto for the message message_proto = file_descriptor_proto.message_type.add() - message_proto.name = str(unique_descriptor_name) + message_proto.name = unique_descriptor_name # Initialize the message with fields defined for i, parameter in enumerate(parameter_values, start=1): From 96f07ce4f8346d19c68bb01158846d7c64c19928 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 24 Jun 2024 15:55:34 -0500 Subject: [PATCH 11/25] Changed 'test_serializer' to 'test_decoder' --- tests/unit/{test_serializer.py => test_decoder.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/{test_serializer.py => test_decoder.py} (100%) diff --git a/tests/unit/test_serializer.py b/tests/unit/test_decoder.py similarity index 100% rename from tests/unit/test_serializer.py rename to tests/unit/test_decoder.py From ce4e3e129599c809c5e6ed8fe1b5cd1ed070a468 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 1 Jul 2024 11:37:28 -0500 Subject: [PATCH 12/25] Implemented encoder to reuse message types, renamed and reorder files --- .../_internal/grpc_servicer.py | 25 +- ..._serializer_types.py => _decoder_types.py} | 3 - .../_internal/parameter/_message.py | 2 +- .../_internal/parameter/decoder_strategy.py | 26 +- .../_internal/parameter/encoder.py | 178 ++++++------- .../_internal/service_manager.py | 11 +- tests/unit/test_decoder.py | 144 +---------- tests/unit/test_decoder_strategy.py | 4 +- tests/unit/test_encoder.py | 241 ++++++++++++++++++ 9 files changed, 365 insertions(+), 269 deletions(-) rename ni_measurement_plugin_sdk_service/_internal/parameter/{_serializer_types.py => _decoder_types.py} (78%) create mode 100644 tests/unit/test_encoder.py diff --git a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index f84d9a361..8a282b404 100644 --- a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -25,7 +25,10 @@ measurement_service_pb2 as v2_measurement_service_pb2, measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc, ) -from ni_measurement_plugin_sdk_service.measurement.info import MeasurementInfo +from ni_measurement_plugin_sdk_service.measurement.info import ( + MeasurementInfo, + ServiceInfo, +) from ni_measurement_plugin_sdk_service.session_management import PinMapContext @@ -133,9 +136,13 @@ def _get_mapping_by_parameter_name( return mapping_by_variable_name -def _serialize_outputs(output_metadata: Dict[int, ParameterMetadata], outputs: Any) -> any_pb2.Any: +def _serialize_outputs( + output_metadata: Dict[int, ParameterMetadata], outputs: Any, service_info: ServiceInfo +) -> any_pb2.Any: if isinstance(outputs, collections.abc.Sequence): - return any_pb2.Any(value=encoder.serialize_parameters(output_metadata, outputs)) + return any_pb2.Any( + value=encoder.serialize_parameters(output_metadata, outputs, service_info) + ) elif outputs is None: raise ValueError(f"Measurement function returned None") else: @@ -163,6 +170,7 @@ def __init__( output_parameter_list: List[ParameterMetadata], measure_function: Callable, owner: object, + service_info: ServiceInfo, ) -> None: """Initialize the measurement v1 servicer.""" super().__init__() @@ -171,6 +179,7 @@ def __init__( self._measurement_info = measurement_info self._measure_function = measure_function self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle + self._service_info = service_info def GetMetadata( # noqa: N802 - function name should be lowercase self, request: v1_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext @@ -196,7 +205,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata + self._configuration_metadata, self._service_info ) for field_number, output_metadata in self._output_metadata.items(): @@ -256,7 +265,7 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v1_measurement_service_pb2.MeasureResponse: return v1_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs) + outputs=_serialize_outputs(self._output_metadata, outputs, self._service_info) ) @@ -270,6 +279,7 @@ def __init__( output_parameter_list: List[ParameterMetadata], measure_function: Callable, owner: object, + service_info: ServiceInfo, ) -> None: """Initialize the measurement v2 servicer.""" super().__init__() @@ -278,6 +288,7 @@ def __init__( self._measurement_info = measurement_info self._measure_function = measure_function self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle + self._service_info = service_info def GetMetadata( # noqa: N802 - function name should be lowercase self, request: v2_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext @@ -304,7 +315,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata + self._configuration_metadata, self._service_info ) for field_number, output_metadata in self._output_metadata.items(): @@ -365,5 +376,5 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse: return v2_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs) + outputs=_serialize_outputs(self._output_metadata, outputs, self._service_info) ) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py b/ni_measurement_plugin_sdk_service/_internal/parameter/_decoder_types.py similarity index 78% rename from ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py rename to ni_measurement_plugin_sdk_service/_internal/parameter/_decoder_types.py index a34d66352..e3dcd1d87 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/_decoder_types.py @@ -16,9 +16,6 @@ Key: TypeAlias = FieldDescriptor WriteFunction: TypeAlias = Callable[[bytes], int] -Encoder: TypeAlias = Callable[[WriteFunction, bytes, bool], int] -PartialEncoderConstructor: TypeAlias = Callable[[int], Encoder] -EncoderConstructor: TypeAlias = Callable[[int, bool, bool], Encoder] Decoder: TypeAlias = Callable[[memoryview, int, int, Message, Dict[Key, Any]], int] PartialDecoderConstructor: TypeAlias = Callable[[int, Key], Decoder] diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py b/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py index dcc04b146..83e2363e0 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py @@ -15,7 +15,7 @@ from google.protobuf.internal import encoder, wire_format from google.protobuf.message import Message -from ni_measurement_plugin_sdk_service._internal.parameter._serializer_types import ( +from ni_measurement_plugin_sdk_service._internal.parameter._decoder_types import ( Decoder, Key, NewDefault, diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py index b4923d4bb..f7babb159 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py @@ -9,7 +9,7 @@ from google.protobuf.message import Message from ni_measurement_plugin_sdk_service._internal.parameter import _message -from ni_measurement_plugin_sdk_service._internal.parameter._serializer_types import ( +from ni_measurement_plugin_sdk_service._internal.parameter._decoder_types import ( Decoder, DecoderConstructor, Key, @@ -148,19 +148,21 @@ def get_decoder( raise ValueError(f"Error can not decode type '{type}'") +_type_default_mapping = { + type_pb2.Field.TYPE_FLOAT: float(), + type_pb2.Field.TYPE_DOUBLE: float(), + type_pb2.Field.TYPE_INT32: int(), + type_pb2.Field.TYPE_INT64: int(), + type_pb2.Field.TYPE_UINT32: int(), + type_pb2.Field.TYPE_UINT64: int(), + type_pb2.Field.TYPE_BOOL: bool(), + type_pb2.Field.TYPE_STRING: str(), + type_pb2.Field.TYPE_ENUM: int(), +} + + def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: """Get the default value for the give type.""" - _type_default_mapping = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), - } if repeated: return list() return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index c7377abca..2ddcfd4e2 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -1,21 +1,24 @@ # enums and default values from enum import Enum from typing import Any, Dict, Sequence -from uuid import uuid4 from google.protobuf import descriptor_pb2, descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import get_type_default +from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import ( + get_type_default, +) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) +from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor def serialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_values: Sequence[Any], + service_info: ServiceInfo, ) -> bytes: """Serialize the parameter values in same order based on the metadata_dict. @@ -28,101 +31,90 @@ def serialize_parameters( bytes: Serialized byte string containing parameter values. """ pool = descriptor_pool.Default() - file_descriptor_proto = descriptor_pb2.FileDescriptorProto() - original_guid = str(uuid4()) - unique_descriptor_name = str("msg" + "".join(filter(str.isalnum, original_guid))[:16]) - file_descriptor_proto.name = unique_descriptor_name - file_descriptor_proto.package = unique_descriptor_name - - # Create a DescriptorProto for the message - message_proto = file_descriptor_proto.message_type.add() - message_proto.name = unique_descriptor_name - - # Initialize the message with fields defined - for i, parameter in enumerate(parameter_values, start=1): - parameter_metadata = parameter_metadata_dict[i] - is_python_enum = isinstance(parameter, Enum) - # Define fields - field_descriptor = _define_fields( - message_proto=message_proto, - metadata=parameter_metadata, - index=i, - param=parameter, - is_python_enum=is_python_enum, + message_name = "".join(char for char in service_info.service_class if char.isalpha()) + # Tries to find a message type in pool with message_name else it creates one + try: + message_proto = pool.FindMessageTypeByName(f"{message_name}.{message_name}") + except KeyError: + message_proto = _create_message_type( + parameter_values, parameter_metadata_dict, message_name, pool ) - # define enums if it's a regular or a protobuf enum and there's a field - if ( - parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM - and field_descriptor is not None - ): - _define_enums( - file_descriptor=file_descriptor_proto, - param=parameter, - field_descriptor=field_descriptor, - ) + message_instance = message_factory.GetMessageClass(message_proto)() - # Get message and add fields to it - pool.Add(file_descriptor_proto) - message_descriptor = pool.FindMessageTypeByName( - f"{unique_descriptor_name}.{unique_descriptor_name}" - ) - message_instance = message_factory.GetMessageClass(message_descriptor)() - - # assign values to fields for i, parameter in enumerate(parameter_values, start=1): field_name = f"field_{i}" parameter_metadata = parameter_metadata_dict[i] parameter = _get_enum_values(param=parameter) - try: + type_default_value = get_type_default(parameter_metadata.type, parameter_metadata.repeated) + + # Doesn't assign default values or None values to fields + if parameter != type_default_value and parameter is not None: if parameter_metadata.repeated: getattr(message_instance, field_name).extend(parameter) elif parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE: getattr(message_instance, field_name).CopyFrom(parameter) else: setattr(message_instance, field_name, parameter) - except AttributeError: - i += 1 # no field: parameter is None or equal to default value return message_instance.SerializeToString() -def serialize_default_values(parameter_metadata_dict: Dict[int, ParameterMetadata]) -> bytes: - """Serialize the Default values in the Metadata. +def _create_message_type( + parameter_values: Sequence[Any], + parameter_metadata_dict: Dict[int, ParameterMetadata], + message_name: str, + pool: descriptor_pool.DescriptorPool, +) -> Any: + """Creates a message descriptor with the fields defined in a file descriptor proto. Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + parameter_value (Sequence[Any]): Parameter values to serialize. + + message_name (str): Service class name. Returns: - bytes: Serialized byte string containing default values. + Any: A message class based on a defined message_descriptor """ - default_value_parameter_array = [ - parameter.default_value for parameter in parameter_metadata_dict.values() - ] - return serialize_parameters(parameter_metadata_dict, default_value_parameter_array) - + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = message_name + file_descriptor.package = message_name + message_proto = file_descriptor.message_type.add() + message_proto.name = message_name -def _equal_to_default_value(metadata: ParameterMetadata, param: Any, is_python_enum: bool) -> bool: - """Determine if 'param' is equal to it's default value. + # Initialize the message with fields defined + for i, parameter in enumerate(parameter_values, start=1): + parameter_metadata = parameter_metadata_dict[i] + field_descriptor = _create_fields( + message_proto=message_proto, metadata=parameter_metadata, index=i + ) + if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: + _create_enum_type( + file_descriptor=file_descriptor, + param=parameter, + field_descriptor=field_descriptor, + ) + pool.Add(file_descriptor) + return pool.FindMessageTypeByName(f"{file_descriptor.package}.{message_proto.name}") - Args: - metadata (ParameterMetadata): Metadata of 'param'. - param (Any): A value/parameter of parameter_values. +def serialize_default_values( + parameter_metadata_dict: Dict[int, ParameterMetadata], service_info: ServiceInfo +) -> bytes: + """Serialize the Default values in the Metadata. - is_python_enum (boolean): True if 'param' is a enum from python's libraries. + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. Returns: - boolean: True if 'param' is equal it's default value or is None. + bytes: Serialized byte string containing default values. """ - default_value = get_type_default(metadata.type, metadata.repeated) - # gets value from a regular python enum - if is_python_enum: - if metadata.repeated: - param = param[0].value - else: - param = param.value - if param == default_value or param is None: - return True - return False + default_value_parameter_array = [ + parameter.default_value for parameter in parameter_metadata_dict.values() + ] + return serialize_parameters( + parameter_metadata_dict, default_value_parameter_array, service_info + ) def _get_enum_values(param: Any) -> Any: @@ -133,13 +125,9 @@ def _get_enum_values(param: Any) -> Any: Returns: Any: An enum value or a list of enums or the 'param'. - """ if param == []: return param - # if param is a list of enums, return values of them in a list - # or param is an enum, returns the value of it - # else it doesn nothing to param if isinstance(param, list) and isinstance(param[0], Enum): return [x.value for x in param] elif isinstance(param, Enum): @@ -147,7 +135,7 @@ def _get_enum_values(param: Any) -> Any: return param -def _define_enums( +def _create_enum_type( file_descriptor: descriptor_pb2.FileDescriptorProto, param: Any, field_descriptor: FieldDescriptorProto, @@ -161,18 +149,15 @@ def _define_enums( field_descriptor (FieldDescriptorProto): Descriptor of a field. """ - # if param is a list, then it sets param to 1st element in list if isinstance(param, list): param = param[0] # if there are no enums/param is a different enum and is a python enum, defines a enum field if param.__class__.__name__ not in [ enum.name for enum in file_descriptor.enum_type ] and isinstance(param, Enum): - # Define a enum class enum_descriptor = file_descriptor.enum_type.add() enum_descriptor.name = param.__class__.__name__ - # Add constants to enum class for name, number in param.__class__.__members__.items(): enum_value_descriptor = enum_descriptor.value.add() enum_value_descriptor.name = name @@ -184,13 +169,11 @@ def _define_enums( field_descriptor.type_name = param.__class__.__name__ -def _define_fields( - message_proto: Any, metadata: ParameterMetadata, index: int, param: Any, is_python_enum: bool -) -> Any: +def _create_fields(message_proto: Any, metadata: ParameterMetadata, index: int) -> Any: """Implement a field in 'message_proto'. Args: - message_proto (message_type): A message instance in 'file_descriptor_proto'. + message_proto (message_type): A message instance in '_FILE_DESCRIPTOR_PROTO'. metadata (ParameterMetadata): Metadata of 'param'. @@ -203,20 +186,17 @@ def _define_fields( Returns: Any: field_descriptor of 'param' or None if 'param' is not equal_to_default_value. """ - # exits if param is None or eqaul to default value - if not _equal_to_default_value(metadata=metadata, param=param, is_python_enum=is_python_enum): - field_descriptor = message_proto.field.add() - - field_descriptor.number = index - field_descriptor.name = f"field_{index}" - field_descriptor.type = metadata.type - # if a value is an array then it's labled as repeated and packed - if metadata.repeated: - field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED - field_descriptor.options.packed = True - else: - field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL - # if a value is a message then assign type name to it's full name - if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: - field_descriptor.type_name = metadata.message_type - return field_descriptor + field_descriptor = message_proto.field.add() + field_descriptor.number = index + field_descriptor.name = f"field_{index}" + field_descriptor.type = metadata.type + + if metadata.repeated: + field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL + + if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + field_descriptor.type_name = metadata.message_type + return field_descriptor diff --git a/ni_measurement_plugin_sdk_service/_internal/service_manager.py b/ni_measurement_plugin_sdk_service/_internal/service_manager.py index e6e5deebe..4ea6814e1 100644 --- a/ni_measurement_plugin_sdk_service/_internal/service_manager.py +++ b/ni_measurement_plugin_sdk_service/_internal/service_manager.py @@ -9,7 +9,9 @@ MeasurementServiceServicerV1, MeasurementServiceServicerV2, ) -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v1 import ( measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc, ) @@ -18,7 +20,10 @@ ) from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient, ServiceLocation from ni_measurement_plugin_sdk_service.grpc.loggers import ServerLogger -from ni_measurement_plugin_sdk_service.measurement.info import MeasurementInfo, ServiceInfo +from ni_measurement_plugin_sdk_service.measurement.info import ( + MeasurementInfo, + ServiceInfo, +) _logger = logging.getLogger(__name__) _V1_INTERFACE = "ni.measurementlink.measurement.v1.MeasurementService" @@ -102,6 +107,7 @@ def start( output_parameter_list, measure_function, owner, + service_info, ) v1_measurement_service_pb2_grpc.add_MeasurementServiceServicer_to_server( servicer_v1, self._server @@ -113,6 +119,7 @@ def start( output_parameter_list, measure_function, owner, + service_info, ) v2_measurement_service_pb2_grpc.add_MeasurementServiceServicer_to_server( servicer_v2, self._server diff --git a/tests/unit/test_decoder.py b/tests/unit/test_decoder.py index 46e49c7b6..36b74fa9c 100644 --- a/tests/unit/test_decoder.py +++ b/tests/unit/test_decoder.py @@ -10,7 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import decoder, encoder +from ni_measurement_plugin_sdk_service._internal.parameter import decoder from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, @@ -52,131 +52,6 @@ class Countries(IntEnum): BIG_MESSAGE_SIZE = 100 -@pytest.mark.parametrize( - "test_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - [ - -0.9999, - -0.9999, - -13, - 1, - 1000, - 2, - True, - "////", - [5.5, -13.3, 1, 0.0, -99.9999], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - ], -) -def test___serializer___serialize_parameter___successful_serialization(test_values): - default_values = test_values - parameter = _get_test_parameter_by_id(default_values) - - # Custom Serialization - custom_serialized_bytes = encoder.serialize_parameters(parameter, test_values) - - _validate_serialized_bytes(custom_serialized_bytes, test_values) - - -@pytest.mark.parametrize( - "default_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - [ - -0.9999, - -0.9999, - -13, - 1, - 1000, - 2, - True, - "////", - [5.5, -13.3, 1, 0.0, -99.9999], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - ], -) -def test___serializer___serialize_default_parameter___successful_serialization(default_values): - parameter = _get_test_parameter_by_id(default_values) - - # Custom Serialization - custom_serialized_bytes = encoder.serialize_default_values(parameter) - - _validate_serialized_bytes(custom_serialized_bytes, default_values) - - @pytest.mark.parametrize( "values", [ @@ -273,23 +148,6 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( assert parameter_value_by_id == pytest.approx(expected_parameter_value_by_id) -def test___big_message___serialize_parameters___returns_serialized_data() -> None: - parameter_metadata_by_id = _get_big_message_metadata_by_id() - values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] - expected_message = _get_big_message(values) - - serialized_data = encoder.serialize_parameters(parameter_metadata_by_id, values) - - message = BigMessage.FromString(serialized_data) - assert message.ListFields() == pytest.approx(expected_message.ListFields()) - - -def _validate_serialized_bytes(custom_serialized_bytes, values): - # Serialization using gRPC Any - grpc_serialized_data = _get_grpc_serialized_data(values) - assert grpc_serialized_data == custom_serialized_bytes - - def _get_grpc_serialized_data(values): grpc_parameter = _get_test_grpc_message(values) parameter_any = any_pb2.Any() diff --git a/tests/unit/test_decoder_strategy.py b/tests/unit/test_decoder_strategy.py index f1d92a277..fef477250 100644 --- a/tests/unit/test_decoder_strategy.py +++ b/tests/unit/test_decoder_strategy.py @@ -37,7 +37,7 @@ ), ], ) -def test___serialization_strategy___get_decoder___returns_expected_decoder( +def test___decoder_strategy___get_decoder___returns_expected_decoder( type, is_repeated, message_type, expected_decoder ): decoder = decoder_strategy.get_decoder(type, is_repeated, message_type) @@ -61,7 +61,7 @@ def test___serialization_strategy___get_decoder___returns_expected_decoder( (type_pb2.Field.TYPE_MESSAGE, True, []), ], ) -def test___serialization_strategy___get_default_value___returns_type_defaults( +def test___decoder_strategy___get_default_value___returns_type_defaults( type, is_repeated, expected_default_value ): default_value = decoder_strategy.get_type_default(type, is_repeated) diff --git a/tests/unit/test_encoder.py b/tests/unit/test_encoder.py new file mode 100644 index 000000000..424eaa3f3 --- /dev/null +++ b/tests/unit/test_encoder.py @@ -0,0 +1,241 @@ +"""Contains tests to validate serializer.py.""" + +from enum import Enum, IntEnum + +import pytest +from google.protobuf import descriptor_pool + +from ni_measurement_plugin_sdk_service._internal.parameter import encoder +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, +) +from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo +from tests.unit.test_decoder import ( + _get_big_message, + _get_big_message_metadata_by_id, + _get_grpc_serialized_data, + _get_test_parameter_by_id, +) +from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage + + +class DifferentColor(Enum): + """Non-primary colors used for testing enum-typed config and output.""" + + PURPLE = 0 + ORANGE = 1 + TEAL = 2 + BROWN = 3 + + +class Countries(IntEnum): + """Countries enum used for testing enum-typed config and output.""" + + AMERICA = 0 + TAIWAN = 1 + AUSTRALIA = 2 + CANADA = 3 + + +double_xy_data = xydata_pb2.DoubleXYData() +double_xy_data.x_data.append(4) +double_xy_data.y_data.append(6) + +double_xy_data2 = xydata_pb2.DoubleXYData() +double_xy_data2.x_data.append(8) +double_xy_data2.y_data.append(10) + +double_xy_data_array = [double_xy_data, double_xy_data2] + +# This should match the number of fields in bigmessage.proto. +BIG_MESSAGE_SIZE = 100 + + +@pytest.mark.parametrize( + "test_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + [ + -0.9999, + -0.9999, + -13, + 1, + 1000, + 2, + True, + "", + [5.5, -13.3, 1, 0.0, -99.9999], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serializer___serialize_parameter___successful_serialization(test_values): + default_values = test_values + parameter = _get_test_parameter_by_id(default_values) + + # Custom Serialization + custom_serialized_bytes = encoder.serialize_parameters( + parameter, + test_values, + service_info=ServiceInfo(service_class="serialize_parameters", description_url=""), + ) + + _validate_serialized_bytes(custom_serialized_bytes, test_values) + + +@pytest.mark.parametrize( + "default_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + [ + -0.9999, + -0.9999, + -13, + 1, + 1000, + 2, + False, + "////", + [5.5, -13.3, 1, 0.0, -99.9999], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serializer___serialize_default_parameter___successful_serialization(default_values): + parameter = _get_test_parameter_by_id(default_values) + + # Custom Serialization + custom_serialized_bytes = encoder.serialize_default_values( + parameter, service_info=ServiceInfo(service_class="default_serialize", description_url="") + ) + + _validate_serialized_bytes(custom_serialized_bytes, default_values) + + +def test___big_message___serialize_parameters___returns_serialized_data() -> None: + parameter_metadata_by_id = _get_big_message_metadata_by_id() + values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] + expected_message = _get_big_message(values) + + serialized_data = encoder.serialize_parameters( + parameter_metadata_by_id, + values, + service_info=ServiceInfo(service_class="big_message", description_url=""), + ) + + message = BigMessage.FromString(serialized_data) + assert message.ListFields() == pytest.approx(expected_message.ListFields()) + + +@pytest.mark.parametrize( + "test_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serialize_parameter_multiple_times___returns_one_message_type(test_values): + for i in range(100): + test___serializer___serialize_parameter___successful_serialization(test_values) + pool = descriptor_pool.Default() + file_descriptor = pool.FindFileByName("serializeparameters") + message_dict = file_descriptor.message_types_by_name + assert len(message_dict) == 1 + + +def _validate_serialized_bytes(custom_serialized_bytes, values): + # Serialization using gRPC Any + grpc_serialized_data = _get_grpc_serialized_data(values) + assert grpc_serialized_data == custom_serialized_bytes From 7ebd83fb8003fb32dfe10aaca1366bea34fba901 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 2 Jul 2024 08:13:51 -0500 Subject: [PATCH 13/25] Fixed docstrings and reordered encoder. --- .../_internal/parameter/encoder.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 2ddcfd4e2..0988028cd 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -27,6 +27,8 @@ def serialize_parameters( parameter_value (Sequence[Any]): Parameter values to serialize. + Service_info (ServiceInfo): Unique service name. + Returns: bytes: Serialized byte string containing parameter values. """ @@ -58,6 +60,27 @@ def serialize_parameters( return message_instance.SerializeToString() +def serialize_default_values( + parameter_metadata_dict: Dict[int, ParameterMetadata], service_info: ServiceInfo +) -> bytes: + """Serialize the Default values in the Metadata. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. + + Service_info (ServiceInfo): Unique service name. + + Returns: + bytes: Serialized byte string containing default values. + """ + default_value_parameter_array = [ + parameter.default_value for parameter in parameter_metadata_dict.values() + ] + return serialize_parameters( + parameter_metadata_dict, default_value_parameter_array, service_info + ) + + def _create_message_type( parameter_values: Sequence[Any], parameter_metadata_dict: Dict[int, ParameterMetadata], @@ -73,8 +96,10 @@ def _create_message_type( message_name (str): Service class name. + pool (descriptor_pool.DescriptorPool): Descriptor pool holding file descriptors. + Returns: - Any: A message class based on a defined message_descriptor + Any: A message descriptor based on a defined message_descriptor """ file_descriptor = descriptor_pb2.FileDescriptorProto() file_descriptor.name = message_name @@ -85,7 +110,7 @@ def _create_message_type( # Initialize the message with fields defined for i, parameter in enumerate(parameter_values, start=1): parameter_metadata = parameter_metadata_dict[i] - field_descriptor = _create_fields( + field_descriptor = _create_field( message_proto=message_proto, metadata=parameter_metadata, index=i ) if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: @@ -98,25 +123,6 @@ def _create_message_type( return pool.FindMessageTypeByName(f"{file_descriptor.package}.{message_proto.name}") -def serialize_default_values( - parameter_metadata_dict: Dict[int, ParameterMetadata], service_info: ServiceInfo -) -> bytes: - """Serialize the Default values in the Metadata. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. - - Returns: - bytes: Serialized byte string containing default values. - """ - default_value_parameter_array = [ - parameter.default_value for parameter in parameter_metadata_dict.values() - ] - return serialize_parameters( - parameter_metadata_dict, default_value_parameter_array, service_info - ) - - def _get_enum_values(param: Any) -> Any: """Get's value of an enum. @@ -169,7 +175,9 @@ def _create_enum_type( field_descriptor.type_name = param.__class__.__name__ -def _create_fields(message_proto: Any, metadata: ParameterMetadata, index: int) -> Any: +def _create_field( + message_proto: Any, metadata: ParameterMetadata, index: int +) -> FieldDescriptorProto: """Implement a field in 'message_proto'. Args: @@ -179,12 +187,8 @@ def _create_fields(message_proto: Any, metadata: ParameterMetadata, index: int) index (int): 'param' index in parameter_values - param (Any): A value/parameter of parameter_values. - - is_python_enum (boolean): True if 'param' is a enum from python's libraries. - Returns: - Any: field_descriptor of 'param' or None if 'param' is not equal_to_default_value. + Any: field_descriptor of 'param'. """ field_descriptor = message_proto.field.add() field_descriptor.number = index From 48d880360daa388db8b839de311b4f962d85fe27 Mon Sep 17 00:00:00 2001 From: LazeringDeath <94755334+LazeringDeath@users.noreply.github.com> Date: Wed, 10 Jul 2024 12:19:31 -0500 Subject: [PATCH 14/25] [DRAFT] Message decoder (#780) Implemented decoder and moved helper functions to serialization_strategy. --- .../_internal/grpc_servicer.py | 4 +- .../_internal/parameter/decoder.py | 173 ++++------------ .../_internal/parameter/decoder_strategy.py | 168 --------------- .../_internal/parameter/default_value.py | 22 ++ .../_internal/parameter/encoder.py | 149 ++------------ .../_internal/parameter/metadata.py | 2 +- .../parameter/serialization_strategy.py | 194 ++++++++++++++++++ tests/unit/test_decoder.py | 20 +- tests/unit/test_decoder_strategy.py | 69 ------- tests/unit/test_encoder.py | 2 +- tests/unit/test_serialization_strategy.py | 30 +++ 11 files changed, 323 insertions(+), 510 deletions(-) delete mode 100644 ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py create mode 100644 ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py create mode 100644 ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py delete mode 100644 tests/unit/test_decoder_strategy.py create mode 100644 tests/unit/test_serialization_strategy.py diff --git a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index 8a282b404..f1e93e226 100644 --- a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -236,7 +236,7 @@ def Measure( # noqa: N802 - function name should be lowercase ) -> v1_measurement_service_pb2.MeasureResponse: """RPC API that executes the registered measurement method.""" mapping_by_id = decoder.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value + self._configuration_metadata, request.configuration_parameters.value, self._service_info ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -348,7 +348,7 @@ def Measure( # noqa: N802 - function name should be lowercase ) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]: """RPC API that executes the registered measurement method.""" mapping_by_id = decoder.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value + self._configuration_metadata, request.configuration_parameters.value, self._service_info ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index c38d4d7ca..d71e78dcf 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -1,24 +1,25 @@ """Parameter Serializer.""" -from typing import Any, Dict, cast +from typing import Any, Dict -from google.protobuf.descriptor import FieldDescriptor -from google.protobuf.internal.decoder import ( # type: ignore[attr-defined] - _DecodeSignedVarint32, -) -from google.protobuf.message import Message +from google.protobuf import descriptor_pool, message_factory +from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from ni_measurement_plugin_sdk_service._internal.parameter import decoder_strategy from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, - get_enum_values_annotation, ) - -_GRPC_WIRE_TYPE_BIT_WIDTH = 3 +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import ( + _get_enum_type, + create_message_type, + deserialize_enum_parameter, +) +from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo def deserialize_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_bytes: bytes, + service_info: ServiceInfo, ) -> Dict[int, Any]: """Deserialize the bytes of the parameter based on the metadata. @@ -27,124 +28,40 @@ def deserialize_parameters( parameter_bytes (bytes): Byte string to deserialize. - Returns: - Dict[int, Any]: Deserialized parameters by ID - """ - # Getting overlapping parameters - overlapping_parameter_by_id = _get_overlapping_parameters( - parameter_metadata_dict, parameter_bytes - ) - - # Deserialization enum parameters to their user-defined type - _deserialize_enum_parameters(parameter_metadata_dict, overlapping_parameter_by_id) - - # Adding missing parameters with type defaults - missing_parameters = _get_missing_parameters( - parameter_metadata_dict, overlapping_parameter_by_id - ) - overlapping_parameter_by_id.update(missing_parameters) - return overlapping_parameter_by_id - - -def _get_overlapping_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes -) -> Dict[int, Any]: - """Get the parameters present in both `parameter_metadata_dict` and `parameter_bytes`. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_bytes (bytes): bytes of Parameter that need to be deserialized. - - Raises: - Exception: If the protobuf filed index is invalid. + Service_info (ServiceInfo): Unique service name. Returns: - Dict[int, Any]: Overlapping Parameters by ID. + Dict[int, Any]: Deserialized parameters by ID """ - # inner_decoder update the overlapping_parameters - overlapping_parameters_by_id: Dict[int, Any] = {} - position = 0 - parameter_bytes_memory_view = memoryview(parameter_bytes) - while position < len(parameter_bytes): - (tag, position) = _DecodeSignedVarint32(parameter_bytes_memory_view, position) - field_index = tag >> _GRPC_WIRE_TYPE_BIT_WIDTH - if field_index not in parameter_metadata_dict: - raise Exception( - f"Error occurred while reading the parameter - given protobuf index '{field_index}' is invalid." + pool = descriptor_pool.Default() + service_name = "".join(char for char in service_info.service_class if char.isalpha()) + message_name = service_name + "DESERIALIZE" + try: + message_proto = pool.FindMessageTypeByName(f"{message_name}.{message_name}") + except KeyError: + message_proto = create_message_type(parameter_metadata_dict, message_name, pool) + message_instance = message_factory.GetMessageClass(message_proto)() + + parameter_values = {} + message_instance.ParseFromString(parameter_bytes) + for i in message_proto.fields_by_number.keys(): + field_name = f"field_{i}" + parameter_metadata = parameter_metadata_dict[i] + value = getattr(message_instance, field_name) + + if ( + parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM + and _get_enum_type(parameter_metadata) is not int + ): + parameter_values[i] = deserialize_enum_parameter( + parameter_metadata, message_instance, field_name ) - field_metadata = parameter_metadata_dict[field_index] - decoder = decoder_strategy.get_decoder( - field_metadata.type, field_metadata.repeated, field_metadata.message_type - ) - inner_decoder = decoder(field_index, cast(FieldDescriptor, field_index)) - position = inner_decoder( - parameter_bytes_memory_view, - position, - len(parameter_bytes), - cast(Message, None), # unused - See serialization_strategy._vector_decoder._new_default - cast(Dict[FieldDescriptor, Any], overlapping_parameters_by_id), - ) - return overlapping_parameters_by_id - - -def _get_missing_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_by_id: Dict[int, Any] -) -> Dict[int, Any]: - """Get the Parameters defined in `parameter_metadata_dict` but not in `parameter_by_id`. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by id. - - parameter_by_id (Dict[int, Any]): Parameters by ID to compare the metadata with. - - Returns: - Dict[int, Any]: Missing parameter(as type defaults) by ID. - """ - missing_parameters = {} - for key, value in parameter_metadata_dict.items(): - if key not in parameter_by_id: - enum_annotations = get_enum_values_annotation(value) - if enum_annotations and not value.repeated: - enum_type = _get_enum_type(value) - missing_parameters[key] = enum_type(0) - else: - missing_parameters[key] = decoder_strategy.get_type_default( - value.type, value.repeated - ) - return missing_parameters - - -def _deserialize_enum_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_by_id: Dict[int, Any] -) -> None: - """Converts all enums in `parameter_by_id` to the user defined enum type. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by id. - - parameter_by_id (Dict[int, Any]): Parameters by ID to compare the metadata with. - """ - for id, value in parameter_by_id.items(): - parameter_metadata = parameter_metadata_dict[id] - if get_enum_values_annotation(parameter_metadata): - enum_type = _get_enum_type(parameter_metadata) - is_protobuf_enum = enum_type is int - if parameter_metadata.repeated: - for index, member_value in enumerate(value): - if is_protobuf_enum: - parameter_by_id[id][index] = member_value - else: - parameter_by_id[id][index] = enum_type(member_value) - else: - if is_protobuf_enum: - parameter_by_id[id] = value - else: - parameter_by_id[id] = enum_type(value) - - -def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: - if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: - return type(parameter_metadata.default_value[0]) - else: - return type(parameter_metadata.default_value) + elif ( + parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE + and not parameter_metadata.repeated + and value.ByteSize() == 0 + ): + parameter_values[i] = None + else: + parameter_values[i] = value + return parameter_values diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py deleted file mode 100644 index f7babb159..000000000 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder_strategy.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Serialization Strategy.""" - -from __future__ import annotations - -from typing import Any, Optional, cast - -from google.protobuf import type_pb2 -from google.protobuf.internal import decoder -from google.protobuf.message import Message - -from ni_measurement_plugin_sdk_service._internal.parameter import _message -from ni_measurement_plugin_sdk_service._internal.parameter._decoder_types import ( - Decoder, - DecoderConstructor, - Key, - PartialDecoderConstructor, -) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( - xydata_pb2, -) - - -def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor: - """Constructs a scalar decoder constructor. - - Takes a field index and a key and returns a Decoder. - - This class returns the Decoder with is_repeated set to False - and is_packed set to False. - """ - - def _unsupported_new_default(message: Optional[Message]) -> Any: - raise NotImplementedError( - "This function should not be called. Verify that you are using up-to-date and compatible versions of the ni-measurement-plugin-sdk-service and protobuf packages." - ) - - def scalar_decoder(field_index: int, key: Key) -> Decoder: - is_repeated = False - is_packed = False - return decoder(field_index, is_repeated, is_packed, key, _unsupported_new_default) - - return scalar_decoder - - -def _vector_decoder( - decoder: DecoderConstructor, is_packed: bool = True -) -> PartialDecoderConstructor: - """Constructs a vector (array) decoder constructor. - - Takes a field index and a key and returns a Decoder. - - This class returns the Decoder with is_repeated set to True - and is_packed defaulting to True. - """ - - def _new_default(unused_message: Optional[Message] = None) -> Any: - return [] - - def vector_decoder(field_index: int, key: Key) -> Decoder: - is_repeated = True - return decoder(field_index, is_repeated, is_packed, key, _new_default) - - return vector_decoder - - -def _double_xy_data_decoder( - decoder: DecoderConstructor, is_repeated: bool -) -> PartialDecoderConstructor: - """Constructs a DoubleXYData decoder constructor. - - Takes a field index and a key and returns a Decoder for DoubleXYData. - """ - - def _new_default(unused_message: Optional[Message] = None) -> Any: - return xydata_pb2.DoubleXYData() - - def message_decoder(field_index: int, key: Key) -> Decoder: - is_packed = True - return decoder(field_index, is_repeated, is_packed, key, _new_default) - - return message_decoder - - -# Cast works around this issue in typeshed -# https://github.com/python/typeshed/issues/10697 -FloatDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.FloatDecoder)) -DoubleDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.DoubleDecoder)) -Int32Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.Int32Decoder)) -UInt32Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt32Decoder)) -Int64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.Int64Decoder)) -UInt64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt64Decoder)) -BoolDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.BoolDecoder)) -StringDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.StringDecoder)) -XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor, is_repeated=False) - -FloatArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.FloatDecoder)) -DoubleArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.DoubleDecoder)) -Int32ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.Int32Decoder)) -UInt32ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.UInt32Decoder)) -Int64ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.Int64Decoder)) -UInt64ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.UInt64Decoder)) -BoolArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.BoolDecoder)) -StringArrayDecoder = _vector_decoder( - cast(DecoderConstructor, decoder.StringDecoder), is_packed=False -) -XYDataArrayDecoder = _double_xy_data_decoder( - _message._message_decoder_constructor, is_repeated=True -) - -_FIELD_TYPE_TO_DECODER_MAPPING = { - type_pb2.Field.TYPE_FLOAT: (FloatDecoder, FloatArrayDecoder), - type_pb2.Field.TYPE_DOUBLE: (DoubleDecoder, DoubleArrayDecoder), - type_pb2.Field.TYPE_INT32: (Int32Decoder, Int32ArrayDecoder), - type_pb2.Field.TYPE_INT64: (Int64Decoder, Int64ArrayDecoder), - type_pb2.Field.TYPE_UINT32: (UInt32Decoder, UInt32ArrayDecoder), - type_pb2.Field.TYPE_UINT64: (UInt64Decoder, UInt64ArrayDecoder), - type_pb2.Field.TYPE_BOOL: (BoolDecoder, BoolArrayDecoder), - type_pb2.Field.TYPE_STRING: (StringDecoder, StringArrayDecoder), - type_pb2.Field.TYPE_ENUM: (Int32Decoder, Int32ArrayDecoder), -} - -_MESSAGE_TYPE_TO_DECODER = { - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder, -} - -_ARRAY_MESSAGE_TYPE_TO_DECODER = { - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataArrayDecoder, -} - - -def get_decoder( - type: type_pb2.Field.Kind.ValueType, repeated: bool, message_type: str = "" -) -> PartialDecoderConstructor: - """Get the appropriate partial decoder constructor for the specified type.""" - decoder_mapping = _FIELD_TYPE_TO_DECODER_MAPPING.get(type) - if decoder_mapping is not None: - scalar_decoder, array_decoder = decoder_mapping - return array_decoder if repeated else scalar_decoder - elif type == type_pb2.Field.Kind.TYPE_MESSAGE: - if repeated: - decoder = _ARRAY_MESSAGE_TYPE_TO_DECODER.get(message_type) - else: - decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type) - if decoder is None: - raise ValueError(f"Unknown message type '{message_type}'") - return decoder - else: - raise ValueError(f"Error can not decode type '{type}'") - - -_type_default_mapping = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), -} - - -def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: - """Get the default value for the give type.""" - if repeated: - return list() - return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py b/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py new file mode 100644 index 000000000..168f4db35 --- /dev/null +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py @@ -0,0 +1,22 @@ +from typing import Any + +from google.protobuf import type_pb2 + +_type_default_mapping = { + type_pb2.Field.TYPE_FLOAT: float(), + type_pb2.Field.TYPE_DOUBLE: float(), + type_pb2.Field.TYPE_INT32: int(), + type_pb2.Field.TYPE_INT64: int(), + type_pb2.Field.TYPE_UINT32: int(), + type_pb2.Field.TYPE_UINT64: int(), + type_pb2.Field.TYPE_BOOL: bool(), + type_pb2.Field.TYPE_STRING: str(), + type_pb2.Field.TYPE_ENUM: int(), +} + + +def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: + """Get the default value for the give type.""" + if repeated: + return list() + return _type_default_mapping.get(type) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 0988028cd..99647eb75 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -1,18 +1,21 @@ -# enums and default values -from enum import Enum +"""Parameter Serializer.""" + from typing import Any, Dict, Sequence -from google.protobuf import descriptor_pb2, descriptor_pool, message_factory +from google.protobuf import descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import ( +from ni_measurement_plugin_sdk_service._internal.parameter.default_value import ( get_type_default, ) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import ( + create_message_type, + get_enum_values, +) from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo -from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor def serialize_parameters( @@ -25,7 +28,7 @@ def serialize_parameters( Args: parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - parameter_value (Sequence[Any]): Parameter values to serialize. + parameter_values (Sequence[Any]): Parameter values to serialize. Service_info (ServiceInfo): Unique service name. @@ -33,20 +36,19 @@ def serialize_parameters( bytes: Serialized byte string containing parameter values. """ pool = descriptor_pool.Default() - message_name = "".join(char for char in service_info.service_class if char.isalpha()) + service_name = "".join(char for char in service_info.service_class if char.isalpha()) + message_name = service_name + "SERIALIZE" # Tries to find a message type in pool with message_name else it creates one try: message_proto = pool.FindMessageTypeByName(f"{message_name}.{message_name}") except KeyError: - message_proto = _create_message_type( - parameter_values, parameter_metadata_dict, message_name, pool - ) + message_proto = create_message_type(parameter_metadata_dict, message_name, pool) message_instance = message_factory.GetMessageClass(message_proto)() for i, parameter in enumerate(parameter_values, start=1): field_name = f"field_{i}" parameter_metadata = parameter_metadata_dict[i] - parameter = _get_enum_values(param=parameter) + parameter = get_enum_values(param=parameter) type_default_value = get_type_default(parameter_metadata.type, parameter_metadata.repeated) # Doesn't assign default values or None values to fields @@ -79,128 +81,3 @@ def serialize_default_values( return serialize_parameters( parameter_metadata_dict, default_value_parameter_array, service_info ) - - -def _create_message_type( - parameter_values: Sequence[Any], - parameter_metadata_dict: Dict[int, ParameterMetadata], - message_name: str, - pool: descriptor_pool.DescriptorPool, -) -> Any: - """Creates a message descriptor with the fields defined in a file descriptor proto. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_value (Sequence[Any]): Parameter values to serialize. - - message_name (str): Service class name. - - pool (descriptor_pool.DescriptorPool): Descriptor pool holding file descriptors. - - Returns: - Any: A message descriptor based on a defined message_descriptor - """ - file_descriptor = descriptor_pb2.FileDescriptorProto() - file_descriptor.name = message_name - file_descriptor.package = message_name - message_proto = file_descriptor.message_type.add() - message_proto.name = message_name - - # Initialize the message with fields defined - for i, parameter in enumerate(parameter_values, start=1): - parameter_metadata = parameter_metadata_dict[i] - field_descriptor = _create_field( - message_proto=message_proto, metadata=parameter_metadata, index=i - ) - if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: - _create_enum_type( - file_descriptor=file_descriptor, - param=parameter, - field_descriptor=field_descriptor, - ) - pool.Add(file_descriptor) - return pool.FindMessageTypeByName(f"{file_descriptor.package}.{message_proto.name}") - - -def _get_enum_values(param: Any) -> Any: - """Get's value of an enum. - - Args: - param (Any): A value/parameter of parameter_values. - - Returns: - Any: An enum value or a list of enums or the 'param'. - """ - if param == []: - return param - if isinstance(param, list) and isinstance(param[0], Enum): - return [x.value for x in param] - elif isinstance(param, Enum): - return param.value - return param - - -def _create_enum_type( - file_descriptor: descriptor_pb2.FileDescriptorProto, - param: Any, - field_descriptor: FieldDescriptorProto, -) -> None: - """Implement a enum class in 'file_descriptor'. - - Args: - file_descriptor (FileDescriptorProto): Descriptor of a proto file. - - param (Any): A value/parameter of parameter_values. - - field_descriptor (FieldDescriptorProto): Descriptor of a field. - """ - if isinstance(param, list): - param = param[0] - # if there are no enums/param is a different enum and is a python enum, defines a enum field - if param.__class__.__name__ not in [ - enum.name for enum in file_descriptor.enum_type - ] and isinstance(param, Enum): - enum_descriptor = file_descriptor.enum_type.add() - enum_descriptor.name = param.__class__.__name__ - - for name, number in param.__class__.__members__.items(): - enum_value_descriptor = enum_descriptor.value.add() - enum_value_descriptor.name = name - enum_value_descriptor.number = number.value - # checks enum if it's protobuf or python - try: - field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name - except TypeError: - field_descriptor.type_name = param.__class__.__name__ - - -def _create_field( - message_proto: Any, metadata: ParameterMetadata, index: int -) -> FieldDescriptorProto: - """Implement a field in 'message_proto'. - - Args: - message_proto (message_type): A message instance in '_FILE_DESCRIPTOR_PROTO'. - - metadata (ParameterMetadata): Metadata of 'param'. - - index (int): 'param' index in parameter_values - - Returns: - Any: field_descriptor of 'param'. - """ - field_descriptor = message_proto.field.add() - field_descriptor.number = index - field_descriptor.name = f"field_{index}" - field_descriptor.type = metadata.type - - if metadata.repeated: - field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED - field_descriptor.options.packed = True - else: - field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL - - if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: - field_descriptor.type_name = metadata.message_type - return field_descriptor diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 64ee5ebe4..54b59bbd0 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -10,7 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter.decoder_strategy import get_type_default +from ni_measurement_plugin_sdk_service._internal.parameter.default_value import get_type_default from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py new file mode 100644 index 000000000..027be49c5 --- /dev/null +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py @@ -0,0 +1,194 @@ +"""Serialization Strategy.""" + +import json +from enum import Enum +from typing import Any, Dict + +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + +from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) +from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor + + +def _create_enum_type( + file_descriptor: descriptor_pb2.FileDescriptorProto, + parameter_metadata: ParameterMetadata, + field_descriptor: FieldDescriptorProto, +) -> None: + """Implement a enum class in 'file_descriptor'. + + Args: + file_descriptor (FileDescriptorProto): Descriptor of a proto file. + + parmeter_metadata (ParameterMetadata): Metadata of current field. + + field_descriptor (FieldDescriptorProto): Descriptor of a field. + + Returns: + None: Only creates a enum class in file_descriptor. + """ + enum_dict = json.loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) + is_protobuf = _get_enum_type(parameter_metadata) is int + + if parameter_metadata.repeated: + enum_type_name = parameter_metadata.default_value[0].__class__.__name__ + else: + enum_type_name = parameter_metadata.default_value.__class__.__name__ + + if ( + enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type] + and not is_protobuf + ): + enum_descriptor = file_descriptor.enum_type.add() + enum_descriptor.name = enum_type_name + for name, number in enum_dict.items(): + enum_value_descriptor = enum_descriptor.value.add() + enum_value_descriptor.name = name + enum_value_descriptor.number = number + + if is_protobuf: + field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name + else: + field_descriptor.type_name = enum_type_name + + +def _create_field( + message_proto: Any, metadata: ParameterMetadata, index: int +) -> FieldDescriptorProto: + """Implement a field in 'message_proto'. + + Args: + message_proto (message_type): A message instance in '_FILE_DESCRIPTOR_PROTO'. + + metadata (ParameterMetadata): Metadata of 'param'. + + index (int): 'param' index in parameter_values + + Returns: + Any: field_descriptor of 'param'. + """ + field_descriptor = message_proto.field.add() + field_descriptor.number = index + field_descriptor.name = f"field_{index}" + field_descriptor.type = metadata.type + + if metadata.repeated: + field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL + + if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + field_descriptor.type_name = metadata.message_type + return field_descriptor + + +def _get_enum_field(enum_dict: Dict[Any, int], enum_type: Any, field_value: int) -> Any: + """Get enum type and value from 'field_value'. + + Args: + enum_dict (Dict[Any, int]): List enum class of 'field_value'. + + enum_type (Any): 'field_value' enum class name. + + field_value (int): Default value of current field. + + Returns: + Any: Enum type of 'field_value' from 'enum_dict' with the enum value. + """ + for name in enum_dict.keys(): + enum_value = getattr(enum_type, name) + if field_value == enum_value.value: + return enum_value + + +def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: + if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: + return type(parameter_metadata.default_value[0]) + else: + return type(parameter_metadata.default_value) + + +def create_message_type( + parameter_metadata_dict: Dict[int, ParameterMetadata], + message_name: str, + pool: descriptor_pool.DescriptorPool, +) -> Any: + """Creates a message descriptor with the fields defined in a file descriptor proto. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + message_name (str): Service class name. + + pool (descriptor_pool.DescriptorPool): Descriptor pool holding file descriptors. + + Returns: + Any: A message descriptor based on a defined message_descriptor + """ + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = message_name + file_descriptor.package = message_name + message_proto = file_descriptor.message_type.add() + message_proto.name = message_name + + # Initialize the message with fields defined + for i in parameter_metadata_dict.keys(): + parameter_metadata = parameter_metadata_dict[i] + field_descriptor = _create_field( + message_proto=message_proto, metadata=parameter_metadata, index=i + ) + if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: + _create_enum_type( + file_descriptor=file_descriptor, + parameter_metadata=parameter_metadata, + field_descriptor=field_descriptor, + ) + pool.Add(file_descriptor) + return pool.FindMessageTypeByName(f"{file_descriptor.package}.{message_proto.name}") + + +def get_enum_values(param: Any) -> Any: + """Get's value of an enum. + + Args: + param (Any): A value/parameter of parameter_values. + + Returns: + Any: An enum value or a list of enums or the 'param'. + """ + if param == []: + return param + if isinstance(param, list) and isinstance(param[0], Enum): + return [x.value for x in param] + elif isinstance(param, Enum): + return param.value + return param + + +def deserialize_enum_parameter( + parameter_metadata: ParameterMetadata, message_instance: Any, field_name: str +) -> Any: + """Convert all enums into the user defined enum type. + + Args: + parameter_metadata (ParameterMetadata): Metadata of current enum value. + + message_instance (Any): Message class of all intialized fields. + + field_name (str): Name of current field. + + Returns: + Any: Enum type or a list of enum types. + """ + enum_dict = json.loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) + field_value = getattr(message_instance, field_name) + enum_type = _get_enum_type(parameter_metadata) + if parameter_metadata.repeated: + return [_get_enum_field(enum_dict, enum_type, value) for value in field_value] + else: + return _get_enum_field(enum_dict, enum_type, field_value) diff --git a/tests/unit/test_decoder.py b/tests/unit/test_decoder.py index 36b74fa9c..f361dfaaa 100644 --- a/tests/unit/test_decoder.py +++ b/tests/unit/test_decoder.py @@ -15,7 +15,10 @@ ParameterMetadata, TypeSpecialization, ) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, +) +from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo from tests.utilities.stubs.serialization import test_pb2 from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage @@ -85,8 +88,11 @@ def test___serializer___deserialize_parameter___successful_deserialization(value parameter = _get_test_parameter_by_id(values) grpc_serialized_data = _get_grpc_serialized_data(values) - parameter_value_by_id = decoder.deserialize_parameters(parameter, grpc_serialized_data) - + parameter_value_by_id = decoder.deserialize_parameters( + parameter, + grpc_serialized_data, + ServiceInfo(service_class="deserializer", description_url=""), + ) assert list(parameter_value_by_id.values()) == values @@ -118,7 +124,9 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): double_xy_data_array, ] parameter = _get_test_parameter_by_id(nonzero_defaults) - parameter_value_by_id = decoder.deserialize_parameters(parameter, bytes()) + parameter_value_by_id = decoder.deserialize_parameters( + parameter, bytes(), ServiceInfo(service_class="empty_buffer", description_url="") + ) for key, value in parameter_value_by_id.items(): parameter_metadata = parameter[key] @@ -142,7 +150,9 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} parameter_value_by_id = decoder.deserialize_parameters( - parameter_metadata_by_id, serialized_data + parameter_metadata_by_id, + serialized_data, + ServiceInfo(service_class="big_message", description_url=""), ) assert parameter_value_by_id == pytest.approx(expected_parameter_value_by_id) diff --git a/tests/unit/test_decoder_strategy.py b/tests/unit/test_decoder_strategy.py deleted file mode 100644 index fef477250..000000000 --- a/tests/unit/test_decoder_strategy.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Contains tests to validate the serializationstrategy.py. """ - -import pytest -from google.protobuf import type_pb2 - -from ni_measurement_plugin_sdk_service._internal.parameter import ( - decoder_strategy, -) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( - xydata_pb2, -) - - -@pytest.mark.parametrize( - "type,is_repeated,message_type,expected_decoder", - [ - (type_pb2.Field.TYPE_FLOAT, False, "", decoder_strategy.FloatDecoder), - (type_pb2.Field.TYPE_DOUBLE, False, "", decoder_strategy.DoubleDecoder), - (type_pb2.Field.TYPE_INT32, False, "", decoder_strategy.Int32Decoder), - (type_pb2.Field.TYPE_INT64, False, "", decoder_strategy.Int64Decoder), - (type_pb2.Field.TYPE_UINT32, False, "", decoder_strategy.UInt32Decoder), - (type_pb2.Field.TYPE_UINT64, False, "", decoder_strategy.UInt64Decoder), - (type_pb2.Field.TYPE_BOOL, False, "", decoder_strategy.BoolDecoder), - (type_pb2.Field.TYPE_STRING, False, "", decoder_strategy.StringDecoder), - (type_pb2.Field.TYPE_ENUM, False, "", decoder_strategy.Int32Decoder), - ( - type_pb2.Field.TYPE_MESSAGE, - False, - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - decoder_strategy.XYDataDecoder, - ), - ( - type_pb2.Field.TYPE_MESSAGE, - True, - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - decoder_strategy.XYDataArrayDecoder, - ), - ], -) -def test___decoder_strategy___get_decoder___returns_expected_decoder( - type, is_repeated, message_type, expected_decoder -): - decoder = decoder_strategy.get_decoder(type, is_repeated, message_type) - - assert decoder == expected_decoder - - -@pytest.mark.parametrize( - "type,is_repeated,expected_default_value", - [ - (type_pb2.Field.TYPE_FLOAT, False, 0.0), - (type_pb2.Field.TYPE_DOUBLE, False, 0.0), - (type_pb2.Field.TYPE_INT32, False, 0), - (type_pb2.Field.TYPE_INT64, False, 0), - (type_pb2.Field.TYPE_UINT32, False, 0), - (type_pb2.Field.TYPE_UINT64, False, 0), - (type_pb2.Field.TYPE_BOOL, False, False), - (type_pb2.Field.TYPE_STRING, False, ""), - (type_pb2.Field.TYPE_ENUM, False, 0), - (type_pb2.Field.TYPE_MESSAGE, False, None), - (type_pb2.Field.TYPE_MESSAGE, True, []), - ], -) -def test___decoder_strategy___get_default_value___returns_type_defaults( - type, is_repeated, expected_default_value -): - default_value = decoder_strategy.get_type_default(type, is_repeated) - - assert default_value == expected_default_value diff --git a/tests/unit/test_encoder.py b/tests/unit/test_encoder.py index 424eaa3f3..c0ef11b5c 100644 --- a/tests/unit/test_encoder.py +++ b/tests/unit/test_encoder.py @@ -230,7 +230,7 @@ def test___serialize_parameter_multiple_times___returns_one_message_type(test_va for i in range(100): test___serializer___serialize_parameter___successful_serialization(test_values) pool = descriptor_pool.Default() - file_descriptor = pool.FindFileByName("serializeparameters") + file_descriptor = pool.FindFileByName("serializeparametersSERIALIZE") message_dict = file_descriptor.message_types_by_name assert len(message_dict) == 1 diff --git a/tests/unit/test_serialization_strategy.py b/tests/unit/test_serialization_strategy.py new file mode 100644 index 000000000..42678a6ac --- /dev/null +++ b/tests/unit/test_serialization_strategy.py @@ -0,0 +1,30 @@ +"""Contains tests to validate the serializationstrategy.py. """ + +import pytest +from google.protobuf import type_pb2 + +from ni_measurement_plugin_sdk_service._internal.parameter import ( + default_value, +) + + +@pytest.mark.parametrize( + "type,is_repeated,expected_default_value", + [ + (type_pb2.Field.TYPE_FLOAT, False, 0.0), + (type_pb2.Field.TYPE_DOUBLE, False, 0.0), + (type_pb2.Field.TYPE_INT32, False, 0), + (type_pb2.Field.TYPE_INT64, False, 0), + (type_pb2.Field.TYPE_UINT32, False, 0), + (type_pb2.Field.TYPE_UINT64, False, 0), + (type_pb2.Field.TYPE_BOOL, False, False), + (type_pb2.Field.TYPE_STRING, False, ""), + (type_pb2.Field.TYPE_ENUM, False, 0), + (type_pb2.Field.TYPE_MESSAGE, False, None), + (type_pb2.Field.TYPE_MESSAGE, True, []), + ], +) +def test___get_default_value___returns_type_defaults(type, is_repeated, expected_default_value): + test_default_value = default_value.get_type_default(type, is_repeated) + + assert test_default_value == expected_default_value From 239eb36bc0ea70156a330d83d724f6b0e489cc97 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 15 Jul 2024 13:02:32 -0500 Subject: [PATCH 15/25] Creates 2 messages per service, renamed message/fields, and reordered helper functions. --- .../_internal/grpc_servicer.py | 32 ++- .../_internal/parameter/decoder.py | 64 ++++-- .../_internal/parameter/encoder.py | 46 +++-- .../_internal/parameter/metadata.py | 13 +- .../parameter/serialization_descriptors.py | 177 ++++++++++++++++ .../parameter/serialization_strategy.py | 194 ------------------ .../_internal/service_manager.py | 10 + tests/unit/test_decoder.py | 34 ++- tests/unit/test_encoder.py | 37 +++- 9 files changed, 352 insertions(+), 255 deletions(-) create mode 100644 ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py delete mode 100644 ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py diff --git a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index f1e93e226..e75293223 100644 --- a/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -137,11 +137,11 @@ def _get_mapping_by_parameter_name( def _serialize_outputs( - output_metadata: Dict[int, ParameterMetadata], outputs: Any, service_info: ServiceInfo + output_metadata: Dict[int, ParameterMetadata], outputs: Any, service_name: str ) -> any_pb2.Any: if isinstance(outputs, collections.abc.Sequence): return any_pb2.Any( - value=encoder.serialize_parameters(output_metadata, outputs, service_info) + value=encoder.serialize_parameters(output_metadata, outputs, service_name) ) elif outputs is None: raise ValueError(f"Measurement function returned None") @@ -205,7 +205,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata, self._service_info + self._configuration_metadata, self._get_service_name() + ".Inputs" ) for field_number, output_metadata in self._output_metadata.items(): @@ -236,7 +236,9 @@ def Measure( # noqa: N802 - function name should be lowercase ) -> v1_measurement_service_pb2.MeasureResponse: """RPC API that executes the registered measurement method.""" mapping_by_id = decoder.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value, self._service_info + self._configuration_metadata, + request.configuration_parameters.value, + self._get_service_name() + ".Inputs", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -265,9 +267,15 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v1_measurement_service_pb2.MeasureResponse: return v1_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs, self._service_info) + outputs=_serialize_outputs( + self._output_metadata, outputs, self._get_service_name() + ".Outputs" + ) ) + def _get_service_name(self) -> str: + service_name = "".join(char for char in self._service_info.service_class if char.isalpha()) + return service_name + class MeasurementServiceServicerV2(v2_measurement_service_pb2_grpc.MeasurementServiceServicer): """Measurement v2 servicer.""" @@ -315,7 +323,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata, self._service_info + self._configuration_metadata, self._get_service_name() + ".Inputs" ) for field_number, output_metadata in self._output_metadata.items(): @@ -348,7 +356,9 @@ def Measure( # noqa: N802 - function name should be lowercase ) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]: """RPC API that executes the registered measurement method.""" mapping_by_id = decoder.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value, self._service_info + self._configuration_metadata, + request.configuration_parameters.value, + self._get_service_name() + ".Inputs", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -376,5 +386,11 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse: return v2_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs, self._service_info) + outputs=_serialize_outputs( + self._output_metadata, outputs, self._get_service_name() + ".Outputs" + ) ) + + def _get_service_name(self) -> str: + service_name = "".join(char for char in self._service_info.service_class if char.isalpha()) + return service_name diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index d71e78dcf..fe5692605 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -1,25 +1,24 @@ """Parameter Serializer.""" +from json import loads from typing import Any, Dict from google.protobuf import descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) -from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import ( +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( _get_enum_type, - create_message_type, - deserialize_enum_parameter, ) -from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo def deserialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes, - service_info: ServiceInfo, + service_name: str, ) -> Dict[int, Any]: """Deserialize the bytes of the parameter based on the metadata. @@ -28,34 +27,27 @@ def deserialize_parameters( parameter_bytes (bytes): Byte string to deserialize. - Service_info (ServiceInfo): Unique service name. + service_name (str): Unique service name. Returns: Dict[int, Any]: Deserialized parameters by ID """ pool = descriptor_pool.Default() - service_name = "".join(char for char in service_info.service_class if char.isalpha()) - message_name = service_name + "DESERIALIZE" - try: - message_proto = pool.FindMessageTypeByName(f"{message_name}.{message_name}") - except KeyError: - message_proto = create_message_type(parameter_metadata_dict, message_name, pool) + message_proto = pool.FindMessageTypeByName(service_name) message_instance = message_factory.GetMessageClass(message_proto)() - parameter_values = {} + message_instance.ParseFromString(parameter_bytes) for i in message_proto.fields_by_number.keys(): - field_name = f"field_{i}" parameter_metadata = parameter_metadata_dict[i] + field_name = parameter_metadata.sanitized_display_name() value = getattr(message_instance, field_name) if ( parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM and _get_enum_type(parameter_metadata) is not int ): - parameter_values[i] = deserialize_enum_parameter( - parameter_metadata, message_instance, field_name - ) + parameter_values[i] = _deserialize_enum_parameter(parameter_metadata, value) elif ( parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE and not parameter_metadata.repeated @@ -65,3 +57,41 @@ def deserialize_parameters( else: parameter_values[i] = value return parameter_values + + +def _deserialize_enum_parameter(parameter_metadata: ParameterMetadata, field_value: Any) -> Any: + """Convert all enums into the user defined enum type. + + Args: + parameter_metadata (ParameterMetadata): Metadata of current enum value. + + field_value (Any): Value of current field. + + Returns: + Any: Enum type or a list of enum types. + """ + enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) + enum_type = _get_enum_type(parameter_metadata) + if parameter_metadata.repeated: + return [_get_enum_field(enum_dict, enum_type, value) for value in field_value] + else: + return _get_enum_field(enum_dict, enum_type, field_value) + + +def _get_enum_field(enum_dict: Dict[Any, int], enum_type: Any, field_value: int) -> Any: + """Get enum type and value from 'field_value'. + + Args: + enum_dict (Dict[Any, int]): List enum class of 'field_value'. + + enum_type (Any): 'field_value' enum class name. + + field_value (int): Default value of current field. + + Returns: + Any: Enum type of 'field_value' from 'enum_dict' with the enum value. + """ + for name in enum_dict.keys(): + enum_value = getattr(enum_type, name) + if field_value == enum_value.value: + return enum_value diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 99647eb75..97c6ef010 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -1,5 +1,6 @@ """Parameter Serializer.""" +from enum import Enum from typing import Any, Dict, Sequence from google.protobuf import descriptor_pool, message_factory @@ -11,17 +12,12 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) -from ni_measurement_plugin_sdk_service._internal.parameter.serialization_strategy import ( - create_message_type, - get_enum_values, -) -from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo def serialize_parameters( parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_values: Sequence[Any], - service_info: ServiceInfo, + service_name: str, ) -> bytes: """Serialize the parameter values in same order based on the metadata_dict. @@ -30,25 +26,19 @@ def serialize_parameters( parameter_values (Sequence[Any]): Parameter values to serialize. - Service_info (ServiceInfo): Unique service name. + service_name (str): Unique service name. Returns: bytes: Serialized byte string containing parameter values. """ pool = descriptor_pool.Default() - service_name = "".join(char for char in service_info.service_class if char.isalpha()) - message_name = service_name + "SERIALIZE" - # Tries to find a message type in pool with message_name else it creates one - try: - message_proto = pool.FindMessageTypeByName(f"{message_name}.{message_name}") - except KeyError: - message_proto = create_message_type(parameter_metadata_dict, message_name, pool) + message_proto = pool.FindMessageTypeByName(service_name) message_instance = message_factory.GetMessageClass(message_proto)() for i, parameter in enumerate(parameter_values, start=1): - field_name = f"field_{i}" parameter_metadata = parameter_metadata_dict[i] - parameter = get_enum_values(param=parameter) + field_name = parameter_metadata.sanitized_display_name() + parameter = _get_enum_values(param=parameter) type_default_value = get_type_default(parameter_metadata.type, parameter_metadata.repeated) # Doesn't assign default values or None values to fields @@ -63,14 +53,14 @@ def serialize_parameters( def serialize_default_values( - parameter_metadata_dict: Dict[int, ParameterMetadata], service_info: ServiceInfo + parameter_metadata_dict: Dict[int, ParameterMetadata], service_name: str ) -> bytes: """Serialize the Default values in the Metadata. Args: parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. - Service_info (ServiceInfo): Unique service name. + service_name (str): Unique service name. Returns: bytes: Serialized byte string containing default values. @@ -79,5 +69,23 @@ def serialize_default_values( parameter.default_value for parameter in parameter_metadata_dict.values() ] return serialize_parameters( - parameter_metadata_dict, default_value_parameter_array, service_info + parameter_metadata_dict, default_value_parameter_array, service_name ) + + +def _get_enum_values(param: Any) -> Any: + """Get's value of an enum. + + Args: + param (Any): A value/parameter of parameter_values. + + Returns: + Any: An enum value or a list of enums or the 'param'. + """ + if param == []: + return param + if isinstance(param, list) and isinstance(param[0], Enum): + return [x.value for x in param] + elif isinstance(param, Enum): + return param.value + return param diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 54b59bbd0..8227a6e88 100644 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -10,7 +10,9 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter.default_value import get_type_default +from ni_measurement_plugin_sdk_service._internal.parameter.default_value import ( + get_type_default, +) from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization @@ -40,6 +42,15 @@ class ParameterMetadata(NamedTuple): Required when 'type' is Kind.TypeMessage. Ignored for any other 'type'. """ + def sanitized_display_name(self) -> str: + """Parameter display name of alpha/numerical characters. + + Returns: + str: Alpha/numerical characters of 'display_name'. + + """ + return "".join(char for char in self.display_name if char.isalnum()) + def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: """Validate and raise exception if the default value does not match the type info. diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py new file mode 100644 index 000000000..bd06ef796 --- /dev/null +++ b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -0,0 +1,177 @@ +"""Serialization Strategy.""" + +from json import loads +from typing import Any, List + +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + +from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) +from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor + + +def _get_output_enum_type( + metadata_enum_list: List[str], + file_descriptor: descriptor_pb2.FileDescriptorProto, +) -> Any: + """Get's matching enum class from 'file_descriptor'. + + Args: + metadata_enum_list (List[str]): Enum names from metadata.annotations. + + file_descriptor: Descriptor of proto file. + + Returns: + Any: Matching enum class in a str type or None when enum is protobuf. + """ + for enum_type in file_descriptor.enum_type: + enum_names = [enum_value.name for enum_value in enum_type.value] + if sorted(metadata_enum_list) == sorted(enum_names): + return enum_type.name + return None + + +def _create_enum_type( + file_descriptor: descriptor_pb2.FileDescriptorProto, + parameter_metadata: ParameterMetadata, + field_descriptor: FieldDescriptorProto, +) -> None: + """Implement a enum class in 'file_descriptor'. + + Args: + file_descriptor (FileDescriptorProto): Descriptor of a proto file. + + parmeter_metadata (ParameterMetadata): Metadata of current field. + + field_descriptor (FieldDescriptorProto): Descriptor of a field. + + Returns: + None: Only creates a enum class in file_descriptor. + """ + enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) + if parameter_metadata.default_value is None: + enum_type_name = _get_output_enum_type( + metadata_enum_list=[enum for enum in enum_dict.keys()], file_descriptor=file_descriptor + ) + else: + enum_type_name = _get_enum_type(parameter_metadata).__name__ + + if enum_type_name == "int" or enum_type_name is None: + field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name + elif enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type]: + enum_descriptor = file_descriptor.enum_type.add() + enum_descriptor.name = enum_type_name + for name, number in enum_dict.items(): + enum_value_descriptor = enum_descriptor.value.add() + enum_value_descriptor.name = name + enum_value_descriptor.number = number + field_descriptor.type_name = enum_descriptor.name + else: + field_descriptor.type_name = enum_type_name + + +def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: + if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: + return type(parameter_metadata.default_value[0]) + else: + return type(parameter_metadata.default_value) + + +def _create_field( + message_proto: Any, metadata: ParameterMetadata, index: int +) -> FieldDescriptorProto: + """Implement a field in 'message_proto'. + + Args: + message_proto (message_type): A message instance in a file descriptor proto. + + metadata (ParameterMetadata): Metadata of 'param'. + + index (int): 'param' index in parameter_values + + Returns: + FieldDescriptorProto: field_descriptor of 'param'. + """ + field_descriptor = message_proto.field.add() + field_descriptor.number = index + field_descriptor.name = metadata.sanitized_display_name() + field_descriptor.type = metadata.type + + if metadata.repeated: + field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL + + if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + field_descriptor.type_name = metadata.message_type + return field_descriptor + + +def create_file_descriptor( + service_name: str, + output_metadata: List[ParameterMetadata], + input_metadata: List[ParameterMetadata], + pool: descriptor_pool.DescriptorPool, +) -> None: + """Creates two message types in one file descriptor proto. + + Args: + service_class_name (str): Unique service name. + + output_metadata (List[ParameterMetadata]): Metadata of output parameters. + + input_metadata (List[ParameterMetadata]): Metadata of input parameters. + + pool (DescriptorPool): Descriptor pool holding file descriptors and enum classes. + + Returns: + None: Only creates a file and two message descriptors. + """ + service_name = "".join(char for char in service_name if char.isalpha()) + try: + pool.FindFileByName(service_name) + except KeyError: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = service_name + file_descriptor.package = service_name + + _create_message_type(input_metadata, "Inputs", file_descriptor) + _create_message_type(output_metadata, "Outputs", file_descriptor) + pool.Add(file_descriptor) + + +def _create_message_type( + parameter_metadata: List[ParameterMetadata], + message_name: str, + file_descriptor: descriptor_pb2.FileDescriptorProto, +) -> None: + """Creates a message descriptor with the fields defined in a file descriptor proto. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + message_name (str): Service class name. + + file_descriptor (descriptor_pb2.FileDescriptorProto): Descriptor of a proto file. + + Returns: + None: Only creates a message_type in 'file_descriptor'. + """ + message_proto = file_descriptor.message_type.add() + message_proto.name = message_name + + # Initialize the message with fields defined + for i, metadata in enumerate(parameter_metadata): + field_descriptor = _create_field( + message_proto=message_proto, metadata=metadata, index=i + 1 + ) + if metadata.type == FieldDescriptorProto.TYPE_ENUM: + _create_enum_type( + file_descriptor=file_descriptor, + parameter_metadata=metadata, + field_descriptor=field_descriptor, + ) diff --git a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py b/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py deleted file mode 100644 index 027be49c5..000000000 --- a/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Serialization Strategy.""" - -import json -from enum import Enum -from typing import Any, Dict - -from google.protobuf import descriptor_pb2, descriptor_pool -from google.protobuf.descriptor_pb2 import FieldDescriptorProto - -from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( - ParameterMetadata, -) -from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor - - -def _create_enum_type( - file_descriptor: descriptor_pb2.FileDescriptorProto, - parameter_metadata: ParameterMetadata, - field_descriptor: FieldDescriptorProto, -) -> None: - """Implement a enum class in 'file_descriptor'. - - Args: - file_descriptor (FileDescriptorProto): Descriptor of a proto file. - - parmeter_metadata (ParameterMetadata): Metadata of current field. - - field_descriptor (FieldDescriptorProto): Descriptor of a field. - - Returns: - None: Only creates a enum class in file_descriptor. - """ - enum_dict = json.loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - is_protobuf = _get_enum_type(parameter_metadata) is int - - if parameter_metadata.repeated: - enum_type_name = parameter_metadata.default_value[0].__class__.__name__ - else: - enum_type_name = parameter_metadata.default_value.__class__.__name__ - - if ( - enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type] - and not is_protobuf - ): - enum_descriptor = file_descriptor.enum_type.add() - enum_descriptor.name = enum_type_name - for name, number in enum_dict.items(): - enum_value_descriptor = enum_descriptor.value.add() - enum_value_descriptor.name = name - enum_value_descriptor.number = number - - if is_protobuf: - field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name - else: - field_descriptor.type_name = enum_type_name - - -def _create_field( - message_proto: Any, metadata: ParameterMetadata, index: int -) -> FieldDescriptorProto: - """Implement a field in 'message_proto'. - - Args: - message_proto (message_type): A message instance in '_FILE_DESCRIPTOR_PROTO'. - - metadata (ParameterMetadata): Metadata of 'param'. - - index (int): 'param' index in parameter_values - - Returns: - Any: field_descriptor of 'param'. - """ - field_descriptor = message_proto.field.add() - field_descriptor.number = index - field_descriptor.name = f"field_{index}" - field_descriptor.type = metadata.type - - if metadata.repeated: - field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED - field_descriptor.options.packed = True - else: - field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL - - if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: - field_descriptor.type_name = metadata.message_type - return field_descriptor - - -def _get_enum_field(enum_dict: Dict[Any, int], enum_type: Any, field_value: int) -> Any: - """Get enum type and value from 'field_value'. - - Args: - enum_dict (Dict[Any, int]): List enum class of 'field_value'. - - enum_type (Any): 'field_value' enum class name. - - field_value (int): Default value of current field. - - Returns: - Any: Enum type of 'field_value' from 'enum_dict' with the enum value. - """ - for name in enum_dict.keys(): - enum_value = getattr(enum_type, name) - if field_value == enum_value.value: - return enum_value - - -def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: - if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: - return type(parameter_metadata.default_value[0]) - else: - return type(parameter_metadata.default_value) - - -def create_message_type( - parameter_metadata_dict: Dict[int, ParameterMetadata], - message_name: str, - pool: descriptor_pool.DescriptorPool, -) -> Any: - """Creates a message descriptor with the fields defined in a file descriptor proto. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - message_name (str): Service class name. - - pool (descriptor_pool.DescriptorPool): Descriptor pool holding file descriptors. - - Returns: - Any: A message descriptor based on a defined message_descriptor - """ - file_descriptor = descriptor_pb2.FileDescriptorProto() - file_descriptor.name = message_name - file_descriptor.package = message_name - message_proto = file_descriptor.message_type.add() - message_proto.name = message_name - - # Initialize the message with fields defined - for i in parameter_metadata_dict.keys(): - parameter_metadata = parameter_metadata_dict[i] - field_descriptor = _create_field( - message_proto=message_proto, metadata=parameter_metadata, index=i - ) - if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: - _create_enum_type( - file_descriptor=file_descriptor, - parameter_metadata=parameter_metadata, - field_descriptor=field_descriptor, - ) - pool.Add(file_descriptor) - return pool.FindMessageTypeByName(f"{file_descriptor.package}.{message_proto.name}") - - -def get_enum_values(param: Any) -> Any: - """Get's value of an enum. - - Args: - param (Any): A value/parameter of parameter_values. - - Returns: - Any: An enum value or a list of enums or the 'param'. - """ - if param == []: - return param - if isinstance(param, list) and isinstance(param[0], Enum): - return [x.value for x in param] - elif isinstance(param, Enum): - return param.value - return param - - -def deserialize_enum_parameter( - parameter_metadata: ParameterMetadata, message_instance: Any, field_name: str -) -> Any: - """Convert all enums into the user defined enum type. - - Args: - parameter_metadata (ParameterMetadata): Metadata of current enum value. - - message_instance (Any): Message class of all intialized fields. - - field_name (str): Name of current field. - - Returns: - Any: Enum type or a list of enum types. - """ - enum_dict = json.loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - field_value = getattr(message_instance, field_name) - enum_type = _get_enum_type(parameter_metadata) - if parameter_metadata.repeated: - return [_get_enum_field(enum_dict, enum_type, value) for value in field_value] - else: - return _get_enum_field(enum_dict, enum_type, field_value) diff --git a/ni_measurement_plugin_sdk_service/_internal/service_manager.py b/ni_measurement_plugin_sdk_service/_internal/service_manager.py index 4ea6814e1..775c37542 100644 --- a/ni_measurement_plugin_sdk_service/_internal/service_manager.py +++ b/ni_measurement_plugin_sdk_service/_internal/service_manager.py @@ -3,6 +3,7 @@ import grpc from deprecation import deprecated +from google.protobuf import descriptor_pool from grpc.framework.foundation import logging_pool from ni_measurement_plugin_sdk_service._internal.grpc_servicer import ( @@ -12,6 +13,9 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( + create_file_descriptor, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v1 import ( measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc, ) @@ -99,6 +103,12 @@ def start( ("grpc.max_send_message_length", -1), ], ) + create_file_descriptor( + service_name=service_info.service_class, + output_metadata=output_parameter_list, + input_metadata=configuration_parameter_list, + pool=descriptor_pool.Default(), + ) for interface in service_info.provided_interfaces: if interface == _V1_INTERFACE: servicer_v1 = MeasurementServiceServicerV1( diff --git a/tests/unit/test_decoder.py b/tests/unit/test_decoder.py index f361dfaaa..304368e60 100644 --- a/tests/unit/test_decoder.py +++ b/tests/unit/test_decoder.py @@ -1,16 +1,19 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum -from typing import Dict, Sequence +from typing import Any, Dict, Sequence import pytest -from google.protobuf import any_pb2, type_pb2 +from google.protobuf import any_pb2, descriptor_pb2, descriptor_pool, type_pb2 from ni_measurement_plugin_sdk_service._annotations import ( ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import decoder +from ni_measurement_plugin_sdk_service._internal.parameter import ( + decoder, + serialization_descriptors, +) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, @@ -18,7 +21,6 @@ from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( xydata_pb2, ) -from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo from tests.utilities.stubs.serialization import test_pb2 from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage @@ -87,11 +89,12 @@ class Countries(IntEnum): def test___serializer___deserialize_parameter___successful_deserialization(values): parameter = _get_test_parameter_by_id(values) grpc_serialized_data = _get_grpc_serialized_data(values) + service_name = _test_create_file_descriptor(list(parameter.values()), "deserializeparameter") parameter_value_by_id = decoder.deserialize_parameters( parameter, grpc_serialized_data, - ServiceInfo(service_class="deserializer", description_url=""), + service_name=service_name, ) assert list(parameter_value_by_id.values()) == values @@ -124,8 +127,9 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): double_xy_data_array, ] parameter = _get_test_parameter_by_id(nonzero_defaults) + service_name = _test_create_file_descriptor(list(parameter.values()), "emptybuffer") parameter_value_by_id = decoder.deserialize_parameters( - parameter, bytes(), ServiceInfo(service_class="empty_buffer", description_url="") + parameter, bytes(), service_name=service_name ) for key, value in parameter_value_by_id.items(): @@ -148,11 +152,14 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( message = _get_big_message(values) serialized_data = message.SerializeToString() expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} + service_name = _test_create_file_descriptor( + list(parameter_metadata_by_id.values()), "bigmessage" + ) parameter_value_by_id = decoder.deserialize_parameters( parameter_metadata_by_id, serialized_data, - ServiceInfo(service_class="big_message", description_url=""), + service_name=service_name, ) assert parameter_value_by_id == pytest.approx(expected_parameter_value_by_id) @@ -385,3 +392,16 @@ def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: def _get_big_message(values: Sequence[float]) -> BigMessage: assert len(values) == BIG_MESSAGE_SIZE return BigMessage(**{f"field{i + 1}": value for (i, value) in enumerate(values)}) + + +def _test_create_file_descriptor(metadata: Dict[int, Any], file_name: str) -> str: + pool = descriptor_pool.Default() + try: + pool.FindMessageTypeByName(f"{file_name}.test") + except KeyError: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = file_name + file_descriptor.package = file_name + serialization_descriptors._create_message_type(metadata, "test", file_descriptor) + pool.Add(file_descriptor) + return file_name + ".test" diff --git a/tests/unit/test_encoder.py b/tests/unit/test_encoder.py index c0ef11b5c..c5f341ee1 100644 --- a/tests/unit/test_encoder.py +++ b/tests/unit/test_encoder.py @@ -1,15 +1,18 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum +from typing import Any, Dict import pytest -from google.protobuf import descriptor_pool +from google.protobuf import descriptor_pb2, descriptor_pool -from ni_measurement_plugin_sdk_service._internal.parameter import encoder +from ni_measurement_plugin_sdk_service._internal.parameter import ( + encoder, + serialization_descriptors, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( xydata_pb2, ) -from ni_measurement_plugin_sdk_service.measurement.info import ServiceInfo from tests.unit.test_decoder import ( _get_big_message, _get_big_message_metadata_by_id, @@ -107,12 +110,13 @@ class Countries(IntEnum): def test___serializer___serialize_parameter___successful_serialization(test_values): default_values = test_values parameter = _get_test_parameter_by_id(default_values) + service_name = _test_create_file_descriptor(list(parameter.values()), "serializeparameter") # Custom Serialization custom_serialized_bytes = encoder.serialize_parameters( parameter, test_values, - service_info=ServiceInfo(service_class="serialize_parameters", description_url=""), + service_name=service_name, ) _validate_serialized_bytes(custom_serialized_bytes, test_values) @@ -173,11 +177,10 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu ) def test___serializer___serialize_default_parameter___successful_serialization(default_values): parameter = _get_test_parameter_by_id(default_values) + service_name = _test_create_file_descriptor(list(parameter.values()), "defaultserialize") # Custom Serialization - custom_serialized_bytes = encoder.serialize_default_values( - parameter, service_info=ServiceInfo(service_class="default_serialize", description_url="") - ) + custom_serialized_bytes = encoder.serialize_default_values(parameter, service_name=service_name) _validate_serialized_bytes(custom_serialized_bytes, default_values) @@ -186,11 +189,14 @@ def test___big_message___serialize_parameters___returns_serialized_data() -> Non parameter_metadata_by_id = _get_big_message_metadata_by_id() values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] expected_message = _get_big_message(values) + service_name = _test_create_file_descriptor( + list(parameter_metadata_by_id.values()), "bigmessage" + ) serialized_data = encoder.serialize_parameters( parameter_metadata_by_id, values, - service_info=ServiceInfo(service_class="big_message", description_url=""), + service_name=service_name, ) message = BigMessage.FromString(serialized_data) @@ -230,7 +236,7 @@ def test___serialize_parameter_multiple_times___returns_one_message_type(test_va for i in range(100): test___serializer___serialize_parameter___successful_serialization(test_values) pool = descriptor_pool.Default() - file_descriptor = pool.FindFileByName("serializeparametersSERIALIZE") + file_descriptor = pool.FindFileByName("serializeparameter") message_dict = file_descriptor.message_types_by_name assert len(message_dict) == 1 @@ -239,3 +245,16 @@ def _validate_serialized_bytes(custom_serialized_bytes, values): # Serialization using gRPC Any grpc_serialized_data = _get_grpc_serialized_data(values) assert grpc_serialized_data == custom_serialized_bytes + + +def _test_create_file_descriptor(metadata: Dict[int, Any], file_name: str) -> str: + pool = descriptor_pool.Default() + try: + pool.FindMessageTypeByName(f"{file_name}.test") + except KeyError: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = file_name + file_descriptor.package = file_name + serialization_descriptors._create_message_type(metadata, "test", file_descriptor) + pool.Add(file_descriptor) + return file_name + ".test" From 55726f69779388695cd8179f17243c1c45868535 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 15 Jul 2024 14:08:22 -0500 Subject: [PATCH 16/25] Deleted serialization_strategy with default_value --- .../_internal/parameter/_serializer_types.py | 23 ------------------- ...tion_strategy.py => test_default_value.py} | 0 2 files changed, 23 deletions(-) delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py rename packages/service/tests/unit/{test_serialization_strategy.py => test_default_value.py} (100%) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py deleted file mode 100644 index e3dcd1d87..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys -import typing -from typing import Any, Callable, Dict - -from google.protobuf.descriptor import FieldDescriptor -from google.protobuf.message import Message - -if typing.TYPE_CHECKING: - if sys.version_info >= (3, 10): - from typing import TypeAlias - else: - from typing_extensions import TypeAlias - - -Key: TypeAlias = FieldDescriptor -WriteFunction: TypeAlias = Callable[[bytes], int] - -Decoder: TypeAlias = Callable[[memoryview, int, int, Message, Dict[Key, Any]], int] -PartialDecoderConstructor: TypeAlias = Callable[[int, Key], Decoder] -NewDefault: TypeAlias = Callable[[Message], Message] -DecoderConstructor: TypeAlias = Callable[[int, bool, bool, Key, NewDefault], Decoder] diff --git a/packages/service/tests/unit/test_serialization_strategy.py b/packages/service/tests/unit/test_default_value.py similarity index 100% rename from packages/service/tests/unit/test_serialization_strategy.py rename to packages/service/tests/unit/test_default_value.py From 7e910f57e1ba04729947360c2db49bc93d6116d6 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 15 Jul 2024 14:10:52 -0500 Subject: [PATCH 17/25] Deleted test_serializer --- .../service/tests/unit/test_serializer.py | 407 ------------------ 1 file changed, 407 deletions(-) delete mode 100644 packages/service/tests/unit/test_serializer.py diff --git a/packages/service/tests/unit/test_serializer.py b/packages/service/tests/unit/test_serializer.py deleted file mode 100644 index 304368e60..000000000 --- a/packages/service/tests/unit/test_serializer.py +++ /dev/null @@ -1,407 +0,0 @@ -"""Contains tests to validate serializer.py.""" - -from enum import Enum, IntEnum -from typing import Any, Dict, Sequence - -import pytest -from google.protobuf import any_pb2, descriptor_pb2, descriptor_pool, type_pb2 - -from ni_measurement_plugin_sdk_service._annotations import ( - ENUM_VALUES_KEY, - TYPE_SPECIALIZATION_KEY, -) -from ni_measurement_plugin_sdk_service._internal.parameter import ( - decoder, - serialization_descriptors, -) -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( - ParameterMetadata, - TypeSpecialization, -) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( - xydata_pb2, -) -from tests.utilities.stubs.serialization import test_pb2 -from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage - - -class DifferentColor(Enum): - """Non-primary colors used for testing enum-typed config and output.""" - - PURPLE = 0 - ORANGE = 1 - TEAL = 2 - BROWN = 3 - - -class Countries(IntEnum): - """Countries enum used for testing enum-typed config and output.""" - - AMERICA = 0 - TAIWAN = 1 - AUSTRALIA = 2 - CANADA = 3 - - -double_xy_data = xydata_pb2.DoubleXYData() -double_xy_data.x_data.append(4) -double_xy_data.y_data.append(6) - -double_xy_data2 = xydata_pb2.DoubleXYData() -double_xy_data2.x_data.append(8) -double_xy_data2.y_data.append(10) - -double_xy_data_array = [double_xy_data, double_xy_data2] - -# This should match the number of fields in bigmessage.proto. -BIG_MESSAGE_SIZE = 100 - - -@pytest.mark.parametrize( - "values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1.0], - [5.5, 3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1", "String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ] - ], -) -def test___serializer___deserialize_parameter___successful_deserialization(values): - parameter = _get_test_parameter_by_id(values) - grpc_serialized_data = _get_grpc_serialized_data(values) - service_name = _test_create_file_descriptor(list(parameter.values()), "deserializeparameter") - - parameter_value_by_id = decoder.deserialize_parameters( - parameter, - grpc_serialized_data, - service_name=service_name, - ) - assert list(parameter_value_by_id.values()) == values - - -def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): - # Note that we set nonzero defaults to validate that we are getting zero-values - # as opposed to simply getting the defaults. - nonzero_defaults = [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1.0], - [5.5, 3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1", "String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ] - parameter = _get_test_parameter_by_id(nonzero_defaults) - service_name = _test_create_file_descriptor(list(parameter.values()), "emptybuffer") - parameter_value_by_id = decoder.deserialize_parameters( - parameter, bytes(), service_name=service_name - ) - - for key, value in parameter_value_by_id.items(): - parameter_metadata = parameter[key] - if parameter_metadata.repeated: - assert value == list() - elif parameter_metadata.type == type_pb2.Field.TYPE_ENUM: - assert value.value == 0 - elif parameter_metadata.type == type_pb2.Field.TYPE_STRING: - assert value == "" - elif parameter_metadata.type == type_pb2.Field.TYPE_MESSAGE: - assert value is None - else: - assert value == 0 - - -def test___big_message___deserialize_parameters___returns_parameter_value_by_id() -> None: - parameter_metadata_by_id = _get_big_message_metadata_by_id() - values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] - message = _get_big_message(values) - serialized_data = message.SerializeToString() - expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} - service_name = _test_create_file_descriptor( - list(parameter_metadata_by_id.values()), "bigmessage" - ) - - parameter_value_by_id = decoder.deserialize_parameters( - parameter_metadata_by_id, - serialized_data, - service_name=service_name, - ) - - assert parameter_value_by_id == pytest.approx(expected_parameter_value_by_id) - - -def _get_grpc_serialized_data(values): - grpc_parameter = _get_test_grpc_message(values) - parameter_any = any_pb2.Any() - parameter_any.Pack(grpc_parameter) - grpc_serialized_data = parameter_any.value - return grpc_serialized_data - - -def _get_test_parameter_by_id(default_values): - parameter_by_id = { - 1: ParameterMetadata( - display_name="float_data", - type=type_pb2.Field.TYPE_FLOAT, - repeated=False, - default_value=default_values[0], - annotations={}, - ), - 2: ParameterMetadata( - display_name="double_data", - type=type_pb2.Field.TYPE_DOUBLE, - repeated=False, - default_value=default_values[1], - annotations={}, - ), - 3: ParameterMetadata( - display_name="int32_data", - type=type_pb2.Field.TYPE_INT32, - repeated=False, - default_value=default_values[2], - annotations={}, - ), - 4: ParameterMetadata( - display_name="uint32_data", - type=type_pb2.Field.TYPE_INT64, - repeated=False, - default_value=default_values[3], - annotations={}, - ), - 5: ParameterMetadata( - display_name="int64_data", - type=type_pb2.Field.TYPE_UINT32, - repeated=False, - default_value=default_values[4], - annotations={}, - ), - 6: ParameterMetadata( - display_name="uint64_data", - type=type_pb2.Field.TYPE_UINT64, - repeated=False, - default_value=default_values[5], - annotations={}, - ), - 7: ParameterMetadata( - display_name="bool_data", - type=type_pb2.Field.TYPE_BOOL, - repeated=False, - default_value=default_values[6], - annotations={}, - ), - 8: ParameterMetadata( - display_name="string_data", - type=type_pb2.Field.TYPE_STRING, - repeated=False, - default_value=default_values[7], - annotations={}, - ), - 9: ParameterMetadata( - display_name="double_array_data", - type=type_pb2.Field.TYPE_DOUBLE, - repeated=True, - default_value=default_values[8], - annotations={}, - ), - 10: ParameterMetadata( - display_name="float_array_data", - type=type_pb2.Field.TYPE_FLOAT, - repeated=True, - default_value=default_values[9], - annotations={}, - ), - 11: ParameterMetadata( - display_name="int32_array_data", - type=type_pb2.Field.TYPE_INT32, - repeated=True, - default_value=default_values[10], - annotations={}, - ), - 12: ParameterMetadata( - display_name="uint32_array_data", - type=type_pb2.Field.TYPE_UINT32, - repeated=True, - default_value=default_values[11], - annotations={}, - ), - 13: ParameterMetadata( - display_name="int64_array_data", - type=type_pb2.Field.TYPE_INT64, - repeated=True, - default_value=default_values[12], - annotations={}, - ), - 14: ParameterMetadata( - display_name="uint64_array_data", - type=type_pb2.Field.TYPE_UINT64, - repeated=True, - default_value=default_values[13], - annotations={}, - ), - 15: ParameterMetadata( - display_name="bool_array_data", - type=type_pb2.Field.TYPE_BOOL, - repeated=True, - default_value=default_values[14], - annotations={}, - ), - 16: ParameterMetadata( - display_name="string_array_data", - type=type_pb2.Field.TYPE_STRING, - repeated=True, - default_value=default_values[15], - annotations={}, - ), - 17: ParameterMetadata( - display_name="enum_data", - type=type_pb2.Field.TYPE_ENUM, - repeated=False, - default_value=default_values[16], - annotations={ - TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, - ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', - }, - ), - 18: ParameterMetadata( - display_name="enum_array_data", - type=type_pb2.Field.TYPE_ENUM, - repeated=True, - default_value=default_values[17], - annotations={ - TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, - ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', - }, - ), - 19: ParameterMetadata( - display_name="int_enum_data", - type=type_pb2.Field.TYPE_ENUM, - repeated=False, - default_value=default_values[18], - annotations={ - TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, - ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', - }, - ), - 20: ParameterMetadata( - display_name="int_enum_array_data", - type=type_pb2.Field.TYPE_ENUM, - repeated=True, - default_value=default_values[19], - annotations={ - TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, - ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', - }, - ), - 21: ParameterMetadata( - display_name="xy_data", - type=type_pb2.Field.TYPE_MESSAGE, - repeated=False, - default_value=default_values[20], - annotations={}, - message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - ), - 22: ParameterMetadata( - display_name="xy_data_array", - type=type_pb2.Field.TYPE_MESSAGE, - repeated=True, - default_value=default_values[21], - annotations={}, - message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - ), - } - return parameter_by_id - - -def _get_test_grpc_message(test_values): - parameter = test_pb2.MeasurementParameter() - parameter.float_data = test_values[0] - parameter.double_data = test_values[1] - parameter.int32_data = test_values[2] - parameter.uint32_data = test_values[3] - parameter.int64_data = test_values[4] - parameter.uint64_data = test_values[5] - parameter.bool_data = test_values[6] - parameter.string_data = test_values[7] - parameter.double_array_data.extend(test_values[8]) - parameter.float_array_data.extend(test_values[9]) - parameter.int32_array_data.extend(test_values[10]) - parameter.uint32_array_data.extend(test_values[11]) - parameter.int64_array_data.extend(test_values[12]) - parameter.uint64_array_data.extend(test_values[13]) - parameter.bool_array_data.extend(test_values[14]) - parameter.string_array_data.extend(test_values[15]) - parameter.enum_data = test_values[16].value - parameter.enum_array_data.extend(list(map(lambda x: x.value, test_values[17]))) - parameter.int_enum_data = test_values[18].value - parameter.int_enum_array_data.extend(list(map(lambda x: x.value, test_values[19]))) - parameter.xy_data.x_data.append(test_values[20].x_data[0]) - parameter.xy_data.y_data.append(test_values[20].y_data[0]) - parameter.xy_data_array.extend(test_values[21]) - return parameter - - -def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: - return { - i - + 1: ParameterMetadata( - display_name=f"field{i + 1}", - type=type_pb2.Field.TYPE_DOUBLE, - repeated=False, - default_value=-1.0, - annotations={}, - ) - for i in range(BIG_MESSAGE_SIZE) - } - - -def _get_big_message(values: Sequence[float]) -> BigMessage: - assert len(values) == BIG_MESSAGE_SIZE - return BigMessage(**{f"field{i + 1}": value for (i, value) in enumerate(values)}) - - -def _test_create_file_descriptor(metadata: Dict[int, Any], file_name: str) -> str: - pool = descriptor_pool.Default() - try: - pool.FindMessageTypeByName(f"{file_name}.test") - except KeyError: - file_descriptor = descriptor_pb2.FileDescriptorProto() - file_descriptor.name = file_name - file_descriptor.package = file_name - serialization_descriptors._create_message_type(metadata, "test", file_descriptor) - pool.Add(file_descriptor) - return file_name + ".test" From 76d4cd8722d9d17ef1b0ec39ccca503196f1539f Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 15 Jul 2024 14:16:20 -0500 Subject: [PATCH 18/25] Deleted _message.py --- .../_internal/parameter/_message.py | 180 ------------------ 1 file changed, 180 deletions(-) delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py deleted file mode 100644 index 83e2363e0..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py +++ /dev/null @@ -1,180 +0,0 @@ -import struct -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from google.protobuf.internal import encoder, wire_format -from google.protobuf.message import Message - -from ni_measurement_plugin_sdk_service._internal.parameter._decoder_types import ( - Decoder, - Key, - NewDefault, - WriteFunction, -) - - -def _message_encoder_constructor( - field_index: int, is_repeated: bool, is_packed: bool -) -> Callable[[WriteFunction, Union[Message, List[Message]], bool], int]: - """Mimics google.protobuf.internal.MessageEncoder. - - This function was forked in order to call SerializeToString instead of _InternalSerialize. - - _InternalSerialize is only defined for the pure-Python protobuf implementation. Our child - messages (like DoubleXYData) are defined in .proto files, so they use whichever protobuf - implementation that google.protobuf.internal.api_implementation chooses (usually upb). - """ - tag = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED) - encode_varint = _varint_encoder() - - if is_repeated: - - def _encode_repeated_message( - write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool - ) -> int: - bytes_written = 0 - for element in cast(List[Message], value): - write(tag) - bytes = element.SerializeToString() - encode_varint(write, len(bytes), deterministic) - bytes_written += write(bytes) - return bytes_written - - return _encode_repeated_message - else: - - def _encode_message( - write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool - ) -> int: - write(tag) - bytes = cast(Message, value).SerializeToString() - encode_varint(write, len(bytes), deterministic) - return write(bytes) - - return _encode_message - - -def _varint_encoder() -> Callable[[WriteFunction, int, Optional[bool]], int]: - """Return an encoder for a basic varint value (does not include tag). - - From google.protobuf.internal.encoder.py _VarintEncoder - """ - local_int2byte = struct.Struct(">B").pack - - def encode_varint( - write: WriteFunction, value: int, unused_deterministic: Optional[bool] = None - ) -> int: - bits = value & 0x7F - value >>= 7 - while value: - write(local_int2byte(0x80 | bits)) - bits = value & 0x7F - value >>= 7 - return write(local_int2byte(bits)) - - return encode_varint - - -def _message_decoder_constructor( - field_index: int, is_repeated: bool, is_packed: bool, key: Key, new_default: NewDefault -) -> Decoder: - """Mimics google.protobuf.internal.MessageDecoder. - - This function was forked in order to call ParseFromString instead of _InternalParse. - - _InternalParse is only defined for the pure-Python protobuf implementation. Our child messages - (like DoubleXYData) are defined in .proto files, so they use whichever protobuf implementation - that google.protobuf.internal.api_implementation chooses (usually upb). - """ - if is_repeated: - tag_bytes = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED) - tag_len = len(tag_bytes) - - def _decode_repeated_message( - buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any] - ) -> int: - decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int) - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, []) - while 1: - parsed_value = new_default(message) - # Read length. - (size, pos) = decode_varint(buffer, pos) - new_pos = pos + size - if new_pos > end: - raise ValueError("Error decoding a message. Message is truncated.") - parsed_bytes = parsed_value.ParseFromString(buffer[pos:new_pos]) - if parsed_bytes != size: - raise ValueError("Parsed incorrect number of bytes.") - value.append(parsed_value) - # Predict that the next tag is another copy of the same repeated field. - pos = new_pos + tag_len - if buffer[new_pos:pos] != tag_bytes or new_pos == end: - # Prediction failed. Return. - return new_pos - - return _decode_repeated_message - else: - - def _decode_message( - buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any] - ) -> int: - decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int) - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) - # Read length. - (size, pos) = decode_varint(buffer, pos) - new_pos = pos + size - if new_pos > end: - raise ValueError("Error decoding a message. Message is truncated.") - parsed_bytes = value.ParseFromString(buffer[pos:new_pos]) - if parsed_bytes != size: - raise ValueError("Parsed incorrect number of bytes.") - return new_pos - - return _decode_message - - -T = TypeVar("T", bound="int") - - -def _varint_decoder(mask: int, result_type: Type[T]) -> Callable[[memoryview, int], Tuple[T, int]]: - """Return an encoder for a basic varint value (does not include tag). - - Decoded values will be bitwise-anded with the given mask before being - returned, e.g. to limit them to 32 bits. The returned decoder does not - take the usual "end" parameter -- the caller is expected to do bounds checking - after the fact (often the caller can defer such checking until later). The - decoder returns a (value, new_pos) pair. - - From google.protobuf.internal.decoder.py _VarintDecoder - """ - - def decode_varint(buffer: memoryview, pos: int) -> Tuple[T, int]: - result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= (b & 0x7F) << shift - pos += 1 - if not (b & 0x80): - result &= mask - result = result_type(result) - return (result, pos) - shift += 7 - if shift >= 64: - raise ValueError("Too many bytes when decoding varint: {shift}") - - return decode_varint From e68913db22332d6665bae39c18a1497cf1152e99 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Mon, 15 Jul 2024 15:52:15 -0500 Subject: [PATCH 19/25] Fixed type errors in test encoder/decoder and docstring --- .../_internal/parameter/serialization_descriptors.py | 2 +- packages/service/tests/unit/test_decoder.py | 4 ++-- packages/service/tests/unit/test_encoder.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index bd06ef796..c96124c10 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -1,4 +1,4 @@ -"""Serialization Strategy.""" +"""Serialization Descriptors.""" from json import loads from typing import Any, List diff --git a/packages/service/tests/unit/test_decoder.py b/packages/service/tests/unit/test_decoder.py index 304368e60..edbfc9087 100644 --- a/packages/service/tests/unit/test_decoder.py +++ b/packages/service/tests/unit/test_decoder.py @@ -1,7 +1,7 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum -from typing import Any, Dict, Sequence +from typing import List, Dict, Sequence import pytest from google.protobuf import any_pb2, descriptor_pb2, descriptor_pool, type_pb2 @@ -394,7 +394,7 @@ def _get_big_message(values: Sequence[float]) -> BigMessage: return BigMessage(**{f"field{i + 1}": value for (i, value) in enumerate(values)}) -def _test_create_file_descriptor(metadata: Dict[int, Any], file_name: str) -> str: +def _test_create_file_descriptor(metadata: List[ParameterMetadata], file_name: str) -> str: pool = descriptor_pool.Default() try: pool.FindMessageTypeByName(f"{file_name}.test") diff --git a/packages/service/tests/unit/test_encoder.py b/packages/service/tests/unit/test_encoder.py index c5f341ee1..861161f98 100644 --- a/packages/service/tests/unit/test_encoder.py +++ b/packages/service/tests/unit/test_encoder.py @@ -1,13 +1,14 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum -from typing import Any, Dict +from typing import List import pytest from google.protobuf import descriptor_pb2, descriptor_pool from ni_measurement_plugin_sdk_service._internal.parameter import ( encoder, + metadata, serialization_descriptors, ) from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( @@ -247,7 +248,7 @@ def _validate_serialized_bytes(custom_serialized_bytes, values): assert grpc_serialized_data == custom_serialized_bytes -def _test_create_file_descriptor(metadata: Dict[int, Any], file_name: str) -> str: +def _test_create_file_descriptor(metadata: List[metadata.ParameterMetadata], file_name: str) -> str: pool = descriptor_pool.Default() try: pool.FindMessageTypeByName(f"{file_name}.test") From 661482d622c66a8ee781601fb2eab4614d7244bd Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Thu, 18 Jul 2024 12:16:10 -0500 Subject: [PATCH 20/25] Renamed and cleaned helper functions, added initalize() in metadata. --- .../_internal/grpc_servicer.py | 25 ++- .../_internal/parameter/decoder.py | 57 ++----- .../_internal/parameter/default_value.py | 4 +- .../_internal/parameter/encoder.py | 11 +- .../_internal/parameter/metadata.py | 28 +++- .../parameter/serialization_descriptors.py | 145 +++++------------- .../measurement/service.py | 4 +- packages/service/tests/unit/test_decoder.py | 57 +++---- packages/service/tests/unit/test_encoder.py | 8 +- 9 files changed, 127 insertions(+), 212 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index e75293223..ac78cbe3c 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -205,7 +205,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata, self._get_service_name() + ".Inputs" + self._configuration_metadata, self._service_info.service_class + ".Configurations" ) for field_number, output_metadata in self._output_metadata.items(): @@ -238,7 +238,7 @@ def Measure( # noqa: N802 - function name should be lowercase mapping_by_id = decoder.deserialize_parameters( self._configuration_metadata, request.configuration_parameters.value, - self._get_service_name() + ".Inputs", + self._service_info.service_class + ".Configurations", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -265,17 +265,16 @@ def Measure( # noqa: N802 - function name should be lowercase measurement_service_context.get().mark_complete() measurement_service_context.reset(token) - def _serialize_response(self, outputs: Any) -> v1_measurement_service_pb2.MeasureResponse: + def _serialize_response( + self, + outputs: Any, + ) -> v1_measurement_service_pb2.MeasureResponse: return v1_measurement_service_pb2.MeasureResponse( outputs=_serialize_outputs( - self._output_metadata, outputs, self._get_service_name() + ".Outputs" + self._output_metadata, outputs, self._service_info.service_class + ".Outputs" ) ) - def _get_service_name(self) -> str: - service_name = "".join(char for char in self._service_info.service_class if char.isalpha()) - return service_name - class MeasurementServiceServicerV2(v2_measurement_service_pb2_grpc.MeasurementServiceServicer): """Measurement v2 servicer.""" @@ -323,7 +322,7 @@ def GetMetadata( # noqa: N802 - function name should be lowercase measurement_signature.configuration_parameters.append(configuration_parameter) measurement_signature.configuration_defaults.value = encoder.serialize_default_values( - self._configuration_metadata, self._get_service_name() + ".Inputs" + self._configuration_metadata, self._service_info.service_class + ".Configurations" ) for field_number, output_metadata in self._output_metadata.items(): @@ -358,7 +357,7 @@ def Measure( # noqa: N802 - function name should be lowercase mapping_by_id = decoder.deserialize_parameters( self._configuration_metadata, request.configuration_parameters.value, - self._get_service_name() + ".Inputs", + self._service_info.service_class + ".Configurations", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -387,10 +386,6 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse: return v2_measurement_service_pb2.MeasureResponse( outputs=_serialize_outputs( - self._output_metadata, outputs, self._get_service_name() + ".Outputs" + self._output_metadata, outputs, self._service_info.service_class + ".Outputs" ) ) - - def _get_service_name(self) -> str: - service_name = "".join(char for char in self._service_info.service_class if char.isalpha()) - return service_name diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index fe5692605..d0b2e5c81 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -1,12 +1,10 @@ """Parameter Serializer.""" -from json import loads -from typing import Any, Dict +from typing import Any, Dict, List, Union from google.protobuf import descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) @@ -30,7 +28,7 @@ def deserialize_parameters( service_name (str): Unique service name. Returns: - Dict[int, Any]: Deserialized parameters by ID + Dict[int, Any]: Deserialized parameters by ID. """ pool = descriptor_pool.Default() message_proto = pool.FindMessageTypeByName(service_name) @@ -40,14 +38,14 @@ def deserialize_parameters( message_instance.ParseFromString(parameter_bytes) for i in message_proto.fields_by_number.keys(): parameter_metadata = parameter_metadata_dict[i] - field_name = parameter_metadata.sanitized_display_name() + field_name = parameter_metadata.field_name + enum_type = _get_enum_type(parameter_metadata) value = getattr(message_instance, field_name) - if ( - parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM - and _get_enum_type(parameter_metadata) is not int - ): - parameter_values[i] = _deserialize_enum_parameter(parameter_metadata, value) + if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM and enum_type is not int: + parameter_values[i] = _deserialize_enum_parameter( + parameter_metadata.repeated, value, enum_type + ) elif ( parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE and not parameter_metadata.repeated @@ -59,39 +57,14 @@ def deserialize_parameters( return parameter_values -def _deserialize_enum_parameter(parameter_metadata: ParameterMetadata, field_value: Any) -> Any: +def _deserialize_enum_parameter( + repeated: bool, field_value: Union[List[int], int], enum_type: Any +) -> Union[List[Any], Any]: """Convert all enums into the user defined enum type. - Args: - parameter_metadata (ParameterMetadata): Metadata of current enum value. - - field_value (Any): Value of current field. - - Returns: - Any: Enum type or a list of enum types. - """ - enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - enum_type = _get_enum_type(parameter_metadata) - if parameter_metadata.repeated: - return [_get_enum_field(enum_dict, enum_type, value) for value in field_value] - else: - return _get_enum_field(enum_dict, enum_type, field_value) - - -def _get_enum_field(enum_dict: Dict[Any, int], enum_type: Any, field_value: int) -> Any: - """Get enum type and value from 'field_value'. - - Args: - enum_dict (Dict[Any, int]): List enum class of 'field_value'. - - enum_type (Any): 'field_value' enum class name. - - field_value (int): Default value of current field. - Returns: - Any: Enum type of 'field_value' from 'enum_dict' with the enum value. + Union[List[Any], Any]: Enum type or a list of enum types. """ - for name in enum_dict.keys(): - enum_value = getattr(enum_type, name) - if field_value == enum_value.value: - return enum_value + if repeated: + return [enum_type(value) for value in field_value] + return enum_type(field_value) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py index 168f4db35..8a54edf3a 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py @@ -2,7 +2,7 @@ from google.protobuf import type_pb2 -_type_default_mapping = { +_TYPE_DEFULAT_MAPPING = { type_pb2.Field.TYPE_FLOAT: float(), type_pb2.Field.TYPE_DOUBLE: float(), type_pb2.Field.TYPE_INT32: int(), @@ -19,4 +19,4 @@ def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any """Get the default value for the give type.""" if repeated: return list() - return _type_default_mapping.get(type) + return _TYPE_DEFULAT_MAPPING.get(type) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 97c6ef010..16e246e9d 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -37,7 +37,7 @@ def serialize_parameters( for i, parameter in enumerate(parameter_values, start=1): parameter_metadata = parameter_metadata_dict[i] - field_name = parameter_metadata.sanitized_display_name() + field_name = parameter_metadata.field_name parameter = _get_enum_values(param=parameter) type_default_value = get_type_default(parameter_metadata.type, parameter_metadata.repeated) @@ -74,14 +74,7 @@ def serialize_default_values( def _get_enum_values(param: Any) -> Any: - """Get's value of an enum. - - Args: - param (Any): A value/parameter of parameter_values. - - Returns: - Any: An enum value or a list of enums or the 'param'. - """ + """Get's value of an enum.""" if param == []: return param if isinstance(param, list) and isinstance(param[0], Enum): diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 8227a6e88..652fa0504 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -42,14 +42,30 @@ class ParameterMetadata(NamedTuple): Required when 'type' is Kind.TypeMessage. Ignored for any other 'type'. """ - def sanitized_display_name(self) -> str: - """Parameter display name of alpha/numerical characters. + field_name: str = "" + """display_name in snake_case format.""" - Returns: - str: Alpha/numerical characters of 'display_name'. - """ - return "".join(char for char in self.display_name if char.isalnum()) +def initialize( + display_name: str, + type: type_pb2.Field.Kind.ValueType, + repeated: bool, + default_value: Any, + annotations: Dict[str, str], + message_type: str = "", +) -> ParameterMetadata: + """Initialize ParameterMetadata with field_name.""" + underscore_display_name = display_name.replace(" ", "_") + if all(char.isalnum() or char == "_" for char in underscore_display_name): + field_name = underscore_display_name + else: + field_name = "".join( + char for char in underscore_display_name if char.isalnum() or char == "_" + ) + + return ParameterMetadata( + display_name, type, repeated, default_value, annotations, message_type, field_name + ) def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index c96124c10..ec7f10328 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -1,76 +1,40 @@ """Serialization Descriptors.""" from json import loads -from typing import Any, List +from typing import List from google.protobuf import descriptor_pb2, descriptor_pool -from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.descriptor_pb2 import FieldDescriptorProto, DescriptorProto from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) -from tests.utilities.stubs.loopback.types_pb2 import ProtobufColor -def _get_output_enum_type( - metadata_enum_list: List[str], - file_descriptor: descriptor_pb2.FileDescriptorProto, -) -> Any: - """Get's matching enum class from 'file_descriptor'. - - Args: - metadata_enum_list (List[str]): Enum names from metadata.annotations. - - file_descriptor: Descriptor of proto file. - - Returns: - Any: Matching enum class in a str type or None when enum is protobuf. - """ - for enum_type in file_descriptor.enum_type: - enum_names = [enum_value.name for enum_value in enum_type.value] - if sorted(metadata_enum_list) == sorted(enum_names): - return enum_type.name - return None - - -def _create_enum_type( +def _create_enum_type_class( file_descriptor: descriptor_pb2.FileDescriptorProto, parameter_metadata: ParameterMetadata, field_descriptor: FieldDescriptorProto, ) -> None: - """Implement a enum class in 'file_descriptor'. - - Args: - file_descriptor (FileDescriptorProto): Descriptor of a proto file. - - parmeter_metadata (ParameterMetadata): Metadata of current field. - - field_descriptor (FieldDescriptorProto): Descriptor of a field. - - Returns: - None: Only creates a enum class in file_descriptor. - """ + """Implement a enum class in 'file_descriptor'.""" enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - if parameter_metadata.default_value is None: - enum_type_name = _get_output_enum_type( - metadata_enum_list=[enum for enum in enum_dict.keys()], file_descriptor=file_descriptor - ) - else: - enum_type_name = _get_enum_type(parameter_metadata).__name__ - - if enum_type_name == "int" or enum_type_name is None: - field_descriptor.type_name = ProtobufColor.DESCRIPTOR.full_name - elif enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type]: + enum_type_name = _get_enum_type(parameter_metadata).__name__ + + # if enum is a protobuf then enum_type_name is 1st letter of each enum name + # e.g. {"NONE": 0, "RED": 1, "GREEN": 2} -> NRG + if enum_type_name == "int" or enum_type_name == "NoneType": + enum_field_names = list(enum_dict.keys())[:] + enum_type_name = "".join(name[0] for name in enum_field_names) + + if enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type]: enum_descriptor = file_descriptor.enum_type.add() enum_descriptor.name = enum_type_name for name, number in enum_dict.items(): enum_value_descriptor = enum_descriptor.value.add() - enum_value_descriptor.name = name + enum_value_descriptor.name = f"{enum_type_name}_{name}" enum_value_descriptor.number = number - field_descriptor.type_name = enum_descriptor.name - else: - field_descriptor.type_name = enum_type_name + field_descriptor.type_name = enum_type_name def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: @@ -81,23 +45,12 @@ def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: def _create_field( - message_proto: Any, metadata: ParameterMetadata, index: int + message_proto: DescriptorProto, metadata: ParameterMetadata, index: int ) -> FieldDescriptorProto: - """Implement a field in 'message_proto'. - - Args: - message_proto (message_type): A message instance in a file descriptor proto. - - metadata (ParameterMetadata): Metadata of 'param'. - - index (int): 'param' index in parameter_values - - Returns: - FieldDescriptorProto: field_descriptor of 'param'. - """ + """Implement a field in 'message_proto'.""" field_descriptor = message_proto.field.add() field_descriptor.number = index - field_descriptor.name = metadata.sanitized_display_name() + field_descriptor.name = metadata.field_name field_descriptor.type = metadata.type if metadata.repeated: @@ -111,6 +64,28 @@ def _create_field( return field_descriptor +def _create_message_type( + parameter_metadata: List[ParameterMetadata], + message_name: str, + file_descriptor: descriptor_pb2.FileDescriptorProto, +) -> None: + """Creates a message type with fields intialized in 'file_descriptor'.""" + message_proto = file_descriptor.message_type.add() + message_proto.name = message_name + + # Initialize the message with fields defined + for i, metadata in enumerate(parameter_metadata): + field_descriptor = _create_field( + message_proto=message_proto, metadata=metadata, index=i + 1 + ) + if metadata.type == FieldDescriptorProto.TYPE_ENUM: + _create_enum_type_class( + file_descriptor=file_descriptor, + parameter_metadata=metadata, + field_descriptor=field_descriptor, + ) + + def create_file_descriptor( service_name: str, output_metadata: List[ParameterMetadata], @@ -127,51 +102,13 @@ def create_file_descriptor( input_metadata (List[ParameterMetadata]): Metadata of input parameters. pool (DescriptorPool): Descriptor pool holding file descriptors and enum classes. - - Returns: - None: Only creates a file and two message descriptors. """ - service_name = "".join(char for char in service_name if char.isalpha()) try: pool.FindFileByName(service_name) except KeyError: file_descriptor = descriptor_pb2.FileDescriptorProto() file_descriptor.name = service_name file_descriptor.package = service_name - - _create_message_type(input_metadata, "Inputs", file_descriptor) + _create_message_type(input_metadata, "Configurations", file_descriptor) _create_message_type(output_metadata, "Outputs", file_descriptor) pool.Add(file_descriptor) - - -def _create_message_type( - parameter_metadata: List[ParameterMetadata], - message_name: str, - file_descriptor: descriptor_pb2.FileDescriptorProto, -) -> None: - """Creates a message descriptor with the fields defined in a file descriptor proto. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - message_name (str): Service class name. - - file_descriptor (descriptor_pb2.FileDescriptorProto): Descriptor of a proto file. - - Returns: - None: Only creates a message_type in 'file_descriptor'. - """ - message_proto = file_descriptor.message_type.add() - message_proto.name = message_name - - # Initialize the message with fields defined - for i, metadata in enumerate(parameter_metadata): - field_descriptor = _create_field( - message_proto=message_proto, metadata=metadata, index=i + 1 - ) - if metadata.type == FieldDescriptorProto.TYPE_ENUM: - _create_enum_type( - file_descriptor=file_descriptor, - parameter_metadata=metadata, - field_descriptor=field_descriptor, - ) diff --git a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py index ab8718563..d1a445da1 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py +++ b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py @@ -415,7 +415,7 @@ def configuration( annotations = self._make_annotations_dict( data_type_info.type_specialization, instrument_type=instrument_type, enum_type=enum_type ) - parameter = parameter_metadata.ParameterMetadata( + parameter = parameter_metadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, @@ -475,7 +475,7 @@ def output( annotations = self._make_annotations_dict( data_type_info.type_specialization, enum_type=enum_type ) - parameter = parameter_metadata.ParameterMetadata( + parameter = parameter_metadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, diff --git a/packages/service/tests/unit/test_decoder.py b/packages/service/tests/unit/test_decoder.py index edbfc9087..23aee9ea3 100644 --- a/packages/service/tests/unit/test_decoder.py +++ b/packages/service/tests/unit/test_decoder.py @@ -1,7 +1,7 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum -from typing import List, Dict, Sequence +from typing import Dict, List, Sequence import pytest from google.protobuf import any_pb2, descriptor_pb2, descriptor_pool, type_pb2 @@ -17,6 +17,7 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, + initialize, ) from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( xydata_pb2, @@ -89,7 +90,7 @@ class Countries(IntEnum): def test___serializer___deserialize_parameter___successful_deserialization(values): parameter = _get_test_parameter_by_id(values) grpc_serialized_data = _get_grpc_serialized_data(values) - service_name = _test_create_file_descriptor(list(parameter.values()), "deserializeparameter") + service_name = _test_create_file_descriptor(list(parameter.values()), "deserialize_parameter") parameter_value_by_id = decoder.deserialize_parameters( parameter, @@ -127,7 +128,7 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): double_xy_data_array, ] parameter = _get_test_parameter_by_id(nonzero_defaults) - service_name = _test_create_file_descriptor(list(parameter.values()), "emptybuffer") + service_name = _test_create_file_descriptor(list(parameter.values()), "empty_buffer") parameter_value_by_id = decoder.deserialize_parameters( parameter, bytes(), service_name=service_name ) @@ -153,7 +154,7 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( serialized_data = message.SerializeToString() expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} service_name = _test_create_file_descriptor( - list(parameter_metadata_by_id.values()), "bigmessage" + list(parameter_metadata_by_id.values()), "big_message" ) parameter_value_by_id = decoder.deserialize_parameters( @@ -175,119 +176,119 @@ def _get_grpc_serialized_data(values): def _get_test_parameter_by_id(default_values): parameter_by_id = { - 1: ParameterMetadata( - display_name="float_data", + 1: initialize( + display_name="float_data!", type=type_pb2.Field.TYPE_FLOAT, repeated=False, default_value=default_values[0], annotations={}, ), - 2: ParameterMetadata( + 2: initialize( display_name="double_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, default_value=default_values[1], annotations={}, ), - 3: ParameterMetadata( + 3: initialize( display_name="int32_data", type=type_pb2.Field.TYPE_INT32, repeated=False, default_value=default_values[2], annotations={}, ), - 4: ParameterMetadata( + 4: initialize( display_name="uint32_data", type=type_pb2.Field.TYPE_INT64, repeated=False, default_value=default_values[3], annotations={}, ), - 5: ParameterMetadata( + 5: initialize( display_name="int64_data", type=type_pb2.Field.TYPE_UINT32, repeated=False, default_value=default_values[4], annotations={}, ), - 6: ParameterMetadata( + 6: initialize( display_name="uint64_data", type=type_pb2.Field.TYPE_UINT64, repeated=False, default_value=default_values[5], annotations={}, ), - 7: ParameterMetadata( + 7: initialize( display_name="bool_data", type=type_pb2.Field.TYPE_BOOL, repeated=False, default_value=default_values[6], annotations={}, ), - 8: ParameterMetadata( + 8: initialize( display_name="string_data", type=type_pb2.Field.TYPE_STRING, repeated=False, default_value=default_values[7], annotations={}, ), - 9: ParameterMetadata( + 9: initialize( display_name="double_array_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=True, default_value=default_values[8], annotations={}, ), - 10: ParameterMetadata( + 10: initialize( display_name="float_array_data", type=type_pb2.Field.TYPE_FLOAT, repeated=True, default_value=default_values[9], annotations={}, ), - 11: ParameterMetadata( + 11: initialize( display_name="int32_array_data", type=type_pb2.Field.TYPE_INT32, repeated=True, default_value=default_values[10], annotations={}, ), - 12: ParameterMetadata( + 12: initialize( display_name="uint32_array_data", type=type_pb2.Field.TYPE_UINT32, repeated=True, default_value=default_values[11], annotations={}, ), - 13: ParameterMetadata( + 13: initialize( display_name="int64_array_data", type=type_pb2.Field.TYPE_INT64, repeated=True, default_value=default_values[12], annotations={}, ), - 14: ParameterMetadata( + 14: initialize( display_name="uint64_array_data", type=type_pb2.Field.TYPE_UINT64, repeated=True, default_value=default_values[13], annotations={}, ), - 15: ParameterMetadata( + 15: initialize( display_name="bool_array_data", type=type_pb2.Field.TYPE_BOOL, repeated=True, default_value=default_values[14], annotations={}, ), - 16: ParameterMetadata( + 16: initialize( display_name="string_array_data", type=type_pb2.Field.TYPE_STRING, repeated=True, default_value=default_values[15], annotations={}, ), - 17: ParameterMetadata( + 17: initialize( display_name="enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -297,7 +298,7 @@ def _get_test_parameter_by_id(default_values): ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, ), - 18: ParameterMetadata( + 18: initialize( display_name="enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -307,7 +308,7 @@ def _get_test_parameter_by_id(default_values): ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, ), - 19: ParameterMetadata( + 19: initialize( display_name="int_enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -317,7 +318,7 @@ def _get_test_parameter_by_id(default_values): ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, ), - 20: ParameterMetadata( + 20: initialize( display_name="int_enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -327,7 +328,7 @@ def _get_test_parameter_by_id(default_values): ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, ), - 21: ParameterMetadata( + 21: initialize( display_name="xy_data", type=type_pb2.Field.TYPE_MESSAGE, repeated=False, @@ -335,7 +336,7 @@ def _get_test_parameter_by_id(default_values): annotations={}, message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, ), - 22: ParameterMetadata( + 22: initialize( display_name="xy_data_array", type=type_pb2.Field.TYPE_MESSAGE, repeated=True, @@ -378,7 +379,7 @@ def _get_test_grpc_message(test_values): def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: return { i - + 1: ParameterMetadata( + + 1: initialize( display_name=f"field{i + 1}", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, diff --git a/packages/service/tests/unit/test_encoder.py b/packages/service/tests/unit/test_encoder.py index 861161f98..5ab159d25 100644 --- a/packages/service/tests/unit/test_encoder.py +++ b/packages/service/tests/unit/test_encoder.py @@ -111,7 +111,7 @@ class Countries(IntEnum): def test___serializer___serialize_parameter___successful_serialization(test_values): default_values = test_values parameter = _get_test_parameter_by_id(default_values) - service_name = _test_create_file_descriptor(list(parameter.values()), "serializeparameter") + service_name = _test_create_file_descriptor(list(parameter.values()), "serialize_parameter") # Custom Serialization custom_serialized_bytes = encoder.serialize_parameters( @@ -178,7 +178,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu ) def test___serializer___serialize_default_parameter___successful_serialization(default_values): parameter = _get_test_parameter_by_id(default_values) - service_name = _test_create_file_descriptor(list(parameter.values()), "defaultserialize") + service_name = _test_create_file_descriptor(list(parameter.values()), "default_serialize") # Custom Serialization custom_serialized_bytes = encoder.serialize_default_values(parameter, service_name=service_name) @@ -191,7 +191,7 @@ def test___big_message___serialize_parameters___returns_serialized_data() -> Non values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] expected_message = _get_big_message(values) service_name = _test_create_file_descriptor( - list(parameter_metadata_by_id.values()), "bigmessage" + list(parameter_metadata_by_id.values()), "big_message" ) serialized_data = encoder.serialize_parameters( @@ -237,7 +237,7 @@ def test___serialize_parameter_multiple_times___returns_one_message_type(test_va for i in range(100): test___serializer___serialize_parameter___successful_serialization(test_values) pool = descriptor_pool.Default() - file_descriptor = pool.FindFileByName("serializeparameter") + file_descriptor = pool.FindFileByName("serialize_parameter") message_dict = file_descriptor.message_types_by_name assert len(message_dict) == 1 From 019cef97fed271ec3b2f73458812ec7be8f3850d Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Thu, 18 Jul 2024 15:51:08 -0500 Subject: [PATCH 21/25] Fixed sytleguide and type assignment. --- .../_internal/parameter/_get_type.py | 36 +++++++++++++++++++ .../_internal/parameter/decoder.py | 2 +- .../_internal/parameter/default_value.py | 22 ------------ .../_internal/parameter/encoder.py | 2 +- .../_internal/parameter/metadata.py | 4 +-- .../parameter/serialization_descriptors.py | 21 ++++------- .../service/tests/unit/test_default_value.py | 4 +-- 7 files changed, 49 insertions(+), 42 deletions(-) create mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py new file mode 100644 index 000000000..c216738bb --- /dev/null +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py @@ -0,0 +1,36 @@ +from typing import Any + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.type_pb2 import Field + +_TYPE_DEFULAT_MAPPING = { + Field.TYPE_FLOAT: float(), + Field.TYPE_DOUBLE: float(), + Field.TYPE_INT32: int(), + Field.TYPE_INT64: int(), + Field.TYPE_UINT32: int(), + Field.TYPE_UINT64: int(), + Field.TYPE_BOOL: bool(), + Field.TYPE_STRING: str(), + Field.TYPE_ENUM: int(), +} + +_TYPE_FIELD_MAPPING = { + Field.TYPE_FLOAT: FieldDescriptorProto.TYPE_FLOAT, + Field.TYPE_DOUBLE: FieldDescriptorProto.TYPE_DOUBLE, + Field.TYPE_INT32: FieldDescriptorProto.TYPE_INT32, + Field.TYPE_INT64: FieldDescriptorProto.TYPE_INT64, + Field.TYPE_UINT32: FieldDescriptorProto.TYPE_UINT32, + Field.TYPE_UINT64: FieldDescriptorProto.TYPE_UINT64, + Field.TYPE_BOOL: FieldDescriptorProto.TYPE_BOOL, + Field.TYPE_STRING: FieldDescriptorProto.TYPE_STRING, + Field.TYPE_ENUM: FieldDescriptorProto.TYPE_ENUM, + Field.TYPE_MESSAGE: FieldDescriptorProto.TYPE_MESSAGE, +} + + +def get_type_default(type: Field.Kind.ValueType, repeated: bool) -> Any: + """Get the default value for the give type.""" + if repeated: + return list() + return _TYPE_DEFULAT_MAPPING.get(type) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index d0b2e5c81..91959544d 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -58,7 +58,7 @@ def deserialize_parameters( def _deserialize_enum_parameter( - repeated: bool, field_value: Union[List[int], int], enum_type: Any + repeated: bool, field_value: Any, enum_type: Any ) -> Union[List[Any], Any]: """Convert all enums into the user defined enum type. diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py deleted file mode 100644 index 8a54edf3a..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/default_value.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any - -from google.protobuf import type_pb2 - -_TYPE_DEFULAT_MAPPING = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), -} - - -def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: - """Get the default value for the give type.""" - if repeated: - return list() - return _TYPE_DEFULAT_MAPPING.get(type) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py index 16e246e9d..0fd12ead9 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -6,7 +6,7 @@ from google.protobuf import descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto -from ni_measurement_plugin_sdk_service._internal.parameter.default_value import ( +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( get_type_default, ) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 652fa0504..c9db44531 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -10,7 +10,7 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter.default_value import ( +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( get_type_default, ) from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization @@ -60,7 +60,7 @@ def initialize( field_name = underscore_display_name else: field_name = "".join( - char for char in underscore_display_name if char.isalnum() or char == "_" + char for char in underscore_display_name if char.isalnum() or char == "_" ) return ParameterMetadata( diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index ec7f10328..0a57fd1cf 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -4,9 +4,12 @@ from typing import List from google.protobuf import descriptor_pb2, descriptor_pool -from google.protobuf.descriptor_pb2 import FieldDescriptorProto, DescriptorProto +from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( + _TYPE_FIELD_MAPPING, +) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) @@ -20,7 +23,7 @@ def _create_enum_type_class( """Implement a enum class in 'file_descriptor'.""" enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) enum_type_name = _get_enum_type(parameter_metadata).__name__ - + # if enum is a protobuf then enum_type_name is 1st letter of each enum name # e.g. {"NONE": 0, "RED": 1, "GREEN": 2} -> NRG if enum_type_name == "int" or enum_type_name == "NoneType": @@ -51,7 +54,7 @@ def _create_field( field_descriptor = message_proto.field.add() field_descriptor.number = index field_descriptor.name = metadata.field_name - field_descriptor.type = metadata.type + field_descriptor.type = _TYPE_FIELD_MAPPING[metadata.type] if metadata.repeated: field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED @@ -92,17 +95,7 @@ def create_file_descriptor( input_metadata: List[ParameterMetadata], pool: descriptor_pool.DescriptorPool, ) -> None: - """Creates two message types in one file descriptor proto. - - Args: - service_class_name (str): Unique service name. - - output_metadata (List[ParameterMetadata]): Metadata of output parameters. - - input_metadata (List[ParameterMetadata]): Metadata of input parameters. - - pool (DescriptorPool): Descriptor pool holding file descriptors and enum classes. - """ + """Creates two message types in one file descriptor proto.""" try: pool.FindFileByName(service_name) except KeyError: diff --git a/packages/service/tests/unit/test_default_value.py b/packages/service/tests/unit/test_default_value.py index 42678a6ac..1be45f278 100644 --- a/packages/service/tests/unit/test_default_value.py +++ b/packages/service/tests/unit/test_default_value.py @@ -4,7 +4,7 @@ from google.protobuf import type_pb2 from ni_measurement_plugin_sdk_service._internal.parameter import ( - default_value, + _get_type, ) @@ -25,6 +25,6 @@ ], ) def test___get_default_value___returns_type_defaults(type, is_repeated, expected_default_value): - test_default_value = default_value.get_type_default(type, is_repeated) + test_default_value = _get_type.get_type_default(type, is_repeated) assert test_default_value == expected_default_value From f351927596204cd0e4acfc4c6dba60f712fe6a64 Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 23 Jul 2024 10:24:29 -0500 Subject: [PATCH 22/25] Pass enum type in initialize() and moved it in ParameterMetadata --- .../_internal/parameter/_get_type.py | 6 +-- .../_internal/parameter/decoder.py | 28 +++++----- .../_internal/parameter/metadata.py | 50 ++++++++++-------- .../parameter/serialization_descriptors.py | 30 +++++------ .../measurement/service.py | 6 ++- packages/service/tests/unit/test_decoder.py | 51 ++++++++++--------- 6 files changed, 91 insertions(+), 80 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py index c216738bb..916680889 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py @@ -3,7 +3,7 @@ from google.protobuf.descriptor_pb2 import FieldDescriptorProto from google.protobuf.type_pb2 import Field -_TYPE_DEFULAT_MAPPING = { +_TYPE_DEFAULT_MAPPING = { Field.TYPE_FLOAT: float(), Field.TYPE_DOUBLE: float(), Field.TYPE_INT32: int(), @@ -15,7 +15,7 @@ Field.TYPE_ENUM: int(), } -_TYPE_FIELD_MAPPING = { +TYPE_FIELD_MAPPING = { Field.TYPE_FLOAT: FieldDescriptorProto.TYPE_FLOAT, Field.TYPE_DOUBLE: FieldDescriptorProto.TYPE_DOUBLE, Field.TYPE_INT32: FieldDescriptorProto.TYPE_INT32, @@ -33,4 +33,4 @@ def get_type_default(type: Field.Kind.ValueType, repeated: bool) -> Any: """Get the default value for the give type.""" if repeated: return list() - return _TYPE_DEFULAT_MAPPING.get(type) + return _TYPE_DEFAULT_MAPPING.get(type) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index 91959544d..24d49eead 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -1,6 +1,6 @@ """Parameter Serializer.""" -from typing import Any, Dict, List, Union +from typing import Any, Dict from google.protobuf import descriptor_pool, message_factory from google.protobuf.descriptor_pb2 import FieldDescriptorProto @@ -8,9 +8,6 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) -from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( - _get_enum_type, -) def deserialize_parameters( @@ -39,13 +36,10 @@ def deserialize_parameters( for i in message_proto.fields_by_number.keys(): parameter_metadata = parameter_metadata_dict[i] field_name = parameter_metadata.field_name - enum_type = _get_enum_type(parameter_metadata) value = getattr(message_instance, field_name) - if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM and enum_type is not int: - parameter_values[i] = _deserialize_enum_parameter( - parameter_metadata.repeated, value, enum_type - ) + if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: + parameter_values[i] = _deserialize_enum_parameter(value, parameter_metadata) elif ( parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE and not parameter_metadata.repeated @@ -57,14 +51,18 @@ def deserialize_parameters( return parameter_values -def _deserialize_enum_parameter( - repeated: bool, field_value: Any, enum_type: Any -) -> Union[List[Any], Any]: +def _deserialize_enum_parameter(field_value: Any, metadata: ParameterMetadata) -> Any: """Convert all enums into the user defined enum type. Returns: Union[List[Any], Any]: Enum type or a list of enum types. """ - if repeated: - return [enum_type(value) for value in field_value] - return enum_type(field_value) + try: + # ValueType is defined when field_value is a protobuf enum + metadata.enum_type.ValueType + return field_value + except AttributeError: + enum_type = metadata.enum_type + if metadata.repeated: + return [enum_type(value) for value in field_value] + return enum_type(field_value) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index c9db44531..2b747605e 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -45,28 +45,38 @@ class ParameterMetadata(NamedTuple): field_name: str = "" """display_name in snake_case format.""" - -def initialize( - display_name: str, - type: type_pb2.Field.Kind.ValueType, - repeated: bool, - default_value: Any, - annotations: Dict[str, str], - message_type: str = "", -) -> ParameterMetadata: - """Initialize ParameterMetadata with field_name.""" - underscore_display_name = display_name.replace(" ", "_") - if all(char.isalnum() or char == "_" for char in underscore_display_name): - field_name = underscore_display_name - else: - field_name = "".join( - char for char in underscore_display_name if char.isalnum() or char == "_" + enum_type: Any = None + """Enum type of parameter""" + + @staticmethod + def initialize( + display_name: str, + type: type_pb2.Field.Kind.ValueType, + repeated: bool, + default_value: Any, + annotations: Dict[str, str], + message_type: str = "", + enum_type: Any = None, + ) -> "ParameterMetadata": + """Initialize ParameterMetadata with field_name.""" + underscore_display_name = display_name.replace(" ", "_") + if all(char.isalnum() or char == "_" for char in underscore_display_name): + field_name = underscore_display_name + else: + field_name = "".join( + char for char in underscore_display_name if char.isalnum() or char == "_" + ) + return ParameterMetadata( + display_name, + type, + repeated, + default_value, + annotations, + message_type, + field_name, + enum_type, ) - return ParameterMetadata( - display_name, type, repeated, default_value, annotations, message_type, field_name - ) - def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: """Validate and raise exception if the default value does not match the type info. diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index 0a57fd1cf..945014bd0 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -1,5 +1,6 @@ """Serialization Descriptors.""" +from enum import EnumType from json import loads from typing import List @@ -8,7 +9,7 @@ from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( - _TYPE_FIELD_MAPPING, + TYPE_FIELD_MAPPING, ) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, @@ -22,15 +23,19 @@ def _create_enum_type_class( ) -> None: """Implement a enum class in 'file_descriptor'.""" enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - enum_type_name = _get_enum_type(parameter_metadata).__name__ - # if enum is a protobuf then enum_type_name is 1st letter of each enum name - # e.g. {"NONE": 0, "RED": 1, "GREEN": 2} -> NRG - if enum_type_name == "int" or enum_type_name == "NoneType": - enum_field_names = list(enum_dict.keys())[:] - enum_type_name = "".join(name[0] for name in enum_field_names) + # True if enum is protobuf + if not isinstance(parameter_metadata.enum_type, EnumType): + try: + enum_type_name = parameter_metadata.enum_type.DESCRIPTOR.name + except AttributeError: + # Uses field name if DESCRIPTOR.name isn't defined + name_sections = parameter_metadata.field_name.split("_") + enum_type_name = "".join(section.capitalize() for section in name_sections) + else: + enum_type_name = parameter_metadata.enum_type.__name__ - if enum_type_name not in [enum_type.name for enum_type in file_descriptor.enum_type]: + if enum_type_name not in [file_enum.name for file_enum in file_descriptor.enum_type]: enum_descriptor = file_descriptor.enum_type.add() enum_descriptor.name = enum_type_name for name, number in enum_dict.items(): @@ -40,13 +45,6 @@ def _create_enum_type_class( field_descriptor.type_name = enum_type_name -def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: - if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: - return type(parameter_metadata.default_value[0]) - else: - return type(parameter_metadata.default_value) - - def _create_field( message_proto: DescriptorProto, metadata: ParameterMetadata, index: int ) -> FieldDescriptorProto: @@ -54,7 +52,7 @@ def _create_field( field_descriptor = message_proto.field.add() field_descriptor.number = index field_descriptor.name = metadata.field_name - field_descriptor.type = _TYPE_FIELD_MAPPING[metadata.type] + field_descriptor.type = TYPE_FIELD_MAPPING[metadata.type] if metadata.repeated: field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED diff --git a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py index d1a445da1..ff78e6b1f 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py +++ b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py @@ -415,13 +415,14 @@ def configuration( annotations = self._make_annotations_dict( data_type_info.type_specialization, instrument_type=instrument_type, enum_type=enum_type ) - parameter = parameter_metadata.initialize( + parameter = parameter_metadata.ParameterMetadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, default_value, annotations, data_type_info.message_type, + enum_type, ) parameter_metadata.validate_default_value_type(parameter) self._configuration_parameter_list.append(parameter) @@ -475,13 +476,14 @@ def output( annotations = self._make_annotations_dict( data_type_info.type_specialization, enum_type=enum_type ) - parameter = parameter_metadata.initialize( + parameter = parameter_metadata.ParameterMetadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, None, annotations, data_type_info.message_type, + enum_type, ) self._output_parameter_list.append(parameter) diff --git a/packages/service/tests/unit/test_decoder.py b/packages/service/tests/unit/test_decoder.py index 23aee9ea3..9a69ae8e9 100644 --- a/packages/service/tests/unit/test_decoder.py +++ b/packages/service/tests/unit/test_decoder.py @@ -17,7 +17,6 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, - initialize, ) from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( xydata_pb2, @@ -176,119 +175,119 @@ def _get_grpc_serialized_data(values): def _get_test_parameter_by_id(default_values): parameter_by_id = { - 1: initialize( + 1: ParameterMetadata.initialize( display_name="float_data!", type=type_pb2.Field.TYPE_FLOAT, repeated=False, default_value=default_values[0], annotations={}, ), - 2: initialize( + 2: ParameterMetadata.initialize( display_name="double_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, default_value=default_values[1], annotations={}, ), - 3: initialize( + 3: ParameterMetadata.initialize( display_name="int32_data", type=type_pb2.Field.TYPE_INT32, repeated=False, default_value=default_values[2], annotations={}, ), - 4: initialize( + 4: ParameterMetadata.initialize( display_name="uint32_data", type=type_pb2.Field.TYPE_INT64, repeated=False, default_value=default_values[3], annotations={}, ), - 5: initialize( + 5: ParameterMetadata.initialize( display_name="int64_data", type=type_pb2.Field.TYPE_UINT32, repeated=False, default_value=default_values[4], annotations={}, ), - 6: initialize( + 6: ParameterMetadata.initialize( display_name="uint64_data", type=type_pb2.Field.TYPE_UINT64, repeated=False, default_value=default_values[5], annotations={}, ), - 7: initialize( + 7: ParameterMetadata.initialize( display_name="bool_data", type=type_pb2.Field.TYPE_BOOL, repeated=False, default_value=default_values[6], annotations={}, ), - 8: initialize( + 8: ParameterMetadata.initialize( display_name="string_data", type=type_pb2.Field.TYPE_STRING, repeated=False, default_value=default_values[7], annotations={}, ), - 9: initialize( + 9: ParameterMetadata.initialize( display_name="double_array_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=True, default_value=default_values[8], annotations={}, ), - 10: initialize( + 10: ParameterMetadata.initialize( display_name="float_array_data", type=type_pb2.Field.TYPE_FLOAT, repeated=True, default_value=default_values[9], annotations={}, ), - 11: initialize( + 11: ParameterMetadata.initialize( display_name="int32_array_data", type=type_pb2.Field.TYPE_INT32, repeated=True, default_value=default_values[10], annotations={}, ), - 12: initialize( + 12: ParameterMetadata.initialize( display_name="uint32_array_data", type=type_pb2.Field.TYPE_UINT32, repeated=True, default_value=default_values[11], annotations={}, ), - 13: initialize( + 13: ParameterMetadata.initialize( display_name="int64_array_data", type=type_pb2.Field.TYPE_INT64, repeated=True, default_value=default_values[12], annotations={}, ), - 14: initialize( + 14: ParameterMetadata.initialize( display_name="uint64_array_data", type=type_pb2.Field.TYPE_UINT64, repeated=True, default_value=default_values[13], annotations={}, ), - 15: initialize( + 15: ParameterMetadata.initialize( display_name="bool_array_data", type=type_pb2.Field.TYPE_BOOL, repeated=True, default_value=default_values[14], annotations={}, ), - 16: initialize( + 16: ParameterMetadata.initialize( display_name="string_array_data", type=type_pb2.Field.TYPE_STRING, repeated=True, default_value=default_values[15], annotations={}, ), - 17: initialize( + 17: ParameterMetadata.initialize( display_name="enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -297,8 +296,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, + enum_type=DifferentColor, ), - 18: initialize( + 18: ParameterMetadata.initialize( display_name="enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -307,8 +307,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, + enum_type=DifferentColor, ), - 19: initialize( + 19: ParameterMetadata.initialize( display_name="int_enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -317,8 +318,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, + enum_type=Countries, ), - 20: initialize( + 20: ParameterMetadata.initialize( display_name="int_enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -327,8 +329,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, + enum_type=Countries, ), - 21: initialize( + 21: ParameterMetadata.initialize( display_name="xy_data", type=type_pb2.Field.TYPE_MESSAGE, repeated=False, @@ -336,7 +339,7 @@ def _get_test_parameter_by_id(default_values): annotations={}, message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, ), - 22: initialize( + 22: ParameterMetadata.initialize( display_name="xy_data_array", type=type_pb2.Field.TYPE_MESSAGE, repeated=True, @@ -379,7 +382,7 @@ def _get_test_grpc_message(test_values): def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: return { i - + 1: initialize( + + 1: ParameterMetadata.initialize( display_name=f"field{i + 1}", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, From ff0850cfb7def4808ee394c5ec00e384dd7ae1ae Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 23 Jul 2024 11:05:08 -0500 Subject: [PATCH 23/25] Changed EnumType to Enum --- .../_internal/parameter/serialization_descriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index 945014bd0..84a871143 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -1,6 +1,6 @@ """Serialization Descriptors.""" -from enum import EnumType +from enum import Enum from json import loads from typing import List @@ -25,7 +25,7 @@ def _create_enum_type_class( enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) # True if enum is protobuf - if not isinstance(parameter_metadata.enum_type, EnumType): + if not isinstance(parameter_metadata.enum_type, Enum): try: enum_type_name = parameter_metadata.enum_type.DESCRIPTOR.name except AttributeError: From ce5425defc1bc532ae161593bc544b49a2ff122c Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 23 Jul 2024 11:45:08 -0500 Subject: [PATCH 24/25] Changed isinstance() to type() in _create_enum_type_class --- .../_internal/parameter/serialization_descriptors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index 84a871143..95b754e5a 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -25,7 +25,7 @@ def _create_enum_type_class( enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) # True if enum is protobuf - if not isinstance(parameter_metadata.enum_type, Enum): + if not type(parameter_metadata.enum_type) is type(Enum): try: enum_type_name = parameter_metadata.enum_type.DESCRIPTOR.name except AttributeError: From e76e00a6862e308b4b4977f4b6c3daee310d905e Mon Sep 17 00:00:00 2001 From: Tyler Nguyen Date: Tue, 23 Jul 2024 15:59:22 -0500 Subject: [PATCH 25/25] Add correct type hint to enum_type passing in initialize() and added helper functions for enums --- .../_internal/parameter/decoder.py | 24 ++++---- .../_internal/parameter/metadata.py | 14 ++++- .../parameter/serialization_descriptors.py | 61 ++++++++++++------- 3 files changed, 61 insertions(+), 38 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py index 24d49eead..b11c7cb24 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -8,6 +8,9 @@ from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, ) +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( + is_protobuf, +) def deserialize_parameters( @@ -52,17 +55,12 @@ def deserialize_parameters( def _deserialize_enum_parameter(field_value: Any, metadata: ParameterMetadata) -> Any: - """Convert all enums into the user defined enum type. - - Returns: - Union[List[Any], Any]: Enum type or a list of enum types. - """ - try: - # ValueType is defined when field_value is a protobuf enum - metadata.enum_type.ValueType + """Convert enum into their user defined enum type.""" + enum_type = metadata.enum_type + if is_protobuf(enum_type): return field_value - except AttributeError: - enum_type = metadata.enum_type - if metadata.repeated: - return [enum_type(value) for value in field_value] - return enum_type(field_value) + + assert enum_type is not None + if metadata.repeated: + return [enum_type(value) for value in field_value] + return enum_type(field_value) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index 2b747605e..1abdadeaf 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -1,8 +1,10 @@ """Contains classes that represents metadata.""" +from __future__ import annotations + import json from enum import Enum -from typing import Any, Dict, Iterable, NamedTuple +from typing import Any, Dict, Iterable, NamedTuple, Union, Type, Optional, TYPE_CHECKING from google.protobuf import type_pb2 @@ -16,6 +18,12 @@ from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization +if TYPE_CHECKING: + from google.protobuf.internal.enum_type_wrapper import _EnumTypeWrapper + + SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper] + + class ParameterMetadata(NamedTuple): """Class that represents the metadata of parameters.""" @@ -45,7 +53,7 @@ class ParameterMetadata(NamedTuple): field_name: str = "" """display_name in snake_case format.""" - enum_type: Any = None + enum_type: Optional[SupportedEnumType] = None """Enum type of parameter""" @staticmethod @@ -56,7 +64,7 @@ def initialize( default_value: Any, annotations: Dict[str, str], message_type: str = "", - enum_type: Any = None, + enum_type: Optional[SupportedEnumType] = None, ) -> "ParameterMetadata": """Initialize ParameterMetadata with field_name.""" underscore_display_name = display_name.replace(" ", "_") diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py index 95b754e5a..9800f1199 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -1,11 +1,17 @@ """Serialization Descriptors.""" -from enum import Enum +from __future__ import annotations + +from enum import Enum, EnumMeta from json import loads -from typing import List +from typing import TYPE_CHECKING, List, Type, Union, Optional -from google.protobuf import descriptor_pb2, descriptor_pool -from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto +from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, +) +from google.protobuf.descriptor_pool import DescriptorPool from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( @@ -15,25 +21,36 @@ ParameterMetadata, ) +if TYPE_CHECKING: + from google.protobuf.internal.enum_type_wrapper import _EnumTypeWrapper + + SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper] + + +def is_protobuf(enum_type: Optional[SupportedEnumType]) -> bool: + """Finds if 'enum_type' is a protobuf or a python enum.""" + return hasattr(enum_type, "ValueType") + + +def _get_enum_type_name(metadata: ParameterMetadata) -> str: + """Get's enum type name from a 'parameter_metadata'.""" + enum_type = metadata.enum_type + if enum_type is None: + raise ValueError("Enum type cannot be None in ParameterMetadata.") + + if is_protobuf(enum_type) and not isinstance(enum_type, EnumMeta): + return enum_type.DESCRIPTOR.name + return enum_type.__name__ + def _create_enum_type_class( - file_descriptor: descriptor_pb2.FileDescriptorProto, - parameter_metadata: ParameterMetadata, + file_descriptor: FileDescriptorProto, + metadata: ParameterMetadata, field_descriptor: FieldDescriptorProto, ) -> None: """Implement a enum class in 'file_descriptor'.""" - enum_dict = loads(parameter_metadata.annotations[ENUM_VALUES_KEY]) - - # True if enum is protobuf - if not type(parameter_metadata.enum_type) is type(Enum): - try: - enum_type_name = parameter_metadata.enum_type.DESCRIPTOR.name - except AttributeError: - # Uses field name if DESCRIPTOR.name isn't defined - name_sections = parameter_metadata.field_name.split("_") - enum_type_name = "".join(section.capitalize() for section in name_sections) - else: - enum_type_name = parameter_metadata.enum_type.__name__ + enum_dict = loads(metadata.annotations[ENUM_VALUES_KEY]) + enum_type_name = _get_enum_type_name(metadata) if enum_type_name not in [file_enum.name for file_enum in file_descriptor.enum_type]: enum_descriptor = file_descriptor.enum_type.add() @@ -68,7 +85,7 @@ def _create_field( def _create_message_type( parameter_metadata: List[ParameterMetadata], message_name: str, - file_descriptor: descriptor_pb2.FileDescriptorProto, + file_descriptor: FileDescriptorProto, ) -> None: """Creates a message type with fields intialized in 'file_descriptor'.""" message_proto = file_descriptor.message_type.add() @@ -82,7 +99,7 @@ def _create_message_type( if metadata.type == FieldDescriptorProto.TYPE_ENUM: _create_enum_type_class( file_descriptor=file_descriptor, - parameter_metadata=metadata, + metadata=metadata, field_descriptor=field_descriptor, ) @@ -91,13 +108,13 @@ def create_file_descriptor( service_name: str, output_metadata: List[ParameterMetadata], input_metadata: List[ParameterMetadata], - pool: descriptor_pool.DescriptorPool, + pool: DescriptorPool, ) -> None: """Creates two message types in one file descriptor proto.""" try: pool.FindFileByName(service_name) except KeyError: - file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor = FileDescriptorProto() file_descriptor.name = service_name file_descriptor.package = service_name _create_message_type(input_metadata, "Configurations", file_descriptor)