Skip to content

Commit

Permalink
Validate node values (#3715)
Browse files Browse the repository at this point in the history
* Added extra inputs/outputs variables validation at the Node level

Signed-off-by: Elena Khaustova <[email protected]>

* Fixing potential typo

Signed-off-by: Elena Khaustova <[email protected]>

* Added release note

Signed-off-by: Elena Khaustova <[email protected]>

* Retrigger the CI

Signed-off-by: Elena Khaustova <[email protected]>

---------

Signed-off-by: Elena Khaustova <[email protected]>
  • Loading branch information
ElenaKhaustova authored Mar 18, 2024
1 parent 0fe2f17 commit 32b87be
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
3 changes: 2 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Upcoming Release 0.19.4

## Major features and improvements
* Improved error message when passing wrong value to node.
* Cookiecutter errors are shown in short format without the `--verbose` flag.
* Kedro commands now work from any subdirectory within a Kedro project.
* Kedro CLI now provides a better error message when project commands are run outside of a project i.e. `kedro run`
* Kedro CLI now provides a better error message when project commands are run outside of a project i.e. `kedro run`.

## Bug fixes and other changes
* Updated `kedro pipeline create` and `kedro pipeline delete` to read the base environment from the project settings.
Expand Down
23 changes: 21 additions & 2 deletions kedro/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__( # noqa: PLR0913
function. When dict[str, str] is provided, variable names
will be mapped to function argument names.
outputs: The name or the list of the names of variables used
as outputs to the function. The number of names should match
as outputs of the function. The number of names should match
the number of outputs returned by the provided function.
When dict[str, str] is provided, variable names will be mapped
to the named outputs the function returns.
Expand All @@ -67,7 +67,6 @@ def __init__( # noqa: PLR0913
and/or fullstops.
"""

if not callable(func):
raise ValueError(
_node_error_message(
Expand All @@ -83,6 +82,16 @@ def __init__( # noqa: PLR0913
)
)

for _input in _to_list(inputs):
if not isinstance(_input, str):
raise ValueError(
_node_error_message(
f"names of variables used as inputs to the function "
f"must be of 'String' type, but {_input} from {inputs} "
f"is '{type(_input)}'."
)
)

if outputs and not isinstance(outputs, (list, dict, str)):
raise ValueError(
_node_error_message(
Expand All @@ -91,6 +100,16 @@ def __init__( # noqa: PLR0913
)
)

for _output in _to_list(outputs):
if not isinstance(_output, str):
raise ValueError(
_node_error_message(
f"names of variables used as outputs of the function "
f"must be of 'String' type, but {_output} from {outputs} "
f"is '{type(_output)}'."
)
)

if not inputs and not outputs:
raise ValueError(
_node_error_message("it must have some 'inputs' or 'outputs'.")
Expand Down
2 changes: 1 addition & 1 deletion tests/ipython/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def dummy_node_empty_input():
return node(
func=dummy_function,
inputs=["", ""],
outputs=[None],
outputs=None,
name="dummy_node_empty_input",
)

Expand Down
13 changes: 13 additions & 0 deletions tests/pipeline/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ def duplicate_output_list_node():
return identity, "A", ["A", "A"]


def bad_input_variable_name():
return lambda x: None, {"a": 1, "b": "B"}, {"a": "A", "b": "B"}


def bad_output_variable_name():
return lambda x: None, {"a": "A", "b": "B"}, {"a": "A", "b": 2}


@pytest.mark.parametrize(
"func, expected",
[
Expand All @@ -275,6 +283,11 @@ def duplicate_output_list_node():
r"\(\[A\]\) -> \[A;A\] due to "
r"duplicate output\(s\) {\'A\'}.",
),
(bad_input_variable_name, "names of variables used as inputs to the function "),
(
bad_output_variable_name,
"names of variables used as outputs of the function ",
),
],
)
def test_bad_node(func, expected):
Expand Down

0 comments on commit 32b87be

Please sign in to comment.