Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic/fault tolerance #34

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 183 additions & 31 deletions ray_provider/decorators/ray_decorators.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import copy
import logging
import functools
from typing import Callable, Optional
from typing import Dict, Optional, Callable, List, Any

from airflow.decorators.base import task_decorator_factory
from airflow.decorators.python import _PythonDecoratedOperator
from airflow.models.xcom_arg import XComArg
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.operators.python import PythonOperator
from airflow.utils.db import provide_session
from airflow.utils.session import provide_session
from airflow.exceptions import AirflowException
import ray
from airflow.operators.python import task
from ray_provider.hooks.ray_client import RayClientHook
from ray_provider.xcom.ray_backend import RayBackend, get_or_create_kv_store

from ray_provider.xcom.ray_backend import RayBackend, get_or_create_kv_store, KVStore


log = logging.getLogger(__name__)


def ray_wrapped(f, ray_conn_id="ray_default", eager=False):

@functools.wraps(f)
def wrapper(*args, **kwargs) -> "ray.ObjectRef":
log.info("[wrapper] Got executor.")

executor = get_or_create_kv_store(
identifier=RayBackend.store_identifier, allow_new=True
)

log.info(f"[wrapper] Launching task (with {args}, {kwargs}.")
ret_str = executor.execute(f, args=args, kwargs=kwargs, eager=eager)
log.info("[wrapper] Remote task finished")
Expand All @@ -31,42 +45,180 @@ def ray_task(
eager: bool = False,
):
"""Wraps a function to be executed on the Ray cluster.

The return values of the function will be cached on the Ray object store.
Downstream tasks must be ray tasks too, as the dependencies will be
fetched from the object store. The RayBackend will need to be setup in your
Dockerfile to use this decorator.

Use as a task decorator: ::

from ray_provider.decorators import ray_task

def ray_example_dag():

@ray_task("ray_conn_id")
def sum_cols(df: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(df.sum()).T

:param python_callable: Function to be invoked on the Ray cluster.
:type python_callable: Optional[Callable]
:param http_conn_id: Http connection id for conenction to ray.
:type http_conn_id: str
:param ray_worker_pool: The pool that controls the
amount of parallel clients created to access the Ray cluster.
:type ray_worker_pool: Optional[str]
:param eager: Whether to run the the function on the
coordinator process (on the Ray cluster) or to
send the function to a remote task. You should
set this to False normally.
:type eager: Optional[bool]
"""

@provide_session
def on_retry_callback(context, session=None):
"""When a task is set to retry, store the output of its upstream tasks.
"""

# List upstream task ids
upstream_tasks = RayPythonOperator._upstream_tasks(
context.get('ti').task_id, context.get('dag'))

# Retrieve upstream object ids from xcom
upstream_objects = [RayPythonOperator._retrieve_obj_id_from_xcom(
task_id, context.get('dag').dag_id) for task_id in upstream_tasks]

# Retrieve the KV Actor
actor_ray_kv_store = KVStore("ray_kv_store").get_actor("ray_kv_store")

# Write to GCS
for dag_id, task_id, obj_ref in upstream_objects:
actor_ray_kv_store.gcs_dump.remote(dag_id, task_id, obj_ref)

@functools.wraps(python_callable)
def wrapper(f):

return task(
ray_wrapped(f, ray_conn_id, eager=eager),
pool=ray_worker_pool,
on_retry_callback=on_retry_callback
)

return wrapper


class RayPythonOperator(PythonOperator):

def __init__(self, *,
python_callable: Callable,
op_args: Optional[List] = None,
op_kwargs: Optional[Dict] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
**kwargs) -> None:

This comment was marked as resolved.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is basically the implementation of the task?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a docstring here?

Copy link
Contributor Author

@pgzmnk pgzmnk Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the ray decorator used the PythonOperator.

Assigning the recovered objects as Task attributes requires modifying to the pre_execute method of the PythonOperator.

RayPythonOperator subclasses the PythonOperator; it sets a custom pre_execute method and implements logic to enable attribute assignment.

  • I will add docstrings explaining the implementation.


# Store task XComArgs
if all(isinstance(arg, XComArg) for arg in self.op_args):
self.ray_xcomarg_op_args = self.op_args
self.ray_xcomarg_op_kwargs = self.op_kwargs

# Indicate whether upstream task arguments were retrieved
self.upstream_not_retrieved = False

super().__init__(python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
templates_dict=templates_dict,
templates_exts=templates_exts,
**kwargs
)

def execute(self, context: Dict):

# Fail task if object retrieval fails
if self.upstream_not_retrieved:
raise AirflowException('Failed to retrieve upstream object.')

