Skip to content

Commit

Permalink
Airflow agent (#1725)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw and wild-endeavor authored Oct 11, 2023
1 parent 048aa10 commit 54e68e0
Show file tree
Hide file tree
Showing 14 changed files with 1,410 additions and 39 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ jobs:
python-version: ["3.8", "3.11"]
plugin-names:
# Please maintain an alphabetical order in the following list
- flytekit-airflow
- flytekit-aws-athena
- flytekit-aws-batch
- flytekit-aws-sagemaker
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ARG PYTHON_VERSION
FROM python:${PYTHON_VERSION}-slim-buster

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

WORKDIR /root
ENV PYTHONPATH /root
Expand All @@ -20,6 +20,8 @@ RUN pip install -U flytekit==$VERSION \

RUN useradd -u 1000 flytekit
RUN chown flytekit: /root
# Some packages will create config file under /home by default, so we need to make sure it's writable
RUN chown flytekit: /home
USER flytekit

ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE"
5 changes: 4 additions & 1 deletion Dockerfile.agent
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
RUN pip install prometheus-client
RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION

# Airflow plugin's dependencies
RUN pip install apache-airflow
RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION flytekitplugins-airflow==$VERSION

CMD pyflyte serve --port 8000
30 changes: 13 additions & 17 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
State,
)
from flyteidl.core.tasks_pb2 import TaskTemplate
from rich.progress import Progress

import flytekit
from flytekit import FlyteContext, logger
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.exceptions.user import FlyteUserException
from flytekit.models.literals import LiteralMap


Expand Down Expand Up @@ -176,7 +176,7 @@ def execute(self, **kwargs) -> typing.Any:
res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.state != SUCCEEDED:
raise Exception(f"Failed to run the task {self._entity.name}")
raise FlyteUserException(f"Failed to run the task {self._entity.name}")

return LiteralMap.from_flyte_idl(res.resource.outputs)

Expand Down Expand Up @@ -205,21 +205,17 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse:
state = RUNNING
grpc_ctx = _get_grpc_context()

progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None)
with progress:
while not is_terminal_state(state):
progress.start_task(task)
time.sleep(1)
if self._agent.asynchronous:
res = await self._agent.async_get(grpc_ctx, resource_meta)
if self._is_canceled:
await self._is_canceled
sys.exit(1)
else:
res = self._agent.get(grpc_ctx, resource_meta)
state = res.resource.state
logger.info(f"Task state: {state}")
while not is_terminal_state(state):
time.sleep(1)
if self._agent.asynchronous:
res = await self._agent.async_get(grpc_ctx, resource_meta)
if self._is_canceled:
await self._is_canceled
sys.exit(1)
else:
res = self._agent.get(grpc_ctx, resource_meta)
state = res.resource.state
logger.info(f"Task state: {state}")
return res

def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any:
Expand Down
50 changes: 30 additions & 20 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import cloudpickle

from flytekit.core.context_manager import FlyteContext
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand Down Expand Up @@ -50,6 +50,33 @@ def python_type(cls) -> typing.Type:

return _SpecificFormatClass

@classmethod
def to_pickle(cls, python_val: typing.Any) -> str:
ctx = FlyteContextManager.current_context()
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return remote_path

@classmethod
def from_pickle(cls, uri: str) -> typing.Any:
ctx = FlyteContextManager.current_context()
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data


class FlytePickleTransformer(TypeTransformer[FlytePickle]):
PYTHON_PICKLE_FORMAT = "PythonPickle"
Expand All @@ -63,15 +90,7 @@ def assert_type(self, t: Type[T], v: T):

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
uri = lv.scalar.blob.uri
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data
return FlytePickle.from_pickle(uri)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if python_val is None:
Expand All @@ -81,16 +100,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)
# Dump the task output into pickle
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
remote_path = FlytePickle.to_pickle(python_val)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]:
Expand Down
33 changes: 33 additions & 0 deletions plugins/flytekit-airflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Flytekit Airflow Plugin
Airflow plugin allows you to seamlessly run Airflow tasks in the Flyte workflow without changing any code.

