Skip to content

Commit

Permalink
fix: upsample the phase10 knowledge dataset
Browse files Browse the repository at this point in the history
When we mix the knowledge dataset with skills today, we do not account for the potential discrepancy
in size between the generated knowledge data and skills data. This leads to the models potentially
forgetting the data it was trained on in the knowledge phase. As a simple workaround, we simply
upsample the knowledge samples before mixing them in with the generated skills dataset.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Nov 14, 2024
1 parent b6f07a8 commit 6809fad
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def __init__(
date_suffix,
sys_prompt,
num_procs,
upsample_amount: int,
auxiliary_inst=None,
):
self.data_dirs = data_dirs
Expand All @@ -555,6 +556,7 @@ def __init__(
self.date_suffix = date_suffix
self.num_procs = num_procs
self.auxiliary_inst = auxiliary_inst
self.upsample_amount = upsample_amount

self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
self.skills_recipe = self._load_default_recipe("skills.yaml")
Expand Down Expand Up @@ -619,6 +621,7 @@ def collect(
skills_phase_data,
self.skills_recipe,
output_file_leaf_skills,
sampling_size=self.upsample_amount,
)
else:
messages = new_generated_data.map(
Expand Down
19 changes: 17 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import (
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
DEFAULT_MAX_NUM_TOKENS,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
Expand Down Expand Up @@ -254,7 +255,14 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt):
def _mixer_init(
ctx,
output_dir,
date_suffix,
knowledge_auxiliary_inst,
system_prompt,
upsample_amount: int,
):
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())

Expand All @@ -264,6 +272,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_p
date_suffix,
system_prompt,
ctx.dataset_num_procs,
upsample_amount,
knowledge_auxiliary_inst,
)

Expand Down Expand Up @@ -295,6 +304,7 @@ def generate_data(
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
upsample_amount: Optional[int] = DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -372,7 +382,12 @@ def generate_data(
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(
ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt
ctx,
output_dir,
date_suffix,
knowledge_pipe.auxiliary_inst,
system_prompt,
upsample_amount,
)

if console_output:
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096
DEFAULT_KNOWLEDGE_UPSAMPLE_AMOUNT = 5000

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
Expand Down

0 comments on commit 6809fad

Please sign in to comment.