diff --git a/WORKSPACE b/WORKSPACE index d6e54d6ea..e9eea8089 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,30 +63,34 @@ load("@rules_jvm_external//:setup.bzl", "rules_jvm_external_setup") rules_jvm_external_setup() load("@rules_jvm_external//:defs.bzl", "maven_install") +load("//:maven_utils.bzl", "maven_artifact_compile_only", "maven_artifact_test_only") ANTLR4_VERSION = "4.13.2" -# Important: there can only be one maven_install rule. Add new maven deps here. maven_install( # keep sorted artifacts = [ "com.google.auto.value:auto-value:1.11.0", "com.google.auto.value:auto-value-annotations:1.11.0", - "com.google.code.findbugs:annotations:3.0.1", - "com.google.errorprone:error_prone_annotations:2.36.0", "com.google.guava:guava:33.3.1-jre", "com.google.guava:guava-testlib:33.3.1-jre", "com.google.protobuf:protobuf-java:4.28.3", "com.google.protobuf:protobuf-java-util:4.28.3", "com.google.re2j:re2j:1.8", - "com.google.testparameterinjector:test-parameter-injector:1.18", - "com.google.truth.extensions:truth-java8-extension:1.4.4", - "com.google.truth.extensions:truth-proto-extension:1.4.4", - "com.google.truth:truth:1.4.4", "org.antlr:antlr4-runtime:" + ANTLR4_VERSION, + "info.picocli:picocli:4.7.6", + "javax.annotation:javax.annotation-api:1.3.2", + "org.freemarker:freemarker:2.3.33", "org.jspecify:jspecify:1.0.0", "org.threeten:threeten-extra:1.8.0", "org.yaml:snakeyaml:2.3", + maven_artifact_test_only("com.google.testparameterinjector", "test-parameter-injector", "1.18"), + maven_artifact_test_only("com.google.truth", "truth", "1.4.4"), + maven_artifact_test_only("com.google.truth.extensions", "truth-java8-extension", "1.4.4"), + maven_artifact_test_only("com.google.truth.extensions", "truth-proto-extension", "1.4.4"), + maven_artifact_test_only("com.google.truth.extensions", "truth-liteproto-extension", "1.4.4"), + maven_artifact_compile_only("com.google.code.findbugs", "annotations", "3.0.1"), + maven_artifact_compile_only("com.google.errorprone", "error_prone_annotations", "2.36.0"), ], repositories = [ "https://maven.google.com", diff --git a/common/internal/BUILD.bazel b/common/internal/BUILD.bazel index 63c8bcbf0..656e7414e 100644 --- a/common/internal/BUILD.bazel +++ b/common/internal/BUILD.bazel @@ -25,6 +25,11 @@ java_library( exports = ["//common/src/main/java/dev/cel/common/internal:dynamic_proto"], ) +java_library( + name = "proto_lite_adapter", + exports = ["//common/src/main/java/dev/cel/common/internal:proto_lite_adapter"], +) + java_library( name = "proto_equality", exports = ["//common/src/main/java/dev/cel/common/internal:proto_equality"], @@ -50,6 +55,11 @@ java_library( exports = ["//common/src/main/java/dev/cel/common/internal:default_instance_message_factory"], ) +java_library( + name = "default_instance_message_lite_factory", + exports = ["//common/src/main/java/dev/cel/common/internal:default_instance_message_lite_factory"], +) + java_library( name = "well_known_proto", exports = ["//common/src/main/java/dev/cel/common/internal:well_known_proto"], @@ -70,7 +80,22 @@ java_library( exports = ["//common/src/main/java/dev/cel/common/internal:cel_descriptor_pools"], ) +java_library( + name = "cel_lite_descriptor_pool", + exports = ["//common/src/main/java/dev/cel/common/internal:cel_lite_descriptor_pool"], +) + java_library( name = "safe_string_formatter", exports = ["//common/src/main/java/dev/cel/common/internal:safe_string_formatter"], ) + +java_library( + name = "proto_java_qualified_names", + exports = ["//common/src/main/java/dev/cel/common/internal:proto_java_qualified_names"], +) + +java_library( + name = "reflection_util", + exports = ["//common/src/main/java/dev/cel/common/internal:reflection_util"], +) diff --git a/common/src/main/java/dev/cel/common/internal/BUILD.bazel b/common/src/main/java/dev/cel/common/internal/BUILD.bazel index bb9d67bff..0e4b0e581 100644 --- a/common/src/main/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/internal/BUILD.bazel @@ -104,6 +104,20 @@ java_library( tags = [ ], deps = [ + "//common/annotations", + "//common/internal:default_instance_message_lite_factory", + "//common/internal:proto_java_qualified_names", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + +java_library( + name = "default_instance_message_lite_factory", + srcs = ["DefaultInstanceMessageLiteFactory.java"], + tags = [ + ], + deps = [ + ":reflection_util", "//common/annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", @@ -119,18 +133,34 @@ java_library( ], deps = [ ":converter", + ":proto_lite_adapter", ":proto_message_factory", ":well_known_proto", "//:auto_value", "//common:error_codes", - "//common:proto_json_adapter", "//common:runtime_exception", "//common/annotations", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", - "@maven//:org_jspecify_jspecify", + ], +) + +java_library( + name = "proto_lite_adapter", + srcs = ["ProtoLiteAdapter.java"], + tags = [ + ], + deps = [ + ":well_known_proto", + "//common:error_codes", + "//common:proto_json_adapter", + "//common:runtime_exception", + "//common/annotations", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -185,6 +215,7 @@ java_library( "//common/annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", + "@maven//:org_jspecify_jspecify", ], ) @@ -230,6 +261,19 @@ java_library( ], ) +java_library( + name = "cel_lite_descriptor_pool", + srcs = ["CelLiteDescriptorPool.java"], + deps = [ + "//common/annotations", + "//common/internal:well_known_proto", + "//protobuf:cel_lite_descriptor", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + java_library( name = "safe_string_formatter", srcs = ["SafeStringFormatter.java"], @@ -240,3 +284,23 @@ java_library( "@maven//:com_google_re2j_re2j", ], ) + +java_library( + name = "proto_java_qualified_names", + srcs = ["ProtoJavaQualifiedNames.java"], + tags = [ + ], + deps = [ + "//common/annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + +java_library( + name = "reflection_util", + srcs = ["ReflectionUtil.java"], + deps = [ + "//common/annotations", + ], +) diff --git a/common/src/main/java/dev/cel/common/internal/CelLiteDescriptorPool.java b/common/src/main/java/dev/cel/common/internal/CelLiteDescriptorPool.java new file mode 100644 index 000000000..c510bb952 --- /dev/null +++ b/common/src/main/java/dev/cel/common/internal/CelLiteDescriptorPool.java @@ -0,0 +1,192 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.internal; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.MessageLite; +import dev.cel.common.annotations.Internal; +import dev.cel.protobuf.CelLiteDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import java.util.Optional; + +/** Descriptor pool for {@link CelLiteDescriptor}s. */ +@Immutable +@Internal +public final class CelLiteDescriptorPool { + private final ImmutableMap protoFqnToMessageInfo; + private final ImmutableMap protoJavaClassNameToMessageInfo; + + public static CelLiteDescriptorPool newInstance(ImmutableSet descriptors) { + return new CelLiteDescriptorPool(descriptors); + } + + public Optional findDescriptorByTypeName(String protoFqn) { + return Optional.ofNullable(protoFqnToMessageInfo.get(protoFqn)); + } + + public Optional findDescriptor(MessageLite msg) { + String className = msg.getClass().getName(); + return Optional.ofNullable(protoJavaClassNameToMessageInfo.get(className)); + } + + private static MessageDescriptor newMessageInfo(WellKnownProto wellKnownProto) { + ImmutableMap.Builder fieldInfoMap = ImmutableMap.builder(); + switch (wellKnownProto) { + case JSON_STRUCT_VALUE: + fieldInfoMap.put( + "fields", + new FieldDescriptor( + "google.protobuf.Struct.fields", + "MESSAGE", + "Fields", + FieldDescriptor.ValueType.MAP.toString(), + FieldDescriptor.Type.MESSAGE.toString(), + String.valueOf(false), + "com.google.protobuf.Struct$FieldsEntry", + "google.protobuf.Struct.FieldsEntry")); + break; + case BOOL_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.BoolValue", + "BOOLEAN", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.BOOL)); + break; + case BYTES_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.BytesValue", + "BYTE_STRING", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.BYTES)); + break; + case DOUBLE_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.DoubleValue", + "DOUBLE", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.DOUBLE)); + break; + case FLOAT_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.FloatValue", + "FLOAT", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.FLOAT)); + break; + case INT32_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.Int32Value", + "INT", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.INT32)); + break; + case INT64_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.Int64Value", + "LONG", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.INT64)); + break; + case STRING_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.StringValue", + "STRING", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.STRING)); + break; + case UINT32_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.UInt32Value", + "INT", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.UINT32)); + break; + case UINT64_VALUE: + fieldInfoMap.put( + "value", + newPrimitiveFieldInfo( + "google.protobuf.UInt64Value", + "LONG", + FieldDescriptor.ValueType.SCALAR, + FieldDescriptor.Type.UINT64)); + break; + case JSON_VALUE: + case JSON_LIST_VALUE: + case DURATION_VALUE: + case TIMESTAMP_VALUE: + // TODO: Complete these + break; + default: + break; + } + + return new MessageDescriptor( + wellKnownProto.typeName(), wellKnownProto.javaClassName(), fieldInfoMap.buildOrThrow()); + } + + private static FieldDescriptor newPrimitiveFieldInfo( + String fullyQualifiedProtoName, + String javaTypeName, + FieldDescriptor.ValueType valueType, + FieldDescriptor.Type protoFieldType) { + return new FieldDescriptor( + fullyQualifiedProtoName + ".value", + javaTypeName, + "Value", + valueType.toString(), + protoFieldType.toString(), + String.valueOf(false), + "", + fullyQualifiedProtoName); + } + + private CelLiteDescriptorPool(ImmutableSet descriptors) { + ImmutableMap.Builder protoFqnMapBuilder = ImmutableMap.builder(); + ImmutableMap.Builder protoJavaClassNameMapBuilder = + ImmutableMap.builder(); + for (WellKnownProto wellKnownProto : WellKnownProto.values()) { + MessageDescriptor wktMessageInfo = newMessageInfo(wellKnownProto); + protoFqnMapBuilder.put(wellKnownProto.typeName(), wktMessageInfo); + protoJavaClassNameMapBuilder.put(wellKnownProto.javaClassName(), wktMessageInfo); + } + + for (CelLiteDescriptor descriptor : descriptors) { + protoFqnMapBuilder.putAll(descriptor.getProtoTypeNamesToDescriptors()); + protoJavaClassNameMapBuilder.putAll(descriptor.getProtoJavaClassNameToDescriptors()); + } + + this.protoFqnToMessageInfo = protoFqnMapBuilder.buildOrThrow(); + this.protoJavaClassNameToMessageInfo = protoJavaClassNameMapBuilder.buildOrThrow(); + } +} diff --git a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java index fc703c905..8a07bb218 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java @@ -22,9 +22,26 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import com.google.protobuf.BytesValue; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Duration; +import com.google.protobuf.Empty; import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.FieldMask; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.ListValue; +import com.google.protobuf.StringValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; import dev.cel.common.CelDescriptors; import dev.cel.common.annotations.Internal; import java.util.HashMap; @@ -40,14 +57,36 @@ @Immutable @Internal public final class DefaultDescriptorPool implements CelDescriptorPool { - private static final ImmutableMap WELL_KNOWN_TYPE_DESCRIPTORS = + + private static final ImmutableMap WELL_KNOWN_PROTO_TO_DESCRIPTORS = + ImmutableMap.builder() + .put(WellKnownProto.ANY_VALUE, Any.getDescriptor()) + .put(WellKnownProto.BOOL_VALUE, BoolValue.getDescriptor()) + .put(WellKnownProto.BYTES_VALUE, BytesValue.getDescriptor()) + .put(WellKnownProto.DOUBLE_VALUE, DoubleValue.getDescriptor()) + .put(WellKnownProto.DURATION_VALUE, Duration.getDescriptor()) + .put(WellKnownProto.FLOAT_VALUE, FloatValue.getDescriptor()) + .put(WellKnownProto.INT32_VALUE, Int32Value.getDescriptor()) + .put(WellKnownProto.INT64_VALUE, Int64Value.getDescriptor()) + .put(WellKnownProto.STRING_VALUE, StringValue.getDescriptor()) + .put(WellKnownProto.TIMESTAMP_VALUE, Timestamp.getDescriptor()) + .put(WellKnownProto.UINT32_VALUE, UInt32Value.getDescriptor()) + .put(WellKnownProto.UINT64_VALUE, UInt64Value.getDescriptor()) + .put(WellKnownProto.JSON_LIST_VALUE, ListValue.getDescriptor()) + .put(WellKnownProto.JSON_STRUCT_VALUE, Struct.getDescriptor()) + .put(WellKnownProto.JSON_VALUE, Value.getDescriptor()) + .put(WellKnownProto.EMPTY_VALUE, Empty.getDescriptor()) + .put(WellKnownProto.FIELD_MASK_VALUE, FieldMask.getDescriptor()) + .buildOrThrow(); + + private static final ImmutableMap WELL_KNOWN_TYPE_NAME_TO_DESCRIPTORS = stream(WellKnownProto.values()) - .collect(toImmutableMap(WellKnownProto::typeName, WellKnownProto::descriptor)); + .collect(toImmutableMap(WellKnownProto::typeName, WELL_KNOWN_PROTO_TO_DESCRIPTORS::get)); /** A DefaultDescriptorPool instance with just well known types loaded. */ public static final DefaultDescriptorPool INSTANCE = new DefaultDescriptorPool( - WELL_KNOWN_TYPE_DESCRIPTORS, + WELL_KNOWN_TYPE_NAME_TO_DESCRIPTORS, ImmutableMultimap.of(), ExtensionRegistry.getEmptyRegistry()); @@ -67,8 +106,8 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) { public static DefaultDescriptorPool create( CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) { - Map descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping - stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor())); + Map descriptorMap = + new HashMap<>(WELL_KNOWN_TYPE_NAME_TO_DESCRIPTORS); // Using a hashmap to allow deduping for (Descriptor descriptor : celDescriptors.messageTypeDescriptors()) { descriptorMap.putIfAbsent(descriptor.getFullName(), descriptor); diff --git a/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java index 1147da7ad..b49c3b1b8 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,23 +14,11 @@ package dev.cel.common.internal; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.CaseFormat; -import com.google.common.base.Joiner; -import com.google.common.base.Strings; -import com.google.common.io.Files; -import com.google.protobuf.DescriptorProtos.FileOptions; import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.EnumDescriptor; -import com.google.protobuf.Descriptors.FileDescriptor; -import com.google.protobuf.Descriptors.ServiceDescriptor; import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; import dev.cel.common.annotations.Internal; -import java.lang.reflect.InvocationTargetException; -import java.util.ArrayDeque; -import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; /** * Singleton factory for creating default messages from a protobuf descriptor. @@ -39,19 +27,13 @@ */ @Internal public final class DefaultInstanceMessageFactory { - - // Controls how many times we should recursively inspect a nested message for building fully - // qualified java class name before aborting. - public static final int SAFE_RECURSE_LIMIT = 50; - - private static final DefaultInstanceMessageFactory instance = new DefaultInstanceMessageFactory(); - - private final Map messageByDescriptorName = - new ConcurrentHashMap<>(); + private static final DefaultInstanceMessageFactory INSTANCE = new DefaultInstanceMessageFactory(); + private static final DefaultInstanceMessageLiteFactory LITE_FACTORY = + DefaultInstanceMessageLiteFactory.getInstance(); /** Gets a single instance of this MessageFactory */ public static DefaultInstanceMessageFactory getInstance() { - return instance; + return INSTANCE; } /** @@ -63,182 +45,29 @@ public static DefaultInstanceMessageFactory getInstance() { * descriptor class isn't loaded in the binary. */ public Optional getPrototype(Descriptor descriptor) { - String descriptorName = descriptor.getFullName(); - LazyGeneratedMessageDefaultInstance lazyDefaultInstance = - messageByDescriptorName.computeIfAbsent( - descriptorName, - (unused) -> - new LazyGeneratedMessageDefaultInstance( - getFullyQualifiedJavaClassName(descriptor))); - - Message defaultInstance = lazyDefaultInstance.getDefaultInstance(); + MessageLite defaultInstance = + LITE_FACTORY + .getPrototype( + descriptor.getFullName(), + ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(descriptor)) + .orElse(null); if (defaultInstance == null) { return Optional.empty(); } - // Reference equality is intended. We want to make sure the descriptors are equal - // to guarantee types to be hermetic if linked types is disabled. - if (defaultInstance.getDescriptorForType() != descriptor) { - return Optional.empty(); - } - return Optional.of(defaultInstance); - } - - /** - * Retrieves the full Java class name from the given descriptor - * - * @return fully qualified class name. - *

Example 1: dev.cel.expr.Value - *

Example 2: com.google.rpc.context.AttributeContext$Resource (Nested classes) - *

Example 3: com.google.api.expr.cel.internal.testdata$SingleFileProto$SingleFile$Path - * (Nested class with java multiple files disabled) - */ - private String getFullyQualifiedJavaClassName(Descriptor descriptor) { - StringBuilder fullClassName = new StringBuilder(); - fullClassName.append(getJavaPackageName(descriptor)); - - String javaOuterClass = getJavaOuterClassName(descriptor); - if (!Strings.isNullOrEmpty(javaOuterClass)) { - fullClassName.append(javaOuterClass).append("$"); - } - - // Recursively build the target class name in case if the message is nested. - ArrayDeque classNames = new ArrayDeque<>(); - Descriptor d = descriptor; - - int recurseCount = 0; - while (d != null) { - classNames.push(d.getName()); - d = d.getContainingType(); - recurseCount++; - if (recurseCount >= SAFE_RECURSE_LIMIT) { - throw new IllegalStateException( - String.format( - "Recursion limit of %d hit while inspecting descriptor: %s", - SAFE_RECURSE_LIMIT, descriptor.getFullName())); - } - } - - Joiner.on("$").appendTo(fullClassName, classNames); - - return fullClassName.toString(); - } - - /** - * Gets the java package name from the descriptor. See - * https://developers.google.com/protocol-buffers/docs/reference/java-generated#package for rules - * on package name generation - */ - private String getJavaPackageName(Descriptor descriptor) { - FileOptions options = descriptor.getFile().getOptions(); - StringBuilder javaPackageName = new StringBuilder(); - if (options.hasJavaPackage()) { - javaPackageName.append(descriptor.getFile().getOptions().getJavaPackage()).append("."); - } else { - javaPackageName - // CEL-Internal-1 - .append(descriptor.getFile().getPackage()) - .append("."); + if (!(defaultInstance instanceof Message)) { + throw new IllegalArgumentException( + "Expected a full protobuf message, but got: " + defaultInstance.getClass()); } - // CEL-Internal-2 + Message fullMessage = (Message) defaultInstance; - return javaPackageName.toString(); - } - - /** - * Gets a wrapping outer class name from the descriptor. The outer class name differs depending on - * the proto options set. See - * https://developers.google.com/protocol-buffers/docs/reference/java-generated#invocation - */ - private String getJavaOuterClassName(Descriptor descriptor) { - FileOptions options = descriptor.getFile().getOptions(); - - if (options.getJavaMultipleFiles()) { - // If java_multiple_files is enabled, protoc does not generate a wrapper outer class - return ""; - } - - if (options.hasJavaOuterClassname()) { - return options.getJavaOuterClassname(); - } else { - // If an outer class name is not explicitly set, the name is converted into - // Pascal case based on the snake cased file name - // Ex: messages_proto.proto becomes MessagesProto - String protoFileNameWithoutExtension = - Files.getNameWithoutExtension(descriptor.getFile().getFullName()); - String outerClassName = - CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, protoFileNameWithoutExtension); - if (hasConflictingClassName(descriptor.getFile(), outerClassName)) { - outerClassName += "OuterClass"; - } - return outerClassName; - } - } - - private boolean hasConflictingClassName(FileDescriptor file, String name) { - for (EnumDescriptor enumDesc : file.getEnumTypes()) { - if (name.equals(enumDesc.getName())) { - return true; - } - } - for (ServiceDescriptor serviceDesc : file.getServices()) { - if (name.equals(serviceDesc.getName())) { - return true; - } - } - for (Descriptor messageDesc : file.getMessageTypes()) { - if (name.equals(messageDesc.getName())) { - return true; - } - } - return false; - } - - /** A placeholder to lazily load the generated messages' defaultInstances. */ - private static final class LazyGeneratedMessageDefaultInstance { - private final String fullClassName; - private volatile Message defaultInstance = null; - private volatile boolean loaded = false; - - public LazyGeneratedMessageDefaultInstance(String fullClassName) { - this.fullClassName = fullClassName; - } - - public Message getDefaultInstance() { - if (!loaded) { - synchronized (this) { - if (!loaded) { - loadDefaultInstance(); - loaded = true; - } - } - } - return defaultInstance; - } - - private void loadDefaultInstance() { - try { - defaultInstance = - (Message) Class.forName(fullClassName).getMethod("getDefaultInstance").invoke(null); - } catch (IllegalAccessException | InvocationTargetException e) { - throw new LinkageError( - String.format("getDefaultInstance for class: %s failed.", fullClassName), e); - } catch (NoSuchMethodException e) { - throw new LinkageError( - String.format("getDefaultInstance method does not exist in class: %s.", fullClassName), - e); - } catch (ClassNotFoundException e) { - // The class may not exist in some instances (Ex: evaluating a checked expression from a - // cached source). - } + // Reference equality is intended. We want to make sure the descriptors are equal + // to guarantee types to be hermetic if linked types is disabled. + if (fullMessage.getDescriptorForType() != descriptor) { + return Optional.empty(); } - } - - /** Clears the descriptor map. This should not be used outside testing. */ - @VisibleForTesting - void resetDescriptorMapForTesting() { - messageByDescriptorName.clear(); + return Optional.of(fullMessage); } private DefaultInstanceMessageFactory() {} diff --git a/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageLiteFactory.java b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageLiteFactory.java new file mode 100644 index 000000000..df663e425 --- /dev/null +++ b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageLiteFactory.java @@ -0,0 +1,111 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.internal; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.MessageLite; +import dev.cel.common.annotations.Internal; +import java.lang.reflect.Method; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Singleton factory for creating default messages from a fully qualified protobuf type name and its + * java class name. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +public final class DefaultInstanceMessageLiteFactory { + + private static final DefaultInstanceMessageLiteFactory INSTANCE = + new DefaultInstanceMessageLiteFactory(); + + private final Map messageByDescriptorName = + new ConcurrentHashMap<>(); + + /** Gets a single instance of this DefaultInstanceMessageLiteFactory */ + public static DefaultInstanceMessageLiteFactory getInstance() { + return INSTANCE; + } + + /** + * Creates a default instance of a protobuf message given a descriptor. This is essentially the + * same as calling FooMessage.getDefaultInstance(), except reflection is leveraged. + * + * @return Default instance of a type. Returns an empty optional if the descriptor used to + * construct the type via reflection is different to the provided descriptor or if the + * descriptor class isn't loaded in the binary. + */ + public Optional getPrototype(String protoFqn, String protoJavaClassFqn) { + LazyGeneratedMessageDefaultInstance lazyDefaultInstance = + messageByDescriptorName.computeIfAbsent( + protoFqn, (unused) -> new LazyGeneratedMessageDefaultInstance(protoJavaClassFqn)); + + MessageLite defaultInstance = lazyDefaultInstance.getDefaultInstance(); + if (defaultInstance == null) { + return Optional.empty(); + } + + return Optional.of(defaultInstance); + } + + /** A placeholder to lazily load the generated messages' defaultInstances. */ + private static final class LazyGeneratedMessageDefaultInstance { + private final String fullClassName; + private volatile MessageLite defaultInstance = null; + private volatile boolean loaded = false; + + public LazyGeneratedMessageDefaultInstance(String fullClassName) { + this.fullClassName = fullClassName; + } + + public MessageLite getDefaultInstance() { + if (!loaded) { + synchronized (this) { + if (!loaded) { + loadDefaultInstance(); + loaded = true; + } + } + } + return defaultInstance; + } + + private void loadDefaultInstance() { + Class clazz; + try { + clazz = Class.forName(fullClassName); + } catch (ClassNotFoundException e) { + // The class may not exist in some instances (Ex: evaluating a checked expression from a + // cached source), or a dynamic descriptor was explicitly used. In this case, a dynamic + // message is returned as the default instance. + return; + } + + Method method = ReflectionUtil.getMethod(clazz, "getDefaultInstance"); + defaultInstance = (MessageLite) ReflectionUtil.invoke(method, null); + } + } + + /** Clears the descriptor map. This should not be used outside testing. */ + @VisibleForTesting + void resetDescriptorMapForTesting() { + messageByDescriptorName.clear(); + } + + private DefaultInstanceMessageLiteFactory() {} +} diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 28da7e5d0..3eed49257 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -15,42 +15,23 @@ package dev.cel.common.internal; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import com.google.common.primitives.UnsignedInts; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Any; -import com.google.protobuf.BoolValue; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.DoubleValue; -import com.google.protobuf.Duration; import com.google.protobuf.DynamicMessage; -import com.google.protobuf.FloatValue; -import com.google.protobuf.Int32Value; -import com.google.protobuf.Int64Value; import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.ListValue; import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.NullValue; -import com.google.protobuf.StringValue; -import com.google.protobuf.Struct; -import com.google.protobuf.Timestamp; -import com.google.protobuf.UInt32Value; -import com.google.protobuf.UInt64Value; -import com.google.protobuf.Value; import dev.cel.common.CelErrorCode; -import dev.cel.common.CelProtoJsonAdapter; import dev.cel.common.CelRuntimeException; import dev.cel.common.annotations.Internal; import java.util.ArrayList; @@ -58,7 +39,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import org.jspecify.annotations.Nullable; /** * The {@code ProtoAdapter} utilities handle conversion between native Java objects which represent @@ -136,12 +116,14 @@ public final class ProtoAdapter { public static final BidiConverter DOUBLE_CONVERTER = BidiConverter.of(Number::doubleValue, Number::floatValue); + private final ProtoLiteAdapter protoLiteAdapter; private final DynamicProto dynamicProto; private final boolean enableUnsignedLongs; public ProtoAdapter(DynamicProto dynamicProto, boolean enableUnsignedLongs) { this.dynamicProto = checkNotNull(dynamicProto); this.enableUnsignedLongs = enableUnsignedLongs; + this.protoLiteAdapter = new ProtoLiteAdapter(enableUnsignedLongs); } /** @@ -157,7 +139,7 @@ public Object adaptProtoToValue(MessageOrBuilder proto) { // If the proto is not a well-known type, then the input Message is what's expected as the // output return value. WellKnownProto wellKnownProto = - WellKnownProto.getByDescriptorName(typeName(proto.getDescriptorForType())); + WellKnownProto.getByTypeName(typeName(proto.getDescriptorForType())); if (wellKnownProto == null) { return proto; } @@ -166,39 +148,8 @@ public Object adaptProtoToValue(MessageOrBuilder proto) { switch (wellKnownProto) { case ANY_VALUE: return unpackAnyProto((Any) proto); - case JSON_VALUE: - return adaptJsonToValue((Value) proto); - case JSON_STRUCT_VALUE: - return adaptJsonStructToValue((Struct) proto); - case JSON_LIST_VALUE: - return adaptJsonListToValue((ListValue) proto); - case BOOL_VALUE: - return ((BoolValue) proto).getValue(); - case BYTES_VALUE: - return ((BytesValue) proto).getValue(); - case DOUBLE_VALUE: - return ((DoubleValue) proto).getValue(); - case FLOAT_VALUE: - return (double) ((FloatValue) proto).getValue(); - case INT32_VALUE: - return (long) ((Int32Value) proto).getValue(); - case INT64_VALUE: - return ((Int64Value) proto).getValue(); - case STRING_VALUE: - return ((StringValue) proto).getValue(); - case UINT32_VALUE: - if (enableUnsignedLongs) { - return UnsignedLong.fromLongBits( - Integer.toUnsignedLong(((UInt32Value) proto).getValue())); - } - return (long) ((UInt32Value) proto).getValue(); - case UINT64_VALUE: - if (enableUnsignedLongs) { - return UnsignedLong.fromLongBits(((UInt64Value) proto).getValue()); - } - return ((UInt64Value) proto).getValue(); default: - return proto; + return protoLiteAdapter.adaptWellKnownProtoToValue(proto, wellKnownProto); } } @@ -314,12 +265,7 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { case MESSAGE: return BidiConverter.of( this::adaptProtoToValue, - value -> - adaptValueToProto(value, fieldDescriptor.getMessageType().getFullName()) - .orElseThrow( - () -> - new IllegalStateException( - String.format("value not convertible to proto: %s", value)))); + value -> adaptValueToProto(value, fieldDescriptor.getMessageType().getFullName())); default: return BidiConverter.IDENTITY; } @@ -333,213 +279,25 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { * protoTypeName} will indicate an alternative packaging of the value which needs to be * considered, such as a packing an {@code google.protobuf.StringValue} into a {@code Any} value. */ - @SuppressWarnings("unchecked") - public Optional adaptValueToProto(Object value, String protoTypeName) { - WellKnownProto wellKnownProto = WellKnownProto.getByDescriptorName(protoTypeName); + public Message adaptValueToProto(Object value, String protoTypeName) { + WellKnownProto wellKnownProto = WellKnownProto.getByTypeName(protoTypeName); if (wellKnownProto == null) { if (value instanceof Message) { - return Optional.of((Message) value); + return (Message) value; } - return Optional.empty(); + + throw new IllegalStateException(String.format("value not convertible to proto: %s", value)); } + switch (wellKnownProto) { case ANY_VALUE: - return Optional.ofNullable(adaptValueToAny(value)); - case JSON_VALUE: - try { - return Optional.of(CelProtoJsonAdapter.adaptValueToJsonValue(value)); - } catch (RuntimeException e) { - return Optional.empty(); - } - case JSON_LIST_VALUE: - try { - return Optional.of(CelProtoJsonAdapter.adaptToJsonListValue((Iterable) value)); - } catch (RuntimeException e) { - return Optional.empty(); - } - case JSON_STRUCT_VALUE: - try { - return Optional.of( - CelProtoJsonAdapter.adaptToJsonStructValue((Map) value)); - } catch (RuntimeException e) { - return Optional.empty(); + if (value instanceof Message) { + protoTypeName = ((Message) value).getDescriptorForType().getFullName(); } - case BOOL_VALUE: - if (value instanceof Boolean) { - return Optional.of(BoolValue.of((Boolean) value)); - } - break; - case BYTES_VALUE: - if (value instanceof ByteString) { - return Optional.of(BytesValue.of((ByteString) value)); - } - break; - case DOUBLE_VALUE: - return Optional.ofNullable(adaptValueToDouble(value)); - case DURATION_VALUE: - return Optional.of((Duration) value); - case FLOAT_VALUE: - return Optional.ofNullable(adaptValueToFloat(value)); - case INT32_VALUE: - return Optional.ofNullable(adaptValueToInt32(value)); - case INT64_VALUE: - return Optional.ofNullable(adaptValueToInt64(value)); - case STRING_VALUE: - if (value instanceof String) { - return Optional.of(StringValue.of((String) value)); - } - break; - case TIMESTAMP_VALUE: - return Optional.of((Timestamp) value); - case UINT32_VALUE: - return Optional.ofNullable(adaptValueToUint32(value)); - case UINT64_VALUE: - return Optional.ofNullable(adaptValueToUint64(value)); - } - return Optional.empty(); - } - - // Helper functions which return a {@code null} value if the conversion is not successful. - // This technique was chosen over {@code Optional} for brevity as any call site which might - // care about an Optional return is handled higher up the call stack. - - private @Nullable Message adaptValueToAny(Object value) { - if (value == null || value instanceof NullValue) { - return Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()); - } - if (value instanceof Boolean) { - return maybePackAny(value, WellKnownProto.BOOL_VALUE); - } - if (value instanceof ByteString) { - return maybePackAny(value, WellKnownProto.BYTES_VALUE); - } - if (value instanceof Double) { - return maybePackAny(value, WellKnownProto.DOUBLE_VALUE); - } - if (value instanceof Float) { - return maybePackAny(value, WellKnownProto.FLOAT_VALUE); - } - if (value instanceof Integer) { - return maybePackAny(value, WellKnownProto.INT32_VALUE); - } - if (value instanceof Long) { - return maybePackAny(value, WellKnownProto.INT64_VALUE); - } - if (value instanceof Message) { - return Any.pack((Message) value); - } - if (value instanceof Iterable) { - return maybePackAny(value, WellKnownProto.JSON_LIST_VALUE); - } - if (value instanceof Map) { - return maybePackAny(value, WellKnownProto.JSON_STRUCT_VALUE); - } - if (value instanceof String) { - return maybePackAny(value, WellKnownProto.STRING_VALUE); - } - if (value instanceof UnsignedLong) { - return maybePackAny(value, WellKnownProto.UINT64_VALUE); - } - return null; - } - - private @Nullable Any maybePackAny(Object value, WellKnownProto wellKnownProto) { - Optional protoValue = adaptValueToProto(value, wellKnownProto.typeName()); - return protoValue.map(Any::pack).orElse(null); - } - - private @Nullable Message adaptValueToDouble(Object value) { - if (value instanceof Double) { - return DoubleValue.of((Double) value); - } - if (value instanceof Float) { - return DoubleValue.of(((Float) value).doubleValue()); - } - return null; - } - - private @Nullable Message adaptValueToFloat(Object value) { - if (value instanceof Double) { - return FloatValue.of(((Double) value).floatValue()); - } - if (value instanceof Float) { - return FloatValue.of((Float) value); - } - return null; - } - - private @Nullable Message adaptValueToInt32(Object value) { - if (value instanceof Integer) { - return Int32Value.of((Integer) value); - } - if (value instanceof Long) { - return Int32Value.of(intCheckedCast((Long) value)); - } - return null; - } - - private @Nullable Message adaptValueToInt64(Object value) { - if (value instanceof Integer) { - return Int64Value.of(((Integer) value).longValue()); - } - if (value instanceof Long) { - return Int64Value.of((Long) value); - } - return null; - } - - private @Nullable Message adaptValueToUint32(Object value) { - if (value instanceof Integer) { - return UInt32Value.of((Integer) value); - } - if (value instanceof Long) { - try { - return UInt32Value.of(unsignedIntCheckedCast((Long) value)); - } catch (IllegalArgumentException e) { - throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); - } - } - if (value instanceof UnsignedLong) { - try { - return UInt32Value.of(unsignedIntCheckedCast(((UnsignedLong) value).longValue())); - } catch (IllegalArgumentException e) { - throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); - } - } - return null; - } - - private @Nullable Message adaptValueToUint64(Object value) { - if (value instanceof Integer) { - return UInt64Value.of(UnsignedInts.toLong((Integer) value)); - } - if (value instanceof Long) { - return UInt64Value.of((Long) value); - } - if (value instanceof UnsignedLong) { - return UInt64Value.of(((UnsignedLong) value).longValue()); - } - return null; - } - - private @Nullable Object adaptJsonToValue(Value value) { - switch (value.getKindCase()) { - case BOOL_VALUE: - return value.getBoolValue(); - case NULL_VALUE: - return value.getNullValue(); - case NUMBER_VALUE: - return value.getNumberValue(); - case STRING_VALUE: - return value.getStringValue(); - case LIST_VALUE: - return adaptJsonListToValue(value.getListValue()); - case STRUCT_VALUE: - return adaptJsonStructToValue(value.getStructValue()); - case KIND_NOT_SET: - return NullValue.NULL_VALUE; + return protoLiteAdapter.adaptValueToAny(value, protoTypeName); + default: + return (Message) protoLiteAdapter.adaptValueToWellKnownProto(value, wellKnownProto); } - return null; } private Object unpackAnyProto(Any anyProto) { @@ -550,17 +308,6 @@ private Object unpackAnyProto(Any anyProto) { } } - private ImmutableList adaptJsonListToValue(ListValue listValue) { - return listValue.getValuesList().stream() - .map(this::adaptJsonToValue) - .collect(ImmutableList.toImmutableList()); - } - - private ImmutableMap adaptJsonStructToValue(Struct struct) { - return struct.getFieldsMap().entrySet().stream() - .collect(toImmutableMap(e -> e.getKey(), e -> adaptJsonToValue(e.getValue()))); - } - /** Returns the default value for a field that can be a proto message */ private static Object getDefaultValueForMaybeMessage(FieldDescriptor descriptor) { if (descriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { @@ -579,7 +326,7 @@ private static boolean isWrapperType(FieldDescriptor fieldDescriptor) { return false; } String fieldTypeName = fieldDescriptor.getMessageType().getFullName(); - WellKnownProto wellKnownProto = WellKnownProto.getByDescriptorName(fieldTypeName); + WellKnownProto wellKnownProto = WellKnownProto.getByTypeName(fieldTypeName); return wellKnownProto != null && wellKnownProto.isWrapperType(); } diff --git a/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java b/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java new file mode 100644 index 000000000..32a73b6f4 --- /dev/null +++ b/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java @@ -0,0 +1,166 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.internal; + +import com.google.common.base.CaseFormat; +import com.google.common.base.Joiner; +import com.google.common.base.Strings; +import com.google.common.io.Files; +import com.google.protobuf.DescriptorProtos.FileOptions; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.EnumDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.Descriptors.GenericDescriptor; +import com.google.protobuf.Descriptors.ServiceDescriptor; +import dev.cel.common.annotations.Internal; +import java.util.ArrayDeque; + +/** + * Helper class for constructing a fully qualified Java class name from a protobuf descriptor. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +public final class ProtoJavaQualifiedNames { + // Controls how many times we should recursively inspect a nested message for building fully + // qualified java class name before aborting. + private static final int SAFE_RECURSE_LIMIT = 50; + + /** + * Retrieves the full Java class name from the given descriptor + * + * @return fully qualified class name. + *

Example 1: dev.cel.expr.Value + *

Example 2: com.google.rpc.context.AttributeContext$Resource (Nested classes) + *

Example 3: com.google.api.expr.cel.internal.testdata$SingleFileProto$SingleFile$Path + * (Nested class with java multiple files disabled) + */ + public static String getFullyQualifiedJavaClassName(Descriptor descriptor) { + return getFullyQualifiedJavaClassNameImpl(descriptor); + } + + public static String getFullyQualifiedJavaClassName(EnumDescriptor descriptor) { + return getFullyQualifiedJavaClassNameImpl(descriptor); + } + + private static String getFullyQualifiedJavaClassNameImpl(GenericDescriptor descriptor) { + StringBuilder fullClassName = new StringBuilder(); + + fullClassName.append(getJavaPackageName(descriptor.getFile())).append("."); + + String javaOuterClass = getJavaOuterClassName(descriptor.getFile()); + if (!Strings.isNullOrEmpty(javaOuterClass)) { + fullClassName.append(javaOuterClass).append("$"); + } + + // Recursively build the target class name in case if the message is nested. + ArrayDeque classNames = new ArrayDeque<>(); + GenericDescriptor d = descriptor; + + int recurseCount = 0; + while (d != null) { + classNames.push(d.getName()); + + if (d instanceof EnumDescriptor) { + d = ((EnumDescriptor) d).getContainingType(); + } else { + d = ((Descriptor) d).getContainingType(); + } + recurseCount++; + if (recurseCount >= SAFE_RECURSE_LIMIT) { + throw new IllegalStateException( + String.format( + "Recursion limit of %d hit while inspecting descriptor: %s", + SAFE_RECURSE_LIMIT, descriptor.getFullName())); + } + } + + Joiner.on("$").appendTo(fullClassName, classNames); + + return fullClassName.toString(); + } + + /** + * Gets the java package name from the descriptor. See + * https://developers.google.com/protocol-buffers/docs/reference/java-generated#package for rules + * on package name generation + */ + public static String getJavaPackageName(FileDescriptor fileDescriptor) { + FileOptions options = fileDescriptor.getFile().getOptions(); + StringBuilder javaPackageName = new StringBuilder(); + if (options.hasJavaPackage()) { + javaPackageName.append(fileDescriptor.getFile().getOptions().getJavaPackage()); + } else { + javaPackageName + // CEL-Internal-1 + .append(fileDescriptor.getPackage()); + } + + // CEL-Internal-2 + + return javaPackageName.toString(); + } + + /** + * Gets a wrapping outer class name from the descriptor. The outer class name differs depending on + * the proto options set. See + * https://developers.google.com/protocol-buffers/docs/reference/java-generated#invocation + */ + private static String getJavaOuterClassName(FileDescriptor descriptor) { + FileOptions options = descriptor.getOptions(); + + if (options.getJavaMultipleFiles()) { + // If java_multiple_files is enabled, protoc does not generate a wrapper outer class + return ""; + } + + if (options.hasJavaOuterClassname()) { + return options.getJavaOuterClassname(); + } else { + // If an outer class name is not explicitly set, the name is converted into + // Pascal case based on the snake cased file name + // Ex: messages_proto.proto becomes MessagesProto + String protoFileNameWithoutExtension = + Files.getNameWithoutExtension(descriptor.getFile().getFullName()); + String outerClassName = + CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, protoFileNameWithoutExtension); + if (hasConflictingClassName(descriptor.getFile(), outerClassName)) { + outerClassName += "OuterClass"; + } + return outerClassName; + } + } + + private static boolean hasConflictingClassName(FileDescriptor file, String name) { + for (EnumDescriptor enumDesc : file.getEnumTypes()) { + if (name.equals(enumDesc.getName())) { + return true; + } + } + for (ServiceDescriptor serviceDesc : file.getServices()) { + if (name.equals(serviceDesc.getName())) { + return true; + } + } + for (Descriptor messageDesc : file.getMessageTypes()) { + if (name.equals(messageDesc.getName())) { + return true; + } + } + return false; + } + + private ProtoJavaQualifiedNames() {} +} diff --git a/common/src/main/java/dev/cel/common/internal/ProtoLiteAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoLiteAdapter.java new file mode 100644 index 000000000..da03cf103 --- /dev/null +++ b/common/src/main/java/dev/cel/common/internal/ProtoLiteAdapter.java @@ -0,0 +1,324 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.internal; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import com.google.common.primitives.UnsignedInts; +import com.google.common.primitives.UnsignedLong; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Duration; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.ListValue; +import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; +import com.google.protobuf.MessageLiteOrBuilder; +import com.google.protobuf.NullValue; +import com.google.protobuf.StringValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; +import dev.cel.common.CelErrorCode; +import dev.cel.common.CelProtoJsonAdapter; +import dev.cel.common.CelRuntimeException; +import dev.cel.common.annotations.Internal; +import java.util.Map; +import java.util.Map.Entry; + +/** + * {@code ProtoLiteAdapter} utilities handle conversion between native Java objects which represent + * CEL values and well-known protobuf counterparts. + * + *

This adapter does not leverage descriptors, thus is compatible with lite-variants of protobuf + * messages. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +@Immutable +public final class ProtoLiteAdapter { + + private final boolean enableUnsignedLongs; + + @SuppressWarnings("unchecked") + public MessageLite adaptValueToWellKnownProto(Object value, WellKnownProto wellKnownProto) { + if (wellKnownProto.isWrapperType() && value instanceof MessageLiteOrBuilder) { + // Unwrap well known proto's underlying value (e.g: Int32Value { value: 1 }) + value = adaptWellKnownProtoToValue((MessageLiteOrBuilder) value, wellKnownProto); + } + switch (wellKnownProto) { + case JSON_VALUE: + return CelProtoJsonAdapter.adaptValueToJsonValue(value); + case JSON_STRUCT_VALUE: + return CelProtoJsonAdapter.adaptToJsonStructValue((Map) value); + case JSON_LIST_VALUE: + return CelProtoJsonAdapter.adaptToJsonListValue((Iterable) value); + case BOOL_VALUE: + return BoolValue.of((Boolean) value); + case BYTES_VALUE: + return BytesValue.of((ByteString) value); + case DOUBLE_VALUE: + return adaptValueToDouble(value); + case FLOAT_VALUE: + return adaptValueToFloat(value); + case INT32_VALUE: + return adaptValueToInt32(value); + case INT64_VALUE: + return adaptValueToInt64(value); + case STRING_VALUE: + return StringValue.of((String) value); + case UINT32_VALUE: + return adaptValueToUint32(value); + case UINT64_VALUE: + return adaptValueToUint64(value); + case DURATION_VALUE: + return (Duration) value; + case TIMESTAMP_VALUE: + return (Timestamp) value; + default: + throw new IllegalArgumentException("Unexpceted wellKnownProto kind: " + wellKnownProto); + } + } + + public Any adaptValueToAny(Object value, String typeName) { + if (value instanceof MessageLite) { + return packAnyMessage((MessageLite) value, typeName); + } + + if (value instanceof NullValue) { + return packAnyMessage( + Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(), WellKnownProto.JSON_VALUE); + } + + WellKnownProto wellKnownProto; + + if (value instanceof Boolean) { + wellKnownProto = WellKnownProto.BOOL_VALUE; + } else if (value instanceof ByteString) { + wellKnownProto = WellKnownProto.BYTES_VALUE; + } else if (value instanceof String) { + wellKnownProto = WellKnownProto.STRING_VALUE; + } else if (value instanceof Float) { + wellKnownProto = WellKnownProto.FLOAT_VALUE; + } else if (value instanceof Double) { + wellKnownProto = WellKnownProto.DOUBLE_VALUE; + } else if (value instanceof Long) { + wellKnownProto = WellKnownProto.INT64_VALUE; + } else if (value instanceof UnsignedLong) { + wellKnownProto = WellKnownProto.UINT64_VALUE; + } else if (value instanceof Iterable) { + wellKnownProto = WellKnownProto.JSON_LIST_VALUE; + } else if (value instanceof Map) { + wellKnownProto = WellKnownProto.JSON_STRUCT_VALUE; + } else { + throw new IllegalArgumentException("Unsupported value conversion to any: " + value); + } + + MessageLite wellKnownProtoMsg = adaptValueToWellKnownProto(value, wellKnownProto); + return packAnyMessage(wellKnownProtoMsg, wellKnownProto); + } + + public Object adaptWellKnownProtoToValue( + MessageLiteOrBuilder proto, WellKnownProto wellKnownProto) { + // Exhaustive switch over the conversion and adaptation of well-known protobuf types to Java + // values. + switch (wellKnownProto) { + case JSON_VALUE: + return adaptJsonToValue((Value) proto); + case JSON_STRUCT_VALUE: + return adaptJsonStructToValue((Struct) proto); + case JSON_LIST_VALUE: + return adaptJsonListToValue((ListValue) proto); + case BOOL_VALUE: + return ((BoolValue) proto).getValue(); + case BYTES_VALUE: + return ((BytesValue) proto).getValue(); + case DOUBLE_VALUE: + return ((DoubleValue) proto).getValue(); + case FLOAT_VALUE: + return (double) ((FloatValue) proto).getValue(); + case INT32_VALUE: + return (long) ((Int32Value) proto).getValue(); + case INT64_VALUE: + return ((Int64Value) proto).getValue(); + case STRING_VALUE: + return ((StringValue) proto).getValue(); + case UINT32_VALUE: + if (enableUnsignedLongs) { + return UnsignedLong.fromLongBits( + Integer.toUnsignedLong(((UInt32Value) proto).getValue())); + } + return (long) ((UInt32Value) proto).getValue(); + case UINT64_VALUE: + if (enableUnsignedLongs) { + return UnsignedLong.fromLongBits(((UInt64Value) proto).getValue()); + } + return ((UInt64Value) proto).getValue(); + default: + return proto; + } + } + + private Object adaptJsonToValue(Value value) { + switch (value.getKindCase()) { + case BOOL_VALUE: + return value.getBoolValue(); + case NULL_VALUE: + return value.getNullValue(); + case NUMBER_VALUE: + return value.getNumberValue(); + case STRING_VALUE: + return value.getStringValue(); + case LIST_VALUE: + return adaptJsonListToValue(value.getListValue()); + case STRUCT_VALUE: + return adaptJsonStructToValue(value.getStructValue()); + case KIND_NOT_SET: + return NullValue.NULL_VALUE; + } + throw new IllegalArgumentException("unexpected value kind: " + value.getKindCase()); + } + + private ImmutableList adaptJsonListToValue(ListValue listValue) { + return listValue.getValuesList().stream() + .map(this::adaptJsonToValue) + .collect(ImmutableList.toImmutableList()); + } + + private ImmutableMap adaptJsonStructToValue(Struct struct) { + return struct.getFieldsMap().entrySet().stream() + .collect(toImmutableMap(Entry::getKey, e -> adaptJsonToValue(e.getValue()))); + } + + private Message adaptValueToDouble(Object value) { + if (value instanceof Double) { + return DoubleValue.of((Double) value); + } + if (value instanceof Float) { + return DoubleValue.of(((Float) value).doubleValue()); + } + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private Message adaptValueToFloat(Object value) { + if (value instanceof Double) { + return FloatValue.of(((Double) value).floatValue()); + } + if (value instanceof Float) { + return FloatValue.of((Float) value); + } + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private Message adaptValueToInt32(Object value) { + if (value instanceof Integer) { + return Int32Value.of((Integer) value); + } + if (value instanceof Long) { + return Int32Value.of(intCheckedCast((Long) value)); + } + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private Message adaptValueToInt64(Object value) { + if (value instanceof Integer) { + return Int64Value.of(((Integer) value).longValue()); + } + if (value instanceof Long) { + return Int64Value.of((Long) value); + } + + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private Message adaptValueToUint32(Object value) { + if (value instanceof Integer) { + return UInt32Value.of((Integer) value); + } + if (value instanceof Long) { + try { + return UInt32Value.of(unsignedIntCheckedCast((Long) value)); + } catch (IllegalArgumentException e) { + throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); + } + } + if (value instanceof UnsignedLong) { + try { + return UInt32Value.of(unsignedIntCheckedCast(((UnsignedLong) value).longValue())); + } catch (IllegalArgumentException e) { + throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); + } + } + + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private Message adaptValueToUint64(Object value) { + if (value instanceof Integer) { + return UInt64Value.of(UnsignedInts.toLong((Integer) value)); + } + if (value instanceof Long) { + return UInt64Value.of((Long) value); + } + if (value instanceof UnsignedLong) { + return UInt64Value.of(((UnsignedLong) value).longValue()); + } + + throw new IllegalArgumentException("Unexpected value type: " + value); + } + + private static int intCheckedCast(long value) { + try { + return Ints.checkedCast(value); + } catch (IllegalArgumentException e) { + throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); + } + } + + private static int unsignedIntCheckedCast(long value) { + try { + return UnsignedInts.checkedCast(value); + } catch (IllegalArgumentException e) { + throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); + } + } + + private static Any packAnyMessage(MessageLite msg, WellKnownProto wellKnownProto) { + return packAnyMessage(msg, wellKnownProto.typeName()); + } + + private static Any packAnyMessage(MessageLite msg, String typeUrl) { + return Any.newBuilder() + .setValue(msg.toByteString()) + .setTypeUrl("type.googleapis.com/" + typeUrl) + .build(); + } + + public ProtoLiteAdapter(boolean enableUnsignedLongs) { + this.enableUnsignedLongs = enableUnsignedLongs; + } +} diff --git a/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java b/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java new file mode 100644 index 000000000..8aa4c7f14 --- /dev/null +++ b/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.internal; + +import dev.cel.common.annotations.Internal; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * Utility class for invoking Java reflection. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +public final class ReflectionUtil { + + public static Method getMethod(String className, String methodName, Class... params) { + try { + return getMethod(Class.forName(className), methodName, params); + } catch (ClassNotFoundException e) { + throw new LinkageError(String.format("Could not find class %s", className), e); + } + } + + public static Method getMethod(Class clazz, String methodName, Class... params) { + try { + return clazz.getMethod(methodName, params); + } catch (NoSuchMethodException e) { + throw new LinkageError( + String.format("method [%s] does not exist in class: [%s].", methodName, clazz.getName()), + e); + } + } + + public static Object invoke(Method method, Object object, Object... params) { + try { + return method.invoke(object, params); + } catch (IllegalArgumentException | InvocationTargetException | IllegalAccessException e) { + throw new LinkageError( + String.format( + "method [%s] invocation failed on class [%s].", + method.getName(), method.getDeclaringClass()), + e); + } + } + + private ReflectionUtil() {} +} diff --git a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java index 14da4396d..09d9a0f84 100644 --- a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java +++ b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java @@ -21,9 +21,10 @@ import com.google.protobuf.Any; import com.google.protobuf.BoolValue; import com.google.protobuf.BytesValue; -import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.DoubleValue; import com.google.protobuf.Duration; +import com.google.protobuf.Empty; +import com.google.protobuf.FieldMask; import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; @@ -36,6 +37,7 @@ import com.google.protobuf.Value; import dev.cel.common.annotations.Internal; import java.util.function.Function; +import org.jspecify.annotations.Nullable; /** * WellKnownProto types used throughout CEL. These types are specially handled to ensure that @@ -44,24 +46,23 @@ */ @Internal public enum WellKnownProto { - JSON_VALUE(Value.getDescriptor()), - JSON_STRUCT_VALUE(Struct.getDescriptor()), - JSON_LIST_VALUE(ListValue.getDescriptor()), - ANY_VALUE(Any.getDescriptor()), - BOOL_VALUE(BoolValue.getDescriptor(), true), - BYTES_VALUE(BytesValue.getDescriptor(), true), - DOUBLE_VALUE(DoubleValue.getDescriptor(), true), - FLOAT_VALUE(FloatValue.getDescriptor(), true), - INT32_VALUE(Int32Value.getDescriptor(), true), - INT64_VALUE(Int64Value.getDescriptor(), true), - STRING_VALUE(StringValue.getDescriptor(), true), - UINT32_VALUE(UInt32Value.getDescriptor(), true), - UINT64_VALUE(UInt64Value.getDescriptor(), true), - DURATION_VALUE(Duration.getDescriptor()), - TIMESTAMP_VALUE(Timestamp.getDescriptor()); - - private final Descriptor descriptor; - private final boolean isWrapperType; + JSON_VALUE("google.protobuf.Value", Value.class.getName()), + JSON_STRUCT_VALUE("google.protobuf.Struct", Struct.class.getName()), + JSON_LIST_VALUE("google.protobuf.ListValue", ListValue.class.getName()), + ANY_VALUE("google.protobuf.Any", Any.class.getName()), + DURATION_VALUE("google.protobuf.Duration", Duration.class.getName()), + TIMESTAMP_VALUE("google.protobuf.Timestamp", Timestamp.class.getName()), + EMPTY_VALUE("google.protobuf.Empty", Empty.class.getName()), + FIELD_MASK_VALUE("google.protobuf.FieldMask", FieldMask.class.getName()), + BOOL_VALUE("google.protobuf.BoolValue", BoolValue.class.getName(), true), + BYTES_VALUE("google.protobuf.BytesValue", BytesValue.class.getName(), true), + DOUBLE_VALUE("google.protobuf.DoubleValue", DoubleValue.class.getName(), true), + FLOAT_VALUE("google.protobuf.FloatValue", FloatValue.class.getName(), true), + INT32_VALUE("google.protobuf.Int32Value", Int32Value.class.getName(), true), + INT64_VALUE("google.protobuf.Int64Value", Int64Value.class.getName(), true), + STRING_VALUE("google.protobuf.StringValue", StringValue.class.getName(), true), + UINT32_VALUE("google.protobuf.UInt32Value", UInt32Value.class.getName(), true), + UINT64_VALUE("google.protobuf.UInt64Value", UInt64Value.class.getName(), true); private static final ImmutableMap WELL_KNOWN_PROTO_MAP; @@ -71,28 +72,42 @@ public enum WellKnownProto { .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); } - WellKnownProto(Descriptor descriptor) { - this(descriptor, /* isWrapperType= */ false); + private final String wellKnownProtoFullName; + private final String javaClassName; + private final boolean isWrapperType; + + public String typeName() { + return wellKnownProtoFullName; } - WellKnownProto(Descriptor descriptor, boolean isWrapperType) { - this.descriptor = descriptor; - this.isWrapperType = isWrapperType; + public String javaClassName() { + return this.javaClassName; } - public Descriptor descriptor() { - return descriptor; + public static @Nullable WellKnownProto getByTypeName(String typeName) { + return WELL_KNOWN_PROTO_MAP.get(typeName); } - public String typeName() { - return descriptor.getFullName(); + public static boolean isWrapperType(String typeName) { + WellKnownProto wellKnownProto = getByTypeName(typeName); + if (wellKnownProto == null) { + return false; + } + + return wellKnownProto.isWrapperType(); } public boolean isWrapperType() { return isWrapperType; } - public static WellKnownProto getByDescriptorName(String name) { - return WELL_KNOWN_PROTO_MAP.get(name); + WellKnownProto(String wellKnownProtoFullName, String javaClassName) { + this(wellKnownProtoFullName, javaClassName, /* isWrapperType= */ false); + } + + WellKnownProto(String wellKnownProtoFullName, String javaClassName, boolean isWrapperType) { + this.wellKnownProtoFullName = wellKnownProtoFullName; + this.javaClassName = javaClassName; + this.isWrapperType = isWrapperType; } } diff --git a/common/src/main/java/dev/cel/common/values/BUILD.bazel b/common/src/main/java/dev/cel/common/values/BUILD.bazel index c9e233108..e89d1d509 100644 --- a/common/src/main/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/values/BUILD.bazel @@ -98,12 +98,34 @@ java_library( ], ) +java_library( + name = "base_proto_cel_value_converter", + srcs = ["BaseProtoCelValueConverter.java"], + tags = [ + "alt_dep=//common/values:base_proto_cel_value_converter", + "avoid_dep", + ], + deps = [ + ":cel_byte_string", + ":cel_value", + ":values", + "//common:options", + "//common/annotations", + "//common/internal:well_known_proto", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:com_google_protobuf_protobuf_java_util", + ], +) + java_library( name = "proto_message_value", srcs = PROTO_MESSAGE_VALUE_SOURCES, tags = [ ], deps = [ + ":base_proto_cel_value_converter", ":cel_value", ":values", "//:auto_value", @@ -115,11 +137,9 @@ java_library( "//common/types", "//common/types:cel_types", "//common/types:type_providers", - "//common/values:cel_byte_string", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", - "@maven//:com_google_protobuf_protobuf_java_util", "@maven//:org_jspecify_jspecify", ], ) @@ -143,3 +163,51 @@ java_library( "@maven//:com_google_protobuf_protobuf_java", ], ) + +java_library( + name = "proto_message_lite_value", + srcs = [ + "ProtoLiteCelValueConverter.java", + "ProtoMessageLiteValue.java", + ], + tags = [ + ], + deps = [ + ":base_proto_cel_value_converter", + ":cel_value", + ":values", + "//:auto_value", + "//common:options", + "//common/annotations", + "//common/internal:cel_lite_descriptor_pool", + "//common/internal:reflection_util", + "//common/internal:well_known_proto", + "//common/types", + "//protobuf:cel_lite_descriptor", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:org_jspecify_jspecify", + ], +) + +java_library( + name = "proto_message_lite_value_provider", + srcs = ["ProtoMessageLiteValueProvider.java"], + deps = [ + ":cel_value", + ":cel_value_provider", + ":proto_message_lite_value", + "//common:error_codes", + "//common:runtime_exception", + "//common/internal:cel_lite_descriptor_pool", + "//common/internal:default_instance_message_lite_factory", + "//common/internal:proto_lite_adapter", + "//common/internal:reflection_util", + "//common/internal:well_known_proto", + "//protobuf:cel_lite_descriptor", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) diff --git a/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java new file mode 100644 index 000000000..c22303c58 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java @@ -0,0 +1,229 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.math.LongMath.checkedAdd; +import static com.google.common.math.LongMath.checkedSubtract; + +import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.MessageLiteOrBuilder; +import com.google.protobuf.StringValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; +import com.google.protobuf.util.Durations; +import com.google.protobuf.util.Timestamps; +import dev.cel.common.CelOptions; +import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.WellKnownProto; +import java.time.Duration; +import java.time.Instant; + +/** + * {@code BaseProtoCelValueConverter} contains the common logic for converting between native Java + * and protobuf objects to {@link CelValue}. This base class is inherited by {@code + * ProtoCelValueConverter} and {@code ProtoLiteCelValueConverter} to perform the conversion using + * full and lite variants of protobuf messages respectively. + * + *

CEL Library Internals. Do Not Use. + */ +@Immutable +@Internal +public abstract class BaseProtoCelValueConverter extends CelValueConverter { + + /** + * Adapts a {@link CelValue} to a native Java object. The CelValue is adapted into protobuf object + * when an equivalent exists. + */ + @Override + public Object fromCelValueToJavaObject(CelValue celValue) { + Preconditions.checkNotNull(celValue); + + if (celValue instanceof TimestampValue) { + return TimeUtils.toProtoTimestamp(((TimestampValue) celValue).value()); + } else if (celValue instanceof DurationValue) { + return TimeUtils.toProtoDuration(((DurationValue) celValue).value()); + } else if (celValue instanceof BytesValue) { + return ByteString.copyFrom(((BytesValue) celValue).value().toByteArray()); + } else if (celValue.equals(NullValue.NULL_VALUE)) { + return com.google.protobuf.NullValue.NULL_VALUE; + } + + return super.fromCelValueToJavaObject(celValue); + } + + /** + * Adapts a plain old Java Object to a {@link CelValue}. Protobuf semantics take precedence for + * conversion. + */ + @Override + public CelValue fromJavaObjectToCelValue(Object value) { + Preconditions.checkNotNull(value); + + if (value instanceof ByteString) { + return BytesValue.create(CelByteString.of(((ByteString) value).toByteArray())); + } else if (value instanceof com.google.protobuf.NullValue) { + return NullValue.NULL_VALUE; + } + + return super.fromJavaObjectToCelValue(value); + } + + protected final CelValue fromWellKnownProtoToCelValue( + MessageLiteOrBuilder message, WellKnownProto wellKnownProto) { + switch (wellKnownProto) { + case JSON_VALUE: + return adaptJsonValueToCelValue((Value) message); + case JSON_STRUCT_VALUE: + return adaptJsonStructToCelValue((Struct) message); + case JSON_LIST_VALUE: + return adaptJsonListToCelValue((com.google.protobuf.ListValue) message); + case DURATION_VALUE: + return DurationValue.create( + TimeUtils.toJavaDuration((com.google.protobuf.Duration) message)); + case TIMESTAMP_VALUE: + return TimestampValue.create(TimeUtils.toJavaInstant((Timestamp) message)); + case BOOL_VALUE: + return fromJavaPrimitiveToCelValue(((BoolValue) message).getValue()); + case BYTES_VALUE: + return fromJavaPrimitiveToCelValue( + ((com.google.protobuf.BytesValue) message).getValue().toByteArray()); + case DOUBLE_VALUE: + return fromJavaPrimitiveToCelValue(((DoubleValue) message).getValue()); + case FLOAT_VALUE: + return fromJavaPrimitiveToCelValue(((FloatValue) message).getValue()); + case INT32_VALUE: + return fromJavaPrimitiveToCelValue(((Int32Value) message).getValue()); + case INT64_VALUE: + return fromJavaPrimitiveToCelValue(((Int64Value) message).getValue()); + case STRING_VALUE: + return fromJavaPrimitiveToCelValue(((StringValue) message).getValue()); + case UINT32_VALUE: + return UintValue.create( + ((UInt32Value) message).getValue(), celOptions.enableUnsignedLongs()); + case UINT64_VALUE: + return UintValue.create( + ((UInt64Value) message).getValue(), celOptions.enableUnsignedLongs()); + default: + throw new UnsupportedOperationException( + "Unsupported message to CelValue conversion - " + message); + } + } + + private CelValue adaptJsonValueToCelValue(Value value) { + switch (value.getKindCase()) { + case BOOL_VALUE: + return fromJavaPrimitiveToCelValue(value.getBoolValue()); + case NUMBER_VALUE: + return fromJavaPrimitiveToCelValue(value.getNumberValue()); + case STRING_VALUE: + return fromJavaPrimitiveToCelValue(value.getStringValue()); + case LIST_VALUE: + return adaptJsonListToCelValue(value.getListValue()); + case STRUCT_VALUE: + return adaptJsonStructToCelValue(value.getStructValue()); + case NULL_VALUE: + case KIND_NOT_SET: // Fall-through is intended + return NullValue.NULL_VALUE; + } + throw new UnsupportedOperationException( + "Unsupported Json to CelValue conversion: " + value.getKindCase()); + } + + private ListValue adaptJsonListToCelValue(com.google.protobuf.ListValue listValue) { + return ImmutableListValue.create( + listValue.getValuesList().stream() + .map(this::adaptJsonValueToCelValue) + .collect(toImmutableList())); + } + + private MapValue adaptJsonStructToCelValue(Struct struct) { + return ImmutableMapValue.create( + struct.getFieldsMap().entrySet().stream() + .collect( + toImmutableMap( + e -> fromJavaObjectToCelValue(e.getKey()), + e -> adaptJsonValueToCelValue(e.getValue())))); + } + + /** Helper to convert between java.util.time and protobuf duration/timestamp. */ + private static class TimeUtils { + private static final int NANOS_PER_SECOND = 1000000000; + + private static Instant toJavaInstant(Timestamp timestamp) { + timestamp = normalizedTimestamp(timestamp.getSeconds(), timestamp.getNanos()); + return Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()); + } + + private static Duration toJavaDuration(com.google.protobuf.Duration duration) { + duration = normalizedDuration(duration.getSeconds(), duration.getNanos()); + return java.time.Duration.ofSeconds(duration.getSeconds(), duration.getNanos()); + } + + private static Timestamp toProtoTimestamp(Instant instant) { + return normalizedTimestamp(instant.getEpochSecond(), instant.getNano()); + } + + private static com.google.protobuf.Duration toProtoDuration(Duration duration) { + return normalizedDuration(duration.getSeconds(), duration.getNano()); + } + + private static Timestamp normalizedTimestamp(long seconds, int nanos) { + if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { + seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); + nanos = nanos % NANOS_PER_SECOND; + } + if (nanos < 0) { + nanos = nanos + NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) + seconds = checkedSubtract(seconds, 1); + } + Timestamp timestamp = Timestamp.newBuilder().setSeconds(seconds).setNanos(nanos).build(); + return Timestamps.checkValid(timestamp); + } + + private static com.google.protobuf.Duration normalizedDuration(long seconds, int nanos) { + if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { + seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); + nanos %= NANOS_PER_SECOND; + } + if (seconds > 0 && nanos < 0) { + nanos += NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) + seconds--; // no overflow since seconds is positive (and we're decrementing) + } + if (seconds < 0 && nanos > 0) { + nanos -= NANOS_PER_SECOND; // no overflow since nanos is positive (and we're subtracting) + seconds++; // no overflow since seconds is negative (and we're incrementing) + } + com.google.protobuf.Duration duration = + com.google.protobuf.Duration.newBuilder().setSeconds(seconds).setNanos(nanos).build(); + return Durations.checkValid(duration); + } + } + + protected BaseProtoCelValueConverter(CelOptions celOptions) { + super(celOptions); + } +} diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index 152adbffb..14cd4fa81 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -14,50 +14,30 @@ package dev.cel.common.values; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.math.LongMath.checkedAdd; -import static com.google.common.math.LongMath.checkedSubtract; - import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Any; -import com.google.protobuf.BoolValue; -import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors.EnumValueDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.DoubleValue; import com.google.protobuf.DynamicMessage; -import com.google.protobuf.FloatValue; -import com.google.protobuf.Int32Value; -import com.google.protobuf.Int64Value; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; -import com.google.protobuf.StringValue; -import com.google.protobuf.Struct; -import com.google.protobuf.Timestamp; -import com.google.protobuf.UInt32Value; -import com.google.protobuf.UInt64Value; -import com.google.protobuf.Value; -import com.google.protobuf.util.Durations; -import com.google.protobuf.util.Timestamps; import dev.cel.common.CelOptions; import dev.cel.common.annotations.Internal; import dev.cel.common.internal.CelDescriptorPool; import dev.cel.common.internal.DynamicProto; import dev.cel.common.internal.WellKnownProto; import dev.cel.common.types.CelTypes; -import java.time.Duration; -import java.time.Instant; import java.util.HashMap; import java.util.List; import java.util.Map; /** - * {@code CelValueConverter} handles bidirectional conversion between native Java and protobuf - * objects to {@link CelValue}. + * {@code ProtoCelValueConverter} handles bidirectional conversion between native Java and protobuf + * objects to {@link CelValue}. This converter leverages descriptors, thus requires the full version + * of protobuf implementation. * *

Protobuf semantics take precedence for conversion. For example, CEL's TimestampValue will be * converted into Protobuf's Timestamp instead of java.time.Instant. @@ -66,7 +46,7 @@ */ @Immutable @Internal -public final class ProtoCelValueConverter extends CelValueConverter { +public final class ProtoCelValueConverter extends BaseProtoCelValueConverter { private final CelDescriptorPool celDescriptorPool; private final DynamicProto dynamicProto; @@ -76,27 +56,6 @@ public static ProtoCelValueConverter newInstance( return new ProtoCelValueConverter(celOptions, celDescriptorPool, dynamicProto); } - /** - * Adapts a {@link CelValue} to a native Java object. The CelValue is adapted into protobuf object - * when an equivalent exists. - */ - @Override - public Object fromCelValueToJavaObject(CelValue celValue) { - Preconditions.checkNotNull(celValue); - - if (celValue instanceof TimestampValue) { - return TimeUtils.toProtoTimestamp(((TimestampValue) celValue).value()); - } else if (celValue instanceof DurationValue) { - return TimeUtils.toProtoDuration(((DurationValue) celValue).value()); - } else if (celValue instanceof BytesValue) { - return ByteString.copyFrom(((BytesValue) celValue).value().toByteArray()); - } else if (NullValue.NULL_VALUE.equals(celValue)) { - return com.google.protobuf.NullValue.NULL_VALUE; - } - - return super.fromCelValueToJavaObject(celValue); - } - /** Adapts a Protobuf message into a {@link CelValue}. */ public CelValue fromProtoMessageToCelValue(MessageOrBuilder message) { Preconditions.checkNotNull(message); @@ -107,7 +66,7 @@ public CelValue fromProtoMessageToCelValue(MessageOrBuilder message) { } WellKnownProto wellKnownProto = - WellKnownProto.getByDescriptorName(message.getDescriptorForType().getFullName()); + WellKnownProto.getByTypeName(message.getDescriptorForType().getFullName()); if (wellKnownProto == null) { return ProtoMessageValue.create((Message) message, celDescriptorPool, this); } @@ -122,42 +81,9 @@ public CelValue fromProtoMessageToCelValue(MessageOrBuilder message) { "Unpacking failed for message: " + message.getDescriptorForType().getFullName(), e); } return fromProtoMessageToCelValue(unpackedMessage); - case JSON_VALUE: - return adaptJsonValueToCelValue((Value) message); - case JSON_STRUCT_VALUE: - return adaptJsonStructToCelValue((Struct) message); - case JSON_LIST_VALUE: - return adaptJsonListToCelValue((com.google.protobuf.ListValue) message); - case DURATION_VALUE: - return DurationValue.create( - TimeUtils.toJavaDuration((com.google.protobuf.Duration) message)); - case TIMESTAMP_VALUE: - return TimestampValue.create(TimeUtils.toJavaInstant((Timestamp) message)); - case BOOL_VALUE: - return fromJavaPrimitiveToCelValue(((BoolValue) message).getValue()); - case BYTES_VALUE: - return fromJavaPrimitiveToCelValue( - ((com.google.protobuf.BytesValue) message).getValue().toByteArray()); - case DOUBLE_VALUE: - return fromJavaPrimitiveToCelValue(((DoubleValue) message).getValue()); - case FLOAT_VALUE: - return fromJavaPrimitiveToCelValue(((FloatValue) message).getValue()); - case INT32_VALUE: - return fromJavaPrimitiveToCelValue(((Int32Value) message).getValue()); - case INT64_VALUE: - return fromJavaPrimitiveToCelValue(((Int64Value) message).getValue()); - case STRING_VALUE: - return fromJavaPrimitiveToCelValue(((StringValue) message).getValue()); - case UINT32_VALUE: - return UintValue.create( - ((UInt32Value) message).getValue(), celOptions.enableUnsignedLongs()); - case UINT64_VALUE: - return UintValue.create( - ((UInt64Value) message).getValue(), celOptions.enableUnsignedLongs()); + default: + return super.fromWellKnownProtoToCelValue(message, wellKnownProto); } - - throw new UnsupportedOperationException( - "Unsupported message to CelValue conversion - " + message); } /** @@ -173,10 +99,6 @@ public CelValue fromJavaObjectToCelValue(Object value) { } else if (value instanceof Message.Builder) { Message.Builder msgBuilder = (Message.Builder) value; return fromProtoMessageToCelValue(msgBuilder.build()); - } else if (value instanceof ByteString) { - return BytesValue.create(CelByteString.of(((ByteString) value).toByteArray())); - } else if (value instanceof com.google.protobuf.NullValue) { - return NullValue.NULL_VALUE; } else if (value instanceof EnumValueDescriptor) { // (b/178627883) Strongly typed enum is not supported yet return IntValue.create(((EnumValueDescriptor) value).getNumber()); @@ -237,96 +159,6 @@ public CelValue fromProtoMessageFieldToCelValue( return fromJavaObjectToCelValue(result); } - private CelValue adaptJsonValueToCelValue(Value value) { - switch (value.getKindCase()) { - case BOOL_VALUE: - return fromJavaPrimitiveToCelValue(value.getBoolValue()); - case NUMBER_VALUE: - return fromJavaPrimitiveToCelValue(value.getNumberValue()); - case STRING_VALUE: - return fromJavaPrimitiveToCelValue(value.getStringValue()); - case LIST_VALUE: - return adaptJsonListToCelValue(value.getListValue()); - case STRUCT_VALUE: - return adaptJsonStructToCelValue(value.getStructValue()); - case NULL_VALUE: - case KIND_NOT_SET: // Fall-through is intended - return NullValue.NULL_VALUE; - } - throw new UnsupportedOperationException( - "Unsupported Json to CelValue conversion: " + value.getKindCase()); - } - - private ListValue adaptJsonListToCelValue(com.google.protobuf.ListValue listValue) { - return ImmutableListValue.create( - listValue.getValuesList().stream() - .map(this::adaptJsonValueToCelValue) - .collect(toImmutableList())); - } - - private MapValue adaptJsonStructToCelValue(Struct struct) { - return ImmutableMapValue.create( - struct.getFieldsMap().entrySet().stream() - .collect( - toImmutableMap( - e -> fromJavaObjectToCelValue(e.getKey()), - e -> adaptJsonValueToCelValue(e.getValue())))); - } - - /** Helper to convert between java.util.time and protobuf duration/timestamp. */ - private static class TimeUtils { - private static final int NANOS_PER_SECOND = 1000000000; - - private static Instant toJavaInstant(Timestamp timestamp) { - timestamp = normalizedTimestamp(timestamp.getSeconds(), timestamp.getNanos()); - return Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()); - } - - private static Duration toJavaDuration(com.google.protobuf.Duration duration) { - duration = normalizedDuration(duration.getSeconds(), duration.getNanos()); - return java.time.Duration.ofSeconds(duration.getSeconds(), duration.getNanos()); - } - - private static Timestamp toProtoTimestamp(Instant instant) { - return normalizedTimestamp(instant.getEpochSecond(), instant.getNano()); - } - - private static com.google.protobuf.Duration toProtoDuration(Duration duration) { - return normalizedDuration(duration.getSeconds(), duration.getNano()); - } - - private static Timestamp normalizedTimestamp(long seconds, int nanos) { - if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { - seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); - nanos = nanos % NANOS_PER_SECOND; - } - if (nanos < 0) { - nanos = nanos + NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) - seconds = checkedSubtract(seconds, 1); - } - Timestamp timestamp = Timestamp.newBuilder().setSeconds(seconds).setNanos(nanos).build(); - return Timestamps.checkValid(timestamp); - } - - private static com.google.protobuf.Duration normalizedDuration(long seconds, int nanos) { - if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { - seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); - nanos %= NANOS_PER_SECOND; - } - if (seconds > 0 && nanos < 0) { - nanos += NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) - seconds--; // no overflow since seconds is positive (and we're decrementing) - } - if (seconds < 0 && nanos > 0) { - nanos -= NANOS_PER_SECOND; // no overflow since nanos is positive (and we're subtracting) - seconds++; // no overflow since seconds is negative (and we're incrementing) - } - com.google.protobuf.Duration duration = - com.google.protobuf.Duration.newBuilder().setSeconds(seconds).setNanos(nanos).build(); - return Durations.checkValid(duration); - } - } - private ProtoCelValueConverter( CelOptions celOptions, CelDescriptorPool celDescriptorPool, DynamicProto dynamicProto) { super(celOptions); diff --git a/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java new file mode 100644 index 000000000..9863f8a74 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java @@ -0,0 +1,157 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.primitives.UnsignedLong; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Internal.EnumLite; +import com.google.protobuf.MessageLite; +import dev.cel.common.CelOptions; +import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.CelLiteDescriptorPool; +import dev.cel.common.internal.ReflectionUtil; +import dev.cel.common.internal.WellKnownProto; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import java.lang.reflect.Method; +import java.util.NoSuchElementException; +import java.util.Optional; + +/** + * {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and + * protobuf objects to {@link CelValue}. This converter is specifically designed for use with + * lite-variants of protobuf messages. + * + *

Protobuf semantics take precedence for conversion. For example, CEL's TimestampValue will be + * converted into Protobuf's Timestamp instead of java.time.Instant. + * + *

CEL Library Internals. Do Not Use. + */ +@Immutable +@Internal +public final class ProtoLiteCelValueConverter extends BaseProtoCelValueConverter { + private final CelLiteDescriptorPool descriptorPool; + + public static ProtoLiteCelValueConverter newInstance( + CelOptions celOptions, CelLiteDescriptorPool celLiteDescriptorPool) { + return new ProtoLiteCelValueConverter(celOptions, celLiteDescriptorPool); + } + + /** Adapts the protobuf message field into {@link CelValue}. */ + public CelValue fromProtoMessageFieldToCelValue(MessageLite msg, FieldDescriptor fieldInfo) { + checkNotNull(msg); + checkNotNull(fieldInfo); + + Method getterMethod = ReflectionUtil.getMethod(msg.getClass(), fieldInfo.getGetterName()); + Object fieldValue = ReflectionUtil.invoke(getterMethod, msg); + + switch (fieldInfo.getProtoFieldType()) { + case UINT32: + fieldValue = UnsignedLong.valueOf((int) fieldValue); + break; + case UINT64: + fieldValue = UnsignedLong.valueOf((long) fieldValue); + break; + default: + break; + } + + return fromJavaObjectToCelValue(fieldValue); + } + + @Override + public CelValue fromJavaObjectToCelValue(Object value) { + checkNotNull(value); + + if (value instanceof MessageLite) { + return fromProtoMessageToCelValue((MessageLite) value); + } else if (value instanceof MessageLite.Builder) { + return fromProtoMessageToCelValue(((MessageLite.Builder) value).build()); + } else if (value instanceof EnumLite) { + // Coerce proto enum values back into int + Method method = ReflectionUtil.getMethod(value.getClass(), "getNumber"); + value = ReflectionUtil.invoke(method, value); + } + + return super.fromJavaObjectToCelValue(value); + } + + public CelValue fromProtoMessageToCelValue(MessageLite msg) { + MessageDescriptor messageInfo = + descriptorPool + .findDescriptor(msg) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find message info for class: " + msg.getClass())); + WellKnownProto wellKnownProto = + WellKnownProto.getByTypeName(messageInfo.getFullyQualifiedProtoName()); + + if (wellKnownProto == null) { + return ProtoMessageLiteValue.create( + msg, messageInfo.getFullyQualifiedProtoName(), descriptorPool, this); + } + + switch (wellKnownProto) { + case ANY_VALUE: + return unpackAnyMessage((Any) msg); + default: + return super.fromWellKnownProtoToCelValue(msg, wellKnownProto); + } + } + + private CelValue unpackAnyMessage(Any anyMsg) { + String typeUrl = + getTypeNameFromTypeUrl(anyMsg.getTypeUrl()) + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("malformed type URL: %s", anyMsg.getTypeUrl()))); + MessageDescriptor messageInfo = + descriptorPool + .findDescriptorByTypeName(typeUrl) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find message info for any packed message's type name: " + + anyMsg)); + + Method method = + ReflectionUtil.getMethod( + messageInfo.getFullyQualifiedProtoJavaClassName(), "parseFrom", ByteString.class); + ByteString packedBytes = anyMsg.getValue(); + MessageLite unpackedMsg = (MessageLite) ReflectionUtil.invoke(method, null, packedBytes); + + return fromProtoMessageToCelValue(unpackedMsg); + } + + private static Optional getTypeNameFromTypeUrl(String typeUrl) { + int pos = typeUrl.lastIndexOf('/'); + if (pos != -1) { + return Optional.of(typeUrl.substring(pos + 1)); + } + return Optional.empty(); + } + + private ProtoLiteCelValueConverter( + CelOptions celOptions, CelLiteDescriptorPool celLiteDescriptorPool) { + super(celOptions); + this.descriptorPool = celLiteDescriptorPool; + } +} diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java new file mode 100644 index 000000000..9cc6c26ed --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java @@ -0,0 +1,128 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import com.google.auto.value.AutoValue; +import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.MessageLite; +import dev.cel.common.internal.CelLiteDescriptorPool; +import dev.cel.common.internal.ReflectionUtil; +import dev.cel.common.internal.WellKnownProto; +import dev.cel.common.types.StructTypeReference; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import java.lang.reflect.Method; +import java.util.Optional; +import org.jspecify.annotations.Nullable; + +/** ProtoMessageLiteValue is a struct value with protobuf support. */ +@AutoValue +@Immutable +public abstract class ProtoMessageLiteValue extends StructValue { + + @Override + public abstract MessageLite value(); + + @Override + public abstract StructTypeReference celType(); + + abstract CelLiteDescriptorPool descriptorPool(); + + abstract ProtoLiteCelValueConverter protoLiteCelValueConverter(); + + @Override + public boolean isZeroValue() { + return value().getDefaultInstanceForType().equals(value()); + } + + @Override + public CelValue select(StringValue field) { + MessageDescriptor messageInfo = + descriptorPool().findDescriptorByTypeName(celType().name()).get(); + FieldDescriptor fieldInfo = messageInfo.getFieldInfoMap().get(field.value()); + if (fieldInfo.getProtoFieldType().equals(FieldDescriptor.Type.MESSAGE) + && WellKnownProto.isWrapperType(fieldInfo.getFieldProtoTypeName())) { + PresenceTestResult presenceTestResult = presenceTest(field); + // Special semantics for wrapper types per CEL spec. NullValue is returned instead of the + // default value for unset fields. + if (!presenceTestResult.hasPresence()) { + return NullValue.NULL_VALUE; + } + + return presenceTestResult.selectedValue().get(); + } + + return protoLiteCelValueConverter().fromProtoMessageFieldToCelValue(value(), fieldInfo); + } + + @Override + public Optional find(StringValue field) { + PresenceTestResult presenceTestResult = presenceTest(field); + + return presenceTestResult.selectedValue(); + } + + private PresenceTestResult presenceTest(StringValue field) { + MessageDescriptor messageInfo = + descriptorPool().findDescriptorByTypeName(celType().name()).get(); + FieldDescriptor fieldInfo = messageInfo.getFieldInfoMap().get(field.value()); + CelValue selectedValue = null; + boolean presenceTestResult; + if (fieldInfo.getHasHasser()) { + Method hasserMethod = ReflectionUtil.getMethod(value().getClass(), fieldInfo.getHasserName()); + presenceTestResult = (boolean) ReflectionUtil.invoke(hasserMethod, value()); + } else { + // Lists, Maps and Opaque Values + selectedValue = + protoLiteCelValueConverter().fromProtoMessageFieldToCelValue(value(), fieldInfo); + presenceTestResult = !selectedValue.isZeroValue(); + } + + if (!presenceTestResult) { + return PresenceTestResult.create(null); + } + + if (selectedValue == null) { + selectedValue = + protoLiteCelValueConverter().fromProtoMessageFieldToCelValue(value(), fieldInfo); + } + + return PresenceTestResult.create(selectedValue); + } + + @AutoValue + abstract static class PresenceTestResult { + abstract boolean hasPresence(); + + abstract Optional selectedValue(); + + static PresenceTestResult create(@Nullable CelValue presentValue) { + Optional maybePresentValue = Optional.ofNullable(presentValue); + return new AutoValue_ProtoMessageLiteValue_PresenceTestResult( + maybePresentValue.isPresent(), maybePresentValue); + } + } + + public static ProtoMessageLiteValue create( + MessageLite value, + String protoFqn, + CelLiteDescriptorPool descriptorPool, + ProtoLiteCelValueConverter protoLiteCelValueConverter) { + Preconditions.checkNotNull(value); + return new AutoValue_ProtoMessageLiteValue( + value, StructTypeReference.create(protoFqn), descriptorPool, protoLiteCelValueConverter); + } +} diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValueProvider.java b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValueProvider.java new file mode 100644 index 000000000..544463bf9 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValueProvider.java @@ -0,0 +1,213 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Arrays.stream; + +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import com.google.common.primitives.UnsignedLong; +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.Any; +import com.google.protobuf.Internal; +import com.google.protobuf.MessageLite; +import dev.cel.common.CelErrorCode; +import dev.cel.common.CelRuntimeException; +import dev.cel.common.internal.CelLiteDescriptorPool; +import dev.cel.common.internal.DefaultInstanceMessageLiteFactory; +import dev.cel.common.internal.ProtoLiteAdapter; +import dev.cel.common.internal.ReflectionUtil; +import dev.cel.common.internal.WellKnownProto; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.WildcardType; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Function; + +/** + * {@code ProtoMessageValueProvider} constructs new instances of protobuf lite-message given its + * fully qualified name and its fields to populate. + * + *

CEL Library Internals. Do Not Use. + */ +@Immutable +public class ProtoMessageLiteValueProvider implements CelValueProvider { + private static final ImmutableMap CLASS_NAME_TO_WELL_KNOWN_PROTO_MAP; + private final ProtoLiteCelValueConverter protoLiteCelValueConverter; + private final CelLiteDescriptorPool descriptorPool; + private final ProtoLiteAdapter protoLiteAdapter; + + static { + CLASS_NAME_TO_WELL_KNOWN_PROTO_MAP = + stream(WellKnownProto.values()) + .collect(toImmutableMap(WellKnownProto::javaClassName, Function.identity())); + } + + @Override + public Optional newValue(String structType, Map fields) { + MessageDescriptor messageInfo = + descriptorPool.findDescriptorByTypeName(structType).orElse(null); + + if (messageInfo == null) { + return Optional.empty(); + } + + MessageLite msg = + DefaultInstanceMessageLiteFactory.getInstance() + .getPrototype( + messageInfo.getFullyQualifiedProtoName(), + messageInfo.getFullyQualifiedProtoJavaClassName()) + .orElse(null); + + if (msg == null) { + return Optional.empty(); + } + + MessageLite.Builder msgBuilder = msg.toBuilder(); + for (Map.Entry entry : fields.entrySet()) { + FieldDescriptor fieldInfo = messageInfo.getFieldInfoMap().get(entry.getKey()); + + Method setterMethod = + ReflectionUtil.getMethod( + msgBuilder.getClass(), fieldInfo.getSetterName(), fieldInfo.getFieldJavaClass()); + Object newFieldValue = + adaptToProtoFieldCompatibleValue( + entry.getValue(), fieldInfo, setterMethod.getParameters()[0]); + msgBuilder = + (MessageLite.Builder) ReflectionUtil.invoke(setterMethod, msgBuilder, newFieldValue); + } + + return Optional.of(protoLiteCelValueConverter.fromProtoMessageToCelValue(msgBuilder.build())); + } + + private Object adaptToProtoFieldCompatibleValue( + Object value, FieldDescriptor fieldInfo, Parameter parameter) { + Class parameterType = parameter.getType(); + if (parameterType.isAssignableFrom(Iterable.class)) { + ParameterizedType listParamType = (ParameterizedType) parameter.getParameterizedType(); + Class listParamActualTypeClass = + getActualTypeClass(listParamType.getActualTypeArguments()[0]); + + List copiedList = new ArrayList<>(); + for (Object element : (Iterable) value) { + copiedList.add( + adaptToProtoFieldCompatibleValueImpl(element, fieldInfo, listParamActualTypeClass)); + } + return copiedList; + } else if (parameterType.isAssignableFrom(Map.class)) { + ParameterizedType mapParamType = (ParameterizedType) parameter.getParameterizedType(); + Class keyActualType = getActualTypeClass(mapParamType.getActualTypeArguments()[0]); + Class valueActualType = getActualTypeClass(mapParamType.getActualTypeArguments()[1]); + + Map copiedMap = new LinkedHashMap<>(); + for (Map.Entry entry : ((Map) value).entrySet()) { + Object adaptedKey = + adaptToProtoFieldCompatibleValueImpl(entry.getKey(), fieldInfo, keyActualType); + Object adaptedValue = + adaptToProtoFieldCompatibleValueImpl(entry.getValue(), fieldInfo, valueActualType); + copiedMap.put(adaptedKey, adaptedValue); + } + return copiedMap; + } + + return adaptToProtoFieldCompatibleValueImpl(value, fieldInfo, parameter.getType()); + } + + private Object adaptToProtoFieldCompatibleValueImpl( + Object value, FieldDescriptor fieldInfo, Class parameterType) { + WellKnownProto wellKnownProto = CLASS_NAME_TO_WELL_KNOWN_PROTO_MAP.get(parameterType.getName()); + if (wellKnownProto != null) { + switch (wellKnownProto) { + case ANY_VALUE: + String typeUrl = fieldInfo.getFieldProtoTypeName(); + if (value instanceof MessageLite) { + MessageLite messageLite = (MessageLite) value; + typeUrl = + descriptorPool + .findDescriptor(messageLite) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find message info for class: " + messageLite.getClass())) + .getFullyQualifiedProtoName(); + } + return protoLiteAdapter.adaptValueToAny(value, typeUrl); + default: + return protoLiteAdapter.adaptValueToWellKnownProto(value, wellKnownProto); + } + } + + if (value instanceof UnsignedLong) { + value = ((UnsignedLong) value).longValue(); + } + + if (parameterType.equals(int.class) || parameterType.equals(Integer.class)) { + return intCheckedCast((long) value); + } else if (parameterType.equals(float.class) || parameterType.equals(Float.class)) { + return ((Double) value).floatValue(); + } else if (Internal.EnumLite.class.isAssignableFrom(parameterType)) { + // CEL coerces enums into int. We need to adapt it back into an actual proto enum. + Method method = ReflectionUtil.getMethod(parameterType, "forNumber", int.class); + return ReflectionUtil.invoke(method, null, intCheckedCast((long) value)); + } else if (parameterType.equals(Any.class)) { + return protoLiteAdapter.adaptValueToAny(value, fieldInfo.getFullyQualifiedProtoName()); + } + + return value; + } + + private static int intCheckedCast(long value) { + try { + return Ints.checkedCast(value); + } catch (IllegalArgumentException e) { + throw new CelRuntimeException(e, CelErrorCode.NUMERIC_OVERFLOW); + } + } + + private static Class getActualTypeClass(Type paramType) { + if (paramType instanceof WildcardType) { + return (Class) ((WildcardType) paramType).getUpperBounds()[0]; + } + + return (Class) paramType; + } + + public static ProtoMessageLiteValueProvider newInstance( + ProtoLiteCelValueConverter protoLiteCelValueConverter, + ProtoLiteAdapter protoLiteAdapter, + CelLiteDescriptorPool celLiteDescriptorPool) { + return new ProtoMessageLiteValueProvider( + protoLiteCelValueConverter, protoLiteAdapter, celLiteDescriptorPool); + } + + private ProtoMessageLiteValueProvider( + ProtoLiteCelValueConverter protoLiteCelValueConverter, + ProtoLiteAdapter protoLiteAdapter, + CelLiteDescriptorPool celLiteDescriptorPool) { + this.protoLiteCelValueConverter = protoLiteCelValueConverter; + this.descriptorPool = celLiteDescriptorPool; + this.protoLiteAdapter = protoLiteAdapter; + } +} diff --git a/common/src/test/java/dev/cel/common/internal/BUILD.bazel b/common/src/test/java/dev/cel/common/internal/BUILD.bazel index f394508d2..3cfbaa6b4 100644 --- a/common/src/test/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/test/java/dev/cel/common/internal/BUILD.bazel @@ -21,6 +21,7 @@ java_library( "//common/internal:comparison_functions", "//common/internal:converter", "//common/internal:default_instance_message_factory", + "//common/internal:default_instance_message_lite_factory", "//common/internal:default_message_factory", "//common/internal:dynamic_proto", "//common/internal:errors", diff --git a/common/src/test/java/dev/cel/common/internal/DefaultInstanceMessageFactoryTest.java b/common/src/test/java/dev/cel/common/internal/DefaultInstanceMessageFactoryTest.java index 4fc36f1b5..dc594a8a3 100644 --- a/common/src/test/java/dev/cel/common/internal/DefaultInstanceMessageFactoryTest.java +++ b/common/src/test/java/dev/cel/common/internal/DefaultInstanceMessageFactoryTest.java @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -44,12 +44,12 @@ import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class DefaultInstanceMessageFactoryTest { +public class DefaultInstanceMessageFactoryTest { @Before public void setUp() { // Reset the statically initialized descriptor map to get clean test runs. - DefaultInstanceMessageFactory.getInstance().resetDescriptorMapForTesting(); + DefaultInstanceMessageLiteFactory.getInstance().resetDescriptorMapForTesting(); } private enum PrototypeDescriptorTestCase { diff --git a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java index 8126ec949..f0bf5dcdc 100644 --- a/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java +++ b/common/src/test/java/dev/cel/common/internal/ProtoAdapterTest.java @@ -15,6 +15,7 @@ package dev.cel.common.internal; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -51,10 +52,6 @@ @RunWith(Enclosed.class) public final class ProtoAdapterTest { - private static final CelOptions LEGACY = CelOptions.DEFAULT; - private static final CelOptions CURRENT = - CelOptions.newBuilder().enableUnsignedLongs(true).build(); - private static final DynamicProto DYNAMIC_PROTO = DynamicProto.create(DefaultMessageFactory.INSTANCE); @@ -66,9 +63,6 @@ public static class BidirectionalConversionTest { @Parameter(1) public Message proto; - @Parameter(2) - public CelOptions options; - @Parameters public static List data() { return Arrays.asList( @@ -76,40 +70,34 @@ public static List data() { { NullValue.NULL_VALUE, Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()), - LEGACY }, - {true, BoolValue.of(true), LEGACY}, - {true, Any.pack(BoolValue.of(true)), LEGACY}, - {true, Value.newBuilder().setBoolValue(true).build(), LEGACY}, + {true, BoolValue.of(true)}, + {true, Any.pack(BoolValue.of(true))}, + {true, Value.newBuilder().setBoolValue(true).build()}, { - ByteString.copyFromUtf8("hello"), - BytesValue.of(ByteString.copyFromUtf8("hello")), - LEGACY + ByteString.copyFromUtf8("hello"), BytesValue.of(ByteString.copyFromUtf8("hello")), }, { ByteString.copyFromUtf8("hello"), Any.pack(BytesValue.of(ByteString.copyFromUtf8("hello"))), - LEGACY }, - {1.5D, DoubleValue.of(1.5D), LEGACY}, - {1.5D, Any.pack(DoubleValue.of(1.5D)), LEGACY}, - {1.5D, Value.newBuilder().setNumberValue(1.5D).build(), LEGACY}, + {1.5D, DoubleValue.of(1.5D)}, + {1.5D, Any.pack(DoubleValue.of(1.5D))}, + {1.5D, Value.newBuilder().setNumberValue(1.5D).build()}, { Duration.newBuilder().setSeconds(123).build(), Duration.newBuilder().setSeconds(123).build(), - LEGACY }, { Duration.newBuilder().setSeconds(123).build(), Any.pack(Duration.newBuilder().setSeconds(123).build()), - LEGACY }, - {1L, Int64Value.of(1L), LEGACY}, - {1L, Any.pack(Int64Value.of(1L)), LEGACY}, - {UnsignedLong.valueOf(1L), UInt64Value.of(1L), LEGACY}, - {"hello", StringValue.of("hello"), LEGACY}, - {"hello", Any.pack(StringValue.of("hello")), LEGACY}, - {"hello", Value.newBuilder().setStringValue("hello").build(), LEGACY}, + {1L, Int64Value.of(1L)}, + {1L, Any.pack(Int64Value.of(1L))}, + {UnsignedLong.valueOf(1L), UInt64Value.of(1L)}, + {"hello", StringValue.of("hello")}, + {"hello", Any.pack(StringValue.of("hello"))}, + {"hello", Value.newBuilder().setStringValue("hello").build()}, { Arrays.asList("hello", "world"), Any.pack( @@ -117,7 +105,6 @@ public static List data() { .addValues(Value.newBuilder().setStringValue("hello")) .addValues(Value.newBuilder().setStringValue("world")) .build()), - LEGACY }, { ImmutableMap.of("hello", "world"), @@ -125,7 +112,6 @@ public static List data() { Struct.newBuilder() .putFields("hello", Value.newBuilder().setStringValue("world").build()) .build()), - LEGACY }, { ImmutableMap.of("list_value", ImmutableList.of(false, NullValue.NULL_VALUE)), @@ -139,30 +125,28 @@ public static List data() { .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE))) .build()) .build(), - LEGACY }, { Timestamp.newBuilder().setSeconds(123).build(), Timestamp.newBuilder().setSeconds(123).build(), - LEGACY }, { Timestamp.newBuilder().setSeconds(123).build(), Any.pack(Timestamp.newBuilder().setSeconds(123).build()), - LEGACY }, // Adaption support for the most current CelOptions. - {UnsignedLong.valueOf(1L), UInt64Value.of(1L), CURRENT}, - {UnsignedLong.valueOf(1L), Any.pack(UInt64Value.of(1L)), CURRENT}, + {UnsignedLong.valueOf(1L), UInt64Value.of(1L)}, + {UnsignedLong.valueOf(1L), Any.pack(UInt64Value.of(1L))}, }); } @Test public void adaptValueToProto_bidirectionalConversion() { DynamicProto dynamicProto = DynamicProto.create(DefaultMessageFactory.INSTANCE); - ProtoAdapter protoAdapter = new ProtoAdapter(dynamicProto, options.enableUnsignedLongs()); + ProtoAdapter protoAdapter = + new ProtoAdapter(dynamicProto, CelOptions.DEFAULT.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(value, proto.getDescriptorForType().getFullName())) - .hasValue(proto); + .isEqualTo(proto); assertThat(protoAdapter.adaptProtoToValue(proto)).isEqualTo(value); } } @@ -179,94 +163,96 @@ public void adaptAnyValue_hermeticTypes_bidirectionalConversion() { typeName.equals(Expr.getDescriptor().getFullName()) ? Optional.of(Expr.newBuilder()) : Optional.empty()), - LEGACY.enableUnsignedLongs()); + CelOptions.DEFAULT.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(expr, Any.getDescriptor().getFullName())) - .hasValue(Any.pack(expr)); + .isEqualTo(Any.pack(expr)); assertThat(protoAdapter.adaptProtoToValue(Any.pack(expr))).isEqualTo(expr); } } @RunWith(JUnit4.class) public static class AsymmetricConversionTest { - @Test - public void adaptValueToProto_asymmetricNullConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); - assertThat(protoAdapter.adaptValueToProto(null, Any.getDescriptor().getFullName())) - .hasValue(Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build())); - assertThat( - protoAdapter.adaptProtoToValue( - Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()))) - .isEqualTo(NullValue.NULL_VALUE); - } - @Test public void adaptValueToProto_asymmetricFloatConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5F, Any.getDescriptor().getFullName())) - .hasValue(Any.pack(FloatValue.of(1.5F))); + .isEqualTo(Any.pack(FloatValue.of(1.5F))); assertThat(protoAdapter.adaptProtoToValue(Any.pack(FloatValue.of(1.5F)))).isEqualTo(1.5D); } @Test public void adaptValueToProto_asymmetricDoubleFloatConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5D, FloatValue.getDescriptor().getFullName())) - .hasValue(FloatValue.of(1.5F)); + .isEqualTo(FloatValue.of(1.5F)); assertThat(protoAdapter.adaptProtoToValue(FloatValue.of(1.5F))).isEqualTo(1.5D); } @Test public void adaptValueToProto_asymmetricFloatDoubleConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); assertThat(protoAdapter.adaptValueToProto(1.5F, DoubleValue.getDescriptor().getFullName())) - .hasValue(DoubleValue.of(1.5D)); + .isEqualTo(DoubleValue.of(1.5D)); } @Test public void adaptValueToProto_asymmetricJsonConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, CURRENT.enableUnsignedLongs()); + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); assertThat( protoAdapter.adaptValueToProto( UnsignedLong.valueOf(1L), Value.getDescriptor().getFullName())) - .hasValue(Value.newBuilder().setNumberValue(1).build()); + .isEqualTo(Value.newBuilder().setNumberValue(1).build()); assertThat( protoAdapter.adaptValueToProto( UnsignedLong.fromLongBits(-1L), Value.getDescriptor().getFullName())) - .hasValue(Value.newBuilder().setStringValue(Long.toUnsignedString(-1L)).build()); + .isEqualTo(Value.newBuilder().setStringValue(Long.toUnsignedString(-1L)).build()); assertThat(protoAdapter.adaptValueToProto(1L, Value.getDescriptor().getFullName())) - .hasValue(Value.newBuilder().setNumberValue(1).build()); + .isEqualTo(Value.newBuilder().setNumberValue(1).build()); assertThat( protoAdapter.adaptValueToProto(Long.MAX_VALUE, Value.getDescriptor().getFullName())) - .hasValue(Value.newBuilder().setStringValue(Long.toString(Long.MAX_VALUE)).build()); + .isEqualTo(Value.newBuilder().setStringValue(Long.toString(Long.MAX_VALUE)).build()); assertThat( protoAdapter.adaptValueToProto( ByteString.copyFromUtf8("foo"), Value.getDescriptor().getFullName())) - .hasValue(Value.newBuilder().setStringValue("Zm9v").build()); + .isEqualTo(Value.newBuilder().setStringValue("Zm9v").build()); } @Test public void adaptValueToProto_unsupportedJsonConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); - assertThat( + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); + + assertThrows( + ClassCastException.class, + () -> protoAdapter.adaptValueToProto( - ImmutableMap.of(1, 1), Any.getDescriptor().getFullName())) - .isEmpty(); + ImmutableMap.of(1, 1), Any.getDescriptor().getFullName())); } @Test public void adaptValueToProto_unsupportedJsonListConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); - assertThat( + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); + + assertThrows( + ClassCastException.class, + () -> protoAdapter.adaptValueToProto( - ImmutableMap.of(1, 1), ListValue.getDescriptor().getFullName())) - .isEmpty(); + ImmutableMap.of(1, 1), ListValue.getDescriptor().getFullName())); } @Test public void adaptValueToProto_unsupportedConversion() { - ProtoAdapter protoAdapter = new ProtoAdapter(DYNAMIC_PROTO, LEGACY.enableUnsignedLongs()); - assertThat(protoAdapter.adaptValueToProto("Hello", Expr.getDescriptor().getFullName())) - .isEmpty(); + ProtoAdapter protoAdapter = + new ProtoAdapter(DYNAMIC_PROTO, CelOptions.DEFAULT.enableUnsignedLongs()); + + assertThrows( + IllegalStateException.class, + () -> protoAdapter.adaptValueToProto("Hello", Expr.getDescriptor().getFullName())); } @Test diff --git a/common/values/BUILD.bazel b/common/values/BUILD.bazel index c7efe7dee..8f46088d8 100644 --- a/common/values/BUILD.bazel +++ b/common/values/BUILD.bazel @@ -21,6 +21,11 @@ java_library( exports = ["//common/src/main/java/dev/cel/common/values"], ) +java_library( + name = "base_proto_cel_value_converter", + exports = ["//common/src/main/java/dev/cel/common/values:base_proto_cel_value_converter"], +) + java_library( name = "proto_message_value_provider", exports = ["//common/src/main/java/dev/cel/common/values:proto_message_value_provider"], @@ -35,3 +40,13 @@ java_library( name = "proto_message_value", exports = ["//common/src/main/java/dev/cel/common/values:proto_message_value"], ) + +java_library( + name = "proto_message_lite_value", + exports = ["//common/src/main/java/dev/cel/common/values:proto_message_lite_value"], +) + +java_library( + name = "proto_message_lite_value_provider", + exports = ["//common/src/main/java/dev/cel/common/values:proto_message_lite_value_provider"], +) diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java index 5c77bdae7..c4b449b81 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java @@ -62,6 +62,7 @@ import java.util.NoSuchElementException; import org.junit.runners.model.Statement; +/** Conformance test suite for CEL-Java. */ // Qualifying proto2/proto3 TestAllTypes makes it less clear. @SuppressWarnings("UnnecessarilyFullyQualified") public final class ConformanceTest extends Statement { diff --git a/java_lite_proto_cel_library.bzl b/java_lite_proto_cel_library.bzl new file mode 100644 index 000000000..691a6cd78 --- /dev/null +++ b/java_lite_proto_cel_library.bzl @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Starlark rule for generating descriptors that is compatible with Protolite Messages.""" + +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_proto//proto:defs.bzl", "proto_descriptor_set") +load("//publish:cel_version.bzl", "CEL_VERSION") + +def java_lite_proto_cel_library( + name, + descriptor_class_prefix, + deps, + debug = False): + """Generates a CelLiteDescriptor + + Args: + name: name of this target. + descriptor_class_prefix: Prefix name for the generated descriptor java class (ex: 'TestAllTypes' generates 'TestAllTypesCelLiteDescriptor.java'). + deps: Name of the proto_library target. Only a single proto_library is supported at this time. + debug: If true, prints additional information during codegen for debugging purposes. + """ + if not deps: + fail("You must provide a proto_library dependency.") + + if len(deps) > 1: + fail("You must provide only one proto_library dependency.") + + _generate_cel_lite_descriptor_class( + name, + descriptor_class_prefix + "CelLiteDescriptor", + deps[0], + debug, + ) + + descriptor_codegen_deps = [ + "//protobuf:cel_lite_descriptor", + ] + + java_library( + name = name, + srcs = [":" + name + "_cel_lite_descriptor"], + deps = deps + descriptor_codegen_deps, + ) + +def _generate_cel_lite_descriptor_class( + name, + descriptor_class_name, + proto_src, + debug): + outfile = "%s.java" % descriptor_class_name + + transitive_descriptor_set_name = "%s_transitive_descriptor_set" % name + proto_descriptor_set( + name = transitive_descriptor_set_name, + deps = [proto_src], + ) + + direct_descriptor_set_name = proto_src + + debug_flag = "--debug" if debug else "" + + cmd = ( + "$(location //protobuf:cel_lite_descriptor_generator) " + + "--descriptor $(location %s) " % direct_descriptor_set_name + + "--transitive_descriptor_set $(location %s) " % transitive_descriptor_set_name + + "--descriptor_class_name %s " % descriptor_class_name + + "--out $(location %s) " % outfile + + "--version %s " % CEL_VERSION + + debug_flag + ) + + native.genrule( + name = name + "_cel_lite_descriptor", + srcs = [ + transitive_descriptor_set_name, + direct_descriptor_set_name, + ], + cmd = cmd, + outs = [outfile], + tools = ["//protobuf:cel_lite_descriptor_generator"], + ) diff --git a/maven_utils.bzl b/maven_utils.bzl new file mode 100644 index 000000000..3dd92d75b --- /dev/null +++ b/maven_utils.bzl @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for installing maven artifact in Bazel WORKSPACE.""" + +load("@rules_jvm_external//:specs.bzl", "maven") + +def maven_artifact_compile_only(group, artifact, version): + """Installs the maven JAR as a compile-time only dependency (ex: tools, codegen).""" + return maven.artifact( + artifact = artifact, + group = group, + neverlink = True, + version = version, + ) + +def maven_artifact_test_only(group, artifact, version): + """Installs the maven JAR as a test-time only dependency (ex: tools, codegen).""" + return maven.artifact( + artifact = artifact, + group = group, + testonly = True, + version = version, + ) diff --git a/protobuf/BUILD.bazel b/protobuf/BUILD.bazel new file mode 100644 index 000000000..83c80bc6d --- /dev/null +++ b/protobuf/BUILD.bazel @@ -0,0 +1,24 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], # TODO: Expose when ready +) + +java_library( + name = "cel_lite_descriptor", + exports = ["//protobuf/src/main/java/dev/cel/protobuf:cel_lite_descriptor"], +) + +alias( + name = "cel_lite_descriptor_generator", + actual = "//protobuf/src/main/java/dev/cel/protobuf:cel_lite_descriptor_generator", + visibility = ["//visibility:public"], +) + +java_library( + name = "proto_descriptor_collector", + testonly = 1, + visibility = ["//visibility:public"], + exports = ["//protobuf/src/main/java/dev/cel/protobuf:proto_descriptor_collector"], +) diff --git a/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel b/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel new file mode 100644 index 000000000..33188dee7 --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel @@ -0,0 +1,75 @@ +load("@rules_java//java:defs.bzl", "java_binary", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//protobuf:__pkg__"], +) + +filegroup( + name = "template_files", + srcs = glob(["templates/*.txt"]), + visibility = ["//visibility:private"], +) + +java_binary( + name = "cel_lite_descriptor_generator", + srcs = ["CelLiteDescriptorGenerator.java"], + main_class = "dev.cel.protobuf.CelLiteDescriptorGenerator", + deps = [ + ":debug_printer", + ":java_file_generator", + ":proto_descriptor_collector", + "//common", + "//common/internal:proto_java_qualified_names", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:info_picocli_picocli", + ], +) + +java_library( + name = "proto_descriptor_collector", + srcs = ["ProtoDescriptorCollector.java"], + deps = [ + ":cel_lite_descriptor", + ":debug_printer", + "//common", + "//common/internal:proto_java_qualified_names", + "//common/internal:well_known_proto", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + +java_library( + name = "java_file_generator", + srcs = ["JavaFileGenerator.java"], + resources = [ + ":template_files", + ], + deps = [ + ":cel_lite_descriptor", + "//:auto_value", + "@maven//:com_google_guava_guava", + "@maven//:org_freemarker_freemarker", + ], +) + +java_library( + name = "debug_printer", + srcs = ["DebugPrinter.java"], + deps = [ + "@maven//:info_picocli_picocli", + ], +) + +java_library( + name = "cel_lite_descriptor", + srcs = ["CelLiteDescriptor.java"], + deps = [ + "//common/annotations", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) diff --git a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java new file mode 100644 index 000000000..7de5b9d5b --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptor.java @@ -0,0 +1,346 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.lang.Math.ceil; + +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.ByteString; +import dev.cel.common.annotations.Internal; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Base class for code generated CEL lite descriptors to extend from. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +@Immutable +public abstract class CelLiteDescriptor { + @SuppressWarnings("Immutable") // Copied to unmodifiable map + private final Map protoFqnToDescriptors; + + @SuppressWarnings("Immutable") // Copied to unmodifiable map + private final Map protoJavaClassNameToDescriptors; + + public Map getProtoTypeNamesToDescriptors() { + return protoFqnToDescriptors; + } + + public Map getProtoJavaClassNameToDescriptors() { + return protoJavaClassNameToDescriptors; + } + + /** + * Contains a collection of classes which describe protobuf messagelite types. + * + *

CEL Library Internals. Do Not Use. + */ + @Internal + @Immutable + public static final class MessageDescriptor { + private final String fullyQualifiedProtoName; + private final String fullyQualifiedProtoJavaClassName; + + @SuppressWarnings("Immutable") // Copied to unmodifiable map + private final Map fieldInfoMap; + + public String getFullyQualifiedProtoName() { + return fullyQualifiedProtoName; + } + + public String getFullyQualifiedProtoJavaClassName() { + return fullyQualifiedProtoJavaClassName; + } + + public Map getFieldInfoMap() { + return fieldInfoMap; + } + + public MessageDescriptor( + String fullyQualifiedProtoName, + String fullyQualifiedProtoJavaClassName, + Map fieldInfoMap) { + this.fullyQualifiedProtoName = checkNotNull(fullyQualifiedProtoName); + this.fullyQualifiedProtoJavaClassName = checkNotNull(fullyQualifiedProtoJavaClassName); + // This is a cheap operation. View over the existing map with mutators disabled. + this.fieldInfoMap = checkNotNull(Collections.unmodifiableMap(fieldInfoMap)); + } + } + + /** + * Describes a field of a protobuf messagelite type. + * + *

CEL Library Internals. Do Not Use. + */ + @Internal + @Immutable + public static final class FieldDescriptor { + private final String fullyQualifiedProtoName; + private final JavaType javaType; + private final String methodSuffixName; + private final String fieldJavaClassName; + private final ValueType celFieldValueType; + private final Type protoFieldType; + private final boolean hasHasser; + private final String fieldProtoTypeName; + + /** Enumeration of the value type. */ + public enum ValueType { + SCALAR, + LIST, + MAP + } + + /** + * Enumeration of the java type. + * + *

This is exactly the same as com.google.protobuf.Descriptors#JavaType + */ + public enum JavaType { + INT, + LONG, + FLOAT, + DOUBLE, + BOOLEAN, + STRING, + BYTE_STRING, + ENUM, + MESSAGE + } + + /** + * Enumeration of the protobuf type. + * + *

This is exactly the same as com.google.protobuf.Descriptors#Type + */ + public enum Type { + DOUBLE, + FLOAT, + INT64, + UINT64, + INT32, + FIXED64, + FIXED32, + BOOL, + STRING, + GROUP, + MESSAGE, + BYTES, + UINT32, + ENUM, + SFIXED32, + SFIXED64, + SINT32, + SINT64 + } + + // Lazily-loaded field + @SuppressWarnings("Immutable") + private volatile Class fieldJavaType; + + public Class getFieldJavaClass() { + if (fieldJavaType == null) { + synchronized (this) { + if (fieldJavaType == null) { + fieldJavaType = loadFieldTypeClass(); + } + } + } + return fieldJavaType; + } + + public JavaType getJavaType() { + return javaType; + } + + public String getMethodSuffixName() { + return methodSuffixName; + } + + public String getSetterName() { + String prefix = ""; + switch (celFieldValueType) { + case SCALAR: + prefix = "set"; + break; + case LIST: + prefix = "addAll"; + break; + case MAP: + prefix = "putAll"; + break; + } + return prefix + getMethodSuffixName(); + } + + public String getGetterName() { + String suffix = ""; + switch (celFieldValueType) { + case SCALAR: + break; + case LIST: + suffix = "List"; + break; + case MAP: + suffix = "Map"; + break; + } + return "get" + getMethodSuffixName() + suffix; + } + + public String getHasserName() { + return "has" + getMethodSuffixName(); + } + + public String getFieldJavaClassName() { + return fieldJavaClassName; + } + + public ValueType getCelFieldValueType() { + return celFieldValueType; + } + + public Type getProtoFieldType() { + return protoFieldType; + } + + public boolean getHasHasser() { + return hasHasser && celFieldValueType.equals(ValueType.SCALAR); + } + + public String getFullyQualifiedProtoName() { + return fullyQualifiedProtoName; + } + + public String getFieldProtoTypeName() { + return fieldProtoTypeName; + } + + /** + * Must be public, used for codegen only. Do not use. + * + * @param fullyQualifiedProtoName Fully qualified protobuf type name including the namespace + * (ex: cel.expr.conformance.proto3.TestAllTypes) + * @param javaTypeName Canonical Java type name (ex: Long, Double, Float, Message... see + * Descriptors#JavaType) + * @param methodSuffixName Suffix used to decorate the getters/setters (eg: "foo" in "setFoo" + * and "getFoo") + * @param celFieldValueType Describes whether the field is a scalar, list or a map with respect + * to CEL. + * @param protoFieldType Protobuf Field Type (ex: INT32, SINT32, GROUP, MESSAGE... see + * Descriptors#Type) + * @param hasHasser True if the message has a presence test method (ex: wrappers). + * @param fieldJavaClassName Fully qualified Java class name for the field, including its + * package name. Empty if the field is a primitive. + * @param fieldProtoTypeName Fully qualified protobuf type name for the field. Empty if the + * field is a primitive. + */ + @Internal + public FieldDescriptor( + String fullyQualifiedProtoName, + String javaTypeName, + String methodSuffixName, + String celFieldValueType, // LIST, MAP, SCALAR + String protoFieldType, // INT32, SINT32, GROUP, MESSAGE... (See Descriptors#Type) + String hasHasser, // + String fieldJavaClassName, + String fieldProtoTypeName) { + this.fullyQualifiedProtoName = checkNotNull(fullyQualifiedProtoName); + this.javaType = JavaType.valueOf(javaTypeName); + this.methodSuffixName = checkNotNull(methodSuffixName); + this.fieldJavaClassName = checkNotNull(fieldJavaClassName); + this.celFieldValueType = ValueType.valueOf(checkNotNull(celFieldValueType)); + this.protoFieldType = Type.valueOf(protoFieldType); + this.hasHasser = Boolean.parseBoolean(hasHasser); + this.fieldProtoTypeName = checkNotNull(fieldProtoTypeName); + this.fieldJavaType = getPrimitiveFieldTypeClass(); + } + + @SuppressWarnings("ReturnMissingNullable") // Avoid taking a dependency on jspecify.nullable. + private Class getPrimitiveFieldTypeClass() { + if (celFieldValueType.equals(ValueType.LIST)) { + return Iterable.class; + } else if (celFieldValueType.equals(ValueType.MAP)) { + return Map.class; + } + + switch (javaType) { + case INT: + return int.class; + case LONG: + return long.class; + case FLOAT: + return float.class; + case DOUBLE: + return double.class; + case BOOLEAN: + return boolean.class; + case STRING: + return String.class; + case BYTE_STRING: + return ByteString.class; + default: + // Non-primitives must be lazily loaded during instantiation of the runtime environment, + // where the generated messages are linked into the binary via java_lite_proto_library. + return null; + } + } + + private Class loadFieldTypeClass() { + if (!javaType.equals(JavaType.ENUM) && !javaType.equals(JavaType.MESSAGE)) { + throw new IllegalArgumentException("Unexpected java type name for " + javaType); + } + + try { + return Class.forName(fieldJavaClassName); + } catch (ClassNotFoundException e) { + throw new LinkageError(String.format("Could not find class %s", fieldJavaClassName), e); + } + } + } + + protected CelLiteDescriptor(List messageInfoList) { + Map protoFqnMap = + new HashMap<>(getMapInitialCapacity(messageInfoList.size())); + Map protoJavaClassNameMap = + new HashMap<>(getMapInitialCapacity(messageInfoList.size())); + for (MessageDescriptor msgInfo : messageInfoList) { + protoFqnMap.put(msgInfo.getFullyQualifiedProtoName(), msgInfo); + protoJavaClassNameMap.put(msgInfo.getFullyQualifiedProtoJavaClassName(), msgInfo); + } + + this.protoFqnToDescriptors = Collections.unmodifiableMap(protoFqnMap); + this.protoJavaClassNameToDescriptors = Collections.unmodifiableMap(protoJavaClassNameMap); + } + + /** + * Returns a capacity that is sufficient to keep the map from being resized as long as it grows no + * larger than expectedSize and the load factor is ≥ its default (0.75). + */ + private static int getMapInitialCapacity(int expectedSize) { + if (expectedSize < 3) { + return expectedSize + 1; + } + + // See https://github.com/openjdk/jdk/commit/3e393047e12147a81e2899784b943923fc34da8e. 0.75 is + // used as a load factor. + return (int) ceil(expectedSize / 0.75); + } +} diff --git a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java new file mode 100644 index 000000000..479c69c63 --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java @@ -0,0 +1,156 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.io.Files; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.DescriptorProtos.FileDescriptorSet; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; +import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.internal.ProtoJavaQualifiedNames; +import dev.cel.protobuf.JavaFileGenerator.JavaFileGeneratorOption; +import java.io.File; +import java.io.IOException; +import java.util.concurrent.Callable; +import picocli.CommandLine; +import picocli.CommandLine.Model.OptionSpec; +import picocli.CommandLine.Option; + +final class CelLiteDescriptorGenerator implements Callable { + + @Option( + names = {"--out"}, + description = "Outpath for the CelLiteDescriptor") + private String outPath = ""; + + @Option( + names = {"--descriptor"}, + description = + "Path to the descriptor (from proto_library) that the CelLiteDescriptor is to be" + + " generated from") + private String targetDescriptorPath = ""; + + @Option( + names = {"--transitive_descriptor_set"}, + description = "Path to the transitive set of descriptors") + private String transitiveDescriptorSetPath = ""; + + @Option( + names = {"--descriptor_class_name"}, + description = "Class name for the CelLiteDescriptor") + private String descriptorClassName = ""; + + @Option( + names = {"--version"}, + description = "CEL-Java version") + private String version = ""; + + @Option( + names = {"--debug"}, + description = "Prints debug output") + private boolean debug = false; + + private DebugPrinter debugPrinter; + + @Override + public Integer call() throws Exception { + String targetDescriptorProtoPath = extractProtoPath(targetDescriptorPath); + debugPrinter.print("Target descriptor proto path: " + targetDescriptorProtoPath); + + FileDescriptor targetFileDescriptor = null; + ImmutableSet transitiveFileDescriptors = + CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( + load(transitiveDescriptorSetPath)); + for (FileDescriptor fd : transitiveFileDescriptors) { + if (fd.getFullName().equals(targetDescriptorProtoPath)) { + debugPrinter.print("Transitive Descriptor Path: " + fd.getFullName()); + targetFileDescriptor = fd; + break; + } + } + + if (targetFileDescriptor == null) { + throw new IllegalArgumentException( + String.format( + "Target descriptor %s not found from transitive set of descriptors!", + targetDescriptorProtoPath)); + } + + codegenCelLiteDescriptor(targetFileDescriptor); + + return 0; + } + + private void codegenCelLiteDescriptor(FileDescriptor targetFileDescriptor) throws Exception { + String javaPackageName = ProtoJavaQualifiedNames.getJavaPackageName(targetFileDescriptor); + ProtoDescriptorCollector descriptorCollector = + ProtoDescriptorCollector.newInstance(debugPrinter); + + JavaFileGenerator.createFile( + outPath, + JavaFileGeneratorOption.newBuilder() + .setVersion(version) + .setDescriptorClassName(descriptorClassName) + .setPackageName(javaPackageName) + .setMessageInfoList(descriptorCollector.collectMessageInfo(targetFileDescriptor)) + .build()); + } + + private String extractProtoPath(String descriptorPath) { + FileDescriptorSet fds = load(descriptorPath); + FileDescriptorProto fileDescriptorProto = Iterables.getOnlyElement(fds.getFileList()); + return fileDescriptorProto.getName(); + } + + private FileDescriptorSet load(String descriptorSetPath) { + try { + byte[] descriptorBytes = Files.toByteArray(new File(descriptorSetPath)); + // TODO: Implement ProtoExtensions + return FileDescriptorSet.parseFrom(descriptorBytes, ExtensionRegistry.getEmptyRegistry()); + } catch (IOException e) { + throw new IllegalArgumentException( + "Failed to load FileDescriptorSet from path: " + descriptorSetPath, e); + } + } + + private void printAllFlags(CommandLine cmd) { + debugPrinter.print("Flag values:"); + debugPrinter.print("-------------------------------------------------------------"); + for (OptionSpec option : cmd.getCommandSpec().options()) { + debugPrinter.print(option.longestName() + ": " + option.getValue()); + } + debugPrinter.print("-------------------------------------------------------------"); + } + + private void initializeDebugPrinter() { + this.debugPrinter = DebugPrinter.newInstance(debug); + } + + public static void main(String[] args) { + CelLiteDescriptorGenerator celLiteDescriptorGenerator = new CelLiteDescriptorGenerator(); + CommandLine cmd = new CommandLine(celLiteDescriptorGenerator); + cmd.parseArgs(args); + celLiteDescriptorGenerator.initializeDebugPrinter(); + celLiteDescriptorGenerator.printAllFlags(cmd); + + int exitCode = cmd.execute(args); + System.exit(exitCode); + } + + CelLiteDescriptorGenerator() {} +} diff --git a/protobuf/src/main/java/dev/cel/protobuf/DebugPrinter.java b/protobuf/src/main/java/dev/cel/protobuf/DebugPrinter.java new file mode 100644 index 000000000..34a09ce98 --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/DebugPrinter.java @@ -0,0 +1,36 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import picocli.CommandLine.Help.Ansi; + +final class DebugPrinter { + + private final boolean debug; + + static DebugPrinter newInstance(boolean debug) { + return new DebugPrinter(debug); + } + + void print(String message) { + if (debug) { + System.out.println(Ansi.ON.string("@|cyan [CelLiteDescriptorGenerator] |@" + message)); + } + } + + private DebugPrinter(boolean debug) { + this.debug = debug; + } +} diff --git a/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java b/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java new file mode 100644 index 000000000..da569f624 --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/JavaFileGenerator.java @@ -0,0 +1,96 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Files; +// CEL-Internal-5 +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import freemarker.template.Configuration; +import freemarker.template.DefaultObjectWrapperBuilder; +import freemarker.template.Template; +import freemarker.template.TemplateException; +import freemarker.template.Version; +import java.io.File; +import java.io.IOException; +import java.io.StringWriter; +import java.io.Writer; + +final class JavaFileGenerator { + + private static final String HELPER_CLASS_TEMPLATE_FILE = "cel_lite_descriptor_template.txt"; + + public static void createFile(String filePath, JavaFileGeneratorOption option) + throws IOException, TemplateException { + Version version = Configuration.VERSION_2_3_32; + Configuration cfg = new Configuration(version); + cfg.setClassForTemplateLoading(JavaFileGenerator.class, "templates/"); + cfg.setDefaultEncoding("UTF-8"); + cfg.setBooleanFormat("c"); + cfg.setAPIBuiltinEnabled(true); + DefaultObjectWrapperBuilder wrapperBuilder = new DefaultObjectWrapperBuilder(version); + wrapperBuilder.setExposeFields(true); + cfg.setObjectWrapper(wrapperBuilder.build()); + + Template template = cfg.getTemplate(HELPER_CLASS_TEMPLATE_FILE); + Writer out = new StringWriter(); + + template.process(option.getTemplateMap(), out); + + Files.asCharSink(new File(filePath), UTF_8).write(out.toString()); + } + + @AutoValue + abstract static class JavaFileGeneratorOption { + abstract String packageName(); + + abstract String descriptorClassName(); + + abstract String version(); + + abstract ImmutableList messageInfoList(); + + ImmutableMap getTemplateMap() { + return ImmutableMap.of( + "package_name", packageName(), + "descriptor_class_name", descriptorClassName(), + "version", version(), + "message_info_list", messageInfoList()); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setPackageName(String packageName); + + abstract Builder setDescriptorClassName(String className); + + abstract Builder setVersion(String version); + + abstract Builder setMessageInfoList(ImmutableList messageInfo); + + abstract JavaFileGeneratorOption build(); + } + + static Builder newBuilder() { + return new AutoValue_JavaFileGenerator_JavaFileGeneratorOption.Builder(); + } + } + + private JavaFileGenerator() {} +} diff --git a/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java b/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java new file mode 100644 index 000000000..143eaddfe --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java @@ -0,0 +1,132 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.CaseFormat; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelDescriptors; +import dev.cel.common.internal.ProtoJavaQualifiedNames; +import dev.cel.common.internal.WellKnownProto; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor; +import dev.cel.protobuf.CelLiteDescriptor.FieldDescriptor.ValueType; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; + +final class ProtoDescriptorCollector { + + private final DebugPrinter debugPrinter; + + @VisibleForTesting + ImmutableList collectMessageInfo(FileDescriptor targetFileDescriptor) { + ImmutableList.Builder messageInfoListBuilder = ImmutableList.builder(); + CelDescriptors celDescriptors = + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + ImmutableList.of(targetFileDescriptor), /* resolveTypeDependencies= */ false); + ImmutableSet messageTypes = + celDescriptors.messageTypeDescriptors().stream() + .filter(d -> WellKnownProto.getByTypeName(d.getFullName()) == null) + .collect(toImmutableSet()); + + for (Descriptor descriptor : messageTypes) { + ImmutableMap.Builder fieldMap = ImmutableMap.builder(); + for (Descriptors.FieldDescriptor fieldDescriptor : descriptor.getFields()) { + String methodSuffixName = + CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, fieldDescriptor.getName()); + + String javaType = fieldDescriptor.getJavaType().toString(); + String embeddedFieldJavaClassName = ""; + String embeddedFieldProtoTypeName = ""; + switch (javaType) { + case "ENUM": + embeddedFieldJavaClassName = + ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName( + fieldDescriptor.getEnumType()); + embeddedFieldProtoTypeName = fieldDescriptor.getEnumType().getFullName(); + break; + case "MESSAGE": + embeddedFieldJavaClassName = + ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName( + fieldDescriptor.getMessageType()); + embeddedFieldProtoTypeName = fieldDescriptor.getMessageType().getFullName(); + break; + default: + break; + } + + ValueType fieldValueType; + if (fieldDescriptor.isMapField()) { + fieldValueType = ValueType.MAP; + } else if (fieldDescriptor.isRepeated()) { + fieldValueType = ValueType.LIST; + } else { + fieldValueType = ValueType.SCALAR; + } + + fieldMap.put( + fieldDescriptor.getName(), + new FieldDescriptor( + fieldDescriptor.getFullName(), + javaType, + methodSuffixName, + fieldValueType.toString(), + fieldDescriptor.getType().toString(), + String.valueOf(fieldDescriptor.hasPresence()), + embeddedFieldJavaClassName, + embeddedFieldProtoTypeName)); + + debugPrinter.print( + String.format( + "Method suffix name in %s, for field %s: %s", + descriptor.getFullName(), fieldDescriptor.getFullName(), methodSuffixName)); + debugPrinter.print(String.format("FieldType: %s", fieldValueType)); + if (!embeddedFieldJavaClassName.isEmpty()) { + debugPrinter.print( + String.format( + "Java class name for field %s: %s", + fieldDescriptor.getName(), embeddedFieldJavaClassName)); + } + } + + messageInfoListBuilder.add( + new MessageDescriptor( + descriptor.getFullName(), + ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(descriptor), + fieldMap.buildOrThrow())); + } + + return messageInfoListBuilder.build(); + } + + static ProtoDescriptorCollector newInstance(DebugPrinter debugPrinter) { + return new ProtoDescriptorCollector(debugPrinter); + } + + @VisibleForTesting + ProtoDescriptorCollector() { + this(DebugPrinter.newInstance(false)); + } + + private ProtoDescriptorCollector(DebugPrinter debugPrinter) { + this.debugPrinter = debugPrinter; + } +} diff --git a/protobuf/src/main/java/dev/cel/protobuf/templates/cel_lite_descriptor_template.txt b/protobuf/src/main/java/dev/cel/protobuf/templates/cel_lite_descriptor_template.txt new file mode 100644 index 000000000..c08d23886 --- /dev/null +++ b/protobuf/src/main/java/dev/cel/protobuf/templates/cel_lite_descriptor_template.txt @@ -0,0 +1,70 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Generated by CEL-Java library. DO NOT EDIT! + * Version: ${version} + */ + +package ${package_name}; + +import dev.cel.protobuf.CelLiteDescriptor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public final class ${descriptor_class_name} extends CelLiteDescriptor { + + private static final ${descriptor_class_name} DESCRIPTOR = new ${descriptor_class_name}(); + + public static ${descriptor_class_name} getDescriptor() { + return DESCRIPTOR; + } + + private static List newDescriptors() { + List descriptors = new ArrayList<>(${message_info_list?size}); + Map fieldDescriptors; + <#list message_info_list as message_info> + + fieldDescriptors = new HashMap<>(${message_info.fieldInfoMap?size}); + <#list message_info.fieldInfoMap as key, value> + fieldDescriptors.put("${key}", new FieldDescriptor( + "${value.fullyQualifiedProtoName}", + "${value.javaType}", + "${value.methodSuffixName}", + "${value.celFieldValueType}", + "${value.protoFieldType}", + "${value.hasHasser}", + "${value.fieldJavaClassName}", + "${value.fieldProtoTypeName}" + )); + + + descriptors.add( + new MessageDescriptor( + "${message_info.fullyQualifiedProtoName}", + "${message_info.fullyQualifiedProtoJavaClassName}", + Collections.unmodifiableMap(fieldDescriptors)) + ); + + + return Collections.unmodifiableList(descriptors); + } + + private ${descriptor_class_name}() { + super(newDescriptors()); + } +} \ No newline at end of file diff --git a/protobuf/src/test/java/dev/cel/protobuf/BUILD.bazel b/protobuf/src/test/java/dev/cel/protobuf/BUILD.bazel new file mode 100644 index 000000000..b7bf89ddf --- /dev/null +++ b/protobuf/src/test/java/dev/cel/protobuf/BUILD.bazel @@ -0,0 +1,46 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("//:java_lite_proto_cel_library.bzl", "java_lite_proto_cel_library") +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = 1, + srcs = glob(["*.java"]), + deps = [ + ":test_all_types_proto3_java_lite_cel_proto", + "//:java_truth", + "//common", + "//common:options", + "//common/types", + "//compiler", + "//compiler:compiler_builder", + "//parser:macro", + "//protobuf:cel_lite_descriptor", + "//protobuf:proto_descriptor_collector", + "//runtime", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +java_lite_proto_cel_library( + name = "test_all_types_proto3_java_lite_cel_proto", + descriptor_class_prefix = "TestAllTypes", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/protobuf/src/test/java/dev/cel/protobuf/CelLiteDescriptorEvaluationTest.java b/protobuf/src/test/java/dev/cel/protobuf/CelLiteDescriptorEvaluationTest.java new file mode 100644 index 000000000..f4dac7960 --- /dev/null +++ b/protobuf/src/test/java/dev/cel/protobuf/CelLiteDescriptorEvaluationTest.java @@ -0,0 +1,386 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.NullValue; +import com.google.protobuf.StringValue; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelOptions; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.expr.conformance.proto3.NestedTestAllTypes; +import dev.cel.expr.conformance.proto3.TestAllTypes; +import dev.cel.expr.conformance.proto3.TestAllTypes.NestedEnum; +import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; +import dev.cel.expr.conformance.proto3.TestAllTypesCelLiteDescriptor; +import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeFactory; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class CelLiteDescriptorEvaluationTest { + private static final CelCompiler CEL_COMPILER = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .addVar("content", SimpleType.DYN) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("cel.expr.conformance.proto3") + .build(); + + private static final CelRuntime CEL_RUNTIME = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(CelOptions.current().enableCelValue(true).build()) + .addCelLiteDescriptors(TestAllTypesCelLiteDescriptor.getDescriptor()) + .build(); + + @Test + public void messageCreation_emptyMessage() throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile("TestAllTypes{}").getAst(); + + TestAllTypes simpleTest = (TestAllTypes) CEL_RUNTIME.createProgram(ast).eval(); + + assertThat(simpleTest).isEqualTo(TestAllTypes.getDefaultInstance()); + } + + @Test + public void messageCreation_fieldsPopulated() throws Exception { + CelAbstractSyntaxTree ast = + CEL_COMPILER + .compile( + "TestAllTypes{" + + "single_int32: 4," + + "single_int64: 6," + + "single_float: 7.1," + + "single_double: 8.2," + + "single_nested_enum: TestAllTypes.NestedEnum.BAR," + + "repeated_int32: [1,2]," + + "repeated_int64: [3,4]," + + "map_string_int32: {'a': 1}," + + "map_string_int64: {'b': 2}," + + "single_int32_wrapper: google.protobuf.Int32Value{value: 9}," + + "single_int64_wrapper: google.protobuf.Int64Value{value: 10}," + + "single_float_wrapper: 11.1," + + "single_double_wrapper: 12.2," + + "single_uint32_wrapper: google.protobuf.UInt32Value{value: 13u}," + + "single_uint64_wrapper: google.protobuf.UInt64Value{value: 14u}," + + "oneof_type: NestedTestAllTypes {" + + " payload: TestAllTypes {" + + " single_bytes: b'abc'," + + " }" + + " }," + + "}") + .getAst(); + TestAllTypes expectedMessage = + TestAllTypes.newBuilder() + .setSingleInt32(4) + .setSingleInt64(6L) + .setSingleFloat(7.1f) + .setSingleDouble(8.2d) + .setSingleNestedEnum(NestedEnum.BAR) + .addAllRepeatedInt32(Arrays.asList(1, 2)) + .addAllRepeatedInt64(Arrays.asList(3L, 4L)) + .putMapStringInt32("a", 1) + .putMapStringInt64("b", 2) + .setSingleInt32Wrapper(Int32Value.of(9)) + .setSingleInt64Wrapper(Int64Value.of(10L)) + .setSingleFloatWrapper(FloatValue.of(11.1f)) + .setSingleDoubleWrapper(DoubleValue.of(12.2d)) + .setSingleUint32Wrapper(UInt32Value.of(13)) + .setSingleUint64Wrapper(UInt64Value.of(14L)) + .setOneofType( + NestedTestAllTypes.newBuilder() + .setPayload( + TestAllTypes.newBuilder().setSingleBytes(ByteString.copyFromUtf8("abc")))) + .build(); + + TestAllTypes simpleTest = (TestAllTypes) CEL_RUNTIME.createProgram(ast).eval(); + + assertThat(simpleTest).isEqualTo(expectedMessage); + } + + @Test + @TestParameters("{expression: 'msg.single_int32 == 1'}") + @TestParameters("{expression: 'msg.single_int64 == 2'}") + @TestParameters("{expression: 'msg.single_uint32 == 3u'}") + @TestParameters("{expression: 'msg.single_uint64 == 4u'}") + @TestParameters("{expression: 'msg.single_sint32 == 5'}") + @TestParameters("{expression: 'msg.single_sint64 == 6'}") + @TestParameters("{expression: 'msg.single_fixed32 == 7u'}") + @TestParameters("{expression: 'msg.single_fixed64 == 8u'}") + @TestParameters("{expression: 'msg.single_sfixed32 == 9'}") + @TestParameters("{expression: 'msg.single_sfixed64 == 10'}") + @TestParameters("{expression: 'msg.single_float == 1.5'}") + @TestParameters("{expression: 'msg.single_double == 2.5'}") + @TestParameters("{expression: 'msg.single_bool == true'}") + @TestParameters("{expression: 'msg.single_string == \"foo\"'}") + @TestParameters("{expression: 'msg.single_bytes == b\"abc\"'}") + @TestParameters("{expression: 'msg.optional_bool == true'}") + public void fieldSelection_literals(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .setSingleInt32(1) + .setSingleInt64(2L) + .setSingleUint32(3) + .setSingleUint64(4L) + .setSingleSint32(5) + .setSingleSint64(6L) + .setSingleFixed32(7) + .setSingleFixed64(8L) + .setSingleSfixed32(9) + .setSingleSfixed64(10L) + .setSingleFloat(1.5f) + .setSingleDouble(2.5d) + .setSingleBool(true) + .setSingleString("foo") + .setSingleBytes(ByteString.copyFromUtf8("abc")) + .setOptionalBool(true) + .build(); + + boolean result = (boolean) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isTrue(); + } + + @Test + @TestParameters("{expression: 'msg.single_uint32'}") + @TestParameters("{expression: 'msg.single_uint64'}") + public void fieldSelection_unsigned(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = TestAllTypes.newBuilder().setSingleUint32(4).setSingleUint64(4L).build(); + + Object result = CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isEqualTo(UnsignedLong.valueOf(4L)); + } + + @Test + @TestParameters("{expression: 'msg.repeated_int32'}") + @TestParameters("{expression: 'msg.repeated_int64'}") + @SuppressWarnings("unchecked") + public void fieldSelection_list(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .addRepeatedInt32(1) + .addRepeatedInt32(2) + .addRepeatedInt64(1L) + .addRepeatedInt64(2L) + .build(); + + List result = + (List) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).containsExactly(1L, 2L).inOrder(); + } + + @Test + @TestParameters("{expression: 'msg.map_string_int32'}") + @TestParameters("{expression: 'msg.map_string_int64'}") + @SuppressWarnings("unchecked") + public void fieldSelection_map(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .putMapStringInt32("a", 1) + .putMapStringInt32("b", 2) + .putMapStringInt64("a", 1L) + .putMapStringInt64("b", 2L) + .build(); + + Map result = + (Map) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).containsExactly("a", 1L, "b", 2L); + } + + @Test + @TestParameters("{expression: 'msg.single_int32_wrapper == 1'}") + @TestParameters("{expression: 'msg.single_int64_wrapper == 2'}") + @TestParameters("{expression: 'msg.single_uint32_wrapper == 3u'}") + @TestParameters("{expression: 'msg.single_uint64_wrapper == 4u'}") + @TestParameters("{expression: 'msg.single_float_wrapper == 1.5'}") + @TestParameters("{expression: 'msg.single_double_wrapper == 2.5'}") + @TestParameters("{expression: 'msg.single_bool_wrapper == true'}") + @TestParameters("{expression: 'msg.single_string_wrapper == \"foo\"'}") + @TestParameters("{expression: 'msg.single_bytes_wrapper == b\"abc\"'}") + public void fieldSelection_wrappers(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .setSingleInt32Wrapper(Int32Value.of(1)) + .setSingleInt64Wrapper(Int64Value.of(2L)) + .setSingleUint32Wrapper(UInt32Value.of(3)) + .setSingleUint64Wrapper(UInt64Value.of(4L)) + .setSingleFloatWrapper(FloatValue.of(1.5f)) + .setSingleDoubleWrapper(DoubleValue.of(2.5d)) + .setSingleBoolWrapper(BoolValue.of(true)) + .setSingleStringWrapper(StringValue.of("foo")) + .setSingleBytesWrapper(BytesValue.of(ByteString.copyFromUtf8("abc"))) + .build(); + + boolean result = (boolean) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isTrue(); + } + + @Test + @TestParameters("{expression: 'msg.single_int32_wrapper'}") + @TestParameters("{expression: 'msg.single_int64_wrapper'}") + @TestParameters("{expression: 'msg.single_uint32_wrapper'}") + @TestParameters("{expression: 'msg.single_uint64_wrapper'}") + @TestParameters("{expression: 'msg.single_float_wrapper'}") + @TestParameters("{expression: 'msg.single_double_wrapper'}") + @TestParameters("{expression: 'msg.single_bool_wrapper'}") + @TestParameters("{expression: 'msg.single_string_wrapper'}") + @TestParameters("{expression: 'msg.single_bytes_wrapper'}") + public void fieldSelection_wrappersNullability(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = TestAllTypes.newBuilder().build(); + + Object result = CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isEqualTo(NullValue.NULL_VALUE); + } + + @Test + @TestParameters("{expression: 'has(msg.single_int32)'}") + @TestParameters("{expression: 'has(msg.single_int64)'}") + @TestParameters("{expression: 'has(msg.single_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.single_int64_wrapper)'}") + @TestParameters("{expression: 'has(msg.repeated_int32)'}") + @TestParameters("{expression: 'has(msg.repeated_int64)'}") + @TestParameters("{expression: 'has(msg.repeated_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.repeated_int64_wrapper)'}") + @TestParameters("{expression: 'has(msg.map_string_int32)'}") + @TestParameters("{expression: 'has(msg.map_string_int64)'}") + @TestParameters("{expression: 'has(msg.map_bool_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.map_bool_int64_wrapper)'}") + public void presenceTest_evaluatesToFalse(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .setSingleInt32(0) + .addAllRepeatedInt32(ImmutableList.of()) + .addAllRepeatedInt32Wrapper(ImmutableList.of()) + .putAllMapBoolInt32(ImmutableMap.of()) + .putAllMapBoolInt32Wrapper(ImmutableMap.of()) + .build(); + + boolean result = (boolean) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isFalse(); + } + + @Test + @TestParameters("{expression: 'has(msg.single_int32)'}") + @TestParameters("{expression: 'has(msg.single_int64)'}") + @TestParameters("{expression: 'has(msg.single_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.single_int64_wrapper)'}") + @TestParameters("{expression: 'has(msg.repeated_int32)'}") + @TestParameters("{expression: 'has(msg.repeated_int64)'}") + @TestParameters("{expression: 'has(msg.repeated_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.repeated_int64_wrapper)'}") + @TestParameters("{expression: 'has(msg.map_string_int32)'}") + @TestParameters("{expression: 'has(msg.map_string_int64)'}") + @TestParameters("{expression: 'has(msg.map_string_int32_wrapper)'}") + @TestParameters("{expression: 'has(msg.map_string_int64_wrapper)'}") + public void presenceTest_evaluatesToTrue(String expression) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expression).getAst(); + TestAllTypes msg = + TestAllTypes.newBuilder() + .setSingleInt32(1) + .setSingleInt64(2) + .setSingleInt32Wrapper(Int32Value.of(0)) + .setSingleInt64Wrapper(Int64Value.of(0)) + .addAllRepeatedInt32(ImmutableList.of(1)) + .addAllRepeatedInt64(ImmutableList.of(2L)) + .addAllRepeatedInt32Wrapper(ImmutableList.of(Int32Value.of(0))) + .addAllRepeatedInt64Wrapper(ImmutableList.of(Int64Value.of(0L))) + .putAllMapStringInt32Wrapper(ImmutableMap.of("a", Int32Value.of(1))) + .putAllMapStringInt64Wrapper(ImmutableMap.of("b", Int64Value.of(2L))) + .putMapStringInt32("a", 1) + .putMapStringInt64("b", 2) + .build(); + + boolean result = (boolean) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", msg)); + + assertThat(result).isTrue(); + } + + @Test + public void nestedMessage() throws Exception { + CelAbstractSyntaxTree ast = + CEL_COMPILER + .compile("msg.single_nested_message.bb == 43 && has(msg.single_nested_message)") + .getAst(); + TestAllTypes nestedMessage = + TestAllTypes.newBuilder() + .setSingleNestedMessage(NestedMessage.newBuilder().setBb(43)) + .build(); + + boolean result = + (boolean) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", nestedMessage)); + + assertThat(result).isTrue(); + } + + @Test + public void enumSelection() throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile("msg.single_nested_enum").getAst(); + TestAllTypes nestedMessage = + TestAllTypes.newBuilder().setSingleNestedEnum(NestedEnum.BAR).build(); + + Long result = (Long) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", nestedMessage)); + + assertThat(result).isEqualTo(NestedEnum.BAR.getNumber()); + } + + @Test + public void anyMessage_packUnpack() throws Exception { + CelAbstractSyntaxTree ast = + CEL_COMPILER.compile("TestAllTypes { single_any: content }.single_any").getAst(); + TestAllTypes content = TestAllTypes.newBuilder().setSingleInt64(1L).build(); + + TestAllTypes result = + (TestAllTypes) CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("content", content)); + + assertThat(result).isEqualTo(content); + } +} diff --git a/protobuf/src/test/java/dev/cel/protobuf/ProtoDescriptorCollectorTest.java b/protobuf/src/test/java/dev/cel/protobuf/ProtoDescriptorCollectorTest.java new file mode 100644 index 000000000..619393105 --- /dev/null +++ b/protobuf/src/test/java/dev/cel/protobuf/ProtoDescriptorCollectorTest.java @@ -0,0 +1,50 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.protobuf; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.expr.conformance.proto3.TestAllTypes; +import dev.cel.protobuf.CelLiteDescriptor.MessageDescriptor; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class ProtoDescriptorCollectorTest { + + private static final ProtoDescriptorCollector COLLECTOR = new ProtoDescriptorCollector(); + + @Test + public void collectMessageInfo() throws Exception { + ImmutableList messageInfoList = + COLLECTOR.collectMessageInfo(TestAllTypes.getDescriptor().getFile()); + + assertThat(messageInfoList.size()).isEqualTo(3); + assertThat(messageInfoList.get(0).getFullyQualifiedProtoName()) + .isEqualTo("cel.expr.conformance.proto3.TestAllTypes"); + assertThat(messageInfoList.get(0).getFullyQualifiedProtoJavaClassName()) + .isEqualTo("dev.cel.expr.conformance.proto3.TestAllTypes"); + assertThat(messageInfoList.get(1).getFullyQualifiedProtoName()) + .isEqualTo("cel.expr.conformance.proto3.TestAllTypes.NestedMessage"); + assertThat(messageInfoList.get(1).getFullyQualifiedProtoJavaClassName()) + .isEqualTo("dev.cel.expr.conformance.proto3.TestAllTypes$NestedMessage"); + assertThat(messageInfoList.get(2).getFullyQualifiedProtoName()) + .isEqualTo("cel.expr.conformance.proto3.NestedTestAllTypes"); + assertThat(messageInfoList.get(2).getFullyQualifiedProtoJavaClassName()) + .isEqualTo("dev.cel.expr.conformance.proto3.NestedTestAllTypes"); + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/Activation.java b/runtime/src/main/java/dev/cel/runtime/Activation.java index 0ea70662e..b7757e771 100644 --- a/runtime/src/main/java/dev/cel/runtime/Activation.java +++ b/runtime/src/main/java/dev/cel/runtime/Activation.java @@ -23,7 +23,6 @@ import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Message; import dev.cel.common.CelOptions; -import dev.cel.common.ExprFeatures; import dev.cel.common.annotations.Internal; import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; @@ -137,24 +136,6 @@ public String toString() { }; } - /** - * Creates an {@code Activation} from a {@code Message} where each field in the message is exposed - * as a top-level variable in the {@code Activation}. - * - *

Unset message fields are published with the default value for the field type. However, an - * unset {@code google.protobuf.Any} value is not a valid CEL value, and will be published as an - * {@code Exception} value on the {@code Activation} just as though an unset {@code Any} would if - * it were accessed during a CEL evaluation. - * - *

Note, this call does not support unsigned integer fields properly and encodes them as long - * values. If {@link ExprFeatures#ENABLE_UNSIGNED_LONGS} is in use, use {@link #fromProto(Message, - * CelOptions)} to ensure that the message fields are properly designated as {@code UnsignedLong} - * values. - */ - public static Activation fromProto(Message message) { - return fromProto(message, CelOptions.LEGACY); - } - /** * Creates an {@code Activation} from a {@code Message} where each field in the message is exposed * as a top-level variable in the {@code Activation}. diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 93e8fe8e9..ec5af8a4b 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -173,13 +173,19 @@ java_library( "//common:options", "//common/annotations", "//common/internal:cel_descriptor_pools", + "//common/internal:cel_lite_descriptor_pool", "//common/internal:comparison_functions", "//common/internal:default_message_factory", "//common/internal:dynamic_proto", + "//common/internal:proto_lite_adapter", "//common/internal:proto_message_factory", "//common/types:cel_types", "//common/values:cel_value_provider", + "//common/values:proto_message_lite_value", + "//common/values:proto_message_lite_value_provider", + "//common/values:proto_message_value", "//common/values:proto_message_value_provider", + "//protobuf:cel_lite_descriptor", "//runtime:interpreter", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", @@ -240,15 +246,12 @@ java_library( deps = [ ":unknown_attributes", "//common:error_codes", - "//common:options", "//common:runtime_exception", "//common/annotations", - "//common/internal:cel_descriptor_pools", - "//common/internal:dynamic_proto", "//common/values", + "//common/values:base_proto_cel_value_converter", "//common/values:cel_value", "//common/values:cel_value_provider", - "//common/values:proto_message_value", "//runtime:interpreter", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java index 4ca72b674..36b9a7cf8 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java @@ -23,6 +23,7 @@ import com.google.protobuf.Message; import dev.cel.common.CelOptions; import dev.cel.common.values.CelValueProvider; +import dev.cel.protobuf.CelLiteDescriptor; import java.util.function.Function; /** Interface for building an instance of CelRuntime */ @@ -78,6 +79,12 @@ public interface CelRuntimeBuilder { @CanIgnoreReturnValue CelRuntimeBuilder addMessageTypes(Iterable descriptors); + @CanIgnoreReturnValue + CelRuntimeBuilder addCelLiteDescriptors(CelLiteDescriptor... descriptors); + + @CanIgnoreReturnValue + CelRuntimeBuilder addCelLiteDescriptors(Iterable descriptors); + /** * Add {@link FileDescriptor}s to the use for type-checking, and for object creation at * interpretation time. diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index dbfaf98c3..dbfc55057 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -34,15 +34,21 @@ import dev.cel.common.CelOptions; import dev.cel.common.annotations.Internal; import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.CelLiteDescriptorPool; import dev.cel.common.internal.CombinedDescriptorPool; import dev.cel.common.internal.DefaultDescriptorPool; import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.internal.DynamicProto; // CEL-Internal-3 +import dev.cel.common.internal.ProtoLiteAdapter; import dev.cel.common.internal.ProtoMessageFactory; import dev.cel.common.types.CelTypes; import dev.cel.common.values.CelValueProvider; +import dev.cel.common.values.ProtoCelValueConverter; +import dev.cel.common.values.ProtoLiteCelValueConverter; +import dev.cel.common.values.ProtoMessageLiteValueProvider; import dev.cel.common.values.ProtoMessageValueProvider; +import dev.cel.protobuf.CelLiteDescriptor; import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Arithmetic; import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Comparison; import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Conversions; @@ -92,6 +98,7 @@ public static final class Builder implements CelRuntimeBuilder { private final ImmutableSet.Builder fileTypes; private final HashMap customFunctionBindings; private final ImmutableSet.Builder celRuntimeLibraries; + private final ImmutableSet.Builder celLiteDescriptorBuilder; @SuppressWarnings("unused") private CelOptions options; @@ -129,6 +136,17 @@ public CelRuntimeBuilder addMessageTypes(Iterable descriptors) { return addFileTypes(CelDescriptorUtil.getFileDescriptorsForDescriptors(descriptors)); } + @Override + public CelRuntimeBuilder addCelLiteDescriptors(CelLiteDescriptor... descriptors) { + return addCelLiteDescriptors(Arrays.asList(descriptors)); + } + + @Override + public CelRuntimeBuilder addCelLiteDescriptors(Iterable descriptors) { + this.celLiteDescriptorBuilder.addAll(descriptors); + return this; + } + @Override public CelRuntimeBuilder addFileTypes(FileDescriptor... fileDescriptors) { return addFileTypes(Arrays.asList(fileDescriptors)); @@ -273,16 +291,44 @@ public CelRuntimeLegacyImpl build() { RuntimeTypeProvider runtimeTypeProvider; if (options.enableCelValue()) { - CelValueProvider messageValueProvider = - ProtoMessageValueProvider.newInstance(dynamicProto, options); - if (celValueProvider != null) { - messageValueProvider = - new CelValueProvider.CombinedCelValueProvider(celValueProvider, messageValueProvider); + ImmutableSet liteDescriptors = celLiteDescriptorBuilder.build(); + if (liteDescriptors.isEmpty()) { + CelValueProvider messageValueProvider = + ProtoMessageValueProvider.newInstance(dynamicProto, options); + if (celValueProvider != null) { + messageValueProvider = + new CelValueProvider.CombinedCelValueProvider( + celValueProvider, messageValueProvider); + } + + ProtoCelValueConverter protoCelValueConverter = + ProtoCelValueConverter.newInstance(options, celDescriptorPool, dynamicProto); + + runtimeTypeProvider = + new RuntimeTypeProviderLegacyImpl(messageValueProvider, protoCelValueConverter); + } else { + CelLiteDescriptorPool celLiteDescriptorPool = + CelLiteDescriptorPool.newInstance(liteDescriptors); + + // TODO: instantiate these dependencies within ProtoMessageLiteValueProvider. + // For now, they need to be outside to instantiate the RuntimeTypeProviderLegacyImpl + // adapter. + ProtoLiteAdapter protoLiteAdapter = new ProtoLiteAdapter(options.enableUnsignedLongs()); + ProtoLiteCelValueConverter protoLiteCelValueConverter = + ProtoLiteCelValueConverter.newInstance(options, celLiteDescriptorPool); + CelValueProvider messageValueProvider = + ProtoMessageLiteValueProvider.newInstance( + protoLiteCelValueConverter, protoLiteAdapter, celLiteDescriptorPool); + if (celValueProvider != null) { + messageValueProvider = + new CelValueProvider.CombinedCelValueProvider( + celValueProvider, messageValueProvider); + } + + runtimeTypeProvider = + new RuntimeTypeProviderLegacyImpl(messageValueProvider, protoLiteCelValueConverter); } - runtimeTypeProvider = - new RuntimeTypeProviderLegacyImpl( - options, messageValueProvider, celDescriptorPool, dynamicProto); } else { runtimeTypeProvider = new DescriptorMessageProvider(runtimeTypeFactory, options); } @@ -375,6 +421,7 @@ private Builder() { this.fileTypes = ImmutableSet.builder(); this.customFunctionBindings = new HashMap<>(); this.celRuntimeLibraries = ImmutableSet.builder(); + this.celLiteDescriptorBuilder = ImmutableSet.builder(); this.extensionRegistry = ExtensionRegistry.getEmptyRegistry(); this.customTypeFactory = null; } @@ -391,6 +438,7 @@ private Builder(Builder builder) { this.fileTypes = deepCopy(builder.fileTypes); this.celRuntimeLibraries = deepCopy(builder.celRuntimeLibraries); this.customFunctionBindings = new HashMap<>(builder.customFunctionBindings); + this.celLiteDescriptorBuilder = builder.celLiteDescriptorBuilder; } private static ImmutableSet.Builder deepCopy(ImmutableSet.Builder builderToCopy) { diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeTypeProviderLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/RuntimeTypeProviderLegacyImpl.java index 36bd054f3..537cd0f15 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeTypeProviderLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeTypeProviderLegacyImpl.java @@ -17,14 +17,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelErrorCode; -import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; import dev.cel.common.annotations.Internal; -import dev.cel.common.internal.CelDescriptorPool; -import dev.cel.common.internal.DynamicProto; +import dev.cel.common.values.BaseProtoCelValueConverter; import dev.cel.common.values.CelValue; import dev.cel.common.values.CelValueProvider; -import dev.cel.common.values.ProtoCelValueConverter; import dev.cel.common.values.SelectableValue; import dev.cel.common.values.StringValue; import java.util.Map; @@ -36,17 +33,13 @@ public final class RuntimeTypeProviderLegacyImpl implements RuntimeTypeProvider { private final CelValueProvider valueProvider; - private final ProtoCelValueConverter protoCelValueConverter; + private final BaseProtoCelValueConverter protoCelValueConverter; @VisibleForTesting public RuntimeTypeProviderLegacyImpl( - CelOptions celOptions, - CelValueProvider valueProvider, - CelDescriptorPool celDescriptorPool, - DynamicProto dynamicProto) { + CelValueProvider valueProvider, BaseProtoCelValueConverter protoCelValueConverter) { this.valueProvider = valueProvider; - this.protoCelValueConverter = - ProtoCelValueConverter.newInstance(celOptions, celDescriptorPool, dynamicProto); + this.protoCelValueConverter = protoCelValueConverter; } @Override diff --git a/runtime/src/test/java/dev/cel/runtime/ActivationTest.java b/runtime/src/test/java/dev/cel/runtime/ActivationTest.java index fc7d4a9ac..c4271104a 100644 --- a/runtime/src/test/java/dev/cel/runtime/ActivationTest.java +++ b/runtime/src/test/java/dev/cel/runtime/ActivationTest.java @@ -164,7 +164,8 @@ public void fromProto_unsignedLongField_signedResult() { TestAllTypes.newBuilder() .setSingleUint32(1) .setSingleUint64(UnsignedLong.MAX_VALUE.longValue()) - .build()); + .build(), + CelOptions.LEGACY); assertThat((Long) activation.resolve("single_uint32")).isEqualTo(1L); assertThat((Long) activation.resolve("single_uint64")).isEqualTo(-1L); } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 8aaa60ed1..c1c906ee6 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -1,4 +1,5 @@ load("@rules_java//java:defs.bzl", "java_library") +load("//:java_lite_proto_cel_library.bzl", "java_lite_proto_cel_library") load("//:testing.bzl", "junit4_test_suites") package(default_applicable_licenses = ["//:license"]) @@ -10,6 +11,7 @@ java_library( ["*.java"], exclude = [ "CelValueInterpreterTest.java", + "CelLiteDescriptorInterpreterTest.java", "InterpreterTest.java", ], ), @@ -88,6 +90,30 @@ java_library( ], ) +java_library( + name = "cel_lite_descriptor_interpreter_test", + testonly = 1, + srcs = [ + "CelLiteDescriptorInterpreterTest.java", + ], + deps = [ + ":test_all_types_proto3_java_lite_cel_proto", + "//extensions:optional_library", + "//runtime", + "//testing:base_interpreter_test", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +java_lite_proto_cel_library( + name = "test_all_types_proto3_java_lite_cel_proto", + descriptor_class_prefix = "TestAllTypes", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", + ], +) + junit4_test_suites( name = "test_suites", shard_count = 4, @@ -97,6 +123,7 @@ junit4_test_suites( ], src_dir = "src/test/java", deps = [ + ":cel_lite_descriptor_interpreter_test", ":cel_value_interpreter_test", ":interpreter_test", ":tests", diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteDescriptorInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteDescriptorInterpreterTest.java new file mode 100644 index 000000000..bf752148a --- /dev/null +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteDescriptorInterpreterTest.java @@ -0,0 +1,48 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.expr.conformance.proto3.TestAllTypesCelLiteDescriptor; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.testing.BaseInterpreterTest; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class CelLiteDescriptorInterpreterTest extends BaseInterpreterTest { + public CelLiteDescriptorInterpreterTest(@TestParameter InterpreterTestOption testOption) { + super( + testOption.celOptions.toBuilder().enableCelValue(true).build(), + testOption.useNativeCelType, + CelRuntimeFactory.standardCelRuntimeBuilder() + .addCelLiteDescriptors(TestAllTypesCelLiteDescriptor.getDescriptor()) + .addLibraries(CelOptionalLibrary.INSTANCE) + .setOptions(testOption.celOptions.toBuilder().enableCelValue(true).build()) + .build()); + } + + @Override + public void dynamicMessage_adapted() throws Exception { + // Dynamic message is not supported in Protolite + skipBaselineVerification(); + } + + @Override + public void dynamicMessage_dynamicDescriptor() throws Exception { + // Dynamic message is not supported in Protolite + skipBaselineVerification(); + } +} diff --git a/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java b/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java index b9092dd9e..5dcab8c10 100644 --- a/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DescriptorMessageProviderTest.java @@ -207,7 +207,8 @@ public void createMessage_wellKnownType_withCustomMessageProvider( return; } - Descriptor wellKnownDescriptor = wellKnownProto.descriptor(); + Descriptor wellKnownDescriptor = + DefaultDescriptorPool.INSTANCE.findDescriptor(wellKnownProto.typeName()).get(); DescriptorMessageProvider messageProvider = new DescriptorMessageProvider( msgName -> diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index a14d359ac..0ef072b3b 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -126,14 +126,21 @@ protected enum InterpreterTestOption { private CelRuntime celRuntime; public BaseInterpreterTest(CelOptions celOptions, boolean useNativeCelType) { - super(useNativeCelType); - this.celOptions = celOptions; - this.celRuntime = + this( + celOptions, + useNativeCelType, CelRuntimeFactory.standardCelRuntimeBuilder() .addLibraries(CelOptionalLibrary.INSTANCE) .addFileTypes(TEST_FILE_DESCRIPTORS) .setOptions(celOptions) - .build(); + .build()); + } + + public BaseInterpreterTest( + CelOptions celOptions, boolean useNativeCelType, CelRuntime celRuntime) { + super(useNativeCelType); + this.celOptions = celOptions; + this.celRuntime = celRuntime; } @Override