Skip to content

Commit

Permalink
[BUGFIX] Fix recursion limit when applying large reduce functions (#851)
Browse files Browse the repository at this point in the history
Python would think a recursive error occurred, but in reality, the stack got too large due to poor optimization when executing array reduce functions. Thanks Melissa from Discord for the bug report!
jmbannon authored Dec 21, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent b8fb119 commit 873ecc0
Showing 3 changed files with 61 additions and 15 deletions.
25 changes: 15 additions & 10 deletions src/ytdl_sub/script/types/function.py
Original file line number Diff line number Diff line change
@@ -232,20 +232,25 @@ def _resolve_lambda_reduce_function(
if len(lambda_array.value) == 1:
return lambda_array.value[0]

reduced = self._instantiate_lambda(
lambda_function_name=lambda_function_name,
args=[lambda_array.value[0], lambda_array.value[1]],
reduced: Resolvable = self._resolve_argument_type(
arg=self._instantiate_lambda(
lambda_function_name=lambda_function_name,
args=[lambda_array.value[0], lambda_array.value[1]],
),
resolved_variables=resolved_variables,
custom_functions=custom_functions,
)
for idx in range(2, len(lambda_array.value)):
reduced = self._instantiate_lambda(
lambda_function_name=lambda_function_name, args=[reduced, lambda_array.value[idx]]
reduced = self._resolve_argument_type(
arg=self._instantiate_lambda(
lambda_function_name=lambda_function_name,
args=[reduced, lambda_array.value[idx]],
),
resolved_variables=resolved_variables,
custom_functions=custom_functions,
)

return self._resolve_argument_type(
arg=reduced,
resolved_variables=resolved_variables,
custom_functions=custom_functions,
)
return reduced

def resolve(
self,
12 changes: 7 additions & 5 deletions src/ytdl_sub/utils/scriptable.py
Original file line number Diff line number Diff line change
@@ -19,12 +19,14 @@ class Scriptable(ABC):
Shared class between Entry and Overrides to manage their underlying Script.
"""

def __init__(self):
self.script = Script(
ScriptUtils.add_sanitized_variables(
dict(copy.deepcopy(VARIABLE_SCRIPTS), **copy.deepcopy(CUSTOM_FUNCTION_SCRIPTS))
)
_BASE_SCRIPT: Script = Script(
ScriptUtils.add_sanitized_variables(
dict(copy.deepcopy(VARIABLE_SCRIPTS), **copy.deepcopy(CUSTOM_FUNCTION_SCRIPTS))
)
)

def __init__(self):
self.script = copy.deepcopy(Scriptable._BASE_SCRIPT)
self.unresolvable: Set[str] = copy.deepcopy(UNRESOLVED_VARIABLES)

def update_script(self) -> None:
39 changes: 39 additions & 0 deletions tests/unit/script/functions/test_array_functions.py
Original file line number Diff line number Diff line change
@@ -58,6 +58,45 @@ def test_array_reduce(self):
output = single_variable_output("{%array_reduce([1, 2, 3, 4], %add)}")
assert output == 10

def test_array_reduce_complex(self):
output = (
Script(
{
"%custom_get": """{
%if(
%bool(siblings_array),
%array_apply_fixed(
siblings_array,
%string($0),
%map_get
)
[]
)
}""",
"siblings_array": """{
[
{'upload_date': '20200101'},
{'upload_date': '19940101'}
]
}""",
"upload_date": "20230101",
"output": """{
%array_reduce(
%if_passthrough(
%custom_get('upload_date'),
[ upload_date ]
),
%max
)
}""",
}
)
.resolve(update=True)
.get("output")
.native
)
assert output == "20200101"

def test_array_enumerate(self):
output = (
Script(

0 comments on commit 873ecc0

Please sign in to comment.