diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index 0da6ea9bf9216..5922477c766e5 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -1341,6 +1341,18 @@ class device_image_impl std::unique_ptr MMergedImageStorage = nullptr; }; +using device_images_iterator = + variadic_iterator::const_iterator, + std::set::const_iterator>; +class device_images_range : public iterator_range { +private: + using Base = iterator_range; + +public: + using Base::Base; +}; + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/helpers.hpp b/sycl/source/detail/helpers.hpp index aa86458d582af..a5612cd2211e5 100644 --- a/sycl/source/detail/helpers.hpp +++ b/sycl/source/detail/helpers.hpp @@ -91,6 +91,8 @@ template class variadic_iterator { }, It); } + + pointer operator->() { return &this->operator*(); } }; // Non-owning! diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index 4000bafaf96b2..2f6f5235602b8 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -249,8 +249,7 @@ class kernel_bundle_impl // images with specialization constants in separation. // TODO: Remove when spec const overwriting issue has been fixed in L0. std::vector ImagesWithSpecConsts; - std::unordered_set> - ImagesWithSpecConstsSet; + std::unordered_set ImagesWithSpecConstsSet; for (const kernel_bundle &ObjectBundle : ObjectBundles) { for (const DevImgPlainWithDeps &DeviceImageWithDeps : @@ -265,36 +264,32 @@ class kernel_bundle_impl ImagesWithSpecConsts.push_back(&DeviceImageWithDeps); for (const device_image_plain &DevImg : DeviceImageWithDeps) - ImagesWithSpecConstsSet.insert(getSyclObjImpl(DevImg)); + ImagesWithSpecConstsSet.insert(&*getSyclObjImpl(DevImg)); } } // Collect all unique images. std::vector DevImages; { - std::set> DevImagesSet; + std::set DevImagesSet; std::unordered_set SeenBinImgs; for (const kernel_bundle &ObjectBundle : ObjectBundles) { - for (const device_image_plain &DevImg : - getSyclObjImpl(ObjectBundle)->MUniqueDeviceImages) { - auto &DevImgImpl = getSyclObjImpl(DevImg); - const RTDeviceBinaryImage *BinImg = DevImgImpl->get_bin_image_ref(); + for (device_image_impl &DevImg : + getSyclObjImpl(ObjectBundle)->device_images()) { + const RTDeviceBinaryImage *BinImg = DevImg.get_bin_image_ref(); // We have duplicate images if either the underlying binary image has // been seen before or the device image implementation is in the // image set already. - if ((BinImg && SeenBinImgs.find(BinImg) != SeenBinImgs.end()) || - ImagesWithSpecConstsSet.find(DevImgImpl) != - ImagesWithSpecConstsSet.end()) + if ((BinImg && SeenBinImgs.count(BinImg)) || + ImagesWithSpecConstsSet.count(&DevImg)) continue; SeenBinImgs.insert(BinImg); - DevImagesSet.insert(DevImgImpl); + DevImagesSet.insert(&DevImg); } } - DevImages.reserve(DevImagesSet.size()); - for (auto It = DevImagesSet.begin(); It != DevImagesSet.end();) - DevImages.push_back(createSyclObjFromImpl( - std::move(DevImagesSet.extract(It++).value()))); + DevImages = device_images_range{DevImagesSet} + .to>(); } // Check for conflicting kernels in RTC kernel bundles. @@ -504,8 +499,7 @@ class kernel_bundle_impl if (get_bundle_state() == bundle_state::input) { // Copy spec constants values from the device images. - auto MergeSpecConstants = [this](const device_image_plain &Img) { - detail::device_image_impl &ImgImpl = *getSyclObjImpl(Img); + for (detail::device_image_impl &ImgImpl : device_images()) { const std::map> &SpecConsts = ImgImpl.get_spec_const_data_ref(); @@ -521,8 +515,7 @@ class kernel_bundle_impl SpecConst.second.back().CompositeOffset + SpecConst.second.back().Size); } - }; - std::for_each(begin(), end(), MergeSpecConstants); + } } for (const detail::KernelBundleImplPtr &Bundle : Bundles) { @@ -629,11 +622,11 @@ class kernel_bundle_impl std::vector NewDevImgs; std::vector> NewBinReso; - for (device_image_plain &DevImg : MUniqueDeviceImages) { + for (device_image_impl &DevImg : device_images()) { std::vector> NewDevImgImpls = - getSyclObjImpl(DevImg)->buildFromSource( - Devices, BuildOptions, LogPtr, RegisteredKernelNames, NewBinReso); - NewDevImgs.reserve(NewDevImgImpls.size()); + DevImg.buildFromSource(Devices, BuildOptions, LogPtr, + RegisteredKernelNames, NewBinReso); + NewDevImgs.reserve(NewDevImgs.size() + NewDevImgImpls.size()); for (std::shared_ptr &DevImgImpl : NewDevImgImpls) NewDevImgs.emplace_back(std::move(DevImgImpl)); } @@ -652,12 +645,11 @@ class kernel_bundle_impl std::vector NewDevImgs; std::vector> NewBinReso; - for (device_image_plain &DevImg : MUniqueDeviceImages) { + for (device_image_impl &DevImg : device_images()) { std::vector> NewDevImgImpls = - getSyclObjImpl(DevImg)->compileFromSource( - Devices, CompileOptions, LogPtr, RegisteredKernelNames, - NewBinReso); - NewDevImgs.reserve(NewDevImgImpls.size()); + DevImg.compileFromSource(Devices, CompileOptions, LogPtr, + RegisteredKernelNames, NewBinReso); + NewDevImgs.reserve(NewDevImgs.size() + NewDevImgImpls.size()); for (std::shared_ptr &DevImgImpl : NewDevImgImpls) NewDevImgs.emplace_back(std::move(DevImgImpl)); } @@ -667,10 +659,9 @@ class kernel_bundle_impl public: bool ext_oneapi_has_kernel(const std::string &Name) const { - return std::any_of(begin(), end(), - [&Name](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->hasKernelName(Name); - }); + return any_of(device_images(), [&Name](device_image_impl &DevImg) { + return DevImg.hasKernelName(Name); + }); } kernel ext_oneapi_get_kernel(const std::string &Name) const { @@ -686,10 +677,9 @@ class kernel_bundle_impl // kernels. In this case, all these bundles should be found and the // resulting kernel object should be able to map devices to their // respective backend kernel objects. - for (const device_image_plain &DevImg : MUniqueDeviceImages) { - device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg); + for (device_image_impl &DevImg : device_images()) { if (std::shared_ptr PotentialKernelImpl = - DevImgImpl.tryGetExtensionKernel(Name, MContext, *this)) + DevImg.tryGetExtensionKernel(Name, MContext, *this)) return detail::createSyclObjFromImpl( std::move(PotentialKernelImpl)); } @@ -705,25 +695,24 @@ class kernel_bundle_impl "files and kernel_bundles successfully built from " "kernel_bundle."); - auto It = - std::find_if(begin(), end(), [&Name](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->hasKernelName(Name); - }); - if (It == end()) + auto It = std::find_if(device_images().begin(), device_images().end(), + [&Name](device_image_impl &DevImg) { + return DevImg.hasKernelName(Name); + }); + if (It == device_images().end()) throw sycl::exception(make_error_code(errc::invalid), "kernel '" + Name + "' not found in kernel_bundle"); - return getSyclObjImpl(*It)->adjustKernelName(Name); + return It->adjustKernelName(Name); } bool ext_oneapi_has_device_global(const std::string &Name) const { std::string MangledName = mangleDeviceGlobalName(Name); return (MDeviceGlobals.size() && MDeviceGlobals.tryGetEntryLockless(MangledName)) || - std::any_of(begin(), end(), - [&MangledName](const device_image_plain &DeviceImage) { - return getSyclObjImpl(DeviceImage) - ->hasDeviceGlobalName(MangledName); + std::any_of(device_images().begin(), device_images().end(), + [&MangledName](device_image_impl &DeviceImage) { + return DeviceImage.hasDeviceGlobalName(MangledName); }); } @@ -774,16 +763,13 @@ class kernel_bundle_impl std::vector get_kernel_ids() const { // Collect kernel ids from all device images, then remove duplicates std::vector Result; - for (const device_image_plain &DeviceImage : MUniqueDeviceImages) { - detail::device_image_impl &DevImgImpl = *getSyclObjImpl(DeviceImage); - + for (device_image_impl &DevImg : device_images()) { // RTC kernel bundles shouldn't have user-facing kernel ids, return an // empty vector when the bundle contains RTC kernels. - if (DevImgImpl.getRTCInfo()) + if (DevImg.getRTCInfo()) continue; - auto KernelIDs = DevImgImpl.get_kernel_ids(); - + auto KernelIDs = DevImg.get_kernel_ids(); Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end()); } std::sort(Result.begin(), Result.end(), LessByNameComp{}); @@ -803,51 +789,43 @@ class kernel_bundle_impl } bool has_kernel(const kernel_id &KernelID) const noexcept { - return std::any_of(begin(), end(), - [&KernelID](const device_image_plain &DeviceImage) { - return DeviceImage.has_kernel(KernelID); - }); + return any_of(device_images(), [&KernelID](device_image_impl &DeviceImage) { + return DeviceImage.has_kernel(KernelID); + }); } bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept { - return std::any_of( - begin(), end(), - [&KernelID, &Dev](const device_image_plain &DeviceImage) { - return DeviceImage.has_kernel(KernelID, Dev); - }); + return any_of(device_images(), + [&KernelID, &Dev](device_image_impl &DeviceImage) { + return DeviceImage.has_kernel(KernelID, Dev); + }); } bool contains_specialization_constants() const noexcept { - return std::any_of( - begin(), end(), [](const device_image_plain &DeviceImage) { - return getSyclObjImpl(DeviceImage)->has_specialization_constants(); - }); + return any_of(device_images(), [](device_image_impl &DeviceImage) { + return DeviceImage.has_specialization_constants(); + }); } bool native_specialization_constant() const noexcept { return contains_specialization_constants() && - std::all_of(begin(), end(), - [](const device_image_plain &DeviceImage) { - return getSyclObjImpl(DeviceImage) - ->all_specialization_constant_native(); - }); + all_of(device_images(), [](device_image_impl &DeviceImage) { + return DeviceImage.all_specialization_constant_native(); + }); } bool has_specialization_constant(const char *SpecName) const noexcept { - return std::any_of(begin(), end(), - [SpecName](const device_image_plain &DeviceImage) { - return getSyclObjImpl(DeviceImage) - ->has_specialization_constant(SpecName); - }); + return any_of(device_images(), [SpecName](device_image_impl &DeviceImage) { + return DeviceImage.has_specialization_constant(SpecName); + }); } void set_specialization_constant_raw_value(const char *SpecName, const void *Value, size_t Size) noexcept { if (has_specialization_constant(SpecName)) - for (const device_image_plain &DeviceImage : MUniqueDeviceImages) - getSyclObjImpl(DeviceImage) - ->set_specialization_constant_raw_value(SpecName, Value); + for (device_image_impl &DeviceImage : device_images()) + DeviceImage.set_specialization_constant_raw_value(SpecName, Value); else { std::vector &Val = MSpecConstValues[std::string{SpecName}]; Val.resize(Size); @@ -857,10 +835,9 @@ class kernel_bundle_impl void get_specialization_constant_raw_value(const char *SpecName, void *ValueRet) const noexcept { - for (const device_image_plain &DeviceImage : MUniqueDeviceImages) - if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) { - getSyclObjImpl(DeviceImage) - ->get_specialization_constant_raw_value(SpecName, ValueRet); + for (device_image_impl &DeviceImage : device_images()) + if (DeviceImage.has_specialization_constant(SpecName)) { + DeviceImage.get_specialization_constant_raw_value(SpecName, ValueRet); return; } @@ -879,19 +856,21 @@ class kernel_bundle_impl } bool is_specialization_constant_set(const char *SpecName) const noexcept { - bool SetInDevImg = std::any_of( - begin(), end(), [SpecName](const device_image_plain &DeviceImage) { - return getSyclObjImpl(DeviceImage) - ->is_specialization_constant_set(SpecName); + bool SetInDevImg = + any_of(device_images(), [SpecName](device_image_impl &DeviceImage) { + return DeviceImage.is_specialization_constant_set(SpecName); }); return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0; } + // Don't use these two for code under `source/detail`, they are only needed to + // communicate across DSO boundary. const device_image_plain *begin() const { return MUniqueDeviceImages.data(); } - const device_image_plain *end() const { return MUniqueDeviceImages.data() + MUniqueDeviceImages.size(); } + // ...use that instead. + device_images_range device_images() const { return MUniqueDeviceImages; } size_t size() const noexcept { return MUniqueDeviceImages.size(); } @@ -931,28 +910,26 @@ class kernel_bundle_impl } bool hasSourceBasedImages() const noexcept { - return std::any_of(begin(), end(), [](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->getOriginMask() & - ImageOriginKernelCompiler; + return any_of(device_images(), [](device_image_impl &DevImg) { + return DevImg.getOriginMask() & ImageOriginKernelCompiler; }); } bool hasSYCLBINImages() const noexcept { - return std::any_of(begin(), end(), [](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->getOriginMask() & ImageOriginSYCLBIN; + return any_of(device_images(), [](device_image_impl &DevImg) { + return DevImg.getOriginMask() & ImageOriginSYCLBIN; }); } bool hasSYCLOfflineImages() const noexcept { - return std::any_of(begin(), end(), [](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->getOriginMask() & ImageOriginSYCLOffline; + return any_of(device_images(), [](device_image_impl &DevImg) { + return DevImg.getOriginMask() & ImageOriginSYCLOffline; }); } bool allSourceBasedImages() const noexcept { - return std::all_of(begin(), end(), [](const device_image_plain &DevImg) { - return getSyclObjImpl(DevImg)->getOriginMask() & - ImageOriginKernelCompiler; + return all_of(device_images(), [](device_image_impl &DevImg) { + return DevImg.getOriginMask() & ImageOriginKernelCompiler; }); } @@ -1026,10 +1003,9 @@ class kernel_bundle_impl // TODO: For source-based kernels, it may be faster to keep a map between // {kernel_name, device} and their corresponding image. // First look through the kernels registered in source-based images. - for (const device_image_plain &DevImg : MUniqueDeviceImages) { - device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg); + for (device_image_impl &DevImg : device_images()) { if (std::shared_ptr SourceBasedKernel = - DevImgImpl.tryGetExtensionKernel(Name, MContext, *this)) + DevImg.tryGetExtensionKernel(Name, MContext, *this)) return SourceBasedKernel; } @@ -1065,9 +1041,9 @@ class kernel_bundle_impl MDeviceGlobals.tryGetEntryLockless(MangledName)) return Entry; - for (const device_image_plain &DevImg : MUniqueDeviceImages) + for (device_image_impl &DevImg : device_images()) if (DeviceGlobalMapEntry *Entry = - getSyclObjImpl(DevImg)->tryGetDeviceGlobalEntry(MangledName)) + DevImg.tryGetDeviceGlobalEntry(MangledName)) return Entry; throw sycl::exception(make_error_code(errc::invalid), @@ -1083,11 +1059,9 @@ class kernel_bundle_impl void populateDeviceGlobalsForSYCLBIN() { // This should only be called from ctors, so lockless initialization is // safe. - for (const device_image_plain &DevImg : MUniqueDeviceImages) { - const auto &DevImgImpl = getSyclObjImpl(DevImg); - if (DevImgImpl->getOriginMask() & ImageOriginSYCLBIN) - if (const RTDeviceBinaryImage *DevBinImg = - DevImgImpl->get_bin_image_ref()) + for (device_image_impl &DevImg : device_images()) { + if (DevImg.getOriginMask() & ImageOriginSYCLBIN) + if (const RTDeviceBinaryImage *DevBinImg = DevImg.get_bin_image_ref()) MDeviceGlobals.initializeEntriesLockless(DevBinImg); } }