Skip to content

Commit

Permalink
Run unit tests across Python versions 3.10-3.12. (#326)
Browse files Browse the repository at this point in the history
Also updated enough of the project to pass these new test
configurations:

* Avoid starred expressions in indexes for Python < 3.11
(https://stackoverflow.com/a/77080166,
https://stackoverflow.com/q/77580216)
* Import `Self` from `typing_extensions` for Python < 3.11
(https://stackoverflow.com/a/77247460)
* Limit use of `logger.getLevelNamesMapping` to Python >= 3.11
(https://docs.python.org/3/library/logging.html#logging.getLevelNamesMapping,
see also
https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks)
* Stop using `unittest` in files that don't need it, fixing exit code 5
in Python >= 3.12 for "no tests were run"
(https://docs.python.org/3/library/unittest.html#unittest.main)

---------

Co-authored-by: Marius Brehler <[email protected]>
  • Loading branch information
ScottTodd and marbre authored Dec 11, 2024
1 parent d509242 commit bc35630
Show file tree
Hide file tree
Showing 15 changed files with 60 additions and 45 deletions.
25 changes: 23 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ concurrency:

jobs:
test:
name: "Unit Tests and Type Checking"
name: "${{ matrix.os }} :: ${{ matrix.version }} :: Unit Tests and Type Checking"
strategy:
fail-fast: false
matrix:
version: [3.11]
# Support for Python 3.13 depends on https://github.com/pytorch/pytorch/issues/130249
version: ["3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{matrix.os}}
env:
Expand Down Expand Up @@ -71,3 +72,23 @@ jobs:
if: ${{ !cancelled() }}
run: |
mypy
# Depends on all other jobs to provide an aggregate job status.
ci_summary:
if: always()
runs-on: ubuntu-20.04
needs:
- test
steps:
- name: Getting failed jobs
run: |
echo '${{ toJson(needs) }}'
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
| jq --raw-output \
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
)"
echo "failed-jobs=${FAILED_JOBS}" >> $GITHUB_OUTPUT
if [[ "${FAILED_JOBS}" != "" ]]; then
echo "The following jobs failed: ${FAILED_JOBS}"
exit 1
fi
10 changes: 6 additions & 4 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ def only_write_dependencies(node):
# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
(
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
)
]
return

Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
ClassVar,
Iterable,
Optional,
Self,
Type,
TypeAlias,
TypeVar,
Expand All @@ -17,6 +16,7 @@

from sympy import Symbol
from sympy.core.expr import Expr
from typing_extensions import Self

from itertools import chain

Expand Down
20 changes: 10 additions & 10 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
Any,
Callable,
Optional,
Self,
Sequence,
Type,
TypeVar,
final,
)
from typing_extensions import Self
import torch.fx as fx

from ..lang.wave_types import Memory, Register, IndexMapping
Expand Down Expand Up @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

@property
def type(self) -> "Memory":
return Memory[*self.shape, self.address_space, self.dtype]
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("shared_memory_barrier")
Expand Down Expand Up @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

def infer_type(self):
self.type = Register[*self.shape, self.dtype]
self.type = Register[(*self.shape, self.dtype)]


@define_op("mma")
Expand Down Expand Up @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
dtype = self.memory_type.dtype
self.type = Register[*self.indexing_dims, dtype]
self.type = Register[(*self.indexing_dims, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]
self.type = Memory[(*self.indexing_dims, address_space, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def infer_type(self):
dst_shape = list(src_type.symbolic_shape)
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
dst_type = Register[(*dst_shape, src_type.dtype)]
self.type = dst_type


Expand Down Expand Up @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
self.type = Register[*self.target_shape, src_dtype]
self.type = Register[(*self.target_shape, src_dtype)]


@define_interface_op("max")
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def infer_type(self):
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
dst_type = Register[(*reduced_dims, src_type.dtype)]
self.type = dst_type

@property
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
self.type = Register[*src_shape, self.dtype]
self.type = Register[(*src_shape, self.dtype)]


@define_op("permute")
Expand All @@ -1488,7 +1488,7 @@ def infer_type(self):
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
self.type = Register[*self.target_shape, src_type.dtype]
self.type = Register[(*self.target_shape, src_type.dtype)]

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
Expand Down
6 changes: 3 additions & 3 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def schedule_reduction(
# is not dynamic.
max_induction_variable = int(max_induction_variable)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
else:
# Otherwise, we need to rely on assumptions provided by the author.
assumptions = get_assumptions(constraints)
if not assumptions:
logger.warn(
logger.warning(
"No assumptions provided to determine if the loop can be pipelined. Skipping pipelining."
)
return {}
Expand All @@ -122,7 +122,7 @@ def schedule_reduction(
constraints, max_induction_variable > scheduler.num_stages - 1
)
if not result:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
Expand Down
18 changes: 11 additions & 7 deletions iree/turbine/support/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import re
import os
import sys
import torch
import numpy as np

Expand Down Expand Up @@ -54,7 +55,7 @@ class DebugFlags:
def set(self, part: str):
m = re.match(SETTING_PART_PATTERN, part)
if not m:
logger.warn("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
logger.warning("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
return
name = m.group(2)
value = m.group(4)
Expand All @@ -64,19 +65,22 @@ def set(self, part: str):
logical_sense = m.group(1) != "-"

if name == "log_level":
log_level_mapping = logging.getLevelNamesMapping()
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warn("Log level '%s' unknown (ignored)", value)
if sys.version_info >= (3, 11):
log_level_mapping = logging.getLevelNamesMapping() # Added in 3.11
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warning("Log level '%s' unknown (ignored)", value)
else:
logger.warning("'log_level' flag requires Python >= 3.11")
elif name == "asserts":
self.asserts = logical_sense
global NDEBUG
NDEBUG = not logical_sense
elif name == "runtime_trace_dir":
self.runtime_trace_dir = value
else:
logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)
logger.warning("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)

@staticmethod
def parse(settings: str) -> "DebugFlags":
Expand Down
8 changes: 4 additions & 4 deletions iree/turbine/tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(load_indices))]
for i in range(*result_shape):
ind = [int(x) + y for x, y in zip(load_indices, offset)]
value[i] = memref[*ind]
value[i] = memref[(*ind,)]
offset[-1] += 1
case vector_d.ExtractStridedSliceOp:
vector = self.symbol_table[op.vector]
Expand All @@ -168,7 +168,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(store_indices))]
for i in range(*result_shape):
memref[
*[int(x) + y for x, y in zip(store_indices, offset)]
(*[int(x) + y for x, y in zip(store_indices, offset)],)
] = vector[i]
offset[-1] += 1
case vector_d.MaskedStoreOp:
Expand All @@ -185,7 +185,7 @@ def callback(self, op: Operation) -> None:
for i in range(*result_shape):
if mask[i]:
ind = [int(x) + y for x, y in zip(store_indices, offset)]
memref[*ind] = vector[i]
memref[(*ind,)] = vector[i]

offset[-1] += 1
case vector_d.ConstantMaskOp:
Expand Down Expand Up @@ -313,7 +313,7 @@ def interpret_ndrange(
):
for wg in np.ndindex(*workgroup_count):
for t in np.ndindex(*workgroup_size):
Interpreter([*wg], [*t]).interpret(asm)
Interpreter([(*wg,)], [(*t,)]).interpret(asm)


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -262,4 +260,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -1109,4 +1108,3 @@ def test_chained_gemm_32x32x8():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -350,4 +349,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -267,4 +265,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -216,4 +215,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -246,4 +245,3 @@ def test_gemm_pipelined():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lit==18.1.7
mypy==1.8.0
ml_dtypes==0.5.0
setuptools
typing_extensions
wheel

# It is expected that you have installed a PyTorch version/variant specific
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def initialize_options(self):
"torch>=2.3.0",
f"Jinja2{get_version_spec('Jinja2')}",
f"ml_dtypes{get_version_spec('ml_dtypes')}",
f"typing_extensions{get_version_spec('typing_extensions')}",
],
extras_require={
"testing": [
Expand Down

0 comments on commit bc35630

Please sign in to comment.