diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9ed0c5b..2af9cae56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Consistent behavior for `max_samples` across sandbox and non-sandbox evals (both now apply `max_samples` per task, formerly evals with sandboxes applied `max_samples` globally). - Bash tool: add `--login` option so that e.g. .bashrc is read before executing the command. - Google/Vertex: Support for `logprobs` and other new 1.5 (002 series) options. +- Handle exponents in numeric normalisation for match, include, and answer scorers. - hf_dataset: added `cached` argument to control whether to use a previously cached version of the dataset if available (defaults to `True`). - hf_dataset: added `revision` option to load a specific branch or commit SHA (when using `revision` datasets are always revalidated on Hugging Face, i.e. `cached` is ignored). - Log viewer: display sample ids rather than indexes. diff --git a/src/inspect_ai/_util/text.py b/src/inspect_ai/_util/text.py index 8a2a79ade..0afaabfae 100644 --- a/src/inspect_ai/_util/text.py +++ b/src/inspect_ai/_util/text.py @@ -62,3 +62,49 @@ def truncate_string_to_bytes(input: str, max_bytes: int) -> TruncatedOutput | No except Exception as ex: logger.warning(f"Unexpected error occurred truncating string: {ex}") return None + + +def str_to_float(s: str) -> float: + """Convert a str to float, including handling exponent characters. + + The Python isnumeric() function returns True for strings that include exponents + (e.g. 5²) however the float() function doesn't handle exponents. This function + will correctly handle these exponents when converting from str to float. + + Args: + s (str): String to convert to float + + Returns: + float: Converted value + + Raises: + ValueError: If the string is not a valid numeric value. + """ + # handle empty input + if not s: + raise ValueError("Input string is empty.") + + superscript_map = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹", "0123456789") + superscript_chars = "⁰¹²³⁴⁵⁶⁷⁸⁹" + + base_part = "" + exponent_part = "" + for idx, char in enumerate(s): + if char in superscript_chars: + base_part = s[:idx] + exponent_part = s[idx:] + break + else: + base_part = s + + # handle empty base (e.g., '²') + base = float(base_part) if base_part else 1.0 + + # handle exponent part + if exponent_part: + exponent_str = exponent_part.translate(superscript_map) + exponent = int(exponent_str) + else: + exponent = 1 # Default exponent is 1 if no superscript is present + + return base**exponent diff --git a/src/inspect_ai/scorer/_common.py b/src/inspect_ai/scorer/_common.py index 50fe816cd..668c0574d 100644 --- a/src/inspect_ai/scorer/_common.py +++ b/src/inspect_ai/scorer/_common.py @@ -1,6 +1,10 @@ from typing import Callable, Literal -from inspect_ai._util.text import strip_numeric_punctuation, strip_punctuation +from inspect_ai._util.text import ( + str_to_float, + strip_numeric_punctuation, + strip_punctuation, +) from inspect_ai.solver._task_state import TaskState from ._metric import CORRECT, INCORRECT, Score @@ -96,7 +100,7 @@ def first_number_normalized(words: list[str]) -> str: def normalize_number(number: str, precision: int = 5) -> str: if number.replace(".", "").isnumeric(): - num = float(number) + num = str_to_float(number) return format(num, f".{precision}g") else: return number diff --git a/tests/util/test_str_to_float.py b/tests/util/test_str_to_float.py new file mode 100644 index 000000000..467d759a9 --- /dev/null +++ b/tests/util/test_str_to_float.py @@ -0,0 +1,64 @@ +import pytest + +from inspect_ai._util.text import str_to_float + + +def test_str_to_float_basic(): + assert str_to_float("1²") == 1.0 + assert str_to_float("2³") == 8.0 + assert str_to_float("5⁴") == 625.0 + assert str_to_float("10⁰") == 1.0 + assert str_to_float("3") == 3.0 + + +def test_str_to_float_decimal_base(): + assert str_to_float("2.5²") == 2.5**2 + assert str_to_float("0.1³") == 0.1**3 + + +def test_str_to_float_negative_base(): + assert str_to_float("-2²") == (-2) ** 2 + assert str_to_float("-2³") == (-2) ** 3 + + +def test_str_to_float_multi_digit_exponent(): + assert str_to_float("2⁴⁵") == 2**45 + assert str_to_float("3⁰⁰⁰") == 3**0 # Exponent is 0 + + +def test_str_to_float_no_exponent(): + assert str_to_float("7") == 7.0 + assert str_to_float("0") == 0.0 + + +def test_str_to_float_no_base(): + # When the base is missing, default to 1.0 + assert str_to_float("⁵") == 1.0**5 + assert str_to_float("⁰") == 1.0**0 + + +def test_str_to_float_zero_exponent(): + assert str_to_float("5⁰") == 1.0 + assert str_to_float("0⁰") == 1.0 # 0^0 is considered 1 in this context + + +def test_str_to_float_invalid_input(): + with pytest.raises(ValueError): + str_to_float("abc") + with pytest.raises(ValueError): + str_to_float("") + with pytest.raises(ValueError): + str_to_float("2^3") + with pytest.raises(ValueError): + str_to_float("⁺²") # Unsupported superscript characters + + +def test_str_to_float_edge_cases(): + # Exponent with unsupported characters + with pytest.raises(ValueError): + str_to_float("2⁻³") + # Base with unsupported characters + with pytest.raises(ValueError): + str_to_float("a²") + # Superscript after decimal point + assert str_to_float("2.5⁴") == 2.5**4