Skip to content

Commit

Permalink
Merge pull request #484 from eshwarprasadS/RHELAI-2756-batch-after-ev…
Browse files Browse the repository at this point in the history
…ery-block

Adding Batching After Every Block
  • Loading branch information
mergify[bot] authored Jan 21, 2025
2 parents b66528c + 0bb9304 commit 8191c2a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 14 deletions.
47 changes: 33 additions & 14 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
def _generate_single(self, dataset) -> Dataset:
"""Generate a single dataset by running the pipeline steps."""
for block_prop in self.chained_blocks:
# Initialize arguments for error handling to None
block, block_name, block_type = None, None, None
try:
# Parse and instantiate the block
Expand All @@ -201,8 +200,39 @@ def _generate_single(self, dataset) -> Dataset:
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)
# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)

# Check if batching is enabled
if not self.ctx.batching_enabled:
logger.info(
"Batching disabled; processing block '%s' single-threaded.",
block_name,
)
dataset = block.generate(dataset)
else:
# Split the dataset into batches
input_splits = self._split_dataset(dataset)
# Process each batch in sequence
output_splits = [
block.generate(input_split) for input_split in input_splits
]
# Combine the processed splits back into a single dataset
dataset = concatenate_datasets(output_splits)

# If the dataset is empty after processing, terminate early
if len(dataset) == 0:
return dataset

# Remove unnecessary columns if specified
drop_columns_in_ds = [
e for e in drop_columns if e in dataset.column_names
]
if drop_columns_in_ds:
dataset = dataset.remove_columns(drop_columns_in_ds)

# Drop duplicates if specified
if drop_duplicates_cols:
dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)

except Exception as err:
raise PipelineBlockError(
exception=err,
Expand All @@ -211,17 +241,6 @@ def _generate_single(self, dataset) -> Dataset:
block_type=block_type,
) from err

# If at any point we end up with an empty data set, the pipeline has failed
if len(dataset) == 0:
return dataset

drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
if drop_columns:
dataset = dataset.remove_columns(drop_columns_in_ds)

if drop_duplicates_cols:
dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)

return dataset

def _drop_duplicates(self, dataset, cols):
Expand Down
65 changes: 65 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,71 @@ def generate(self, dataset):
assert res.to_list() == [{"foo": i * 2} for i in range(10)]


def test_pipeline_batching_after_each_block(sample_dataset, threaded_ctx):
"""Test that batching occurs after each block in the pipeline."""

class MockBlockOne:
def __init__(self, ctx, pipeline, block_name, **block_config):
self.ctx = ctx # Save the context for use in generate if needed

def generate(self, dataset):
# Assert that the dataset entering Block 1 is properly batched
assert (
len(dataset) <= self.ctx.batch_size
), f"Dataset size {len(dataset)} entering block 1 exceeds batch size {self.ctx.batch_size}"
# Simulate dataset explosion in Block 1

exploded_data = []
for _ in range(10): # Repeat each entry 10 times
exploded_data.extend(dataset)

# Create a new Dataset from the exploded data
output = Dataset.from_list(exploded_data)

return output

class MockBlockTwo:
def __init__(self, ctx, pipeline, block_name, **block_config):
self.ctx = ctx # Save the context for use in generate if needed

def generate(self, dataset):
# Assert that the dataset entering Block 2 is properly batched (this will fail if batching is not done after each block)
assert (
len(dataset) <= self.ctx.batch_size
), f"Dataset size {len(dataset)} entering block 2 exceeds batch size {self.ctx.batch_size}"
return dataset

# Define the pipeline configuration with two blocks
pipe_cfg = [
{
"name": "block-one",
"type": "block_one",
"config": {},
},
{
"name": "block-two",
"type": "block_two",
"config": {},
},
]

# Patch block types to use the mock implementations
with block_types({"block_one": MockBlockOne, "block_two": MockBlockTwo}):
# Run the pipeline
result = Pipeline(threaded_ctx, "", pipe_cfg).generate(sample_dataset)
# Assertions for the final output dataset:
# 1. Check the final dataset length is the expected value
expected_len = (
len(sample_dataset) * 10
) # Since Block 1 multiplies the dataset by 10
assert (
len(result) == expected_len
), f"Expected dataset length {expected_len}, but got {len(result)}"

# 2. Check the dataset features: Ensure the feature structure is consistent with the input
assert "foo" in result[0], "Feature 'foo' not found in the final dataset"


## Pipeline Error Handling ##


Expand Down

0 comments on commit 8191c2a

Please sign in to comment.