Skip to content

Commit

Permalink
Support struct
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 8, 2023
1 parent 3280083 commit 9f36a57
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.SimpleType;
import org.flyte.flytekit.SdkLiteralType;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;

/**
* Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link
Expand Down Expand Up @@ -102,7 +103,8 @@ public Literal toLiteral(T value) {
var tree = OBJECT_MAPPER.valueToTree(value);

try {
return OBJECT_MAPPER.treeToValue(tree, Literal.class);
return Literal.ofScalar(
Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap()));
} catch (IOException e) {
throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.module.SimpleDeserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.api.v1.Literal;
import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;
import org.flyte.flytekit.jackson.serializers.StructSerializer;

class SdkLiteralTypeModule extends Module {
Expand All @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

var deserializers = new SimpleDeserializers();
deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer());
deserializers.addDeserializer(StructWrapper.class, new StructDeserializer());
context.addDeserializers(deserializers);

// append with the lowest priority to use as fallback, if builtin annotations aren't present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.deser.Deserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers;
import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers;
import org.flyte.flytekit.jackson.serializers.BindingMapSerializers;
import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers;
Expand Down Expand Up @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) {
context.addSerializers(serializers);

context.addSerializers(new LiteralMapSerializers());
context.addDeserializers(new LiteralMapDeserializers());
context.addSerializers(new BindingMapSerializers());
context.addDeserializers(sdkbindingDeserializers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ private SdkLiteralType<?> toLiteralType(
.dimensionality(annotation.dimensionality())
.build());
}
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
try {
return JacksonSdkLiteralType.of(type);
} catch (Exception e) {
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()), e);
}
}

private static boolean isPrimitiveAssignableFrom(Class<?> fromClass, Class<?> toClass) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.BeanProperty;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.ContextualDeserializer;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;
import java.time.Duration;
Expand All @@ -46,44 +50,57 @@
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Primitive;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.Scalar.Kind;
import org.flyte.api.v1.SimpleType;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkLiteralType;
import org.flyte.flytekit.SdkLiteralTypes;
import org.flyte.flytekit.jackson.JacksonSdkLiteralType;

class SdkBindingDataDeserializer extends StdDeserializer<SdkBindingData<?>> {
class SdkBindingDataDeserializer extends StdDeserializer<SdkBindingData<?>>
implements ContextualDeserializer {
private static final long serialVersionUID = 0L;

private final JavaType type;

public SdkBindingDataDeserializer() {
this(null);
}

private SdkBindingDataDeserializer(JavaType type) {
super(SdkBindingData.class);

this.type = type;
}

@Override
public SdkBindingData<?> deserialize(
JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException {
JsonNode tree = jsonParser.readValueAsTree();
return transform(tree);
return transform(tree, deserializationContext);
}

private SdkBindingData<?> transform(JsonNode tree) {
private SdkBindingData<?> transform(
JsonNode tree, DeserializationContext deserializationContext) {
Literal.Kind literalKind = Literal.Kind.valueOf(tree.get(LITERAL).asText());
switch (literalKind) {
case SCALAR:
return transformScalar(tree);
return transformScalar(tree, deserializationContext);
case COLLECTION:
return transformCollection(tree);
return transformCollection(tree, deserializationContext);

case MAP:
return transformMap(tree);
return transformMap(tree, deserializationContext);

default:
throw new UnsupportedOperationException(
String.format("Not supported literal type %s", literalKind.name()));
}
}

private static SdkBindingData<?> transformScalar(JsonNode tree) {
private SdkBindingData<?> transformScalar(
JsonNode tree, DeserializationContext deserializationContext) {
Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText());
switch (scalarKind) {
case PRIMITIVE:
Expand All @@ -109,6 +126,8 @@ private static SdkBindingData<?> transformScalar(JsonNode tree) {
return transformBlob(tree);

case GENERIC:
return transformGeneric(tree, deserializationContext, scalarKind);

default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
Expand All @@ -132,8 +151,28 @@ private static SdkBindingData<Blob> transformBlob(JsonNode tree) {
.build());
}

private SdkBindingData<Object> transformGeneric(
JsonNode tree, DeserializationContext deserializationContext, Kind scalarKind) {
JsonParser jsonParser = tree.get(VALUE).traverse();
try {
jsonParser.nextToken();
Object object =
deserializationContext
.findNonContextualValueDeserializer(type)
.deserialize(jsonParser, deserializationContext);
@SuppressWarnings("unchecked")
SdkLiteralType<Object> jacksonSdkLiteralType =
(SdkLiteralType<Object>) JacksonSdkLiteralType.of(type.getRawClass());
return SdkBindingData.literal(jacksonSdkLiteralType, object);
} catch (IOException e) {
throw new UnsupportedOperationException(
"Type contains an unsupported generic: " + scalarKind, e);
}
}

@SuppressWarnings("unchecked")
private <T> SdkBindingData<List<T>> transformCollection(JsonNode tree) {
private <T> SdkBindingData<List<T>> transformCollection(
JsonNode tree, DeserializationContext deserializationContext) {
SdkLiteralType<T> literalType = (SdkLiteralType<T>) readLiteralType(tree.get(TYPE));
Iterator<JsonNode> elements = tree.get(VALUE).elements();

Expand All @@ -143,7 +182,10 @@ private <T> SdkBindingData<List<T>> transformCollection(JsonNode tree) {
case COLLECTION_TYPE:
List<T> collection =
(List<T>)
streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList());
streamOf(elements)
.map((JsonNode tree1) -> transform(tree1, deserializationContext))
.map(SdkBindingData::get)
.collect(toList());
return SdkBindingDataFactory.of(literalType, collection);

case SCHEMA_TYPE:
Expand All @@ -155,7 +197,8 @@ private <T> SdkBindingData<List<T>> transformCollection(JsonNode tree) {
}

@SuppressWarnings("unchecked")
private <T> SdkBindingData<Map<String, T>> transformMap(JsonNode tree) {
private <T> SdkBindingData<Map<String, T>> transformMap(
JsonNode tree, DeserializationContext deserializationContext) {
SdkLiteralType<T> literalType = (SdkLiteralType<T>) readLiteralType(tree.get(TYPE));
JsonNode valueNode = tree.get(VALUE);
List<Map.Entry<String, JsonNode>> entries =
Expand All @@ -168,7 +211,11 @@ private <T> SdkBindingData<Map<String, T>> transformMap(JsonNode tree) {
case COLLECTION_TYPE:
Map<String, T> bindingDataMap =
entries.stream()
.map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get()))
.map(
entry ->
Map.entry(
entry.getKey(),
(T) transform(entry.getValue(), deserializationContext).get()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return SdkBindingDataFactory.of(literalType, bindingDataMap);

Expand Down Expand Up @@ -220,4 +267,9 @@ private <T> Stream<T> streamOf(Iterator<T> nodes) {
return StreamSupport.stream(
Spliterators.spliteratorUnknownSize(nodes, Spliterator.ORDERED), false);
}

@Override
public JsonDeserializer<?> createContextual(DeserializationContext ctxt, BeanProperty property) {
return new SdkBindingDataDeserializer(property.getType().containedType(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,35 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.Struct.Value;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;

public class LiteralStructDeserializer extends StdDeserializer<Literal> {
public class StructDeserializer extends StdDeserializer<StructWrapper> {
private static final long serialVersionUID = -6835948754469626304L;

public LiteralStructDeserializer() {
super(Literal.class);
// we cannot use Struct directly because it is an auto-value class so this deserializer will not
// be used by Jackson
public static class StructWrapper {

private final Struct struct;

public StructWrapper(Struct struct) {
this.struct = struct;
}

public Struct unwrap() {
return struct;
}
}

@Override
public Literal deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
public StructDeserializer() {
super(StructWrapper.class);
}

Struct generic = readValueAsStruct(p);
return Literal.ofScalar(Scalar.ofGeneric(generic));
@Override
public StructWrapper deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
return new StructWrapper(readValueAsStruct(p));
}

private static Struct readValueAsStruct(JsonParser p) throws IOException {
Expand All @@ -67,46 +79,46 @@ private static Struct readValueAsStruct(JsonParser p) throws IOException {
return Struct.of(unmodifiableMap(fields));
}

private static Struct.Value readValueAsStructValue(JsonParser p) throws IOException {
private static Value readValueAsStructValue(JsonParser p) throws IOException {
switch (p.currentToken()) {
case START_ARRAY:
p.nextToken();

List<Value> valuesList = new ArrayList<>();

while (p.currentToken() != JsonToken.END_ARRAY) {
Struct.Value value = readValueAsStructValue(p);
Value value = readValueAsStructValue(p);
p.nextToken();

valuesList.add(value);
}

return Struct.Value.ofListValue(unmodifiableList(valuesList));
return Value.ofListValue(unmodifiableList(valuesList));

case START_OBJECT:
Struct struct = readValueAsStruct(p);

return Struct.Value.ofStructValue(struct);
return Value.ofStructValue(struct);

case VALUE_STRING:
String stringValue = p.readValueAs(String.class);

return Struct.Value.ofStringValue(stringValue);
return Value.ofStringValue(stringValue);

case VALUE_NUMBER_FLOAT:
case VALUE_NUMBER_INT:
Double doubleValue = p.readValueAs(Double.class);

return Struct.Value.ofNumberValue(doubleValue);
return Value.ofNumberValue(doubleValue);

case VALUE_NULL:
return Struct.Value.ofNullValue();
return Value.ofNullValue();

case VALUE_FALSE:
return Struct.Value.ofBoolValue(false);
return Value.ofBoolValue(false);

case VALUE_TRUE:
return Struct.Value.ofBoolValue(true);
return Value.ofBoolValue(true);

case FIELD_NAME:
case NOT_AVAILABLE:
Expand Down
Loading

0 comments on commit 9f36a57

Please sign in to comment.