diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index da50d076b..32a5a0686 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -22,8 +22,11 @@ import java.time.Instant; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.BlobTypeDescription; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) @@ -48,8 +51,8 @@ public abstract static class AutoAllInputsInput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) + public abstract SdkBindingData blob(); public abstract SdkBindingData> l(); @@ -66,13 +69,13 @@ public static AutoAllInputsInput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, - // Blob blob, + SdkBindingData blob, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsInput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); } } @@ -91,8 +94,8 @@ public abstract static class AutoAllInputsOutput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) + public abstract SdkBindingData blob(); public abstract SdkBindingData> l(); @@ -109,12 +112,13 @@ public static AutoAllInputsOutput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, + SdkBindingData blob, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsOutput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); } } @@ -127,6 +131,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) { input.b(), input.t(), input.d(), + input.blob(), input.l(), input.m(), input.emptyList(), diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 26e888172..30dfe0793 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -24,6 +24,10 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -31,6 +35,7 @@ import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; +import org.flyte.flytekit.jackson.BlobTypeDescription; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) @@ -57,6 +62,18 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) SdkBindingDataFactory.of(true), SdkBindingDataFactory.of(someInstant), SdkBindingDataFactory.of(Duration.ofDays(1L)), + SdkBindingDataFactory.of( + Blob.builder() + .uri("file://test/test.csv") + .metadata( + BlobMetadata.builder() + .type( + BlobType.builder() + .format("csv") + .dimensionality(BlobDimensionality.MULTIPART) + .build()) + .build()) + .build()), SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")), SdkBindingDataFactory.ofStringMap(Map.of("test", "test")), SdkBindingDataFactory.ofStringCollection(Collections.emptyList()), @@ -71,6 +88,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) outputs.b(), outputs.t(), outputs.d(), + outputs.blob(), outputs.l(), outputs.m(), outputs.emptyList(), @@ -92,8 +110,8 @@ public abstract static class AllInputsWorkflowOutput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + @BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART) + public abstract SdkBindingData blob(); public abstract SdkBindingData> l(); @@ -110,12 +128,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, + SdkBindingData blob, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); } } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java new file mode 100644 index 000000000..0fe5c791f --- /dev/null +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit.jackson; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.flyte.api.v1.BlobType.BlobDimensionality; + +/** Applied to a blob property to annotate its type. */ +@Target({ElementType.FIELD, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +public @interface BlobTypeDescription { + /** + * Describes the blob's format. + * + * @return format, not {@code null} + */ + String format(); + + /** + * Describes the blob's dimensionality. + * + * @return dimensionality, not {@code null} + */ + BlobDimensionality dimensionality(); +} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 183e00444..c100b81c3 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -30,6 +30,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkLiteralType; @@ -63,11 +65,7 @@ public void property(BeanProperty prop) { String propName = prop.getName(); AnnotatedMember member = prop.getMember(); SdkLiteralType literalType = - toLiteralType( - handledType, - /*rootLevel=*/ true, - propName, - member.getMember().getDeclaringClass().getName()); + toLiteralType(handledType, /* rootLevel= */ true, propName, member); String description = getDescription(member); @@ -132,18 +130,17 @@ private String getDescription(AnnotatedMember member) { @SuppressWarnings("AlreadyChecked") private SdkLiteralType toLiteralType( - JavaType javaType, boolean rootLevel, String propName, String declaringClassName) { + JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) { Class type = javaType.getRawClass(); if (SdkBindingData.class.isAssignableFrom(type)) { - return toLiteralType( - javaType.getBindings().getBoundType(0), false, propName, declaringClassName); + return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member); } else if (rootLevel) { throw new UnsupportedOperationException( String.format( "Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. " + "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.", - propName, declaringClassName, type)); + propName, member.getMember().getDeclaringClass().getName(), type)); } else if (isPrimitiveAssignableFrom(Long.class, type)) { return SdkLiteralTypes.integers(); } else if (isPrimitiveAssignableFrom(Double.class, type)) { @@ -159,8 +156,7 @@ private SdkLiteralType toLiteralType( } else if (List.class.isAssignableFrom(type)) { JavaType elementType = javaType.getBindings().getBoundType(0); - return SdkLiteralTypes.collections( - toLiteralType(elementType, false, propName, declaringClassName)); + return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member)); } else if (Map.class.isAssignableFrom(type)) { JavaType keyType = javaType.getBindings().getBoundType(0); JavaType valueType = javaType.getBindings().getBoundType(1); @@ -170,9 +166,22 @@ private SdkLiteralType toLiteralType( "Only Map is supported, got [" + javaType.getGenericSignature() + "]"); } - return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName)); + return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member)); + } else if (Blob.class.isAssignableFrom(type)) { + BlobTypeDescription annotation = member.getAnnotation(BlobTypeDescription.class); + if (annotation == null) { + throw new UnsupportedOperationException( + String.format( + "Field '%s' from class '%s' is declared as '%s' and it must be annotated", + propName, member.getMember().getDeclaringClass().getName(), type)); + } + return SdkLiteralTypes.blobs( + BlobType.builder() + .format(annotation.format()) + .dimensionality(annotation.dimensionality()) + .build()); } - // TODO: Support blobs and structs + // TODO: Support structs throw new UnsupportedOperationException( String.format("Unsupported type: [%s]", type.getName())); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index e99acdd8a..39f7e1034 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -28,7 +28,6 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; -import java.io.Serializable; import java.time.Duration; import java.time.Instant; import java.util.Iterator; @@ -39,6 +38,10 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -80,7 +83,7 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private static SdkBindingData transformScalar(JsonNode tree) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -102,14 +105,33 @@ private static SdkBindingData transformScalar(JsonNode t throw new UnsupportedOperationException( "Type contains an unsupported primitive: " + primitiveKind); - case GENERIC: case BLOB: + return transformBlob(tree); + + case GENERIC: default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); } } + private static SdkBindingData transformBlob(JsonNode tree) { + JsonNode value = tree.get(VALUE); + String uri = value.get("uri").asText(); + JsonNode type = value.get("metadata").get("type"); + String format = type.get("format").asText(); + BlobDimensionality dimensionality = + BlobDimensionality.valueOf(type.get("dimensionality").asText()); + return SdkBindingDataFactory.of( + Blob.builder() + .uri(uri) + .metadata( + BlobMetadata.builder() + .type(BlobType.builder().format(format).dimensionality(dimensionality).build()) + .build()) + .build()); + } + @SuppressWarnings("unchecked") private SdkBindingData> transformCollection(JsonNode tree) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java index 282376109..7862b6f26 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java @@ -16,7 +16,7 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; @@ -24,7 +24,7 @@ import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; public class BlobSerializer extends ScalarSerializer { public BlobSerializer( @@ -38,8 +38,8 @@ public BlobSerializer( @Override void serializeScalar() throws IOException { - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.BLOB); + gen.writeObject(Kind.BLOB); + gen.writeFieldName(VALUE); serializerProvider .findValueSerializer(Blob.class) .serialize(value.scalar().blob(), gen, serializerProvider); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java index 86af1b5fc..4267bc532 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol { public static final String TYPE = "type"; public static final String KIND = "kind"; public static final String PRIMITIVE = "primitive"; + public static final String BLOB = "blob"; } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index b4b6ce995..38fd83639 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -38,7 +38,10 @@ import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -64,7 +67,7 @@ public static AutoValueInput createAutoValueInput( boolean b, Instant t, Duration d, - // Blob blob, + Blob blob, List l, Map m, List> ll, @@ -78,6 +81,7 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(b), SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), + SdkBindingDataFactory.of(blob), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.ofStringMap(m), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), @@ -98,7 +102,7 @@ public void testVariableMap() { hasEntry("b", createVar(SimpleType.BOOLEAN)), hasEntry("t", createVar(SimpleType.DATETIME)), hasEntry("d", createVar(SimpleType.DURATION)), - // hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), hasEntry( "l", createVar(LiteralType.ofCollectionType(ofSimpleType(SimpleType.STRING)))), hasEntry( @@ -119,11 +123,11 @@ public void testVariableMap() { void testFromLiteralMap() { Instant datetime = Instant.ofEpochSecond(12, 34); Duration duration = Duration.ofSeconds(56, 78); - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); Map literalMap = new HashMap<>(); literalMap.put("i", literalOf(Primitive.ofIntegerValue(123L))); literalMap.put("f", literalOf(Primitive.ofFloatValue(123.0))); @@ -131,7 +135,7 @@ void testFromLiteralMap() { literalMap.put("b", literalOf(Primitive.ofBooleanValue(true))); literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); - // literalMap.put("blob", literalOf(blob)); + literalMap.put("blob", literalOf(blob)); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); literalMap.put( @@ -159,9 +163,9 @@ void testFromLiteralMap() { Literal.ofMap( Map.of( "math", - Literal.ofMap( - Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), - "pokemon", Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); + Literal.ofMap(Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), + "pokemon", + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); AutoValueInput input = JacksonSdkType.of(AutoValueInput.class).fromLiteralMap(literalMap); @@ -175,7 +179,7 @@ void testFromLiteralMap() { /* b= */ true, /* t= */ datetime, /* d= */ duration, - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("123"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -194,11 +198,11 @@ private static Literal stringLiteralOf(String string) { @Test void testToLiteralMap() { - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); Map literalMap = JacksonSdkType.of(AutoValueInput.class) .toLiteralMap( @@ -209,7 +213,7 @@ void testToLiteralMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -271,13 +275,17 @@ void testToLiteralMap() { Map.of( "pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), "pokemon", - Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))) - // hasEntry("blob", literalOf(blob)) - ))); + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))), + hasEntry("blob", literalOf(blob))))); } @Test public void testToSdkBindingDataMap() { + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); AutoValueInput input = createAutoValueInput( /* i= */ 42L, @@ -286,7 +294,7 @@ public void testToSdkBindingDataMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -305,6 +313,7 @@ public void testToSdkBindingDataMap() { expected.put("b", input.b()); expected.put("t", input.t()); expected.put("d", input.d()); + expected.put("blob", input.blob()); expected.put("l", input.l()); expected.put("m", input.m()); expected.put("ll", input.ll()); @@ -536,8 +545,8 @@ public abstract static class AutoValueInput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + @BlobTypeDescription(format = "", dimensionality = BlobDimensionality.SINGLE) + public abstract SdkBindingData blob(); public abstract SdkBindingData> l(); @@ -558,7 +567,7 @@ public static AutoValueInput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, - // Blob blob, + SdkBindingData blob, SdkBindingData> l, SdkBindingData> m, SdkBindingData>> ll, @@ -566,7 +575,7 @@ public static AutoValueInput create( SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, l, m, ll, lm, ml, mm); + i, f, s, b, t, d, blob, l, m, ll, lm, ml, mm); } } @@ -701,4 +710,8 @@ private static Variable createVar(LiteralType literalType, String description) { private static Literal literalOf(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } + + private static Literal literalOf(Blob blob) { + return Literal.ofScalar(Scalar.ofBlob(blob)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java index 1c18ee674..f9fb9162b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.time.Instant; +import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; @@ -49,6 +50,10 @@ static Literal ofDuration(Duration value) { return ofPrimitive(Primitive.ofDuration(value)); } + static Literal ofBlob(Blob value) { + return Literal.ofScalar(Scalar.ofBlob(value)); + } + private static Literal ofPrimitive(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index 757e5bf10..f8c9d8d47 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -26,6 +26,7 @@ import java.time.ZoneOffset; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; /** A utility class for creating {@link SdkBindingData} objects for different types. */ public final class SdkBindingDataFactory { @@ -123,6 +124,10 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + public static SdkBindingData of(SdkLiteralType type, T value) { + return SdkBindingData.literal(type, value); + } + /** * Creates a {@code SdkBindingData} for a flyte collection of string given a java {@code * List}. @@ -296,4 +301,8 @@ public static SdkBindingData> ofBindingMap( return SdkBindingData.bindingMap(valuesType, valueMap); } + + public static SdkBindingData of(Blob value) { + return SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata().type()), value); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java index f09073ea9..6587399d2 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java @@ -25,6 +25,8 @@ import java.util.Map; import java.util.Map.Entry; import org.flyte.api.v1.BindingData; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -73,6 +75,8 @@ public static SdkLiteralType of(Class clazz) { return (SdkLiteralType) datetimes(); } else if (clazz.equals(Duration.class)) { return (SdkLiteralType) durations(); + } else if (clazz.equals(BlobType.class)) { + return null; } throw new IllegalArgumentException("Unsupported type: " + clazz); } @@ -181,6 +185,15 @@ public static SdkLiteralType> maps(SdkLiteralType mapValue return new MapSdkLiteralType<>(mapValueType); } + /** + * Returns a {@link SdkLiteralType} for blobs. + * + * @return the {@link SdkLiteralType} + */ + public static SdkLiteralType blobs(BlobType blobType) { + return new BlobSdkLiteralType(blobType); + } + private static class IntegerSdkLiteralType extends PrimitiveSdkLiteralType { private static final IntegerSdkLiteralType INSTANCE = new IntegerSdkLiteralType(); @@ -205,6 +218,39 @@ public String toString() { } } + private static class BlobSdkLiteralType extends SdkLiteralType { + private final BlobType blobType; + + public BlobSdkLiteralType(BlobType blobType) { + this.blobType = blobType; + } + + @Override + public LiteralType getLiteralType() { + return LiteralType.ofBlobType(blobType); + } + + @Override + public Literal toLiteral(Blob value) { + return Literals.ofBlob(value); + } + + @Override + public Blob fromLiteral(Literal literal) { + return literal.scalar().blob(); + } + + @Override + public BindingData toBindingData(Blob value) { + return BindingData.ofScalar(Scalar.ofBlob(value)); + } + + @Override + public String toString() { + return "blobs"; + } + } + private static class FloatSdkLiteralType extends PrimitiveSdkLiteralType { private static final FloatSdkLiteralType INSTANCE = new FloatSdkLiteralType(); diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala index 2c2203c06..7e5ed0f4e 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -16,6 +16,8 @@ */ package org.flyte.flytekitscala +import org.flyte.api.v1.{Blob, BlobType} +import org.flyte.api.v1.BlobType.BlobDimensionality import org.flyte.flytekit.SdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{of, _} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 377833815..0ef82f3b7 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -16,9 +16,14 @@ */ package org.flyte.flytekitscala +import org.flyte.api.v1.BlobType.BlobDimensionality + import java.time.{Duration, Instant} import scala.jdk.CollectionConverters._ import org.flyte.api.v1.{ + Blob, + BlobMetadata, + BlobType, Literal, LiteralType, Primitive, @@ -386,6 +391,24 @@ class SdkScalaTypeTest { SdkJavaBindingDataFactory.of(true), SdkJavaBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkJavaBindingDataFactory.of(Duration.ZERO), + SdkJavaBindingDataFactory.of( + Blob + .builder() + .uri("file://test/test.csv") + .metadata( + BlobMetadata + .builder() + .`type`( + BlobType + .builder() + .format("csv") + .dimensionality(BlobDimensionality.MULTIPART) + .build() + ) + .build() + ) + .build() + ), SdkJavaBindingDataFactory.ofStringCollection(List("1", "2", "3").asJava), SdkJavaBindingDataFactory.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), SdkJavaBindingDataFactory.ofStringCollection(List.empty[String].asJava), diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala index 2c4923d55..ddb20199b 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala @@ -172,7 +172,6 @@ object SdkBindingDataConverters { jf.Function.identity() ) } - case LiteralType.Kind.BLOB_TYPE => ??? // TODO not yet supported case LiteralType.Kind.SCHEMA_TYPE => ??? // TODO not yet supported case LiteralType.Kind.COLLECTION_TYPE => val TypeCastingResult(convertedElementType, convFunction) = toScalaType( @@ -257,7 +256,6 @@ object SdkBindingDataConverters { jf.Function.identity() ) } - case LiteralType.Kind.BLOB_TYPE => ??? // TODO do we support blob? case LiteralType.Kind.SCHEMA_TYPE => ??? // TODO do we support schema type? case LiteralType.Kind.COLLECTION_TYPE => diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 596639929..1fc5060e5 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -202,6 +202,14 @@ object SdkLiteralTypes { */ def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations() + /** Returns a [[SdkLiteralType]] for blob. + * + * @return + * the [[SdkLiteralType]] + */ + def blobs(blobType: BlobType): SdkLiteralType[Blob] = + SdkJavaLiteralTypes.blobs(blobType) + /** Returns a [[SdkLiteralType]] for flyte collections. * * @param elementType diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 19e190348..bd75cb68a 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -19,6 +19,7 @@ package org.flyte.flytekitscala import java.time.{Duration, Instant} import java.{util => ju} import magnolia.{CaseClass, Magnolia, Param, SealedTrait} +import org.flyte.api.v1.BlobType.BlobDimensionality import org.flyte.api.v1._ import org.flyte.flytekit.{ SdkBindingData, @@ -230,6 +231,18 @@ object SdkScalaType { implicit def durationLiteralType: SdkScalaLiteralType[Duration] = DelegateLiteralType(SdkLiteralTypes.durations()) + // fixme: create blob type from annotation + implicit def blobLiteralType: SdkScalaLiteralType[Blob] = + DelegateLiteralType( + SdkLiteralTypes.blobs( + BlobType + .builder() + .format("") + .dimensionality(BlobDimensionality.SINGLE) + .build() + ) + ) + // TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData // This makes Scala dev mad when they are forced to use the java types instead of scala types // We need to think what to do, maybe move the factory methods out of SdkDataBinding into their own class diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 278035e50..2e843bf0c 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -51,6 +51,11 @@ auto-service-annotations provided + + org.flyte + flytekit-api + provided +