We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
F ./torch_xla/csrc/runtime/debug_macros.h:20] Non-OK-status: status.status() status: INTERNAL: during context [hlo verifier]: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:594) hlo->operand_count() == split_count Failed after pipeline-start *** Begin stack trace *** tsl::CurrentStackTrace() std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230125::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&) torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >) torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&) torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool) torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, bool, bool, bool) torch_xla::XLATensor::ApplyPendingGraph() torch_xla::XLATensor::GetXlaData() torch_xla::XLATensor::ToTensor(bool) torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) at::native::to(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>) at::_ops::to_dtype_layout::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>) _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault PyEval_EvalCode _PyRun_SimpleFileObject _PyRun_AnyFileObject Py_RunMain Py_BytesMain __libc_start_main *** End stack trace ***
Codes to reproduce the behavior:
import torch import torch.nn as nn import numpy as np import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr] import os os.environ['PJRT_DEVICE'] = 'CPU' os.environ['CPU_NUM_DEVICES'] = '5' mesh_shape = (5, 1) # 5 个设备 num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) B, N, C = 1, 5, 30 H = 10 input_tensor = torch.randn(B, N, C).to(xm.xla_device()) partition_spec = (None, 0, None) xs.mark_sharding(input_tensor, mesh, partition_spec) temp = xm.all_to_all(input_tensor, split_dimension=2, concat_dimension=1, split_count=num_devices,) hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([temp]) print(hlo_text)
an all_to_all operation in mlir
reduce_scatter .etc also produce bugs
The text was updated successfully, but these errors were encountered:
No branches or pull requests
🐛 Bug
To Reproduce
Codes to reproduce the behavior:
Expected behavior
an all_to_all operation in mlir
Environment
Additional context
reduce_scatter .etc also produce bugs
The text was updated successfully, but these errors were encountered: