diff --git a/src/instructlab/sdg/datamixing.py b/src/instructlab/sdg/datamixing.py index 5172fdfb..de38dfef 100644 --- a/src/instructlab/sdg/datamixing.py +++ b/src/instructlab/sdg/datamixing.py @@ -547,6 +547,7 @@ def __init__( date_suffix, sys_prompt, num_procs, + upsample_amount: int, auxiliary_inst=None, ): self.data_dirs = data_dirs @@ -555,6 +556,7 @@ def __init__( self.date_suffix = date_suffix self.num_procs = num_procs self.auxiliary_inst = auxiliary_inst + self.upsample_amount = upsample_amount self.knowledge_recipe = self._load_default_recipe("knowledge.yaml") self.skills_recipe = self._load_default_recipe("skills.yaml") @@ -619,6 +621,7 @@ def collect( skills_phase_data, self.skills_recipe, output_file_leaf_skills, + sampling_size=self.upsample_amount, ) else: messages = new_generated_data.map( diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index cf65ae14..90cc15ab 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -21,6 +21,7 @@ from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init from instructlab.sdg.llmblock import ( + DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT, DEFAULT_MAX_NUM_TOKENS, MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL, @@ -254,7 +255,14 @@ def load_pipeline(yaml_basename): ) -def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt): +def _mixer_init( + ctx, + output_dir, + date_suffix, + knowledge_auxiliary_inst, + system_prompt, + upsample_amount: int, +): data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")] data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs()) @@ -264,6 +272,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_p date_suffix, system_prompt, ctx.dataset_num_procs, + upsample_amount, knowledge_auxiliary_inst, ) @@ -295,6 +304,7 @@ def generate_data( batch_size: Optional[int] = None, checkpoint_dir: Optional[str] = None, max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS, + upsample_amount: Optional[int] = DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT, ) -> None: """Generate data for training and testing a model. @@ -372,7 +382,12 @@ def generate_data( mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx) mixer = _mixer_init( - ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt + ctx, + output_dir, + date_suffix, + knowledge_pipe.auxiliary_inst, + system_prompt, + upsample_amount, ) if console_output: diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 0e9a5f22..80cd2437 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) DEFAULT_MAX_NUM_TOKENS = 4096 +DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT = 5000 MODEL_FAMILY_MIXTRAL = "mixtral" MODEL_FAMILY_MERLINITE = "merlinite"