Skip to content

Commit 1ac689a

Browse files
committed
update generate_data API calls
1 parent 09fdffb commit 1ac689a

File tree

5 files changed

+15
-14
lines changed

5 files changed

+15
-14
lines changed

scripts/test_freeform_skills.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
ds = Dataset.from_list(samples)
5151

52-
skills_flow = SynthSkillsFlow(client, "mixtral", teacher_model).get_flow()
52+
skills_flow = SynthSkillsFlow(client, "mixtral", teacher_model, 30).get_flow()
5353
skills_pipe = Pipeline(skills_flow)
5454

5555
sdg = SDG([skills_pipe])

scripts/test_grounded_skills.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797

9898
ds = Dataset.from_list(samples)
9999

100-
skills_flow = SynthGroundedSkillsFlow(client, "mixtral", teacher_model).get_flow()
100+
skills_flow = SynthGroundedSkillsFlow(client, "mixtral", teacher_model, 30).get_flow()
101101
skills_pipe = Pipeline(skills_flow)
102102

103103
sdg = SDG([skills_pipe])

scripts/test_knowledge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
ds = Dataset.from_list(samples)
4040

41-
mmlu_flow = MMLUBenchFlow(client, "mixtral", teacher_model).get_flow()
42-
knowledge_flow = SynthKnowledgeFlow(client, "mixtral", teacher_model).get_flow()
41+
mmlu_flow = MMLUBenchFlow(client, "mixtral", teacher_model, 30).get_flow()
42+
knowledge_flow = SynthKnowledgeFlow(client, "mixtral", teacher_model, 30).get_flow()
4343
knowledge_pipe = Pipeline(knowledge_flow)
4444
mmlu_pipe = Pipeline(mmlu_flow)
4545

src/instructlab/sdg/default_flows.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ def _get_model_prompt(model_family):
2929

3030

3131
class Flow(ABC):
32-
def __init__(self, client, model_family, model_id, batched=True) -> None:
32+
def __init__(self, client, model_family, model_id, num_instructions_to_generate, batched=True) -> None:
3333
self.client = client
3434
self.model_family = model_family
3535
self.model_id = model_id
36+
self.num_instructions_to_generate = num_instructions_to_generate
3637
self.batched = batched
3738

3839
@abstractmethod
@@ -60,7 +61,7 @@ def get_flow(self) -> list:
6061
"gen_kwargs": {
6162
"max_tokens": 2048,
6263
"temperature": 0.7,
63-
"n": 1
64+
"n": self.num_instructions_to_generate
6465
},
6566
"drop_duplicates": ["output"],
6667
}
@@ -280,7 +281,7 @@ def get_flow(self) -> list:
280281
"output_cols": ["question"],
281282
"batch_kwargs": {
282283
"num_procs": 8,
283-
"num_samples": 30,
284+
"num_samples": self.num_instructions_to_generate,
284285
"batched": self.batched,
285286
},
286287
},
@@ -375,16 +376,16 @@ def get_flow(self) -> list:
375376
"model_prompt": _get_model_prompt(self.model_family),
376377
"output_cols": ["context"],
377378
"batch_kwargs": {
378-
"num_samples": 30,
379379
"num_procs": 8,
380380
"batched": self.batched,
381381
}
382382
},
383383
"gen_kwargs": {
384384
"temperature": 0.7,
385385
"max_tokens": 2048,
386-
"n": 10
386+
"n": self.num_instructions_to_generate
387387
},
388+
"drop_duplicates": ["context"],
388389
},
389390
{
390391
"block_type": LLMBlock,
@@ -396,6 +397,7 @@ def get_flow(self) -> list:
396397
"model_prompt": _get_model_prompt(self.model_family),
397398
"output_cols": ["question"],
398399
"batch_kwargs": {
400+
"num_samples": 3,
399401
"num_procs": 8,
400402
"batched": self.batched,
401403
},
@@ -414,7 +416,6 @@ def get_flow(self) -> list:
414416
"batch_kwargs": {
415417
"num_procs": 8,
416418
"batched": self.batched,
417-
"num_samples": 10,
418419
},
419420
},
420421
},

src/instructlab/sdg/generate_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _gen_test_data(
124124
outfile.write("\n")
125125

126126

127-
def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
127+
def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_generate, batched):
128128
knowledge_flow_types = []
129129
freeform_skill_flow_types = []
130130
grounded_skill_flow_types = []
@@ -144,7 +144,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
144144
[
145145
Pipeline(
146146
flow_type(
147-
client, model_family, model_name, num_iters, batched
147+
client, model_family, model_name, num_instructions_to_generate, batched
148148
).get_flow()
149149
)
150150
for flow_type in knowledge_flow_types
@@ -154,7 +154,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
154154
[
155155
Pipeline(
156156
flow_type(
157-
client, model_family, model_name, num_iters, batched
157+
client, model_family, model_name, num_instructions_to_generate, batched
158158
).get_flow()
159159
)
160160
for flow_type in freeform_skill_flow_types
@@ -164,7 +164,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
164164
[
165165
Pipeline(
166166
flow_type(
167-
client, model_family, model_name, num_iters, batched
167+
client, model_family, model_name, num_instructions_to_generate, batched
168168
).get_flow()
169169
)
170170
for flow_type in grounded_skill_flow_types

0 commit comments

Comments
 (0)