Skip to content

Commit

Permalink
fix: upsample the phase10 knowledge dataset
Browse files Browse the repository at this point in the history
When we mix the knowledge dataset with skills today, we do not account for the potential discrepancy
in size between the generated knowledge data and skills data. This leads to the models potentially
forgetting the data it was trained on in the knowledge phase. As a simple workaround, we simply
upsample the knowledge samples before mixing them in with the generated skills dataset.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Nov 13, 2024
1 parent b6f07a8 commit 842b43f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,27 @@ def _convert_to_leaf_node_messages(sample: dict, sys_prompt: str):
return sample


def upsample_dataset(ds: Dataset, num_samples: int) -> Dataset:
"""
Given a `Dataset`, upsample it such that the resulting dataset
is equal to or greater than `num_samples`.
Args:
ds (Dataset): The dataset to upsample.
num_samples (int): The number of samples to upsample by.
Returns:
Dataset: The resulting dataset.
"""
if len(ds) >= num_samples:
return ds

pd_ds = ds.to_pandas()
pd_ds = pd_ds.sample(num_samples, replace=True)
new_ds = Dataset.from_pandas(pd_ds)
return new_ds


class DataMixer:
# pylint: disable=too-many-instance-attributes

Expand All @@ -547,6 +568,7 @@ def __init__(
date_suffix,
sys_prompt,
num_procs,
upsample_amount: int,
auxiliary_inst=None,
):
self.data_dirs = data_dirs
Expand All @@ -555,6 +577,7 @@ def __init__(
self.date_suffix = date_suffix
self.num_procs = num_procs
self.auxiliary_inst = auxiliary_inst
self.upsample_amount = upsample_amount

self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
self.skills_recipe = self._load_default_recipe("skills.yaml")
Expand Down Expand Up @@ -612,6 +635,11 @@ def collect(
skills_phase_data = _create_phase10_ds(
new_generated_data, self.auxiliary_inst, use_legacy_pretraining_format
)
# XXX(osilkin): To prevent the size of the skills dataset from drowning out the knowledge
# samples, we upsample the knowledge data here to prevent catastrophic forgetting
skills_phase_data = upsample_dataset(
skills_phase_data, self.upsample_amount
)
output_file_leaf_skills = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p10.jsonl"
)
Expand Down
19 changes: 17 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import (
DEFAULT_MAX_NUM_TOKENS,
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
)
Expand Down Expand Up @@ -254,7 +255,14 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt):
def _mixer_init(
ctx,
output_dir,
date_suffix,
knowledge_auxiliary_inst,
system_prompt,
upsample_amount: int,
):
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())

Expand All @@ -264,6 +272,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_p
date_suffix,
system_prompt,
ctx.dataset_num_procs,
upsample_amount,
knowledge_auxiliary_inst,
)

Expand Down Expand Up @@ -295,6 +304,7 @@ def generate_data(
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
upsample_amount: Optional[int] = DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -372,7 +382,12 @@ def generate_data(
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(
ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt
ctx,
output_dir,
date_suffix,
knowledge_pipe.auxiliary_inst,
system_prompt,
upsample_amount,
)

if console_output:
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT = 5000

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
Expand Down

0 comments on commit 842b43f

Please sign in to comment.