Skip to content

Commit

Permalink
perf: do not recompute the intersection when backtracking (#171)
Browse files Browse the repository at this point in the history
* perf: do not recompute the accumulated_intersection when backtracking

* more correct comment message

* make things dry again
  • Loading branch information
Eh2406 authored Dec 21, 2023
1 parent 526775f commit 8e11a81
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
3 changes: 1 addition & 2 deletions src/internal/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> State<P, VS, Priority> {
incompat_changed: bool,
decision_level: DecisionLevel,
) {
self.partial_solution
.backtrack(decision_level, &self.incompatibility_store);
self.partial_solution.backtrack(decision_level);
// Remove contradicted incompatibilities that depend on decisions we just backtracked away.
self.contradicted_incompatibilities
.retain(|_, dl| *dl <= decision_level);
Expand Down
65 changes: 30 additions & 35 deletions src/internal/partial_solution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub struct DatedDerivation<P: Package, VS: VersionSet> {
global_index: u32,
decision_level: DecisionLevel,
cause: IncompId<P, VS>,
accumulated_intersection: Term<VS>,
}

impl<P: Package, VS: VersionSet> Display for DatedDerivation<P, VS> {
Expand Down Expand Up @@ -207,11 +208,11 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
store: &Arena<Incompatibility<P, VS>>,
) {
use indexmap::map::Entry;
let term = store[cause].get(&package).unwrap().negate();
let dated_derivation = DatedDerivation {
let mut dated_derivation = DatedDerivation {
global_index: self.next_global_index,
decision_level: self.current_decision_level,
cause,
accumulated_intersection: store[cause].get(&package).unwrap().negate(),
};
self.next_global_index += 1;
let pa_last_index = self.package_assignments.len().saturating_sub(1);
Expand All @@ -226,7 +227,8 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
panic!("add_derivation should not be called after a decision")
}
AssignmentsIntersection::Derivations(t) => {
*t = t.intersection(&term);
*t = t.intersection(&dated_derivation.accumulated_intersection);
dated_derivation.accumulated_intersection = t.clone();
if t.is_positive() {
// we can use `swap_indices` to make `changed_this_decision_level` only go down by 1
// but the copying is slower then the larger search
Expand All @@ -238,6 +240,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
pa.dated_derivations.push(dated_derivation);
}
Entry::Vacant(v) => {
let term = dated_derivation.accumulated_intersection.clone();
if term.is_positive() {
self.changed_this_decision_level =
std::cmp::min(self.changed_this_decision_level, pa_last_index);
Expand Down Expand Up @@ -297,13 +300,9 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
}

/// Backtrack the partial solution to a given decision level.
pub fn backtrack(
&mut self,
decision_level: DecisionLevel,
store: &Arena<Incompatibility<P, VS>>,
) {
pub fn backtrack(&mut self, decision_level: DecisionLevel) {
self.current_decision_level = decision_level;
self.package_assignments.retain(|p, pa| {
self.package_assignments.retain(|_p, pa| {
if pa.smallest_decision_level > decision_level {
// Remove all entries that have a smallest decision level higher than the backtrack target.
false
Expand All @@ -325,18 +324,14 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
}
debug_assert!(!pa.dated_derivations.is_empty());

let last = pa.dated_derivations.last().unwrap();

// Update highest_decision_level.
pa.highest_decision_level = pa.dated_derivations.last().unwrap().decision_level;

// Recompute the assignments intersection.
pa.assignments_intersection = AssignmentsIntersection::Derivations(
pa.dated_derivations
.iter()
.fold(Term::any(), |acc, dated_derivation| {
let term = store[dated_derivation.cause].get(p).unwrap().negate();
acc.intersection(&term)
}),
);
pa.highest_decision_level = last.decision_level;

// Reset the assignments intersection.
pa.assignments_intersection =
AssignmentsIntersection::Derivations(last.accumulated_intersection.clone());
true
}
});
Expand Down Expand Up @@ -400,7 +395,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
incompat: &Incompatibility<P, VS>,
store: &Arena<Incompatibility<P, VS>>,
) -> (P, SatisfierSearch<P, VS>) {
let satisfied_map = Self::find_satisfier(incompat, &self.package_assignments, store);
let satisfied_map = Self::find_satisfier(incompat, &self.package_assignments);
let (satisfier_package, &(satisfier_index, _, satisfier_decision_level)) = satisfied_map
.iter()
.max_by_key(|(_p, (_, global_index, _))| global_index)
Expand Down Expand Up @@ -440,14 +435,13 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P
fn find_satisfier(
incompat: &Incompatibility<P, VS>,
package_assignments: &FnvIndexMap<P, PackageAssignments<P, VS>>,
store: &Arena<Incompatibility<P, VS>>,
) -> SmallMap<P, (usize, u32, DecisionLevel)> {
let mut satisfied = SmallMap::Empty;
for (package, incompat_term) in incompat.iter() {
let pa = package_assignments.get(package).expect("Must exist");
satisfied.insert(
package.clone(),
pa.satisfier(package, incompat_term, Term::any(), store),
pa.satisfier(package, incompat_term, &Term::any()),
);
}
satisfied
Expand Down Expand Up @@ -483,7 +477,7 @@ impl<P: Package, VS: VersionSet, Priority: Ord + Clone> PartialSolution<P, VS, P

satisfied_map.insert(
satisfier_package.clone(),
satisfier_pa.satisfier(satisfier_package, incompat_term, accum_term, store),
satisfier_pa.satisfier(satisfier_package, incompat_term, &accum_term),
);

// Finally, let's identify the decision level of that previous satisfier.
Expand All @@ -504,16 +498,15 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
&self,
package: &P,
incompat_term: &Term<VS>,
start_term: Term<VS>,
store: &Arena<Incompatibility<P, VS>>,
start_term: &Term<VS>,
) -> (usize, u32, DecisionLevel) {
// Term where we accumulate intersections until incompat_term is satisfied.
let mut accum_term = start_term;
// Indicate if we found a satisfier in the list of derivations, otherwise it will be the decision.
for (idx, dated_derivation) in self.dated_derivations.iter().enumerate() {
let this_term = store[dated_derivation.cause].get(package).unwrap().negate();
accum_term = accum_term.intersection(&this_term);
if accum_term.subset_of(incompat_term) {
if dated_derivation
.accumulated_intersection
.intersection(start_term)
.subset_of(incompat_term)
{
// We found the derivation causing satisfaction.
return (
idx,
Expand All @@ -524,13 +517,13 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
}
// If it wasn't found in the derivations,
// it must be the decision which is last (if called in the right context).
match self.assignments_intersection {
match &self.assignments_intersection {
AssignmentsIntersection::Decision((global_index, _, _)) => (
self.dated_derivations.len(),
global_index,
*global_index,
self.highest_decision_level,
),
AssignmentsIntersection::Derivations(_) => {
AssignmentsIntersection::Derivations(accumulated_intersection) => {
unreachable!(
concat!(
"while processing package {}: ",
Expand All @@ -539,7 +532,9 @@ impl<P: Package, VS: VersionSet> PackageAssignments<P, VS> {
"but instead it was a derivation. This shouldn't be possible! ",
"(Maybe your Version ordering is broken?)"
),
package, accum_term, incompat_term
package,
accumulated_intersection.intersection(start_term),
incompat_term
)
}
}
Expand Down

0 comments on commit 8e11a81

Please sign in to comment.