- Compile Airflow tasks to Flyte tasks
- Use Airflow sensors/operators in Flyte workflows
- Add support running Airflow tasks locally without running a cluster

## Example
```python
from airflow.sensors.filesystem import FileSensor
from flytekit import task, workflow

@task()
def t1():
print("flyte")


@workflow
def wf():
sensor = FileSensor(task_id="id", filepath="/tmp/1234")
sensor >> t1()


if __name__ == '__main__':
wf()
```


To install the plugin, run the following command:

```bash
pip install flytekitplugins-airflow
```
16 changes: 16 additions & 0 deletions plugins/flytekit-airflow/flytekitplugins/airflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
.. currentmodule:: flytekitplugins.airflow
This package contains things that are useful when extending Flytekit.
.. autosummary::
:template: custom.rst
:toctree: generated/
AirflowConfig
AirflowTask
AirflowAgent
"""

from .agent import AirflowAgent
from .task import AirflowConfig, AirflowTask
109 changes: 109 additions & 0 deletions plugins/flytekit-airflow/flytekitplugins/airflow/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import importlib
from dataclasses import dataclass
from typing import Optional

import cloudpickle
import grpc
import jsonpickle
from airflow.providers.google.cloud.operators.dataproc import (
DataprocDeleteClusterOperator,
DataprocJobBaseOperator,
JobStatus,
)
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.context import Context
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from flytekitplugins.airflow.task import AirflowConfig
from google.cloud.exceptions import NotFound

from flytekit import FlyteContext, FlyteContextManager, logger
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class ResourceMetadata:
job_id: str
airflow_config: AirflowConfig


def _get_airflow_task(ctx: FlyteContext, airflow_config: AirflowConfig):
task_module = importlib.import_module(name=airflow_config.task_module)
task_def = getattr(task_module, airflow_config.task_name)
task_config = airflow_config.task_config

# Set the GET_ORIGINAL_TASK attribute to True so that task_def will return the original
# airflow task instead of the Flyte task.
ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build()
if issubclass(task_def, DataprocJobBaseOperator):
return task_def(**task_config, asynchronous=True)
return task_def(**task_config)


class AirflowAgent(AgentBase):
def __init__(self):
super().__init__(task_type="airflow", asynchronous=False)

def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
airflow_config = jsonpickle.decode(task_template.custom.get("task_config_pkl"))
resource_meta = ResourceMetadata(job_id="", airflow_config=airflow_config)

ctx = FlyteContextManager.current_context()
airflow_task = _get_airflow_task(ctx, airflow_config)
if isinstance(airflow_task, DataprocJobBaseOperator):
airflow_task.execute(context=Context())
resource_meta.job_id = ctx.user_space_params.xcom_data["value"]["resource"]

return CreateTaskResponse(resource_meta=cloudpickle.dumps(resource_meta))

def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
meta = cloudpickle.loads(resource_meta)
airflow_config = meta.airflow_config
job_id = meta.job_id
task = _get_airflow_task(FlyteContextManager.current_context(), meta.airflow_config)
cur_state = RUNNING

if issubclass(type(task), BaseSensorOperator):
if task.poke(context=Context()):
cur_state = SUCCEEDED
elif issubclass(type(task), DataprocJobBaseOperator):
job = task.hook.get_job(
job_id=job_id,
region=airflow_config.task_config["region"],
project_id=airflow_config.task_config["project_id"],
)
if job.status.state == JobStatus.State.DONE:
cur_state = SUCCEEDED
elif job.status.state in (JobStatus.State.ERROR, JobStatus.State.CANCELLED):
cur_state = PERMANENT_FAILURE
elif isinstance(task, DataprocDeleteClusterOperator):
try:
task.execute(context=Context())
except NotFound:
logger.info("Cluster already deleted.")
cur_state = SUCCEEDED
else:
task.execute(context=Context())
cur_state = SUCCEEDED
return GetTaskResponse(resource=Resource(state=cur_state, outputs=None))

def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()


AgentRegistry.register(AirflowAgent())
Loading

0 comments on commit 54e68e0

Please sign in to comment.