Skip to content

Commit

Permalink
Delete xla_client.execute_with_python_values.
Browse files Browse the repository at this point in the history
This is not a public API and exists only for testing.

PiperOrigin-RevId: 681453343
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 2, 2024
1 parent 9dc8ec1 commit 8cb57d6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 61 deletions.
34 changes: 0 additions & 34 deletions xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,40 +469,6 @@ def computation_count():
# There are different implementations of Executable for different backends.


def execute_with_python_values(executable, arguments, backend):
"""Execute on one replica with Python values as arguments and output."""

def put(arg):
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])

arguments = [put(arg) for arg in arguments]
outputs = executable.execute(arguments)
return [np.asarray(x) for x in outputs]


def execute_with_python_values_replicated(executable, arguments, backend):
"""Execute on many replicas with Python values as arguments and output.
Args:
executable: the program to run.
arguments: a list of lists of Python values indexed by `[replica][arg_num]`
to pass as inputs.
backend: the backend we are targeting.
Returns:
A list of python values, one per replica.
"""
devices = executable.local_devices()

# pylint: disable=g-complex-comprehension
def copy_to_devices(pyvals):
return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)]

inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)]
outputs = executable.execute_sharded_on_local_devices(inputs)
return [[np.asarray(x) for x in xs] for xs in zip(*outputs)]


class PaddingType(enum.Enum):
VALID = 1
SAME = 2
Expand Down
7 changes: 0 additions & 7 deletions xla/python/xla_client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType:
...

def execute_with_python_values(executable: LoadedExecutable, arguments: Sequence[Any],
backend: Client) -> Sequence[numpy.ndarray]: ...

def execute_with_python_values_replicated(
executable: LoadedExecutable, arguments: Sequence[Sequence[Any]],
backend: Client) -> Sequence[Sequence[numpy.ndarray]]: ...

def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ...

def heap_profile(client: Client) -> bytes:
Expand Down
49 changes: 29 additions & 20 deletions xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@
xla_client._xla.mlir.xla_computation_to_mlir_module)


def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name
"""Execute on one replica with Python values as arguments and output."""

def put(arg): # pylint: disable=invalid-name
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])

arguments = [put(arg) for arg in arguments]
outputs = executable.execute(arguments)
return [np.asarray(x) for x in outputs]


# pylint: disable=invalid-name
def jax_array_convert_to_array(self, dtype=None, copy=None):
del copy
Expand Down Expand Up @@ -164,7 +175,7 @@ def _NewComputation(self, name=None):
def _Execute(self, c, arguments):
compiled_c = self.backend.compile(
xla_computation_to_mlir_module(c.build()))
return xla_client.execute_with_python_values(
return execute_with_python_values(
compiled_c, arguments, backend=self.backend)

def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
Expand Down Expand Up @@ -596,7 +607,7 @@ def testExecuteFromProto(self):
# Load and execute the proto
c = xla_client.XlaComputation(serialized_proto)
m = xla_computation_to_mlir_module(c)
ans, = xla_client.execute_with_python_values(
ans, = execute_with_python_values(
self.backend.compile(m), (), backend=self.backend)
np.testing.assert_equal(ans, np.int32(3))

Expand Down Expand Up @@ -1245,7 +1256,7 @@ def testConvertElementType(self, src_dtype, dst_dtype):
ops.ConvertElementType(
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))

result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
self.assertLen(result, 1)
Expand Down Expand Up @@ -1275,7 +1286,7 @@ def testBitcastConvertType(self, src_dtype, dst_dtype):
ops.BitcastConvertType(
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))

