Skip to content

Commit

Permalink
[pallas] Move the hardware_generation query in the code path that nee…
Browse files Browse the repository at this point in the history
…ds it

This change allows us to lower and export Pallas calls even
on machines that do not have TPUs, in many cases.

PiperOrigin-RevId: 641841079
  • Loading branch information
gnecula authored and jax authors committed Jun 10, 2024
1 parent af95803 commit 2ade7e7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 5 deletions.
14 changes: 9 additions & 5 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,12 @@ def as_tpu_kernel(
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if not device_kind.startswith("TPU v"):
raise ValueError(f"Unrecognized TPU device kind: {device_kind}.")

if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int):
raise ValueError(
"vmem_limit_bytes must be an int: provided with a"
f" {type(vmem_limit_bytes)}."
)
hardware_generation = int(device_kind[len("TPU v")])
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
)
Expand All @@ -405,6 +401,14 @@ def as_tpu_kernel(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if not device_kind.startswith("TPU v"):
raise ValueError(
f"Unrecognized TPU device kind: {device_kind}. "
"tpu_custom_call cannot be lowered on a machine without TPUs "
"when mosaic_use_python_pipeline=True.")
hardware_generation = int(device_kind[len("TPU v")])
module = _lower_tpu_kernel(module, hardware_generation)
needs_hlo_passes = False
needs_layout_passes = False
Expand Down
29 changes: 29 additions & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,32 @@ jax_test(
"//jax:pallas_gpu", # build_cleaner: keep
],
)

jax_test(
name = "export_pallas_test",
srcs = ["export_pallas_test.py"],
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
enable_configs = [
"gpu_a100_x32",
],
tags = [],
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu", # build_cleaner: keep
"//jax/experimental/export",
],
)
56 changes: 56 additions & 0 deletions tests/pallas/export_pallas_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test exporting Pallas kernels."""

from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental import export
# Import mosaic for flag definitions
from jax.experimental import mosaic as _ # noqa: F401
from jax.experimental import pallas as pl
import numpy as np


jax.config.parse_flags_with_absl()


class ExportTest(jtu.JaxTestCase):

def test_cross_platform(self):
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)

a = np.arange(8)
exp = export.export(
add_vectors,
# TODO(necula): Make this test work on GPU also
lowering_platforms=["tpu"],
)(a, a)

if jtu.device_under_test() == "tpu":
res = export.call(exp)(a, a)
self.assertAllClose(res, a + a)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2ade7e7

Please sign in to comment.