diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 87b170c94c4bb..52e8783c4f28d 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -2477,34 +2477,16 @@ static ur_result_t SetKernelParamsAndLaunch( return Error; } -ur_result_t enqueueImpCommandBufferKernel( - context Ctx, DeviceImplPtr DeviceImpl, - ur_exp_command_buffer_handle_t CommandBuffer, - const CGExecKernel &CommandGroup, - std::vector &SyncPoints, - ur_exp_command_buffer_sync_point_t *OutSyncPoint, - ur_exp_command_buffer_command_handle_t *OutCommand, - const std::function &getMemAllocationFunc) { - auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx); - const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); - - const std::vector> - &AlternativeKernels = CommandGroup.MAlternativeKernels; +namespace { +std::tuple, + const KernelArgMask *> +getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, + DeviceImplPtr DeviceImpl, + std::vector &UrKernelsToRelease, + std::vector &UrProgramsToRelease) { - // UR kernel and program for 'CommandGroup' ur_kernel_handle_t UrKernel = nullptr; - ur_program_handle_t UrProgram = nullptr; - - // Impl objects created when 'CommandGroup' is from a kernel bundle - std::shared_ptr SyclKernelImpl = nullptr; std::shared_ptr DeviceImageImpl = nullptr; - - // List of ur objects to be released after UR call - std::vector UrKernelsToRelease; - std::vector UrProgramsToRelease; - - auto Kernel = CommandGroup.MSyclKernel; - auto KernelBundleImplPtr = CommandGroup.MKernelBundle; const KernelArgMask *EliminatedArgMask = nullptr; // Use kernel_bundle if available unless it is interop. @@ -2512,63 +2494,74 @@ ur_result_t enqueueImpCommandBufferKernel( // in interop kernel bundles (if any) do not have kernel_id // and can therefore not be looked up, but since they are self-contained // they can simply be launched directly. - if (KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) { + if (auto KernelBundleImplPtr = CommandGroup.MKernelBundle; + KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) { auto KernelName = CommandGroup.MKernelName; kernel_id KernelID = detail::ProgramManager::getInstance().getSYCLKernelID(KernelName); + kernel SyclKernel = KernelBundleImplPtr->get_kernel(KernelID, KernelBundleImplPtr); - SyclKernelImpl = detail::getSyclObjImpl(SyclKernel); + + auto SyclKernelImpl = detail::getSyclObjImpl(SyclKernel); UrKernel = SyclKernelImpl->getHandleRef(); DeviceImageImpl = SyclKernelImpl->getDeviceImage(); - UrProgram = DeviceImageImpl->get_ur_program_ref(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - } else if (Kernel != nullptr) { + } else if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) { UrKernel = Kernel->getHandleRef(); - UrProgram = Kernel->getProgramRef(); EliminatedArgMask = Kernel->getKernelArgMask(); } else { + ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = sycl::detail::ProgramManager::getInstance().getOrCreateKernel( ContextImpl, DeviceImpl, CommandGroup.MKernelName); UrKernelsToRelease.push_back(UrKernel); UrProgramsToRelease.push_back(UrProgram); } + return std::make_tuple(UrKernel, DeviceImageImpl, EliminatedArgMask); +} +} // anonymous namespace + +ur_result_t enqueueImpCommandBufferKernel( + context Ctx, DeviceImplPtr DeviceImpl, + ur_exp_command_buffer_handle_t CommandBuffer, + const CGExecKernel &CommandGroup, + std::vector &SyncPoints, + ur_exp_command_buffer_sync_point_t *OutSyncPoint, + ur_exp_command_buffer_command_handle_t *OutCommand, + const std::function &getMemAllocationFunc) { + // List of ur objects to be released after UR call. We don't do anything + // with the ur_program_handle_t objects, but need to update their reference + // count. + std::vector UrKernelsToRelease; + std::vector UrProgramsToRelease; + + ur_kernel_handle_t UrKernel = nullptr; + std::shared_ptr DeviceImageImpl = nullptr; + const KernelArgMask *EliminatedArgMask = nullptr; + + auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx); + std::tie(UrKernel, DeviceImageImpl, EliminatedArgMask) = + getCGKernelInfo(CommandGroup, ContextImpl, DeviceImpl, UrKernelsToRelease, + UrProgramsToRelease); // Build up the list of UR kernel handles that the UR command could be // updated to use. std::vector AltUrKernels; + const std::vector> + &AlternativeKernels = CommandGroup.MAlternativeKernels; for (const auto &AltCGKernelWP : AlternativeKernels) { auto AltCGKernel = AltCGKernelWP.lock(); assert(AltCGKernel != nullptr); ur_kernel_handle_t AltUrKernel = nullptr; - if (auto KernelBundleImplPtr = AltCGKernel->MKernelBundle; - KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) { - auto KernelName = AltCGKernel->MKernelName; - kernel_id KernelID = - detail::ProgramManager::getInstance().getSYCLKernelID(KernelName); - kernel SyclKernel = - KernelBundleImplPtr->get_kernel(KernelID, KernelBundleImplPtr); - AltUrKernel = detail::getSyclObjImpl(SyclKernel)->getHandleRef(); - } else if (AltCGKernel->MSyclKernel != nullptr) { - AltUrKernel = Kernel->getHandleRef(); - } else { - ur_program_handle_t UrProgram = nullptr; - std::tie(AltUrKernel, std::ignore, std::ignore, UrProgram) = - sycl::detail::ProgramManager::getInstance().getOrCreateKernel( - ContextImpl, DeviceImpl, AltCGKernel->MKernelName); - UrKernelsToRelease.push_back(AltUrKernel); - UrProgramsToRelease.push_back(UrProgram); - } - - if (AltUrKernel != UrKernel) { - // Don't include command-group 'CommandGroup' in the list to pass to UR, - // as this will be used for the primary ur kernel parameter. - AltUrKernels.push_back(AltUrKernel); - } + std::tie(AltUrKernel, std::ignore, std::ignore) = + getCGKernelInfo(*AltCGKernel.get(), ContextImpl, DeviceImpl, + UrKernelsToRelease, UrProgramsToRelease); + AltUrKernels.push_back(AltUrKernel); } + const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &Ctx, &getMemAllocationFunc](sycl::detail::ArgDesc &Arg, size_t NextTrueIndex) {