Skip to content

Commit

Permalink
Bring back Blob support
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 7, 2023
1 parent c892ceb commit 9ab0009
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 158 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2020-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.jackson;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/** Applied to a blob property to annotate its type. */
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface BlobTypeDescription {
/**
* Describes the blob's format.
*
* @return format, not {@code null}
*/
String format();

/**
* Describes the blob's dimensionality.
*
* @return dimensionality, not {@code null}
*/
String dimensionality();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Variable;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkLiteralType;
Expand Down Expand Up @@ -63,11 +66,7 @@ public void property(BeanProperty prop) {
String propName = prop.getName();
AnnotatedMember member = prop.getMember();
SdkLiteralType<?> literalType =
toLiteralType(
handledType,
/*rootLevel=*/ true,
propName,
member.getMember().getDeclaringClass().getName());
toLiteralType(handledType, /* rootLevel= */ true, propName, member);

String description = getDescription(member);

Expand Down Expand Up @@ -132,18 +131,17 @@ private String getDescription(AnnotatedMember member) {

@SuppressWarnings("AlreadyChecked")
private SdkLiteralType<?> toLiteralType(
JavaType javaType, boolean rootLevel, String propName, String declaringClassName) {
JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) {
Class<?> type = javaType.getRawClass();

if (SdkBindingData.class.isAssignableFrom(type)) {
return toLiteralType(
javaType.getBindings().getBoundType(0), false, propName, declaringClassName);
return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member);
} else if (rootLevel) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. "
+ "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.",
propName, declaringClassName, type));
propName, member.getMember().getDeclaringClass().getName(), type));
} else if (isPrimitiveAssignableFrom(Long.class, type)) {
return SdkLiteralTypes.integers();
} else if (isPrimitiveAssignableFrom(Double.class, type)) {
Expand All @@ -159,8 +157,7 @@ private SdkLiteralType<?> toLiteralType(
} else if (List.class.isAssignableFrom(type)) {
JavaType elementType = javaType.getBindings().getBoundType(0);

return SdkLiteralTypes.collections(
toLiteralType(elementType, false, propName, declaringClassName));
return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member));
} else if (Map.class.isAssignableFrom(type)) {
JavaType keyType = javaType.getBindings().getBoundType(0);
JavaType valueType = javaType.getBindings().getBoundType(1);
Expand All @@ -170,9 +167,22 @@ private SdkLiteralType<?> toLiteralType(
"Only Map<String, ?> is supported, got [" + javaType.getGenericSignature() + "]");
}

return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName));
return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member));
} else if (Blob.class.isAssignableFrom(type)) {
BlobTypeDescription annotation = member.getAnnotation(BlobTypeDescription.class);
if (annotation == null) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it must be annotated",
propName, member.getMember().getDeclaringClass().getName(), type));
}
return SdkLiteralTypes.blobs(
BlobType.builder()
.format(annotation.format())
.dimensionality(BlobDimensionality.valueOf(annotation.dimensionality()))
.build());
}
// TODO: Support blobs and structs
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;
import java.io.Serializable;
import java.time.Duration;
import java.time.Instant;
import java.util.Iterator;
Expand All @@ -39,6 +38,10 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Primitive;
Expand Down Expand Up @@ -80,7 +83,7 @@ private SdkBindingData<?> transform(JsonNode tree) {
}
}

private static SdkBindingData<? extends Serializable> transformScalar(JsonNode tree) {
private static SdkBindingData<?> transformScalar(JsonNode tree) {
Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText());
switch (scalarKind) {
case PRIMITIVE:
Expand All @@ -102,8 +105,27 @@ private static SdkBindingData<? extends Serializable> transformScalar(JsonNode t
throw new UnsupportedOperationException(
"Type contains an unsupported primitive: " + primitiveKind);

case GENERIC:
case BLOB:
JsonNode value = tree.get(VALUE);
String uri = value.get("uri").asText();
JsonNode type = value.get("metadata").get("type");
String format = type.get("format").asText();
BlobDimensionality dimensionality =
BlobDimensionality.valueOf(type.get("dimensionality").asText());
return SdkBindingDataFactory.of(
Blob.builder()
.uri(uri)
.metadata(
BlobMetadata.builder()
.type(
BlobType.builder()
.format(format)
.dimensionality(dimensionality)
.build())
.build())
.build());
case GENERIC:

default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
*/
package org.flyte.flytekit.jackson.serializers;

import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR;
import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.Scalar.Kind;

public class BlobSerializer extends ScalarSerializer {
public BlobSerializer(
Expand All @@ -38,8 +38,8 @@ public BlobSerializer(

@Override
void serializeScalar() throws IOException {
gen.writeFieldName(SCALAR);
gen.writeObject(Scalar.Kind.BLOB);
gen.writeObject(Kind.BLOB);
gen.writeFieldName(VALUE);
serializerProvider
.findValueSerializer(Blob.class)
.serialize(value.scalar().blob(), gen, serializerProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol {
public static final String TYPE = "type";
public static final String KIND = "kind";
public static final String PRIMITIVE = "primitive";
public static final String BLOB = "blob";
}
Loading

0 comments on commit 9ab0009

Please sign in to comment.