Skip to content

Commit

Permalink
typ{o,e} fixes (#680)
Browse files Browse the repository at this point in the history
* docstr/comment typos

* uDAG arg validation: raise `ValueError`s instead of `TileDBCloudError`s

* type nit
  • Loading branch information
ryan-williams authored Nov 14, 2024
1 parent be74861 commit e904d0b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 34 deletions.
52 changes: 23 additions & 29 deletions src/tiledb/cloud/dag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .. import array
from .. import client
from .. import rest_api
from .. import tiledb_cloud_error as tce
from .. import udf
from .._common import functions
from .._common import futures
Expand All @@ -46,7 +45,7 @@
from . import visualization as viz
from .mode import Mode

Status = st.Status # Re-export for compabitility.
Status = st.Status # Re-export for compatibility.
_T = TypeVar("_T")
# Special string included in server errors when there is a problem loading
# stored parameters.
Expand Down Expand Up @@ -214,7 +213,7 @@ def _check_resources_and_mode(self) -> None:
if self.mode == Mode.BATCH:
if self._resource_class:
if resources_set:
raise tce.TileDBCloudError(
raise ValueError(
"Only one of `resources` and `resource_class`"
" may be set when running a task graph node."
)
Expand All @@ -225,20 +224,18 @@ def _check_resources_and_mode(self) -> None:
# (written in bytes) was probably a user error.
if resources_set and "memory" in self._resources:
if re.match(r"^[0-9]{1,7}$", self._resources["memory"]):
raise tce.TileDBCloudError(
raise ValueError(
"The `memory` key in `resources` is missing a"
" unit suffix. Did you forget to append 'Mi' or 'Gi'?"
)
elif resources_set:
raise tce.TileDBCloudError(
raise ValueError(
"Cannot set resources for REALTIME task graphs,"
' please use "resource_class" to set a predefined option'
' for "standard" or "large"'
)
elif self.mode is Mode.LOCAL and self._resource_class:
raise tce.TileDBCloudError(
"Resource class cannot be set for locally-executed nodes."
)
raise ValueError("Resource class cannot be set for locally-executed nodes.")

def _find_deps(self) -> None:
"""Finds Nodes this depends on and adds them to our dependency list."""
Expand All @@ -257,7 +254,7 @@ def _find_deps(self) -> None:
try:
dep.future.result(0)
except Exception as e:
raise tce.TileDBCloudError(
raise ValueError(
"Nodes from a previous DAG may only be used as inputs"
" in a subsequent DAG if they are already complete."
) from e
Expand Down Expand Up @@ -956,7 +953,7 @@ def _add_prewrapped_node(

if self.mode == Mode.BATCH:
if kwargs.get("mode") is not None and kwargs.get("mode") != Mode.BATCH:
raise tce.TileDBCloudError(
raise ValueError(
"BATCH mode DAG can only execute BATCH mode Nodes."
)
kwargs["mode"] = Mode.BATCH
Expand All @@ -981,8 +978,8 @@ def submit_array_udf(self, func: Callable, *args: Any, **kwargs: Any):
"""Submit a function that will be executed in the cloud serverlessly.
:param func: Function to execute in UDF task.
:param *args: Postional arguments to pass into Node instantation.
:param **kwargs: Keyword args to pass into Node instantiation.
:param args: Positional arguments to pass into Node instantiation.
:param kwargs: Keyword args to pass into Node instantiation.
:return: Node that is created.
"""

Expand All @@ -998,8 +995,8 @@ def submit_local(self, func: Callable, *args: Any, **kwargs):
"""Submit a function that will run locally.
:param func: Function to execute in UDF task.
:param *args: Postional arguments to pass into Node instantation.
:param **kwargs: Keyword args to pass into Node instantiation.
:param args: Positional arguments to pass into Node instantiation.
:param kwargs: Keyword args to pass into Node instantiation.
:return: Node that is created
"""

Expand All @@ -1010,8 +1007,8 @@ def submit_udf(self, func: Callable, *args, **kwargs):
"""Submit a function that will be executed in the cloud serverlessly.
:param func: Function to execute in UDF task.
:param *args: Postional arguments to pass into Node instantation.
:param **kwargs: Keyword args to pass into Node instantiation.
:param args: Positional arguments to pass into Node instantiation.
:param kwargs: Keyword args to pass into Node instantiation.
:return: Node that is created.
"""

Expand Down Expand Up @@ -1062,17 +1059,15 @@ def submit_udf_stage(
```
:param func: Function to execute in UDF task.
:param *args: Postional arguments to pass into Node instantation.
:param args: Positional arguments to pass into Node instantiation.
:param expand_node_output: Node that we want to expand the output of.
The output of the node should be a JSON encoded list.
:param **kwargs: Keyword args to pass into Node instantiation.
:param kwargs: Keyword args to pass into Node instantiation.
:return: Node that is created.
"""

if "local_mode" in kwargs or self.mode != Mode.BATCH:
raise tce.TileDBCloudError(
"Stage nodes are only supported for BATCH mode DAGs."
)
raise ValueError("Stage nodes are only supported for BATCH mode DAGs.")

return self._add_prewrapped_node(
udf.exec_base,
Expand All @@ -1083,17 +1078,18 @@ def submit_udf_stage(
**kwargs,
)

def submit_sql(self, *args: Any, **kwargs: Any) -> Node:
def submit_sql(self, sql: str, *args: Any, **kwargs: Any) -> Node:
"""Submit a sql query to run serverlessly in the cloud.
:param sql: Query to execute.
:param *args: Postional arguments to pass into Node instantation.
:param **kwargs: Keyword args to pass into Node instantiation.
:param args: Positional arguments to pass into Node instantiation.
:param kwargs: Keyword args to pass into Node instantiation.
:return: Node that is created
"""

return self._add_prewrapped_node(
_sql_exec.exec_base,
sql,
*args,
_internal_accepts_stored_params=False,
_fallback_name="SQL query",
Expand Down Expand Up @@ -1259,9 +1255,7 @@ def compute(self) -> None:
if self.mode == Mode.REALTIME:
roots = self._find_root_nodes()
if len(roots) == 0:
raise tce.TileDBCloudError(
"DAG is circular, there are no root nodes"
)
raise ValueError("DAG is circular, there are no root nodes")
self._status = Status.RUNNING

for node in roots:
Expand Down Expand Up @@ -1661,7 +1655,7 @@ def _build_batch_taskgraph(self):
if name not in _SKIP_BATCH_UDF_KWARGS
}

all_args = types.Arguments(node_args, filtered_node_kwargs)
all_args = types.Arguments(tuple(node_args), filtered_node_kwargs)
encoder = _BatchArgEncoder(input_is_expanded=bool(node._expand_node_output))
kwargs["arguments"] = encoder.encode_arguments(all_args)

Expand Down Expand Up @@ -2001,7 +1995,7 @@ def maybe_replace(self, arg) -> Optional[visitor.Replacement]:
def _topo_sort(
lst: Sequence[rest_api.TaskGraphNodeMetadata],
) -> Sequence[rest_api.TaskGraphNodeMetadata]:
"""Topologically sorts the list of node metadatas."""
"""Topologically sorts the list of node metadata."""
by_uuid: Dict[str, rest_api.TaskGraphNodeMetadata] = {
node.client_node_uuid: node for node in lst
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ def _remote_result(self, node: dag_dag.Node) -> results.RemoteResult:

def test_resource_checks(self):
grf = dag.DAG()
with self.assertRaises(tce.TileDBCloudError):
with self.assertRaises(ValueError):
grf.submit(repr, None, resources={"x": "y"})
with self.assertRaises(tce.TileDBCloudError):
with self.assertRaises(ValueError):
grf.submit_local(repr, None, resources={"x": "y"})
with self.assertRaises(tce.TileDBCloudError):
with self.assertRaises(ValueError):
grf.submit_local(repr, None, resource_class="hello")
grf = dag.DAG(mode=Mode.BATCH)
with self.assertRaises(tce.TileDBCloudError):
with self.assertRaises(ValueError):
grf.submit(repr, None, resources={"memory": "100"})

def test_kwargs(self):
Expand Down Expand Up @@ -522,7 +522,7 @@ def test_two_dags_bad(self):
d1 = dag.DAG()
n1 = d1.submit(repr, "whatever")
d2 = dag.DAG()
with self.assertRaises(tce.TileDBCloudError):
with self.assertRaises(ValueError):
d2.submit(repr, n1)

def test_retry(self):
Expand Down

0 comments on commit e904d0b

Please sign in to comment.