Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearOp with multiple sharded dimensions #3770

Open
samnordmann opened this issue Jan 28, 2025 · 2 comments
Open

LinearOp with multiple sharded dimensions #3770

samnordmann opened this issue Jan 28, 2025 · 2 comments

Comments

@samnordmann
Copy link
Collaborator

samnordmann commented Jan 28, 2025

The problem

Context: adding HostIr overlap algorithm to transformer's fwd mlp layer.

Let us consider the inputs of a linear layer:

x [O, DID{D}, B * S / (D * O) , E]
w0[DID{D}, 4 * E / D, E]
b0[DID{D}, 4 * E / D]

Compared to before, we added on x the axis "O", a tile of x's batch*sequence axis, corresponding to the "Stream" parallelization.

We would like to define the output as follows:

linear0 = linear(x, w0, b0) [Stream{O}, DID{D}, D, B * S / (D * O), 4 * E / D]

where, on linear0:

  • axis(1) "DID{D}" comes from w0 and b0's axis(0). This axis doesn't require any resharding to be produced
  • axis(2) "D" comes from x's axis(1), after being allgathered.

However, currently, the obtained shape of linear0 is [D, O, D, B * S / (D * O), 4 * E / D], i.e., the axis are not ordered the way we want. More precisely, we get:

T5_g___bfloat[iStream13{8}, iS14{2}, ideviceIdx.x15{8}, iS16{16}, iS17{384}, rS18{768}] (DeviceMesh{0 1 2 3 4 5 6 7})
   = linear(T0_g___bfloat[iS0{2}, ideviceIdx.x1{8}, iS2{16}, iS3{768}] (DeviceMesh{0 1 2 3 4 5 6 7}),
            T1_g___bfloat[ideviceIdx.x4{8}, iS5{384}, iS6{768}] (DeviceMesh{0 1 2 3 4 5 6 7})  ,
          T2_g___bfloat[ideviceIdx.x7{8}, iS8{384}] (DeviceMesh{0 1 2 3 4 5 6 7})  )

The problems is that "linear" only accept w0 to be 2D, and b0 to be 1D. Currently, this problem is bypassed by manually handling this outermost sharded axis. The current behavior doesn't allow the case we are considering here. See also probably related to this comment: https://github.com/NVIDIA/Fuser/blob/main/csrc/ops/composite.cpp#L174.

What is the best approach to achieve this goal? Possible solutions I can think of:

  1. Patch linear op to make it accept more general shape, with some restrictions, and maybe imposing on the user to properly broadcast the necessary dimension (like in at::matmul).
  2. I suggest reshaping the tensor and collapsing all outer dimensions to make it 2D, then re-expand it. But if I do that, I'm afraid to lose the DID parallelization.
  3. Another solution would be to manually add a "set" operation to represent allgather-ing x. But I'd rather not as the distributed matmul is treated as a single resharding op by the host IR lowering to create the pipeline algorithm (we could also change this behavior of course...)

Wdyt?

With a Matmul

I am not sure this is completely related, but IMO it can help the discussion. I tried to emulate this problem by replacing the LinearOp by a MatmulOp. This op is more convenient since at::matmul is more flexible and accepts any dimensions in the inputs as long as the user manually broadcast the necessary dimensions so that the inputs match. I wrote the following test:

TEST_F(MultiDeviceTest, DoubleShardings) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* x = makeContigTensor(3); //[DID{D}, M, K]
  TensorView* w = makeContigTensor(3); // [DID{D'}, K, N]
  TensorView* xb = unsqueeze(x, {1}); // [DID{D}, b1, M, K]
  TensorView* wb = unsqueeze(w, {0}); // [b1, DID{D'}, K, N]
  TensorView* y = matmul(xb, wb); // [D, DID{D'}, M, N]
  fusion->addInput(x);
  fusion->addInput(w);
  fusion->addOutput(y);

  auto mesh = DeviceMesh::createForNumDevices(communicator_->size());
  x->setDeviceMesh(mesh);
  w->setDeviceMesh(mesh);
  xb->setDeviceMesh(mesh);
  wb->setDeviceMesh(mesh);
  w->axis(0)->parallelize(ParallelType::DIDx);
  wb->axis(1)->parallelize(ParallelType::DIDx);
  x->axis(0)->parallelize(ParallelType::DIDx);
  xb->axis(0)->parallelize(ParallelType::DIDx);
  y->axis(1)->parallelize(ParallelType::DIDx);

  constexpr int64_t kLowerBound = 0;
  constexpr int64_t kUpperBound = 10;
  auto x_tensor = at::randint(kLowerBound, kUpperBound, {communicator_->size(), 2, 3}, tensor_options);
  auto w_tensor = at::randint(
      kLowerBound, kUpperBound, {communicator_->size(), 3, 5}, tensor_options);
  auto sharded_x_tensor = shardTensor(w_tensor, x);
  auto sharded_w_tensor = shardTensor(w_tensor, w);
  at::Tensor expected_y_tensor = at::matmul(x_tensor.unsqueeze(1), sharded_w_tensor.unsqueeze(0));

  FusionExecutorCache executor_cache(std::move(fusion));
  std::vector<c10::IValue> inputs({sharded_x_tensor, sharded_w_tensor});
  std::vector<at::Tensor> outputs = executor_cache.runFusionWithInputs(inputs);

  testValidate(
      executor_cache.fusion(),
      outputs,
      inputs,
      {shardTensor(expected_y_tensor, 0, mesh)},
      __LINE__,
      __FILE__);
}

But I am getting an error:

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/Fuser/csrc/expr_evaluator.cpp":438, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. When trying to propagate constant tensor sizes through the graph a conf
lict was found with 2 different sizes across dimensions that are expected to match.
  For Producer TV: T3_l_float[bS10{1}, ideviceIdx.x11{i4}, iS12{i5}, iS13{i6}] (DeviceMesh{0 1}) id: iS12{i5} found size: 3
  For Consumer TV: T4_g_float[iS14{i0}, ideviceIdx.x15{i4}, iS16{i2}, iS17{i6}, rS18{i3}] id: rS18{i3} found size: 5
  With producer-consumer relationship through the expression: T4_g_float[iS14{i0}, ideviceIdx.x15{i4}, iS16{i2}, iS17{i6}, rS18{i3}]
   = matmul(T2_l_float[ideviceIdx.x6{i0}, bS7{1}, iS8{i2}, iS9{i3}] (DeviceMesh{0 1}),
            T3_l_float[bS10{1}, ideviceIdx.x11{i4}, iS12{i5}, iS13{i6}] (DeviceMesh{0 1}))

I feel it indicates something is broken in the logic, but I'd be happy to hear your thoughts

@wujingyue @naoyam @cowanmeg @Priya2698

@wujingyue
Copy link
Collaborator

The problems is that "linear" only accept w0 to be 2D, and b0 to be 1D.

I think this has been fixed by #3073. No?

@wujingyue
Copy link
Collaborator

Anyhow, extending LinearOp to support more shapes makes sense. #3073 has been on that line. I'm unsure why it failed for you.

Fixing #2563 will be the ultimate fix. This way, we don't need to extend the "logical" definition of LinearOp. #3650 from @Priya2698 may have already fixed some aspects, but I wouldn't be surprised if it doesn't work out of the box for the new Stream parallel type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants