From 2d5eb3fcf6b5f482ef85f7ee6770a31d5130df4f Mon Sep 17 00:00:00 2001 From: x-hgg-x <39058530+x-hgg-x@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:09:25 +0100 Subject: [PATCH] Add a provider method to register a conflict --- src/internal/core.rs | 8 +++++++- src/provider.rs | 33 +++++++++++++++++++++++++++++++-- src/solver.rs | 10 ++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/internal/core.rs b/src/internal/core.rs index 59594bc0..f81d12db 100644 --- a/src/internal/core.rs +++ b/src/internal/core.rs @@ -97,7 +97,7 @@ impl State { &mut self, package_id: PackageId, package_store: &PackageArena, - dependency_provider: &DP, + dependency_provider: &mut DP, ) -> Result<(), DerivationTree> { self.unit_propagation_buffer.clear(); self.unit_propagation_buffer.push(package_id); @@ -147,6 +147,12 @@ impl State { } } if let Some(incompat_id) = conflict_id { + dependency_provider.register_conflict( + self.incompatibility_store[incompat_id] + .iter() + .map(|(pid, _)| pid), + package_store, + ); let (package_almost, root_cause) = self .conflict_resolution(incompat_id, package_store, dependency_provider) .map_err(|terminal_incompat_id| { diff --git a/src/provider.rs b/src/provider.rs index c713ce0a..89f8dddd 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -69,6 +69,7 @@ impl VersionRanges for Ranges { pub struct OfflineDependencyProvider { #[allow(clippy::type_complexity)] dependencies: Map)>>, + conflicts: Map, } #[cfg(feature = "serde")] @@ -112,6 +113,7 @@ where .into_iter() .map(|(p, versions)| (p, versions.into_iter().collect())) .collect(), + conflicts: Map::default(), }) } } @@ -121,6 +123,7 @@ impl OfflineDependency pub fn new() -> Self { Self { dependencies: Map::default(), + conflicts: Map::default(), } } @@ -255,9 +258,20 @@ impl DependencyProvide &mut self, package_id: PackageId, set: VersionSet, - _: &PackageArena, + package_store: &PackageArena, ) -> Self::Priority { - Reverse(((set.count() as u64) << 32) + package_id.get() as u64) + let version_count = set.count(); + if version_count == 0 { + return Reverse(0); + } + let pkg = match package_store.pkg(package_id).unwrap() { + PackageVersionWrapper::Pkg(p) => p.pkg(), + PackageVersionWrapper::VirtualPkg(p) => p.pkg(), + PackageVersionWrapper::VirtualDep(p) => p.pkg(), + }; + let conflict_count = self.conflicts.get(pkg).copied().unwrap_or_default(); + + Reverse(((u32::MAX as u64).saturating_sub(conflict_count) << 6) + version_count as u64) } fn get_dependencies( @@ -398,4 +412,19 @@ impl DependencyProvide } } } + + fn register_conflict( + &mut self, + package_ids: impl Iterator, + package_store: &PackageArena, + ) { + for package_id in package_ids { + let pkg = match package_store.pkg(package_id).unwrap() { + PackageVersionWrapper::Pkg(p) => p.pkg(), + PackageVersionWrapper::VirtualPkg(p) => p.pkg(), + PackageVersionWrapper::VirtualDep(p) => p.pkg(), + }; + *self.conflicts.entry(pkg.clone()).or_default() += 1; + } + } } diff --git a/src/solver.rs b/src/solver.rs index 8e38c09c..ab9cab65 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -312,4 +312,14 @@ pub trait DependencyProvider { package: &'a Self::P, version_set: VersionSet, ) -> impl Display + 'a; + + /// Register a conflict for the given packages. + fn register_conflict( + &mut self, + package_ids: impl Iterator, + package_store: &PackageArena, + ) { + let _ = package_ids; + let _ = package_store; + } }