From 05ce4f94fa66ca35a4ed2aed98acc2689ff39b7d Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Thu, 8 Feb 2024 14:49:11 +0100 Subject: [PATCH] feat: runtime agnostic impl --- Cargo.toml | 13 ++-- src/lib.rs | 30 +------- src/problem.rs | 11 +-- src/runtime.rs | 81 +++++++++++++++++++++ src/solver/cache.rs | 10 +-- src/solver/mod.rs | 171 ++++++++++---------------------------------- tests/solver.rs | 22 +++--- 7 files changed, 154 insertions(+), 184 deletions(-) create mode 100644 src/runtime.rs diff --git a/Cargo.toml b/Cargo.toml index d6ad6e3..c69b739 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "resolvo" version = "0.3.0" -authors = ["Adolfo Ochagavía ", "Bas Zalmstra ", "Tim de Jager " ] +authors = ["Adolfo Ochagavía ", "Bas Zalmstra ", "Tim de Jager "] description = "Fast package resolver written in Rust (CDCL based SAT solving)" keywords = ["dependency", "solver", "version"] categories = ["algorithms"] @@ -10,20 +10,25 @@ repository = "https://github.com/mamba-org/resolvo" license = "BSD-3-Clause" edition = "2021" readme = "README.md" +resolver = "2" [dependencies] -itertools = "0.11.0" +itertools = "0.12.1" petgraph = "0.6.4" tracing = "0.1.37" elsa = "1.9.0" bitvec = "1.0.1" serde = { version = "1.0", features = ["derive"], optional = true } futures = { version = "0.3.30", default-features = false, features = ["alloc"] } -tokio = { version = "1.35.1", features = ["rt", "sync"] } +event-listener = "5.0.0" + +tokio = { version = "1.35.1", features = ["rt"], optional = true } +async-std = { version = "1.12.0", default-features = false, features = ["alloc", "default"], optional = true } [dev-dependencies] insta = "1.31.0" indexmap = "2.0.0" proptest = "1.2.0" tracing-test = { version = "0.2.4", features = ["no-env-filter"] } -tokio = { version = "1.35.1", features = ["time"] } +tokio = { version = "1.35.1", features = ["time", "rt"] } +resolvo = { path = ".", features = ["tokio"] } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 6303f6d..e520eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub(crate) mod internal; mod pool; pub mod problem; pub mod range; +pub mod runtime; mod solvable; mod solver; @@ -33,7 +34,7 @@ use std::{ rc::Rc, }; -/// The solver is based around the fact that for for every package name we are trying to find a +/// The solver is based around the fact that for every package name we are trying to find a /// single variant. Variants are grouped by their respective package name. A package name is /// anything that we can compare and hash for uniqueness checks. /// @@ -45,7 +46,7 @@ pub trait PackageName: Eq + Hash {} impl PackageName for N {} -/// A [`VersionSet`] is describes a set of "versions". The trait defines whether a given version +/// A [`VersionSet`] describes a set of "versions". The trait defines whether a given version /// is part of the set or not. /// /// One could implement [`VersionSet`] for [`std::ops::Range`] where the implementation @@ -67,15 +68,6 @@ pub trait DependencyProvider: Sized { /// Sort the specified solvables based on which solvable to try first. The solver will /// iteratively try to select the highest version. If a conflict is found with the highest /// version the next version is tried. This continues until a solution is found. - /// - /// # Async - /// - /// The returned future will be awaited by a tokio runtime blocking the main thread. You are - /// free to use other runtimes in your implementation, as long as the runtime-specific code runs - /// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can - /// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to - /// retrieve necessary information from the network, and `await` the returned task handle. - #[allow(async_fn_in_trait)] async fn sort_candidates( &self, @@ -85,26 +77,10 @@ pub trait DependencyProvider: Sized { /// Obtains a list of solvables that should be considered when a package with the given name is /// requested. - /// - /// # Async - /// - /// The returned future will be awaited by a tokio runtime blocking the main thread. You are - /// free to use other runtimes in your implementation, as long as the runtime-specific code runs - /// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can - /// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to - /// retrieve necessary information from the network, and `await` the returned task handle. #[allow(async_fn_in_trait)] async fn get_candidates(&self, name: NameId) -> Option; /// Returns the dependencies for the specified solvable. - /// - /// # Async - /// - /// The returned future will be awaited by a tokio runtime blocking the main thread. You are - /// free to use other runtimes in your implementation, as long as the runtime-specific code runs - /// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can - /// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to - /// retrieve necessary information from the network, and `await` the returned task handle. #[allow(async_fn_in_trait)] async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; diff --git a/src/problem.rs b/src/problem.rs index 7dc1870..1460a22 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -11,10 +11,10 @@ use petgraph::graph::{DiGraph, EdgeIndex, EdgeReference, NodeIndex}; use petgraph::visit::{Bfs, DfsPostOrder, EdgeRef}; use petgraph::Direction; -use crate::internal::id::StringId; use crate::{ - internal::id::{ClauseId, SolvableId, VersionSetId}, + internal::id::{ClauseId, SolvableId, StringId, VersionSetId}, pool::Pool, + runtime::AsyncRuntime, solver::{clause::Clause, Solver}, DependencyProvider, PackageName, SolvableDisplay, VersionSet, }; @@ -40,9 +40,9 @@ impl Problem { } /// Generates a graph representation of the problem (see [`ProblemGraph`] for details) - pub fn graph>( + pub fn graph, RT: AsyncRuntime>( &self, - solver: &Solver, + solver: &Solver, ) -> ProblemGraph { let mut graph = DiGraph::::default(); let mut nodes: HashMap = HashMap::default(); @@ -158,9 +158,10 @@ impl Problem { N: PackageName + Display, D: DependencyProvider, M: SolvableDisplay, + RT: AsyncRuntime, >( &self, - solver: &'a Solver, + solver: &'a Solver, pool: Rc>, merged_solvable_display: &'a M, ) -> DisplayUnsat<'a, VS, N, M> { diff --git a/src/runtime.rs b/src/runtime.rs new file mode 100644 index 0000000..31b5ea9 --- /dev/null +++ b/src/runtime.rs @@ -0,0 +1,81 @@ +//! Solving in resolvo is a compute heavy operation. However, while computing the solver will +//! request additional information from the [`crate::DependencyProvider`] and a dependency provider +//! might want to perform multiple requests concurrently. To that end the +//! [`crate::DependencyProvider`]s methods are async. The implementer can implement the async +//! operations in any way they choose including with any runtime they choose. +//! However, the solver itself is completely single threaded, but it still has to await the calls to +//! the dependency provider. Using the [`AsyncRuntime`] allows the caller of the solver to choose +//! how to await the futures. +//! +//! By default, the solver uses the [`NowOrNeverRuntime`] runtime which polls any future once. If +//! the future yields (thus requiring an additional poll) the runtime panics. If the methods of +//! [`crate::DependencyProvider`] do not yield (e.g. do not `.await`) this will suffice. +//! +//! Only if the [`crate::DependencyProvider`] implementation yields you will need to provide a +//! [`AsyncRuntime`] to the solver. +//! +//! ## `tokio` +//! +//! The solver uses tokio to await the results of async methods in [`crate::DependencyProvider`]. It +//! will run them concurrently, but blocking the main thread. That means that a single-threaded +//! tokio runtime is usually enough. It is also possible to use a different runtime, as long as you +//! avoid mixing incompatible futures. +//! +//! The [`AsyncRuntime`] trait is implemented both for [`tokio::runtime::Handle`] and for +//! [`tokio::runtime::Runtime`]. +//! +//! ## `async-std` +//! +//! Use the [`AsyncStdRuntime`] struct to block on async methods from the +//! [`crate::DependencyProvider`] using the `async-std` executor. + +use futures::FutureExt; +use std::future::Future; + +/// A trait to wrap an async runtime. +pub trait AsyncRuntime { + /// Runs the given future on the current thread, blocking until it is complete, and yielding its + /// resolved result. + fn block_on(&self, f: F) -> F::Output; +} + +/// The simplest runtime possible evaluates and consumes the future, returning the resulting +/// output if the future is ready after the first call to [`Future::poll`]. If the future does +/// yield the runtime panics. +/// +/// This assumes that the passed in future never yields. For purely blocking computations this +/// is the preferred method since it also incurs very little overhead and doesn't require the +/// inclusion of a heavy-weight runtime. +#[derive(Default, Copy, Clone)] +pub struct NowOrNeverRuntime; + +impl AsyncRuntime for NowOrNeverRuntime { + fn block_on(&self, f: F) -> F::Output { + f.now_or_never() + .expect("can only use non-yielding futures with the NowOrNeverRuntime") + } +} + +#[cfg(feature = "tokio")] +impl AsyncRuntime for tokio::runtime::Handle { + fn block_on(&self, f: F) -> F::Output { + self.block_on(f) + } +} + +#[cfg(feature = "tokio")] +impl AsyncRuntime for tokio::runtime::Runtime { + fn block_on(&self, f: F) -> F::Output { + self.block_on(f) + } +} + +#[cfg(feature = "async-std")] +pub struct AsyncStdRuntime; + +#[cfg(feature = "async-std")] +impl AsyncRuntime for AsyncStdRuntime { + fn block_on(&self, f: F) -> F::Output { + async_std::task::block_on(f) + } +} diff --git a/src/solver/cache.rs b/src/solver/cache.rs index 82a3501..eed36db 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -10,8 +10,8 @@ use crate::{ }; use bitvec::vec::BitVec; use elsa::FrozenMap; +use event_listener::Event; use std::{any::Any, cell::RefCell, collections::HashMap, marker::PhantomData, rc::Rc}; -use tokio::sync::Notify; /// Keeps a cache of previously computed and/or requested information about solvables and version /// sets. @@ -21,7 +21,7 @@ pub struct SolverCache, package_name_to_candidates: FrozenCopyMap, - package_name_to_candidates_in_flight: RefCell>>, + package_name_to_candidates_in_flight: RefCell>>, /// A mapping of `VersionSetId` to the candidates that match that set. version_set_candidates: FrozenMap>, @@ -99,7 +99,7 @@ impl> SolverCache { // Found an in-flight request, wait for that request to finish and return the computed result. - in_flight.notified().await; + in_flight.listen().await; self.package_name_to_candidates .get_copy(&package_name) .expect("after waiting for a request the result should be available") @@ -108,7 +108,7 @@ impl> SolverCache> SolverCache> { +pub struct Solver< + VS: VersionSet, + N: PackageName, + D: DependencyProvider, + RT: AsyncRuntime = NowOrNeverRuntime, +> { /// The [Pool] used by the solver pub pool: Rc>, - pub(crate) async_runtime: tokio::runtime::Runtime, + pub(crate) async_runtime: RT, pub(crate) cache: SolverCache, pub(crate) clauses: RefCell>, @@ -69,22 +75,16 @@ pub struct Solver> root_requirements: Vec, } -impl> Solver { +impl> + Solver +{ /// Create a solver, using the provided pool and async runtime. - /// - /// # Async runtime - /// - /// The solver uses tokio to await the results of async methods in [DependencyProvider]. It will - /// run them concurrently, but blocking the main thread. That means that a single-threaded tokio - /// runtime is usually enough. It is also possible to use a different runtime, as long as you - /// avoid mixing incompatible futures. For details, check out the documentation for the async - /// methods of [DependencyProvider]. - pub fn new(provider: D, async_runtime: tokio::runtime::Runtime) -> Self { + pub fn new(provider: D) -> Self { let pool = provider.pool(); Self { cache: SolverCache::new(provider), pool, - async_runtime, + async_runtime: NowOrNeverRuntime, clauses: RefCell::new(Arena::new()), requires_clauses: Default::default(), watches: WatchMap::new(), @@ -98,19 +98,6 @@ impl> Solver Self { - Self::new( - provider, - tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(), - ) - } } /// The root cause of a solver error. @@ -140,7 +127,29 @@ pub(crate) enum PropagationError { Cancelled(Box), } -impl> Solver { +impl, RT: AsyncRuntime> + Solver +{ + /// Set the runtime of the solver to `runtime`. + pub fn with_runtime(self, runtime: RT2) -> Solver { + Solver { + pool: self.pool, + async_runtime: runtime, + cache: self.cache, + clauses: self.clauses, + requires_clauses: self.requires_clauses, + watches: self.watches, + negative_assertions: self.negative_assertions, + learnt_clauses: self.learnt_clauses, + learnt_why: self.learnt_why, + learnt_clause_ids: self.learnt_clause_ids, + clauses_added_for_package: self.clauses_added_for_package, + clauses_added_for_solvable: self.clauses_added_for_solvable, + decision_tracker: self.decision_tracker, + root_requirements: self.root_requirements, + } + } + /// Solves the provided `jobs` and returns a transaction from the found solution /// /// Returns a [`Problem`] if no solution was found, which provides ways to inspect the causes @@ -207,7 +216,7 @@ impl> Sol NonMatchingCandidates { solvable_id: SolvableId, version_set_id: VersionSetId, - non_matching_candidates: Vec, + non_matching_candidates: &'i [SolvableId], }, Candidates { name_id: NameId, @@ -351,8 +360,7 @@ impl> Sol let non_matching_candidates = self .cache .get_or_cache_non_matching_candidates(version_set_id) - .await? - .to_vec(); + .await?; Ok(TaskResult::NonMatchingCandidates { solvable_id, version_set_id, @@ -497,7 +505,7 @@ impl> Sol ); // Add forbidden clauses for the candidates - for forbidden_candidate in non_matching_candidates { + for &forbidden_candidate in non_matching_candidates { let (clause, conflict) = ClauseState::constrains( solvable_id, forbidden_candidate, @@ -519,107 +527,6 @@ impl> Sol Ok(output) } - // /// Adds all clauses for a specific package name. - // /// - // /// These clauses include: - // /// - // /// 1. making sure that only a single candidate for the package is selected (forbid multiple) - // /// 2. if there is a locked candidate then that candidate is the only selectable candidate. - // /// - // /// If this function is called with the same package name twice, the clauses will only be added - // /// once. - // /// - // /// There is no need to propagate after adding these clauses because none of the clauses are - // /// assertions (only a single literal) and we assume that no decision has been made about any - // /// of the solvables involved. This assumption is checked when debug_assertions are enabled. - // /// - // /// If the provider has requested the solving process to be cancelled, the cancellation value - // /// will be returned as an `Err(...)`. - // async fn add_clauses_for_package( - // &self, - // output: &RefCell, - // package_name: NameId, - // ) -> Result<(), Box> { - // let mutex = { - // let mut clauses = self.clauses_added_for_package.borrow_mut(); - // let mutex = clauses - // .entry(package_name) - // .or_insert_with(|| Rc::new(tokio::sync::Mutex::new(false))); - // mutex.clone() - // }; - // - // // This prevents concurrent calls to `add_clauses_for_package` from racing. Only the first - // // call for a given package will go through, and others will wait till it completes. - // let mut clauses_added = mutex.lock().await; - // if *clauses_added { - // return Ok(()); - // } - // - // tracing::trace!( - // "┝━ adding clauses for package '{}'", - // self.pool.resolve_package_name(package_name) - // ); - // - // let package_candidates = self.cache.get_or_cache_candidates(package_name).await?; - // let locked_solvable_id = package_candidates.locked; - // let candidates = &package_candidates.candidates; - // - // // Check the assumption that no decision has been made about any of the solvables. - // for &candidate in candidates { - // debug_assert!( - // self.decision_tracker.assigned_value(candidate).is_none(), - // "a decision has been made about a candidate of a package that was not properly added yet." - // ); - // } - // - // let mut output = output.borrow_mut(); - // - // // Each candidate gets a clause to disallow other candidates. - // for (i, &candidate) in candidates.iter().enumerate() { - // for &other_candidate in &candidates[i + 1..] { - // let clause_id = self - // .clauses - // .borrow_mut() - // .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); - // - // debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); - // output.clauses_to_watch.push(clause_id); - // } - // } - // - // // If there is a locked solvable, forbid other solvables. - // if let Some(locked_solvable_id) = locked_solvable_id { - // for &other_candidate in candidates { - // if other_candidate != locked_solvable_id { - // let clause_id = self - // .clauses - // .borrow_mut() - // .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); - // - // debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); - // output.clauses_to_watch.push(clause_id); - // } - // } - // } - // - // // Add a clause for solvables that are externally excluded. - // for (solvable, reason) in package_candidates.excluded.iter().copied() { - // let clause_id = self - // .clauses - // .borrow_mut() - // .alloc(ClauseState::exclude(solvable, reason)); - // - // // Exclusions are negative assertions, tracked outside of the watcher system - // output.negative_assertions.push((solvable, clause_id)); - // - // // Conflicts should be impossible here - // debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true)); - // } - // - // *clauses_added = true; - // Ok(()) - // } - /// Run the CDCL algorithm to solve the SAT problem /// /// The CDCL algorithm's job is to find a valid assignment to the variables involved in the diff --git a/tests/solver.rs b/tests/solver.rs index 23c685d..7062dfb 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -404,7 +404,7 @@ fn transaction_to_string(pool: &Pool, solvables: &Vec String { let requirements = provider.requirements(specs); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); match solver.solve(requirements) { Ok(_) => panic!("expected unsat, but a solution was found"), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { @@ -436,7 +436,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.requirements(specs); let pool = provider.pool(); - let mut solver = Solver::new(provider, runtime); + let mut solver = Solver::new(provider).with_runtime(runtime); match solver.solve(requirements) { Ok(solvables) => transaction_to_string(&pool, &solvables), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { @@ -462,7 +462,7 @@ fn test_unit_propagation_1() { let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]); let root_requirements = provider.requirements(&["asdf"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(root_requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -482,7 +482,7 @@ fn test_unit_propagation_nested() { ]); let requirements = provider.requirements(&["asdf"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); @@ -509,7 +509,7 @@ fn test_resolve_multiple() { ]); let requirements = provider.requirements(&["asdf", "efgh"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); @@ -567,7 +567,7 @@ fn test_resolve_with_nonexisting() { ]); let requirements = provider.requirements(&["asdf"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -601,7 +601,7 @@ fn test_resolve_with_nested_deps() { ]); let requirements = provider.requirements(&["apache-airflow"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -628,7 +628,7 @@ fn test_resolve_with_unknown_deps() { provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -673,7 +673,7 @@ fn test_resolve_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -694,7 +694,7 @@ fn test_resolve_ignored_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); @@ -752,7 +752,7 @@ fn test_resolve_cyclic() { BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); let requirements = provider.requirements(&["a 0..100"]); let pool = provider.pool(); - let mut solver = Solver::new_with_default_runtime(provider); + let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); let result = transaction_to_string(&pool, &solved);