Skip to content

Commit

Permalink
[Data] Make internal UDF names more descriptive (#44985)
Browse files Browse the repository at this point in the history
APIs like `select_columns` call `map_batches` under-the-hood and use
functions with non-descriptives names. For example, if you call
`select_columns`, you'd see something like this in the progress bar:

```
ReadRange->MapBatches(fn)
```

To prevent confusion (e.g., what is `fn`?), this PR makes the function
names more descriptive.
  • Loading branch information
bveeramani authored Apr 27, 2024
1 parent 10f7f2a commit ff62312
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,15 +647,15 @@ def add_column(
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
"""

def process_batch(batch: "pandas.DataFrame") -> "pandas.DataFrame":
def add_column(batch: "pandas.DataFrame") -> "pandas.DataFrame":
batch.loc[:, col] = fn(batch)
return batch

if not callable(fn):
raise ValueError("`fn` must be callable, got {}".format(fn))

return self.map_batches(
process_batch,
add_column,
batch_format="pandas", # TODO(ekl) we should make this configurable.
compute=compute,
concurrency=concurrency,
Expand Down Expand Up @@ -761,11 +761,11 @@ def select_columns(
ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" # noqa: E501

def fn(batch):
def select_columns(batch):
return BlockAccessor.for_block(batch).select(columns=cols)

return self.map_batches(
fn,
select_columns,
batch_format="pandas",
zero_copy_batch=True,
compute=compute,
Expand Down Expand Up @@ -1119,7 +1119,7 @@ def random_sample(
if seed is not None:
random.seed(seed)

def process_batch(batch):
def random_sample(batch):
if isinstance(batch, list):
return [row for row in batch if random.random() <= fraction]
if isinstance(batch, pa.Table):
Expand All @@ -1135,7 +1135,7 @@ def process_batch(batch):
)
raise ValueError(f"Unsupported batch type: {type(batch)}")

return self.map_batches(process_batch, batch_format=None)
return self.map_batches(random_sample, batch_format=None)

@ConsumptionAPI
def streaming_split(
Expand Down

0 comments on commit ff62312

Please sign in to comment.