Skip to content

Commit

Permalink
Merge pull request #295 from superbobry:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675199892
  • Loading branch information
The jax_triton Authors committed Sep 16, 2024
2 parents 973e106 + a72a159 commit 4ebaed0
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 40 deletions.
3 changes: 1 addition & 2 deletions examples/block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import functools

from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -113,7 +112,7 @@ def mha(q, k, v, *,
sm_scale: float = 1.0,
block_q: int = 128,
block_k: int = 128,
num_warps: Optional[int] = None,
num_warps: int | None = None,
num_stages: int = 1,
grid=None,
):
Expand Down
3 changes: 1 addition & 2 deletions examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import functools
import timeit

from typing import Tuple

import jax.numpy as jnp
from jax import random
Expand Down Expand Up @@ -63,7 +62,7 @@ class BlockELL:
blocks: jnp.ndarray # float32[n_rows, n_blocks, *block_size]
blocks_per_row: jnp.ndarray # int32[n_rows, n_blocks]
indices: jnp.ndarray # int32[n_rows, max_num_blocks_per_row, 2]
shape: Tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1])
shape: tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1])

ndim: int = property(lambda self: len(self.shape))
num_blocks = property(lambda self: self.blocks.shape[0])
Expand Down
4 changes: 2 additions & 2 deletions examples/pallas/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def body(k, acc_refs):
accs = for_loop.for_loop(num_k_blocks, body, [acc_i, acc_f, acc_o, acc_g])
bs = [pl.load(b_ref, (idx_n,))
for b_ref in [b_hi_ref, b_hf_ref, b_hg_ref, b_ho_ref]]
acc_i, acc_f, acc_g, acc_o = [acc + b for acc, b in zip(accs, bs)]
acc_i, acc_f, acc_g, acc_o = (acc + b for acc, b in zip(accs, bs))
i_gate, f_gate, o_gate = (
jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o))
cell = jnp.tanh(acc_g)
Expand Down Expand Up @@ -124,7 +124,7 @@ def lstm_cell_reference(weights, x, h, c):
xs = [jnp.dot(x, w) for w in ws]
hs = [jnp.dot(h, u) for u in us]
accs = [x + h for x, h in zip(xs, hs)]
acc_i, acc_f, acc_g, acc_o = [acc + b[None] for acc, b in zip(accs, bs)]
acc_i, acc_f, acc_g, acc_o = (acc + b[None] for acc, b in zip(accs, bs))
i_gate, f_gate, o_gate = (
jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o))
cell = jnp.tanh(acc_g)
Expand Down
4 changes: 2 additions & 2 deletions jax_triton/experimental/fusion/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import os

from typing import Any, Tuple
from typing import Any

import jax
from jax import lax
Expand Down Expand Up @@ -204,7 +204,7 @@ def make_elementwise(shape, dtype, *args):
class MatmulElementwise(jax_rewrite.JaxExpression):
x: jax_rewrite.JaxExpression
y: jax_rewrite.JaxExpression
elem_ops: Tuple[core.Primitive]
elem_ops: tuple[core.Primitive]

def match(self, expr, bindings, succeed):
if not isinstance(expr, MatmulElementwise):
Expand Down
25 changes: 13 additions & 12 deletions jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import dataclasses
import itertools as it

from typing import Any, Callable, List, Tuple, Union
from typing import Any
from collections.abc import Callable

from jax._src import core as jax_core
import jax.numpy as jnp
Expand All @@ -35,7 +36,7 @@
class Node(matcher.Pattern, metaclass=abc.ABCMeta):

@abc.abstractproperty
def parents(self) -> List[Node]:
def parents(self) -> list[Node]:
...


Expand All @@ -51,9 +52,9 @@ def map_parents(self, fn: Callable[[Node], Node]) -> Node:
class Eqn(Node):
primitive: jax_core.Primitive
params: jr.Params
invars: List[Node]
shape: Union[Tuple[int, ...], List[Tuple[int, ...]]]
dtype: Union[jnp.dtype, List[jnp.dtype]]
invars: list[Node]
shape: tuple[int, ...] | list[tuple[int, ...]]
dtype: jnp.dtype | list[jnp.dtype]

@property
def parents(self):
Expand All @@ -77,7 +78,7 @@ def match(self, expr, bindings, succeed):

@dataclasses.dataclass(frozen=True, eq=False)
class JaxprVar(Node):
shape: Tuple[int, ...]
shape: tuple[int, ...]
dtype: jnp.dtype

