Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,13 +707,13 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture

#ifdef USING_OSS_CUTLASS_MOE_GEMM
mGemmProfilerBackend.init(mMoERunner, GemmProfilerBackend::GemmToProfile::Undefined, typeToDtypeID<DataType>(),
typeToDtypeID<WeightType>(), typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mInterSize,
mGroupSize, mActType, mUseBias, mUseLora, /*min_latency_mode=*/false,
typeToDtypeID<WeightType>(), typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mHiddenSize,
mInterSize, mGroupSize, mActType, mUseBias, mUseLora, /*min_latency_mode=*/false,
/*need_weights=*/false, parallelism_config, /*enable_alltoall=*/false);
#else
mGemmProfilerBackend.init(mMoERunner, GemmProfilerBackend::GemmToProfile::Undefined, typeToDtypeID<DataType>(),
typeToDtypeID<WeightType>(), typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mInterSize,
mGroupSize, mActType, mUseBias, mUseLora, /*min_latency_mode=*/false,
typeToDtypeID<WeightType>(), typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mHiddenSize,
mInterSize, mGroupSize, mActType, mUseBias, mUseLora, /*min_latency_mode=*/false,
/*need_weights=*/false, parallelism_config);
#endif

Expand Down Expand Up @@ -989,7 +989,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mFinalOutput + mFinalOutputSize * mBufferIndex,
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
Expand All @@ -1001,10 +1001,10 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mFinalOutput + mFinalOutputSize * mBufferIndex,
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config, mUseLora,
mLoraParams[mBufferIndex],
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
#endif
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,29 @@ struct Sm90ScatterPtrArray {
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{}));

using ElementIndex = int32_t;
// TODO: more generic treatment, or pass StrideIndex via template param?
using StrideIndex = conditional_t<cutlass::gemm::detail::is_mn_major<StrideOutput>(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>;

static constexpr bool MajorMode = cutlass::gemm::detail::is_major<0,StrideOutput>() ? 0 : 1;

using StrideIndex = decltype(replace<1-MajorMode>(Stride<_0,_0,_0>{}, Int<1>{}));

struct SharedStorage {};

struct Arguments {
ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
bool use_reduction = true;
ElementOutput* ptr_out{}; // output tensor pointer
StrideOutput dOut = {}; // output tensor stride
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
int shape_override = -1; // override value for contiguous output tensor mode
bool use_reduction = true; // use reduction or regular store
};

struct Params {
ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store
bool use_reduction = true;
ElementOutput* ptr_out{}; // output tensor pointer
StrideOutput dOut = {}; // output tensor stride
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store
int shape_override = -1; // override value for contiguous output tensor mode
bool use_reduction = true; // use reduction or regular store
};

template <class ProblemShape>
Expand All @@ -96,6 +100,7 @@ struct Sm90ScatterPtrArray {
args.dOut,
args.ptr_index,
cutlass::FastDivmod(args.index_modulo),
args.shape_override,
args.use_reduction
};
}
Expand Down Expand Up @@ -329,6 +334,14 @@ struct Sm90ScatterPtrArray {
Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)

auto residue_cD = args.residue_cD;

// If shape_override is set, adjust residue_cD to change predication.
// This is used to support fused slicing (where the output tensor is smaller than problem shape)
if (params_ptr->shape_override >= 0) {
get<MajorMode>(residue_cD) += params_ptr->shape_override - get<MajorMode>(args.problem_shape_mnkl);
}

auto args_tuple = make_tuple(
cute::move(tC_rOut),
tiled_r2s,
Expand All @@ -338,7 +351,7 @@ struct Sm90ScatterPtrArray {
tiled_r2g_stg,
params_ptr->use_reduction,
args.thread_idx,
args.residue_cD);
residue_cD);

return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple));
}
Expand Down Expand Up @@ -512,11 +525,12 @@ struct FusionCallbacks<
// using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
// ScatterArguments scatter{};

ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
int const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
bool use_reduction = true;
ElementOutput* ptr_out{}; // output tensor pointer
StrideOutput dOut{}; // output tensor stride
int const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
int shape_override = -1; // override value for contiguous output tensor mode
bool use_reduction = true; // use reduction or regular store

operator typename Impl::Arguments() const {
return
Expand All @@ -532,7 +546,7 @@ struct FusionCallbacks<
{} // binary args: multiply
}, // end binary op
//scatter // unary args: reduce
{ ptr_out, dOut, ptr_index, index_modulo, use_reduction }
{ ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction }
}; // end unary op
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ struct TmaWarpSpecializedGroupedGemmInput

int const** ptr_source_token_index = nullptr;
int num_rows_in_final_output = 0;
int shape_override = -1;

bool use_reduction = true;
};
Expand Down
Loading