Skip to content

Commit

Permalink
Fix get_device_memory_ids (#1305)
Browse files Browse the repository at this point in the history
A recent change to the way `StringColumn`s are implemented in cudf threw up that we were never correctly determining the number of device buffers belonging to cudf columns if they had children (e.g. list and struct columns) or masks (any nullable column). Handle those cases and update the test.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)
  - Mads R. B. Kristensen (https://github.com/madsbk)

URL: #1305
  • Loading branch information
wence- authored Jan 18, 2024
1 parent 10f1dee commit 34e7404
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 4 additions & 0 deletions dask_cuda/get_device_memory_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def get_device_memory_objects_cudf_index(obj):
def get_device_memory_objects_cudf_multiindex(obj):
return dispatch(obj._columns)

@dispatch.register(cudf.core.column.ColumnBase)
def get_device_memory_objects_cudf_column(obj):
return dispatch(obj.data) + dispatch(obj.children) + dispatch(obj.mask)


@sizeof.register_lazy("cupy")
def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
Expand Down
15 changes: 13 additions & 2 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,24 @@ def test_dataframes_share_dev_mem(root_dir):
def test_cudf_get_device_memory_objects():
cudf = pytest.importorskip("cudf")
objects = [
cudf.DataFrame({"a": range(10), "b": range(10)}, index=reversed(range(10))),
cudf.DataFrame(
{"a": [0, 1, 2, 3, None, 5, 6, 7, 8, 9], "b": range(10)},
index=reversed(range(10)),
),
cudf.MultiIndex(
levels=[[1, 2], ["blue", "red"]], codes=[[0, 0, 1, 1], [1, 0, 1, 0]]
),
]
res = get_device_memory_ids(objects)
assert len(res) == 4, "We expect four buffer objects"
# Buffers are:
# 1. int data for objects[0].a
# 2. mask data for objects[0].a
# 3. int data for objects[0].b
# 4. int data for objects[0].index
# 5. int data for objects[1].levels[0]
# 6. char data for objects[1].levels[1]
# 7. offset data for objects[1].levels[1]
assert len(res) == 7, "We expect seven buffer objects"


def test_externals(root_dir):
Expand Down

0 comments on commit 34e7404

Please sign in to comment.