Skip to content

Commit

Permalink
handle exponents when converting str to float (#891)
Browse files Browse the repository at this point in the history
* handle exponents when converting str to float

* correct isnumeric

* Update CHANGELOG.md

---------

Co-authored-by: Charles Teague <[email protected]>
  • Loading branch information
jjallaire and dragonstyle authored Nov 25, 2024
1 parent 9b5ab2d commit 0a0532a
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Consistent behavior for `max_samples` across sandbox and non-sandbox evals (both now apply `max_samples` per task, formerly evals with sandboxes applied `max_samples` globally).
- Bash tool: add `--login` option so that e.g. .bashrc is read before executing the command.
- Google/Vertex: Support for `logprobs` and other new 1.5 (002 series) options.
- Handle exponents in numeric normalisation for match, include, and answer scorers.
- hf_dataset: added `cached` argument to control whether to use a previously cached version of the dataset if available (defaults to `True`).
- hf_dataset: added `revision` option to load a specific branch or commit SHA (when using `revision` datasets are always revalidated on Hugging Face, i.e. `cached` is ignored).
- Log viewer: display sample ids rather than indexes.
Expand Down
46 changes: 46 additions & 0 deletions src/inspect_ai/_util/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,49 @@ def truncate_string_to_bytes(input: str, max_bytes: int) -> TruncatedOutput | No
except Exception as ex:
logger.warning(f"Unexpected error occurred truncating string: {ex}")
return None


def str_to_float(s: str) -> float:
"""Convert a str to float, including handling exponent characters.
The Python isnumeric() function returns True for strings that include exponents
(e.g. 5²) however the float() function doesn't handle exponents. This function
will correctly handle these exponents when converting from str to float.
Args:
s (str): String to convert to float
Returns:
float: Converted value
Raises:
ValueError: If the string is not a valid numeric value.
"""
# handle empty input
if not s:
raise ValueError("Input string is empty.")

superscript_map = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹", "0123456789")
superscript_chars = "⁰¹²³⁴⁵⁶⁷⁸⁹"

base_part = ""
exponent_part = ""
for idx, char in enumerate(s):
if char in superscript_chars:
base_part = s[:idx]
exponent_part = s[idx:]
break
else:
base_part = s

# handle empty base (e.g., '²')
base = float(base_part) if base_part else 1.0

# handle exponent part
if exponent_part:
exponent_str = exponent_part.translate(superscript_map)
exponent = int(exponent_str)
else:
exponent = 1 # Default exponent is 1 if no superscript is present

return base**exponent
8 changes: 6 additions & 2 deletions src/inspect_ai/scorer/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Callable, Literal

from inspect_ai._util.text import strip_numeric_punctuation, strip_punctuation
from inspect_ai._util.text import (
str_to_float,
strip_numeric_punctuation,
strip_punctuation,
)
from inspect_ai.solver._task_state import TaskState

from ._metric import CORRECT, INCORRECT, Score
Expand Down Expand Up @@ -96,7 +100,7 @@ def first_number_normalized(words: list[str]) -> str:

def normalize_number(number: str, precision: int = 5) -> str:
if number.replace(".", "").isnumeric():
num = float(number)
num = str_to_float(number)
return format(num, f".{precision}g")
else:
return number
64 changes: 64 additions & 0 deletions tests/util/test_str_to_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest

from inspect_ai._util.text import str_to_float


def test_str_to_float_basic():
assert str_to_float("1²") == 1.0
assert str_to_float("2³") == 8.0
assert str_to_float("5⁴") == 625.0
assert str_to_float("10⁰") == 1.0
assert str_to_float("3") == 3.0


def test_str_to_float_decimal_base():
assert str_to_float("2.5²") == 2.5**2
assert str_to_float("0.1³") == 0.1**3


def test_str_to_float_negative_base():
assert str_to_float("-2²") == (-2) ** 2
assert str_to_float("-2³") == (-2) ** 3


def test_str_to_float_multi_digit_exponent():
assert str_to_float("2⁴⁵") == 2**45
assert str_to_float("3⁰⁰⁰") == 3**0 # Exponent is 0


def test_str_to_float_no_exponent():
assert str_to_float("7") == 7.0
assert str_to_float("0") == 0.0


def test_str_to_float_no_base():
# When the base is missing, default to 1.0
assert str_to_float("⁵") == 1.0**5
assert str_to_float("⁰") == 1.0**0


def test_str_to_float_zero_exponent():
assert str_to_float("5⁰") == 1.0
assert str_to_float("0⁰") == 1.0 # 0^0 is considered 1 in this context


def test_str_to_float_invalid_input():
with pytest.raises(ValueError):
str_to_float("abc")
with pytest.raises(ValueError):
str_to_float("")
with pytest.raises(ValueError):
str_to_float("2^3")
with pytest.raises(ValueError):
str_to_float("⁺²") # Unsupported superscript characters


def test_str_to_float_edge_cases():
# Exponent with unsupported characters
with pytest.raises(ValueError):
str_to_float("2⁻³")
# Base with unsupported characters
with pytest.raises(ValueError):
str_to_float("a²")
# Superscript after decimal point
assert str_to_float("2.5⁴") == 2.5**4

0 comments on commit 0a0532a

Please sign in to comment.