diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask index 508e6cb51..201af9a92 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask @@ -4,3 +4,4 @@ org.flyte.examples.flytekitscala.GreetTask org.flyte.examples.flytekitscala.AddQuestionTask org.flyte.examples.flytekitscala.NoInputsTask org.flyte.examples.flytekitscala.NestedIOTask +org.flyte.examples.flytekitscala.NestedIOTaskNoop diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala index df5c3b438..ae9e19aca 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala @@ -73,14 +73,19 @@ class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry { 6.toDouble, "hello", List("1", "2"), - List(NestedNested(7.toDouble, NestedNestedNested("world"))), + List(NestedNested(7.toDouble, Some(NestedNestedNested("world")))), Map("1" -> "1", "2" -> "2"), - Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))), + Map( + "foo" -> NestedNested( + 7.toDouble, + Some(NestedNestedNested("world")) + ) + ), Some(false), None, Some(List("3", "4")), Some(Map("3" -> "3", "4" -> "4")), - NestedNested(7.toDouble, NestedNestedNested("world")) + NestedNested(7.toDouble, Some(NestedNestedNested("world"))) ) ) ) diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala index ef4d61245..6f6c165a3 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala @@ -24,7 +24,7 @@ import org.flyte.flytekitscala.{ } case class NestedNestedNested(string: String) -case class NestedNested(double: Double, nested: NestedNestedNested) +case class NestedNested(double: Double, nested: Option[NestedNestedNested]) case class Nested( boolean: Boolean, byte: Byte, @@ -57,9 +57,6 @@ case class NestedIOTaskOutput( generic: SdkBindingData[Nested] ) -/** Example Flyte task that takes a name as the input and outputs a simple - * greeting message. - */ class NestedIOTask extends SdkRunnableTask[ NestedIOTaskInput, @@ -69,17 +66,21 @@ class NestedIOTask SdkScalaType[NestedIOTaskOutput] ) { - /** Defines task behavior. This task takes a name as the input, wraps it in a - * welcome message, and outputs the message. - * - * @param input - * the name of the person to be greeted - * @return - * the welcome message - */ override def run(input: NestedIOTaskInput): NestedIOTaskOutput = NestedIOTaskOutput( input.name, input.generic ) } + +class NestedIOTaskNoop + extends SdkRunnableTask[ + NestedIOTaskOutput, + NestedIOTaskOutput + ]( + SdkScalaType[NestedIOTaskOutput], + SdkScalaType[NestedIOTaskOutput] + ) { + + override def run(input: NestedIOTaskOutput): NestedIOTaskOutput = input +} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala index dfe996650..bd9268738 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala @@ -32,6 +32,7 @@ class NestedIOWorkflow builder: SdkScalaWorkflowBuilder, input: NestedIOTaskInput ): Unit = { - builder.apply(new NestedIOTask(), input) + val output = builder.apply(new NestedIOTask(), input) + builder.apply(new NestedIOTaskNoop(), output.getOutputs) } } diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 720002000..2424c8237 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -49,7 +49,12 @@ import org.flyte.flytekitscala.SdkLiteralTypes.{ } // The constructor is reflectedly invoked so it cannot be an inner class -case class ScalarNested(foo: String, bar: String) +case class ScalarNested( + foo: String, + bar: Option[String], + nestedNested: Option[ScalarNestedNested] +) +case class ScalarNestedNested(foo: String, bar: Option[String]) class SdkScalaTypeTest { @@ -178,7 +183,15 @@ class SdkScalaTypeTest { Struct.of( Map( "foo" -> Struct.Value.ofStringValue("foo"), - "bar" -> Struct.Value.ofStringValue("bar") + "bar" -> Struct.Value.ofNullValue(), + "nestedNested" -> Struct.Value.ofStructValue( + Struct.of( + Map( + "foo" -> Struct.Value.ofStringValue("foo"), + "bar" -> Struct.Value.ofStringValue("bar") + ).asJava + ) + ) ).asJava ) ) @@ -196,7 +209,11 @@ class SdkScalaTypeTest { blob = SdkBindingDataFactory.of(blob), generic = SdkBindingDataFactory.of( SdkLiteralTypes.generics(), - ScalarNested("foo", "bar") + ScalarNested( + "foo", + None, + Some(ScalarNestedNested("foo", Some("bar"))) + ) ) ) @@ -218,7 +235,11 @@ class SdkScalaTypeTest { blob = SdkBindingDataFactory.of(blob), generic = SdkBindingDataFactory.of( SdkLiteralTypes.generics(), - ScalarNested("foo", "bar") + ScalarNested( + "foo", + Some("bar"), + Some(ScalarNestedNested("foo", Some("bar"))) + ) ) ) @@ -245,7 +266,15 @@ class SdkScalaTypeTest { Struct.of( Map( "foo" -> Struct.Value.ofStringValue("foo"), - "bar" -> Struct.Value.ofStringValue("bar") + "bar" -> Struct.Value.ofStringValue("bar"), + "nestedNested" -> Struct.Value.ofStructValue( + Struct.of( + Map( + "foo" -> Struct.Value.ofStringValue("foo"), + "bar" -> Struct.Value.ofStringValue("bar") + ).asJava + ) + ) ).asJava ) ) @@ -285,7 +314,11 @@ class SdkScalaTypeTest { blob = SdkBindingDataFactory.of(blob), generic = SdkBindingDataFactory.of( SdkLiteralTypes.generics(), - ScalarNested("foo", "bar") + ScalarNested( + "foo", + Some("bar"), + Some(ScalarNestedNested("foo", Some("bar"))) + ) ) ) @@ -301,7 +334,11 @@ class SdkScalaTypeTest { "blob" -> SdkBindingDataFactory.of(blob), "generic" -> SdkBindingDataFactory.of( SdkLiteralTypes.generics[ScalarNested](), - ScalarNested("foo", "bar") + ScalarNested( + "foo", + Some("bar"), + Some(ScalarNestedNested("foo", Some("bar"))) + ) ) ).asJava diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 6a36f3520..517ec24dc 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -297,41 +297,39 @@ object SdkLiteralTypes { ): S = { val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader) - def valueToParamValue(value: Any, param: Symbol): Any = { - def valueToParamValue0(value: Any, param: Symbol): Any = { - if (param.typeSignature =:= typeOf[Byte]) { - value.asInstanceOf[Double].toByte - } else if (param.typeSignature =:= typeOf[Short]) { - value.asInstanceOf[Double].toShort - } else if (param.typeSignature =:= typeOf[Int]) { - value.asInstanceOf[Double].toInt - } else if (param.typeSignature =:= typeOf[Long]) { - value.asInstanceOf[Double].toLong - } else if (param.typeSignature =:= typeOf[Float]) { - value.asInstanceOf[Double].toFloat - } else if (param.typeSignature <:< typeOf[Product]) { - val typeTag = createTypeTag(param.typeSignature) - val classTag = ClassTag( - typeTag.mirror.runtimeClass(param.typeSignature) - ) - mapToProduct(value.asInstanceOf[Map[String, Any]])( - typeTag, - classTag - ) + def valueToParamValue(value: Any, tpe: Type): Any = { + if (tpe =:= typeOf[Byte]) { + value.asInstanceOf[Double].toByte + } else if (tpe =:= typeOf[Short]) { + value.asInstanceOf[Double].toShort + } else if (tpe =:= typeOf[Int]) { + value.asInstanceOf[Double].toInt + } else if (tpe =:= typeOf[Long]) { + value.asInstanceOf[Double].toLong + } else if (tpe =:= typeOf[Float]) { + value.asInstanceOf[Double].toFloat + } else if (tpe <:< typeOf[Option[Any]]) { // this has to be before Product check because Option is a Product + if (value == None) { // None is used to represent Struct.Value.Kind.NULL_VALUE when converting struct to map + None } else { - value - } - } - - if (param.typeSignature <:< typeOf[Option[Any]]) { - Some( - valueToParamValue0( - value, - param.typeSignature.dealias.typeArgs.head.typeSymbol + Some( + valueToParamValue( + value, + tpe.dealias.typeArgs.head + ) ) + } + } else if (tpe <:< typeOf[Product]) { + val typeTag = createTypeTag(tpe) + val classTag = ClassTag( + typeTag.mirror.runtimeClass(tpe) + ) + mapToProduct(value.asInstanceOf[Map[String, Any]])( + typeTag, + classTag ) } else { - valueToParamValue0(value, param) + value } } @@ -371,7 +369,7 @@ object SdkLiteralTypes { s"Map is missing required parameter named $paramName" ) ) - valueToParamValue(value, param) + valueToParamValue(value, param.typeSignature.dealias) }) constructorMirror(constructorArgs: _*).asInstanceOf[S]