diff --git a/core/src/main/java/org/infinispan/protostream/WrappedMessage.java b/core/src/main/java/org/infinispan/protostream/WrappedMessage.java index 61d4fe857..27179e028 100644 --- a/core/src/main/java/org/infinispan/protostream/WrappedMessage.java +++ b/core/src/main/java/org/infinispan/protostream/WrappedMessage.java @@ -1,11 +1,5 @@ package org.infinispan.protostream; -import java.io.IOException; -import java.time.Instant; -import java.util.Date; -import java.util.Iterator; -import java.util.Objects; - import org.infinispan.protostream.containers.ElementContainerAdapter; import org.infinispan.protostream.containers.IndexedElementContainerAdapter; import org.infinispan.protostream.containers.IterableElementContainerAdapter; @@ -17,6 +11,12 @@ import org.infinispan.protostream.impl.TagReaderImpl; import org.infinispan.protostream.impl.TagWriterImpl; +import java.io.IOException; +import java.time.Instant; +import java.util.Date; +import java.util.Iterator; +import java.util.Objects; + /** * A wrapper for messages, enums or primitive types that encodes the type of the inner object/value and also helps keep * track of where the message ends. The need for this wrapper stems from two particular design choices in the Protocol @@ -244,27 +244,32 @@ static void write(ImmutableSerializationContext ctx, TagWriter out, Object t) th } private static void writeMessage(ImmutableSerializationContext ctx, TagWriter out, Object t, boolean nulls) throws IOException { + if (tryWritePrimitive(out, t, nulls)) { + return; + } + writeCustomObject(ctx, out, t); + } + + private static boolean tryWritePrimitive(TagWriter out, Object t, boolean nulls) throws IOException { + // primitives or single tag objects (Date) are written by this method. if (t == null) { if (nulls) { out.writeBool(WRAPPED_EMPTY, true); + out.flush(); } - return; + return true; } if (t instanceof String) { out.writeString(WRAPPED_STRING, (String) t); } else if (t instanceof Character) { - out.writeInt32(WRAPPED_CHAR, ((Character) t).charValue()); + out.writeInt32(WRAPPED_CHAR, (Character) t); } else if (t instanceof Byte) { - out.writeInt32(WRAPPED_BYTE, ((Byte) t).byteValue()); + out.writeInt32(WRAPPED_BYTE, (Byte) t); } else if (t instanceof Short) { - out.writeInt32(WRAPPED_SHORT, ((Short) t).shortValue()); + out.writeInt32(WRAPPED_SHORT, (Short) t); } else if (t instanceof Date) { out.writeInt64(WRAPPED_DATE_MILLIS, ((Date) t).getTime()); - } else if (t instanceof Instant) { - Instant instant = (Instant) t; - out.writeInt64(WRAPPED_INSTANT_SECONDS, instant.getEpochSecond()); - out.writeInt32(WRAPPED_INSTANT_NANOS, instant.getNano()); } else if (t instanceof Long) { out.writeInt64(WRAPPED_INT64, (Long) t); } else if (t instanceof Integer) { @@ -278,31 +283,43 @@ private static void writeMessage(ImmutableSerializationContext ctx, TagWriter ou } else if (t instanceof byte[]) { out.writeBytes(WRAPPED_BYTES, (byte[]) t); } else { - // This is either a message type or an enum. Try to lookup a marshaller. - BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(t); - BaseMarshaller marshaller = marshallerDelegate.getMarshaller(); + return false; + } + out.flush(); + return true; + } - if (marshaller instanceof ElementContainerAdapter) { - writeContainer(ctx, out, marshallerDelegate, t); + private static void writeCustomObject(ImmutableSerializationContext ctx, TagWriter out, T t) throws IOException { + if (t instanceof Instant instant) { + out.writeInt64(WRAPPED_INSTANT_SECONDS, instant.getEpochSecond()); + out.writeInt32(WRAPPED_INSTANT_NANOS, instant.getNano()); + out.flush(); + return; + } + // This is either a message type or an enum. Try to lookup a marshaller. + BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(t); + BaseMarshaller marshaller = marshallerDelegate.getMarshaller(); + + if (marshaller instanceof ElementContainerAdapter) { + writeContainer(ctx, out, marshallerDelegate, t); + } else { + // Write the type discriminator, either the fully qualified name or a numeric type id. + String typeName = marshaller.getTypeName(); + int typeId = mapTypeIdOut(typeName, ctx); + if (typeId < 0) { + out.writeString(WRAPPED_TYPE_NAME, typeName); } else { - // Write the type discriminator, either the fully qualified name or a numeric type id. - String typeName = marshaller.getTypeName(); - int typeId = mapTypeIdOut(typeName, ctx); - if (typeId < 0) { - out.writeString(WRAPPED_TYPE_NAME, typeName); - } else { - out.writeUInt32(WRAPPED_TYPE_ID, typeId); - } + out.writeUInt32(WRAPPED_TYPE_ID, typeId); + } - if (t.getClass().isEnum()) { - ((EnumMarshallerDelegate) marshallerDelegate).encode(WRAPPED_ENUM, (Enum) t, out); - } else { - ByteArrayOutputStreamEx buffer = new ByteArrayOutputStreamEx(); - TagWriterImpl nestedCtx = TagWriterImpl.newInstanceNoBuffer(ctx, buffer); - marshallerDelegate.marshall(nestedCtx, null, t); - nestedCtx.flush(); - out.writeBytes(WRAPPED_MESSAGE, buffer.getByteBuffer()); - } + if (t.getClass().isEnum()) { + ((EnumMarshallerDelegate) marshallerDelegate).encode(WRAPPED_ENUM, (Enum) t, out); + } else { + ByteArrayOutputStreamEx buffer = new ByteArrayOutputStreamEx(); + TagWriterImpl nestedCtx = TagWriterImpl.newInstanceNoBuffer(ctx, buffer); + marshallerDelegate.marshall(nestedCtx, null, t); + nestedCtx.flush(); + out.writeBytes(WRAPPED_MESSAGE, buffer.getByteBuffer()); } } out.flush(); @@ -328,11 +345,39 @@ private static void writeContainer(ImmutableSerializationContext ctx, TagWriter nestedCtx.flush(); out.writeBytes(WRAPPED_CONTAINER_MESSAGE, buffer.getByteBuffer()); + if (ctx.getConfiguration().wrapCollectionElements()) { + writeContainerWrappingElements(containerMarshaller, containerSize, container, ctx, out, buffer); + } else { + writeContainerWithoutWrappingElements(containerMarshaller, containerSize, container, ctx, out); + } + } + + private static void writeContainerWrappingElements(BaseMarshaller containerMarshaller, int containerSize, Object container, + ImmutableSerializationContext ctx, TagWriter out, ByteArrayOutputStreamEx buffer) throws IOException { + if (containerMarshaller instanceof IterableElementContainerAdapter) { + Iterator elements = ((IterableElementContainerAdapter) containerMarshaller).getElements(container); + for (int i = 0; i < containerSize; i++) { + writeContainerElementWrapped(ctx, out, buffer, elements.next()); + } + if (elements.hasNext()) { + throw new IllegalStateException("Container number of elements mismatch"); + } + } else if (containerMarshaller instanceof IndexedElementContainerAdapter) { + IndexedElementContainerAdapter adapter = (IndexedElementContainerAdapter) containerMarshaller; + for (int i = 0; i < containerSize; i++) { + writeContainerElementWrapped(ctx, out, buffer, adapter.getElement(container, i)); + } + } else { + throw new IllegalStateException("Unknown container adapter kind : " + containerMarshaller.getJavaClass().getName()); + } + } + + private static void writeContainerWithoutWrappingElements(BaseMarshaller containerMarshaller, int containerSize, Object container, + ImmutableSerializationContext ctx, TagWriter out) throws IOException { if (containerMarshaller instanceof IterableElementContainerAdapter) { Iterator elements = ((IterableElementContainerAdapter) containerMarshaller).getElements(container); for (int i = 0; i < containerSize; i++) { - Object e = elements.next(); - writeMessage(ctx, out, e, true); + writeMessage(ctx, out, elements.next(), true); } if (elements.hasNext()) { throw new IllegalStateException("Container number of elements mismatch"); @@ -340,30 +385,142 @@ private static void writeContainer(ImmutableSerializationContext ctx, TagWriter } else if (containerMarshaller instanceof IndexedElementContainerAdapter) { IndexedElementContainerAdapter adapter = (IndexedElementContainerAdapter) containerMarshaller; for (int i = 0; i < containerSize; i++) { - Object e = adapter.getElement(container, i); - writeMessage(ctx, out, e, true); + writeMessage(ctx, out, adapter.getElement(container, i), true); } } else { throw new IllegalStateException("Unknown container adapter kind : " + containerMarshaller.getJavaClass().getName()); } } + private static void writeContainerElementWrapped(ImmutableSerializationContext ctx, TagWriter out, ByteArrayOutputStreamEx buffer, Object e) throws IOException { + if (tryWritePrimitive(out, e, true)) { + return; + } + buffer.reset(); + TagWriterImpl elementWriter = TagWriterImpl.newInstanceNoBuffer(ctx, buffer); + writeMessage(ctx, elementWriter, e, true); + elementWriter.flush(); + out.writeBytes(WRAPPED_MESSAGE, buffer.getByteBuffer()); + } + static T read(ImmutableSerializationContext ctx, TagReader in) throws IOException { return readMessage(ctx, in, false); } private static T readMessage(ImmutableSerializationContext ctx, TagReader in, boolean nulls) throws IOException { + ValueOrTag primitiveValue = tryReadPrimitive(in, nulls); + if (primitiveValue.hasValue()) { + return primitiveValue.getValue(); + } + + assert primitiveValue.hasTag(); + return readCustomObject(primitiveValue.getTag(), ctx, in); + } + + private static ValueOrTag tryReadPrimitive(TagReader in, boolean nulls) throws IOException { + var tag = in.readTag(); + Object value = null; + switch (tag) { + case WRAPPED_EMPTY << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + if (!nulls) { + throw new IllegalStateException("Encountered a null message but nulls are not accepted"); + } + in.readBool(); // We ignore the actual boolean value! Will be returning null anyway. + break; + } + case WRAPPED_STRING << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { + value = in.readString(); + break; + } + case WRAPPED_CHAR << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = (char) in.readInt32(); + break; + } + case WRAPPED_SHORT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = (short) in.readInt32(); + break; + } + case WRAPPED_BYTE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = (byte) in.readInt32(); + break; + } + case WRAPPED_DATE_MILLIS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = new Date(in.readInt64()); + break; + } + case WRAPPED_BYTES << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { + value = in.readByteArray(); + break; + } + case WRAPPED_BOOL << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readBool(); + break; + } + case WRAPPED_DOUBLE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { + value = in.readDouble(); + break; + } + case WRAPPED_FLOAT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { + value = in.readFloat(); + break; + } + case WRAPPED_FIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { + value = in.readFixed32(); + break; + } + case WRAPPED_SFIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { + value = in.readSFixed32(); + break; + } + case WRAPPED_FIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { + value = in.readFixed64(); + break; + } + case WRAPPED_SFIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { + value = in.readSFixed64(); + break; + } + case WRAPPED_INT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readInt64(); + break; + } + case WRAPPED_UINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readUInt64(); + break; + } + case WRAPPED_SINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readSInt64(); + break; + } + case WRAPPED_INT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readInt32(); + break; + } + case WRAPPED_UINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readUInt32(); + break; + } + case WRAPPED_SINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { + value = in.readSInt32(); + break; + } + case 0: + return new Value<>(null); + default: + return new Tag<>(tag); + } + return new Value<>((T) value); + } + + private static T readCustomObject(int tag, ImmutableSerializationContext ctx, TagReader in) throws IOException { String typeName = null; Integer typeId = null; int enumValue = -1; byte[] messageBytes = null; Object value = null; int fieldCount = 0; - int expectedFieldCount = 1; - - int tag; - out: - while ((tag = in.readTag()) != 0) { + int expectedFieldCount; + do { fieldCount++; switch (tag) { case WRAPPED_CONTAINER_SIZE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: @@ -372,15 +529,7 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in case WRAPPED_CONTAINER_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { expectedFieldCount = 1; value = readContainer(ctx, in, tag); - break out; - } - case WRAPPED_EMPTY << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - if (!nulls) { - throw new IllegalStateException("Encountered a null message but nulls are not accepted"); - } - expectedFieldCount = 1; - in.readBool(); // We ignore the actual boolean value! Will be returning null anyway. - break out; + break; } case WRAPPED_TYPE_NAME << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { expectedFieldCount = 2; @@ -402,31 +551,6 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in messageBytes = in.readByteArray(); break; } - case WRAPPED_STRING << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { - expectedFieldCount = 1; - value = in.readString(); - break out; - } - case WRAPPED_CHAR << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = (char) in.readInt32(); - break out; - } - case WRAPPED_SHORT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = (short) in.readInt32(); - break out; - } - case WRAPPED_BYTE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = (byte) in.readInt32(); - break out; - } - case WRAPPED_DATE_MILLIS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = new Date(in.readInt64()); - break out; - } case WRAPPED_INSTANT_SECONDS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { expectedFieldCount = 2; long seconds = in.readInt64(); @@ -439,81 +563,11 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in value = value == null ? Instant.ofEpochSecond(0, nanos) : Instant.ofEpochSecond(((Instant) value).getEpochSecond(), nanos); break; } - case WRAPPED_BYTES << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { - expectedFieldCount = 1; - value = in.readByteArray(); - break out; - } - case WRAPPED_BOOL << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readBool(); - break out; - } - case WRAPPED_DOUBLE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { - expectedFieldCount = 1; - value = in.readDouble(); - break out; - } - case WRAPPED_FLOAT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { - expectedFieldCount = 1; - value = in.readFloat(); - break out; - } - case WRAPPED_FIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { - expectedFieldCount = 1; - value = in.readFixed32(); - break out; - } - case WRAPPED_SFIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32: { - expectedFieldCount = 1; - value = in.readSFixed32(); - break out; - } - case WRAPPED_FIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { - expectedFieldCount = 1; - value = in.readFixed64(); - break out; - } - case WRAPPED_SFIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64: { - expectedFieldCount = 1; - value = in.readSFixed64(); - break out; - } - case WRAPPED_INT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readInt64(); - break out; - } - case WRAPPED_UINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readUInt64(); - break out; - } - case WRAPPED_SINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readSInt64(); - break out; - } - case WRAPPED_INT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readInt32(); - break out; - } - case WRAPPED_UINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readUInt32(); - break out; - } - case WRAPPED_SINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT: { - expectedFieldCount = 1; - value = in.readSInt32(); - break out; - } default: throw new IllegalStateException("Unexpected tag : " + tag + " (Field number : " - + WireType.getTagFieldNumber(tag) + ", Wire type : " + WireType.getTagWireType(tag) + ")"); + + WireType.getTagFieldNumber(tag) + ", Wire type : " + WireType.getTagWireType(tag) + ")"); } - } + } while ((tag = in.readTag()) != 0); if (value == null && typeName == null && typeId == null && messageBytes == null) { return null; @@ -533,14 +587,14 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in if (typeId != null) { typeName = ctx.getDescriptorByTypeId(typeId).getFullName(); } - BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(typeName); + BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(typeName); if (messageBytes != null) { // it's a Message type TagReaderImpl nestedInput = TagReaderImpl.newInstance(ctx, messageBytes); - return (T) marshallerDelegate.unmarshall(nestedInput, null); + return marshallerDelegate.unmarshall(nestedInput, null); } else { // it's an Enum - EnumMarshaller marshaller = (EnumMarshaller) marshallerDelegate.getMarshaller(); + EnumMarshaller marshaller = (EnumMarshaller) marshallerDelegate.getMarshaller(); T e = (T) marshaller.decode(enumValue); if (e == null) { // Unknown enum value cause by schema evolution. We cannot handle data loss here so we throw! @@ -609,23 +663,58 @@ private static Object readContainer(ImmutableSerializationContext ctx, TagReader throw new IllegalStateException("The unmarshalled container must not be null"); } - if (containerMarshaller instanceof IterableElementContainerAdapter) { - IterableElementContainerAdapter adapter = (IterableElementContainerAdapter) containerMarshaller; - for (int i = 0; i < containerSize; i++) { - Object e = readMessage(ctx, in, true); - adapter.appendElement(container, e); + if (ctx.getConfiguration().wrapCollectionElements()) { + readContainerWithWrappedElements(containerMarshaller, containerSize, container, ctx, in); + } else { + readContainerWithoutWrappedElements(containerMarshaller, containerSize, container, ctx, in); + } + + return container; + } + + private static void readContainerWithWrappedElements(BaseMarshaller containerMarshaller, int containerSize, + Object container, ImmutableSerializationContext ctx, TagReader in) throws IOException { + if (containerMarshaller instanceof IterableElementContainerAdapter adapter) { + for (int i = 0; i < containerSize; i++) { + adapter.appendElement(container, readContainerElementWrapped(ctx, in)); } - } else if (containerMarshaller instanceof IndexedElementContainerAdapter) { - IndexedElementContainerAdapter adapter = (IndexedElementContainerAdapter) containerMarshaller; - for (int i = 0; i < containerSize; i++) { - Object e = readMessage(ctx, in, true); - adapter.setElement(container, i, e); + } else if (containerMarshaller instanceof IndexedElementContainerAdapter adapter) { + for (int i = 0; i < containerSize; i++) { + adapter.setElement(container, i, readContainerElementWrapped(ctx, in)); } } else { throw new IllegalStateException("Unknown container adapter kind : " + containerMarshaller.getJavaClass().getName()); } + } - return container; + private static void readContainerWithoutWrappedElements(BaseMarshaller containerMarshaller, int containerSize, + Object container, ImmutableSerializationContext ctx, TagReader in) throws IOException { + if (containerMarshaller instanceof IterableElementContainerAdapter adapter) { + for (int i = 0; i < containerSize; i++) { + adapter.appendElement(container, readMessage(ctx, in, true)); + } + } else if (containerMarshaller instanceof IndexedElementContainerAdapter adapter) { + for (int i = 0; i < containerSize; i++) { + adapter.setElement(container, i, readMessage(ctx, in, true)); + } + } else { + throw new IllegalStateException("Unknown container adapter kind : " + containerMarshaller.getJavaClass().getName()); + } + } + + private static E readContainerElementWrapped(ImmutableSerializationContext ctx, TagReader in) throws IOException { + ValueOrTag primitiveValue = tryReadPrimitive(in, true); + if (primitiveValue.hasValue()) { + return primitiveValue.getValue(); + } + assert primitiveValue.hasTag(); + var tag = primitiveValue.getTag(); + if (tag != (WRAPPED_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED)) { + throw new IllegalStateException("Unexpected tag : " + tag + " (Field number : " + + WireType.getTagFieldNumber(tag) + ", Wire type : " + WireType.getTagWireType(tag) + ")"); + } + var elementReader = TagReaderImpl.newInstance(ctx, in.readByteArray()); + return readMessage(ctx, elementReader, true); } /** @@ -691,4 +780,47 @@ public void write(WriteContext ctx, WrappedMessage wrappedMessage) throws IOExce writeMessage(ctx.getSerializationContext(), ctx.getWriter(), wrappedMessage.value, false); } }; + + private interface ValueOrTag { + default boolean hasValue() { + return false; + } + + default boolean hasTag() { + return false; + } + + default T getValue() { + throw new UnsupportedOperationException(); + } + + default int getTag() { + throw new UnsupportedOperationException(); + } + } + + private record Value(T value) implements ValueOrTag { + + @Override + public T getValue() { + return value; + } + + @Override + public boolean hasValue() { + return true; + } + } + + private record Tag(int tag) implements ValueOrTag { + @Override + public boolean hasTag() { + return true; + } + + @Override + public int getTag() { + return tag; + } + } } diff --git a/core/src/main/java/org/infinispan/protostream/config/Configuration.java b/core/src/main/java/org/infinispan/protostream/config/Configuration.java index bcc1c0376..3a04cd46c 100644 --- a/core/src/main/java/org/infinispan/protostream/config/Configuration.java +++ b/core/src/main/java/org/infinispan/protostream/config/Configuration.java @@ -43,6 +43,8 @@ public interface Configuration { */ int maxNestedMessageDepth(); + boolean wrapCollectionElements(); + AnnotationsConfig annotationsConfig(); interface AnnotationsConfig { @@ -82,6 +84,20 @@ interface Builder { Builder maxNestedMessageDepth(int maxNestedMessageDepth); + /** + * Wraps all the elements in a collection or array into a wrapped message. + *

+ * WARNING: enabling this option will change the binary format in an incompatible way. All readers/writers must + * have this option enabled or disabled in order to be able to parse the messages. Use with caution. + *

+ * This option is required to fix a bug (IPROTO-273) where collections (or arrays) of non-primitive classes are + * unable to be read. + * + * @param wrapCollectionElements {@code true} to enable wrap the elements, {@code false} otherwise. + * @return This instance. + */ + Builder wrapCollectionElements(boolean wrapCollectionElements); + AnnotationsConfig.Builder annotationsConfig(); Configuration build(); diff --git a/core/src/main/java/org/infinispan/protostream/config/impl/ConfigurationImpl.java b/core/src/main/java/org/infinispan/protostream/config/impl/ConfigurationImpl.java index 3f64fa2c1..e493cf70d 100644 --- a/core/src/main/java/org/infinispan/protostream/config/impl/ConfigurationImpl.java +++ b/core/src/main/java/org/infinispan/protostream/config/impl/ConfigurationImpl.java @@ -1,12 +1,12 @@ package org.infinispan.protostream.config.impl; -import java.util.HashMap; -import java.util.Map; - import org.infinispan.protostream.config.AnnotationConfiguration; import org.infinispan.protostream.config.Configuration; import org.infinispan.protostream.descriptors.AnnotationElement; +import java.util.HashMap; +import java.util.Map; + /** * @author anistor@redhat.com * @since 2.0 @@ -14,15 +14,15 @@ public final class ConfigurationImpl implements Configuration { private final boolean logOutOfSequenceReads; private final boolean logOutOfSequenceWrites; - private final boolean lenient; private final AnnotationsConfigImpl annotationsConfig; private final int maxNestedMessageDepth; + private final boolean wrapCollectionElements; private ConfigurationImpl(BuilderImpl builder, Map annotations) { this.logOutOfSequenceReads = builder.logOutOfSequenceReads; this.logOutOfSequenceWrites = builder.logOutOfSequenceWrites; - this.lenient = builder.lenient; this.maxNestedMessageDepth = builder.maxNestedMessageDepth; + this.wrapCollectionElements = builder.wrapCollectionElements; this.annotationsConfig = new AnnotationsConfigImpl(annotations, builder.logUndefinedAnnotations); } @@ -41,6 +41,11 @@ public int maxNestedMessageDepth() { return maxNestedMessageDepth; } + @Override + public boolean wrapCollectionElements() { + return wrapCollectionElements; + } + @Override public AnnotationsConfig annotationsConfig() { return annotationsConfig; @@ -85,10 +90,10 @@ public String toString() { public static final class BuilderImpl implements Builder { private boolean logOutOfSequenceReads = true; private boolean logOutOfSequenceWrites = true; - private boolean lenient = true; private int maxNestedMessageDepth = Configuration.DEFAULT_MAX_NESTED_DEPTH; private AnnotationsConfigBuilderImpl annotationsConfigBuilder = null; private Boolean logUndefinedAnnotations; + private boolean wrapCollectionElements; final class AnnotationsConfigBuilderImpl implements AnnotationsConfig.Builder { @@ -138,7 +143,6 @@ public Builder setLogOutOfSequenceWrites(boolean logOutOfSequenceWrites) { @Override public Builder setLenient(boolean lenient) { - this.lenient = lenient; return this; } @@ -148,6 +152,12 @@ public Builder maxNestedMessageDepth(int maxNestedMessageDepth) { return this; } + @Override + public Builder wrapCollectionElements(boolean wrapCollectionElements) { + this.wrapCollectionElements = wrapCollectionElements; + return this; + } + @Override public AnnotationsConfig.Builder annotationsConfig() { if (annotationsConfigBuilder == null) { diff --git a/types/pom.xml b/types/pom.xml index d071d0ead..099e6c565 100644 --- a/types/pom.xml +++ b/types/pom.xml @@ -59,6 +59,37 @@ junit test + + + org.apache.logging.log4j + log4j-api + test + + + + org.apache.logging.log4j + log4j-core + test + + + + org.apache.logging.log4j + log4j-slf4j-impl + test + + + + org.apache.logging.log4j + log4j-jul + test + + + + junit + junit + test + + diff --git a/types/src/test/java/org/infinispan/protostream/types/java/Book.java b/types/src/test/java/org/infinispan/protostream/types/java/Book.java new file mode 100644 index 000000000..77726aa3d --- /dev/null +++ b/types/src/test/java/org/infinispan/protostream/types/java/Book.java @@ -0,0 +1,60 @@ +package org.infinispan.protostream.types.java; + +import org.infinispan.protostream.annotations.ProtoFactory; +import org.infinispan.protostream.annotations.ProtoField; + +import java.util.Objects; + +public class Book implements Comparable { + + @ProtoField(number = 1) + String title; + + @ProtoField(number = 2) + String description; + + @ProtoField(number = 3, defaultValue = "2023") + int publicationYear; + + @ProtoFactory + public Book(String title, String description, int publicationYear) { + this.title = title; + this.description = description; + this.publicationYear = publicationYear; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Book book = (Book) o; + return publicationYear == book.publicationYear && Objects.equals(title, book.title) && Objects.equals(description, book.description); + } + + @Override + public int hashCode() { + return Objects.hash(title, description, publicationYear); + } + + @Override + public String toString() { + return "Book{" + + "title='" + title + '\'' + + ", description='" + description + '\'' + + ", publicationYear=" + publicationYear + + '}'; + } + + @Override + public int compareTo(Book o) { + int cmp = Integer.compare(this.publicationYear, o.publicationYear); + if (cmp != 0) { + return cmp; + } + cmp = title.compareTo(o.title); + if (cmp != 0) { + return cmp; + } + return description.compareTo(o.description); + } +} diff --git a/types/src/test/java/org/infinispan/protostream/types/java/BookSchema.java b/types/src/test/java/org/infinispan/protostream/types/java/BookSchema.java new file mode 100644 index 000000000..3b54bb813 --- /dev/null +++ b/types/src/test/java/org/infinispan/protostream/types/java/BookSchema.java @@ -0,0 +1,14 @@ +package org.infinispan.protostream.types.java; + +import org.infinispan.protostream.GeneratedSchema; +import org.infinispan.protostream.annotations.AutoProtoSchemaBuilder; + +@AutoProtoSchemaBuilder( + includeClasses = { + Book.class + }, + schemaFileName = "book.proto", + schemaFilePath = "proto/", + schemaPackageName = "library") +public interface BookSchema extends GeneratedSchema { +} diff --git a/types/src/test/java/org/infinispan/protostream/types/java/TypesMarshallingTest.java b/types/src/test/java/org/infinispan/protostream/types/java/TypesMarshallingTest.java index 1ba93edff..83fcb46a3 100644 --- a/types/src/test/java/org/infinispan/protostream/types/java/TypesMarshallingTest.java +++ b/types/src/test/java/org/infinispan/protostream/types/java/TypesMarshallingTest.java @@ -4,8 +4,8 @@ import org.infinispan.protostream.ImmutableSerializationContext; import org.infinispan.protostream.ProtobufUtil; import org.infinispan.protostream.SerializationContext; +import org.infinispan.protostream.config.Configuration; import org.infinispan.protostream.impl.Log; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -17,62 +17,126 @@ import java.lang.invoke.MethodHandles; import java.math.BigDecimal; import java.math.BigInteger; -import java.util.Arrays; -import java.util.BitSet; -import java.util.UUID; +import java.util.*; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; @RunWith(Parameterized.class) public class TypesMarshallingTest { private static final Log log = Log.LogFactory.getLog(MethodHandles.lookup().lookupClass()); - private final MarshallingMethod method; + private final TestConfiguration testConfiguration; private final ImmutableSerializationContext context; - public TypesMarshallingTest(MarshallingMethod method) { - this.method = method; - context = newContext(); + public TypesMarshallingTest(TestConfiguration testConfiguration) { + this.testConfiguration = testConfiguration; + context = newContext(true); } @Parameterized.Parameters public static Object[][] marshallingMethods() { return Arrays.stream(MarshallingMethodType.values()) + .flatMap(t -> switch (t) { + case BYTE_ARRAY, INPUT_STREAM, JSON -> Stream.of(new TestConfiguration(t, false, false, null)); + default -> Stream.of( + new TestConfiguration(t, true, true, null), + new TestConfiguration(t, true, false, ArrayList::new), + new TestConfiguration(t, true, false, HashSet::new), + new TestConfiguration(t, true, false, LinkedHashSet::new), + new TestConfiguration(t, true, false, LinkedList::new), + new TestConfiguration(t, true, false, TreeSet::new)); + }) .map(t -> new Object[]{t}) .toArray(Object[][]::new); } @Test public void testUUID() throws IOException { - method.marshallAndUnmarshallTest(UUID.randomUUID(), context); + testConfiguration.method.marshallAndUnmarshallTest(UUID.randomUUID(), context, false); } @Test public void testBitSet() throws IOException { var bytes = new byte[ThreadLocalRandom.current().nextInt(64)]; ThreadLocalRandom.current().nextBytes(bytes); - method.marshallAndUnmarshallTest(BitSet.valueOf(bytes), context); + testConfiguration.method.marshallAndUnmarshallTest(BitSet.valueOf(bytes), context, false); } @Test public void testBigDecimal() throws IOException { - method.marshallAndUnmarshallTest(BigDecimal.valueOf(ThreadLocalRandom.current().nextDouble(-256, 256)), context); + testConfiguration.method.marshallAndUnmarshallTest(BigDecimal.valueOf(ThreadLocalRandom.current().nextDouble(-256, 256)), context, false); } @Test public void testBigInteger() throws IOException { - method.marshallAndUnmarshallTest(BigInteger.valueOf(ThreadLocalRandom.current().nextInt()), context); + testConfiguration.method.marshallAndUnmarshallTest(BigInteger.valueOf(ThreadLocalRandom.current().nextInt()), context, false); + } + + @Test + public void testContainerWithString() throws IOException { + assumeTrue(testConfiguration.runTest); + if (testConfiguration.isArray) { + testConfiguration.method.marshallAndUnmarshallTest(stringArray(), context, true); + } else { + testConfiguration.method.marshallAndUnmarshallTest(stringCollection(testConfiguration.collectionBuilder), context, false); + } + } + + @Test + public void testContainerWithBooks() throws IOException { + assumeTrue(testConfiguration.runTest); + if (testConfiguration.isArray) { + testConfiguration.method.marshallAndUnmarshallTest(bookArray(), context, true); + } else { + testConfiguration.method.marshallAndUnmarshallTest(bookCollection(testConfiguration.collectionBuilder), context, false); + } + } + + @Test + public void testPrimitiveCollectionCompatibility() throws IOException { + assumeTrue(testConfiguration.method == MarshallingMethodType.WRAPPED_MESSAGE); + var list = new ArrayList<>(List.of("a1", "a2", "a3")); + + var oldCtx = newContext(false); + + // send with oldCtx: simulates previous version + var data = ProtobufUtil.toWrappedByteArray(oldCtx, list, 512); + // read with newCtx: simulates current version + var listCopy = ProtobufUtil.fromWrappedByteArray(context, data); + + assertEquals(list, listCopy); + + // other way around + // send with newCtx: simulates current version + data = ProtobufUtil.toWrappedByteArray(oldCtx, list, 512); + // read with oldCtx: simulates previous version + listCopy = ProtobufUtil.fromWrappedByteArray(context, data); + + assertEquals(list, listCopy); } @FunctionalInterface public interface MarshallingMethod { - void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx) throws IOException; + void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx, boolean isArray) throws IOException; } - private static ImmutableSerializationContext newContext() { - var ctx = ProtobufUtil.newSerializationContext(); + public record TestConfiguration(MarshallingMethod method, boolean runTest, boolean isArray, + Supplier> collectionBuilder) { + + } + + private static ImmutableSerializationContext newContext(boolean wrapCollectionElements) { + var config = Configuration.builder().wrapCollectionElements(wrapCollectionElements).build(); + var ctx = ProtobufUtil.newSerializationContext(config); register(new CommonTypesSchema(), ctx); register(new CommonContainerTypesSchema(), ctx); + register(new BookSchemaImpl(), ctx); return ctx; } @@ -81,40 +145,81 @@ private static void register(GeneratedSchema schema, SerializationContext contex schema.registerSchema(context); } + private static Collection stringCollection(Supplier> supplier) { + var collection = supplier.get(); + collection.add("a"); + collection.add("b"); + collection.add("c"); + return collection; + } + + private static Collection bookCollection(Supplier> supplier) { + var collection = supplier.get(); + collection.add(new Book("Book1", "Description1", 2020)); + collection.add(new Book("Book2", "Description2", 2021)); + collection.add(new Book("Book3", "Description3", 2022)); + return collection; + } + + private static String[] stringArray() { + return new String[]{"a", "b", "c"}; + } + + private static Object[] bookArray() { + // cannot use new Book[] because there is no marshaller for it. + return new Object[]{ + new Book("Book1", "Description1", 2020), + new Book("Book2", "Description2", 2021), + new Book("Book3", "Description3", 2022) + }; + } + enum MarshallingMethodType implements MarshallingMethod { WRAPPED_MESSAGE { @Override - public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx) throws IOException { + public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx, boolean isArray) throws IOException { var bytes = ProtobufUtil.toWrappedByteArray(ctx, original, 512); var copy = ProtobufUtil.fromWrappedByteArray(ctx, bytes); log.debugf("Wrapped Message: bytes length=%s, original=%s, copy=%s", bytes.length, original, copy); - Assert.assertEquals(original, copy); + if (isArray) { + assertArrayEquals((Object[]) original, (Object[]) copy); + } else { + assertEquals(original, copy); + } } }, INPUT_STREAM { @Override - public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx) throws IOException { + public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx, boolean isArray) throws IOException { var baos = new ByteArrayOutputStream(512); ProtobufUtil.writeTo(ctx, baos, original); var bais = new ByteArrayInputStream(baos.toByteArray()); var copy = ProtobufUtil.readFrom(ctx, bais, original.getClass()); log.debugf("Input Stream: bytes length=%s, original=%s, copy=%s", baos.size(), original, copy); - Assert.assertEquals(original, copy); + if (isArray) { + assertArrayEquals((Object[]) original, (Object[]) copy); + } else { + assertEquals(original, copy); + } } }, BYTE_ARRAY { @Override - public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx) throws IOException { + public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx, boolean isArray) throws IOException { var baos = new ByteArrayOutputStream(512); ProtobufUtil.writeTo(ctx, baos, original); var copy = ProtobufUtil.fromByteArray(ctx, baos.toByteArray(), original.getClass()); log.debugf("Byte Array: bytes length=%s, original=%s, copy=%s", baos.size(), original, copy); - Assert.assertEquals(original, copy); + if (isArray) { + assertArrayEquals((Object[]) original, (Object[]) copy); + } else { + assertEquals(original, copy); + } } }, JSON { @Override - public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx) throws IOException { + public void marshallAndUnmarshallTest(Object original, ImmutableSerializationContext ctx, boolean isArray) throws IOException { var bytes = ProtobufUtil.toWrappedByteArray(ctx, original, 512); var json = ProtobufUtil.toCanonicalJSON(ctx, bytes); @@ -123,7 +228,11 @@ public void marshallAndUnmarshallTest(Object original, ImmutableSerializationCon var copy = ProtobufUtil.fromWrappedByteArray(ctx, jsonBytes); log.debugf("JSON: JSON bytes length=%s, JSON String=%s, original=%s, copy=%s", jsonBytes.length, json, original, copy); - Assert.assertEquals(original, copy); + if (isArray) { + assertArrayEquals((Object[]) original, (Object[]) copy); + } else { + assertEquals(original, copy); + } } } } diff --git a/types/src/test/resources/log4j2.xml b/types/src/test/resources/log4j2.xml new file mode 100644 index 000000000..ff6a83379 --- /dev/null +++ b/types/src/test/resources/log4j2.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + +