diff --git a/.scalafmt.conf b/.scalafmt.conf index 6d6fd4e2c..971a38a84 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,3 +1,3 @@ -version=2.5.2 +version=3.7.14 runner.dialect=scala212source3 diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index 32a5a0686..4039283aa 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -37,8 +37,20 @@ public AllInputsTask() { JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class)); } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public abstract String world(); + + public static Nested create(String hello, String world) { + return new AutoValue_AllInputsTask_Nested(hello, world); + } + } + @AutoValue public abstract static class AutoAllInputsInput { + public abstract SdkBindingData i(); public abstract SdkBindingData f(); @@ -54,6 +66,8 @@ public abstract static class AutoAllInputsInput { @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -70,12 +84,13 @@ public static AutoAllInputsInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsInput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -97,6 +112,8 @@ public abstract static class AutoAllInputsOutput { @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -113,12 +130,13 @@ public static AutoAllInputsOutput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsOutput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -132,6 +150,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) { input.t(), input.d(), input.blob(), + input.generic(), input.l(), input.m(), input.emptyList(), diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 30dfe0793..1183f3ce3 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -29,6 +29,7 @@ import org.flyte.api.v1.BlobType; import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; +import org.flyte.examples.AllInputsTask.Nested; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkNode; @@ -36,6 +37,7 @@ import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.BlobTypeDescription; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) @@ -74,6 +76,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) .build()) .build()) .build()), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")), SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")), SdkBindingDataFactory.ofStringMap(Map.of("test", "test")), SdkBindingDataFactory.ofStringCollection(Collections.emptyList()), @@ -89,6 +93,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) outputs.t(), outputs.d(), outputs.blob(), + outputs.generic(), outputs.l(), outputs.m(), outputs.emptyList(), @@ -113,6 +118,8 @@ public abstract static class AllInputsWorkflowOutput { @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -129,12 +136,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java index 0be5ba34f..969fa64bd 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java @@ -33,6 +33,7 @@ import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; /** * Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link @@ -102,7 +103,8 @@ public Literal toLiteral(T value) { var tree = OBJECT_MAPPER.valueToTree(value); try { - return OBJECT_MAPPER.treeToValue(tree, Literal.class); + return Literal.ofScalar( + Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap())); } catch (IOException e) { throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java index 861a1c640..4ec2d158d 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java @@ -20,8 +20,8 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.module.SimpleDeserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.api.v1.Literal; -import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; import org.flyte.flytekit.jackson.serializers.StructSerializer; class SdkLiteralTypeModule extends Module { @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); var deserializers = new SimpleDeserializers(); - deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer()); + deserializers.addDeserializer(StructWrapper.class, new StructDeserializer()); context.addDeserializers(deserializers); // append with the lowest priority to use as fallback, if builtin annotations aren't present diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java index 17f71c25a..aa25ff45e 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.deser.Deserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers; import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers; import org.flyte.flytekit.jackson.serializers.BindingMapSerializers; import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers; @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); context.addSerializers(new LiteralMapSerializers()); - context.addDeserializers(new LiteralMapDeserializers()); context.addSerializers(new BindingMapSerializers()); context.addDeserializers(sdkbindingDeserializers); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index c100b81c3..b97366859 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -181,9 +181,12 @@ private SdkLiteralType toLiteralType( .dimensionality(annotation.dimensionality()) .build()); } - // TODO: Support structs - throw new UnsupportedOperationException( - String.format("Unsupported type: [%s]", type.getName())); + try { + return JacksonSdkLiteralType.of(type); + } catch (Exception e) { + throw new UnsupportedOperationException( + String.format("Unsupported type: [%s]", type.getName()), e); + } } private static boolean isPrimitiveAssignableFrom(Class fromClass, Class toClass) { diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java deleted file mode 100644 index f3015c3da..000000000 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2020-2023 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit.jackson.deserializers; - -import com.fasterxml.jackson.databind.BeanDescription; -import com.fasterxml.jackson.databind.DeserializationConfig; -import com.fasterxml.jackson.databind.JavaType; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.deser.Deserializers; -import java.util.Map; -import org.flyte.api.v1.LiteralType; -import org.flyte.flytekit.jackson.JacksonLiteralMap; - -public class LiteralMapDeserializers extends Deserializers.Base { - - @Override - public JsonDeserializer findBeanDeserializer( - JavaType type, DeserializationConfig config, BeanDescription beanDesc) { - if (type.getRawClass().equals(JacksonLiteralMap.class)) { - Map literalTypeMap = type.getValueHandler(); - - return new LiteralMapDeserializer(literalTypeMap); - } - - return null; - } -} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index 39f7e1034..bbaca166c 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -24,8 +24,12 @@ import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.BeanProperty; import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.ContextualDeserializer; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; import java.time.Duration; @@ -46,36 +50,48 @@ import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLiteralType; import org.flyte.flytekit.SdkLiteralTypes; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; -class SdkBindingDataDeserializer extends StdDeserializer> { +class SdkBindingDataDeserializer extends StdDeserializer> + implements ContextualDeserializer { private static final long serialVersionUID = 0L; + private final JavaType type; + public SdkBindingDataDeserializer() { + this(null); + } + + private SdkBindingDataDeserializer(JavaType type) { super(SdkBindingData.class); + + this.type = type; } @Override public SdkBindingData deserialize( JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode tree = jsonParser.readValueAsTree(); - return transform(tree); + return transform(tree, deserializationContext); } - private SdkBindingData transform(JsonNode tree) { + private SdkBindingData transform( + JsonNode tree, DeserializationContext deserializationContext) { Literal.Kind literalKind = Literal.Kind.valueOf(tree.get(LITERAL).asText()); switch (literalKind) { case SCALAR: - return transformScalar(tree); + return transformScalar(tree, deserializationContext); case COLLECTION: - return transformCollection(tree); + return transformCollection(tree, deserializationContext); case MAP: - return transformMap(tree); + return transformMap(tree, deserializationContext); default: throw new UnsupportedOperationException( @@ -83,7 +99,8 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private SdkBindingData transformScalar( + JsonNode tree, DeserializationContext deserializationContext) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -109,6 +126,8 @@ private static SdkBindingData transformScalar(JsonNode tree) { return transformBlob(tree); case GENERIC: + return transformGeneric(tree, deserializationContext, scalarKind); + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); @@ -132,8 +151,28 @@ private static SdkBindingData transformBlob(JsonNode tree) { .build()); } + private SdkBindingData transformGeneric( + JsonNode tree, DeserializationContext deserializationContext, Kind scalarKind) { + JsonParser jsonParser = tree.get(VALUE).traverse(); + try { + jsonParser.nextToken(); + Object object = + deserializationContext + .findNonContextualValueDeserializer(type) + .deserialize(jsonParser, deserializationContext); + @SuppressWarnings("unchecked") + SdkLiteralType jacksonSdkLiteralType = + (SdkLiteralType) JacksonSdkLiteralType.of(type.getRawClass()); + return SdkBindingData.literal(jacksonSdkLiteralType, object); + } catch (IOException e) { + throw new UnsupportedOperationException( + "Type contains an unsupported generic: " + scalarKind, e); + } + } + @SuppressWarnings("unchecked") - private SdkBindingData> transformCollection(JsonNode tree) { + private SdkBindingData> transformCollection( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); Iterator elements = tree.get(VALUE).elements(); @@ -143,7 +182,10 @@ private SdkBindingData> transformCollection(JsonNode tree) { case COLLECTION_TYPE: List collection = (List) - streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList()); + streamOf(elements) + .map((JsonNode tree1) -> transform(tree1, deserializationContext)) + .map(SdkBindingData::get) + .collect(toList()); return SdkBindingDataFactory.of(literalType, collection); case SCHEMA_TYPE: @@ -155,7 +197,8 @@ private SdkBindingData> transformCollection(JsonNode tree) { } @SuppressWarnings("unchecked") - private SdkBindingData> transformMap(JsonNode tree) { + private SdkBindingData> transformMap( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); JsonNode valueNode = tree.get(VALUE); List> entries = @@ -168,7 +211,11 @@ private SdkBindingData> transformMap(JsonNode tree) { case COLLECTION_TYPE: Map bindingDataMap = entries.stream() - .map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get())) + .map( + entry -> + Map.entry( + entry.getKey(), + (T) transform(entry.getValue(), deserializationContext).get())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); return SdkBindingDataFactory.of(literalType, bindingDataMap); @@ -220,4 +267,9 @@ private Stream streamOf(Iterator nodes) { return StreamSupport.stream( Spliterators.spliteratorUnknownSize(nodes, Spliterator.ORDERED), false); } + + @Override + public JsonDeserializer createContextual(DeserializationContext ctxt, BeanProperty property) { + return new SdkBindingDataDeserializer(property.getType().containedType(0)); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java similarity index 70% rename from flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java rename to flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java index 0c17f55d5..88f673f80 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java @@ -29,23 +29,35 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.flyte.api.v1.Literal; -import org.flyte.api.v1.Scalar; import org.flyte.api.v1.Struct; import org.flyte.api.v1.Struct.Value; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; -public class LiteralStructDeserializer extends StdDeserializer { +public class StructDeserializer extends StdDeserializer { private static final long serialVersionUID = -6835948754469626304L; - public LiteralStructDeserializer() { - super(Literal.class); + // we cannot use Struct directly because it is an auto-value class so this deserializer will not + // be used by Jackson + public static class StructWrapper { + + private final Struct struct; + + public StructWrapper(Struct struct) { + this.struct = struct; + } + + public Struct unwrap() { + return struct; + } } - @Override - public Literal deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public StructDeserializer() { + super(StructWrapper.class); + } - Struct generic = readValueAsStruct(p); - return Literal.ofScalar(Scalar.ofGeneric(generic)); + @Override + public StructWrapper deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + return new StructWrapper(readValueAsStruct(p)); } private static Struct readValueAsStruct(JsonParser p) throws IOException { @@ -67,7 +79,7 @@ private static Struct readValueAsStruct(JsonParser p) throws IOException { return Struct.of(unmodifiableMap(fields)); } - private static Struct.Value readValueAsStructValue(JsonParser p) throws IOException { + private static Value readValueAsStructValue(JsonParser p) throws IOException { switch (p.currentToken()) { case START_ARRAY: p.nextToken(); @@ -75,38 +87,38 @@ private static Struct.Value readValueAsStructValue(JsonParser p) throws IOExcept List valuesList = new ArrayList<>(); while (p.currentToken() != JsonToken.END_ARRAY) { - Struct.Value value = readValueAsStructValue(p); + Value value = readValueAsStructValue(p); p.nextToken(); valuesList.add(value); } - return Struct.Value.ofListValue(unmodifiableList(valuesList)); + return Value.ofListValue(unmodifiableList(valuesList)); case START_OBJECT: Struct struct = readValueAsStruct(p); - return Struct.Value.ofStructValue(struct); + return Value.ofStructValue(struct); case VALUE_STRING: String stringValue = p.readValueAs(String.class); - return Struct.Value.ofStringValue(stringValue); + return Value.ofStringValue(stringValue); case VALUE_NUMBER_FLOAT: case VALUE_NUMBER_INT: Double doubleValue = p.readValueAs(Double.class); - return Struct.Value.ofNumberValue(doubleValue); + return Value.ofNumberValue(doubleValue); case VALUE_NULL: - return Struct.Value.ofNullValue(); + return Value.ofNullValue(); case VALUE_FALSE: - return Struct.Value.ofBoolValue(false); + return Value.ofBoolValue(false); case VALUE_TRUE: - return Struct.Value.ofBoolValue(true); + return Value.ofBoolValue(true); case FIELD_NAME: case NOT_AVAILABLE: diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java index 12ec69e18..5c73535c7 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java @@ -16,20 +16,15 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.LITERAL; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_TYPE; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_VALUE; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; import java.io.IOException; -import java.util.Map; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Struct; public class GenericSerializer extends ScalarSerializer { public GenericSerializer( @@ -48,85 +43,7 @@ public GenericSerializer( @Override public void serializeScalar() throws IOException { gen.writeObject(Scalar.Kind.GENERIC); - for (Map.Entry entry : value.scalar().generic().fields().entrySet()) { - gen.writeFieldName(entry.getKey()); - serializeStructValue(entry.getValue()); - } - } - - private void serializeStructValue(Struct.Value value) throws IOException { - if (!value.kind().equals(Struct.Value.Kind.LIST_VALUE) - && !value.kind().equals(Struct.Value.Kind.NULL_VALUE)) { - gen.writeStartObject(); - gen.writeFieldName(LITERAL); - gen.writeObject(Literal.Kind.SCALAR); - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.GENERIC); - } - - if (isSimpleType(value.kind())) { - gen.writeFieldName(STRUCT_TYPE); - } - switch (value.kind()) { - case BOOL_VALUE: - writeSimpleType( - Struct.Value.Kind.BOOL_VALUE, - value, - (generator, v) -> generator.writeBoolean(v.boolValue())); - return; - - case LIST_VALUE: - throw new RuntimeException("not supported list inside the struct"); - - case NUMBER_VALUE: - writeSimpleType( - Struct.Value.Kind.NUMBER_VALUE, - value, - (generator, v) -> generator.writeNumber(v.numberValue())); - return; - - case STRING_VALUE: - writeSimpleType( - Struct.Value.Kind.STRING_VALUE, - value, - (generator, v) -> generator.writeString(v.stringValue())); - return; - - case STRUCT_VALUE: - value.structValue().fields().forEach((k, v) -> writeStructValue(gen, k, v)); - gen.writeEndObject(); - return; - - case NULL_VALUE: - gen.writeNull(); - } - } - - private void writeStructValue(JsonGenerator gen, String k, Struct.Value v) { - try { - gen.writeFieldName(k); - serializeStructValue(v); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private boolean isSimpleType(Struct.Value.Kind kind) { - return kind.equals(Struct.Value.Kind.BOOL_VALUE) - || kind.equals(Struct.Value.Kind.NUMBER_VALUE) - || kind.equals(Struct.Value.Kind.STRING_VALUE); - } - - private void writeSimpleType( - Struct.Value.Kind kind, Struct.Value structValue, WriteGenericFunction writeTypeFunction) - throws IOException { - gen.writeObject(kind); - gen.writeFieldName(STRUCT_VALUE); - writeTypeFunction.write(gen, structValue); - gen.writeEndObject(); - } - - interface WriteGenericFunction { - void write(JsonGenerator gen, Struct.Value value) throws IOException; + gen.writeFieldName(VALUE); + new StructSerializer().serialize(value.scalar().generic(), gen, serializerProvider); } } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index 38fd83639..91eace169 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -47,6 +47,8 @@ import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Struct.Value; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -68,6 +70,7 @@ public static AutoValueInput createAutoValueInput( Instant t, Duration d, Blob blob, + Nested generic, List l, Map m, List> ll, @@ -82,6 +85,7 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), generic), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.ofStringMap(m), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), @@ -103,6 +107,7 @@ public void testVariableMap() { hasEntry("t", createVar(SimpleType.DATETIME)), hasEntry("d", createVar(SimpleType.DURATION)), hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("generic", createVar(LiteralType.ofSimpleType(SimpleType.STRUCT))), hasEntry( "l", createVar(LiteralType.ofCollectionType(ofSimpleType(SimpleType.STRING)))), hasEntry( @@ -136,6 +141,7 @@ void testFromLiteralMap() { literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); literalMap.put("blob", literalOf(blob)); + literalMap.put("generic", literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); literalMap.put( @@ -180,6 +186,7 @@ void testFromLiteralMap() { /* t= */ datetime, /* d= */ duration, /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("123"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -214,6 +221,7 @@ void testToLiteralMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -295,6 +303,7 @@ public void testToSdkBindingDataMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -314,6 +323,7 @@ public void testToSdkBindingDataMap() { expected.put("t", input.t()); expected.put("d", input.d()); expected.put("blob", input.blob()); + expected.put("generic", input.generic()); expected.put("l", input.l()); expected.put("m", input.m()); expected.put("ll", input.ll()); @@ -529,6 +539,15 @@ public static AutoValueDeprecatedInput create(long i) { } } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public static AutoValue_JacksonSdkTypeTest_Nested create(String hello) { + return new AutoValue_JacksonSdkTypeTest_Nested(hello); + } + } + @AutoValue public abstract static class AutoValueInput { @@ -548,6 +567,8 @@ public abstract static class AutoValueInput { @BlobTypeDescription(format = "", dimensionality = BlobDimensionality.SINGLE) public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -568,6 +589,7 @@ public static AutoValueInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData>> ll, @@ -575,7 +597,7 @@ public static AutoValueInput create( SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, blob, l, m, ll, lm, ml, mm); + i, f, s, b, t, d, blob, generic, l, m, ll, lm, ml, mm); } } @@ -714,4 +736,8 @@ private static Literal literalOf(Primitive primitive) { private static Literal literalOf(Blob blob) { return Literal.ofScalar(Scalar.ofBlob(blob)); } + + private static Literal literalOf(Struct generic) { + return Literal.ofScalar(Scalar.ofGeneric(generic)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index ac183ad8d..aa7217ec8 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -124,6 +124,17 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + /** + * Creates a {@code SdkBindingData} for a flyte type with the given value. + * + * @param type the flyte type + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(SdkLiteralType type, T value) { + return SdkBindingData.literal(type, value); + } + /** * Creates a {@code SdkBindingData} for a flyte Blob with the given value. * diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 1616a78a4..a3cbaf863 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -38,7 +38,8 @@ import org.flyte.flytekit.{ import org.flyte.flytekitscala.SdkBindingDataFactory import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test -import org.flyte.examples.AllInputsTask.AutoAllInputsInput +import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested} +import org.flyte.flytekit.jackson.JacksonSdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings} class SdkScalaTypeTest { @@ -409,6 +410,10 @@ class SdkScalaTypeTest { SdkJavaBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkJavaBindingDataFactory.of(Duration.ZERO), SdkJavaBindingDataFactory.of(blob), + SdkJavaBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkJavaBindingDataFactory.ofStringCollection(List("1", "2", "3").asJava), SdkJavaBindingDataFactory.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), SdkJavaBindingDataFactory.ofStringCollection(List.empty[String].asJava), @@ -425,6 +430,7 @@ class SdkScalaTypeTest { instant: SdkBindingData[Instant], duration: SdkBindingData[Duration], blob: SdkBindingData[Blob], + generic: SdkBindingData[Nested], list: SdkBindingData[List[String]], map: SdkBindingData[Map[String, String]], emptyList: SdkBindingData[List[String]], @@ -439,6 +445,7 @@ class SdkScalaTypeTest { input.t(), input.d(), input.blob(), + input.generic(), toScalaList(input.l()), toScalaMap(input.m()), toScalaList(input.emptyList()), @@ -453,6 +460,10 @@ class SdkScalaTypeTest { SdkBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkBindingDataFactory.of(Duration.ZERO), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkBindingDataFactory.of(List("1", "2", "3")), SdkBindingDataFactory.of(Map("a" -> "2", "b" -> "3")), SdkBindingDataFactory.ofStringCollection(List.empty[String]), diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala index ad134a296..857238ee4 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala @@ -145,6 +145,18 @@ object SdkBindingDataFactory { def of(value: Blob): SdkBindingData[Blob] = SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata.`type`), value) + /** Creates a [[SdkBindingData]] for a flyte type with the given value. + * + * @param type + * the flyte type + * @param value + * the simple value for this data + * @return + * the new [[SdkBindingData]] + */ + def of[T](`type`: SdkLiteralType[T], value: T): SdkBindingData[T] = + SdkBindingData.literal(`type`, value) + /** Creates a [[SdkBindingDataFactory]] for a flyte string collection given a * scala [[List]]. * diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java index 0e78c9774..23fee27ad 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java @@ -16,12 +16,15 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class BuildBqReference extends SdkRunnableTask { private static final long serialVersionUID = -489898361071672070L; @@ -35,7 +38,10 @@ public BuildBqReference() { @Override public Output run(Input input) { return Output.create( - BQReference.create(input.project().get(), input.dataset().get(), input.tableName().get())); + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(BQReference.class), + BQReference.create( + input.project().get(), input.dataset().get(), input.tableName().get()))); } @AutoValue @@ -58,11 +64,8 @@ public static Input create( public abstract static class Output { abstract SdkBindingData ref(); - public static Output create(BQReference ref) { - // TODO We need a way to generate SdkBindings of generic autovalues like BQReference - // that would be mapped to sdkStructs. JacksonSdkType of nested autovalues are mapped as - // structs - return null; + public static Output create(SdkBindingData ref) { + return new AutoValue_BuildBqReference_Output(ref); } } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java index 9e82df7ca..2389f4434 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java @@ -16,13 +16,14 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class MockLookupBqTask extends SdkRunnableTask { private static final long serialVersionUID = 604843235716487166L; @@ -39,7 +40,7 @@ public abstract static class Input { public static Input create( SdkBindingData ref, SdkBindingData checkIfExists) { - return null; // TODO + return new AutoValue_MockLookupBqTask_Input(ref, checkIfExists); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java index d1b565e77..5a5c6ccca 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java @@ -16,6 +16,7 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -23,10 +24,7 @@ import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; -// This workflow relays on SdkBinding that should be serialized -// as Struct. By going to typed inputs and outputs, we have de-scoped the support -// of structs. -// @AutoService(SdkWorkflow.class) +@AutoService(SdkWorkflow.class) public class MockPipelineWorkflow extends SdkWorkflow { public MockPipelineWorkflow() { diff --git a/integration-tests/src/test/java/org/flyte/AdditionalIT.java b/integration-tests/src/test/java/org/flyte/AdditionalIT.java index 3c9914312..00e50c27a 100644 --- a/integration-tests/src/test/java/org/flyte/AdditionalIT.java +++ b/integration-tests/src/test/java/org/flyte/AdditionalIT.java @@ -23,7 +23,6 @@ import flyteidl.core.Literals; import org.flyte.utils.Literal; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -67,7 +66,6 @@ public void testBranchNodeWorkflow(long a, long b, long c, long d, String expect "table-exists,true", "non-existent,false", }) - @Disabled("Not supporting struct with the strongly typed implementation.") public void testStructs(String name, boolean expected) { Literals.LiteralMap output = CLIENT.createExecution(