diff --git a/clients/gtest/matrix_transform_gtest.cpp b/clients/gtest/matrix_transform_gtest.cpp index 995acd7e83..c3483c8084 100644 --- a/clients/gtest/matrix_transform_gtest.cpp +++ b/clients/gtest/matrix_transform_gtest.cpp @@ -1151,6 +1151,244 @@ TEST(MatrixTransformTest, ScalarsOnDevice) transB); } +TEST(MatrixTransformTest, MultipleDevices) +{ + int numDevices{}; + int curDevice{}; + auto hipErr = hipGetDeviceCount(&numDevices); + EXPECT_EQ(hipErr, hipSuccess); + hipErr = hipGetDevice(&curDevice); + EXPECT_EQ(hipErr, hipSuccess); + // acquire at most 2 devices + numDevices = std::min(numDevices, 2); + + for (int deviceId = 0; deviceId < numDevices; ++deviceId) + { + hipErr = hipSetDevice(deviceId); + EXPECT_EQ(hipErr, hipSuccess); + int64_t m = 1024; + int64_t n = 1024; + int32_t batchSize = 1; + auto datatype = HIP_R_32F; + auto scaleDatatype = HIP_R_32F; + auto opA = HIPBLAS_OP_N; + auto opB = HIPBLAS_OP_N; + auto orderA = HIPBLASLT_ORDER_ROW; + auto orderB = HIPBLASLT_ORDER_ROW; + auto orderC = HIPBLASLT_ORDER_COL; + float alpha = 1; + float beta = 1; + int64_t batchStride = m * n; + std::pair shapeA; + std::pair shapeB; + shapeA.first = opA == HIPBLAS_OP_T ? n : m; + shapeA.second = opA == HIPBLAS_OP_T ? m : n; + shapeB.first = opB == HIPBLAS_OP_T ? n : m; + shapeB.second = opB == HIPBLAS_OP_T ? m : n; + uint32_t ldA = (orderA == HIPBLASLT_ORDER_ROW) + ? getLeadingDimSize(shapeA.first, shapeA.second) + : getLeadingDimSize(shapeA.first, shapeA.second); + uint32_t ldB = (orderB == HIPBLASLT_ORDER_ROW) + ? getLeadingDimSize(shapeB.first, shapeB.second) + : getLeadingDimSize(shapeB.first, shapeB.second); + uint32_t ldC = (orderC == HIPBLASLT_ORDER_ROW) ? getLeadingDimSize(m, n) + : getLeadingDimSize(m, n); + + auto inputs = makeMatrixTransformIOPtr(datatype, m, n, batchSize); + void* dA = inputs->getBuf(0); + void* dB = inputs->getBuf(1); + void* dC = inputs->getBuf(2); + + hipblasLtMatrixTransformDesc_t desc; + auto hipblasLtErr = hipblasLtMatrixTransformDescCreate(&desc, scaleDatatype); + hipblasLtPointerMode_t pMode = HIPBLASLT_POINTER_MODE_HOST; + hipblasLtErr = hipblasLtMatrixTransformDescSetAttribute( + desc, + hipblasLtMatrixTransformDescAttributes_t::HIPBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + &pMode, + sizeof(pMode)); + + ASSERT_EQ(hipblasLtErr, HIPBLAS_STATUS_SUCCESS); + + hipblasLtErr = hipblasLtMatrixTransformDescSetAttribute( + desc, HIPBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opA, sizeof(opA)); + hipblasLtErr = hipblasLtMatrixTransformDescSetAttribute( + desc, HIPBLASLT_MATRIX_TRANSFORM_DESC_TRANSB, &opB, sizeof(opB)); + hipblasLtMatrixLayout_t layoutA, layoutB, layoutC; + hipblasLtErr + = hipblasLtMatrixLayoutCreate(&layoutA, datatype, shapeA.first, shapeA.second, ldA); + hipblasLtErr + = hipblasLtMatrixLayoutCreate(&layoutB, datatype, shapeB.first, shapeB.second, ldB); + hipblasLtErr = hipblasLtMatrixLayoutCreate(&layoutC, datatype, m, n, ldC); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutA, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER, + &orderA, + sizeof(orderA)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutB, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER, + &orderB, + sizeof(orderB)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutC, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER, + &orderC, + sizeof(orderC)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutA, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchSize, + sizeof(batchSize)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutB, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchSize, + sizeof(batchSize)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutC, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchSize, + sizeof(batchSize)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutA, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batchStride, + sizeof(batchStride)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutB, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batchStride, + sizeof(batchStride)); + hipblasLtErr = hipblasLtMatrixLayoutSetAttribute( + layoutC, + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batchStride, + sizeof(batchStride)); + hipblasLtHandle_t handle{}; + hipblasLtErr = hipblasLtCreate(&handle); + hipblasLtErr = hipblasLtMatrixTransform( + handle, desc, &alpha, dA, layoutA, &beta, dB, layoutB, dC, layoutC, nullptr); + ASSERT_EQ(hipblasLtErr, HIPBLAS_STATUS_SUCCESS); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + auto rowMajA = (orderA == HIPBLASLT_ORDER_ROW); + auto rowMajB = (orderB == HIPBLASLT_ORDER_ROW); + auto rowMajC = (orderC == HIPBLASLT_ORDER_ROW); + auto transA = (opA == HIPBLAS_OP_T); + auto transB = (opB == HIPBLAS_OP_T); + + if(datatype == HIP_R_32F) + { + validation(dC, + dA, + dB, + alpha, + beta, + m, + n, + ldA, + ldB, + ldC, + batchSize, + batchStride, + rowMajA, + rowMajB, + rowMajC, + transA, + transB); + } + else if(datatype == HIP_R_16F) + { + validation(dC, + dA, + dB, + alpha, + beta, + m, + n, + ldA, + ldB, + ldC, + batchSize, + batchStride, + rowMajA, + rowMajB, + rowMajC, + transA, + transB); + } + else if(datatype == HIP_R_16BF) + { + validation(dC, + dA, + dB, + alpha, + beta, + m, + n, + ldA, + ldB, + ldC, + batchSize, + batchStride, + rowMajA, + rowMajB, + rowMajC, + transA, + transB); + } + else if(datatype == HIP_R_8I) + { + validation(dC, + dA, + dB, + alpha, + beta, + m, + n, + ldA, + ldB, + ldC, + batchSize, + batchStride, + rowMajA, + rowMajB, + rowMajC, + transA, + transB); + } + else if(datatype == HIP_R_32I) + { + validation(dC, + dA, + dB, + alpha, + beta, + m, + n, + ldA, + ldB, + ldC, + batchSize, + batchStride, + rowMajA, + rowMajB, + rowMajC, + transA, + transB); + } + + hipblasLtErr = hipblasLtMatrixTransformDescDestroy(desc); + hipblasLtErr = hipblasLtDestroy(handle); + hipblasLtErr = hipblasLtMatrixLayoutDestroy(layoutA); + hipblasLtErr = hipblasLtMatrixLayoutDestroy(layoutB); + hipblasLtErr = hipblasLtMatrixLayoutDestroy(layoutC); + } + + hipErr = hipSetDevice(curDevice); + EXPECT_EQ(hipErr, hipSuccess); +} + INSTANTIATE_TEST_SUITE_P( AllCombinations, MatrixTransformTest, diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_transform.cpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_transform.cpp index c0503af656..be6e6f3dd0 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_transform.cpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_transform.cpp @@ -29,6 +29,7 @@ #include "rocblaslt-types.h" #include "rocblaslt.h" #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include namespace { @@ -65,23 +67,35 @@ namespace TensileLite::hip::SolutionAdapter& transformAdapter() { - static auto& adapter = []() -> TensileLite::hip::SolutionAdapter& { - static TensileLite::hip::SolutionAdapter adp; - auto coPath = transformCodeObjectPath(); - const std::string coFolder = dirname(&coPath[0]); + using AdapterPtr = std::unique_ptr; + static auto& adapter = []() -> std::vector& { + static std::vector adapters; + int numDevices{}; + HIP_CHECK_EXC(hipGetDeviceCount(&numDevices)); + for(int i = 0; i < numDevices; ++i) + { + adapters.emplace_back(new TensileLite::hip::SolutionAdapter); + } + auto coPath = transformCodeObjectPath(); + const std::string coFolder = dirname(&coPath[0]); try { - (void)adp.initializeLazyLoading("", coFolder); + for(auto& adp : adapters) + { + (void)adp->initializeLazyLoading("", coFolder); + } } catch(const std::runtime_error& e) { rocblaslt_log_error( "transformCodeObject", "TransformCodeObjectPath", coFolder.c_str()); } - return adp; + return adapters; }(); - return adapter; + int deviceId{}; + HIP_CHECK_EXC(hipGetDevice(&deviceId)); + return *adapter.at(deviceId); } rocblaslt_matrix_layout dummyMatrixLayout() @@ -159,7 +173,7 @@ namespace betaPtr = dummyScalarPtr(); } - const ScaleType *nullScalePtr = nullptr; + const ScaleType* nullScalePtr = nullptr; kArgs.appendAligned("c", c); kArgs.appendAligned("a", a);