def match(self, expr, bindings, succeed):
Expand Down Expand Up @@ -131,7 +132,7 @@ def from_literal(cls, var: jax_core.Literal) -> Literal:
@dataclasses.dataclass(eq=False)
class Part(Node):
index: int
shape: Tuple[int, ...]
shape: tuple[int, ...]
dtype: jnp.dtype
parent: Node

Expand All @@ -153,9 +154,9 @@ def map_parents(self, fn):

@dataclasses.dataclass(eq=True)
class JaxprGraph(matcher.Pattern):
constvars: List[Node]
invars: List[Node]
outvars: List[Node]
constvars: list[Node]
invars: list[Node]
outvars: list[Node]

def get_nodes(self):
nodes = set(self.outvars)
Expand All @@ -167,7 +168,7 @@ def get_nodes(self):
queue.append(p)
return nodes

def get_children(self, node) -> List[Node]:
def get_children(self, node) -> list[Node]:
nodes = self.get_nodes()
return [n for n in nodes if node in n.parents]

Expand Down Expand Up @@ -274,7 +275,7 @@ def to_jaxpr(self) -> jax_core.Jaxpr:
outvars = [env[n] for n in self.outvars]
return jax_core.Jaxpr(constvars, invars, outvars, eqns, jax_core.no_effects)

def toposort(self) -> List[Node]:
def toposort(self) -> list[Node]:
node_stack = list(self.outvars)
child_counts = {}
while node_stack:
Expand Down
4 changes: 2 additions & 2 deletions jax_triton/experimental/fusion/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Contains lowering passes for jaxprs to pallas."""
import functools

from typing import Any, Dict
from typing import Any

import jax
from jax import api_util
Expand Down Expand Up @@ -317,7 +317,7 @@ def read(v: core.Atom) -> Any:
def write(v: Var, val: Any) -> None:
env[v] = val

env: Dict[Var, Any] = {}
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand Down
35 changes: 18 additions & 17 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import pprint
import tempfile
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
from typing import Any, Protocol, Union
from collections.abc import Callable, Sequence
import zlib
from functools import partial

Expand Down Expand Up @@ -102,11 +103,11 @@
jnp.dtype("bool"): "B",
}

Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]
Grid = Union[int, tuple[int], tuple[int, int], tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[dict[str, Any]], Grid]]


def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]:
def normalize_grid(grid: GridOrLambda, metaparams) -> tuple[int, int, int]:
if callable(grid):
grid = grid(metaparams)
if isinstance(grid, int):
Expand Down Expand Up @@ -186,8 +187,8 @@ class CompilationResult:
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]
ttgir: str | None
llir: str | None

def compile_ttir_inplace(
ttir,
Expand Down Expand Up @@ -375,7 +376,7 @@ def get_or_create_triton_kernel(
enable_fp_fusion,
metaparams,
dump: bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
) -> tuple[triton_kernel_call_lib.TritonKernel, Any]:
if num_warps is None:
num_warps = 4
if num_stages is None:
Expand Down Expand Up @@ -730,7 +731,7 @@ def prune_configs(configs, named_args, **kwargs):
class ShapeDtype(Protocol):

@property
def shape(self) -> Tuple[int, ...]:
def shape(self) -> tuple[int, ...]:
...

@property
Expand All @@ -739,21 +740,21 @@ def dtype(self) -> np.dtype:


def triton_call(
*args: Union[jax.Array, bool, int, float, np.float32],
*args: jax.Array | bool | int | float | np.float32,
kernel: triton.JITFunction,
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
out_shape: ShapeDtype | Sequence[ShapeDtype],
grid: GridOrLambda,
name: str = "",
custom_call_target_name: str = "triton_kernel_call",
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_warps: int | None = None,
num_stages: int | None = None,
num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple.
compute_capability: Optional[int] = None,
compute_capability: int | None = None,
enable_fp_fusion: bool = True,
input_output_aliases: Optional[Dict[int, int]] = None,
zeroed_outputs: Union[
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
] = (),
input_output_aliases: dict[int, int] | None = None,
zeroed_outputs: (
Sequence[int] | Callable[[dict[str, Any]], Sequence[int]]
) = (),
debug: bool = False,
serialized_metadata: bytes = b"",
**metaparams: Any,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies = [
"absl-py>=1.4.0",
"jax>=0.4.31",
"triton>=3.0",
"setuptools", # triton seems to need this when installing itself.
]

[project.optional-dependencies]
Expand Down

0 comments on commit 4ebaed0

Please sign in to comment.