Skip to content

Commit

Permalink
Minor clean up of struct (#265)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Nov 17, 2023
1 parent bd369ca commit 8ec7a78
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 32 deletions.
4 changes: 4 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
/** Defines type behavior for blob objects. */
@AutoValue
public abstract class BlobType {

public static final BlobType DEFAULT =
BlobType.builder().dimensionality(BlobDimensionality.SINGLE).format("").build();

public enum BlobDimensionality {
SINGLE,
MULTIPART
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.examples.AllInputsTask.Nested;
import org.flyte.flytekit.SdkBindingData;
Expand Down Expand Up @@ -66,14 +65,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
SdkBindingDataFactory.of(
Blob.builder()
.uri("file://test/test.csv")
.metadata(
BlobMetadata.builder()
.type(
BlobType.builder()
.format("")
.dimensionality(BlobDimensionality.SINGLE)
.build())
.build())
.metadata(BlobMetadata.builder().type(BlobType.DEFAULT).build())
.build()),
SdkBindingDataFactory.of(
JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Variable;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkLiteralType;
Expand Down Expand Up @@ -172,8 +171,7 @@ private SdkLiteralType<?> toLiteralType(
// fixme: create blob type from annotation, or rethink how we could offer the offloaded data
// feature
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
return SdkLiteralTypes.blobs(
BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build());
return SdkLiteralTypes.blobs(BlobType.DEFAULT);
}
try {
return JacksonSdkLiteralType.of(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@

public class JacksonSdkTypeTest {

private static final BlobType BLOB_TYPE =
BlobType.builder().format("").dimensionality(BlobType.BlobDimensionality.SINGLE).build();
private static final BlobType BLOB_TYPE = BlobType.DEFAULT;

private static final Blob BLOB =
Blob.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class TestOfReturnsProperTypeProvider extends ArgumentsProvider {
Arguments.of(booleans(), of[Boolean]()),
Arguments.of(datetimes(), of[Instant]()),
Arguments.of(durations(), of[Duration]()),
Arguments.of(blobs(BlobType.DEFAULT), of[Blob]()),
Arguments.of(generics(), of[ScalarNested]()),
Arguments.of(collections(integers()), of[List[Long]]()),
Arguments.of(collections(floats()), of[List[Double]]()),
Arguments.of(collections(strings()), of[List[String]]()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.flyte.api.v1.{
Primitive,
Scalar,
SimpleType,
Struct,
Variable
}
import org.flyte.flytekit.{
Expand All @@ -40,17 +41,32 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows}
import org.junit.jupiter.api.Test
import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested}
import org.flyte.flytekit.jackson.JacksonSdkLiteralType
import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings}
import org.flyte.flytekitscala.SdkLiteralTypes.{
blobs,
collections,
maps,
strings
}

// The constructor is reflectedly invoked so it cannot be an inner class
case class ScalarNested(foo: String, bar: String)

class SdkScalaTypeTest {

private val blob = Blob.builder
.metadata(BlobMetadata.builder.`type`(BlobType.DEFAULT).build)
.uri("file://test")
.build

case class ScalarInput(
string: SdkBindingData[String],
integer: SdkBindingData[Long],
float: SdkBindingData[Double],
boolean: SdkBindingData[Boolean],
datetime: SdkBindingData[Instant],
duration: SdkBindingData[Duration]
duration: SdkBindingData[Duration],
blob: SdkBindingData[Blob],
generic: SdkBindingData[ScalarNested]
)

case class CollectionInput(
Expand Down Expand Up @@ -116,7 +132,13 @@ class SdkScalaTypeTest {
"float" -> createVar(SimpleType.FLOAT),
"boolean" -> createVar(SimpleType.BOOLEAN),
"datetime" -> createVar(SimpleType.DATETIME),
"duration" -> createVar(SimpleType.DURATION)
"duration" -> createVar(SimpleType.DURATION),
"blob" -> Variable
.builder()
.literalType(LiteralType.ofBlobType(BlobType.DEFAULT))
.description("")
.build(),
"generic" -> createVar(SimpleType.STRUCT)
).asJava

val output = SdkScalaType[ScalarInput].getVariableMap
Expand Down Expand Up @@ -149,6 +171,17 @@ class SdkScalaTypeTest {
),
"duration" -> Literal.ofScalar(
Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456)))
),
"blob" -> Literal.ofScalar(Scalar.ofBlob(blob)),
"generic" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
)
).asJava

Expand All @@ -159,7 +192,12 @@ class SdkScalaTypeTest {
float = SdkBindingDataFactory.of(42.0),
boolean = SdkBindingDataFactory.of(true),
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
)
)

val output = SdkScalaType[ScalarInput].fromLiteralMap(input)
Expand All @@ -176,7 +214,12 @@ class SdkScalaTypeTest {
float = SdkBindingDataFactory.of(42.0),
boolean = SdkBindingDataFactory.of(true),
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
)
)

val expected = Map(
Expand All @@ -195,6 +238,17 @@ class SdkScalaTypeTest {
),
"duration" -> Literal.ofScalar(
Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456)))
),
"blob" -> Literal.ofScalar(Scalar.ofBlob(blob)),
"generic" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
)
).asJava

