diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e6145c8c1fd6..174f085e2fbd 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -647,7 +647,7 @@ def add_column( ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ - def process_batch(batch: "pandas.DataFrame") -> "pandas.DataFrame": + def add_column(batch: "pandas.DataFrame") -> "pandas.DataFrame": batch.loc[:, col] = fn(batch) return batch @@ -655,7 +655,7 @@ def process_batch(batch: "pandas.DataFrame") -> "pandas.DataFrame": raise ValueError("`fn` must be callable, got {}".format(fn)) return self.map_batches( - process_batch, + add_column, batch_format="pandas", # TODO(ekl) we should make this configurable. compute=compute, concurrency=concurrency, @@ -761,11 +761,11 @@ def select_columns( ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ # noqa: E501 - def fn(batch): + def select_columns(batch): return BlockAccessor.for_block(batch).select(columns=cols) return self.map_batches( - fn, + select_columns, batch_format="pandas", zero_copy_batch=True, compute=compute, @@ -1119,7 +1119,7 @@ def random_sample( if seed is not None: random.seed(seed) - def process_batch(batch): + def random_sample(batch): if isinstance(batch, list): return [row for row in batch if random.random() <= fraction] if isinstance(batch, pa.Table): @@ -1135,7 +1135,7 @@ def process_batch(batch): ) raise ValueError(f"Unsupported batch type: {type(batch)}") - return self.map_batches(process_batch, batch_format=None) + return self.map_batches(random_sample, batch_format=None) @ConsumptionAPI def streaming_split(