diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 97d00f5..5343ac7 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,7 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::ConditionalRequirement; using cbindgen_private::Problem; using cbindgen_private::Requirement; @@ -24,6 +25,23 @@ inline Requirement requirement_union(VersionSetUnionId id) { return cbindgen_private::resolvo_requirement_union(id); } +/** + * Specifies a conditional requirement (dependency) of a single version set. + * A solvable belonging to the version set satisfies the requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_single(VersionSetId id) { + return cbindgen_private::resolvo_conditional_requirement_single(id); +} + +/** + * Specifies a conditional requirement (dependency) of the union (logical OR) of multiple version + * sets. A solvable belonging to any of the version sets contained in the union satisfies the + * requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_union(VersionSetUnionId id) { + return cbindgen_private::resolvo_conditional_requirement_union(id); +} + /** * Called to solve a package problem. * diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index 781e365..a35b576 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,95 @@ impl From for resolvo::SolvableId { } } +/// A wrapper around an optional version set id. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct FfiOptionVersionSetId { + pub is_some: bool, + pub value: VersionSetId, +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v.into(), + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value.into()) + } else { + None + } + } +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v, + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value) + } else { + None + } + } +} + +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +/// First VersionSetId is the condition, second is the requirement. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct ConditionalRequirement { + pub condition: FfiOptionVersionSetId, + pub requirement: Requirement, +} + +impl From for ConditionalRequirement { + fn from(value: resolvo::ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + } + } +} + +impl From for resolvo::ConditionalRequirement { + fn from(value: ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. /// cbindgen:derive-eq /// cbindgen:derive-neq @@ -162,7 +251,7 @@ pub struct Dependencies { /// A pointer to the first element of a list of requirements. Requirements /// defines which packages should be installed alongside the depending /// package and the constraints applied to the package. - pub requirements: Vector, + pub requirements: Vector, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set @@ -475,7 +564,7 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { #[repr(C)] pub struct Problem<'a> { - pub requirements: Slice<'a, Requirement>, + pub requirements: Slice<'a, ConditionalRequirement>, pub constraints: Slice<'a, VersionSetId>, pub soft_requirements: Slice<'a, SolvableId>, } @@ -525,6 +614,28 @@ pub extern "C" fn resolvo_solve( } } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_single( + version_set_id: VersionSetId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Single(version_set_id), + } +} + +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_union( + version_set_union_id: VersionSetUnionId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Union(version_set_union_id), + } +} + #[no_mangle] #[allow(unused)] pub extern "C" fn resolvo_requirement_single(version_set_id: VersionSetId) -> Requirement { diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index 1bb02b7..952e86e 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -48,16 +48,17 @@ struct PackageDatabase : public resolvo::DependencyProvider { /** * Allocates a new requirement for a single version set. */ - resolvo::Requirement alloc_requirement(std::string_view package, uint32_t version_start, - uint32_t version_end) { + resolvo::ConditionalRequirement alloc_requirement(std::string_view package, + uint32_t version_start, + uint32_t version_end) { auto id = alloc_version_set(package, version_start, version_end); - return resolvo::requirement_single(id); + return resolvo::conditional_requirement_single(id); } /** * Allocates a new requirement for a version set union. */ - resolvo::Requirement alloc_requirement_union( + resolvo::ConditionalRequirement alloc_requirement_union( std::initializer_list> version_sets) { std::vector version_set_union{version_sets.size()}; @@ -69,7 +70,7 @@ struct PackageDatabase : public resolvo::DependencyProvider { auto id = resolvo::VersionSetUnionId{static_cast(version_set_unions.size())}; version_set_unions.push_back(std::move(version_set_union)); - return resolvo::requirement_union(id); + return resolvo::conditional_requirement_union(id); } /** @@ -219,7 +220,8 @@ SCENARIO("Solve") { const auto d_1 = db.alloc_candidate("d", 1, {}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; + resolvo::Vector requirements = { + db.alloc_requirement("a", 1, 3)}; resolvo::Vector constraints = { db.alloc_version_set("b", 1, 3), db.alloc_version_set("c", 1, 3), @@ -263,7 +265,7 @@ SCENARIO("Solve Union") { "f", 1, {{db.alloc_requirement("b", 1, 10)}, {db.alloc_version_set("a", 10, 20)}}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = { + resolvo::Vector requirements = { db.alloc_requirement_union({{"c", 1, 10}, {"d", 1, 10}}), db.alloc_requirement("e", 1, 10), db.alloc_requirement("f", 1, 10), diff --git a/src/conflict.rs b/src/conflict.rs index 3d121b6..3d08547 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -11,14 +11,13 @@ use petgraph::{ Direction, }; -use crate::solver::variable_map::VariableOrigin; use crate::{ internal::{ arena::ArenaId, id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, }, runtime::AsyncRuntime, - solver::{clause::Clause, Solver}, + solver::{clause::Clause, variable_map::VariableOrigin, Solver}, DependencyProvider, Interner, Requirement, }; @@ -160,6 +159,60 @@ impl Conflict { ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } + &Clause::Conditional( + package_id, + condition_variable, + condition_version_set_id, + requirement, + ) => { + let solvable = package_id + .as_solvable_or_root(&solver.variable_map) + .expect("only solvables can be excluded"); + let package_node = Self::add_node(&mut graph, &mut nodes, solvable); + + let requirement_candidates = solver + .async_runtime + .block_on(solver.cache.get_or_cache_sorted_candidates( + requirement, + )) + .unwrap_or_else(|_| { + unreachable!( + "The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`" + ) + }); + + if requirement_candidates.is_empty() { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?}, which has no candidates" + ); + graph.add_edge( + package_node, + unresolved_node, + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), + ); + } else { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?} if {condition_variable:?}" + ); + + for &candidate_id in requirement_candidates { + let candidate_node = + Self::add_node(&mut graph, &mut nodes, candidate_id.into()); + + graph.add_edge( + package_node, + candidate_node, + ConflictEdge::ConditionalRequires( + condition_version_set_id, + requirement, + ), + ); + } + } + } } } @@ -210,7 +263,7 @@ impl Conflict { } /// A node in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Eq, PartialEq, Debug)] pub(crate) enum ConflictNode { /// Node corresponding to a solvable Solvable(SolvableOrRootId), @@ -239,33 +292,41 @@ impl ConflictNode { } /// An edge in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd, Debug)] pub(crate) enum ConflictEdge { /// The target node is a candidate for the dependency specified by the /// [`Requirement`] Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), + /// The target node is a candidate for a conditional dependency + ConditionalRequires(VersionSetId, Requirement), } impl ConflictEdge { - fn try_requires(self) -> Option { + fn try_requires_or_conditional(self) -> Option<(Requirement, Option)> { match self { - ConflictEdge::Requires(match_spec_id) => Some(match_spec_id), + ConflictEdge::Requires(match_spec_id) => Some((match_spec_id, None)), + ConflictEdge::ConditionalRequires(version_set_id, match_spec_id) => { + Some((match_spec_id, Some(version_set_id))) + } ConflictEdge::Conflict(_) => None, } } - fn requires(self) -> Requirement { + fn requires_or_conditional(self) -> (Requirement, Option) { match self { - ConflictEdge::Requires(match_spec_id) => match_spec_id, + ConflictEdge::Requires(match_spec_id) => (match_spec_id, None), + ConflictEdge::ConditionalRequires(version_set_id, match_spec_id) => { + (match_spec_id, Some(version_set_id)) + } ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"), } } } /// Conflict causes -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Debug)] pub(crate) enum ConflictCause { /// The solvable is locked Locked(SolvableId), @@ -341,6 +402,11 @@ impl ConflictGraph { ConflictEdge::Requires(_) if target != ConflictNode::UnresolvedDependency => { "black" } + ConflictEdge::ConditionalRequires(_, _) + if target != ConflictNode::UnresolvedDependency => + { + "blue" // This indicates that the requirement has candidates, but the condition is not met + } _ => "red", }; @@ -348,6 +414,13 @@ impl ConflictGraph { ConflictEdge::Requires(requirement) => { requirement.display(interner).to_string() } + ConflictEdge::ConditionalRequires(condition_version_set_id, requirement) => { + format!( + "if {} then {}", + Requirement::from(*condition_version_set_id).display(interner), + requirement.display(interner) + ) + } ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } @@ -493,10 +566,15 @@ impl ConflictGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Requires(req) => ((req, None), e.target()), + ConflictEdge::ConditionalRequires(condition, req) => { + ((req, Some(condition)), e.target()) + } ConflictEdge::Conflict(_) => unreachable!(), }) - .chunk_by(|(&version_set_id, _)| version_set_id); + .collect::>() + .into_iter() + .chunk_by(|((&version_set_id, condition), _)| (version_set_id, *condition)); for (_, mut deps) in &dependencies { if deps.all(|(_, target)| !installable.contains(&target)) { @@ -539,10 +617,15 @@ impl ConflictGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Requires(version_set_id) => ((version_set_id, None), e.target()), + ConflictEdge::ConditionalRequires(condition, version_set_id) => { + ((version_set_id, Some(condition)), e.target()) + } ConflictEdge::Conflict(_) => unreachable!(), }) - .chunk_by(|(&version_set_id, _)| version_set_id); + .collect::>() + .into_iter() + .chunk_by(|((&version_set_id, condition), _)| (version_set_id, *condition)); // Missing if at least one dependency is missing if dependencies @@ -629,42 +712,6 @@ impl Indenter { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_indenter_without_top_level_indent() { - let indenter = Indenter::new(false); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), ""); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - } - - #[test] - fn test_indenter_with_multiple_siblings() { - let indenter = Indenter::new(true); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " ├─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " │ ├─ "); - } -} - /// A struct implementing [`fmt::Display`] that generates a user-friendly /// representation of a conflict graph pub struct DisplayUnsat<'i, I: Interner> { @@ -697,11 +744,13 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { top_level_indent: bool, ) -> fmt::Result { pub enum DisplayOp { + ConditionalRequirement((Requirement, VersionSetId), Vec), Requirement(Requirement, Vec), Candidate(NodeIndex), } let graph = &self.graph.graph; + println!("graph {:?}", graph); let installable_nodes = &self.installable_set; let mut reported: HashSet = HashSet::new(); @@ -709,21 +758,26 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let indenter = Indenter::new(top_level_indent); let mut stack = top_level_edges .iter() - .filter(|e| e.weight().try_requires().is_some()) - .chunk_by(|e| e.weight().requires()) + .filter(|e| e.weight().try_requires_or_conditional().is_some()) + .chunk_by(|e| e.weight().requires_or_conditional()) .into_iter() - .map(|(version_set_id, group)| { + .map(|(version_set_id_with_condition, group)| { let edges: Vec<_> = group.map(|e| e.id()).collect(); - (version_set_id, edges) + (version_set_id_with_condition, edges) }) - .sorted_by_key(|(_version_set_id, edges)| { + .sorted_by_key(|(_version_set_id_with_condition, edges)| { edges .iter() .any(|&edge| installable_nodes.contains(&graph.edge_endpoints(edge).unwrap().1)) }) - .map(|(version_set_id, edges)| { + .map(|((version_set_id, condition), edges)| { ( - DisplayOp::Requirement(version_set_id, edges), + if let Some(condition) = condition { + println!("conditional requirement"); + DisplayOp::ConditionalRequirement((version_set_id, condition), edges) + } else { + DisplayOp::Requirement(version_set_id, edges) + }, indenter.push_level(), ) }) @@ -957,7 +1011,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { writeln!(f, "{indent}{version} would require",)?; let mut requirements = graph .edges(candidate) - .chunk_by(|e| e.weight().requires()) + .chunk_by(|e| e.weight().requires_or_conditional()) .into_iter() .map(|(version_set_id, group)| { let edges: Vec<_> = group.map(|e| e.id()).collect(); @@ -969,9 +1023,16 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { .contains(&graph.edge_endpoints(edge).unwrap().1) }) }) - .map(|(version_set_id, edges)| { + .map(|((version_set_id, condition), edges)| { ( - DisplayOp::Requirement(version_set_id, edges), + if let Some(condition) = condition { + DisplayOp::ConditionalRequirement( + (version_set_id, condition), + edges, + ) + } else { + DisplayOp::Requirement(version_set_id, edges) + }, indenter.push_level(), ) }) @@ -984,6 +1045,132 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { stack.extend(requirements); } } + DisplayOp::ConditionalRequirement((requirement, condition), edges) => { + debug_assert!(!edges.is_empty()); + + let installable = edges.iter().any(|&e| { + let (_, target) = graph.edge_endpoints(e).unwrap(); + installable_nodes.contains(&target) + }); + + let req = requirement.display(self.interner).to_string(); + let condition = self.interner.display_version_set(condition); + + let target_nx = graph.edge_endpoints(edges[0]).unwrap().1; + let missing = + edges.len() == 1 && graph[target_nx] == ConflictNode::UnresolvedDependency; + if missing { + // No candidates for requirement + if top_level { + writeln!(f, "{indent} the condition {condition} is true but no candidates were found for {req}.")?; + } else { + writeln!(f, "{indent}{req}, for which no candidates were found.",)?; + } + } else if installable { + // Package can be installed (only mentioned for top-level requirements) + if top_level { + writeln!( + f, + "{indent}due to the condition {condition}, {req} can be installed with any of the following options:" + )?; + } else { + writeln!(f, "{indent}{req}, which can be installed with any of the following options:")?; + } + + let children: Vec<_> = edges + .iter() + .filter(|&&e| { + installable_nodes.contains(&graph.edge_endpoints(e).unwrap().1) + }) + .map(|&e| { + ( + DisplayOp::Candidate(graph.edge_endpoints(e).unwrap().1), + indenter.push_level(), + ) + }) + .collect(); + + // TODO: this is an utterly ugly hack that should be burnt to ashes + let mut deduplicated_children = Vec::new(); + let mut merged_and_seen = HashSet::new(); + for child in children { + let (DisplayOp::Candidate(child_node), _) = child else { + unreachable!() + }; + let solvable_id = graph[child_node].solvable_or_root(); + let Some(solvable_id) = solvable_id.solvable() else { + continue; + }; + + let merged = self.merged_candidates.get(&solvable_id); + + // Skip merged stuff that we have already seen + if merged_and_seen.contains(&solvable_id) { + continue; + } + + if let Some(merged) = merged { + merged_and_seen.extend(merged.ids.iter().copied()) + } + + deduplicated_children.push(child); + } + + if !deduplicated_children.is_empty() { + deduplicated_children[0].1.set_last(); + } + + stack.extend(deduplicated_children); + } else { + // Package cannot be installed (the conflicting requirement is further down + // the tree) + if top_level { + writeln!(f, "{indent}The condition {condition} is true but {req} cannot be installed because there are no viable options:")?; + } else { + writeln!(f, "{indent}{req}, which cannot be installed because there are no viable options:")?; + } + + let children: Vec<_> = edges + .iter() + .map(|&e| { + ( + DisplayOp::Candidate(graph.edge_endpoints(e).unwrap().1), + indenter.push_level(), + ) + }) + .collect(); + + // TODO: this is an utterly ugly hack that should be burnt to ashes + let mut deduplicated_children = Vec::new(); + let mut merged_and_seen = HashSet::new(); + for child in children { + let (DisplayOp::Candidate(child_node), _) = child else { + unreachable!() + }; + let Some(solvable_id) = graph[child_node].solvable() else { + continue; + }; + let merged = self.merged_candidates.get(&solvable_id); + + // Skip merged stuff that we have already seen + if merged_and_seen.contains(&solvable_id) { + continue; + } + + if let Some(merged) = merged { + merged_and_seen.extend(merged.ids.iter().copied()) + } + + deduplicated_children.push(child); + } + + if !deduplicated_children.is_empty() { + deduplicated_children[0].1.set_last(); + } + + stack.extend(deduplicated_children); + } + } } } @@ -1020,6 +1207,7 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { let conflict = match e.weight() { ConflictEdge::Requires(_) => continue, ConflictEdge::Conflict(conflict) => conflict, + ConflictEdge::ConditionalRequires(_, _) => continue, }; // The only possible conflict at the root level is a Locked conflict @@ -1052,3 +1240,39 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_indenter_without_top_level_indent() { + let indenter = Indenter::new(false); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), ""); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + } + + #[test] + fn test_indenter_with_multiple_siblings() { + let indenter = Indenter::new(true); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " ├─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " │ ├─ "); + } +} diff --git a/src/internal/id.rs b/src/internal/id.rs index 47fe226..5f87e34 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -22,6 +22,23 @@ impl ArenaId for NameId { } } +/// The id associated to an extra +#[repr(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +pub struct ExtraId(pub u32); + +impl ArenaId for ExtraId { + fn from_usize(x: usize) -> Self { + Self(x as u32) + } + + fn to_usize(self) -> usize { + self.0 as usize + } +} + /// The id associated with a generic string #[repr(transparent)] #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] @@ -46,6 +63,12 @@ impl ArenaId for StringId { #[cfg_attr(feature = "serde", serde(transparent))] pub struct VersionSetId(pub u32); +impl From<(VersionSetId, Option)> for VersionSetId { + fn from((id, _): (VersionSetId, Option)) -> Self { + id + } +} + impl ArenaId for VersionSetId { fn from_usize(x: usize) -> Self { Self(x as u32) diff --git a/src/lib.rs b/src/lib.rs index 575c678..74eb27e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub use internal::{ mapping::Mapping, }; use itertools::Itertools; -pub use requirement::Requirement; +pub use requirement::{ConditionalRequirement, Requirement}; pub use solver::{Problem, Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of @@ -206,7 +206,7 @@ pub struct KnownDependencies { feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty") )] - pub requirements: Vec, + pub requirements: Vec, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set diff --git a/src/requirement.rs b/src/requirement.rs index 244ec48..938575c 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -1,7 +1,92 @@ -use crate::{Interner, VersionSetId, VersionSetUnionId}; +use crate::{Interner, StringId, VersionSetId, VersionSetUnionId}; use itertools::Itertools; use std::fmt::Display; +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum Condition { + VersionSetId(VersionSetId), + Extra(StringId), +} + +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ConditionalRequirement { + /// The conditions that must be met for the requirement to be active. + pub conditions: Vec, + /// The requirement that is only active when the condition is met. + pub requirement: Requirement, +} + +impl ConditionalRequirement { + /// Creates a new conditional requirement. + pub fn new(conditions: Vec, requirement: Requirement) -> Self { + Self { + conditions, + requirement, + } + } + /// Returns the version sets that satisfy the requirement. + pub fn requirement_version_sets<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator + 'i { + self.requirement.version_sets(interner) + } + + /// Returns the version sets that satisfy the requirement, along with the condition that must be met. + pub fn version_sets_with_condition<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator)> + 'i { + self.requirement + .version_sets(interner) + .map(move |vs| (vs, self.conditions.clone())) + } + + /// Returns the condition and requirement. + pub fn into_condition_and_requirement(self) -> (Vec, Requirement) { + (self.conditions, self.requirement) + } +} + +impl From for ConditionalRequirement { + fn from(value: Requirement) -> Self { + Self { + conditions: vec![], + requirement: value, + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetId) -> Self { + Self { + conditions: vec![], + requirement: value.into(), + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetUnionId) -> Self { + Self { + conditions: vec![], + requirement: value.into(), + } + } +} + +impl From<(VersionSetId, Vec)> for ConditionalRequirement { + fn from((requirement, conditions): (VersionSetId, Vec)) -> Self { + Self { + conditions, + requirement: requirement.into(), + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] diff --git a/src/snapshot.rs b/src/snapshot.rs index 0b8b6d2..ab6d926 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -220,7 +220,15 @@ impl DependencySnapshot { } } - for &requirement in deps.requirements.iter() { + for &req in deps.requirements.iter() { + let (condition, requirement) = req.into_condition_and_requirement(); + + if let Some(condition) = condition { + if seen.insert(Element::VersionSet(condition)) { + queue.push_back(Element::VersionSet(condition)); + } + } + match requirement { Requirement::Single(version_set) => { if seen.insert(Element::VersionSet(version_set)) { diff --git a/src/solver/clause.rs b/src/solver/clause.rs index f034130..99bb693 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -46,7 +46,7 @@ use crate::{ /// limited set of clauses. There are thousands of clauses for a particular /// dependency resolution problem, and we try to keep the [`Clause`] enum small. /// A naive implementation would store a `Vec`. -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub(crate) enum Clause { /// An assertion that the root solvable must be installed /// @@ -77,6 +77,12 @@ pub(crate) enum Clause { /// /// In SAT terms: (¬A ∨ ¬B) Constrains(VariableId, VariableId, VersionSetId), + /// In SAT terms: (¬A ∨ ¬C ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C is the condition, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + /// We need to store the version set id because in the conflict graph, the version set id + /// is used to identify the condition variable. + Conditional(VariableId, VariableId, VersionSetId, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -230,6 +236,43 @@ impl Clause { ) } + fn conditional( + parent_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = + if decision_tracker.assigned_value(condition_variable) == Some(true) { + // then ~condition is false + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + None + }; + + ( + Clause::Conditional( + parent_id, + condition_variable, + condition_version_set_id, + requirement, + ), + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variable.negative()), + ]), + requirement_literal.is_none() + && decision_tracker.assigned_value(condition_variable) == Some(true), + ) + } + /// Tries to fold over all the literals in the clause. /// /// This function is useful to iterate, find, or filter the literals in a @@ -272,6 +315,17 @@ impl Clause { Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), + Clause::Conditional(package_id, condition_variable, _, requirement) => { + iter::once(package_id.negative()) + .chain(iter::once(condition_variable.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } } } @@ -419,6 +473,35 @@ impl WatchedLiterals { (Self::from_kind_and_initial_watches(watched_literals), kind) } + /// Shorthand method to construct a [Clause::Conditional] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn conditional( + package_id: VariableId, + requirement: Requirement, + condition_variable: VariableId, + condition_version_set_id: VersionSetId, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional( + package_id, + requirement, + condition_variable, + condition_version_set_id, + decision_tracker, + requirement_candidates, + ); + + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { let watched_literals = watched_literals?; debug_assert!(watched_literals[0] != watched_literals[1]); @@ -611,6 +694,17 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } + Clause::Conditional(package_id, condition_variable, _, requirement) => { + write!( + f, + "Conditional({}({:?}), {}({:?}), {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition_variable.display(self.variable_map, self.interner), + condition_variable, + requirement.display(self.interner), + ) + } } } } @@ -671,17 +765,11 @@ mod test { clause.as_ref().unwrap().watched_literals[0].variable(), parent ); - assert_eq!( - clause.unwrap().watched_literals[1].variable(), - candidate1.into() - ); + assert_eq!(clause.unwrap().watched_literals[1].variable(), candidate1); // No conflict, still one candidate available decisions - .try_add_decision( - Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), - 1, - ) + .try_add_decision(Decision::new(candidate1, false, ClauseId::from_usize(0)), 1) .unwrap(); let (clause, conflict, _kind) = WatchedLiterals::requires( parent, @@ -696,13 +784,13 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate2.into() + candidate2 ); // Conflict, no candidates available decisions .try_add_decision( - Decision::new(candidate2.into(), false, ClauseId::install_root()), + Decision::new(candidate2, false, ClauseId::install_root()), 1, ) .unwrap(); @@ -719,7 +807,7 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate1.into() + candidate1 ); // Panic diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 8c0e026..bda9bd4 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -19,6 +19,7 @@ use crate::{ id::{ClauseId, LearntClauseId, NameId, SolvableId, SolvableOrRootId, VariableId}, mapping::Mapping, }, + requirement::ConditionalRequirement, runtime::{AsyncRuntime, NowOrNeverRuntime}, solver::binary_encoding::AtMostOnceTracker, Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, @@ -36,6 +37,7 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { new_requires_clauses: Vec<(VariableId, Requirement, ClauseId)>, + new_conditional_clauses: Vec<(VariableId, VariableId, Requirement, ClauseId)>, conflicting_clauses: Vec, negative_assertions: Vec<(VariableId, ClauseId)>, clauses_to_watch: Vec, @@ -51,7 +53,7 @@ struct AddClauseOutput { /// This struct follows the builder pattern and can have its fields set by one /// of the available setter methods. pub struct Problem { - requirements: Vec, + requirements: Vec, constraints: Vec, soft_requirements: S, } @@ -80,7 +82,7 @@ impl> Problem { /// /// Returns the [`Problem`] for further mutation or to pass to /// [`Solver::solve`]. - pub fn requirements(self, requirements: Vec) -> Self { + pub fn requirements(self, requirements: Vec) -> Self { Self { requirements, ..self @@ -150,6 +152,8 @@ pub struct Solver { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, + conditional_clauses: + IndexMap<(VariableId, VariableId), Vec<(Requirement, ClauseId)>, ahash::RandomState>, watches: WatchMap, /// A mapping from requirements to the variables that represent the @@ -172,7 +176,7 @@ pub struct Solver { decision_tracker: DecisionTracker, /// The [`Requirement`]s that must be installed as part of the solution. - root_requirements: Vec, + root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, @@ -200,6 +204,7 @@ impl Solver { clauses: Clauses::default(), variable_map: VariableMap::default(), requires_clauses: Default::default(), + conditional_clauses: Default::default(), requirement_to_sorted_candidates: FrozenMap::default(), watches: WatchMap::new(), negative_assertions: Default::default(), @@ -213,7 +218,6 @@ impl Solver { clauses_added_for_solvable: Default::default(), forbidden_clauses_added: Default::default(), name_activity: Default::default(), - activity_add: 1.0, activity_decay: 0.95, } @@ -280,6 +284,7 @@ impl Solver { clauses: self.clauses, variable_map: self.variable_map, requires_clauses: self.requires_clauses, + conditional_clauses: self.conditional_clauses, requirement_to_sorted_candidates: self.requirement_to_sorted_candidates, watches: self.watches, negative_assertions: self.negative_assertions, @@ -660,6 +665,16 @@ impl Solver { .or_default() .push((requirement, clause_id)); } + + for (solvable_id, condition_variable, requirement, clause_id) in + output.new_conditional_clauses + { + self.conditional_clauses + .entry((solvable_id, condition_variable)) + .or_default() + .push((requirement, clause_id)); + } + self.negative_assertions .append(&mut output.negative_assertions); @@ -695,7 +710,7 @@ impl Solver { fn resolve_dependencies(&mut self, mut level: u32) -> Result { loop { // Make a decision. If no decision could be made it means the problem is - // satisfyable. + // satisfiable. let Some((candidate, required_by, clause_id)) = self.decide() else { break; }; @@ -767,8 +782,36 @@ impl Solver { } let mut best_decision: Option = None; - for (&solvable_id, requirements) in self.requires_clauses.iter() { + + // Chain together the requires_clauses and conditional_clauses iterations + let requires_iter = self + .requires_clauses + .iter() + .map(|(&solvable_id, requirements)| { + ( + solvable_id, + None, + requirements + .iter() + .map(|(r, c)| (*r, *c)) + .collect::>(), + ) + }); + + let conditional_iter = + self.conditional_clauses + .iter() + .map(|((solvable_id, condition), clauses)| { + ( + *solvable_id, + Some(*condition), + clauses.iter().map(|(r, c)| (*r, *c)).collect::>(), + ) + }); + + for (solvable_id, condition, requirements) in requires_iter.chain(conditional_iter) { let is_explicit_requirement = solvable_id == VariableId::root(); + if let Some(best_decision) = &best_decision { // If we already have an explicit requirement, there is no need to evaluate // non-explicit requirements. @@ -782,11 +825,25 @@ impl Solver { continue; } - for (deps, clause_id) in requirements.iter() { + // For conditional clauses, check that at least one conditional variable is true + if let Some(condition_variable) = condition { + // Check if any candidate that matches the condition's version set is installed + let condition_met = + self.decision_tracker.assigned_value(condition_variable) == Some(true); + + // If the condition is not met, skip this requirement entirely + if !condition_met { + continue; + } + } + + for (requirement, clause_id) in requirements { let mut candidate = ControlFlow::Break(()); // Get the candidates for the individual version sets. - let version_set_candidates = &self.requirement_to_sorted_candidates[deps]; + let version_set_candidates = &self.requirement_to_sorted_candidates[&requirement]; + + let version_sets = requirement.version_sets(self.provider()); // Iterate over all version sets in the requirement and find the first version // set that we can act on, or if a single candidate (from any version set) makes @@ -795,10 +852,7 @@ impl Solver { // NOTE: We zip the version sets from the requirements and the variables that we // previously cached. This assumes that the order of the version sets is the // same in both collections. - for (version_set, candidates) in deps - .version_sets(self.provider()) - .zip(version_set_candidates) - { + for (version_set, candidates) in version_sets.zip(version_set_candidates) { // Find the first candidate that is not yet assigned a value or find the first // value that makes this clause true. candidate = candidates.iter().try_fold( @@ -875,7 +929,7 @@ impl Solver { candidate_count, package_activity, ))) => { - let decision = (candidate, solvable_id, *clause_id); + let decision = (candidate, solvable_id, clause_id); best_decision = Some(match &best_decision { None => PossibleDecision { is_explicit_requirement, @@ -1519,7 +1573,7 @@ async fn add_clauses_for_solvables( RequirementCandidateVariables, ahash::RandomState, >, - root_requirements: &[Requirement], + root_requirements: &[ConditionalRequirement], root_constraints: &[VersionSetId], ) -> Result> { let mut output = AddClauseOutput::default(); @@ -1534,6 +1588,7 @@ async fn add_clauses_for_solvables( SortedCandidates { solvable_id: SolvableOrRootId, requirement: Requirement, + condition: Option<(SolvableId, VersionSetId)>, candidates: Vec<&'i [SolvableId]>, }, NonMatchingCandidates { @@ -1615,7 +1670,7 @@ async fn add_clauses_for_solvables( None => variable_map.root(), }; - let (requirements, constrains) = match dependencies { + let (conditional_requirements, constrains) = match dependencies { Dependencies::Known(deps) => (deps.requirements, deps.constrains), Dependencies::Unknown(reason) => { // There is no information about the solvable's dependencies, so we add @@ -1637,17 +1692,29 @@ async fn add_clauses_for_solvables( } }; - for version_set_id in requirements + for (version_set_id, condition) in conditional_requirements .iter() - .flat_map(|requirement| requirement.version_sets(cache.provider())) - .chain(constrains.iter().copied()) + .flat_map(|conditional_requirement| { + conditional_requirement.version_sets_with_condition(cache.provider()) + }) + .chain(constrains.iter().map(|&vs| (vs, None))) { let dependency_name = cache.provider().version_set_name(version_set_id); if clauses_added_for_package.insert(dependency_name) { - tracing::trace!( - "┝━ Adding clauses for package '{}'", - cache.provider().display_name(dependency_name), - ); + if let Some(condition) = condition { + let condition_name = cache.provider().version_set_name(condition); + tracing::trace!( + "┝━ Adding conditional clauses for package '{}' with condition '{}' and version set '{}'", + cache.provider().display_name(dependency_name), + cache.provider().display_name(condition_name), + cache.provider().display_version_set(condition), + ); + } else { + tracing::trace!( + "┝━ Adding clauses for package '{}'", + cache.provider().display_name(dependency_name), + ); + } pending_futures.push( async move { @@ -1660,32 +1727,70 @@ async fn add_clauses_for_solvables( } .boxed_local(), ); + + if let Some(condition) = condition { + let condition_name = cache.provider().version_set_name(condition); + if clauses_added_for_package.insert(condition_name) { + pending_futures.push( + async move { + let condition_candidates = + cache.get_or_cache_candidates(condition_name).await?; + Ok(TaskResult::Candidates { + name_id: condition_name, + package_candidates: condition_candidates, + }) + } + .boxed_local(), + ); + } + } } } - for requirement in requirements { + for conditional_requirement in conditional_requirements { // Find all the solvable that match for the given version set - pending_futures.push( - async move { - let candidates = futures::future::try_join_all( - requirement - .version_sets(cache.provider()) - .map(|version_set| { - cache.get_or_cache_sorted_candidates_for_version_set( - version_set, - ) - }), - ) - .await?; - - Ok(TaskResult::SortedCandidates { - solvable_id, - requirement, - candidates, - }) + let version_sets = + conditional_requirement.requirement_version_sets(cache.provider()); + let candidates = + futures::future::try_join_all(version_sets.map(|version_set| { + cache.get_or_cache_sorted_candidates_for_version_set(version_set) + })) + .await?; + + // condition is `VersionSetId` right now but it will become a `Requirement` + // in the next versions of resolvo + if let Some(condition) = conditional_requirement.condition { + let condition_candidates = + cache.get_or_cache_matching_candidates(condition).await?; + + for &condition_candidate in condition_candidates { + let candidates = candidates.clone(); + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: Some((condition_candidate, condition)), + candidates, + }) + } + .boxed_local(), + ); } - .boxed_local(), - ); + } else { + // Add a task result for the condition + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + condition: None, + candidates: candidates.clone(), + }) + } + .boxed_local(), + ); + } } for version_set_id in constrains { @@ -1751,6 +1856,7 @@ async fn add_clauses_for_solvables( TaskResult::SortedCandidates { solvable_id, requirement, + condition, candidates, } => { tracing::trace!( @@ -1820,30 +1926,70 @@ async fn add_clauses_for_solvables( ); } - // Add the requirements clause - let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); - let (watched_literals, conflict, kind) = WatchedLiterals::requires( - variable, - requirement, - version_set_variables.iter().flatten().copied(), - decision_tracker, - ); - let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + if let Some((condition, condition_version_set_id)) = condition { + let condition_variable = variable_map.intern_solvable(condition); - if has_watches { - output.clauses_to_watch.push(clause_id); - } + // Add a condition clause + let (watched_literals, conflict, kind) = WatchedLiterals::conditional( + variable, + requirement, + condition_variable, + condition_version_set_id, + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); + + // Add the conditional clause + let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + condition_variable, + requirement, + clause_id, + )); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } else { + let (watched_literals, conflict, kind) = WatchedLiterals::requires( + variable, + requirement, + version_set_variables.iter().flatten().copied(), + decision_tracker, + ); + + // Add the requirements clause + let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } - output - .new_requires_clauses - .push((variable, requirement, clause_id)); + output + .new_requires_clauses + .push((variable, requirement, clause_id)); - if conflict { - output.conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - output.negative_assertions.push((variable, clause_id)); + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } } TaskResult::NonMatchingCandidates { diff --git a/src/utils/pool.rs b/src/utils/pool.rs index 2a3b6fe..271efdf 100644 --- a/src/utils/pool.rs +++ b/src/utils/pool.rs @@ -6,7 +6,7 @@ use std::{ use crate::internal::{ arena::Arena, frozen_copy_map::FrozenCopyMap, - id::{NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, + id::{ExtraId, NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, small_vec::SmallVec, }; @@ -43,6 +43,12 @@ pub struct Pool { /// Map from package names to the id of their interned counterpart pub(crate) string_to_ids: FrozenCopyMap, + /// Interned extras + extras: Arena, + + /// Map from package names and their extras to the id of their interned counterpart + pub(crate) extra_to_ids: FrozenCopyMap<(NameId, String), ExtraId, ahash::RandomState>, + /// Interned match specs pub(crate) version_sets: Arena, @@ -62,6 +68,8 @@ impl Default for Pool { package_names: Arena::new(), strings: Arena::new(), string_to_ids: Default::default(), + extras: Arena::new(), + extra_to_ids: Default::default(), version_set_to_id: Default::default(), version_sets: Arena::new(), version_set_unions: Arena::new(), @@ -116,6 +124,34 @@ impl Pool { next_id } + /// Interns an extra into the [`Pool`], returning its [`StringId`]. Extras + /// are deduplicated. If the same extra is inserted twice the same + /// [`StringId`] will be returned. + /// + /// The original extra can be resolved using the + /// [`Self::resolve_extra`] function. + pub fn intern_extra( + &self, + package_id: NameId, + extra_name: impl Into + AsRef, + ) -> ExtraId { + if let Some(id) = self + .extra_to_ids + .get_copy(&(package_id, extra_name.as_ref().to_string())) + { + return id; + } + + let extra = extra_name.into(); + let id = self.extras.alloc((package_id, extra)); + self.extra_to_ids.insert_copy((package_id, extra), id); + id + } + + pub fn resolve_extra(&self, extra_id: ExtraId) -> &(NameId, String) { + &self.extras[extra_id] + } + /// Returns the package name associated with the provided [`NameId`]. /// /// Panics if the package name is not found in the pool. @@ -123,6 +159,13 @@ impl Pool { &self.package_names[name_id] } + /// Returns the extra associated with the provided [`StringId`]. + /// + /// Panics if the extra is not found in the pool. + // pub fn resolve_extra(&self, package_id: NameId, extra_id: StringId) -> &str { + // &self.strings[self.extra_to_ids.get_copy(&(package_id, extra_id)).unwrap()] + // } + /// Returns the [`NameId`] associated with the specified name or `None` if /// the name has not previously been interned using /// [`Self::intern_package_name`]. diff --git a/tests/solver.rs b/tests/solver.rs index de15d8a..792d9dc 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -18,13 +18,13 @@ use std::{ use ahash::HashMap; use indexmap::IndexMap; use insta::assert_snapshot; -use itertools::Itertools; +use itertools::{ExactlyOneError, Itertools}; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::Pool, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Problem, - Requirement, SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, - VersionSetUnionId, + Candidates, ConditionalRequirement, Dependencies, DependencyProvider, ExtraId, Interner, + KnownDependencies, NameId, Problem, Requirement, SolvableId, Solver, SolverCache, StringId, + UnsolvableOrCancelled, VersionSetId, VersionSetUnionId, }; use tracing_test::traced_test; use version_ranges::Ranges; @@ -113,19 +113,29 @@ impl FromStr for Pack { struct Spec { name: String, versions: Ranges, + condition: Option>, + extras: Vec, } impl Spec { - pub fn new(name: String, versions: Ranges) -> Self { - Self { name, versions } + pub fn new( + name: String, + versions: Ranges, + condition: Option>, + extras: Vec, + ) -> Self { + Self { + name, + versions, + condition, + extras, + } } pub fn parse_union( spec: &str, ) -> impl Iterator::Err>> + '_ { - spec.split('|') - .map(str::trim) - .map(|dep| Spec::from_str(dep)) + spec.split('|').map(str::trim).map(Spec::from_str) } } @@ -133,11 +143,34 @@ impl FromStr for Spec { type Err = (); fn from_str(s: &str) -> Result { - let split = s.split(' ').collect::>(); - let name = split - .first() - .expect("spec does not have a name") - .to_string(); + let split = s.split_once("; if"); + + if split.is_none() { + let split = s.split(' ').collect::>(); + + // Extract feature name from brackets if present + let name_parts: Vec<_> = split[0].split('[').collect(); + let (name, extras) = if name_parts.len() > 1 { + // Has features in brackets + let extras = name_parts[1] + .trim_end_matches(']') + .split(',') + .map(|f| f.trim().to_string()) + .collect::>(); + (name_parts[0].to_string(), extras) + } else { + (name_parts[0].to_string(), vec![]) + }; + + let versions = version_range(split.get(1)); + return Ok(Spec::new(name, versions, None, extras)); + } + + let (spec, condition) = split.unwrap(); + + let condition = Spec::parse_union(condition).next().unwrap().unwrap(); + + let spec = Spec::from_str(spec).unwrap(); fn version_range(s: Option<&&str>) -> Ranges { if let Some(s) = s { @@ -156,9 +189,12 @@ impl FromStr for Spec { } } - let versions = version_range(split.get(1)); - - Ok(Spec::new(name, versions)) + Ok(Spec::new( + spec.name, + spec.versions, + Some(Box::new(condition)), + spec.extras, + )) } } @@ -187,6 +223,7 @@ struct BundleBoxProvider { struct BundleBoxPackageDependencies { dependencies: Vec>, constrains: Vec, + extras: HashMap>>, } impl BundleBoxProvider { @@ -200,11 +237,19 @@ impl BundleBoxProvider { .expect("package missing") } - pub fn requirements>(&self, requirements: &[&str]) -> Vec { + pub fn requirements)>>( + &self, + requirements: &[&str], + ) -> Vec { requirements .iter() .map(|dep| Spec::from_str(dep).unwrap()) - .map(|spec| self.intern_version_set(&spec)) + .map(|spec| { + ( + self.intern_version_set(&spec), + spec.condition.as_ref().map(|c| self.intern_version_set(c)), + ) + }) .map(From::from) .collect() } @@ -236,10 +281,10 @@ impl BundleBoxProvider { .intern_version_set_union(specs.next().unwrap(), specs) } - pub fn from_packages(packages: &[(&str, u32, Vec<&str>)]) -> Self { + pub fn from_packages(packages: &[(&str, u32, Vec<&str>, &[(&str, &[&str])])]) -> Self { let mut result = Self::new(); - for (name, version, deps) in packages { - result.add_package(name, Pack::new(*version), deps, &[]); + for (name, version, deps, extras) in packages { + result.add_package(name, Pack::new(*version), deps, &[], extras); } result } @@ -267,8 +312,9 @@ impl BundleBoxProvider { package_version: Pack, dependencies: &[&str], constrains: &[&str], + extras: &[(&str, &[&str])], ) { - self.pool.intern_package_name(package_name); + let package_id = self.pool.intern_package_name(package_name); let dependencies = dependencies .iter() @@ -276,6 +322,19 @@ impl BundleBoxProvider { .collect::, _>>() .unwrap(); + let extras = extras + .iter() + .map(|(key, values)| { + (self.pool.intern_extra(package_id, key), { + values + .iter() + .map(|dep| Spec::parse_union(dep).collect()) + .collect::, _>>() + .unwrap() + }) + }) + .collect::>(); + let constrains = constrains .iter() .map(|dep| Spec::from_str(dep)) @@ -290,6 +349,7 @@ impl BundleBoxProvider { BundleBoxPackageDependencies { dependencies, constrains, + extras, }, ); } @@ -386,7 +446,7 @@ impl DependencyProvider for BundleBoxProvider { candidates .iter() .copied() - .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) == !inverse) + .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) != inverse) .collect() } @@ -502,18 +562,44 @@ impl DependencyProvider for BundleBoxProvider { .intern_version_set(first_name, first.versions.clone()); let requirement = if remaining_req_specs.len() == 0 { - first_version_set.into() + if let Some(condition) = &first.condition { + ConditionalRequirement::new( + Some(self.intern_version_set(condition)), + first_version_set.into(), + ) + } else { + first_version_set.into() + } } else { - let other_version_sets = remaining_req_specs.map(|spec| { - self.pool.intern_version_set( + // Check if all specs have the same condition + let common_condition = first.condition.as_ref().map(|c| self.intern_version_set(c)); + + // Collect version sets for union + let mut version_sets = vec![first_version_set]; + for spec in remaining_req_specs { + // Verify condition matches + if spec.condition.as_ref().map(|c| self.intern_version_set(c)) + != common_condition + { + panic!("All specs in a union must have the same condition"); + } + + version_sets.push(self.pool.intern_version_set( self.pool.intern_package_name(&spec.name), spec.versions.clone(), - ) - }); - - self.pool - .intern_version_set_union(first_version_set, other_version_sets) - .into() + )); + } + + // Create union and wrap in conditional if needed + let union = self + .pool + .intern_version_set_union(version_sets[0], version_sets.into_iter().skip(1)); + + if let Some(condition) = common_condition { + ConditionalRequirement::new(Some(condition), union.into()) + } else { + union.into() + } }; result.requirements.push(requirement); @@ -538,7 +624,7 @@ impl DependencyProvider for BundleBoxProvider { } /// Create a string from a [`Transaction`] -fn transaction_to_string(interner: &impl Interner, solvables: &Vec) -> String { +fn transaction_to_string(interner: &impl Interner, solvables: &[SolvableId]) -> String { use std::fmt::Write; let mut buf = String::new(); for solvable in solvables @@ -590,7 +676,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); - let problem = Problem::new().requirements(requirements); + let problem = Problem::new().requirements(requirements.into_iter().map(|r| r.into()).collect()); match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { @@ -703,12 +789,12 @@ fn test_resolve_with_concurrent_metadata_fetching() { #[test] fn test_resolve_with_conflict() { let provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec!["conflicting 1"]), - ("asdf", 3, vec!["conflicting 0"]), - ("efgh", 7, vec!["conflicting 0"]), - ("efgh", 6, vec!["conflicting 0"]), - ("conflicting", 1, vec![]), - ("conflicting", 0, vec![]), + ("asdf", 4, vec!["conflicting 1"], &[]), + ("asdf", 3, vec!["conflicting 0"], &[]), + ("efgh", 7, vec!["conflicting 0"], &[]), + ("efgh", 6, vec!["conflicting 0"], &[]), + ("conflicting", 1, vec![], &[]), + ("conflicting", 0, vec![], &[]), ]); let result = solve_snapshot(provider, &["asdf", "efgh"]); insta::assert_snapshot!(result); @@ -719,9 +805,9 @@ fn test_resolve_with_conflict() { #[traced_test] fn test_resolve_with_nonexisting() { let provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec!["b"]), - ("asdf", 3, vec![]), - ("b", 1, vec!["idontexist"]), + ("asdf", 4, vec!["b"], &[]), + ("asdf", 3, vec![], &[]), + ("b", 1, vec!["idontexist"], &[]), ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); @@ -745,18 +831,25 @@ fn test_resolve_with_nested_deps() { "apache-airflow", 3, vec!["opentelemetry-api 2..4", "opentelemetry-exporter-otlp"], + &[], ), ( "apache-airflow", 2, vec!["opentelemetry-api 2..4", "opentelemetry-exporter-otlp"], + &[], ), - ("apache-airflow", 1, vec![]), - ("opentelemetry-api", 3, vec!["opentelemetry-sdk"]), - ("opentelemetry-api", 2, vec![]), - ("opentelemetry-api", 1, vec![]), - ("opentelemetry-exporter-otlp", 1, vec!["opentelemetry-grpc"]), - ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"]), + ("apache-airflow", 1, vec![], &[]), + ("opentelemetry-api", 3, vec!["opentelemetry-sdk"], &[]), + ("opentelemetry-api", 2, vec![], &[]), + ("opentelemetry-api", 1, vec![], &[]), + ( + "opentelemetry-exporter-otlp", + 1, + vec!["opentelemetry-grpc"], + &[], + ), + ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"], &[]), ]); let requirements = provider.requirements(&["apache-airflow"]); let mut solver = Solver::new(provider); @@ -781,8 +874,9 @@ fn test_resolve_with_unknown_deps() { Pack::new(3).with_unknown_deps(), &[], &[], + &[], ); - provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); + provider.add_package("opentelemetry-api", Pack::new(2), &[], &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); let mut solver = Solver::new(provider); let problem = Problem::new().requirements(requirements); @@ -809,12 +903,14 @@ fn test_resolve_and_cancel() { Pack::new(3).with_unknown_deps(), &[], &[], + &[], ); provider.add_package( "opentelemetry-api", Pack::new(2).cancel_during_get_dependencies(), &[], &[], + &[], ); let error = solve_unsat(provider, &["opentelemetry-api"]); insta::assert_snapshot!(error); @@ -825,7 +921,7 @@ fn test_resolve_and_cancel() { #[test] fn test_resolve_locked_top_level() { let mut provider = - BundleBoxProvider::from_packages(&[("asdf", 4, vec![]), ("asdf", 3, vec![])]); + BundleBoxProvider::from_packages(&[("asdf", 4, vec![], &[]), ("asdf", 3, vec![], &[])]); provider.set_locked("asdf", 3); let requirements = provider.requirements(&["asdf"]); @@ -845,9 +941,9 @@ fn test_resolve_locked_top_level() { #[test] fn test_resolve_ignored_locked_top_level() { let mut provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec![]), - ("asdf", 3, vec!["fgh"]), - ("fgh", 1, vec![]), + ("asdf", 4, vec![], &[]), + ("asdf", 3, vec!["fgh"], &[]), + ("fgh", 1, vec![], &[]), ]); provider.set_locked("fgh", 1); @@ -869,10 +965,10 @@ fn test_resolve_ignored_locked_top_level() { #[test] fn test_resolve_favor_without_conflict() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec![]), - ("a", 2, vec![]), - ("b", 1, vec![]), - ("b", 2, vec![]), + ("a", 1, vec![], &[]), + ("a", 2, vec![], &[]), + ("b", 1, vec![], &[]), + ("b", 2, vec![], &[]), ]); provider.set_favored("a", 1); provider.set_favored("b", 1); @@ -888,12 +984,12 @@ fn test_resolve_favor_without_conflict() { #[test] fn test_resolve_favor_with_conflict() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec!["c 1"]), - ("a", 2, vec![]), - ("b", 1, vec!["c 1"]), - ("b", 2, vec!["c 2"]), - ("c", 1, vec![]), - ("c", 2, vec![]), + ("a", 1, vec!["c 1"], &[]), + ("a", 2, vec![], &[]), + ("b", 1, vec!["c 1"], &[]), + ("b", 2, vec!["c 2"], &[]), + ("c", 1, vec![], &[]), + ("c", 2, vec![], &[]), ]); provider.set_favored("a", 1); provider.set_favored("b", 1); @@ -909,8 +1005,10 @@ fn test_resolve_favor_with_conflict() { #[test] fn test_resolve_cyclic() { - let provider = - BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); + let provider = BundleBoxProvider::from_packages(&[ + ("a", 2, vec!["b 0..10"], &[]), + ("b", 5, vec!["a 2..4"], &[]), + ]); let requirements = provider.requirements(&["a 0..100"]); let mut solver = Solver::new(provider); let problem = Problem::new().requirements(requirements); @@ -926,15 +1024,15 @@ fn test_resolve_cyclic() { #[test] fn test_resolve_union_requirements() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec![]), - ("b", 1, vec![]), - ("c", 1, vec!["a"]), - ("d", 1, vec!["b"]), - ("e", 1, vec!["a | b"]), + ("a", 1, vec![], &[]), + ("b", 1, vec![], &[]), + ("c", 1, vec!["a"], &[]), + ("d", 1, vec!["b"], &[]), + ("e", 1, vec!["a | b"], &[]), ]); // Make d conflict with a=1 - provider.add_package("f", 1.into(), &["b"], &["a 2"]); + provider.add_package("f", 1.into(), &["b"], &["a 2"], &["b"]); let result = solve_snapshot(provider, &["c | d", "e", "f"]); assert_snapshot!(result, @r###" @@ -1079,8 +1177,8 @@ fn test_unsat_constrains() { ("b", 42, vec![]), ]); - provider.add_package("c", 10.into(), &[], &["b 0..50"]); - provider.add_package("c", 8.into(), &[], &["b 0..50"]); + provider.add_package("c", 10.into(), &[], &["b 0..50"], &[]); + provider.add_package("c", 8.into(), &[], &["b 0..50"], &[]); let error = solve_unsat(provider, &["a", "c"]); insta::assert_snapshot!(error); } @@ -1095,8 +1193,8 @@ fn test_unsat_constrains_2() { ("b", 2, vec!["c 2"]), ]); - provider.add_package("c", 1.into(), &[], &["a 3"]); - provider.add_package("c", 2.into(), &[], &["a 3"]); + provider.add_package("c", 1.into(), &[], &["a 3"], &[]); + provider.add_package("c", 2.into(), &[], &["a 3"], &[]); let error = solve_unsat(provider, &["a"]); insta::assert_snapshot!(error); } @@ -1270,13 +1368,13 @@ fn test_solve_with_additional_with_constrains() { ("e", 1, vec!["c"]), ]); - provider.add_package("f", 1.into(), &[], &["c 2..3"]); - provider.add_package("g", 1.into(), &[], &["b 2..3"]); - provider.add_package("h", 1.into(), &[], &["b 1..2"]); - provider.add_package("i", 1.into(), &[], &[]); - provider.add_package("j", 1.into(), &["i"], &[]); - provider.add_package("k", 1.into(), &["i"], &[]); - provider.add_package("l", 1.into(), &["j", "k"], &[]); + provider.add_package("f", 1.into(), &[], &["c 2..3"], &[]); + provider.add_package("g", 1.into(), &[], &["b 2..3"], &[]); + provider.add_package("h", 1.into(), &[], &["b 1..2"], &[]); + provider.add_package("i", 1.into(), &[], &[], &[]); + provider.add_package("j", 1.into(), &["i"], &[], &[]); + provider.add_package("k", 1.into(), &["i"], &[], &[]); + provider.add_package("l", 1.into(), &["j", "k"], &[], &[]); let requirements = provider.requirements(&["a 0..10", "e"]); let constraints = provider.requirements(&["b 1..2", "c", "k 2..3"]); @@ -1429,6 +1527,364 @@ fn test_explicit_root_requirements() { "###); } +#[test] +#[traced_test] +fn test_conditional_requirements() { + let mut provider = BundleBoxProvider::new(); + + // Add packages + provider.add_package("a", 1.into(), &["b"], &[], &[]); // a depends on b + provider.add_package("b", 1.into(), &[], &[], &[]); // Simple package b + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + + // Create problem with both regular and conditional requirements + let requirements = provider.requirements(&["a", "c 1; if b 1..2"]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + "###); +} + +#[test] +#[traced_test] +fn test_conditional_requirements_not_met() { + let mut provider = BundleBoxProvider::new(); + provider.add_package("b", 1.into(), &[], &[], &[]); // Add b=1 as a candidate + provider.add_package("b", 2.into(), &[], &[], &[]); // Different version of b + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 2"], &[], &[]); // a depends on b=2 specifically + + // Create conditional requirement: if b=1 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1", // If b=1 is installed, require c (note the exact version) + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (not b=1), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} + +#[test] +fn test_nested_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); // Base package + provider.add_package("b", 1.into(), &[], &[], &[]); // First level conditional + provider.add_package("c", 1.into(), &[], &[], &[]); // Second level conditional + provider.add_package("d", 1.into(), &[], &[], &[]); // Third level conditional + + // Create nested conditional requirements using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "c 1; if b 1", // If b is installed, require c + "d 1; if c 1", // If c is installed, require d + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // All packages should be installed due to chain reaction + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + d=1 + "###); +} + +#[test] +fn test_multiple_conditions_same_package() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("target", 1.into(), &[], &[], &[]); + + // Create multiple conditions that all require the same package using the parser + let requirements = provider.requirements(&[ + "b", // Only require package b + "target 1; if a 1", // If a is installed, require target + "target 1; if b 1", // If b is installed, require target + "target 1; if c 1", // If c is installed, require target + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // b and target should be installed + insta::assert_snapshot!(result, @r###" + b=1 + target=1 + "###); +} + +#[test] +fn test_circular_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); + provider.add_package("b", 1.into(), &[], &[], &[]); + + // Create circular conditional dependencies using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "a 1; if b 1", // If b is installed, require a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Both packages should be installed due to circular dependency + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 4..6"], &[], &[]); // a depends on b versions 4-5 + + // Create conditional requirement: if b=1..3 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1..3", // If b is version 1-2, require c + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=4 is installed (not b 1..3), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=5 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..3"], &[], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + c=2 + "###); +} + +#[test] +fn test_conditional_requirements_conflict() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + + // Package c has two versions with different dependencies + provider.add_package("c", 1.into(), &["d 1"], &[], &[]); // c v1 requires d v1 + provider.add_package("c", 2.into(), &["d 2"], &[], &[]); // c v2 requires d v2 + + // Package d has incompatible versions + provider.add_package("d", 1.into(), &[], &[], &[]); + provider.add_package("d", 2.into(), &[], &[], &[]); + + provider.add_package( + "a", + 1.into(), + &["b 1", "c 1; if b 1", "d 2", "c 2; if b 2"], + &[], + &[], + ); + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + + // This should fail to solve because: + // 1. When b=1 is chosen, it triggers the conditional requirement for c 1 + // 2. c 1 requires d 1, but a requires d 2 + // 3. d 1 and d 2 cannot be installed together + + let solved = solver + .solve(problem) + .map_err(|e| match e { + UnsolvableOrCancelled::Unsolvable(conflict) => { + conflict.display_user_friendly(&solver).to_string() + } + UnsolvableOrCancelled::Cancelled(_) => "kir".to_string(), + }) + .unwrap_err(); + + assert_snapshot!(solved, @r" + The following packages are incompatible + └─ a * cannot be installed because there are no viable options: + └─ a 1 would require + ├─ b >=1, <2, which can be installed with any of the following options: + │ └─ b 1 + ├─ d >=2, <3, which can be installed with any of the following options: + │ └─ d 2 + └─ c >=1, <2, which cannot be installed because there are no viable options: + └─ c 1 would require + └─ d >=1, <2, which cannot be installed because there are no viable options: + └─ d 1, which conflicts with the versions reported above. + "); +} + +/// In this test, the resolver installs the highest available version of b which is b 2 +/// However, the conditional requirement is that if b 1..2 is installed, require c +/// Since b 2 is installed, c should not be installed +#[test] +fn test_conditional_requirements_multiple_versions_not_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..2"], &[], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} + +#[test] +fn test_optional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Add package a with base dependency on b and optional dependencies via features + provider.add_package( + "a", + 1.into(), + &["b 1"], + &[], + &[("feat1", &["c"]), ("feat2", &["d"])], + ); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("d", 1.into(), &[], &[], &[]); + + // Request package a with both optional features enabled + let requirements = provider.requirements(&["a[feat2]"]); + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + d=1 + "###); +} + +#[test] +fn test_conditonal_requirements_with_extras() { + let mut provider = BundleBoxProvider::new(); + + // Package a has both optional dependencies (via features) and conditional dependencies + provider.add_package( + "a", + 1.into(), + &["b 1"], + &[], + &[("feat1", &["c"]), ("feat2", &["d"])], + ); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("d", 1.into(), &[], &[], &[]); + provider.add_package("e", 1.into(), &[], &[], &[]); + + // Request package a with feat1 enabled, which will pull in c + // This should trigger the conditional requirement on e + let requirements = provider.requirements(&["a[feat1]", "e 1; if c 1"]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + e=1 + "###); +} + #[cfg(feature = "serde")] fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef) { let file = std::io::BufWriter::new(std::fs::File::create(destination.as_ref()).unwrap()); diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index 901996c..3629eaf 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -128,7 +128,8 @@ fn main() { let start = Instant::now(); - let problem = Problem::default().requirements(requirements); + let problem = + Problem::default().requirements(requirements.into_iter().map(Into::into).collect()); let mut solver = Solver::new(provider); let mut records = None; let mut error = None;