From 8cb57d6a33dfd628fa9a750931b5292e84e6435b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 2 Oct 2024 07:58:17 -0700 Subject: [PATCH] Delete xla_client.execute_with_python_values. This is not a public API and exists only for testing. PiperOrigin-RevId: 681453343 --- xla/python/xla_client.py | 34 ------------------------ xla/python/xla_client.pyi | 7 ----- xla/python/xla_client_test.py | 49 +++++++++++++++++++++-------------- 3 files changed, 29 insertions(+), 61 deletions(-) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 5885155dcd887..5cc12efa93709 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -469,40 +469,6 @@ def computation_count(): # There are different implementations of Executable for different backends. -def execute_with_python_values(executable, arguments, backend): - """Execute on one replica with Python values as arguments and output.""" - - def put(arg): - return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) - - arguments = [put(arg) for arg in arguments] - outputs = executable.execute(arguments) - return [np.asarray(x) for x in outputs] - - -def execute_with_python_values_replicated(executable, arguments, backend): - """Execute on many replicas with Python values as arguments and output. - - Args: - executable: the program to run. - arguments: a list of lists of Python values indexed by `[replica][arg_num]` - to pass as inputs. - backend: the backend we are targeting. - - Returns: - A list of python values, one per replica. - """ - devices = executable.local_devices() - - # pylint: disable=g-complex-comprehension - def copy_to_devices(pyvals): - return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)] - - inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)] - outputs = executable.execute_sharded_on_local_devices(inputs) - return [[np.asarray(x) for x in xs] for xs in zip(*outputs)] - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 8731080c99b52..898a632ab340a 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -71,13 +71,6 @@ _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: ... -def execute_with_python_values(executable: LoadedExecutable, arguments: Sequence[Any], - backend: Client) -> Sequence[numpy.ndarray]: ... - -def execute_with_python_values_replicated( - executable: LoadedExecutable, arguments: Sequence[Sequence[Any]], - backend: Client) -> Sequence[Sequence[numpy.ndarray]]: ... - def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... def heap_profile(client: Client) -> bytes: diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 85e451d11e706..edd582e405c26 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -64,6 +64,17 @@ xla_client._xla.mlir.xla_computation_to_mlir_module) +def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): # pylint: disable=invalid-name + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + + arguments = [put(arg) for arg in arguments] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] + + # pylint: disable=invalid-name def jax_array_convert_to_array(self, dtype=None, copy=None): del copy @@ -164,7 +175,7 @@ def _NewComputation(self, name=None): def _Execute(self, c, arguments): compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build())) - return xla_client.execute_with_python_values( + return execute_with_python_values( compiled_c, arguments, backend=self.backend) def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): @@ -596,7 +607,7 @@ def testExecuteFromProto(self): # Load and execute the proto c = xla_client.XlaComputation(serialized_proto) m = xla_computation_to_mlir_module(c) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( self.backend.compile(m), (), backend=self.backend) np.testing.assert_equal(ans, np.int32(3)) @@ -1245,7 +1256,7 @@ def testConvertElementType(self, src_dtype, dst_dtype): ops.ConvertElementType( ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 1) @@ -1275,7 +1286,7 @@ def testBitcastConvertType(self, src_dtype, dst_dtype): ops.BitcastConvertType( ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 1) @@ -1859,7 +1870,7 @@ def testTuple(self): ops.Constant(c, NumpyArrayF32([1.0, 2.0])), ops.Constant(c, NumpyArrayBool([True, False, False, True])) ]) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 3) @@ -1899,7 +1910,7 @@ def testRngNormal(self): ops.Constant(c, NumpyArrayF32(1.)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape and uniqueness @@ -1916,7 +1927,7 @@ def testRngUniformF32(self): ops.Constant(c, NumpyArrayF32(hi)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape, uniqueness, and range @@ -1935,7 +1946,7 @@ def testRngUniformS32(self): ops.Constant(c, NumpyArrayS32(hi)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape, integrality, and range @@ -1965,7 +1976,7 @@ def testSortKeyVal(self): values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) c = self._NewComputation() ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 2) @@ -1988,7 +1999,7 @@ def testSortCustomComparator(self): c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=1, comparator=comparator) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 2) @@ -2578,7 +2589,7 @@ def testInfeedS32Values(self): device.transfer_to_infeed(item) for item in to_infeed: - result, = xla_client.execute_with_python_values( + result, = execute_with_python_values( compiled_c, (), backend=self.backend) self.assertEqual(result, item) @@ -2597,7 +2608,7 @@ def testInfeedTuple(self): device = self.backend.local_devices()[0] device.transfer_to_infeed(to_infeed) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( compiled_c, (), backend=self.backend) self.assertLen(result, 2) np.testing.assert_equal(result[0], to_infeed[0]) @@ -2741,7 +2752,7 @@ def testInvokeWithWrongElementType(self): c.clear_op_metadata() def TestFun(): - return xla_client.execute_with_python_values( + return execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), [self.f32_scalar_2], self.backend) @@ -2763,7 +2774,7 @@ def testComputationRootDifferentFromLastOp(self): arg = NumpyArrayF32(1.0) compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build(result))) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -2787,7 +2798,7 @@ def testSetSharding(self): arg = NumpyArrayF32(1.0) compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build(result))) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -3128,7 +3139,7 @@ def testHloProgramViaIfrtProgram(self): ) compiled_c = self.backend.compile_ifrt_program(program, options) - results = xla_client.execute_with_python_values( + results = execute_with_python_values( compiled_c, arguments=(), backend=self.backend ) @@ -3154,10 +3165,8 @@ def testExecutableSerialization(self): serialized = self.backend.serialize_executable(executable) deserialized = self.backend.deserialize_executable(serialized, options) - expected, = xla_client.execute_with_python_values(executable, (), - self.backend) - actual, = xla_client.execute_with_python_values(deserialized, (), - self.backend) + expected, = execute_with_python_values(executable, (), self.backend) + actual, = execute_with_python_values(deserialized, (), self.backend) self.assertTrue(np.all(actual == expected)) def testCompileOptionsSerialization(self):