result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
self.assertLen(result, 1)
Expand Down Expand Up @@ -1859,7 +1870,7 @@ def testTuple(self):
ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
])
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
self.assertLen(result, 3)
Expand Down Expand Up @@ -1899,7 +1910,7 @@ def testRngNormal(self):
ops.Constant(c, NumpyArrayF32(1.)),
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape))
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
# since the result is random, we just check shape and uniqueness
Expand All @@ -1916,7 +1927,7 @@ def testRngUniformF32(self):
ops.Constant(c, NumpyArrayF32(hi)),
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape))
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
# since the result is random, we just check shape, uniqueness, and range
Expand All @@ -1935,7 +1946,7 @@ def testRngUniformS32(self):
ops.Constant(c, NumpyArrayS32(hi)),
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
shape))
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
# since the result is random, we just check shape, integrality, and range
Expand Down Expand Up @@ -1965,7 +1976,7 @@ def testSortKeyVal(self):
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
c = self._NewComputation()
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
self.assertLen(result, 2)
Expand All @@ -1988,7 +1999,7 @@ def testSortCustomComparator(self):
c, (ops.Constant(c, keys), ops.Constant(c, values)),
dimension=1,
comparator=comparator)
result = xla_client.execute_with_python_values(
result = execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
backend=self.backend)
self.assertLen(result, 2)
Expand Down Expand Up @@ -2578,7 +2589,7 @@ def testInfeedS32Values(self):
device.transfer_to_infeed(item)

for item in to_infeed:
result, = xla_client.execute_with_python_values(
result, = execute_with_python_values(
compiled_c, (), backend=self.backend)
self.assertEqual(result, item)

Expand All @@ -2597,7 +2608,7 @@ def testInfeedTuple(self):
device = self.backend.local_devices()[0]
device.transfer_to_infeed(to_infeed)

result = xla_client.execute_with_python_values(
result = execute_with_python_values(
compiled_c, (), backend=self.backend)
self.assertLen(result, 2)
np.testing.assert_equal(result[0], to_infeed[0])
Expand Down Expand Up @@ -2741,7 +2752,7 @@ def testInvokeWithWrongElementType(self):
c.clear_op_metadata()

def TestFun():
return xla_client.execute_with_python_values(
return execute_with_python_values(
self.backend.compile(xla_computation_to_mlir_module(c.build())),
[self.f32_scalar_2], self.backend)

Expand All @@ -2763,7 +2774,7 @@ def testComputationRootDifferentFromLastOp(self):
arg = NumpyArrayF32(1.0)
compiled_c = self.backend.compile(
xla_computation_to_mlir_module(c.build(result)))
ans, = xla_client.execute_with_python_values(
ans, = execute_with_python_values(
compiled_c, [arg], backend=self.backend)
np.testing.assert_allclose(ans, 4.14)

Expand All @@ -2787,7 +2798,7 @@ def testSetSharding(self):
arg = NumpyArrayF32(1.0)
compiled_c = self.backend.compile(
xla_computation_to_mlir_module(c.build(result)))
ans, = xla_client.execute_with_python_values(
ans, = execute_with_python_values(
compiled_c, [arg], backend=self.backend)
np.testing.assert_allclose(ans, 4.14)

Expand Down Expand Up @@ -3128,7 +3139,7 @@ def testHloProgramViaIfrtProgram(self):
)

compiled_c = self.backend.compile_ifrt_program(program, options)
results = xla_client.execute_with_python_values(
results = execute_with_python_values(
compiled_c, arguments=(), backend=self.backend
)

Expand All @@ -3154,10 +3165,8 @@ def testExecutableSerialization(self):
serialized = self.backend.serialize_executable(executable)
deserialized = self.backend.deserialize_executable(serialized, options)

expected, = xla_client.execute_with_python_values(executable, (),
self.backend)
actual, = xla_client.execute_with_python_values(deserialized, (),
self.backend)
expected, = execute_with_python_values(executable, (), self.backend)
actual, = execute_with_python_values(deserialized, (), self.backend)
self.assertTrue(np.all(actual == expected))

def testCompileOptionsSerialization(self):
Expand Down

0 comments on commit 8cb57d6

Please sign in to comment.