Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track line and column data incrementally when parsing strings #71

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 54 additions & 36 deletions src/parsy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,17 @@
noop = lambda x: x


def line_info_at(stream, index):
if index > len(stream):
raise ValueError("invalid index")
line = stream.count("\n", 0, index)
last_nl = stream.rfind("\n", 0, index)
col = index - (last_nl + 1)
return (line, col)


class ParseError(RuntimeError):
def __init__(self, expected, stream, index):
self.expected = expected
self.stream = stream
self.index = index

def line_info(self):
try:
return "{}:{}".format(*line_info_at(self.stream, self.index))
except (TypeError, AttributeError): # not a str
return str(self.index)
if isinstance(self.stream, str):
return "{}:{}".format(self.index.line, self.index.column)
else:
return str(self.index.offset)

def __str__(self):
expected_list = sorted(repr(e) for e in self.expected)
Expand All @@ -42,30 +33,37 @@ def __str__(self):
return f"expected one of {', '.join(expected_list)} at {self.line_info()}"


@dataclass
@dataclass(frozen=True)
class Position:
offset: int
line: int
column: int


@dataclass(frozen=True)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made Result and Position frozen dataclasses, because it seems there's no reason to treat them mutably, and they ought to be hashable if they can be.

class Result:
status: bool
index: int
index: Position
value: Any
furthest: int
furthest: Position
expected: FrozenSet[str]

@staticmethod
def success(index, value):
return Result(True, index, value, -1, frozenset())
return Result(True, index, value, Position(-1, -1, -1), frozenset())

@staticmethod
def failure(index, expected):
return Result(False, -1, None, index, frozenset([expected]))
return Result(False, Position(-1, -1, -1), None, index, frozenset([expected]))

# collect the furthest failure from self and other
def aggregate(self, other):
if not other:
return self

if self.furthest > other.furthest:
if self.furthest.offset > other.furthest.offset:
return self
elif self.furthest == other.furthest:
elif self.furthest.offset == other.furthest.offset:
# if we both have the same failure index, we combine the expected messages.
return Result(self.status, self.index, self.value, self.furthest, self.expected | other.expected)
else:
Expand All @@ -83,14 +81,14 @@ class Parser:
of the failure.
"""

def __init__(self, wrapped_fn: Callable[[str | bytes | list, int], Result]):
def __init__(self, wrapped_fn: Callable[[str | bytes | list, Position], Result]):
"""
Creates a new Parser from a function that takes a stream
and returns a Result.
"""
self.wrapped_fn = wrapped_fn

def __call__(self, stream: str | bytes | list, index: int):
def __call__(self, stream: str | bytes | list, index: Position):
return self.wrapped_fn(stream, index)

def parse(self, stream: str | bytes | list) -> Any:
Expand All @@ -104,10 +102,10 @@ def parse_partial(self, stream: str | bytes | list) -> tuple[Any, str | bytes |
Returns a tuple of the result and the unparsed remainder,
or raises ParseError
"""
result = self(stream, 0)
result = self(stream, Position(0, 0, 0) if isinstance(stream, str) else Position(0, -1, -1))

if result.status:
return (result.value, stream[result.index :])
return (result.value, stream[result.index.offset :])
else:
raise ParseError(result.expected, stream, result.furthest)

Expand Down Expand Up @@ -268,7 +266,6 @@ def until_parser(stream, index):
values = []
times = 0
while True:

# try parser first
res = other(stream, index)
if res.status and times >= min:
Expand Down Expand Up @@ -497,8 +494,8 @@ def generated(stream, index):
return generated


index = Parser(lambda _, index: Result.success(index, index))
line_info = Parser(lambda stream, index: Result.success(index, line_info_at(stream, index)))
index = Parser(lambda _, index: Result.success(index, index.offset))
line_info = Parser(lambda _, index: Result.success(index, (index.line, index.column)))


def success(value: Any) -> Parser:
Expand All @@ -516,6 +513,17 @@ def fail(expected: str) -> Parser:
return Parser(lambda _, index: Result.failure(index, expected))


def make_index_update(consumed: str) -> Callable[[Position], Position]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is curried to avoid recomputing the count and rfind methods every single time someone uses the string parser

slen = len(consumed)
line_count = consumed.count("\n")
last_nl = consumed.rfind("\n")
return lambda index: Position(
offset=index.offset + slen,
line=index.line + line_count,
column=slen - (last_nl + 1) if last_nl >= 0 else index.column + slen,
)


