diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java new file mode 100644 index 000000000..5806df7c2 --- /dev/null +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/BlobTypeDescription.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit.jackson; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** Applied to a blob property to annotate its type. */ +@Target({ElementType.FIELD, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +public @interface BlobTypeDescription { + /** + * Describes the blob's format. + * + * @return format, not {@code null} + */ + String format(); + + /** + * Describes the blob's dimensionality. + * + * @return dimensionality, not {@code null} + */ + String dimensionality(); +} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 183e00444..a59d24027 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -30,6 +30,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkLiteralType; @@ -63,11 +66,7 @@ public void property(BeanProperty prop) { String propName = prop.getName(); AnnotatedMember member = prop.getMember(); SdkLiteralType literalType = - toLiteralType( - handledType, - /*rootLevel=*/ true, - propName, - member.getMember().getDeclaringClass().getName()); + toLiteralType(handledType, /* rootLevel= */ true, propName, member); String description = getDescription(member); @@ -132,18 +131,17 @@ private String getDescription(AnnotatedMember member) { @SuppressWarnings("AlreadyChecked") private SdkLiteralType toLiteralType( - JavaType javaType, boolean rootLevel, String propName, String declaringClassName) { + JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) { Class type = javaType.getRawClass(); if (SdkBindingData.class.isAssignableFrom(type)) { - return toLiteralType( - javaType.getBindings().getBoundType(0), false, propName, declaringClassName); + return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member); } else if (rootLevel) { throw new UnsupportedOperationException( String.format( "Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. " + "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.", - propName, declaringClassName, type)); + propName, member.getMember().getDeclaringClass().getName(), type)); } else if (isPrimitiveAssignableFrom(Long.class, type)) { return SdkLiteralTypes.integers(); } else if (isPrimitiveAssignableFrom(Double.class, type)) { @@ -159,8 +157,7 @@ private SdkLiteralType toLiteralType( } else if (List.class.isAssignableFrom(type)) { JavaType elementType = javaType.getBindings().getBoundType(0); - return SdkLiteralTypes.collections( - toLiteralType(elementType, false, propName, declaringClassName)); + return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member)); } else if (Map.class.isAssignableFrom(type)) { JavaType keyType = javaType.getBindings().getBoundType(0); JavaType valueType = javaType.getBindings().getBoundType(1); @@ -170,9 +167,22 @@ private SdkLiteralType toLiteralType( "Only Map is supported, got [" + javaType.getGenericSignature() + "]"); } - return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName)); + return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member)); + } else if (Blob.class.isAssignableFrom(type)) { + BlobTypeDescription annotation = member.getAnnotation(BlobTypeDescription.class); + if (annotation == null) { + throw new UnsupportedOperationException( + String.format( + "Field '%s' from class '%s' is declared as '%s' and it must be annotated", + propName, member.getMember().getDeclaringClass().getName(), type)); + } + return SdkLiteralTypes.blobs( + BlobType.builder() + .format(annotation.format()) + .dimensionality(BlobDimensionality.valueOf(annotation.dimensionality())) + .build()); } - // TODO: Support blobs and structs + // TODO: Support structs throw new UnsupportedOperationException( String.format("Unsupported type: [%s]", type.getName())); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index e99acdd8a..7a84c1e90 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -28,7 +28,6 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; -import java.io.Serializable; import java.time.Duration; import java.time.Instant; import java.util.Iterator; @@ -39,6 +38,10 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -80,7 +83,7 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private static SdkBindingData transformScalar(JsonNode tree) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -102,8 +105,27 @@ private static SdkBindingData transformScalar(JsonNode t throw new UnsupportedOperationException( "Type contains an unsupported primitive: " + primitiveKind); - case GENERIC: case BLOB: + JsonNode value = tree.get(VALUE); + String uri = value.get("uri").asText(); + JsonNode type = value.get("metadata").get("type"); + String format = type.get("format").asText(); + BlobDimensionality dimensionality = + BlobDimensionality.valueOf(type.get("dimensionality").asText()); + return SdkBindingDataFactory.of( + Blob.builder() + .uri(uri) + .metadata( + BlobMetadata.builder() + .type( + BlobType.builder() + .format(format) + .dimensionality(dimensionality) + .build()) + .build()) + .build()); + case GENERIC: + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java index 282376109..7862b6f26 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java @@ -16,7 +16,7 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; @@ -24,7 +24,7 @@ import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; public class BlobSerializer extends ScalarSerializer { public BlobSerializer( @@ -38,8 +38,8 @@ public BlobSerializer( @Override void serializeScalar() throws IOException { - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.BLOB); + gen.writeObject(Kind.BLOB); + gen.writeFieldName(VALUE); serializerProvider .findValueSerializer(Blob.class) .serialize(value.scalar().blob(), gen, serializerProvider); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java index 86af1b5fc..4267bc532 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol { public static final String TYPE = "type"; public static final String KIND = "kind"; public static final String PRIMITIVE = "primitive"; + public static final String BLOB = "blob"; } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index b4b6ce995..b3ed6d2ca 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -38,6 +38,8 @@ import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; @@ -64,7 +66,7 @@ public static AutoValueInput createAutoValueInput( boolean b, Instant t, Duration d, - // Blob blob, + Blob blob, List l, Map m, List> ll, @@ -78,6 +80,7 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(b), SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), + SdkBindingDataFactory.of(blob), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.ofStringMap(m), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), @@ -119,11 +122,11 @@ public void testVariableMap() { void testFromLiteralMap() { Instant datetime = Instant.ofEpochSecond(12, 34); Duration duration = Duration.ofSeconds(56, 78); - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); Map literalMap = new HashMap<>(); literalMap.put("i", literalOf(Primitive.ofIntegerValue(123L))); literalMap.put("f", literalOf(Primitive.ofFloatValue(123.0))); @@ -131,7 +134,7 @@ void testFromLiteralMap() { literalMap.put("b", literalOf(Primitive.ofBooleanValue(true))); literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); - // literalMap.put("blob", literalOf(blob)); + literalMap.put("blob", literalOf(blob)); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); literalMap.put( @@ -159,9 +162,9 @@ void testFromLiteralMap() { Literal.ofMap( Map.of( "math", - Literal.ofMap( - Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), - "pokemon", Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); + Literal.ofMap(Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), + "pokemon", + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); AutoValueInput input = JacksonSdkType.of(AutoValueInput.class).fromLiteralMap(literalMap); @@ -175,7 +178,7 @@ void testFromLiteralMap() { /* b= */ true, /* t= */ datetime, /* d= */ duration, - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("123"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -194,11 +197,11 @@ private static Literal stringLiteralOf(String string) { @Test void testToLiteralMap() { - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); Map literalMap = JacksonSdkType.of(AutoValueInput.class) .toLiteralMap( @@ -209,7 +212,7 @@ void testToLiteralMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -271,13 +274,17 @@ void testToLiteralMap() { Map.of( "pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), "pokemon", - Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))) - // hasEntry("blob", literalOf(blob)) - ))); + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))), + hasEntry("blob", literalOf(blob))))); } @Test public void testToSdkBindingDataMap() { + Blob blob = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); AutoValueInput input = createAutoValueInput( /* i= */ 42L, @@ -286,7 +293,7 @@ public void testToSdkBindingDataMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ blob, /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -305,6 +312,7 @@ public void testToSdkBindingDataMap() { expected.put("b", input.b()); expected.put("t", input.t()); expected.put("d", input.d()); + expected.put("blob", input.blob()); expected.put("l", input.l()); expected.put("m", input.m()); expected.put("ll", input.ll()); @@ -536,8 +544,8 @@ public abstract static class AutoValueInput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + @BlobTypeDescription(format = "", dimensionality = "SINGLE") + public abstract SdkBindingData blob(); public abstract SdkBindingData> l(); @@ -558,7 +566,7 @@ public static AutoValueInput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, - // Blob blob, + SdkBindingData blob, SdkBindingData> l, SdkBindingData> m, SdkBindingData>> ll, @@ -566,7 +574,7 @@ public static AutoValueInput create( SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, l, m, ll, lm, ml, mm); + i, f, s, b, t, d, blob, l, m, ll, lm, ml, mm); } } @@ -701,4 +709,8 @@ private static Variable createVar(LiteralType literalType, String description) { private static Literal literalOf(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } + + private static Literal literalOf(Blob blob) { + return Literal.ofScalar(Scalar.ofBlob(blob)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java index 1c18ee674..f9fb9162b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.time.Instant; +import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; @@ -49,6 +50,10 @@ static Literal ofDuration(Duration value) { return ofPrimitive(Primitive.ofDuration(value)); } + static Literal ofBlob(Blob value) { + return Literal.ofScalar(Scalar.ofBlob(value)); + } + private static Literal ofPrimitive(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index 757e5bf10..f8c9d8d47 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -26,6 +26,7 @@ import java.time.ZoneOffset; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; /** A utility class for creating {@link SdkBindingData} objects for different types. */ public final class SdkBindingDataFactory { @@ -123,6 +124,10 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + public static SdkBindingData of(SdkLiteralType type, T value) { + return SdkBindingData.literal(type, value); + } + /** * Creates a {@code SdkBindingData} for a flyte collection of string given a java {@code * List}. @@ -296,4 +301,8 @@ public static SdkBindingData> ofBindingMap( return SdkBindingData.bindingMap(valuesType, valueMap); } + + public static SdkBindingData of(Blob value) { + return SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata().type()), value); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java index f09073ea9..6587399d2 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java @@ -25,6 +25,8 @@ import java.util.Map; import java.util.Map.Entry; import org.flyte.api.v1.BindingData; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -73,6 +75,8 @@ public static SdkLiteralType of(Class clazz) { return (SdkLiteralType) datetimes(); } else if (clazz.equals(Duration.class)) { return (SdkLiteralType) durations(); + } else if (clazz.equals(BlobType.class)) { + return null; } throw new IllegalArgumentException("Unsupported type: " + clazz); } @@ -181,6 +185,15 @@ public static SdkLiteralType> maps(SdkLiteralType mapValue return new MapSdkLiteralType<>(mapValueType); } + /** + * Returns a {@link SdkLiteralType} for blobs. + * + * @return the {@link SdkLiteralType} + */ + public static SdkLiteralType blobs(BlobType blobType) { + return new BlobSdkLiteralType(blobType); + } + private static class IntegerSdkLiteralType extends PrimitiveSdkLiteralType { private static final IntegerSdkLiteralType INSTANCE = new IntegerSdkLiteralType(); @@ -205,6 +218,39 @@ public String toString() { } } + private static class BlobSdkLiteralType extends SdkLiteralType { + private final BlobType blobType; + + public BlobSdkLiteralType(BlobType blobType) { + this.blobType = blobType; + } + + @Override + public LiteralType getLiteralType() { + return LiteralType.ofBlobType(blobType); + } + + @Override + public Literal toLiteral(Blob value) { + return Literals.ofBlob(value); + } + + @Override + public Blob fromLiteral(Literal literal) { + return literal.scalar().blob(); + } + + @Override + public BindingData toBindingData(Blob value) { + return BindingData.ofScalar(Scalar.ofBlob(value)); + } + + @Override + public String toString() { + return "blobs"; + } + } + private static class FloatSdkLiteralType extends PrimitiveSdkLiteralType { private static final FloatSdkLiteralType INSTANCE = new FloatSdkLiteralType(); diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 278035e50..2e843bf0c 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -51,6 +51,11 @@ auto-service-annotations provided + + org.flyte + flytekit-api + provided +