Skip to content

Commit

Permalink
Add some types to printing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica authored and ricardoV94 committed Jun 6, 2024
1 parent 1a92165 commit 086323f
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,18 +1200,18 @@ def __call__(self, *args):

def pydotprint(
fct,
outfile=None,
compact=True,
format="png",
with_ids=False,
high_contrast=True,
outfile: str | None = None,
compact: bool = True,
format: str = "png",
with_ids: bool = False,
high_contrast: bool = True,
cond_highlight=None,
colorCodes=None,
max_label_size=70,
scan_graphs=False,
var_with_name_simple=False,
print_output_file=True,
return_image=False,
colorCodes: dict | None = None,
max_label_size: int = 70,
scan_graphs: bool = False,
var_with_name_simple: bool = False,
print_output_file: bool = True,
return_image: bool = False,
):
"""Print to a file the graph of a compiled pytensor function's ops. Supports
all pydot output formats, including png and svg.
Expand Down Expand Up @@ -1676,7 +1676,9 @@ def get_tag(self):
return rval


def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None):
def min_informative_str(
obj, indent_level: int = 0, _prev_obs: dict | None = None, _tag_generator=None
) -> str:
"""
Returns a string specifying to the user what obj is
The string will print out as much of the graph as is needed
Expand Down Expand Up @@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None
return rval


def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str:
"""
Returns a string, with no endlines, fully specifying
how a variable is computed. Does not include any memory
Expand Down Expand Up @@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
return rval


def position_independent_str(obj):
def position_independent_str(obj) -> str:
if isinstance(obj, Variable):
rval = "pytensor_var"
rval += "{type=" + str(obj.type) + "}"
Expand All @@ -1842,7 +1844,7 @@ def position_independent_str(obj):
return rval


def hex_digest(x):
def hex_digest(x: np.ndarray) -> str:
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
Expand All @@ -1852,8 +1854,8 @@ def hex_digest(x):
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
rval += "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval += "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval


Expand Down

0 comments on commit 086323f

Please sign in to comment.