Skip to content

Commit e7f77b5

Browse files
committed
fix more pyright issues
1 parent 33c5d9b commit e7f77b5

File tree

6 files changed

+39
-124
lines changed

6 files changed

+39
-124
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -22633,14 +22633,6 @@
2263322633
"lineCount": 1
2263422634
}
2263522635
},
22636-
{
22637-
"code": "reportUnusedFunction",
22638-
"range": {
22639-
"startColumn": 4,
22640-
"endColumn": 8,
22641-
"lineCount": 1
22642-
}
22643-
},
2264422636
{
2264522637
"code": "reportMissingParameterType",
2264622638
"range": {
@@ -24307,22 +24299,6 @@
2430724299
"lineCount": 1
2430824300
}
2430924301
},
24310-
{
24311-
"code": "reportAttributeAccessIssue",
24312-
"range": {
24313-
"startColumn": 25,
24314-
"endColumn": 29,
24315-
"lineCount": 1
24316-
}
24317-
},
24318-
{
24319-
"code": "reportAttributeAccessIssue",
24320-
"range": {
24321-
"startColumn": 25,
24322-
"endColumn": 29,
24323-
"lineCount": 1
24324-
}
24325-
},
2432624302
{
2432724303
"code": "reportAttributeAccessIssue",
2432824304
"range": {
@@ -24347,14 +24323,6 @@
2434724323
"lineCount": 1
2434824324
}
2434924325
},
24350-
{
24351-
"code": "reportAttributeAccessIssue",
24352-
"range": {
24353-
"startColumn": 43,
24354-
"endColumn": 47,
24355-
"lineCount": 1
24356-
}
24357-
},
2435824326
{
2435924327
"code": "reportArgumentType",
2436024328
"range": {
@@ -24363,22 +24331,6 @@
2436324331
"lineCount": 1
2436424332
}
2436524333
},
24366-
{
24367-
"code": "reportAttributeAccessIssue",
24368-
"range": {
24369-
"startColumn": 37,
24370-
"endColumn": 41,
24371-
"lineCount": 1
24372-
}
24373-
},
24374-
{
24375-
"code": "reportAttributeAccessIssue",
24376-
"range": {
24377-
"startColumn": 42,
24378-
"endColumn": 50,
24379-
"lineCount": 1
24380-
}
24381-
},
2438224334
{
2438324335
"code": "reportArgumentType",
2438424336
"range": {
@@ -24387,22 +24339,6 @@
2438724339
"lineCount": 1
2438824340
}
2438924341
},
24390-
{
24391-
"code": "reportAttributeAccessIssue",
24392-
"range": {
24393-
"startColumn": 38,
24394-
"endColumn": 42,
24395-
"lineCount": 1
24396-
}
24397-
},
24398-
{
24399-
"code": "reportAttributeAccessIssue",
24400-
"range": {
24401-
"startColumn": 43,
24402-
"endColumn": 51,
24403-
"lineCount": 1
24404-
}
24405-
},
2440624342
{
2440724343
"code": "reportMissingParameterType",
2440824344
"range": {
@@ -24615,22 +24551,6 @@
2461524551
"lineCount": 1
2461624552
}
2461724553
},
24618-
{
24619-
"code": "reportUnusedFunction",
24620-
"range": {
24621-
"startColumn": 4,
24622-
"endColumn": 28,
24623-
"lineCount": 1
24624-
}
24625-
},
24626-
{
24627-
"code": "reportUnusedFunction",
24628-
"range": {
24629-
"startColumn": 4,
24630-
"endColumn": 30,
24631-
"lineCount": 1
24632-
}
24633-
},
2463424554
{
2463524555
"code": "reportMissingParameterType",
2463624556
"range": {
@@ -24663,14 +24583,6 @@
2466324583
"lineCount": 1
2466424584
}
2466524585
},
24666-
{
24667-
"code": "reportUnusedFunction",
24668-
"range": {
24669-
"startColumn": 4,
24670-
"endColumn": 23,
24671-
"lineCount": 1
24672-
}
24673-
},
2467424586
{
2467524587
"code": "reportGeneralTypeIssues",
2467624588
"range": {
@@ -24735,14 +24647,6 @@
2473524647
"lineCount": 1
2473624648
}
2473724649
},
24738-
{
24739-
"code": "reportUnusedFunction",
24740-
"range": {
24741-
"startColumn": 4,
24742-
"endColumn": 26,
24743-
"lineCount": 1
24744-
}
24745-
},
2474624650
{
2474724651
"code": "reportMissingParameterType",
2474824652
"range": {

arraycontext/container/traversal.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def unflatten(
777777
checks are skipped.
778778
"""
779779
# NOTE: https://github.com/python/mypy/issues/7057
780-
offset = 0
780+
offset: int = 0
781781
common_dtype = None
782782

783783
def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
@@ -791,7 +791,8 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
791791
# {{{ validate subary
792792

793793
if (
794-
isinstance(template_subary_c.size, Integer)
794+
isinstance(offset, Integer)
795+
and isinstance(template_subary_c.size, Integer)
795796
and isinstance(ary.size, Integer)
796797
and (offset + template_subary_c.size) > ary.size):
797798
raise ValueError("'template' and 'ary' sizes do not match: "
@@ -816,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
816817

817818
# {{{ reshape
818819

820+
if not isinstance(template_subary_c.size, Integer):
821+
raise NotImplementedError(
822+
"unflatten is not implemented for arrays with array-valued "
823+
"size.") from None
824+
825+
# FIXME: Not sure how to make the slicing part work for Array-valued sizes
819826
flat_subary = ary[offset:offset + template_subary_c.size]
820827
try:
821828
subary = actx.np.reshape(flat_subary,

arraycontext/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
582582
def outline(self,
583583
f: Callable[..., Any],
584584
*,
585-
id: Hashable | None = None) -> Callable[..., Any]:
585+
id: Hashable | None = None) -> Callable[..., Any]: # pyright: ignore[reportUnusedParameter]
586586
"""
587587
Returns a drop-in-replacement for *f*. The behavior of the returned
588588
callable is specific to the derived class.

arraycontext/impl/pytato/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def outline(self,
236236
f: Callable[..., Any],
237237
*,
238238
id: Hashable | None = None,
239-
tags: frozenset[Tag] = frozenset()
239+
tags: frozenset[Tag] = frozenset() # pyright: ignore[reportCallInDefaultInitializer]
240240
) -> Callable[..., Any]:
241241
from pytato.tags import FunctionIdentifier
242242

@@ -962,6 +962,7 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
962962
from .compile import LazilyJAXCompilingFunctionCaller
963963
return LazilyJAXCompilingFunctionCaller(self, f)
964964

965+
@override
965966
def transform_dag(self, dag: pytato.DictOfNamedArrays
966967
) -> pytato.DictOfNamedArrays:
967968
import pytato as pt

arraycontext/impl/pytato/outline.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@
4545
from arraycontext.context import (
4646
Array,
4747
ArrayOrContainer,
48-
ArrayOrContainerTc,
4948
ArrayT,
5049
)
5150
from arraycontext.impl.pytato import _BasePytatoArrayContext
5251

5352

5453
def _get_arg_id_to_arg(args: tuple[object, ...],
5554
kwargs: Mapping[str, object]
56-
) -> immutabledict[tuple[object, ...], object]:
55+
) -> immutabledict[tuple[object, ...], pt.Array]:
5756
"""
5857
Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5958
to argument values. See
@@ -104,7 +103,7 @@ def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
104103

105104

106105
def _get_arg_id_to_placeholder(
107-
arg_id_to_arg: Mapping[tuple[object, ...], object],
106+
arg_id_to_arg: Mapping[tuple[object, ...], pt.Array],
108107
prefix: str | None = None) -> immutabledict[tuple[object, ...], pt.Placeholder]:
109108
"""
110109
Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
@@ -122,25 +121,25 @@ def _get_arg_id_to_placeholder(
122121

123122
def _call_with_placeholders(
124123
f: Callable[..., object],
125-
args: tuple[object],
124+
args: tuple[object, ...],
126125
kwargs: Mapping[str, object],
127126
arg_id_to_placeholder: Mapping[tuple[object, ...], pt.Placeholder]) -> object:
128127
"""
129128
Construct placeholders analogous to *args* and *kwargs* and call *f*.
130129
"""
131130
def get_placeholder_replacement(
132-
arg: ArrayOrContainerTc | Scalar | None, key: tuple[object, ...]
133-
) -> ArrayOrContainerTc | Scalar | None:
131+
arg: ArrayOrContainer | Scalar | None, key: tuple[object, ...]
132+
) -> ArrayOrContainer | Scalar | None:
134133
if arg is None:
135134
return None
136135
elif np.isscalar(arg):
137136
return cast(Scalar, arg)
138137
elif isinstance(arg, pt.Array):
139-
return cast(ArrayOrContainerTc, arg_id_to_placeholder[key])
138+
return arg_id_to_placeholder[key]
140139
elif is_array_container_type(arg.__class__):
141-
def _rec_to_placeholder(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
142-
result = get_placeholder_replacement(ary, key + keys)
143-
return cast(ArrayT, result)
140+
def _rec_to_placeholder(
141+
keys: tuple[object, ...], ary: Array) -> Array:
142+
return cast("Array", get_placeholder_replacement(ary, key + keys))
144143

145144
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
146145
else:
@@ -176,7 +175,7 @@ def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT:
176175

177176
def _pack_output(
178177
output_template: ArrayOrContainer,
179-
unpacked_output: Array | immutabledict[str, Array]
178+
unpacked_output: pt.Array | immutabledict[str, pt.Array]
180179
) -> ArrayOrContainer:
181180
"""
182181
Pack *unpacked_output* into array containers according to *output_template*.
@@ -187,9 +186,9 @@ def _pack_output(
187186
elif is_array_container_type(output_template.__class__):
188187
assert isinstance(unpacked_output, immutabledict)
189188

190-
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array:
189+
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array: # pyright: ignore[reportUnusedParameter]
191190
key_str = _get_output_arg_id_str(key)
192-
return unpacked_output[key_str]
191+
return unpacked_output[key_str] # type: ignore[index]
193192

194193
return rec_keyed_map_array_container(_pack_into_container, output_template)
195194
else:
@@ -262,9 +261,6 @@ def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
262261
call_site_output = func_def(**call_bindings)
263262

264263
assert isinstance(call_site_output, pt.Array | immutabledict)
265-
# FIXME: pt.Array is not an actx Array
266-
return _pack_output(cast("Array | immutabledict[str, Array]", output),
267-
cast("Array | immutabledict[str, Array]", call_site_output))
268-
264+
return _pack_output(output, call_site_output)
269265

270266
# vim: foldmethod=marker

test/test_arraycontext.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
"""
2525

2626
import logging
27+
from collections.abc import Callable
2728
from dataclasses import dataclass
2829
from functools import partial
30+
from typing import TypeAlias
2931

3032
import numpy as np
3133
import pytest
@@ -34,6 +36,7 @@
3436
from pytools.tag import Tag
3537

3638
from arraycontext import (
39+
ArrayContext,
3740
BcastUntilActxArray,
3841
EagerJAXArrayContext,
3942
NumpyArrayContext,
@@ -58,6 +61,9 @@
5861
logger = logging.getLogger(__name__)
5962

6063

64+
ArrayContextFactory: TypeAlias = Callable[[], ArrayContext]
65+
66+
6167
# {{{ array context fixture
6268

6369
class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext):
@@ -1166,15 +1172,16 @@ def my_rhs(scale, vel):
11661172
np.testing.assert_allclose(result.v, 3.14*v_x)
11671173

11681174

1169-
def test_actx_compile_with_outlined_function(actx_factory):
1175+
def test_actx_compile_with_outlined_function(actx_factory: ArrayContextFactory):
11701176
actx = actx_factory()
11711177
rng = np.random.default_rng()
11721178

11731179
@actx.outline
1174-
def outlined_scale_and_orthogonalize(alpha, vel):
1180+
def outlined_scale_and_orthogonalize(alpha: float, vel: Velocity2D) -> Velocity2D:
11751181
return scale_and_orthogonalize(alpha, vel)
11761182

1177-
def multi_scale_and_orthogonalize(alpha, vel1, vel2):
1183+
def multi_scale_and_orthogonalize(
1184+
alpha: float, vel1: Velocity2D, vel2: Velocity2D) -> np.ndarray:
11781185
return make_obj_array([
11791186
outlined_scale_and_orthogonalize(alpha, vel1),
11801187
outlined_scale_and_orthogonalize(alpha, vel2)])
@@ -1193,10 +1200,10 @@ def multi_scale_and_orthogonalize(alpha, vel1, vel2):
11931200

11941201
result1 = actx.to_numpy(scaled_speed1)
11951202
result2 = actx.to_numpy(scaled_speed2)
1196-
np.testing.assert_allclose(result1.u, -3.14*v1_y)
1197-
np.testing.assert_allclose(result1.v, 3.14*v1_x)
1198-
np.testing.assert_allclose(result2.u, -3.14*v2_y)
1199-
np.testing.assert_allclose(result2.v, 3.14*v2_x)
1203+
np.testing.assert_allclose(result1.u, -3.14*v1_y) # pyright: ignore[reportAttributeAccessIssue]
1204+
np.testing.assert_allclose(result1.v, 3.14*v1_x) # pyright: ignore[reportAttributeAccessIssue]
1205+
np.testing.assert_allclose(result2.u, -3.14*v2_y) # pyright: ignore[reportAttributeAccessIssue]
1206+
np.testing.assert_allclose(result2.v, 3.14*v2_x) # pyright: ignore[reportAttributeAccessIssue]
12001207

12011208
# }}}
12021209

0 commit comments

Comments
 (0)