Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make plugin task first-class citizen #268

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ message RuntimeMetadata {

//+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.).
string flavor = 3;

//+optional It can be used to provide extra information for the plugin.
PluginMetadata plugin_metadata = 4;
}

message PluginMetadata {
//+optional It can be used to decide use sync plugin or async plugin during runtime.
bool is_sync_plugin = 1;
}

// Task Metadata
Expand Down
22 changes: 22 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

/** A task that is handled by a Flyte backend plugin instead of run as a container. */
public interface PluginTask extends Task {
boolean isSyncPlugin();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

/** A registrar that creates {@link PluginTask} instances. */
public abstract class PluginTaskRegistrar implements Registrar<TaskIdentifier, PluginTask> {}
8 changes: 8 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
/**
* A Task structure that uniquely identifies a task in the system. Tasks are registered as a first
* step in the system.
*
* <p>FIXME: consider offering TaskMetadata instead of having everything in TaskTemplate, see
* https://github.com/flyteorg/flyte/blob/ea72bbd12578d64087221592554fb71c368f8057/flyteidl/protos/flyteidl/core/tasks.proto#L90
*/
@AutoValue
public abstract class TaskTemplate {
Expand Down Expand Up @@ -64,6 +67,9 @@ public abstract class TaskTemplate {
*/
public abstract boolean cacheSerializable();

/** Indicates whether to use sync plugin or async plugin to handle this task. */
public abstract boolean isSyncPlugin();

public abstract Builder toBuilder();

public static Builder builder() {
Expand All @@ -89,6 +95,8 @@ public abstract static class Builder {

public abstract Builder cacheSerializable(boolean cacheSerializable);

public abstract Builder isSyncPlugin(boolean isSyncPlugin);

public abstract TaskTemplate build();
}
}
115 changes: 115 additions & 0 deletions flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.flyte.api.v1.PartialTaskIdentifier;

/** A task that is handled by a Flyte backend plugin instead of run as a container. */
public abstract class SdkPluginTask<InputT, OutputT> extends SdkTransform<InputT, OutputT> {

private final SdkType<InputT> inputType;
private final SdkType<OutputT> outputType;

/**
* Called by subclasses passing the {@link SdkType}s for inputs and outputs.
*
* @param inputType type for inputs.
* @param outputType type for outputs.
*/
public SdkPluginTask(SdkType<InputT> inputType, SdkType<OutputT> outputType) {
this.inputType = inputType;
this.outputType = outputType;
}

public abstract String getType();

@Override
public SdkType<InputT> getInputType() {
return inputType;
}

@Override
public SdkType<OutputT> getOutputType() {
return outputType;
}

/** Specifies custom data that can be read by the backend plugin. */
public SdkStruct getCustom() {
return SdkStruct.empty();
}

/**
* Number of retries. Retries will be consumed when the task fails with a recoverable error. The
* number of retries must be less than or equals to 10.
*
* @return number of retries
*/
public int getRetries() {
return 0;
}

/**
* Indicates whether the system should attempt to look up this task's output to avoid duplication
* of work.
*/
public boolean isCached() {
return false;
}

/** Indicates a logical version to apply to this task for the purpose of cache. */
public String getCacheVersion() {
return null;
}

/**
* Indicates whether the system should attempt to execute cached instances in serial to avoid
* duplicate work.
*/
public boolean isCacheSerializable() {
return false;
}

@Override
SdkNode<OutputT> apply(
SdkWorkflowBuilder builder,
String nodeId,
List<String> upstreamNodeIds,
@Nullable SdkNodeMetadata metadata,
Map<String, SdkBindingData<?>> inputs) {
PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build();
List<CompilerError> errors =
Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap());

if (!errors.isEmpty()) {
throw new CompilerException(errors);
}

return new SdkTaskNode<>(
builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType);
}

/**
* Signaling whether this task is supposed to be handled by a synchronous backend plugin,
* defaulting to false.
*/
public boolean isSyncPlugin() {
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import com.google.auto.service.AutoService;
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.flyte.api.v1.PluginTask;
import org.flyte.api.v1.PluginTaskRegistrar;
import org.flyte.api.v1.RetryStrategy;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TypedInterface;

/**
* Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s
* implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use
* {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the
* corresponding file.
*
* @see ServiceLoader
*/
@AutoService(PluginTaskRegistrar.class)
public class SdkPluginTaskRegistrar extends PluginTaskRegistrar {
private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName());

static {
// enable all levels for the actual handler to pick up
LOG.setLevel(Level.ALL);
}

private static class PluginTaskImpl<InputT, OutputT> implements PluginTask {
private final SdkPluginTask<InputT, OutputT> sdkTask;

private PluginTaskImpl(SdkPluginTask<InputT, OutputT> sdkTask) {
this.sdkTask = sdkTask;
}

@Override
public String getType() {
return sdkTask.getType();
}

@Override
public Struct getCustom() {
return sdkTask.getCustom().struct();
}

@Override
public TypedInterface getInterface() {
return TypedInterface.builder()
.inputs(sdkTask.getInputType().getVariableMap())
.outputs(sdkTask.getOutputType().getVariableMap())
.build();
}

@Override
public RetryStrategy getRetries() {
return RetryStrategy.builder().retries(sdkTask.getRetries()).build();
}

@Override
public boolean isCached() {
return sdkTask.isCached();
}

@Override
public String getCacheVersion() {
return sdkTask.getCacheVersion();
}

@Override
public boolean isCacheSerializable() {
return sdkTask.isCacheSerializable();
}

@Override
public String getName() {
return sdkTask.getName();
}

@Override
public boolean isSyncPlugin() {
return sdkTask.isSyncPlugin();
}
}

/**
* Load {@link SdkPluginTask}s using {@link ServiceLoader}.
*
* @param env env vars in a map that would be used to pick up the project, domain and version for
* the discovered tasks.
* @param classLoader class loader to use when discovering the task using {@link
* ServiceLoader#load(Class, ClassLoader)}
* @return a map of {@link SdkPluginTask}s by its task identifier.
*/
@Override
@SuppressWarnings("rawtypes")
public Map<TaskIdentifier, PluginTask> load(Map<String, String> env, ClassLoader classLoader) {
ServiceLoader<SdkPluginTask> loader = ServiceLoader.load(SdkPluginTask.class, classLoader);

LOG.fine("Discovering SdkPluginTask");

Map<TaskIdentifier, PluginTask> tasks = new HashMap<>();
SdkConfig sdkConfig = SdkConfig.load(env);

for (SdkPluginTask<?, ?> sdkTask : loader) {
String name = sdkTask.getName();
TaskIdentifier taskId =
TaskIdentifier.builder()
.domain(sdkConfig.domain())
.project(sdkConfig.project())
.name(name)
.version(sdkConfig.version())
.build();
LOG.fine(String.format("Discovered [%s]", name));

PluginTask task = new PluginTaskImpl<>(sdkTask);
PluginTask previous = tasks.put(taskId, task);

if (previous != null) {
throw new IllegalArgumentException(
String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous));
}
}

return tasks;
}
}
Loading
Loading