Skip to content

Commit

Permalink
expose max_num_tokens as configurable
Browse files Browse the repository at this point in the history
max-num-tokens is a nice way to run a shorter or longer SDG run.
locally I have been modifiyng the pipeline yaml from 2048 to 512 which ends up just generating less data
exposing this to the CLI could allow power users to run different types of SDG runs!

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Nov 7, 2024
1 parent 4c82c05 commit ba32b79
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _context_init(
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
max_num_tokens: Optional[int] = 4096,
):
extra_kwargs = {}
if batch_size is not None:
Expand All @@ -196,6 +197,7 @@ def _context_init(
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
max_num_tokens=max_num_tokens,
**extra_kwargs,
)

Expand Down Expand Up @@ -281,6 +283,7 @@ def generate_data(
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = 4096,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -343,6 +346,7 @@ def generate_data(
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
Expand Down
23 changes: 20 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096

DEFAULT_NUM_SAMPLES = 30

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

Expand Down Expand Up @@ -62,6 +66,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_path,
output_cols,
model_prompt=None,
Expand All @@ -81,8 +86,18 @@ def __init__(
self.parser_name = parser_kwargs.get("parser_name", None)
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
# max_num_tokens should only be applicable to knowledge blocks
# gen_knowledge if the full/simple pipeline's knowledge generation block
if block_name != "gen_knowledge":
logger.debug(
f"Not applying max_num_tokens to block {block_name}. This is only applicable for gen_knowledge."
)
max_num_tokens = DEFAULT_MAX_NUM_TOKENS
self.gen_kwargs = self._gen_kwargs(
gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=max_num_tokens,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
Expand Down Expand Up @@ -150,10 +165,10 @@ def _gen_kwargs(self, gen_kwargs, **defaults):
and gen_kwargs["n"] == "scaled"
):
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
return gen_kwargs

def _generate(self, samples) -> list:
Expand Down Expand Up @@ -259,6 +274,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_paths,
output_cols,
selector_column_name,
Expand All @@ -271,6 +287,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_paths[0][0],
output_cols,
model_prompt=model_prompt,
Expand Down
15 changes: 13 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
max_num_tokens: the maximum number of tokens to generate per sample.
"""

# The default batch size of 8 has been determined as a good default for
Expand All @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = 4096
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -191,11 +193,20 @@ def _generate_single(self, dataset) -> Dataset:
block_name = block_prop["name"]
block_type = _lookup_block_type(block_prop["type"])
block_config = block_prop["config"]
max_num_tokens = self.ctx.max_num_tokens
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
if block_type in (llmblock.LLMBlock, llmblock.ConditionalLLMBlock):
block = block_type(
self.ctx,
self,
block_name,
max_num_tokens,
**block_config,
)
else:
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_model_prompt_empty_string(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt="",
Expand All @@ -57,6 +58,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
Expand All @@ -76,6 +78,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
Expand Down

0 comments on commit ba32b79

Please sign in to comment.