Expand Down Expand Up @@ -227,7 +281,12 @@ class SdkScalaTypeTest {
float = SdkBindingDataFactory.of(42.0),
boolean = SdkBindingDataFactory.of(true),
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
)
)

val output = SdkScalaType[ScalarInput].toSdkBindingMap(input)
Expand All @@ -238,7 +297,12 @@ class SdkScalaTypeTest {
"float" -> SdkBindingDataFactory.of(42.0),
"boolean" -> SdkBindingDataFactory.of(true),
"datetime" -> SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
"duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
"duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
"blob" -> SdkBindingDataFactory.of(blob),
"generic" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[ScalarNested](),
ScalarNested("foo", "bar")
)
).asJava

assertEquals(expected, output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ object SdkBindingDataConverters {
SdkScalaLiteralTypes.strings(),
jf.Function.identity()
)
case SimpleType.STRUCT => ??? // TODO not yet supported
case SimpleType.STRUCT =>
throw new UnsupportedOperationException(
"Converting Scala case class instance to Java object is not supported"
)
case SimpleType.BOOLEAN =>
TypeCastingResult(
SdkScalaLiteralTypes.booleans(),
Expand Down Expand Up @@ -239,7 +242,9 @@ object SdkBindingDataConverters {
jf.Function.identity()
)
case SimpleType.STRUCT =>
??? // TODO how to handle? do we support structs already?
throw new UnsupportedOperationException(
"Converting Java object to Scala case class instance is not supported"
)
case SimpleType.BOOLEAN =>
TypeCastingResult(
SdkJavaLiteralTypes.booleans(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ object SdkLiteralTypes {
datetimes().asInstanceOf[SdkLiteralType[T]]
case t if t =:= typeOf[Duration] =>
durations().asInstanceOf[SdkLiteralType[T]]
case t if t =:= typeOf[Blob] =>
blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]]
case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) =>
generics().asInstanceOf[SdkLiteralType[T]]

case t if t =:= typeOf[List[Long]] =>
collections(integers()).asInstanceOf[SdkLiteralType[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import org.flyte.flytekit.{

import scala.annotation.implicitNotFound
import scala.collection.JavaConverters._
import scala.reflect.{ClassTag, classTag}
import scala.reflect.runtime.universe.{TypeTag, typeOf}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

/** Type class to map between Flyte `Variable` and `Literal` and Scala case
* classes.
Expand Down Expand Up @@ -245,13 +245,7 @@ object SdkScalaType {
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
implicit def blobLiteralType: SdkScalaLiteralType[Blob] =
DelegateLiteralType(
SdkLiteralTypes.blobs(
BlobType
.builder()
.format("")
.dimensionality(BlobDimensionality.SINGLE)
.build()
)
SdkLiteralTypes.blobs(BlobType.DEFAULT)
)

// TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData<java.util.Long>
Expand Down

0 comments on commit 8ec7a78

Please sign in to comment.