From 726008f39444435e01c9d93036e6d245c409553d Mon Sep 17 00:00:00 2001 From: Jesse Bannon Date: Sat, 26 Oct 2024 20:54:07 -0700 Subject: [PATCH] [BUGFIX] Custom function ordering --- src/ytdl_sub/script/script.py | 11 +++++++---- .../unit/script/types/test_custom_function.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/ytdl_sub/script/script.py b/src/ytdl_sub/script/script.py index f7c20d882..ab5063183 100644 --- a/src/ytdl_sub/script/script.py +++ b/src/ytdl_sub/script/script.py @@ -490,15 +490,18 @@ def add(self, variables: Dict[str, str], unresolvable: Optional[Set[str]] = None name: definition for name, definition in variables.items() if not _is_function(name) } + custom_function_names = set(self._functions.keys()) | functions_to_add.keys() + variable_names = ( + set(self._variables.keys()) | variables_to_add.keys() | (unresolvable or set()) + ) + for definitions in [functions_to_add, variables_to_add]: for name, definition in definitions.items(): parsed = parse( text=definition, name=name, - custom_function_names=set(self._functions.keys()), - variable_names=set(self._variables.keys()) - .union(variables.keys()) - .union(unresolvable or set()), + custom_function_names=custom_function_names, + variable_names=variable_names, ) if parsed.maybe_resolvable is None: diff --git a/tests/unit/script/types/test_custom_function.py b/tests/unit/script/types/test_custom_function.py index e7b76e808..61096cc3d 100644 --- a/tests/unit/script/types/test_custom_function.py +++ b/tests/unit/script/types/test_custom_function.py @@ -22,6 +22,24 @@ def test_custom_function_use_input_param_multiple_times(self): } ).resolve() == ScriptOutput({"output": Integer(9)}) + def test_custom_functions_any_order_via_add(self): + assert Script({}).add( + { + "%custom_cubed": "{%mul(%custom_square($0),$0)}", + "%custom_square": "{%mul($0, $0)}", + "output": "{%custom_cubed(3)}", + } + ).resolve() == ScriptOutput({"output": Integer(27)}) + + def test_custom_functions_any_order_via_init(self): + assert Script( + { + "%custom_cubed": "{%mul(%custom_square($0),$0)}", + "%custom_square": "{%mul($0, $0)}", + "output": "{%custom_cubed(3)}", + } + ).resolve() == ScriptOutput({"output": Integer(27)}) + def test_custom_function_cycle(self): with pytest.raises( CycleDetected, match=re.escape("The custom function %cycle_func cannot call itself.")