From 842b43ff9bea45e17b068e1fefdbc89d145ffb95 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:20:53 -0500 Subject: [PATCH] fix: upsample the phase10 knowledge dataset When we mix the knowledge dataset with skills today, we do not account for the potential discrepancy in size between the generated knowledge data and skills data. This leads to the models potentially forgetting the data it was trained on in the knowledge phase. As a simple workaround, we simply upsample the knowledge samples before mixing them in with the generated skills dataset. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/sdg/datamixing.py | 28 ++++++++++++++++++++++++++++ src/instructlab/sdg/generate_data.py | 19 +++++++++++++++++-- src/instructlab/sdg/llmblock.py | 1 + 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/instructlab/sdg/datamixing.py b/src/instructlab/sdg/datamixing.py index 5172fdfb..8621aa73 100644 --- a/src/instructlab/sdg/datamixing.py +++ b/src/instructlab/sdg/datamixing.py @@ -528,6 +528,27 @@ def _convert_to_leaf_node_messages(sample: dict, sys_prompt: str): return sample +def upsample_dataset(ds: Dataset, num_samples: int) -> Dataset: + """ + Given a `Dataset`, upsample it such that the resulting dataset + is equal to or greater than `num_samples`. + + Args: + ds (Dataset): The dataset to upsample. + num_samples (int): The number of samples to upsample by. + + Returns: + Dataset: The resulting dataset. + """ + if len(ds) >= num_samples: + return ds + + pd_ds = ds.to_pandas() + pd_ds = pd_ds.sample(num_samples, replace=True) + new_ds = Dataset.from_pandas(pd_ds) + return new_ds + + class DataMixer: # pylint: disable=too-many-instance-attributes @@ -547,6 +568,7 @@ def __init__( date_suffix, sys_prompt, num_procs, + upsample_amount: int, auxiliary_inst=None, ): self.data_dirs = data_dirs @@ -555,6 +577,7 @@ def __init__( self.date_suffix = date_suffix self.num_procs = num_procs self.auxiliary_inst = auxiliary_inst + self.upsample_amount = upsample_amount self.knowledge_recipe = self._load_default_recipe("knowledge.yaml") self.skills_recipe = self._load_default_recipe("skills.yaml") @@ -612,6 +635,11 @@ def collect( skills_phase_data = _create_phase10_ds( new_generated_data, self.auxiliary_inst, use_legacy_pretraining_format ) + # XXX(osilkin): To prevent the size of the skills dataset from drowning out the knowledge + # samples, we upsample the knowledge data here to prevent catastrophic forgetting + skills_phase_data = upsample_dataset( + skills_phase_data, self.upsample_amount + ) output_file_leaf_skills = ( f"node_datasets_{self.date_suffix}/{leaf_node_path}_p10.jsonl" ) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index cf65ae14..d90132ea 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -22,6 +22,7 @@ from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init from instructlab.sdg.llmblock import ( DEFAULT_MAX_NUM_TOKENS, + DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT, MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL, ) @@ -254,7 +255,14 @@ def load_pipeline(yaml_basename): ) -def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt): +def _mixer_init( + ctx, + output_dir, + date_suffix, + knowledge_auxiliary_inst, + system_prompt, + upsample_amount: int, +): data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")] data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs()) @@ -264,6 +272,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_p date_suffix, system_prompt, ctx.dataset_num_procs, + upsample_amount, knowledge_auxiliary_inst, ) @@ -295,6 +304,7 @@ def generate_data( batch_size: Optional[int] = None, checkpoint_dir: Optional[str] = None, max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS, + upsample_amount: Optional[int] = DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT, ) -> None: """Generate data for training and testing a model. @@ -372,7 +382,12 @@ def generate_data( mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx) mixer = _mixer_init( - ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt + ctx, + output_dir, + date_suffix, + knowledge_pipe.auxiliary_inst, + system_prompt, + upsample_amount, ) if console_output: diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 0e9a5f22..80cd2437 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) DEFAULT_MAX_NUM_TOKENS = 4096 +DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT = 5000 MODEL_FAMILY_MIXTRAL = "mixtral" MODEL_FAMILY_MERLINITE = "merlinite"