Skip to content

Commit

Permalink
Handle Option correctly (#266)
Browse files Browse the repository at this point in the history
* Handle Option correctly

Signed-off-by: Hongxin Liang <[email protected]>

* Clearer comment

Signed-off-by: Hongxin Liang <[email protected]>

* IT

Signed-off-by: Hongxin Liang <[email protected]>

* Early return of None

Signed-off-by: Hongxin Liang <[email protected]>

* Enrich IT workflow

Signed-off-by: Hongxin Liang <[email protected]>

* Resource

Signed-off-by: Hongxin Liang <[email protected]>

---------

Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Nov 23, 2023
1 parent 8ec7a78 commit 21c9e2e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ org.flyte.examples.flytekitscala.GreetTask
org.flyte.examples.flytekitscala.AddQuestionTask
org.flyte.examples.flytekitscala.NoInputsTask
org.flyte.examples.flytekitscala.NestedIOTask
org.flyte.examples.flytekitscala.NestedIOTaskNoop
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
6.toDouble,
"hello",
List("1", "2"),
List(NestedNested(7.toDouble, NestedNestedNested("world"))),
List(NestedNested(7.toDouble, Some(NestedNestedNested("world")))),
Map("1" -> "1", "2" -> "2"),
Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))),
Map(
"foo" -> NestedNested(
7.toDouble,
Some(NestedNestedNested("world"))
)
),
Some(false),
None,
Some(List("3", "4")),
Some(Map("3" -> "3", "4" -> "4")),
NestedNested(7.toDouble, NestedNestedNested("world"))
NestedNested(7.toDouble, Some(NestedNestedNested("world")))
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.flyte.flytekitscala.{
}

case class NestedNestedNested(string: String)
case class NestedNested(double: Double, nested: NestedNestedNested)
case class NestedNested(double: Double, nested: Option[NestedNestedNested])
case class Nested(
boolean: Boolean,
byte: Byte,
Expand Down Expand Up @@ -57,9 +57,6 @@ case class NestedIOTaskOutput(
generic: SdkBindingData[Nested]
)

/** Example Flyte task that takes a name as the input and outputs a simple
* greeting message.
*/
class NestedIOTask
extends SdkRunnableTask[
NestedIOTaskInput,
Expand All @@ -69,17 +66,21 @@ class NestedIOTask
SdkScalaType[NestedIOTaskOutput]
) {

/** Defines task behavior. This task takes a name as the input, wraps it in a
* welcome message, and outputs the message.
*
* @param input
* the name of the person to be greeted
* @return
* the welcome message
*/
override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
NestedIOTaskOutput(
input.name,
input.generic
)
}

class NestedIOTaskNoop
extends SdkRunnableTask[
NestedIOTaskOutput,
NestedIOTaskOutput
](
SdkScalaType[NestedIOTaskOutput],
SdkScalaType[NestedIOTaskOutput]
) {

override def run(input: NestedIOTaskOutput): NestedIOTaskOutput = input
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NestedIOWorkflow
builder: SdkScalaWorkflowBuilder,
input: NestedIOTaskInput
): Unit = {
builder.apply(new NestedIOTask(), input)
val output = builder.apply(new NestedIOTask(), input)
builder.apply(new NestedIOTaskNoop(), output.getOutputs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ import org.flyte.flytekitscala.SdkLiteralTypes.{
}

// The constructor is reflectedly invoked so it cannot be an inner class
case class ScalarNested(foo: String, bar: String)
case class ScalarNested(
foo: String,
bar: Option[String],
nestedNested: Option[ScalarNestedNested]
)
case class ScalarNestedNested(foo: String, bar: Option[String])

class SdkScalaTypeTest {

Expand Down Expand Up @@ -178,7 +183,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofNullValue(),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand All @@ -196,7 +209,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
None,
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -218,7 +235,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -245,7 +266,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofStringValue("bar"),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand Down Expand Up @@ -285,7 +314,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -301,7 +334,11 @@ class SdkScalaTypeTest {
"blob" -> SdkBindingDataFactory.of(blob),
"generic" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[ScalarNested](),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
).asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,41 +297,39 @@ object SdkLiteralTypes {
): S = {
val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)

def valueToParamValue(value: Any, param: Symbol): Any = {
def valueToParamValue0(value: Any, param: Symbol): Any = {
if (param.typeSignature =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (param.typeSignature =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (param.typeSignature =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (param.typeSignature =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (param.typeSignature =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (param.typeSignature <:< typeOf[Product]) {
val typeTag = createTypeTag(param.typeSignature)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(param.typeSignature)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
def valueToParamValue(value: Any, tpe: Type): Any = {
if (tpe =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (tpe =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (tpe =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (tpe =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (tpe =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (tpe <:< typeOf[Option[Any]]) { // this has to be before Product check because Option is a Product
if (value == None) { // None is used to represent Struct.Value.Kind.NULL_VALUE when converting struct to map
None
} else {
value
}
}

if (param.typeSignature <:< typeOf[Option[Any]]) {
Some(
valueToParamValue0(
value,
param.typeSignature.dealias.typeArgs.head.typeSymbol
Some(
valueToParamValue(
value,
tpe.dealias.typeArgs.head
)
)
}
} else if (tpe <:< typeOf[Product]) {
val typeTag = createTypeTag(tpe)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(tpe)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
} else {
valueToParamValue0(value, param)
value
}
}

Expand Down Expand Up @@ -371,7 +369,7 @@ object SdkLiteralTypes {
s"Map is missing required parameter named $paramName"
)
)
valueToParamValue(value, param)
valueToParamValue(value, param.typeSignature.dealias)
})

constructorMirror(constructorArgs: _*).asInstanceOf[S]
Expand Down

0 comments on commit 21c9e2e

Please sign in to comment.