From 8ec7a78aa7684df0d64fce26682d98d7b8de1600 Mon Sep 17 00:00:00 2001 From: Honnix Date: Fri, 17 Nov 2023 15:34:28 +0100 Subject: [PATCH] Minor clean up of struct (#265) Signed-off-by: Hongxin Liang --- .../main/java/org/flyte/api/v1/BlobType.java | 4 + .../org/flyte/examples/AllInputsWorkflow.java | 10 +-- .../flytekit/jackson/VariableMapVisitor.java | 4 +- .../flytekit/jackson/JacksonSdkTypeTest.java | 3 +- .../flytekitscala/SdkLiteralTypesTest.scala | 2 + .../flytekitscala/SdkScalaTypeTest.scala | 78 +++++++++++++++++-- .../SdkBindingDataConverters.scala | 9 ++- .../flyte/flytekitscala/SdkLiteralTypes.scala | 4 + .../flyte/flytekitscala/SdkScalaType.scala | 12 +-- 9 files changed, 94 insertions(+), 32 deletions(-) diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java b/flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java index 9d730d82a..26269ac6e 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java @@ -21,6 +21,10 @@ /** Defines type behavior for blob objects. */ @AutoValue public abstract class BlobType { + + public static final BlobType DEFAULT = + BlobType.builder().dimensionality(BlobDimensionality.SINGLE).format("").build(); + public enum BlobDimensionality { SINGLE, MULTIPART diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 8bd9acc31..c81cf7d5e 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -27,7 +27,6 @@ import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; -import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; import org.flyte.examples.AllInputsTask.Nested; import org.flyte.flytekit.SdkBindingData; @@ -66,14 +65,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) SdkBindingDataFactory.of( Blob.builder() .uri("file://test/test.csv") - .metadata( - BlobMetadata.builder() - .type( - BlobType.builder() - .format("") - .dimensionality(BlobDimensionality.SINGLE) - .build()) - .build()) + .metadata(BlobMetadata.builder().type(BlobType.DEFAULT).build()) .build()), SdkBindingDataFactory.of( JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")), diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index c565898be..6b78841cc 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -32,7 +32,6 @@ import java.util.Map; import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobType; -import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkLiteralType; @@ -172,8 +171,7 @@ private SdkLiteralType toLiteralType( // fixme: create blob type from annotation, or rethink how we could offer the offloaded data // feature // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype - return SdkLiteralTypes.blobs( - BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build()); + return SdkLiteralTypes.blobs(BlobType.DEFAULT); } try { return JacksonSdkLiteralType.of(type); diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index a1f969961..4ae10036e 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -56,8 +56,7 @@ public class JacksonSdkTypeTest { - private static final BlobType BLOB_TYPE = - BlobType.builder().format("").dimensionality(BlobType.BlobDimensionality.SINGLE).build(); + private static final BlobType BLOB_TYPE = BlobType.DEFAULT; private static final Blob BLOB = Blob.builder() diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala index 7e5ed0f4e..5d169ad54 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -66,6 +66,8 @@ class TestOfReturnsProperTypeProvider extends ArgumentsProvider { Arguments.of(booleans(), of[Boolean]()), Arguments.of(datetimes(), of[Instant]()), Arguments.of(durations(), of[Duration]()), + Arguments.of(blobs(BlobType.DEFAULT), of[Blob]()), + Arguments.of(generics(), of[ScalarNested]()), Arguments.of(collections(integers()), of[List[Long]]()), Arguments.of(collections(floats()), of[List[Double]]()), Arguments.of(collections(strings()), of[List[String]]()), diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index a3cbaf863..720002000 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -29,6 +29,7 @@ import org.flyte.api.v1.{ Primitive, Scalar, SimpleType, + Struct, Variable } import org.flyte.flytekit.{ @@ -40,17 +41,32 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested} import org.flyte.flytekit.jackson.JacksonSdkLiteralType -import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings} +import org.flyte.flytekitscala.SdkLiteralTypes.{ + blobs, + collections, + maps, + strings +} + +// The constructor is reflectedly invoked so it cannot be an inner class +case class ScalarNested(foo: String, bar: String) class SdkScalaTypeTest { + private val blob = Blob.builder + .metadata(BlobMetadata.builder.`type`(BlobType.DEFAULT).build) + .uri("file://test") + .build + case class ScalarInput( string: SdkBindingData[String], integer: SdkBindingData[Long], float: SdkBindingData[Double], boolean: SdkBindingData[Boolean], datetime: SdkBindingData[Instant], - duration: SdkBindingData[Duration] + duration: SdkBindingData[Duration], + blob: SdkBindingData[Blob], + generic: SdkBindingData[ScalarNested] ) case class CollectionInput( @@ -116,7 +132,13 @@ class SdkScalaTypeTest { "float" -> createVar(SimpleType.FLOAT), "boolean" -> createVar(SimpleType.BOOLEAN), "datetime" -> createVar(SimpleType.DATETIME), - "duration" -> createVar(SimpleType.DURATION) + "duration" -> createVar(SimpleType.DURATION), + "blob" -> Variable + .builder() + .literalType(LiteralType.ofBlobType(BlobType.DEFAULT)) + .description("") + .build(), + "generic" -> createVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[ScalarInput].getVariableMap @@ -149,6 +171,17 @@ class SdkScalaTypeTest { ), "duration" -> Literal.ofScalar( Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456))) + ), + "blob" -> Literal.ofScalar(Scalar.ofBlob(blob)), + "generic" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of( + Map( + "foo" -> Struct.Value.ofStringValue("foo"), + "bar" -> Struct.Value.ofStringValue("bar") + ).asJava + ) + ) ) ).asJava @@ -159,7 +192,12 @@ class SdkScalaTypeTest { float = SdkBindingDataFactory.of(42.0), boolean = SdkBindingDataFactory.of(true), datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), - duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)), + blob = SdkBindingDataFactory.of(blob), + generic = SdkBindingDataFactory.of( + SdkLiteralTypes.generics(), + ScalarNested("foo", "bar") + ) ) val output = SdkScalaType[ScalarInput].fromLiteralMap(input) @@ -176,7 +214,12 @@ class SdkScalaTypeTest { float = SdkBindingDataFactory.of(42.0), boolean = SdkBindingDataFactory.of(true), datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), - duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)), + blob = SdkBindingDataFactory.of(blob), + generic = SdkBindingDataFactory.of( + SdkLiteralTypes.generics(), + ScalarNested("foo", "bar") + ) ) val expected = Map( @@ -195,6 +238,17 @@ class SdkScalaTypeTest { ), "duration" -> Literal.ofScalar( Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456))) + ), + "blob" -> Literal.ofScalar(Scalar.ofBlob(blob)), + "generic" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of( + Map( + "foo" -> Struct.Value.ofStringValue("foo"), + "bar" -> Struct.Value.ofStringValue("bar") + ).asJava + ) + ) ) ).asJava @@ -227,7 +281,12 @@ class SdkScalaTypeTest { float = SdkBindingDataFactory.of(42.0), boolean = SdkBindingDataFactory.of(true), datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), - duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)), + blob = SdkBindingDataFactory.of(blob), + generic = SdkBindingDataFactory.of( + SdkLiteralTypes.generics(), + ScalarNested("foo", "bar") + ) ) val output = SdkScalaType[ScalarInput].toSdkBindingMap(input) @@ -238,7 +297,12 @@ class SdkScalaTypeTest { "float" -> SdkBindingDataFactory.of(42.0), "boolean" -> SdkBindingDataFactory.of(true), "datetime" -> SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), - "duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) + "duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)), + "blob" -> SdkBindingDataFactory.of(blob), + "generic" -> SdkBindingDataFactory.of( + SdkLiteralTypes.generics[ScalarNested](), + ScalarNested("foo", "bar") + ) ).asJava assertEquals(expected, output) diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala index 0ed989a20..368d09ba4 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala @@ -150,7 +150,10 @@ object SdkBindingDataConverters { SdkScalaLiteralTypes.strings(), jf.Function.identity() ) - case SimpleType.STRUCT => ??? // TODO not yet supported + case SimpleType.STRUCT => + throw new UnsupportedOperationException( + "Converting Scala case class instance to Java object is not supported" + ) case SimpleType.BOOLEAN => TypeCastingResult( SdkScalaLiteralTypes.booleans(), @@ -239,7 +242,9 @@ object SdkBindingDataConverters { jf.Function.identity() ) case SimpleType.STRUCT => - ??? // TODO how to handle? do we support structs already? + throw new UnsupportedOperationException( + "Converting Java object to Scala case class instance is not supported" + ) case SimpleType.BOOLEAN => TypeCastingResult( SdkJavaLiteralTypes.booleans(), diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 21fd1597a..6a36f3520 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -66,6 +66,10 @@ object SdkLiteralTypes { datetimes().asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[Duration] => durations().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Blob] => + blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]] + case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) => + generics().asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[List[Long]] => collections(integers()).asInstanceOf[SdkLiteralType[T]] diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 4dd706f5e..a0a7fb107 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -30,8 +30,8 @@ import org.flyte.flytekit.{ import scala.annotation.implicitNotFound import scala.collection.JavaConverters._ -import scala.reflect.{ClassTag, classTag} -import scala.reflect.runtime.universe.{TypeTag, typeOf} +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag /** Type class to map between Flyte `Variable` and `Literal` and Scala case * classes. @@ -245,13 +245,7 @@ object SdkScalaType { // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype implicit def blobLiteralType: SdkScalaLiteralType[Blob] = DelegateLiteralType( - SdkLiteralTypes.blobs( - BlobType - .builder() - .format("") - .dimensionality(BlobDimensionality.SINGLE) - .build() - ) + SdkLiteralTypes.blobs(BlobType.DEFAULT) ) // TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData