Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] make GroupedLinear inp support collection of torch.Tensor #1120

Closed

Conversation

BeingGod
Copy link
Contributor

@BeingGod BeingGod commented Aug 19, 2024

Description

For FP8 GroupedMLP linear_fc1, to make sure Tensor shape is aligned by 16 we will split activation and pad each tensor then concat list of Tensor as GroupedLinear inp args.

e.g.

class TEGroupedMLP(MegatronModule):
...
    forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor):
        def _pad_tensor(inp: torch.Tensor):
            if inp.shape[0] % 16 == 0:
                return inp
            pad_len = (inp.shape[0] + 15) // 16 * 16 - inp.shape[0]
            pad_tensor = torch.zeros(pad_len, inp.shape[1], dtype=inp.dtype, device=inp.device)
            return torch.cat((inp, pad_tensor), dim=0)

        tokens_per_expert = tokens_per_expert.tolist()
        if self.config.fp8:
            splited_permuted_local_hidden_states = torch.split(
                permuted_local_hidden_states.view(-1, permuted_local_hidden_states.shape[-1]), tokens_per_expert)
            splited_pad_permuted_local_hidden_states = [_pad_tensor(hidden_states) for hidden_states in splited_permuted_local_hidden_states]
            orig_tokens_per_expert = tokens_per_expert
            tokens_per_expert = [hidden_states.shape[0] for hidden_states in splited_pad_permuted_local_hidden_states]
            permuted_local_hidden_states = torch.cat(splited_pad_permuted_local_hidden_states, dim=0)

        intermediate_parallel, bias_parallel = self.linear_fc1(
            permuted_local_hidden_states, tokens_per_expert
        )
...

In _GroupedLinear it will split inp base m_splits. Actually the cat and split is duplicated in this place. We hope _GroupedLinear can accept inp as a collection of Tensor (e.g List[torch.Tensor] or Tuple[torch.Tensor]) to reduce 2 * cat kernel call (1 * forward + 1 * backward).

profiling:
image

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Make GroupedLinear inp support collection of torch.Tensor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12
Copy link
Collaborator

yaox12 commented Aug 19, 2024

Hi, I think calling _pad_tensor for num_local_experts times introduces non-neglectable overheads (corresponding to the blue boxes IIUC).
image

So, our plan is to pad all the inputs using one concat.

    def _pad_tensor_for_fp8(self, hidden_states, tokens_per_expert):
        """Padding tensor shapes to multiples of 16."""
        padded_tokens_per_expert = [
            (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def forward(
        self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Forward of TEGroupedMLP

        Args:
            permuted_local_hidden_states (torch.Tensor): The permuted input hidden states of the
            local experts.
            tokens_per_expert (torch.Tensor): The number of tokens per expert.

        Return:
            output (torch.Tensor): The output of the local experts.
        """
        tokens_per_expert = tokens_per_expert.tolist()
        if self.config.fp8:
            actual_tokens_per_expert = tokens_per_expert
            permuted_local_hidden_states, tokens_per_expert = self._pad_tensor_for_fp8(
                permuted_local_hidden_states, tokens_per_expert
            )
        intermediate_parallel, bias_parallel = self.linear_fc1(
            permuted_local_hidden_states, tokens_per_expert
        )

@BeingGod
Copy link
Contributor Author

BeingGod commented Aug 19, 2024

Yeap, I notice that. My plan is to implement a fused kernel to deal multi _pad_tensor. But I think your plan is more elegant than mine.

By the way, have you consider implement a fused kernel to eliminate call of aten::zero_ (e.g. we can padding each tensor in kernel and write back result, it don't need output Tensor is zeros).

Thanks a lot !

@yaox12
Copy link
Collaborator

yaox12 commented Aug 19, 2024

By the way, have you consider implement a fused kernel to eliminate call of aten::zero_ (e.g. we can padding each tensor in kernel and write back result, it don't need output Tensor is zeros).

Not yet. It's a good idea. Maybe I can have a try.

I'm working on enabling FP8 support in MCore's TEGroupedMLP, and the above code is part of that. Perhaps I can add the optimized padding kernel there.

@yaox12
Copy link
Collaborator

yaox12 commented Aug 21, 2024

Hi @BeingGod, do you still need this feature with the padding method I pasted above?

@BeingGod
Copy link
Contributor Author

BeingGod commented Aug 21, 2024

Hi @yaox12 , Seems our work have some conflict. I'm trying to fuse padding + cast + transpose and I already implemented a fused kernel. Belowing is my analysis:

Before Workflow:
image

After Workflow:
image

Can you help me review this solution ?

@yaox12
Copy link
Collaborator

yaox12 commented Aug 21, 2024

@BeingGod Looks reasonable to me.

@BeingGod BeingGod closed this Aug 22, 2024
@BeingGod BeingGod reopened this Aug 22, 2024
@BeingGod
Copy link
Contributor Author

Hi @yaox12 , Seems our work have some conflict. I'm trying to fuse padding + cast + transpose and I already implemented a fused kernel. Belowing is my analysis:

Before Workflow: image

After Workflow: image

Can you help me review this solution ?

Hi, @yaox12 Can you help me review this PR #1129

@BeingGod BeingGod closed this Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants