Skip to content

Commit

Permalink
Make it possible to pass device through pull_apart_list (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
karlhigley authored Feb 28, 2022
1 parent b1a3594 commit 35c016c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def encode_list_column(original, encoded, dtype=None):
)


def pull_apart_list(original):
def pull_apart_list(original, device=None):
values = flatten_list_column_values(original)
if isinstance(original, pd.Series):
offsets = pd.Series([0]).append(original.map(len).cumsum()).reset_index(drop=True)
Expand All @@ -402,7 +402,7 @@ def pull_apart_list(original):
elements = original._column.elements
if isinstance(elements, cudf.core.column.lists.ListColumn):
offsets = elements.offsets[offsets]
return make_series(values), make_series(offsets)
return make_series(values, device), make_series(offsets, device)


def to_arrow(x):
Expand Down

0 comments on commit 35c016c

Please sign in to comment.