Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register DAG #692

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion src/tiledb/cloud/dag/dag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Directed acyclic graphs as TileDB task graphs."""

import collections
import datetime
import itertools
Expand Down Expand Up @@ -25,6 +27,8 @@
Union,
)

import tiledb

from .. import array
from .. import client
from .. import rest_api
Expand All @@ -41,6 +45,7 @@
from ..rest_api import models
from ..sql import _execution as _sql_exec
from ..taskgraphs import _results as _tg_results
from ..taskgraphs import registration
from . import status as st
from . import visualization as viz
from .mode import Mode
Expand Down Expand Up @@ -618,6 +623,62 @@ def _to_log_metadata(self) -> rest_api.TaskGraphNodeMetadata:
),
)

def _registration_name(
self, existing: Set[str], fallback_name: Optional[str] = None
) -> str:
"""Generates the unique name to be used when building the graph.

If the node has a ``name``, then that is used. If not, then it generates
a new unique name that is not contained within the ``existing`` set,
and adds that newly-generated name to the set so subsequent Nodes don't
reuse that name.

:param existing: Existing set of names to avoid using.
:param fallback_name: A string to use to generate a display name if the
node is unnamed.
:return: Registration name for node.
"""

fallback_name = fallback_name or type(self).__name__

if self.name is not None:
# A Node which already has a Name does not need to have one generated.
return self.name

if fallback_name in existing:
return existing.add(fallback_name)

id_to_use = self.id
while True:
id_str = str(id_to_use)
for chars in range(2, 13, 2):
# Try to generate unique names with increasingly large slices
# of the node's UUID.
end = id_str[-chars:]

if f"{fallback_name} ({end})" not in existing:
return existing.add(f"{fallback_name} ({end})")
# At this point every single alternate generated name we could generate,
# from "name (xx)" to "name (xxxxxxxxxxxx)", has been taken.
# Just throw in a new ID to start from.
id_to_use = uuid.uuid4()

def to_registration_json(self, existing_names: Set[str]) -> Dict[str, Any]:
"""Converts this node to the form used when registering the graph.

This is the form of the Node that will be used to represent it in the
``RegisteredTaskGraph`` object, i.e. a ``RegisteredTaskGraphNode``.

:param existing_names: The set of names that have already been used,
so that we don't generate a duplicate node name.
:return: Mapping of Node for registration.
"""

return {
"client_node_id": str(self.id),
"name": self._registration_name(existing_names),
}


class DAG:
"""Low-level API for creating and managing direct acyclic graphs
Expand Down Expand Up @@ -1801,7 +1862,56 @@ def _update_status(self) -> None:
futures.execute_callbacks(nd, cbs)
with self._lifecycle_condition:
self._set_status(Status.FAILED)
raise # Bail out and fail loudly.
raise # Bail out and fail loudly

def _tdb_to_json(self, override_name: Optional[str] = None) -> Dict[str, Any]:
"""Converts this DAG to a registerable format.

:param override_name: Name to override DAG conversion.
:return: Mapping of DAG tree to json for submission to REST.
"""

nodes = _topo_sort_nodes(self.nodes)

existing_names = set(self.nodes_by_name.keys())

node_jsons = [n.to_registration_json(existing_names) for n in nodes]

for n, n_json in zip(nodes, node_jsons):
n_json["depends_on"] = [str(parent.id) for _, parent in n.parents.items()]

return {
"name": override_name or self.name,
"nodes": node_jsons,
}

def register(
self,
override_name: Optional[str] = None,
spencerseale marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
"""Register DAG to TileDB.

:param override_name: Name to register DAG as. Uses self.name as default.
:return: Registered name of task graph.
"""

tg_name = override_name or self.name

if not tg_name:
raise ValueError(
"Must specify registration name to DAG.name or override_name."
)

registration.register(
graph=self,
name=tg_name,
namespace=self.namespace,
spencerseale marked this conversation as resolved.
Show resolved Hide resolved
)

with tiledb.open(f"tiledb://{self.namespace}/{tg_name}", "w") as A:
A.meta["dataset_type"] = "registered_task_graph"
sgillies marked this conversation as resolved.
Show resolved Hide resolved

return f"{self.namespace}/{tg_name}"


def list_logs(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import uuid
from concurrent import futures
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand All @@ -26,6 +28,7 @@
from tiledb.cloud._results import results
from tiledb.cloud._results import stored_params as sp
from tiledb.cloud._vendor import cloudpickle as tdbcp
from tiledb.cloud.client import default_user
from tiledb.cloud.dag import Mode
from tiledb.cloud.dag import dag as dag_dag
from tiledb.cloud.rest_api import models
Expand Down Expand Up @@ -1350,6 +1353,33 @@ def _b64(x: bytes) -> str:
return str(base64.b64encode(x), encoding="ascii")


@pytest.fixture
def dag_fixture():
"""DAG fixture for pytests."""

graph = dag.DAG(name="dag-test-fixture", namespace=default_user().username)

yield graph


@patch("tiledb.cloud.dag.dag.registration.register")
def test_dag_register(mock_register: MagicMock, dag_fixture: dag.DAG) -> None:
"""Test DAG.register"""

# verify DAG.name used to register
registered_name1 = dag_fixture.register()
assert registered_name1 == dag_fixture.name

# verify override name
registered_name2 = dag_fixture.register(override_name="override-name")
assert registered_name2 == "override-name"

# verify catch if no name set to DAG.name or override_name
dag_fixture.name = None
with pytest.raises(ValueError):
dag_fixture.register()


# This is the base64 of the Arrow data returned by this query:
# set @a = 1;
# select @a a, ? param1; -- params: [1.1]
Expand Down
Loading