Skip to content

Commit

Permalink
fix list requires: must not capture one-for-all done_deps
Browse files Browse the repository at this point in the history
  • Loading branch information
m.kindritskiy committed Sep 10, 2024
1 parent de9bbe4 commit 2ae90a6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
15 changes: 12 additions & 3 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
Dict,
List,
Set,
NoReturn,
Optional,
DefaultDict,
Expand Down Expand Up @@ -722,10 +723,14 @@ def process_node(
)
if graph_link.requires:
if isinstance(graph_link.requires, list):
done_deps = set()
done_link_deps: Set = set()

def add_done_dep_callback(
dep: Dep, req: Any, graph_link: Link, schedule: Callable
done_deps: Set,
dep: Dep,
req: Any,
graph_link: Link,
schedule: Callable,
) -> None:
def done_cb() -> None:
done_deps.add(req)
Expand All @@ -736,7 +741,11 @@ def done_cb() -> None:

for req in graph_link.requires:
add_done_dep_callback(
to_dep[to_func[req]], req, graph_link, schedule
done_link_deps,
to_dep[to_func[req]],
req,
graph_link,
schedule,
)
else:
dep = to_dep[to_func[graph_link.requires]]
Expand Down
30 changes: 23 additions & 7 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_links():

def test_links_requires_list():
db = {
"song": {100: {"name": "fuel", "artist_id": 1, "album_id": 10}},
"song": {100: {"id": 100, "name": "fuel", "artist_id": 1, "album_id": 10}},
"artist": {
1: {"name": "Metallica"},
},
Expand All @@ -224,17 +224,22 @@ def get_fields(song_id):

return [list(get_fields(song_id)) for song_id in song_ids]

def song_info_fields(fields, ids):
def get_fields(id_):
album_id = id_["album_id"]
artist_id = id_["artist_id"]
def song_info_fields(fields, requires):
def get_fields(require):
if "id" in require:
song = db["song"][require["id"]]
album_id = song["album_id"]
artist_id = song["artist_id"]
else:
album_id = require["album_id"]
artist_id = require["artist_id"]
for f in fields:
if f.name == "album_name":
yield db["album"][album_id]["name"]
elif f.name == "artist_name":
yield db["artist"][artist_id]["name"]

return [list(get_fields(id_)) for id_ in ids]
return [list(get_fields(require)) for require in requires]

graph = Graph(
[
Expand All @@ -258,6 +263,12 @@ def get_fields(id_):
link_song_info,
requires=["album_id", "artist_id"],
),
Link(
"infoV2",
TypeRef["SongInfo"],
link_song_info,
requires=["id"],
),
],
),
Root(
Expand All @@ -274,6 +285,10 @@ def get_fields(id_):
Q.info[
Q.album_name,
Q.artist_name,
],
Q.infoV2[
Q.album_name,
Q.artist_name,
]
]
]
Expand All @@ -283,7 +298,8 @@ def get_fields(id_):
result,
{
"song": {
"info": {"album_name": "Reload", "artist_name": "Metallica"}
"info": {"album_name": "Reload", "artist_name": "Metallica"},
"infoV2": {"album_name": "Reload", "artist_name": "Metallica"},
}
},
)
Expand Down

0 comments on commit 2ae90a6

Please sign in to comment.