diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 8c38e8d1d..68d6eb486 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -391,7 +391,7 @@ def encode_list_column(original, encoded, dtype=None): ) -def pull_apart_list(original): +def pull_apart_list(original, device=None): values = flatten_list_column_values(original) if isinstance(original, pd.Series): offsets = pd.Series([0]).append(original.map(len).cumsum()).reset_index(drop=True) @@ -402,7 +402,7 @@ def pull_apart_list(original): elements = original._column.elements if isinstance(elements, cudf.core.column.lists.ListColumn): offsets = elements.offsets[offsets] - return make_series(values), make_series(offsets) + return make_series(values, device), make_series(offsets, device) def to_arrow(x):