Skip to content

Commit

Permalink
Merge pull request #492 from bbrowning/run-pipeline-cli
Browse files Browse the repository at this point in the history
Stub in a run_pipeline CLI and add example usage
  • Loading branch information
mergify[bot] authored Jan 29, 2025
2 parents 9f68566 + 70da871 commit 597e372
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 2 deletions.
18 changes: 18 additions & 0 deletions docs/examples/blocks/iterblock/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# IterBlock

`IterBlock` is used to run multiple iterations of another `Block`,
such as to call an `LLMBlock` multiple times generating a new sample
from each iteration.

A simple example of its usage is shown in
[pipeline.yaml](pipeline.yaml), where we use it to call the
`DuplicateColumnsBlock` 5 times for every input sample, which results
in us generating 5 output samples per each input samples with the
specified column duplicated in each output sample.

Assuming you have SDG installed, you can run that example with a
command like:

```shell
python -m instructlab.sdg.cli.run_pipeline --pipeline pipeline.yaml --input input.jsonl --output output.jsonl
```
1 change: 1 addition & 0 deletions docs/examples/blocks/iterblock/input.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"foo": "bar"}
9 changes: 9 additions & 0 deletions docs/examples/blocks/iterblock/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: "1.0"
blocks:
- name: iterate_block_example
type: IterBlock
config:
num_iters: 5
block_type: DuplicateColumnsBlock
columns_map:
foo: baz
90 changes: 90 additions & 0 deletions src/instructlab/sdg/cli/run_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from pathlib import Path
import os

# Third Party
from datasets import Dataset
import openai

# First Party
from instructlab.sdg.pipeline import Pipeline, PipelineContext
from instructlab.sdg.utils.json import jldump, jlload
from instructlab.sdg.utils.logging import setup_logger

if __name__ == "__main__":
# Standard
import argparse

parser = argparse.ArgumentParser(
description="Run a synthetic data generation pipeline."
)

# Required args
parser.add_argument(
"--pipeline",
type=str,
required=True,
help="Path to the yaml file of the pipeline to execute.",
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Input jsonl file containing samples used for data generation.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Path to write the generated samples to, in jsonl format.",
)
parser.add_argument(
"--endpoint-url",
type=str,
default="http://localhost:8000/v1",
help="URL endpoint of an OpenAI-compatible API server running your teacher model.",
)
parser.add_argument(
"--model-family",
type=str,
default="mixtral",
help="Model family of your teacher model. Valid values are granite, merlinite, mistral, or mixtral.",
)
parser.add_argument(
"--model-id",
type=str,
help="The id of the teacher model to use, as recognized by your OpenAI-compatible API server.",
)

# Optional args
parser.add_argument(
"--api-key",
type=str,
default="EMPTY",
help="API key for the OpenAI-compatible API endpoint",
)
parser.add_argument(
"--log-level",
type=str,
default=os.getenv("LOG_LEVEL", "INFO"),
help="Logging level",
)

args = parser.parse_args()
setup_logger(args.log_level)
client = openai.OpenAI(base_url=args.endpoint_url, api_key=args.api_key)
# TODO: Remove num_instructions_to_generate hardcode of 30 here,
# but first we need to remove it as a required parameter of the
# PipelineContext generally.
#
# https://github.com/instructlab/sdg/issues/491
pipeline_context = PipelineContext(client, args.model_family, args.model_id, 30)
pipeline_path = Path(args.pipeline).absolute()
pipeline = Pipeline.from_file(pipeline_context, pipeline_path)
input_path = Path(args.input).absolute()
input_ds = Dataset.from_list(jlload(str(input_path)))
output_ds = pipeline.generate(input_ds)
output_path = Path(args.output).absolute()
jldump(output_ds, str(output_path))
17 changes: 17 additions & 0 deletions src/instructlab/sdg/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import logging


def setup_logger(level="DEBUG"):
"""
Setup a logger - ONLY to be used when running CLI commands in
SDG directly. DO NOT call this from regular library code, and only
call it from __main__ entrypoints in the instructlab.sdg.cli
package
"""
logging.basicConfig(
level=level,
format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s",
)
40 changes: 38 additions & 2 deletions tests/functional/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import pathlib
from pathlib import Path
import shlex
import shutil
import subprocess
import sys

# Third Party
from docling.document_converter import DocumentConverter

# First Party
from instructlab.sdg.utils.json import jlload


def test_example_mixing(tmp_path: pathlib.Path, examples_path: pathlib.Path):
def test_example_mixing(tmp_path: Path, examples_path: Path):
example_copy_path = tmp_path.joinpath("mix_datasets")
shutil.copytree(examples_path.joinpath("mix_datasets"), example_copy_path)
script = example_copy_path.joinpath("example_mixing.py")
Expand Down Expand Up @@ -38,3 +42,35 @@ def test_example_mixing(tmp_path: pathlib.Path, examples_path: pathlib.Path):
from_ds_2.append(sample)
assert len(from_ds_1) == 10
assert len(from_ds_2) == 1


def _test_example_run_pipeline(example_dir):
doc_converter = DocumentConverter()
readme_path = example_dir.joinpath("README.md")
assert readme_path.exists()
conv_result = doc_converter.convert(readme_path)
command_to_run = ""
for item, _level in conv_result.document.iterate_items():
if item.label == "code" and "run_pipeline" in item.text:
command_to_run = item.text
assert command_to_run
# Turn the generic command into a list of shell arguments
shell_args = shlex.split(command_to_run)
# Ensure we use the proper Python - ie from tox environment
shell_args[0] = sys.executable
# Run the command with the current working directory set to our
# example's subdirectory
subprocess.check_call(shell_args, text=True, cwd=example_dir)


def test_example_iterblock(tmp_path: Path, examples_path: Path):
shutil.copytree(
examples_path.joinpath("blocks", "iterblock"), tmp_path, dirs_exist_ok=True
)
iterblock_path = tmp_path
_test_example_run_pipeline(iterblock_path)
output_jsonl = iterblock_path.joinpath("output.jsonl")
assert output_jsonl.exists()
output = jlload(output_jsonl)
assert len(output) == 5
assert output[4]["baz"] == "bar"

0 comments on commit 597e372

Please sign in to comment.