Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inner_split and outer_split to save typing #3951

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

wujingyue
Copy link
Collaborator

No description provided.

@wujingyue
Copy link
Collaborator Author

!test

Copy link

Description

  • Replace split with outer_split in tests

  • Add inner_split and outer_split methods in TensorView


Changes walkthrough 📝

Relevant files
Enhancement
test_multidevice_sharding.cpp
Replace `split` with `outer_split`                                             

tests/cpp/test_multidevice_sharding.cpp

  • Replaced tv->split(axis, factor, /*inner_split=*/false) with
    tv->outer_split(axis, factor)
  • +11/-11 
    interface_nodes.h
    Add `inner_split` and `outer_split` methods                           

    csrc/ir/interface_nodes.h

    • Added inner_split and outer_split template methods
    +10/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Impact

    Ensure that the introduction of outer_split instead of split does not negatively impact performance. Verify that the performance metrics are consistent with the expected improvements.

    tv->outer_split(1, d);
    Template Usage

    Verify that the template usage in inner_split and outer_split is appropriate and does not introduce unnecessary complexity or potential issues with type safety.

    template <typename FactorType>
    TensorView* inner_split(int64_t axis, FactorType factor) {
      return split(axis, factor, /*inner_split=*/true);
    }
    
    template <typename FactorType>
    TensorView* outer_split(int64_t axis, FactorType factor) {
      return split(axis, factor, /*inner_split=*/false);
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant