Skip to content

Commit

Permalink
Support struct
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 11, 2023
1 parent b246564 commit 4de3436
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 186 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=2.5.2
version=3.7.14
runner.dialect=scala212source3

Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,20 @@ public AllInputsTask() {
JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class));
}

@AutoValue
public abstract static class Nested {
public abstract String hello();

public abstract String world();

public static Nested create(String hello, String world) {
return new AutoValue_AllInputsTask_Nested(hello, world);
}
}

@AutoValue
public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Long> i();

public abstract SdkBindingData<Double> f();
Expand All @@ -51,6 +63,8 @@ public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -67,12 +81,13 @@ public static AutoAllInputsInput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsInput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -93,6 +108,8 @@ public abstract static class AutoAllInputsOutput {

public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -109,12 +126,13 @@ public static AutoAllInputsOutput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsOutput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -128,6 +146,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) {
input.t(),
input.d(),
input.blob(),
input.generic(),
input.l(),
input.m(),
input.emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.examples.AllInputsTask.Nested;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkLiteralType;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
Expand Down Expand Up @@ -73,6 +75,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
.build())
.build())
.build()),
SdkBindingDataFactory.of(
JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")),
SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")),
SdkBindingDataFactory.ofStringMap(Map.of("test", "test")),
SdkBindingDataFactory.ofStringCollection(Collections.emptyList()),
Expand All @@ -88,6 +92,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
outputs.t(),
outputs.d(),
outputs.blob(),
outputs.generic(),
outputs.l(),
outputs.m(),
outputs.emptyList(),
Expand All @@ -111,6 +116,8 @@ public abstract static class AllInputsWorkflowOutput {

public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

public abstract SdkBindingData<Map<String, String>> m();
Expand All @@ -127,12 +134,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create(
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput(
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.SimpleType;
import org.flyte.flytekit.SdkLiteralType;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;

/**
* Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link
Expand Down Expand Up @@ -102,7 +103,8 @@ public Literal toLiteral(T value) {
var tree = OBJECT_MAPPER.valueToTree(value);

try {
return OBJECT_MAPPER.treeToValue(tree, Literal.class);
return Literal.ofScalar(
Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap()));
} catch (IOException e) {
throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.module.SimpleDeserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.api.v1.Literal;
import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;
import org.flyte.flytekit.jackson.serializers.StructSerializer;

class SdkLiteralTypeModule extends Module {
Expand All @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

var deserializers = new SimpleDeserializers();
deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer());
deserializers.addDeserializer(StructWrapper.class, new StructDeserializer());
context.addDeserializers(deserializers);

// append with the lowest priority to use as fallback, if builtin annotations aren't present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.deser.Deserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers;
import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers;
import org.flyte.flytekit.jackson.serializers.BindingMapSerializers;
import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers;
Expand Down Expand Up @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

context.addSerializers(new LiteralMapSerializers());
context.addDeserializers(new LiteralMapDeserializers());
context.addSerializers(new BindingMapSerializers());
context.addDeserializers(sdkbindingDeserializers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ private SdkLiteralType<?> toLiteralType(
return SdkLiteralTypes.blobs(
BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build());
}
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
try {
return JacksonSdkLiteralType.of(type);
} catch (Exception e) {
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()), e);
}
}

private static boolean isPrimitiveAssignableFrom(Class<?> fromClass, Class<?> toClass) {
Expand Down

This file was deleted.

Loading

0 comments on commit 4de3436

Please sign in to comment.