diff --git a/docs/docs/learn/evaluation/metrics.md b/docs/docs/learn/evaluation/metrics.md index 9e1f742f14..9ba9c131a6 100644 --- a/docs/docs/learn/evaluation/metrics.md +++ b/docs/docs/learn/evaluation/metrics.md @@ -6,17 +6,15 @@ sidebar_position: 5 DSPy is a machine learning framework, so you must think about your **automatic metrics** for evaluation (to track your progress) and optimization (so DSPy can make your programs more effective). - ## What is a metric and how do I define a metric for my task? -A metric is just a function that will take examples from your data and the output of your system and return a score that quantifies how good the output is. What makes outputs from your system good or bad? +A metric is just a function that will take examples from your data and the output of your system and return a score that quantifies how good the output is. What makes outputs from your system good or bad? For simple tasks, this could be just "accuracy" or "exact match" or "F1 score". This may be the case for simple classification or short-form QA tasks. However, for most applications, your system will output long-form outputs. There, your metric should probably be a smaller DSPy program that checks multiple properties of the output (quite possibly using AI feedback from LMs). -Getting this right on the first try is unlikely, but you should start with something simple and iterate. - +Getting this right on the first try is unlikely, but you should start with something simple and iterate. ## Simple metrics @@ -54,6 +52,34 @@ def validate_context_and_answer(example, pred, trace=None): Defining a good metric is an iterative process, so doing some initial evaluations and looking at your data and outputs is key. +## Multi-objective metrics with subscores + +Many real systems must balance more than one objective: quality vs. leakage, answer accuracy vs. latency, etc. DSPy metrics now expose a simple helper called [`dspy.metrics.subscore`](../../api/index.md) that lets you declare named subscores inside an ordinary Python metric. Each `subscore` behaves like a float so you can keep writing intuitive math, while DSPy records the subscore values, metadata, and the expression you returned. + +```python +from dspy.metrics import subscore + +def metric(example, pred, ctx=None): + acc = subscore("accuracy", answer_exact_match(example, pred), bounds=(0, 1)) + bleu = subscore("bleu", bleu_like(example.answer, pred.answer), bounds=(0, 1)) + latency = subscore( + "latency_s", + (ctx.latency_ms or 0) / 1000 if ctx else 0, + maximize=False, + units="s", + ) + return acc**2 + 0.3 * bleu - 0.02 * latency +``` + +When this metric runs during evaluation or optimization, DSPy evaluates the returned expression to obtain the aggregate scalar (preserving backwards compatibility), but also keeps a `Score` object that exposes: + +- `scalar`: the numeric value of the expression (`acc**2 + …`). +- `subscores`: the resolved subscores, e.g. `{"accuracy": 1.0, "bleu": 0.73, "latency_s": 0.42}`. +- `info`: metadata such as the canonical expression string and any per-subscore metadata you provided (bounds, maximize, units, cost, …). + +Optimizers can use those subscores directly for Pareto frontiers or constrained search, and evaluation tables will include additional columns for each subscore. + +Metrics that return subscores typically accept a third argument `ctx`, which contains runtime information (latency, token usage, optional seed). If you omit `subscore`, nothing changes—legacy metrics that return a plain float continue to work as before. ## Evaluation @@ -79,7 +105,6 @@ evaluator = Evaluate(devset=YOUR_DEVSET, num_threads=1, display_progress=True, d evaluator(YOUR_PROGRAM, metric=YOUR_METRIC) ``` - ## Intermediate: Using AI feedback for your metric For most applications, your system will output long-form outputs, so your metric should check multiple dimensions of the output using AI feedback from LMs. @@ -104,7 +129,7 @@ def metric(gold, pred, trace=None): engaging = "Does the assessed text make for a self-contained, engaging tweet?" correct = f"The text should answer `{question}` with `{answer}`. Does the assessed text contain this answer?" - + correct = dspy.Predict(Assess)(assessed_text=tweet, assessment_question=correct) engaging = dspy.Predict(Assess)(assessed_text=tweet, assessment_question=engaging) @@ -117,20 +142,16 @@ def metric(gold, pred, trace=None): When compiling, `trace is not None`, and we want to be strict about judging things, so we will only return `True` if `score >= 2`. Otherwise, we return a score out of 1.0 (i.e., `score / 2.0`). - ## Advanced: Using a DSPy program as your metric If your metric is itself a DSPy program, one of the most powerful ways to iterate is to compile (optimize) your metric itself. That's usually easy because the output of the metric is usually a simple value (e.g., a score out of 5) so the metric's metric is easy to define and optimize by collecting a few examples. - - ### Advanced: Accessing the `trace` When your metric is used during evaluation runs, DSPy will not try to track the steps of your program. But during compiling (optimization), DSPy will trace your LM calls. The trace will contain inputs/outputs to each DSPy predictor and you can leverage that to validate intermediate steps for optimization. - ```python def validate_hops(example, pred, trace=None): hops = [example.question] + [outputs.query for *_, outputs in trace if 'query' in outputs] diff --git a/dspy/evaluate/evaluate.py b/dspy/evaluate/evaluate.py index 50513ce8e2..fc0fde4f57 100644 --- a/dspy/evaluate/evaluate.py +++ b/dspy/evaluate/evaluate.py @@ -1,7 +1,10 @@ import csv +import dataclasses import importlib +import inspect import json import logging +import time import types from typing import TYPE_CHECKING, Any, Callable @@ -11,6 +14,8 @@ import tqdm import dspy +from dspy.metrics import Score +from dspy.metrics._subscores import _begin_collect, _end_collect, finalize_scores from dspy.primitives.prediction import Prediction from dspy.utils.callback import with_callbacks from dspy.utils.parallelizer import ParallelExecutor @@ -45,6 +50,23 @@ def HTML(x: str) -> str: # noqa: N802 logger = logging.getLogger(__name__) +@dataclasses.dataclass +class EvaluationMetricContext: + usage: dict | None = None + latency_ms: float | None = None + seed: int | None = None + + @property + def cache_key(self) -> tuple[Any, ...]: + usage_key = None + if self.usage is not None: + try: + usage_key = json.dumps(self.usage, sort_keys=True, default=repr) + except TypeError: + usage_key = repr(self.usage) + return (self.latency_ms, usage_key, self.seed) + + class EvaluationResult(Prediction): """ A class that represents the result of an evaluation. @@ -109,6 +131,7 @@ def __init__( self.failure_score = failure_score self.save_as_csv = save_as_csv self.save_as_json = save_as_json + self._metric_accepts_ctx_cache: dict[int, bool] = {} if "return_outputs" in kwargs: raise ValueError("`return_outputs` is no longer supported. Results are always returned inside the `results` field of the `EvaluationResult` object.") @@ -168,16 +191,31 @@ def __call__( ) def process_item(example): + start_time = time.perf_counter() prediction = program(**example.inputs()) - score = metric(example, prediction) - return prediction, score + latency_ms = (time.perf_counter() - start_time) * 1000.0 + + prediction_obj = _extract_prediction_object(prediction) + usage = prediction_obj.get_lm_usage() if isinstance(prediction_obj, Prediction) else None + ctx = EvaluationMetricContext(usage=usage, latency_ms=latency_ms) + + if isinstance(prediction_obj, Prediction): + prediction_obj.bind_example(example) + + scores = self._execute_metric(metric, example, prediction, ctx) + + if isinstance(prediction_obj, Prediction) and isinstance(scores, Score): + prediction_obj._store_scores(scores, ctx.cache_key) + + return prediction, scores results = executor.execute(process_item, devset) assert len(devset) == len(results) - results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results] + results = [((dspy.Prediction(), Score(self.failure_score)) if r is None else r) for r in results] results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)] - ncorrect, ntotal = sum(score for *_, score in results), len(devset) + aggregates = [score.scalar for *_, score in results] + ncorrect, ntotal = sum(aggregates), len(devset) logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)") @@ -227,19 +265,19 @@ def process_item(example): @staticmethod def _prepare_results_output( - results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str + results: list[tuple["dspy.Example", "dspy.Example", Score]], metric_name: str ): return [ ( - merge_dicts(example, prediction) | {metric_name: score} + merge_dicts(example, prediction) | _scores_to_row(score, metric_name) if prediction_is_dictlike(prediction) - else dict(example) | {"prediction": prediction, metric_name: score} + else dict(example) | {"prediction": prediction} | _scores_to_row(score, metric_name) ) for example, prediction, score in results ] def _construct_result_table( - self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str + self, results: list[tuple["dspy.Example", "dspy.Example", Score]], metric_name: str ) -> "pd.DataFrame": """ Construct a pandas DataFrame from the specified result list. @@ -262,6 +300,49 @@ def _construct_result_table( return result_df.rename(columns={"correct": metric_name}) + def _execute_metric( + self, + metric: Callable | None, + example: "dspy.Example", + prediction: Any, + ctx: EvaluationMetricContext, + ) -> Score: + if metric is None: + if isinstance(prediction, Prediction): + scores = prediction.resolve_score(ctx) + if scores is None: + raise ValueError("Prediction does not provide a score and no metric was supplied.") + return scores + raise ValueError("No metric provided for evaluation.") + + token = _begin_collect() + try: + if self._metric_accepts_context(metric): + result = metric(example, prediction, ctx) + else: + result = metric(example, prediction) + finally: + collector = _end_collect(token) + + ctx_info: dict[str, Any] = {} + if ctx.usage is not None: + ctx_info["usage"] = ctx.usage + if ctx.latency_ms is not None: + ctx_info["latency_ms"] = ctx.latency_ms + if ctx.seed is not None: + ctx_info["seed"] = ctx.seed + + return finalize_scores(result, collector, ctx_info=ctx_info) + + def _metric_accepts_context(self, metric: Callable) -> bool: + cache_key = id(metric) + cached = self._metric_accepts_ctx_cache.get(cache_key) + if cached is not None: + return cached + accepts = _callable_accepts_context(metric) + self._metric_accepts_ctx_cache[cache_key] = accepts + return accepts + def _display_result_table(self, result_df: "pd.DataFrame", display_table: bool | int, metric_name: str): """ Display the specified result DataFrame in a table format. @@ -321,6 +402,39 @@ def merge_dicts(d1, d2) -> dict: return merged +def _scores_to_row(scores: Score, metric_name: str) -> dict[str, Any]: + row = {metric_name: scores.scalar} + for subscore_name, value in scores.subscores.items(): + row[f"{metric_name}.{subscore_name}"] = value + expr = scores.info.get("expr") if isinstance(scores.info, dict) else None + if expr is not None: + row[f"{metric_name}.expr"] = expr + return row + + +def _extract_prediction_object(prediction: Any) -> Any: + if isinstance(prediction, Prediction): + return prediction + if isinstance(prediction, tuple) and prediction: + first = prediction[0] + if isinstance(first, Prediction): + return first + return prediction + + +def _callable_accepts_context(metric: Callable) -> bool: + try: + sig = inspect.signature(metric) + except (TypeError, ValueError): + return True + + params = list(sig.parameters.values()) + for param in params: + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + return True + return len(params) >= 3 + + def truncate_cell(content) -> str: """Truncate content of a cell to 25 words.""" words = str(content).split() diff --git a/dspy/metrics/__init__.py b/dspy/metrics/__init__.py new file mode 100644 index 0000000000..039fb93bfb --- /dev/null +++ b/dspy/metrics/__init__.py @@ -0,0 +1,23 @@ +"""Public helpers for metric subscores and score aggregation.""" + +from ._resolver import resolve_metric_score +from ._subscores import ( + Score, + coerce_metric_value, + subscore, + subscore_abs, + subscore_clip, + subscore_max, + subscore_min, +) + +__all__ = [ + "Score", + "subscore", + "subscore_abs", + "subscore_min", + "subscore_max", + "subscore_clip", + "coerce_metric_value", + "resolve_metric_score", +] diff --git a/dspy/metrics/_resolver.py b/dspy/metrics/_resolver.py new file mode 100644 index 0000000000..4ed109e053 --- /dev/null +++ b/dspy/metrics/_resolver.py @@ -0,0 +1,121 @@ +"""Helpers for executing metrics and resolving their scores.""" + +from __future__ import annotations + +import warnings +from typing import Any, Callable + +from ._subscores import ( + Score, + SubscoreExpr, + SubscoreValue, + _begin_collect, + _end_collect, + finalize_scores, +) + + +def resolve_metric_score( + metric: Callable[..., Any], + *args: Any, + context: str = "metric", + warn_on_expression: bool = True, + ctx_info: dict[str, Any] | None = None, +) -> tuple[Score, dict[str, Any]]: + """Run ``metric`` under the subscore collector and normalize the result. + + Parameters + ---------- + metric: + Callable metric to execute. The callable is invoked as ``metric(*args)``. + *args: + Positional arguments to forward to ``metric``. + context: + Human-readable context used in warning messages. + warn_on_expression: + Whether to emit runtime warnings if the metric returns a raw + ``SubscoreExpr``/``SubscoreValue`` that needs to be resolved. + ctx_info: + Optional context dictionary that will be folded into the resulting + ``Score.info`` via :func:`finalize_scores`. + + Returns + ------- + tuple[Score, dict[str, Any]] + The resolved ``Score`` object and any auxiliary metadata returned next + to the score (e.g., extra fields from a ``dspy.Prediction``). + """ + + token = _begin_collect() + try: + raw_output = metric(*args) + finally: + collector = _end_collect(token) + + value, metadata = _unwrap_metric_output( + raw_output, + context=context, + warn_on_expression=warn_on_expression, + ) + + score = finalize_scores(value, collector, ctx_info=ctx_info) + return score, metadata + + +def _unwrap_metric_output( + raw_output: Any, + *, + context: str, + warn_on_expression: bool, +) -> tuple[Any, dict[str, Any]]: + metadata: dict[str, Any] = {} + value = raw_output + + prediction = _as_prediction(raw_output) + if prediction is not None: + if not hasattr(prediction, "score"): + raise ValueError( + f"{context}: metric returned a Prediction without a `score` field." + ) + value = prediction.score + metadata = {k: v for k, v in prediction.items() if k != "score"} + elif isinstance(raw_output, dict) and "score" in raw_output: + metadata = {k: v for k, v in raw_output.items() if k != "score"} + value = raw_output["score"] + else: + attr_score = getattr(raw_output, "score", None) + if attr_score is not None: + value = attr_score + if hasattr(raw_output, "items"): + try: + metadata = {k: v for k, v in raw_output.items() if k != "score"} + except Exception: + metadata = {} + + if value is None: + raise ValueError(f"{context}: metric did not return a score.") + + if warn_on_expression and isinstance(value, SubscoreExpr): + warnings.warn( + f"{context}: metric returned a subscore expression; resolving to a Score.", + RuntimeWarning, + stacklevel=3, + ) + elif warn_on_expression and isinstance(value, SubscoreValue): + warnings.warn( + f"{context}: metric returned a subscore component; resolving to a Score.", + RuntimeWarning, + stacklevel=3, + ) + + return value, metadata + + +def _as_prediction(obj: Any) -> Any | None: + try: + from dspy.primitives.prediction import Prediction # type: ignore + except Exception: # pragma: no cover - defensive import guard + return None + + return obj if isinstance(obj, Prediction) else None + diff --git a/dspy/metrics/_subscores.py b/dspy/metrics/_subscores.py new file mode 100644 index 0000000000..864f86578f --- /dev/null +++ b/dspy/metrics/_subscores.py @@ -0,0 +1,408 @@ +"""Utilities for declaring and aggregating metric subscores. + +This module implements the tiny algebra used to trace metric subscores and +resolve arithmetic expressions that combine them. Users interact with the +`subscore` helper, which behaves like a float while carrying metadata and a +stable name. Internally, we track subscores in a ``ContextVar`` so concurrent +evaluations remain isolated. + +The module also exposes a :class:`Score` dataclass that carries the resolved +scalar score, the subscores, and any auxiliary information collected during +evaluation (e.g., canonical expressions, metadata, usage statistics). +""" + +from __future__ import annotations + +import contextvars +import dataclasses +import math +import warnings +from typing import Any, Iterable + +__all__ = [ + "Score", + "subscore", + "subscore_abs", + "subscore_min", + "subscore_max", + "subscore_clip", + "_begin_collect", + "_end_collect", + "finalize_scores", +] + + +@dataclasses.dataclass(frozen=True) +class Score: + """Resolved metric information. + + Attributes + ---------- + scalar: + The scalar score obtained from evaluating the user's arithmetic + expression. For backwards compatibility this behaves exactly like the + previous single metric number. + subscores: + Mapping from subscore name to numeric value. Each entry corresponds to + a call to :func:`subscore` that was used in the returned expression (or + registered during metric execution even if the user coerced it to a + float early). + info: + A free-form dictionary used to carry auxiliary details such as the + canonical expression string, per-subscore metadata, latency, usage, etc. + """ + + scalar: float + subscores: dict[str, float] = dataclasses.field(default_factory=dict) + info: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self) -> None: # pragma: no cover - trivial conversions + object.__setattr__(self, "scalar", float(self.scalar)) + subscores = {name: float(value) for name, value in self.subscores.items()} + object.__setattr__(self, "subscores", subscores) + + +class SubscoreExpr: + """Tiny expression DAG that keeps arithmetic over subscores traceable.""" + + __slots__ = ("args", "op") + + def __init__(self, op: str, args: tuple[Any, ...]): + self.op = op + self.args = args + + # Arithmetic --------------------------------------------------------- + def __add__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("add", (self, other)) + + def __radd__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("add", (other, self)) + + def __sub__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("sub", (self, other)) + + def __rsub__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("sub", (other, self)) + + def __mul__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("mul", (self, other)) + + def __rmul__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("mul", (other, self)) + + def __truediv__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("div", (self, other)) + + def __rtruediv__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("div", (other, self)) + + def __pow__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("pow", (self, other)) + + def __rpow__(self, other: Any) -> SubscoreExpr: + return SubscoreExpr("pow", (other, self)) + + def __neg__(self) -> SubscoreExpr: + return SubscoreExpr("neg", (self,)) + + def __float__(self) -> float: + return float(self.eval()) + + # Comparisons -------------------------------------------------------- + def _cmp(self, other: Any, op: str) -> bool: + left = self.eval() + right = _eval_node(other) + if op == "lt": + return left < right + if op == "le": + return left <= right + if op == "gt": + return left > right + if op == "ge": + return left >= right + raise ValueError(f"Unsupported comparison op: {op}") + + def __lt__(self, other: Any) -> bool: + return self._cmp(other, "lt") + + def __le__(self, other: Any) -> bool: + return self._cmp(other, "le") + + def __gt__(self, other: Any) -> bool: + return self._cmp(other, "gt") + + def __ge__(self, other: Any) -> bool: + return self._cmp(other, "ge") + + def __bool__(self) -> bool: + return bool(self.eval()) + + # Evaluation --------------------------------------------------------- + def eval(self) -> float: + args = [ _eval_node(arg) for arg in self.args ] + op = self.op + if op == "add": + return args[0] + args[1] + if op == "sub": + return args[0] - args[1] + if op == "mul": + return args[0] * args[1] + if op == "div": + return args[0] / args[1] + if op == "pow": + return args[0] ** args[1] + if op == "neg": + return -args[0] + if op == "abs": + return abs(args[0]) + if op == "min": + return min(args) + if op == "max": + return max(args) + if op == "clip": + value, lower, upper = args + if lower is not None and value < lower: + value = lower + if upper is not None and value > upper: + value = upper + return value + raise ValueError(f"Unsupported operator: {op!r}") + + def used_subscores(self) -> dict[str, float]: + subscores: dict[str, float] = {} + _collect_subscores(self, subscores) + return subscores + + def to_repr(self) -> str: + return _repr_node(self) + + def __repr__(self) -> str: # pragma: no cover - debugging helper + return f"SubscoreExpr({self.to_repr()})" + + +class SubscoreValue(SubscoreExpr): + """Leaf node representing a named subscore.""" + + __slots__ = ("meta", "name", "value") + + def __init__(self, name: str, value: float, meta: dict[str, Any]): + self.name = name + self.value = float(value) + self.meta = meta + super().__init__("subscore", (self,)) + + # SubscoreValue inherits operator overloads from SubscoreExpr via SubscoreExpr methods. + + def eval(self) -> float: + return self.value + + def used_subscores(self) -> dict[str, float]: + return {self.name: self.value} + + def to_repr(self) -> str: + return self.name + + def __float__(self) -> float: + return self.value + + +@dataclasses.dataclass +class _Collector: + subscores: dict[str, tuple[float, dict[str, Any]]] = dataclasses.field(default_factory=dict) + + +_CONTEXT: contextvars.ContextVar[_Collector | None] = contextvars.ContextVar( + "_dspy_subscores", default=None +) + + +def _begin_collect() -> contextvars.Token: + token = _CONTEXT.set(_Collector()) + return token + + +def _end_collect(token: contextvars.Token) -> _Collector: + collector = _CONTEXT.get() + _CONTEXT.reset(token) + return collector or _Collector() + + +def subscore(name: str, value: float, /, **meta: Any) -> SubscoreValue: + """Register a named subscore for the active metric evaluation.""" + + if not isinstance(name, str) or not name: + raise ValueError("subscore name must be a non-empty string") + numeric = float(value) + collector = _CONTEXT.get() + if collector is not None: + if name in collector.subscores: + raise ValueError(f"Duplicate subscore name: {name}") + collector.subscores[name] = (numeric, dict(meta)) + return SubscoreValue(name, numeric, dict(meta)) + + +def subscore_abs(value: Any) -> SubscoreExpr: + return SubscoreExpr("abs", (_ensure_expr(value),)) + + +def subscore_min(*values: Any) -> SubscoreExpr: + if not values: + raise ValueError("subscore_min requires at least one argument") + return SubscoreExpr("min", tuple(_ensure_expr(v) for v in values)) + + +def subscore_max(*values: Any) -> SubscoreExpr: + if not values: + raise ValueError("subscore_max requires at least one argument") + return SubscoreExpr("max", tuple(_ensure_expr(v) for v in values)) + + +def subscore_clip(value: Any, lower: float | None = None, upper: float | None = None) -> SubscoreExpr: + return SubscoreExpr("clip", (_ensure_expr(value), lower, upper)) + + +def finalize_scores( + result: Any, + collector: _Collector, + *, + ctx_info: dict[str, Any] | None = None, +) -> Score: + """Convert a metric return value into a :class:`Score` object.""" + + ctx_info = dict(ctx_info or {}) + if isinstance(result, Score): + info = dict(result.info) + info.update({k: v for k, v in ctx_info.items() if v is not None}) + return Score(result.scalar, dict(result.subscores), info) + + if isinstance(result, SubscoreValue): + subscores = result.used_subscores() + meta = {name: collector.subscores.get(name, (None, {}))[1] for name in subscores} + info = {"expr": result.to_repr(), "meta": meta} + info.update({k: v for k, v in ctx_info.items() if v is not None}) + _validate_subscores(subscores) + return Score(result.value, subscores, info) + + if isinstance(result, SubscoreExpr): + scalar = float(result.eval()) + subscores = result.used_subscores() + meta = {name: collector.subscores.get(name, (None, {}))[1] for name in subscores} + info = {"expr": result.to_repr(), "meta": meta} + info.update({k: v for k, v in ctx_info.items() if v is not None}) + _validate_subscores(subscores) + _validate_scalar("scalar", scalar) + return Score(scalar, subscores, info) + + scalar = float(result) + subscores = {name: value for name, (value, _) in collector.subscores.items()} + meta = {name: meta for name, (_, meta) in collector.subscores.items()} + info = {"expr": None, "meta": meta} + info.update({k: v for k, v in ctx_info.items() if v is not None}) + _validate_subscores(subscores) + _validate_scalar("scalar", scalar) + return Score(scalar, subscores, info) + + +# Helpers ----------------------------------------------------------------- + +def _ensure_expr(value: Any) -> Any: + if isinstance(value, (SubscoreExpr, SubscoreValue)): + return value + return value + + +def _eval_node(node: Any) -> float: + if isinstance(node, SubscoreExpr): + return node.eval() + if isinstance(node, SubscoreValue): + return node.value + return float(node) + + +def _collect_subscores(node: Any, subscores: dict[str, float]) -> None: + if isinstance(node, SubscoreExpr) and node.op != "subscore": + for arg in node.args: + _collect_subscores(arg, subscores) + elif isinstance(node, SubscoreValue): + subscores[node.name] = node.value + + +def _repr_node(node: Any) -> str: + if isinstance(node, SubscoreValue): + return node.to_repr() + if isinstance(node, SubscoreExpr): + op = node.op + args = node.args + if op == "neg": + return f"-({_repr_node(args[0])})" + if op == "abs": + return f"abs({_repr_node(args[0])})" + if op == "min": + return "min(" + ", ".join(_repr_node(a) for a in args) + ")" + if op == "max": + return "max(" + ", ".join(_repr_node(a) for a in args) + ")" + if op == "clip": + value, lower, upper = args + return f"clip({_repr_node(value)}, {lower!r}, {upper!r})" + if op == "subscore": + return _repr_node(args[0]) + symbol = { + "add": "+", + "sub": "-", + "mul": "*", + "div": "/", + "pow": "**", + }[op] + if op == "sub" and len(args) == 1: + return f"-({_repr_node(args[0])})" + left = _repr_node(args[0]) + right = _repr_node(args[1]) + return f"({left}{symbol}{right})" + if isinstance(node, str): + return node + return repr(node) + + +def _validate_subscores(subscores: Iterable[tuple[str, float]] | dict[str, float]) -> None: + if isinstance(subscores, dict): + items = subscores.items() + else: + items = subscores + for name, value in items: + _validate_scalar(name, value) + + +def _validate_scalar(name: str, value: float) -> None: + if not math.isfinite(float(value)): + raise ValueError(f"Non-finite value for {name}: {value!r}") + + +def coerce_metric_value(value: Any, *, context: str = "metric", warn_on_expression: bool = True) -> Any: + """Coerce subscore expressions into scalar values for legacy metric handlers. + + Optimizers that do not capture subscores may invoke metrics that return SubscoreExpr/SubscoreValue. + This helper warns once per call (when desired) and evaluates the expression so callers + can continue working with numeric scores. + """ + + if value is None: + return None + if isinstance(value, Score): + return value.scalar + if isinstance(value, SubscoreExpr): + if warn_on_expression: + warnings.warn( + f"{context}: metric returned a subscore expression; coercing to float and discarding subscores.", + RuntimeWarning, + stacklevel=3, + ) + return float(value) + if isinstance(value, SubscoreValue): + if warn_on_expression: + warnings.warn( + f"{context}: metric returned a subscore component; coercing to float and discarding metadata.", + RuntimeWarning, + stacklevel=3, + ) + return float(value) + return value diff --git a/dspy/primitives/prediction.py b/dspy/primitives/prediction.py index 4f32fe9fce..c61f6f72ff 100644 --- a/dspy/primitives/prediction.py +++ b/dspy/primitives/prediction.py @@ -1,5 +1,13 @@ +from __future__ import annotations + +from typing import Any, Callable + +from dspy.metrics import Score +from dspy.metrics._subscores import _begin_collect, _end_collect, finalize_scores from dspy.primitives.example import Example +ScoreFn = Callable[[Example, "Prediction", Any], Any] + class Prediction(Example): """A prediction object that contains the output of a DSPy module. @@ -23,6 +31,16 @@ def __init__(self, *args, **kwargs): self._completions = None self._lm_usage = None + self._score_value: float | None = None + self._score_fn: ScoreFn | None = None + self._scores_obj: Score | None = None + self._bound_example: Example | None = None + self._resolved_ctx_key: Any = None + + if "score" in self._store: + initial_score = self._store.pop("score") + if initial_score is not None: + self.score = initial_score def get_lm_usage(self): return self._lm_usage @@ -51,70 +69,144 @@ def __str__(self): return self.__repr__() def __float__(self): - if "score" not in self._store: - raise ValueError("Prediction object does not have a 'score' field to convert to float.") - return float(self._store["score"]) + return float(self._require_numeric_score()) def __add__(self, other): if isinstance(other, (float, int)): return self.__float__() + other elif isinstance(other, Prediction): - return self.__float__() + float(other) + return self._require_numeric_score() + other._require_numeric_score() raise TypeError(f"Unsupported type for addition: {type(other)}") def __radd__(self, other): if isinstance(other, (float, int)): return other + self.__float__() elif isinstance(other, Prediction): - return float(other) + self.__float__() + return other._require_numeric_score() + self._require_numeric_score() raise TypeError(f"Unsupported type for addition: {type(other)}") def __truediv__(self, other): if isinstance(other, (float, int)): - return self.__float__() / other + return self._require_numeric_score() / other elif isinstance(other, Prediction): - return self.__float__() / float(other) + return self._require_numeric_score() / other._require_numeric_score() raise TypeError(f"Unsupported type for division: {type(other)}") def __rtruediv__(self, other): if isinstance(other, (float, int)): - return other / self.__float__() + return other / self._require_numeric_score() elif isinstance(other, Prediction): - return float(other) / self.__float__() + return other._require_numeric_score() / self._require_numeric_score() raise TypeError(f"Unsupported type for division: {type(other)}") def __lt__(self, other): if isinstance(other, (float, int)): return self.__float__() < other elif isinstance(other, Prediction): - return self.__float__() < float(other) + return self._require_numeric_score() < other._require_numeric_score() raise TypeError(f"Unsupported type for comparison: {type(other)}") def __le__(self, other): if isinstance(other, (float, int)): return self.__float__() <= other elif isinstance(other, Prediction): - return self.__float__() <= float(other) + return self._require_numeric_score() <= other._require_numeric_score() raise TypeError(f"Unsupported type for comparison: {type(other)}") def __gt__(self, other): if isinstance(other, (float, int)): return self.__float__() > other elif isinstance(other, Prediction): - return self.__float__() > float(other) + return self._require_numeric_score() > other._require_numeric_score() raise TypeError(f"Unsupported type for comparison: {type(other)}") def __ge__(self, other): if isinstance(other, (float, int)): return self.__float__() >= other elif isinstance(other, Prediction): - return self.__float__() >= float(other) + return self._require_numeric_score() >= other._require_numeric_score() raise TypeError(f"Unsupported type for comparison: {type(other)}") @property def completions(self): return self._completions + @property + def score(self) -> float | ScoreFn | None: + if self._score_value is not None: + return self._score_value + return self._score_fn + + @score.setter + def score(self, value: float | ScoreFn | None) -> None: + self._scores_obj = None + self._resolved_ctx_key = None + if callable(value): + self._score_fn = value + self._score_value = None + self._store.pop("score", None) + else: + self._score_fn = None + if value is None: + self._score_value = None + self._store.pop("score", None) + else: + self._score_value = float(value) + self._store["score"] = self._score_value + + @property + def scores(self) -> Score | None: + if self._scores_obj is None and self._score_value is not None: + self._scores_obj = Score(self._score_value) + return self._scores_obj + + def bind_example(self, example: Example) -> None: + self._bound_example = example + + def resolve_score(self, ctx: Any | None = None, *, force: bool = False) -> Score | None: + if self._score_fn is None: + return self.scores + + ctx_key = getattr(ctx, "cache_key", None) + if not force and self._scores_obj is not None and self._resolved_ctx_key == ctx_key: + return self._scores_obj + + if self._bound_example is None: + raise RuntimeError("Cannot resolve score: no bound example. Call bind_example(ex).") + + token = _begin_collect() + try: + result = self._score_fn(self._bound_example, self, ctx) + finally: + collector = _end_collect(token) + + ctx_info = {} + if ctx is not None: + usage = getattr(ctx, "usage", None) + latency_ms = getattr(ctx, "latency_ms", None) + seed = getattr(ctx, "seed", None) + if usage is not None: + ctx_info["usage"] = usage + if latency_ms is not None: + ctx_info["latency_ms"] = latency_ms + if seed is not None: + ctx_info["seed"] = seed + + scores = finalize_scores(result, collector, ctx_info=ctx_info) + self._store_scores(scores, ctx_key) + return scores + + def _require_numeric_score(self) -> float: + if self._score_value is None: + raise TypeError("Prediction score is not resolved to a numeric value.") + return float(self._score_value) + + def _store_scores(self, scores: Score, ctx_key: Any | None = None) -> None: + self._scores_obj = scores + self._score_value = scores.scalar + self._store["score"] = self._score_value + self._resolved_ctx_key = ctx_key + class Completions: def __init__(self, list_or_dict, signature=None): diff --git a/dspy/propose/utils.py b/dspy/propose/utils.py index 8bd720a23a..6b8ca314bd 100644 --- a/dspy/propose/utils.py +++ b/dspy/propose/utils.py @@ -3,6 +3,7 @@ import re import dspy +from dspy.metrics import Score try: from IPython.core.magics.code import extract_symbols @@ -95,6 +96,16 @@ def create_predictor_level_history_string(base_program, predictor_i, trial_logs, predictor = history_item["program"].predictors()[predictor_i] instruction = get_signature(predictor).instructions score = history_item["score"] + if isinstance(score, Score): + score = score.scalar + elif isinstance(score, dict) and "score" in score: + score = score["score"] + else: + attr_score = getattr(score, "score", None) + if isinstance(attr_score, Score): + score = attr_score.scalar + elif attr_score is not None: + score = attr_score if instruction in instruction_aggregate: instruction_aggregate[instruction]["total_score"] += score diff --git a/dspy/teleprompt/avatar_optimizer.py b/dspy/teleprompt/avatar_optimizer.py index ddba74e5f2..8a2e07fe25 100644 --- a/dspy/teleprompt/avatar_optimizer.py +++ b/dspy/teleprompt/avatar_optimizer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import dspy +from dspy.metrics import resolve_metric_score from dspy.predict.avatar import ActionOutput from dspy.teleprompt.teleprompt import Teleprompter @@ -22,10 +23,11 @@ class EvalResult(BaseModel): class Comparator(dspy.Signature): """After executing the given actions on user inputs using the given instruction, some inputs have yielded good, results, while others have not. I'll provide you the inputs along with their, corresponding evaluation metrics: -Task: -(1) Firstly, identify and contrast the patterns of inputs that have achieved good results with those that have not. -(2) Then, review the computational logic for any inconsistencies in the previous actions. -(3) Lastly, specify the modification in tools used that can lead to improved performance on the negative inputs.""" + Task: + (1) Firstly, identify and contrast the patterns of inputs that have achieved good results with those that have not. + (2) Then, review the computational logic for any inconsistencies in the previous actions. + (3) Lastly, specify the modification in tools used that can lead to improved performance on the negative inputs. + """ instruction: str = dspy.InputField( prefix="Instruction: ", @@ -52,11 +54,12 @@ class Comparator(dspy.Signature): class FeedbackBasedInstruction(dspy.Signature): """There is a task that needs to be completed for which one can use multiple tools to achieve the desired outcome. A group's performance was evaluated on a dataset of inputs, the inputs that did well are positive inputs, and the inputs that did not do well are negative inputs. -You received feedback on how they can better use the tools to improve your performance on the negative inputs. You have been provided with the previous instruction, that they followed to use tools to complete the task, and the feedback on your performance. + You received feedback on how they can better use the tools to improve your performance on the negative inputs. You have been provided with the previous instruction, that they followed to use tools to complete the task, and the feedback on your performance. -Your task is to incorporate the feedback and generate a detailed instruction for the group to follow to improve their performance on the task. + Your task is to incorporate the feedback and generate a detailed instruction for the group to follow to improve their performance on the task. -Make sure that the new instruction talks about how to use the tools effectively and should be no more than 3 paragraphs long. The previous instruction contains general guidelines that you must retain in the new instruction.""" + Make sure that the new instruction talks about how to use the tools effectively and should be no more than 3 paragraphs long. The previous instruction contains general guidelines that you must retain in the new instruction. + """ previous_instruction: str = dspy.InputField( prefix="Previous Instruction: ", @@ -83,7 +86,9 @@ def __init__( max_negative_inputs: int | None = None, optimize_for: str = "max", ): - assert metric is not None, "`metric` argument cannot be None. Please provide a metric function." + assert ( + metric is not None + ), "`metric` argument cannot be None. Please provide a metric function." self.metric = metric self.optimize_for = optimize_for @@ -103,7 +108,13 @@ def process_example(self, actor, example, return_outputs): try: prediction = actor(**example.inputs().toDict()) - score = self.metric(example, prediction) + score_obj, _ = resolve_metric_score( + self.metric, + example, + prediction, + context="AvatarOptimizer metric", + ) + score = float(score_obj.scalar) if return_outputs: return example, prediction, score @@ -118,24 +129,30 @@ def process_example(self, actor, example, return_outputs): else: return 0 - - def thread_safe_evaluator(self, devset, actor, return_outputs=False, num_threads=None): + def thread_safe_evaluator( + self, devset, actor, return_outputs=False, num_threads=None + ): total_score = 0 total_examples = len(devset) results = [] num_threads = num_threads or dspy.settings.num_threads with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(self.process_example, actor, example, return_outputs) for example in devset] - - for future in tqdm(futures, total=total_examples, desc="Processing examples"): + futures = [ + executor.submit(self.process_example, actor, example, return_outputs) + for example in devset + ] + + for future in tqdm( + futures, total=total_examples, desc="Processing examples" + ): result = future.result() if return_outputs: example, prediction, score = result total_score += score results.append((example, prediction, score)) else: - total_score += result + total_score += float(result) avg_metric = total_score / total_examples @@ -144,16 +161,15 @@ def thread_safe_evaluator(self, devset, actor, return_outputs=False, num_threads else: return avg_metric - def _get_pos_neg_results( - self, - actor: dspy.Module, - trainset: list[dspy.Example] + self, actor: dspy.Module, trainset: list[dspy.Example] ) -> tuple[float, list[EvalResult], list[EvalResult]]: pos_inputs = [] neg_inputs = [] - avg_score, results = self.thread_safe_evaluator(trainset, actor, return_outputs=True) + avg_score, results = self.thread_safe_evaluator( + trainset, actor, return_outputs=True + ) print(f"Average Score: {avg_score}") for example, prediction, score in results: @@ -162,7 +178,7 @@ def _get_pos_neg_results( EvalResult( example=example.inputs().toDict(), score=score, - actions=prediction.actions if prediction else None + actions=prediction.actions if prediction else None, ) ) elif score <= self.lower_bound: @@ -170,30 +186,37 @@ def _get_pos_neg_results( EvalResult( example=example.inputs().toDict(), score=score, - actions=prediction.actions if prediction else None + actions=prediction.actions if prediction else None, ) ) if len(pos_inputs) == 0: - raise ValueError("No positive examples found, try lowering the upper_bound or providing more training data") + raise ValueError( + "No positive examples found, try lowering the upper_bound or providing more training data" + ) if len(neg_inputs) == 0: - raise ValueError("No negative examples found, try raising the lower_bound or providing more training data") + raise ValueError( + "No negative examples found, try raising the lower_bound or providing more training data" + ) return (avg_score, pos_inputs, neg_inputs) - def compile(self, student, *, trainset): best_actor = deepcopy(student) best_score = -999 if self.optimize_for == "max" else 999 for i in range(self.max_iters): - print(20*"=") + print(20 * "=") print(f"Iteration {i+1}/{self.max_iters}") - score, pos_inputs, neg_inputs = self._get_pos_neg_results(best_actor, trainset) + score, pos_inputs, neg_inputs = self._get_pos_neg_results( + best_actor, trainset + ) print(f"Positive examples: {len(pos_inputs)}") print(f"Negative examples: {len(neg_inputs)}") - print(f"Sampling {self.max_positive_inputs} positive examples and {self.max_negative_inputs} negative examples") + print( + f"Sampling {self.max_positive_inputs} positive examples and {self.max_negative_inputs} negative examples" + ) if self.max_positive_inputs and len(pos_inputs) > self.max_positive_inputs: pos_inputs = sample(pos_inputs, self.max_positive_inputs) @@ -205,18 +228,22 @@ def compile(self, student, *, trainset): instruction=best_actor.actor.signature.instructions, actions=[str(tool) for tool in best_actor.tools], pos_input_with_metrics=pos_inputs, - neg_input_with_metrics=neg_inputs + neg_input_with_metrics=neg_inputs, ).feedback new_instruction = self.feedback_instruction( previous_instruction=best_actor.actor.signature.instructions, - feedback=feedback + feedback=feedback, ).new_instruction print(f"Generated new instruction: {new_instruction}") - if (self.optimize_for == "max" and best_score < score) or (self.optimize_for == "min" and best_score > score): - best_actor.actor.signature = best_actor.actor.signature.with_instructions(new_instruction) + if (self.optimize_for == "max" and best_score < score) or ( + self.optimize_for == "min" and best_score > score + ): + best_actor.actor.signature = ( + best_actor.actor.signature.with_instructions(new_instruction) + ) best_actor.actor_clone = deepcopy(best_actor.actor) best_score = score diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 31a5736239..1f0b1e7a0d 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -5,6 +5,7 @@ import tqdm import dspy +from dspy.metrics import resolve_metric_score from dspy.teleprompt.teleprompt import Teleprompter from .vanilla import LabeledFewShot @@ -200,11 +201,18 @@ def _bootstrap_one_example(self, example, round_idx=0): predictor.demos = predictor_cache[name] if self.metric: - metric_val = self.metric(example, prediction, trace) - if self.metric_threshold: + score_obj, _ = resolve_metric_score( + self.metric, + example, + prediction, + trace, + context="BootstrapFewShot metric", + ) + metric_val = score_obj.scalar + if self.metric_threshold is not None and metric_val is not None: success = metric_val >= self.metric_threshold else: - success = metric_val + success = bool(metric_val) else: success = True except Exception as e: diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index 844afe8b00..5f594c767c 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -10,6 +10,7 @@ from dspy.adapters.types import History from dspy.adapters.types.base_type import Type from dspy.evaluate import Evaluate +from dspy.metrics import Score from dspy.primitives import Example, Prediction from dspy.teleprompt.bootstrap_trace import TraceData @@ -153,8 +154,16 @@ def evaluate(self, batch, candidate, capture_traces=False): scores.append(self.failure_score) else: score = t["score"] - if hasattr(score, "score"): + if isinstance(score, Score): + score = score.scalar + elif isinstance(score, dict) and "score" in score: score = score["score"] + else: + attr_score = getattr(score, "score", None) + if isinstance(attr_score, Score): + score = attr_score.scalar + elif attr_score is not None: + score = attr_score scores.append(score) return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajs) else: @@ -170,7 +179,21 @@ def evaluate(self, batch, candidate, capture_traces=False): res = evaluator(program) outputs = [r[1] for r in res.results] scores = [r[2] for r in res.results] - scores = [s["score"] if hasattr(s, "score") else s for s in scores] + coerced_scores = [] + for s in scores: + if isinstance(s, Score): + coerced_scores.append(s.scalar) + elif isinstance(s, dict) and "score" in s: + coerced_scores.append(s["score"]) + else: + attr_score = getattr(s, "score", None) + if isinstance(attr_score, Score): + coerced_scores.append(attr_score.scalar) + elif attr_score is not None: + coerced_scores.append(attr_score) + else: + coerced_scores.append(s) + scores = coerced_scores return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: @@ -192,8 +215,16 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - example = data["example"] prediction = data["prediction"] module_score = data["score"] - if hasattr(module_score, "score"): + if isinstance(module_score, Score): + module_score = module_score.scalar + elif isinstance(module_score, dict) and "score" in module_score: module_score = module_score["score"] + else: + attr_score = getattr(module_score, "score", None) + if isinstance(attr_score, Score): + module_score = attr_score.scalar + elif attr_score is not None: + module_score = attr_score trace_instances = [t for t in trace if t[0].signature.equals(module.signature)] if not self.add_format_failure_as_feedback: diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index fd5c3e8808..ebd310b4ef 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -7,6 +7,7 @@ import dspy from dspy.adapters.utils import get_field_description_string +from dspy.metrics import resolve_metric_score from dspy.signatures import InputField, OutputField logger = logging.getLogger(__name__) @@ -41,22 +42,17 @@ def wrapped_program(example): logger.warning(e) trace = dspy.settings.trace.copy() - output = None score = 0.0 output_metadata = {} try: - output = metric(example, prediction) - if isinstance(output, (int, float)): - score = output - elif isinstance(output, dspy.Prediction): - if not hasattr(output, "score"): - raise ValueError("When `metric` returns a `dspy.Prediction`, it must contain a `score` field.") - score = output.score - # Extract fields from the output dspy.Prediction, excluding `score`` - output_metadata = { - k: v for k, v in output.items() if k != "score" - } + score_obj, output_metadata = resolve_metric_score( + metric, + example, + prediction, + context="SIMBA metric", + ) + score = float(score_obj.scalar) except Exception as e: logger.warning(e) diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index c32f5e3ebb..0e56f99b9a 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -10,6 +10,8 @@ import tqdm +from dspy.metrics import Score + logger = logging.getLogger(__name__) @@ -168,7 +170,15 @@ def all_done(): # Update progress if self.compare_results: - vals = [r[-1] for r in results if r is not None] + vals = [] + for r in results: + if r is None: + continue + value = r[-1] + if isinstance(value, Score): + vals.append(value.scalar) + else: + vals.append(value) self._update_progress(pbar, sum(vals), len(vals)) else: self._update_progress( diff --git a/tests/README.md b/tests/README.md index 4cee3eb49c..1e663f8f8b 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,3 +1,3 @@ The tests in this directory are primarily concerned with code correctness and Adapter reliability. -If you're looking for testing the end-to-end quality of DSPy modules and optimizer, refer to [LangProBe](https://github.com/Shangyint/langProBe). \ No newline at end of file +If you're looking for testing the end-to-end quality of DSPy modules and optimizer, refer to [LangProBe](https://github.com/Shangyint/langProBe). diff --git a/tests/evaluate/test_evaluate.py b/tests/evaluate/test_evaluate.py index 211cf25962..9065ee25d9 100644 --- a/tests/evaluate/test_evaluate.py +++ b/tests/evaluate/test_evaluate.py @@ -7,6 +7,7 @@ import dspy from dspy.evaluate.evaluate import Evaluate, EvaluationResult from dspy.evaluate.metrics import answer_exact_match +from dspy.metrics import Score, subscore from dspy.predict import Predict from dspy.utils.callback import BaseCallback from dspy.utils.dummies import DummyLM @@ -50,8 +51,42 @@ def test_evaluate_call(): metric=answer_exact_match, display_progress=False, ) - score = ev(program) - assert score.score == 100.0 + result = ev(program) + assert result.score == 100.0 + assert all(isinstance(score_obj, Score) for *_, score_obj in result.results) + assert all(score_obj.scalar == 1.0 for *_, score_obj in result.results) + + +def test_evaluate_with_subscores(): + dspy.settings.configure( + lm=DummyLM( + { + "What is 1+1?": {"answer": "2"}, + "What is 2+2?": {"answer": "4"}, + } + ) + ) + devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] + program = Predict("question -> answer") + + def metric_with_axes(example, pred, ctx=None): + acc = subscore("acc", answer_exact_match(example, pred)) + latency = subscore("latency_s", (ctx.latency_ms or 0) / 1000 if ctx else 0, maximize=False, units="s") + return acc - 0.05 * latency + + ev = Evaluate( + devset=devset, + metric=metric_with_axes, + display_progress=False, + ) + + result = ev(program) + assert result.score == pytest.approx(100.0, abs=0.5) + for _, _, scores in result.results: + assert isinstance(scores, Score) + assert "acc" in scores.subscores + assert "latency_s" in scores.subscores + assert scores.info.get("expr") is not None @pytest.mark.extra @@ -67,9 +102,9 @@ def test_construct_result_df(): metric=answer_exact_match, ) results = [ - (devset[0], {"answer": "2"}, 100.0), - (devset[1], {"answer": "4"}, 100.0), - (devset[2], {"answer": "-1"}, 0.0), + (devset[0], {"answer": "2"}, Score(100.0)), + (devset[1], {"answer": "4"}, Score(100.0)), + (devset[2], {"answer": "-1"}, Score(0.0)), ] result_df = ev._construct_result_table(results, answer_exact_match.__name__) pd.testing.assert_frame_equal( @@ -98,6 +133,7 @@ def test_multithread_evaluate_call(): ) result = ev(program) assert result.score == 100.0 + assert all(isinstance(score_obj, Score) for *_, score_obj in result.results) def test_multi_thread_evaluate_call_cancelled(monkeypatch): @@ -259,5 +295,5 @@ def on_evaluate_end( assert callback.end_call_count == 1 def test_evaluation_result_repr(): - result = EvaluationResult(score=100.0, results=[(new_example("What is 1+1?", "2"), {"answer": "2"}, 100.0)]) + result = EvaluationResult(score=100.0, results=[(new_example("What is 1+1?", "2"), {"answer": "2"}, Score(100.0))]) assert repr(result) == "EvaluationResult(score=100.0, results=)" diff --git a/tests/metrics/test_subscores.py b/tests/metrics/test_subscores.py new file mode 100644 index 0000000000..d1639c4f0b --- /dev/null +++ b/tests/metrics/test_subscores.py @@ -0,0 +1,65 @@ +import math + +import pytest + +import dspy +from dspy.metrics import subscore +from dspy.metrics._subscores import _begin_collect, _end_collect, finalize_scores + + +def test_subscore_expression_resolution(): + token = _begin_collect() + try: + acc = subscore("acc", 0.8, bounds=(0, 1)) + lat = subscore("latency", 0.2, maximize=False, units="s") + result = acc - 0.5 * lat + finally: + collector = _end_collect(token) + + scores = finalize_scores(result, collector) + assert math.isclose(scores.scalar, 0.8 - 0.1) + assert scores.subscores == {"acc": 0.8, "latency": 0.2} + assert "expr" in scores.info and "latency" in scores.info["meta"] + + +def test_subscore_float_cast_keeps_subscores(): + token = _begin_collect() + try: + acc = subscore("acc", 1.0) + value = float(acc) + finally: + collector = _end_collect(token) + + scores = finalize_scores(value, collector) + assert scores.subscores == {"acc": 1.0} + assert scores.info["expr"] is None + + +def test_subscore_duplicate_name_raises(): + token = _begin_collect() + try: + subscore("acc", 1.0) + with pytest.raises(ValueError): + subscore("acc", 0.5) + finally: + _end_collect(token) + + +def test_prediction_callable_score_resolution(): + example = dspy.Example(question="What is 1+1?", answer="2").with_inputs("question") + prediction = dspy.Prediction(answer="2") + + def score_fn(ex, pred, ctx=None): + assert ex.answer == "2" + return subscore("acc", 1.0) + + prediction.score = score_fn + prediction.bind_example(example) + + with pytest.raises(TypeError): + float(prediction) + + scores = prediction.resolve_score() + assert isinstance(scores.subscores, dict) and scores.subscores["acc"] == 1.0 + assert math.isclose(prediction.score, 1.0) + assert math.isclose(float(prediction), 1.0)