@@ -127,7 +127,15 @@ def compile(
127127 if self .max_errors is not None
128128 else dspy .settings .max_errors
129129 )
130- zeroshot_opt = (self .max_bootstrapped_demos == 0 ) and (self .max_labeled_demos == 0 )
130+
131+ effective_max_bootstrapped_demos = (
132+ max_bootstrapped_demos if max_bootstrapped_demos is not None else self .max_bootstrapped_demos
133+ )
134+ effective_max_labeled_demos = (
135+ max_labeled_demos if max_labeled_demos is not None else self .max_labeled_demos
136+ )
137+
138+ zeroshot_opt = (effective_max_bootstrapped_demos == 0 ) and (effective_max_labeled_demos == 0 )
131139
132140 # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
133141 if self .auto is None and (self .num_candidates is not None and num_trials is None ):
@@ -149,22 +157,46 @@ def compile(
149157 seed = seed or self .seed
150158 self ._set_random_seeds (seed )
151159
152- # Update max demos if specified
153- if max_bootstrapped_demos is not None :
154- self .max_bootstrapped_demos = max_bootstrapped_demos
155- if max_labeled_demos is not None :
156- self .max_labeled_demos = max_labeled_demos
157160
158161 # Set training & validation sets
159162 trainset , valset = self ._set_and_validate_datasets (trainset , valset )
160163
164+ num_instruct_candidates = (
165+ self .num_instruct_candidates
166+ if self .num_instruct_candidates is not None
167+ else self .num_candidates
168+ )
169+ num_fewshot_candidates = (
170+ self .num_fewshot_candidates
171+ if self .num_fewshot_candidates is not None
172+ else self .num_candidates
173+ )
174+
161175 # Set hyperparameters based on run mode (if set)
162- num_trials , valset , minibatch = self ._set_hyperparams_from_run_mode (
163- student , num_trials , minibatch , zeroshot_opt , valset
176+ (
177+ num_trials ,
178+ valset ,
179+ minibatch ,
180+ num_instruct_candidates ,
181+ num_fewshot_candidates ,
182+ ) = self ._set_hyperparams_from_run_mode (
183+ student ,
184+ num_trials ,
185+ minibatch ,
186+ zeroshot_opt ,
187+ valset ,
188+ num_instruct_candidates ,
189+ num_fewshot_candidates ,
164190 )
165191
166192 if self .auto :
167- self ._print_auto_run_settings (num_trials , minibatch , valset )
193+ self ._print_auto_run_settings (
194+ num_trials ,
195+ minibatch ,
196+ valset ,
197+ num_fewshot_candidates ,
198+ num_instruct_candidates ,
199+ )
168200
169201 if minibatch and minibatch_size > len (valset ):
170202 raise ValueError (f"Minibatch size cannot exceed the size of the valset. Valset size: { len (valset )} ." )
@@ -183,7 +215,17 @@ def compile(
183215
184216 with dspy .context (lm = self .task_model ):
185217 # Step 1: Bootstrap few-shot examples
186- demo_candidates = self ._bootstrap_fewshot_examples (program , trainset , seed , teacher )
218+ demo_candidates = self ._bootstrap_fewshot_examples (
219+ program ,
220+ trainset ,
221+ seed ,
222+ teacher ,
223+ num_fewshot_candidates = num_fewshot_candidates ,
224+ max_bootstrapped_demos = effective_max_bootstrapped_demos ,
225+ max_labeled_demos = effective_max_labeled_demos ,
226+ max_errors = effective_max_errors ,
227+ metric_threshold = self .metric_threshold ,
228+ )
187229
188230 # Step 2: Propose instruction candidates
189231 instruction_candidates = self ._propose_instructions (
@@ -195,6 +237,7 @@ def compile(
195237 data_aware_proposer ,
196238 tip_aware_proposer ,
197239 fewshot_aware_proposer ,
240+ num_instruct_candidates = num_instruct_candidates ,
198241 )
199242
200243 # If zero-shot, discard demos
@@ -234,13 +277,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
234277 def _set_hyperparams_from_run_mode (
235278 self ,
236279 program : Any ,
237- num_trials : int ,
280+ num_trials : int | None ,
238281 minibatch : bool ,
239282 zeroshot_opt : bool ,
240283 valset : list ,
241- ) -> tuple [int , list , bool ]:
284+ num_instruct_candidates : int | None ,
285+ num_fewshot_candidates : int | None ,
286+ ) -> tuple [int , list , bool , int , int ]:
242287 if self .auto is None :
243- return num_trials , valset , minibatch
288+ if num_instruct_candidates is None or num_fewshot_candidates is None :
289+ raise ValueError ("num_candidates must be provided when auto is None." )
290+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
244291
245292 auto_settings = AUTO_RUN_SETTINGS [self .auto ]
246293
@@ -250,12 +297,12 @@ def _set_hyperparams_from_run_mode(
250297 # Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
251298 # This is because we've found that it's generally better to spend optimization budget on few-shot examples
252299 # When they are allowed.
253- self . num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
254- self . num_fewshot_candidates = auto_settings ["n" ]
300+ num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
301+ num_fewshot_candidates = auto_settings ["n" ]
255302
256303 num_trials = self ._set_num_trials_from_num_candidates (program , zeroshot_opt , auto_settings ["n" ])
257304
258- return num_trials , valset , minibatch
305+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
259306
260307 def _set_and_validate_datasets (self , trainset : list , valset : list | None ):
261308 if not trainset :
@@ -274,13 +321,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
274321
275322 return trainset , valset
276323
277- def _print_auto_run_settings (self , num_trials : int , minibatch : bool , valset : list ):
324+ def _print_auto_run_settings (
325+ self ,
326+ num_trials : int ,
327+ minibatch : bool ,
328+ valset : list ,
329+ num_fewshot_candidates : int ,
330+ num_instruct_candidates : int ,
331+ ):
278332 logger .info (
279333 f"\n RUNNING WITH THE FOLLOWING { self .auto .upper ()} AUTO RUN SETTINGS:"
280334 f"\n num_trials: { num_trials } "
281335 f"\n minibatch: { minibatch } "
282- f"\n num_fewshot_candidates: { self . num_fewshot_candidates } "
283- f"\n num_instruct_candidates: { self . num_instruct_candidates } "
336+ f"\n num_fewshot_candidates: { num_fewshot_candidates } "
337+ f"\n num_instruct_candidates: { num_instruct_candidates } "
284338 f"\n valset size: { len (valset )} \n "
285339 )
286340
@@ -293,18 +347,19 @@ def _estimate_lm_calls(
293347 minibatch_full_eval_steps : int ,
294348 valset : list ,
295349 program_aware_proposer : bool ,
350+ num_instruct_candidates : int ,
296351 ) -> tuple [str , str ]:
297352 num_predictors = len (program .predictors ())
298353
299354 # Estimate prompt model calls
300355 estimated_prompt_model_calls = (
301356 10 # Data summarizer calls
302- + self . num_instruct_candidates * num_predictors # Candidate generation
357+ + num_instruct_candidates * num_predictors # Candidate generation
303358 + (num_predictors + 1 if program_aware_proposer else 0 ) # Program-aware proposer
304359 )
305360 prompt_model_line = (
306361 f"{ YELLOW } - Prompt Generation: { BLUE } { BOLD } 10{ ENDC } { YELLOW } data summarizer calls + "
307- f"{ BLUE } { BOLD } { self . num_instruct_candidates } { ENDC } { YELLOW } * "
362+ f"{ BLUE } { BOLD } { num_instruct_candidates } { ENDC } { YELLOW } * "
308363 f"{ BLUE } { BOLD } { num_predictors } { ENDC } { YELLOW } lm calls in program "
309364 f"+ ({ BLUE } { BOLD } { num_predictors + 1 } { ENDC } { YELLOW } ) lm calls in program-aware proposer "
310365 f"= { BLUE } { BOLD } { estimated_prompt_model_calls } { ENDC } { YELLOW } prompt model calls{ ENDC } "
@@ -331,38 +386,48 @@ def _estimate_lm_calls(
331386
332387 return prompt_model_line , task_model_line
333388
334- def _bootstrap_fewshot_examples (self , program : Any , trainset : list , seed : int , teacher : Any ) -> list | None :
389+ def _bootstrap_fewshot_examples (
390+ self ,
391+ program : Any ,
392+ trainset : list ,
393+ seed : int ,
394+ teacher : Any ,
395+ * ,
396+ num_fewshot_candidates : int ,
397+ max_bootstrapped_demos : int ,
398+ max_labeled_demos : int ,
399+ max_errors : int | None ,
400+ metric_threshold : float | None ,
401+ ) -> list | None :
335402 logger .info ("\n ==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==" )
336- if self . max_bootstrapped_demos > 0 :
403+ if max_bootstrapped_demos > 0 :
337404 logger .info (
338405 "These will be used as few-shot example candidates for our program and for creating instructions.\n "
339406 )
340407 else :
341408 logger .info ("These will be used for informing instruction proposal.\n " )
342409
343- logger .info (f"Bootstrapping N={ self . num_fewshot_candidates } sets of demonstrations..." )
410+ logger .info (f"Bootstrapping N={ num_fewshot_candidates } sets of demonstrations..." )
344411
345- zeroshot = self . max_bootstrapped_demos == 0 and self . max_labeled_demos == 0
412+ zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
346413
347- # try:
348- effective_max_errors = (
349- self .max_errors if self .max_errors is not None else dspy .settings .max_errors
350- )
414+ if max_errors is None :
415+ max_errors = dspy .settings .max_errors
351416
352417 demo_candidates = create_n_fewshot_demo_sets (
353418 student = program ,
354- num_candidate_sets = self . num_fewshot_candidates ,
419+ num_candidate_sets = num_fewshot_candidates ,
355420 trainset = trainset ,
356- max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_labeled_demos ),
421+ max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos ),
357422 max_bootstrapped_demos = (
358- BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_bootstrapped_demos
423+ BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
359424 ),
360425 metric = self .metric ,
361- max_errors = effective_max_errors ,
426+ max_errors = max_errors ,
362427 teacher = teacher ,
363428 teacher_settings = self .teacher_settings ,
364429 seed = seed ,
365- metric_threshold = self . metric_threshold ,
430+ metric_threshold = metric_threshold ,
366431 rng = self .rng ,
367432 )
368433 # NOTE: Bootstrapping is essential to MIPRO!
@@ -384,6 +449,7 @@ def _propose_instructions(
384449 data_aware_proposer : bool ,
385450 tip_aware_proposer : bool ,
386451 fewshot_aware_proposer : bool ,
452+ num_instruct_candidates : int ,
387453 ) -> dict [int , list [str ]]:
388454 logger .info ("\n ==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==" )
389455 logger .info (
@@ -408,12 +474,12 @@ def _propose_instructions(
408474 init_temperature = self .init_temperature ,
409475 )
410476
411- logger .info (f"\n Proposing N={ self . num_instruct_candidates } instructions...\n " )
477+ logger .info (f"\n Proposing N={ num_instruct_candidates } instructions...\n " )
412478 instruction_candidates = proposer .propose_instructions_for_program (
413479 trainset = trainset ,
414480 program = program ,
415481 demo_candidates = demo_candidates ,
416- N = self . num_instruct_candidates ,
482+ N = num_instruct_candidates ,
417483 trial_logs = {},
418484 )
419485
0 commit comments