|
4 | 4 | #pragma once |
5 | 5 |
|
6 | 6 | #include "ck/utility/common_header.hpp" |
7 | | -#include "ck/utility/env.hpp" |
8 | 7 | #include "ck/tensor_description/multi_index_transform_helper.hpp" |
9 | 8 | #include "ck/tensor_description/tensor_descriptor.hpp" |
10 | 9 | #include "ck/tensor_description/tensor_descriptor_helper.hpp" |
@@ -607,203 +606,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 |
607 | 606 | c_block_size * sizeof(CShuffleDataType)); |
608 | 607 | } |
609 | 608 |
|
610 | | - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} |
611 | | - __host__ static constexpr bool CheckValidity(const Argument& karg) |
612 | | - { |
613 | | - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && |
614 | | - (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, |
615 | | - "Invalid tuning param!"); |
616 | | - |
617 | | - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || |
618 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || |
619 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || |
620 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && |
621 | | - !(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)) |
622 | | - { |
623 | | - if(!(karg.M % MPerBlock == 0)) |
624 | | - { |
625 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
626 | | - { |
627 | | - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " |
628 | | - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
629 | | - << std::endl; |
630 | | - } |
631 | | - return false; |
632 | | - } |
633 | | - } |
634 | | - |
635 | | - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || |
636 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || |
637 | | - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || |
638 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && |
639 | | - (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)) |
640 | | - { |
641 | | - if(!(karg.N % NPerBlock == 0)) |
642 | | - { |
643 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
644 | | - { |
645 | | - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " |
646 | | - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
647 | | - << std::endl; |
648 | | - } |
649 | | - return false; |
650 | | - } |
651 | | - } |
652 | | - |
653 | | - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || |
654 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || |
655 | | - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || |
656 | | - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) |
657 | | - { |
658 | | - |
659 | | - auto K_t = karg.KBatch * KPerBlock; |
660 | | - if(!(karg.K % K_t == 0)) |
661 | | - { |
662 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
663 | | - { |
664 | | - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " |
665 | | - << karg.K << " " << __FILE__ << ":" << __LINE__ |
666 | | - << ", in function: " << __func__ << std::endl; |
667 | | - } |
668 | | - return false; |
669 | | - } |
670 | | - } |
671 | | - else |
672 | | - { |
673 | | - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); |
674 | | - auto K_t = karg.KBatch * KReadVec; |
675 | | - auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; |
676 | | - if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) |
677 | | - { |
678 | | - return false; |
679 | | - } |
680 | | - } |
681 | | - |
682 | | - if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) |
683 | | - { |
684 | | - if(karg.K % ABlockTransferSrcScalarPerVector != 0) |
685 | | - { |
686 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
687 | | - { |
688 | | - std::cout << "Arg K (" << karg.K |
689 | | - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" |
690 | | - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
691 | | - << __LINE__ << ", in function: " << __func__ << std::endl; |
692 | | - } |
693 | | - return false; |
694 | | - } |
695 | | - } |
696 | | - else |
697 | | - { |
698 | | - if(karg.M % ABlockTransferSrcScalarPerVector != 0) |
699 | | - { |
700 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
701 | | - { |
702 | | - std::cout << "Arg M (" << karg.M |
703 | | - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" |
704 | | - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
705 | | - << __LINE__ << ", in function: " << __func__ << std::endl; |
706 | | - } |
707 | | - return false; |
708 | | - } |
709 | | - } |
710 | | - |
711 | | - if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) |
712 | | - { |
713 | | - if(karg.N % BBlockTransferSrcScalarPerVector != 0) |
714 | | - { |
715 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
716 | | - { |
717 | | - std::cout << "Arg N (" << karg.N |
718 | | - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" |
719 | | - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
720 | | - << __LINE__ << ", in function: " << __func__ << std::endl; |
721 | | - } |
722 | | - return false; |
723 | | - } |
724 | | - } |
725 | | - else |
726 | | - { |
727 | | - if(karg.K % BBlockTransferSrcScalarPerVector != 0) |
728 | | - { |
729 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
730 | | - { |
731 | | - std::cout << "Arg K (" << karg.K |
732 | | - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" |
733 | | - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" |
734 | | - << __LINE__ << ", in function: " << __func__ << std::endl; |
735 | | - } |
736 | | - return false; |
737 | | - } |
738 | | - } |
739 | | - |
740 | | - if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) |
741 | | - { |
742 | | - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) |
743 | | - { |
744 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
745 | | - { |
746 | | - std::cout << "Arg N (" << karg.N |
747 | | - << ") value is not a multiple of " |
748 | | - "CShuffleBlockTransferScalarPerVector_NPerBlock (" |
749 | | - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " |
750 | | - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
751 | | - << std::endl; |
752 | | - } |
753 | | - return false; |
754 | | - } |
755 | | - } |
756 | | - else |
757 | | - { |
758 | | - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) |
759 | | - { |
760 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
761 | | - { |
762 | | - std::cout << "Arg M (" << karg.M |
763 | | - << ") value is not a multiple of " |
764 | | - "CShuffleBlockTransferScalarPerVector_NPerBlock (" |
765 | | - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " |
766 | | - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ |
767 | | - << std::endl; |
768 | | - } |
769 | | - return false; |
770 | | - } |
771 | | - } |
772 | | - |
773 | | - if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value || |
774 | | - is_same<remove_cvref_t<CDataType>, float>::value || |
775 | | - is_same<remove_cvref_t<CDataType>, bhalf_t>::value || |
776 | | - is_same<remove_cvref_t<CDataType>, int32_t>::value)) |
777 | | - { |
778 | | - if(!karg.IsReduceAdd()) |
779 | | - { |
780 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
781 | | - { |
782 | | - std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ |
783 | | - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; |
784 | | - } |
785 | | - if(karg.KBatch > 1) |
786 | | - { |
787 | | - return false; |
788 | | - } |
789 | | - } |
790 | | - } |
791 | | - |
792 | | - // check gridwise gemm pipeline |
793 | | - const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); |
794 | | - |
795 | | - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) |
796 | | - { |
797 | | - if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) |
798 | | - { |
799 | | - return false; |
800 | | - } |
801 | | - } |
802 | | - |
803 | | - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) |
804 | | - return true; |
805 | | - } |
806 | | - |
807 | 609 | __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) |
808 | 610 | { |
809 | 611 | const index_t num_loop = K / KPerBlock; |
|
0 commit comments