-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pallas] Move the hardware_generation query in the code path that nee…
…ds it This change allows us to lower and export Pallas calls even on machines that do not have TPUs, in many cases. PiperOrigin-RevId: 641841079
- Loading branch information
Showing
3 changed files
with
94 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Test exporting Pallas kernels.""" | ||
|
||
from absl.testing import absltest | ||
import jax | ||
from jax._src import test_util as jtu | ||
from jax.experimental import export | ||
# Import mosaic for flag definitions | ||
from jax.experimental import mosaic as _ # noqa: F401 | ||
from jax.experimental import pallas as pl | ||
import numpy as np | ||
|
||
|
||
jax.config.parse_flags_with_absl() | ||
|
||
|
||
class ExportTest(jtu.JaxTestCase): | ||
|
||
def test_cross_platform(self): | ||
def add_vectors_kernel(x_ref, y_ref, o_ref): | ||
x, y = x_ref[...], y_ref[...] | ||
o_ref[...] = x + y | ||
|
||
@jax.jit | ||
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: | ||
return pl.pallas_call(add_vectors_kernel, | ||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) | ||
)(x, y) | ||
|
||
a = np.arange(8) | ||
exp = export.export( | ||
add_vectors, | ||
# TODO(necula): Make this test work on GPU also | ||
lowering_platforms=["tpu"], | ||
)(a, a) | ||
|
||
if jtu.device_under_test() == "tpu": | ||
res = export.call(exp)(a, a) | ||
self.assertAllClose(res, a + a) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |