Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: do not recompute the intersection when backtracking #171

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/internal/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> State<P, VS, Priority> {
incompat_changed: bool,
decision_level: DecisionLevel,
) {
self.partial_solution
.backtrack(decision_level, &self.incompatibility_store);
self.partial_solution.backtrack(decision_level);
// Remove contradicted incompatibilities that depend on decisions we just backtracked away.
self.contradicted_incompatibilities
.retain(|_, dl| *dl <= decision_level);
Expand Down
65 changes: 30 additions & 35 deletions src/internal/partial_solution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub struct DatedDerivation<P: Package, VS: VersionSet> {
global_index: u32,
decision_level: DecisionLevel,
cause: IncompId<P, VS>,
accumulated_intersection: Term<VS>,
}

impl<P: Package, VS: VersionSet> Display for DatedDerivation<P, VS> {
Expand Down Expand Up @@ -207,11 +208,11 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
store: &Arena<Incompatibility<P, VS>>,
) {
use indexmap::map::Entry;
let term = store[cause].get(&package).unwrap().negate();
let dated_derivation = DatedDerivation {
let mut dated_derivation = DatedDerivation {
global_index: self.next_global_index,
decision_level: self.current_decision_level,
cause,
accumulated_intersection: store[cause].get(&package).unwrap().negate(),
};
self.next_global_index += 1;
let pa_last_index = self.package_assignments.len().saturating_sub(1);
Expand All @@ -226,7 +227,8 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
panic!("add_derivation should not be called after a decision")
}
AssignmentsIntersection::Derivations(t) => {
*t = t.intersection(&term);
*t = t.intersection(&dated_derivation.accumulated_intersection);
dated_derivation.accumulated_intersection = t.clone();
if t.is_positive() {
// we can use `swap_indices` to make `changed_this_decision_level` only go down by 1
// but the copying is slower then the larger search
Expand All @@ -238,6 +240,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
pa.dated_derivations.push(dated_derivation);
}
Entry::Vacant(v) => {
let term = dated_derivation.accumulated_intersection.clone();
if term.is_positive() {
self.changed_this_decision_level =
std::cmp::min(self.changed_this_decision_level, pa_last_index);
Expand Down Expand Up @@ -297,13 +300,9 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
}

/// Backtrack the partial solution to a given decision level.
pub fn backtrack(
&mut self,
decision_level: DecisionLevel,
store: &Arena<Incompatibility<P, VS>>,
) {
pub fn backtrack(&mut self, decision_level: DecisionLevel) {
self.current_decision_level = decision_level;
self.package_assignments.retain(|p, pa| {
self.package_assignments.retain(|_p, pa| {
if pa.smallest_decision_level > decision_level {
// Remove all entries that have a smallest decision level higher than the backtrack target.
false
Expand All @@ -325,18 +324,14 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
}
debug_assert!(!pa.dated_derivations.is_empty());

let last = pa.dated_derivations.last().unwrap();

// Update highest_decision_level.
pa.highest_decision_level = pa.dated_derivations.last().unwrap().decision_level;

// Recompute the assignments intersection.
pa.assignments_intersection = AssignmentsIntersection::Derivations(
pa.dated_derivations
.iter()
.fold(Term::any(), |acc, dated_derivation| {
let term = store[dated_derivation.cause].get(p).unwrap().negate();
acc.intersection(&term)
}),
);
pa.highest_decision_level = last.decision_level;

// Reset the assignments intersection.
pa.assignments_intersection =
AssignmentsIntersection::Derivations(last.accumulated_intersection.clone());
true
}
});
Expand Down Expand Up @@ -400,7 +395,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
incompat: &Incompatibility<P, VS>,
store: &Arena<Incompatibility<P, VS>>,
) -> (P, SatisfierSearch<P, VS>) {
let satisfied_map = Self::find_satisfier(incompat, &self.package_assignments, store);
let satisfied_map = Self::find_satisfier(incompat, &self.package_assignments);
let (satisfier_package, &(satisfier_index, _, satisfier_decision_level)) = satisfied_map
.iter()
.max_by_key(|(_p, (_, global_index, _))| global_index)
Expand Down Expand Up @@ -440,14 +435,13 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
fn find_satisfier(
incompat: &Incompatibility<P, VS>,
package_assignments: &FnvIndexMap<P, PackageAssignments<P, VS>>,
store: &Arena<Incompatibility<P, VS>>,
) -> SmallMap<P, (usize, u32, DecisionLevel)> {
let mut satisfied = SmallMap::Empty;
for (package, incompat_term) in incompat.iter() {
let pa = package_assignments.get(package).expect("Must exist");
satisfied.insert(
package.clone(),
pa.satisfier(package, incompat_term, Term::any(), store),
pa.satisfier(package, incompat_term, &Term::any()),
);
}
satisfied
Expand Down Expand Up @@ -483,7 +477,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P

satisfied_map.insert(
satisfier_package.clone(),
satisfier_pa.satisfier(satisfier_package, incompat_term, accum_term, store),
satisfier_pa.satisfier(satisfier_package, incompat_term, &accum_term),
);

// Finally, let's identify the decision level of that previous satisfier.
Expand All @@ -504,16 +498,15 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
&self,
package: &P,
incompat_term: &Term<VS>,
start_term: Term<VS>,
store: &Arena<Incompatibility<P, VS>>,
start_term: &Term<VS>,
) -> (usize, u32, DecisionLevel) {
// Term where we accumulate intersections until incompat_term is satisfied.
let mut accum_term = start_term;
// Indicate if we found a satisfier in the list of derivations, otherwise it will be the decision.
for (idx, dated_derivation) in self.dated_derivations.iter().enumerate() {
let this_term = store[dated_derivation.cause].get(package).unwrap().negate();
accum_term = accum_term.intersection(&this_term);
if accum_term.subset_of(incompat_term) {
if dated_derivation
.accumulated_intersection
.intersection(start_term)
.subset_of(incompat_term)
{
// We found the derivation causing satisfaction.
return (
idx,
Expand All @@ -524,13 +517,13 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
}
// If it wasn't found in the derivations,
// it must be the decision which is last (if called in the right context).
match self.assignments_intersection {
match &self.assignments_intersection {
AssignmentsIntersection::Decision((global_index, _, _)) => (
self.dated_derivations.len(),
global_index,
*global_index,
self.highest_decision_level,
),
AssignmentsIntersection::Derivations(_) => {
AssignmentsIntersection::Derivations(accumulated_intersection) => {
unreachable!(
concat!(
"while processing package {}: ",
Expand All @@ -539,7 +532,9 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
"but instead it was a derivation. This shouldn't be possible! ",
"(Maybe your Version ordering is broken?)"
),
package, accum_term, incompat_term
package,
accumulated_intersection.intersection(start_term),
incompat_term
)
}
}
Expand Down