Skip to content

[CPU] Fix CNNs compilation time #30679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,27 +282,33 @@ static std::tuple<primitive_desc, size_t> selectPrimitiveDescWithMultipleAttribu
};

PrimitiveDescWithPriority prim_desc_w_priority{dnnl::primitive_desc(), 0, implPriorities.size()};
const bool first_match = implPriorities.front() == impl_desc_type::unknown;

// try all the provided attributes and select the one which results in a primitive desc with the highest priority
for (size_t attrId = 0; attrId < attrs.size(); attrId++) {
const auto& attr = attrs[attrId];

for (size_t priorityId = 0; priorityId < implPriorities.size(); priorityId++) {
const auto preferredImplType = implPriorities[priorityId];
// the only way to fully reset primitive_desc after iterating over the implementations is to re-create it
auto cur_desc = createPrimitiveDescriptor(attr);
const bool found = DnnlExtensionUtils::find_implementation(cur_desc, preferredImplType);

const size_t highestPriority = prim_desc_w_priority.priority;
if (found && priorityId < highestPriority) {
prim_desc_w_priority = {cur_desc, attrId, priorityId};
}
}
auto cur_desc = createPrimitiveDescriptor(attr);

DnnlExtensionUtils::for_each_implementation(
cur_desc,
first_match,
[&](impl_desc_type implType) { // is acceptable implementation
return contains(implPriorities, implType);
},
[&](dnnl::primitive_desc& desc) { // is implementation with highest priority
const impl_desc_type descImplType = parse_impl_name(desc.impl_info_str());
const auto it = std::find(implPriorities.begin(), implPriorities.end(), descImplType);
const size_t priorityId = std::distance(implPriorities.begin(), it);
const size_t highestPriority = prim_desc_w_priority.priority;
if (priorityId < highestPriority) {
auto desc_copy = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get(true)));
prim_desc_w_priority = {desc_copy, attrId, priorityId};
}
});
}

auto prim_desc = prim_desc_w_priority.prim_desc;

return {prim_desc, prim_desc_w_priority.attrId};
return {prim_desc_w_priority.prim_desc, prim_desc_w_priority.attrId};
}

static primitive_desc createPrimitiveDesc(const dnnl::memory::desc& inputDesc,
Expand Down
Loading