return super(RayPythonOperator, self).execute(context)

@provide_session
def pre_execute(self, context, session=None):
ti = context.get('ti')
task = ti.task

if ti._try_number <= 1 or ti.state != 'up_for_retry':
return

# Retrieve the KV Actor
actor_ray_kv_store = KVStore("ray_kv_store").get_actor("ray_kv_store")

# List upstream task ids
upstream_tasks = self._upstream_tasks(task.task_id, task.dag)

# Retrieve upstream object ids from xcom
upstream_objects = [self._retrieve_obj_id_from_xcom(
task_id, task.dag.dag_id) for task_id in upstream_tasks]

# Retrieve object refs from Ray kv store
recovered_obj_refs = ray.get(
actor_ray_kv_store.recover_objects.remote(upstream_objects))

# Set recovered objects as current Task's XComArgs
for task_id, obj_ref in recovered_obj_refs.items():

# Flag if object retrieval failed
if obj_ref == -404:
self.upstream_not_retrieved = True

if 'ObjectRef' in str(obj_ref):
RayBackend.set(
key='return_value',
value=str(obj_ref),
execution_date=ti.execution_date,
task_id=task_id,
dag_id=task.dag.dag_id,
session=session
)

# Reassign XComArg objects
self.op_args = self.ray_xcomarg_op_args
self.op_kwargs = self.ray_xcomarg_op_kwargs

# Render XComArg object with newly assigned values
self.render_template_fields(context)

# Write to `rendered_task_instance_fields` table
RenderedTaskInstanceFields.write(
RenderedTaskInstanceFields(ti=ti, render_templates=False))
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id)

@staticmethod
@provide_session
def _retrieve_obj_id_from_xcom(task_id, dag_id, session=None):
# To-do: incorporate execution id in filter

obj_ref_key = session.query(RayBackend).filter(
RayBackend.key == 'return_value',
RayBackend.task_id == task_id,
RayBackend.dag_id == dag_id) \
.order_by(RayBackend.timestamp.desc()).first()

return (dag_id, task_id, obj_ref_key.value if bool(obj_ref_key) else None)

@staticmethod
def _upstream_tasks(task_id, dag, path=[]):
"""List upstream tasks recursively.
"""

def _recurse_upstream_tasks(task_id, dag):
r = [task_id]
for child in dag.get_task(task_id)._upstream_task_ids:
r.extend(_recurse_upstream_tasks(child, dag))
return r

upstream_tasks = set(_recurse_upstream_tasks(task_id, dag))
upstream_tasks.remove(task_id)
return upstream_tasks


class _RayDecoratedOperator(_PythonDecoratedOperator, RayPythonOperator):
pass


class _RayTaskDecorator:
def __call__(
self, python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs
):
"""
Python operator decorator. Wraps a function into an Airflow operator.
Accepts kwargs for operator kwarg. This decorator can be reused in a single DAG.
:param python_callable: Function to decorate
:type python_callable: Optional[Callable]
:param multiple_outputs: if set, function return value will be
unrolled to multiple XCom values. List/Tuples will unroll to xcom values
with index as key. Dict will unroll to xcom values with keys as XCom keys.
Defaults to False.
:type multiple_outputs: bool
"""

return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
** kwargs,
)


task = _RayTaskDecorator()
16 changes: 0 additions & 16 deletions ray_provider/example_dags/__init__.py

This file was deleted.

53 changes: 53 additions & 0 deletions ray_provider/example_dags/demo_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from airflow.decorators import dag, task
from ray_provider.decorators.ray_decorators import ray_task
from ray_provider.xcom.ray_backend import RayBackend


from datetime import datetime


default_args = {
"owner": "airflow",
"on_success_callback": RayBackend.on_success_callback,
"on_failure_callback": RayBackend.on_failure_callback,
"retries": 1,
"retry_delay": 0,
}

task_args = {
"ray_conn_id": "ray_cluster_connection",
}


@dag(
default_args=default_args,
schedule_interval=None,
start_date=datetime(2020, 1, 1, 0, 0, 0),
tags=['demo']
)
def demo():

@ray_task(**task_args)
def load_data1():
return 1

@ray_task(**task_args)
def load_data2():
return 2

@ray_task(**task_args)
def transform_data(data, data2):
return data * data2 * 100

# Upstream outputs save to GCS when this task retries
@ray_task(**task_args)
def divide_by_zero(data):
return data/0

the_data1 = load_data1()
the_data2 = load_data2()
the_transformed_data = transform_data(the_data1, the_data2)
divide_by_zero_output = divide_by_zero(the_transformed_data)


demo_dag = demo()
Loading