Skip to content

Commit 056d54e

Browse files
authored
fix(MIPROv2): zero shot not taking .compile parameters into account before determining if the program was zero shot (#8909)
* fix(MIPROv2): zero shot not taking .compile parameters into account before determining if the program was zero shot * remove extra logs * Remove log * Fix merge conflict * Remove extra whitespace
1 parent da69f9d commit 056d54e

File tree

1 file changed

+102
-36
lines changed

1 file changed

+102
-36
lines changed

dspy/teleprompt/mipro_optimizer_v2.py

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,15 @@ def compile(
127127
if self.max_errors is not None
128128
else dspy.settings.max_errors
129129
)
130-
zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0)
130+
131+
effective_max_bootstrapped_demos = (
132+
max_bootstrapped_demos if max_bootstrapped_demos is not None else self.max_bootstrapped_demos
133+
)
134+
effective_max_labeled_demos = (
135+
max_labeled_demos if max_labeled_demos is not None else self.max_labeled_demos
136+
)
137+
138+
zeroshot_opt = (effective_max_bootstrapped_demos == 0) and (effective_max_labeled_demos == 0)
131139

132140
# If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
133141
if self.auto is None and (self.num_candidates is not None and num_trials is None):
@@ -149,22 +157,46 @@ def compile(
149157
seed = seed or self.seed
150158
self._set_random_seeds(seed)
151159

152-
# Update max demos if specified
153-
if max_bootstrapped_demos is not None:
154-
self.max_bootstrapped_demos = max_bootstrapped_demos
155-
if max_labeled_demos is not None:
156-
self.max_labeled_demos = max_labeled_demos
157160

158161
# Set training & validation sets
159162
trainset, valset = self._set_and_validate_datasets(trainset, valset)
160163

164+
num_instruct_candidates = (
165+
self.num_instruct_candidates
166+
if self.num_instruct_candidates is not None
167+
else self.num_candidates
168+
)
169+
num_fewshot_candidates = (
170+
self.num_fewshot_candidates
171+
if self.num_fewshot_candidates is not None
172+
else self.num_candidates
173+
)
174+
161175
# Set hyperparameters based on run mode (if set)
162-
num_trials, valset, minibatch = self._set_hyperparams_from_run_mode(
163-
student, num_trials, minibatch, zeroshot_opt, valset
176+
(
177+
num_trials,
178+
valset,
179+
minibatch,
180+
num_instruct_candidates,
181+
num_fewshot_candidates,
182+
) = self._set_hyperparams_from_run_mode(
183+
student,
184+
num_trials,
185+
minibatch,
186+
zeroshot_opt,
187+
valset,
188+
num_instruct_candidates,
189+
num_fewshot_candidates,
164190
)
165191

166192
if self.auto:
167-
self._print_auto_run_settings(num_trials, minibatch, valset)
193+
self._print_auto_run_settings(
194+
num_trials,
195+
minibatch,
196+
valset,
197+
num_fewshot_candidates,
198+
num_instruct_candidates,
199+
)
168200

169201
if minibatch and minibatch_size > len(valset):
170202
raise ValueError(f"Minibatch size cannot exceed the size of the valset. Valset size: {len(valset)}.")
@@ -183,7 +215,17 @@ def compile(
183215

184216
with dspy.context(lm=self.task_model):
185217
# Step 1: Bootstrap few-shot examples
186-
demo_candidates = self._bootstrap_fewshot_examples(program, trainset, seed, teacher)
218+
demo_candidates = self._bootstrap_fewshot_examples(
219+
program,
220+
trainset,
221+
seed,
222+
teacher,
223+
num_fewshot_candidates=num_fewshot_candidates,
224+
max_bootstrapped_demos=effective_max_bootstrapped_demos,
225+
max_labeled_demos=effective_max_labeled_demos,
226+
max_errors=effective_max_errors,
227+
metric_threshold=self.metric_threshold,
228+
)
187229

188230
# Step 2: Propose instruction candidates
189231
instruction_candidates = self._propose_instructions(
@@ -195,6 +237,7 @@ def compile(
195237
data_aware_proposer,
196238
tip_aware_proposer,
197239
fewshot_aware_proposer,
240+
num_instruct_candidates=num_instruct_candidates,
198241
)
199242

200243
# If zero-shot, discard demos
@@ -234,13 +277,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
234277
def _set_hyperparams_from_run_mode(
235278
self,
236279
program: Any,
237-
num_trials: int,
280+
num_trials: int | None,
238281
minibatch: bool,
239282
zeroshot_opt: bool,
240283
valset: list,
241-
) -> tuple[int, list, bool]:
284+
num_instruct_candidates: int | None,
285+
num_fewshot_candidates: int | None,
286+
) -> tuple[int, list, bool, int, int]:
242287
if self.auto is None:
243-
return num_trials, valset, minibatch
288+
if num_instruct_candidates is None or num_fewshot_candidates is None:
289+
raise ValueError("num_candidates must be provided when auto is None.")
290+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
244291

245292
auto_settings = AUTO_RUN_SETTINGS[self.auto]
246293

@@ -250,12 +297,12 @@ def _set_hyperparams_from_run_mode(
250297
# Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
251298
# This is because we've found that it's generally better to spend optimization budget on few-shot examples
252299
# When they are allowed.
253-
self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
254-
self.num_fewshot_candidates = auto_settings["n"]
300+
num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
301+
num_fewshot_candidates = auto_settings["n"]
255302

256303
num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"])
257304

258-
return num_trials, valset, minibatch
305+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
259306

260307
def _set_and_validate_datasets(self, trainset: list, valset: list | None):
261308
if not trainset:
@@ -274,13 +321,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
274321

275322
return trainset, valset
276323

277-
def _print_auto_run_settings(self, num_trials: int, minibatch: bool, valset: list):
324+
def _print_auto_run_settings(
325+
self,
326+
num_trials: int,
327+
minibatch: bool,
328+
valset: list,
329+
num_fewshot_candidates: int,
330+
num_instruct_candidates: int,
331+
):
278332
logger.info(
279333
f"\nRUNNING WITH THE FOLLOWING {self.auto.upper()} AUTO RUN SETTINGS:"
280334
f"\nnum_trials: {num_trials}"
281335
f"\nminibatch: {minibatch}"
282-
f"\nnum_fewshot_candidates: {self.num_fewshot_candidates}"
283-
f"\nnum_instruct_candidates: {self.num_instruct_candidates}"
336+
f"\nnum_fewshot_candidates: {num_fewshot_candidates}"
337+
f"\nnum_instruct_candidates: {num_instruct_candidates}"
284338
f"\nvalset size: {len(valset)}\n"
285339
)
286340

@@ -293,18 +347,19 @@ def _estimate_lm_calls(
293347
minibatch_full_eval_steps: int,
294348
valset: list,
295349
program_aware_proposer: bool,
350+
num_instruct_candidates: int,
296351
) -> tuple[str, str]:
297352
num_predictors = len(program.predictors())
298353

299354
# Estimate prompt model calls
300355
estimated_prompt_model_calls = (
301356
10 # Data summarizer calls
302-
+ self.num_instruct_candidates * num_predictors # Candidate generation
357+
+ num_instruct_candidates * num_predictors # Candidate generation
303358
+ (num_predictors + 1 if program_aware_proposer else 0) # Program-aware proposer
304359
)
305360
prompt_model_line = (
306361
f"{YELLOW}- Prompt Generation: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + "
307-
f"{BLUE}{BOLD}{self.num_instruct_candidates}{ENDC}{YELLOW} * "
362+
f"{BLUE}{BOLD}{num_instruct_candidates}{ENDC}{YELLOW} * "
308363
f"{BLUE}{BOLD}{num_predictors}{ENDC}{YELLOW} lm calls in program "
309364
f"+ ({BLUE}{BOLD}{num_predictors + 1}{ENDC}{YELLOW}) lm calls in program-aware proposer "
310365
f"= {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}"
@@ -331,38 +386,48 @@ def _estimate_lm_calls(
331386

332387
return prompt_model_line, task_model_line
333388

334-
def _bootstrap_fewshot_examples(self, program: Any, trainset: list, seed: int, teacher: Any) -> list | None:
389+
def _bootstrap_fewshot_examples(
390+
self,
391+
program: Any,
392+
trainset: list,
393+
seed: int,
394+
teacher: Any,
395+
*,
396+
num_fewshot_candidates: int,
397+
max_bootstrapped_demos: int,
398+
max_labeled_demos: int,
399+
max_errors: int | None,
400+
metric_threshold: float | None,
401+
) -> list | None:
335402
logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==")
336-
if self.max_bootstrapped_demos > 0:
403+
if max_bootstrapped_demos > 0:
337404
logger.info(
338405
"These will be used as few-shot example candidates for our program and for creating instructions.\n"
339406
)
340407
else:
341408
logger.info("These will be used for informing instruction proposal.\n")
342409

343-
logger.info(f"Bootstrapping N={self.num_fewshot_candidates} sets of demonstrations...")
410+
logger.info(f"Bootstrapping N={num_fewshot_candidates} sets of demonstrations...")
344411

345-
zeroshot = self.max_bootstrapped_demos == 0 and self.max_labeled_demos == 0
412+
zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
346413

347-
# try:
348-
effective_max_errors = (
349-
self.max_errors if self.max_errors is not None else dspy.settings.max_errors
350-
)
414+
if max_errors is None:
415+
max_errors = dspy.settings.max_errors
351416

352417
demo_candidates = create_n_fewshot_demo_sets(
353418
student=program,
354-
num_candidate_sets=self.num_fewshot_candidates,
419+
num_candidate_sets=num_fewshot_candidates,
355420
trainset=trainset,
356-
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_labeled_demos),
421+
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos),
357422
max_bootstrapped_demos=(
358-
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_bootstrapped_demos
423+
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
359424
),
360425
metric=self.metric,
361-
max_errors=effective_max_errors,
426+
max_errors=max_errors,
362427
teacher=teacher,
363428
teacher_settings=self.teacher_settings,
364429
seed=seed,
365-
metric_threshold=self.metric_threshold,
430+
metric_threshold=metric_threshold,
366431
rng=self.rng,
367432
)
368433
# NOTE: Bootstrapping is essential to MIPRO!
@@ -384,6 +449,7 @@ def _propose_instructions(
384449
data_aware_proposer: bool,
385450
tip_aware_proposer: bool,
386451
fewshot_aware_proposer: bool,
452+
num_instruct_candidates: int,
387453
) -> dict[int, list[str]]:
388454
logger.info("\n==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==")
389455
logger.info(
@@ -408,12 +474,12 @@ def _propose_instructions(
408474
init_temperature=self.init_temperature,
409475
)
410476

411-
logger.info(f"\nProposing N={self.num_instruct_candidates} instructions...\n")
477+
logger.info(f"\nProposing N={num_instruct_candidates} instructions...\n")
412478
instruction_candidates = proposer.propose_instructions_for_program(
413479
trainset=trainset,
414480
program=program,
415481
demo_candidates=demo_candidates,
416-
N=self.num_instruct_candidates,
482+
N=num_instruct_candidates,
417483
trial_logs={},
418484
)
419485

0 commit comments

Comments
 (0)