def string(expected_string: str, transform: Callable[[str], str] = noop) -> Parser:
"""
Returns a parser that expects the ``expected_string`` and produces
Expand All @@ -527,11 +535,12 @@ def string(expected_string: str, transform: Callable[[str], str] = noop) -> Pars

slen = len(expected_string)
transformed_s = transform(expected_string)
index_update = make_index_update(expected_string)

@Parser
def string_parser(stream, index):
if transform(stream[index : index + slen]) == transformed_s:
return Result.success(index + slen, expected_string)
if transform(stream[index.offset : index.offset + slen]) == transformed_s:
return Result.success(index_update(index), expected_string)
else:
return Result.failure(index, expected_string)

Expand All @@ -557,9 +566,14 @@ def regex(exp: str, flags=0, group: int | str | tuple = 0) -> Parser:

@Parser
def regex_parser(stream, index):
match = exp.match(stream, index)
match = exp.match(stream, index.offset)
if match:
return Result.success(match.end(), match.group(*group))
index = (
make_index_update(stream[match.start() : match.end()])(index)
if isinstance(stream, str)
else Position(match.end(), -1, -1)
)
return Result.success(index, match.group(*group))
else:
return Result.failure(index, exp.pattern)

Expand All @@ -576,15 +590,19 @@ def test_item(func: Callable[..., bool], description: str) -> Parser:

@Parser
def test_item_parser(stream, index):
if index < len(stream):
if index.offset < len(stream):
if isinstance(stream, bytes):
# Subscripting bytes with `[index]` instead of
# `[index:index + 1]` returns an int
item = stream[index : index + 1]
item = stream[index.offset : index.offset + 1]
else:
item = stream[index]
item = stream[index.offset]
if func(item):
return Result.success(index + 1, item)
if isinstance(stream, str):
index = make_index_update(item)(index)
else:
index = Position(index.offset + 1, index.line, index.column)
return Result.success(index, item)
return Result.failure(index, description)

return test_item_parser
Expand Down Expand Up @@ -668,7 +686,7 @@ def eof(stream, index):
A parser that only succeeds if the end of the stream has been reached.
"""

if index >= len(stream):
if index.offset >= len(stream):
return Result.success(index, None)
else:
return Result.failure(index, "EOF")
Expand Down
24 changes: 4 additions & 20 deletions tests/test_parsy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from parsy import (
ParseError,
Position,
alt,
any_char,
char_from,
Expand All @@ -18,7 +19,6 @@
index,
letter,
line_info,
line_info_at,
match_item,
peek,
regex,
Expand Down Expand Up @@ -225,7 +225,7 @@ def thing():

self.assertEqual(ex.expected, frozenset(["a thing"]))
self.assertEqual(ex.stream, "x")
self.assertEqual(ex.index, 0)
self.assertEqual(ex.index, Position(0, 0, 0))

def test_generate_default_desc(self):
# We shouldn't give a default desc, the messages from the internal
Expand All @@ -242,7 +242,7 @@ def thing():

self.assertEqual(ex.expected, frozenset(["b"]))
self.assertEqual(ex.stream, "ax")
self.assertEqual(ex.index, 1)
self.assertEqual(ex.index, Position(1, 0, 1))

self.assertIn("expected 'b' at 0:1", str(ex))

Expand Down Expand Up @@ -345,7 +345,6 @@ def test_at_most(self):
self.assertRaises(ParseError, ab.at_most(2).parse, "ababab")

def test_until(self):

until = string("s").until(string("x"))

s = "ssssx"
Expand All @@ -367,7 +366,6 @@ def test_until(self):
self.assertEqual(until.parse_partial("xxxx"), ([], "xxxx"))

def test_until_with_consume_other(self):

until = string("s").until(string("x"), consume_other=True)

self.assertEqual(until.parse("ssssx"), 4 * ["s"] + ["x"])
Expand All @@ -379,7 +377,6 @@ def test_until_with_consume_other(self):
self.assertRaises(ParseError, until.parse, "xssssxy")

def test_until_with_min(self):

until = string("s").until(string("x"), min=3)

self.assertEqual(until.parse_partial("sssx"), (3 * ["s"], "x"))
Expand All @@ -388,7 +385,6 @@ def test_until_with_min(self):
self.assertRaises(ParseError, until.parse_partial, "ssx")

def test_until_with_max(self):

# until with max
until = string("s").until(string("x"), max=3)

Expand All @@ -398,7 +394,6 @@ def test_until_with_max(self):
self.assertRaises(ParseError, until.parse_partial, "ssssx")

def test_until_with_min_max(self):

until = string("s").until(string("x"), min=3, max=5)

self.assertEqual(until.parse_partial("sssx"), (3 * ["s"], "x"))
Expand Down Expand Up @@ -647,7 +642,7 @@ def test_test_item(self):
ex = err.exception
self.assertEqual(str(ex), "expected one of 'EOF', 'START/STOP' at 1")
self.assertEqual(ex.expected, {"EOF", "START/STOP"})
self.assertEqual(ex.index, 1)
self.assertEqual(ex.index, Position(1, -1, -1))

def test_match_item(self):
self.assertEqual(match_item(self.START).parse([self.START]), self.START)
Expand Down Expand Up @@ -675,17 +670,6 @@ def foo():
self.assertEqual(foo.many().parse(["A", "B"]), [("A", 0), ("B", 1)])


class TestUtils(unittest.TestCase):
def test_line_info_at(self):
text = "abc\ndef"
self.assertEqual(line_info_at(text, 0), (0, 0))
self.assertEqual(line_info_at(text, 2), (0, 2))
self.assertEqual(line_info_at(text, 3), (0, 3))
self.assertEqual(line_info_at(text, 4), (1, 0))
self.assertEqual(line_info_at(text, 7), (1, 3))
self.assertRaises(ValueError, lambda: line_info_at(text, 8))


class TestForwardDeclaration(unittest.TestCase):
def test_forward_declaration_1(self):
# This is the example from the docs
Expand Down