diff --git a/src/parsy/__init__.py b/src/parsy/__init__.py index bb105c5..fc0b921 100644 --- a/src/parsy/__init__.py +++ b/src/parsy/__init__.py @@ -12,15 +12,6 @@ noop = lambda x: x -def line_info_at(stream, index): - if index > len(stream): - raise ValueError("invalid index") - line = stream.count("\n", 0, index) - last_nl = stream.rfind("\n", 0, index) - col = index - (last_nl + 1) - return (line, col) - - class ParseError(RuntimeError): def __init__(self, expected, stream, index): self.expected = expected @@ -28,10 +19,10 @@ def __init__(self, expected, stream, index): self.index = index def line_info(self): - try: - return "{}:{}".format(*line_info_at(self.stream, self.index)) - except (TypeError, AttributeError): # not a str - return str(self.index) + if isinstance(self.stream, str): + return "{}:{}".format(self.index.line, self.index.column) + else: + return str(self.index.offset) def __str__(self): expected_list = sorted(repr(e) for e in self.expected) @@ -42,30 +33,37 @@ def __str__(self): return f"expected one of {', '.join(expected_list)} at {self.line_info()}" -@dataclass +@dataclass(frozen=True) +class Position: + offset: int + line: int + column: int + + +@dataclass(frozen=True) class Result: status: bool - index: int + index: Position value: Any - furthest: int + furthest: Position expected: FrozenSet[str] @staticmethod def success(index, value): - return Result(True, index, value, -1, frozenset()) + return Result(True, index, value, Position(-1, -1, -1), frozenset()) @staticmethod def failure(index, expected): - return Result(False, -1, None, index, frozenset([expected])) + return Result(False, Position(-1, -1, -1), None, index, frozenset([expected])) # collect the furthest failure from self and other def aggregate(self, other): if not other: return self - if self.furthest > other.furthest: + if self.furthest.offset > other.furthest.offset: return self - elif self.furthest == other.furthest: + elif self.furthest.offset == other.furthest.offset: # if we both have the same failure index, we combine the expected messages. return Result(self.status, self.index, self.value, self.furthest, self.expected | other.expected) else: @@ -83,14 +81,14 @@ class Parser: of the failure. """ - def __init__(self, wrapped_fn: Callable[[str | bytes | list, int], Result]): + def __init__(self, wrapped_fn: Callable[[str | bytes | list, Position], Result]): """ Creates a new Parser from a function that takes a stream and returns a Result. """ self.wrapped_fn = wrapped_fn - def __call__(self, stream: str | bytes | list, index: int): + def __call__(self, stream: str | bytes | list, index: Position): return self.wrapped_fn(stream, index) def parse(self, stream: str | bytes | list) -> Any: @@ -104,10 +102,10 @@ def parse_partial(self, stream: str | bytes | list) -> tuple[Any, str | bytes | Returns a tuple of the result and the unparsed remainder, or raises ParseError """ - result = self(stream, 0) + result = self(stream, Position(0, 0, 0) if isinstance(stream, str) else Position(0, -1, -1)) if result.status: - return (result.value, stream[result.index :]) + return (result.value, stream[result.index.offset :]) else: raise ParseError(result.expected, stream, result.furthest) @@ -268,7 +266,6 @@ def until_parser(stream, index): values = [] times = 0 while True: - # try parser first res = other(stream, index) if res.status and times >= min: @@ -497,8 +494,8 @@ def generated(stream, index): return generated -index = Parser(lambda _, index: Result.success(index, index)) -line_info = Parser(lambda stream, index: Result.success(index, line_info_at(stream, index))) +index = Parser(lambda _, index: Result.success(index, index.offset)) +line_info = Parser(lambda _, index: Result.success(index, (index.line, index.column))) def success(value: Any) -> Parser: @@ -516,6 +513,17 @@ def fail(expected: str) -> Parser: return Parser(lambda _, index: Result.failure(index, expected)) +def make_index_update(consumed: str) -> Callable[[Position], Position]: + slen = len(consumed) + line_count = consumed.count("\n") + last_nl = consumed.rfind("\n") + return lambda index: Position( + offset=index.offset + slen, + line=index.line + line_count, + column=slen - (last_nl + 1) if last_nl >= 0 else index.column + slen, + ) + + def string(expected_string: str, transform: Callable[[str], str] = noop) -> Parser: """ Returns a parser that expects the ``expected_string`` and produces @@ -527,11 +535,12 @@ def string(expected_string: str, transform: Callable[[str], str] = noop) -> Pars slen = len(expected_string) transformed_s = transform(expected_string) + index_update = make_index_update(expected_string) @Parser def string_parser(stream, index): - if transform(stream[index : index + slen]) == transformed_s: - return Result.success(index + slen, expected_string) + if transform(stream[index.offset : index.offset + slen]) == transformed_s: + return Result.success(index_update(index), expected_string) else: return Result.failure(index, expected_string) @@ -557,9 +566,14 @@ def regex(exp: str, flags=0, group: int | str | tuple = 0) -> Parser: @Parser def regex_parser(stream, index): - match = exp.match(stream, index) + match = exp.match(stream, index.offset) if match: - return Result.success(match.end(), match.group(*group)) + index = ( + make_index_update(stream[match.start() : match.end()])(index) + if isinstance(stream, str) + else Position(match.end(), -1, -1) + ) + return Result.success(index, match.group(*group)) else: return Result.failure(index, exp.pattern) @@ -576,15 +590,19 @@ def test_item(func: Callable[..., bool], description: str) -> Parser: @Parser def test_item_parser(stream, index): - if index < len(stream): + if index.offset < len(stream): if isinstance(stream, bytes): # Subscripting bytes with `[index]` instead of # `[index:index + 1]` returns an int - item = stream[index : index + 1] + item = stream[index.offset : index.offset + 1] else: - item = stream[index] + item = stream[index.offset] if func(item): - return Result.success(index + 1, item) + if isinstance(stream, str): + index = make_index_update(item)(index) + else: + index = Position(index.offset + 1, index.line, index.column) + return Result.success(index, item) return Result.failure(index, description) return test_item_parser @@ -668,7 +686,7 @@ def eof(stream, index): A parser that only succeeds if the end of the stream has been reached. """ - if index >= len(stream): + if index.offset >= len(stream): return Result.success(index, None) else: return Result.failure(index, "EOF") diff --git a/tests/test_parsy.py b/tests/test_parsy.py index ba08f56..14d8eeb 100644 --- a/tests/test_parsy.py +++ b/tests/test_parsy.py @@ -7,6 +7,7 @@ from parsy import ( ParseError, + Position, alt, any_char, char_from, @@ -18,7 +19,6 @@ index, letter, line_info, - line_info_at, match_item, peek, regex, @@ -225,7 +225,7 @@ def thing(): self.assertEqual(ex.expected, frozenset(["a thing"])) self.assertEqual(ex.stream, "x") - self.assertEqual(ex.index, 0) + self.assertEqual(ex.index, Position(0, 0, 0)) def test_generate_default_desc(self): # We shouldn't give a default desc, the messages from the internal @@ -242,7 +242,7 @@ def thing(): self.assertEqual(ex.expected, frozenset(["b"])) self.assertEqual(ex.stream, "ax") - self.assertEqual(ex.index, 1) + self.assertEqual(ex.index, Position(1, 0, 1)) self.assertIn("expected 'b' at 0:1", str(ex)) @@ -345,7 +345,6 @@ def test_at_most(self): self.assertRaises(ParseError, ab.at_most(2).parse, "ababab") def test_until(self): - until = string("s").until(string("x")) s = "ssssx" @@ -367,7 +366,6 @@ def test_until(self): self.assertEqual(until.parse_partial("xxxx"), ([], "xxxx")) def test_until_with_consume_other(self): - until = string("s").until(string("x"), consume_other=True) self.assertEqual(until.parse("ssssx"), 4 * ["s"] + ["x"]) @@ -379,7 +377,6 @@ def test_until_with_consume_other(self): self.assertRaises(ParseError, until.parse, "xssssxy") def test_until_with_min(self): - until = string("s").until(string("x"), min=3) self.assertEqual(until.parse_partial("sssx"), (3 * ["s"], "x")) @@ -388,7 +385,6 @@ def test_until_with_min(self): self.assertRaises(ParseError, until.parse_partial, "ssx") def test_until_with_max(self): - # until with max until = string("s").until(string("x"), max=3) @@ -398,7 +394,6 @@ def test_until_with_max(self): self.assertRaises(ParseError, until.parse_partial, "ssssx") def test_until_with_min_max(self): - until = string("s").until(string("x"), min=3, max=5) self.assertEqual(until.parse_partial("sssx"), (3 * ["s"], "x")) @@ -647,7 +642,7 @@ def test_test_item(self): ex = err.exception self.assertEqual(str(ex), "expected one of 'EOF', 'START/STOP' at 1") self.assertEqual(ex.expected, {"EOF", "START/STOP"}) - self.assertEqual(ex.index, 1) + self.assertEqual(ex.index, Position(1, -1, -1)) def test_match_item(self): self.assertEqual(match_item(self.START).parse([self.START]), self.START) @@ -675,17 +670,6 @@ def foo(): self.assertEqual(foo.many().parse(["A", "B"]), [("A", 0), ("B", 1)]) -class TestUtils(unittest.TestCase): - def test_line_info_at(self): - text = "abc\ndef" - self.assertEqual(line_info_at(text, 0), (0, 0)) - self.assertEqual(line_info_at(text, 2), (0, 2)) - self.assertEqual(line_info_at(text, 3), (0, 3)) - self.assertEqual(line_info_at(text, 4), (1, 0)) - self.assertEqual(line_info_at(text, 7), (1, 3)) - self.assertRaises(ValueError, lambda: line_info_at(text, 8)) - - class TestForwardDeclaration(unittest.TestCase): def test_forward_declaration_1(self): # This is the example from the docs