Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Dec 17, 2024
2 parents f25bb99 + 67a087d commit 0af4d58
Show file tree
Hide file tree
Showing 35 changed files with 289 additions and 131 deletions.
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@

sys._BUILDING_SPHINX_DOCS = True


nitpick_ignore_regex = [
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"numpy.bool_"],
["py:class", r"typing_extensions(.+)"],
["py:class", r"P\.args"],
["py:class", r"P\.kwargs"],
["py:class", r"lp\.LoopKernel"],
["py:class", r"_dtype_any"],
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"bidict",
"immutabledict",
"loopy>=2020.2",
"pytools>=2024.1.14",
"pytools>=2024.1.21",
"pymbolic>=2024.2",
"typing_extensions>=4",
]
Expand All @@ -62,6 +62,7 @@ extend-select = [
"NPY", # numpy
"RUF",
"UP",
"TC",
]
extend-ignore = [
"E226",
Expand Down
5 changes: 3 additions & 2 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from loopy.tools import LoopyKeyBuilder
Expand All @@ -48,12 +47,14 @@
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper


if TYPE_CHECKING:
from collections.abc import Mapping

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.loopy import LoopyCall

__doc__ = """
.. currentmodule:: pytato.analysis
Expand Down
16 changes: 8 additions & 8 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def {cls.__name__}_hash(self):
except AttributeError:
pass
h = hash(frozenset({attr_tuple_hash}))
h = hash({attr_tuple_hash})
object.__setattr__(self, "_hash_value", h)
return h
Expand Down Expand Up @@ -409,7 +409,7 @@ def _dataclass_setstate(self, state):

# {{{ assign mapper_method

mm_cls = cast(type[_HasMapperMethod], cls)
mm_cls = cast("type[_HasMapperMethod]", cls)

snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower()
default_mapper_method_name = f"map_{snake_clsname}"
Expand Down Expand Up @@ -832,7 +832,7 @@ def conj(self) -> ArrayOrScalar:

def __abs__(self) -> Array:
import pytato as pt
return cast(Array, pt.abs(self))
return cast("Array", pt.abs(self))

def __pos__(self) -> Array:
return self
Expand Down Expand Up @@ -1783,7 +1783,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | Integer, not_none(self.indices[i_idx]))
cast("Array | Integer", not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer
Expand Down Expand Up @@ -1831,7 +1831,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | Integer, not_none(self.indices[i_idx]))
cast("Array | Integer", not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer slices
Expand Down Expand Up @@ -2065,7 +2065,7 @@ def matmul(x1: Array, x2: Array) -> Array:
if x1.ndim == x2.ndim == 1:
return pt.sum(x1 * x2)
elif x1.ndim == 1:
return cast(Array, pt.dot(x1, x2))
return cast("Array", pt.dot(x1, x2))
elif x2.ndim == 1:
x1_indices = index_names[:x1.ndim]
return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}",
Expand Down Expand Up @@ -2398,7 +2398,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,
else:
fill_value = conv_dtype.type(fill_value)

return IndexLambda(expr=cast(ArithmeticExpression, fill_value),
return IndexLambda(expr=cast("ArithmeticExpression", fill_value),
shape=shape, dtype=conv_dtype,
bindings=immutabledict(),
tags=_get_default_tags(),
Expand Down Expand Up @@ -2541,7 +2541,7 @@ def arange(*args: Any, **kwargs: Any) -> Array:
from math import ceil
# np.real() suppresses "ComplexWarning: Casting complex values to real
# discards the imaginary part":
size = max(0, int(ceil((np.real(stop)-np.real(start))/np.real(step))))
size = max(0, ceil((np.real(stop)-np.real(start))/np.real(step)))

from pymbolic.primitives import Variable
return IndexLambda(expr=start + Variable("_0") * step,
Expand Down
11 changes: 7 additions & 4 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@
# }}}


from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np
from immutabledict import immutabledict

import pymbolic.primitives as prim
from pymbolic import Scalar, var
from pymbolic.typing import Expression

from pytato.array import (
Array,
Expand All @@ -78,6 +77,10 @@
from pytato.scalar_expr import SCALAR_CLASSES


if TYPE_CHECKING:
from pymbolic.typing import Expression


def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
func_name: str,
ret_dtype: _dtype_any | None = None,
Expand All @@ -88,7 +91,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
np_func_name = func_name

np_func = getattr(np, np_func_name)
return cast(ArrayOrScalar, np_func(*inputs))
return cast("ArrayOrScalar", np_func(*inputs))

if not inputs:
raise ValueError("at least one argument must be present")
Expand Down Expand Up @@ -233,7 +236,7 @@ def imag(x: ArrayOrScalar) -> ArrayOrScalar:
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
else:
if np.isscalar(x):
return cast(Scalar, x_dtype.type(0))
return cast("Scalar", x_dtype.type(0))
else:
assert isinstance(x, Array)
import pytato as pt
Expand Down
12 changes: 8 additions & 4 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
"""

import dataclasses
from collections.abc import Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

from immutabledict import immutabledict

Expand All @@ -57,10 +56,8 @@
SizeParam,
make_dict_of_named_arrays,
)
from pytato.function import NamedCallResult
from pytato.loopy import LoopyCall
from pytato.scalar_expr import IntegralScalarExpression, is_integral_scalar_expression
from pytato.target import Target
from pytato.transform import (
ArrayOrNames,
CachedWalkMapper,
Expand All @@ -70,6 +67,13 @@
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin


if TYPE_CHECKING:
from collections.abc import Mapping

from pytato.function import NamedCallResult
from pytato.target import Target


SymbolicIndex: TypeAlias = tuple[IntegralScalarExpression, ...]


Expand Down
22 changes: 12 additions & 10 deletions pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,30 @@
"""

import logging
from collections.abc import Hashable, Mapping
from typing import TYPE_CHECKING, Any

import numpy as np

from pytato.array import make_dict_of_named_arrays
from pytato.distributed.nodes import DistributedRecv, DistributedSend
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
PartId,
)
from pytato.scalar_expr import INT_CLASSES
from pytato.target import BoundProgram


logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping

import mpi4py.MPI

from pytato.distributed.nodes import DistributedRecv, DistributedSend
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
PartId,
)
from pytato.target import BoundProgram


# {{{ generate_code_for_partition

Expand Down Expand Up @@ -88,7 +90,7 @@ def _post_receive(mpi_communicator: mpi4py.MPI.Comm,

assert isinstance(recv.comm_tag, int)
# mypy is right here, size params in 'recv.shape' must be evaluated
buf = np.empty(recv.shape, dtype=recv.dtype) # type: ignore[arg-type]
buf = np.empty(recv.shape, dtype=recv.dtype) # type: ignore[type-var]

return mpi_communicator.Irecv(
buf=buf, source=recv.src_rank, tag=recv.comm_tag), buf
Expand Down Expand Up @@ -134,7 +136,7 @@ def execute_distributed_partition(
context: dict[str, Any] = input_args.copy()

pids_to_execute = set(partition.parts)
pids_executed = set()
pids_executed: set[PartId] = set()
recv_names_completed = set()
send_requests = []

Expand Down
21 changes: 11 additions & 10 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@
DistributedSend,
DistributedSendRefHolder,
)
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper


if TYPE_CHECKING:
import mpi4py.MPI

from pytato.function import FunctionDefinition, NamedCallResult


@dataclasses.dataclass(frozen=True)
class CommunicationOpIdentifier:
Expand Down Expand Up @@ -350,7 +351,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if name is not None:
return self._get_placeholder_for(name, expr)

return cast(ArrayOrNames, super().rec(expr))
return cast("ArrayOrNames", super().rec(expr))

def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder:
placeholder = self.partition_input_name_to_placeholder.get(name)
Expand Down Expand Up @@ -832,7 +833,7 @@ def find_distributed_partition(
# Production
# comm_batches = comm_batches_or_exc
comm_batches = cast(
Sequence[Set[CommunicationOpIdentifier]],
"Sequence[Set[CommunicationOpIdentifier]]",
comm_batches_or_exc)

# }}}
Expand Down Expand Up @@ -933,7 +934,7 @@ def find_distributed_partition(
ary: max(
(comm_id_to_part_id[
_recv_to_comm_id(local_rank,
cast(DistributedRecv, recvd_ary))]
cast("DistributedRecv", recvd_ary))]
for recvd_ary in recvd_array_dep_mapper(ary)),
default=-1)
for ary in mso_arrays
Expand Down Expand Up @@ -1004,18 +1005,18 @@ def get_materialized_predecessors(ary: Array) -> dict[Array, None]:

# }}}

# Don't be tempted to put outputs in _array_names; the mapping from output array
# Don't be tempted to put outputs in array_names; the mapping from output array
# to name may not be unique
_array_name_gen = UniqueNameGenerator(forced_prefix="_pt_dist_")
_array_names: dict[Array, str] = {}
array_name_gen = UniqueNameGenerator(forced_prefix="_pt_dist_")
array_names: dict[Array, str] = {}

def gen_array_name(ary: Array) -> str:
name = _array_names.get(ary)
name = array_names.get(ary)
if name is not None:
return name
else:
name = _array_name_gen()
_array_names[ary] = name
name = array_name_gen()
array_names[ary] = name
return name

recvd_ary_to_name: dict[Array, str] = {
Expand Down
9 changes: 5 additions & 4 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@

import dataclasses
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import numpy as np

from pymbolic.mapper.optimize import optimize_mapper

from pytato.array import (
Expand All @@ -51,7 +48,6 @@
ShapeType,
make_dict_of_named_arrays,
)
from pytato.distributed.nodes import CommTagType, DistributedRecv
from pytato.distributed.partition import (
CommunicationOpIdentifier,
DistributedGraphPartition,
Expand All @@ -64,7 +60,12 @@


if TYPE_CHECKING:
from collections.abc import Sequence

import mpi4py.MPI
import numpy as np

from pytato.distributed.nodes import CommTagType, DistributedRecv


# {{{ data structures
Expand Down
5 changes: 3 additions & 2 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
THE SOFTWARE.
"""

from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from pytools import memoize_method
Expand All @@ -50,11 +49,13 @@
SizeParam,
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult


if TYPE_CHECKING:
from collections.abc import Callable

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall, LoopyCallResult

__doc__ = """
Expand Down
Loading

0 comments on commit 0af4d58

Please sign in to comment.