Skip to content

Commit

Permalink
Support struct
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 9, 2023
1 parent 8d360ba commit 7032727
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 190 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=2.5.2
version=3.7.14
runner.dialect=scala212source3

Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,20 @@ public AllInputsTask() {
JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class));
}

@AutoValue
public abstract static class Nested {
public abstract String hello();

public abstract String world();

public static Nested create(String hello, String world) {
return new AutoValue_AllInputsTask_Nested(hello, world);
}
}

@AutoValue
public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Long> i();

public abstract SdkBindingData<Double> f();
Expand All @@ -54,6 +66,8 @@ public abstract static class AutoAllInputsInput {
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -70,12 +84,13 @@ public static AutoAllInputsInput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsInput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -97,6 +112,8 @@ public abstract static class AutoAllInputsOutput {
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -113,12 +130,13 @@ public static AutoAllInputsOutput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsOutput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -132,6 +150,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) {
input.t(),
input.d(),
input.blob(),
input.generic(),
input.l(),
input.m(),
input.emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.examples.AllInputsTask.Nested;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.BlobTypeDescription;
import org.flyte.flytekit.jackson.JacksonSdkLiteralType;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
Expand Down Expand Up @@ -74,6 +76,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
.build())
.build())
.build()),
SdkBindingDataFactory.of(
JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")),
SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")),
SdkBindingDataFactory.ofStringMap(Map.of("test", "test")),
SdkBindingDataFactory.ofStringCollection(Collections.emptyList()),
Expand All @@ -89,6 +93,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
outputs.t(),
outputs.d(),
outputs.blob(),
outputs.generic(),
outputs.l(),
outputs.m(),
outputs.emptyList(),
Expand All @@ -113,6 +118,8 @@ public abstract static class AllInputsWorkflowOutput {
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -129,12 +136,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.SimpleType;
import org.flyte.flytekit.SdkLiteralType;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;

/**
* Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link
Expand Down Expand Up @@ -102,7 +103,8 @@ public Literal toLiteral(T value) {
var tree = OBJECT_MAPPER.valueToTree(value);

try {
return OBJECT_MAPPER.treeToValue(tree, Literal.class);
return Literal.ofScalar(
Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap()));
} catch (IOException e) {
throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.module.SimpleDeserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.api.v1.Literal;
import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;
import org.flyte.flytekit.jackson.serializers.StructSerializer;

class SdkLiteralTypeModule extends Module {
Expand All @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

var deserializers = new SimpleDeserializers();
deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer());
deserializers.addDeserializer(StructWrapper.class, new StructDeserializer());
context.addDeserializers(deserializers);

// append with the lowest priority to use as fallback, if builtin annotations aren't present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.deser.Deserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers;
import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers;
import org.flyte.flytekit.jackson.serializers.BindingMapSerializers;
import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers;
Expand Down Expand Up @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

context.addSerializers(new LiteralMapSerializers());
context.addDeserializers(new LiteralMapDeserializers());
context.addSerializers(new BindingMapSerializers());
context.addDeserializers(sdkbindingDeserializers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ private SdkLiteralType<?> toLiteralType(
.dimensionality(annotation.dimensionality())
.build());
}
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
try {
return JacksonSdkLiteralType.of(type);
} catch (Exception e) {
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()), e);
}
}

private static boolean isPrimitiveAssignableFrom(Class<?> fromClass, Class<?> toClass) {
Expand Down

This file was deleted.

Loading

0 comments on commit 7032727

Please sign in to comment.