Skip to content

Commit

Permalink
more annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Carreau committed Nov 4, 2024
1 parent a64a567 commit 12e4b40
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 38 deletions.
19 changes: 13 additions & 6 deletions lib/python/pyflyby/_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
import sys

from typing import List, Any, Dict
from typing import List, Any, Dict, Union, Literal


from pyflyby._autoimp import (LoadSymbolError, ScopeStack, auto_eval,
Expand Down Expand Up @@ -2269,7 +2269,9 @@ def disable(self):

def _safe_call(self, function, *args, **kwargs):
on_error = kwargs.pop("on_error", None)
raise_on_error = kwargs.pop("raise_on_error", "if_debug")
raise_on_error: Union[bool, Literal["if_debug"]] = kwargs.pop(
"raise_on_error", "if_debug"
)
if self._errored:
# If we previously errored, then we should already have
# unregistered the hook that led to here. However, in some corner
Expand All @@ -2295,7 +2297,7 @@ def _safe_call(self, function, *args, **kwargs):
logger.error("Error trying to disable: %s: %s",
type(e2).__name__, e2)
# Raise or print traceback in debug mode.
if raise_on_error == True:
if raise_on_error is True:
raise
elif raise_on_error == 'if_debug':
if logger.debug_enabled:
Expand All @@ -2305,7 +2307,7 @@ def _safe_call(self, function, *args, **kwargs):
import traceback
traceback.print_exc()
raise
elif raise_on_error == False:
elif raise_on_error is False:
if logger.debug_enabled:
import traceback
traceback.print_exc()
Expand All @@ -2328,8 +2330,13 @@ def reset_state_new_cell(self):
sorted([k for k,v in autoimported.items() if not v]))
self._autoimported_this_cell = {}

def auto_import(self, arg, namespaces=None,
raise_on_error='if_debug', on_error=None):
def auto_import(
self,
arg,
namespaces=None,
raise_on_error: Union[bool, Literal["if_debug"]] = "if_debug",
on_error=None,
):
if namespaces is None:
namespaces = get_global_namespaces(self._ip)

Expand Down
34 changes: 17 additions & 17 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
from textwrap import dedent
import types
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union, cast, Literal
import warnings


Expand Down Expand Up @@ -353,7 +353,7 @@ def _annotate_ast_nodes(ast_node: ast.AST) -> AnnotatedAst:


def _annotate_ast_startpos(
ast_node: ast.AST, parent_ast_node, minpos, text, flags
ast_node: ast.AST, parent_ast_node, minpos: FilePos, text: FileText, flags
) -> bool:
r"""
Annotate ``ast_node``. Set ``ast_node.startpos`` to the starting position
Expand Down Expand Up @@ -410,8 +410,8 @@ def _annotate_ast_startpos(
# Walk all nodes/fields of the AST. We implement this as a custom
# depth-first search instead of using ast.walk() or ast.NodeVisitor
# so that we can easily keep track of the preceding node's lineno.
child_minpos = minpos
is_first_child = True
child_minpos: FilePos = minpos
is_first_child: bool = True
leftstr_node = None
for child_node in _iter_child_nodes_in_order(aast_node):
leftstr = _annotate_ast_startpos(
Expand Down Expand Up @@ -668,7 +668,7 @@ def _ast_node_is_in_docstring_position(ast_node):
return False


def infer_compile_mode(arg):
def infer_compile_mode(arg:ast.AST) -> Literal['exec','eval','single']:
"""
Infer the mode needed to compile ``arg``.
Expand All @@ -679,19 +679,18 @@ def infer_compile_mode(arg):
"""
# Infer mode from ast object.
if isinstance(arg, ast.Module):
mode = "exec"
return "exec"
elif isinstance(arg, ast.Expression):
mode = "eval"
return "eval"
elif isinstance(arg, ast.Interactive):
mode = "single"
return "single"
else:
raise TypeError(
"Expected Module/Expression/Interactive ast node; got %s"
% (type(arg).__name__))
return mode


class _DummyAst_Node(object):
class _DummyAst_Node:
pass


Expand Down Expand Up @@ -945,8 +944,9 @@ def from_filename(cls, filename):
return cls.from_text(Filename(filename))

@classmethod
def from_text(cls, text, filename=None, startpos=None, flags=None,
auto_flags=False):
def from_text(
cls, text, filename=None, startpos=None, flags=None, auto_flags: bool = False
):
"""
:type text:
`FileText` or convertible
Expand Down Expand Up @@ -1144,7 +1144,7 @@ def expression_ast_node(self) -> Optional[ast.Expression]:
else:
return None

def parse(self, mode=None) -> Union[ast.Expression, ast.Module]:
def parse(self, mode: Optional[str] = None) -> Union[ast.Expression, ast.Module]:
"""
Parse the source text into an AST.
Expand All @@ -1165,7 +1165,7 @@ def parse(self, mode=None) -> Union[ast.Expression, ast.Module]:
return self.expression_ast_node
else:
raise SyntaxError
elif mode == None:
elif mode is None:
if self.expression_ast_node:
return self.expression_ast_node
else:
Expand All @@ -1176,17 +1176,17 @@ def parse(self, mode=None) -> Union[ast.Expression, ast.Module]:
else:
raise ValueError("parse(): invalid mode=%r" % (mode,))

def compile(self, mode=None):
def compile(self, mode: Optional[str] = None):
"""
Parse into AST and compile AST into code.
:rtype:
``CodeType``
"""
ast_node = self.parse(mode=mode)
mode = infer_compile_mode(ast_node)
c_mode = infer_compile_mode(ast_node)
filename = str(self.filename or "<unknown>")
return compile(ast_node, filename, mode)
return compile(ast_node, filename, c_mode)

@cached_property
def statements(self) -> Tuple[PythonStatement, ...]:
Expand Down
26 changes: 13 additions & 13 deletions lib/python/pyflyby/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,6 @@
from pyflyby._util import cmp
from shlex import quote as shquote

usage = """
py --- command-line python multitool with automatic importing
$ py [--file] filename.py arg1 arg2 Execute file
$ py [--apply] function arg1 arg2 Call function
$ py [--eval] 'function(arg1, arg2)' Evaluate code
$ py [--module] modname arg1 arg2 Run a module
$ py --debug file/code... args... Debug code
$ py --debug PID Attach debugger to PID
$ py IPython shell
""".strip()

# TODO: add --tidy-imports, etc

Expand Down Expand Up @@ -377,6 +364,19 @@
from pyflyby._parse import PythonBlock
from pyflyby._util import indent, prefixes

usage = """
py --- command-line python multitool with automatic importing
$ py [--file] filename.py arg1 arg2 Execute file
$ py [--apply] function arg1 arg2 Call function
$ py [--eval] 'function(arg1, arg2)' Evaluate code
$ py [--module] modname arg1 arg2 Run a module
$ py --debug file/code... args... Debug code
$ py --debug PID Attach debugger to PID
$ py IPython shell
""".strip()

# Default compiler flags (feature flags) used for all user code. We include
# "print_function" here, but we also use auto_flags=True, which means
Expand Down
3 changes: 1 addition & 2 deletions lib/python/pyflyby/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@



from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
import inspect
import os
import sys
Expand Down Expand Up @@ -451,7 +451,6 @@ def cmp(a, b):


# Create a context manager with an arbitrary number of contexts.
from contextlib import ExitStack
@contextmanager
def nested(*mgrs):
with ExitStack() as stack:
Expand Down

0 comments on commit 12e4b40

Please sign in to comment.