Skip to content

Commit

Permalink
Add support to Binary input/output type (#291)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Gomez Ferrer <[email protected]>
Co-authored-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
andresgomezfrr and andresgomezfrr authored Apr 18, 2024
1 parent c8617fe commit 6c2a1a9
Show file tree
Hide file tree
Showing 19 changed files with 358 additions and 10 deletions.
47 changes: 47 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/Binary.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2020-2021 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

import com.google.auto.value.AutoValue;

/**
* A simple byte array with a tag to help different parts of the system communicate about what is in
* the byte array. It's strongly advisable that consumers of this type define a unique tag and
* validate the tag before parsing the data.
*/
@AutoValue
public abstract class Binary {
public static final String TAG_FIELD = "tag";
public static final String VALUE_FIELD = "value";

public abstract byte[] value();

public abstract String tag();

public static Builder builder() {
return new AutoValue_Binary.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder value(byte[] value);

public abstract Builder tag(String tag);

public abstract Binary build();
}
}
9 changes: 8 additions & 1 deletion flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public abstract class Scalar {
public enum Kind {
PRIMITIVE,
GENERIC,
BLOB
BLOB,
BINARY
}

public abstract Kind kind();
Expand All @@ -36,6 +37,8 @@ public enum Kind {

public abstract Blob blob();

public abstract Binary binary();

// TODO: add the rest of the cases from src/main/proto/flyteidl/core/literals.proto

public static Scalar ofPrimitive(Primitive primitive) {
Expand All @@ -49,4 +52,8 @@ public static Scalar ofGeneric(Struct generic) {
public static Scalar ofBlob(Blob blob) {
return AutoOneOf_Scalar.blob(blob);
}

public static Scalar ofBinary(Binary binary) {
return AutoOneOf_Scalar.binary(binary);
}
}
3 changes: 2 additions & 1 deletion flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ public enum SimpleType {
BOOLEAN,
DATETIME,
DURATION,
STRUCT
STRUCT,
BINARY
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.Variable;
Expand Down Expand Up @@ -172,6 +173,8 @@ private SdkLiteralType<?> toLiteralType(
// feature
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
return SdkLiteralTypes.blobs(BlobType.DEFAULT);
} else if (Binary.class.isAssignableFrom(type)) {
return SdkLiteralTypes.binary();
}
try {
return JacksonSdkLiteralType.of(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
Expand All @@ -34,6 +35,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
Expand Down Expand Up @@ -167,11 +169,56 @@ private static Literal deserialize(JsonParser p, SimpleType simpleType) throws I
Struct generic = readValueAsStruct(p);

return Literal.ofScalar(Scalar.ofGeneric(generic));

case BINARY:
Binary binary = readValueAsBinary(p);

return Literal.ofScalar(Scalar.ofBinary(binary));
}

throw new AssertionError(String.format("Unexpected SimpleType: [%s]", simpleType));
}

private static Binary readValueAsBinary(JsonParser p) throws IOException {
verifyToken(p, JsonToken.START_OBJECT);
p.nextToken();

Binary.Builder binaryBuilder = Binary.builder();

while (p.currentToken() != JsonToken.END_OBJECT) {
verifyToken(p, JsonToken.FIELD_NAME);
String fieldName = p.currentName();
p.nextToken();

switch (fieldName) {
case Binary.TAG_FIELD:
binaryBuilder.tag(p.readValueAs(String.class));
break;
case Binary.VALUE_FIELD:
ByteArrayOutputStream value = new ByteArrayOutputStream();
p.readBinaryValue(value);
binaryBuilder.value(value.toByteArray());
break;
default:
throw new IllegalStateException("Unexpected field [" + fieldName + "]");
}

p.nextToken();
}

Binary binary = binaryBuilder.build();

if (binary.tag() == null) {
throw new IllegalStateException("Missing field [" + Binary.TAG_FIELD + "]");
}

if (binary.value().length == 0) {
throw new IllegalStateException("Missing field [" + Binary.VALUE_FIELD + "]");
}

return binary;
}

private static Struct readValueAsStruct(JsonParser p) throws IOException {
verifyToken(p, JsonToken.START_OBJECT);
p.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -42,6 +43,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
Expand Down Expand Up @@ -128,12 +130,24 @@ private SdkBindingData<?> transformScalar(
case GENERIC:
return transformGeneric(tree, deserializationContext, scalarKind, type);

case BINARY:
return transformBinary(tree);

default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
}
}

private static SdkBindingData<Binary> transformBinary(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String tag = value.get(Binary.TAG_FIELD).asText();
String base64Value = value.get(Binary.VALUE_FIELD).asText();

return SdkBindingDataFactory.of(
Binary.builder().tag(tag).value(Base64.getDecoder().decode(base64Value)).build());
}

private static SdkBindingData<Blob> transformBlob(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String uri = value.get("uri").asText();
Expand Down Expand Up @@ -256,6 +270,8 @@ private SdkLiteralType<?> readLiteralType(JsonNode typeNode) {
return SdkLiteralTypes.durations();
case STRUCT:
return JacksonSdkLiteralType.of(type.getContentType().getRawClass());
case BINARY:
return SdkLiteralTypes.binary();
}
throw new UnsupportedOperationException(
"Type contains a collection/map of an supported literal type: " + kind);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2020-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.jackson.serializers;

import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Scalar.Kind;

public class BinarySerializer extends ScalarSerializer {
public BinarySerializer(
JsonGenerator gen,
String key,
Literal value,
SerializerProvider serializerProvider,
LiteralType literalType) {
super(gen, key, value, serializerProvider, literalType);
}

@Override
void serializeScalar() throws IOException {
gen.writeObject(Kind.BINARY);
gen.writeFieldName(VALUE);
gen.writeStartObject();
gen.writeFieldName(Binary.TAG_FIELD);
gen.writeString(value.scalar().binary().tag());
gen.writeFieldName(Binary.VALUE_FIELD);
gen.writeString(Base64.getEncoder().encodeToString(value.scalar().binary().value()));
gen.writeEndObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ private static ScalarSerializer createScalarSerializer(
return new GenericSerializer(gen, key, value, serializerProvider, literalType);
case BLOB:
return new BlobSerializer(gen, key, value, serializerProvider, literalType);
case BINARY:
return new BinarySerializer(gen, key, value, serializerProvider, literalType);
}
throw new AssertionError("Unexpected Literal.Kind: [" + value.scalar().kind() + "]");
}
Expand Down
Loading

0 comments on commit 6c2a1a9

Please sign in to comment.