Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_source_segment and support trailing comments in functions #19

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions ast_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _enrich(source: Union[str, bytes], tree: ast.AST) -> None:
if not comment_nodes:
return

tree_intervals = _get_tree_intervals(tree)
tree_intervals = _get_tree_intervals(source, tree)
for c_node in comment_nodes:
c_lineno = c_node.lineno
possible_intervals_for_c_node = [
Expand Down Expand Up @@ -99,9 +99,11 @@ def _enrich(source: Union[str, bytes], tree: ast.AST) -> None:
for left, right in zip(attr[:-1], attr[1:]):
if isinstance(left, Comment) and isinstance(right, Comment):
right.inline = False

target_node.end_lineno = c_node.end_lineno
target_node.end_col_offset = c_node.end_col_offset

def _get_tree_intervals(
source: str,
node: ast.AST,
) -> Dict[Tuple[int, int], Dict[str, Union[List[Tuple[int, int]], ast.AST]]]:
res = {}
Expand All @@ -119,6 +121,12 @@ def _get_tree_intervals(
if hasattr(node, "end_lineno")
else max(attr_intervals)[1]
)
# Add trailing comment lines, doesn't match indentation
for line in source.splitlines()[high:]:
if line.strip().startswith("#"):
high += 1
else:
break
res[(low, high)] = {"intervals": attr_intervals, "node": node}
return res

Expand Down Expand Up @@ -173,3 +181,44 @@ def _get_first_not_comment_idx(orelse: list[ast.stmt]) -> int:

def unparse(ast_obj: ast.AST) -> str:
return _Unparser().visit(ast_obj)


def get_source_segment(source, node, *, padded=False):
"""Get source code segment of the *source* that generated *node*.

If some location information (`lineno`, `end_lineno`, `col_offset`,
or `end_col_offset`) is missing, return None.

If *padded* is `True`, the first line of a multi-line statement will
be padded with spaces to match its original position.

Customized version of ast.get_source_segment that includes trailing
inline comments.
"""
try:
if node.end_lineno is None or node.end_col_offset is None:
return None
lineno = node.lineno - 1
end_lineno = node.end_lineno - 1
col_offset = node.col_offset
# Add trailing inline comment:
end_col_offset = max(node.body[-1].end_col_offset, node.end_col_offset)
except AttributeError:
return None

lines = ast._splitlines_no_ff(source)
if end_lineno == lineno:
return lines[lineno].encode()[col_offset:end_col_offset].decode()

if padded:
padding = ast._pad_whitespace(lines[lineno].encode()[:col_offset].decode())
else:
padding = ''

first = padding + lines[lineno].encode()[col_offset:].decode()
last = lines[end_lineno].encode()[:end_col_offset].decode()
lines = lines[lineno+1:end_lineno]

lines.insert(0, first)
lines.append(last)
return ''.join(lines)
53 changes: 52 additions & 1 deletion test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from ast_comments import Comment, parse
from ast_comments import Comment, get_source_segment, parse


def test_single_comment_in_tree():
Expand Down Expand Up @@ -304,3 +304,54 @@ def test_comment_in_multilined_list():
"""
)
assert len(parse(source).body) == 1


def test_function_with_trailing_comment():
"""Function with trailing comments inside."""
source = dedent(
"""
def foo(*args, **kwargs):
print(args, kwargs) # comment to print
# A comment
# comment in function 'foo'
"""
)
nodes = parse(source).body
assert len(nodes) == 1
function_node = nodes[0]
assert function_node.body[1].value == "# comment to print"
assert function_node.body[1].inline
assert function_node.body[-1].value == "# comment in function 'foo'"
assert not function_node.body[-1].inline


def test_get_source_segment():
"""Check that get_source_segment roundtrips function code."""
source = dedent(
"""
def foo(*args, **kwargs):
print(args, kwargs) # comment to print
# A comment
# comment in function 'foo'
"""
)
function_node = parse(source).body[0]
assert source.strip() == get_source_segment(source, function_node)


@pytest.mark.xfail(reason="Skipping extraneous comments doesn't work.")
def test_get_source_segment_outside_comment():
"""Check that get_source_segment skips extraneous comments."""
source = dedent(
"""
def foo(*args, **kwargs):
print(args, kwargs) # comment to print
# A comment
# comment in function 'foo'
# comment outside function 'foo'
"""
)
function_node = parse(source).body[0]
assert function_node.body[-1].value == "# comment in function 'foo'"
assert not function_node.body[-1].inline
assert source.strip() == get_source_segment(source, function_node)