Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,48 @@ impl CpModelBuilder {
pub fn solve_with_parameters(&self, params: &proto::SatParameters) -> proto::CpSolverResponse {
ffi::solve_with_parameters(self.proto(), params)
}

/// Solves the model with the given
/// [parameters][proto::SatParameters],
/// a solution handler that is called with feasible solutions [proto::CpSolverResponse],
/// and returns the final [proto::CpSolverResponse].
///
/// The given function will be called on each improving feasible solution found
/// during the search. For a non-optimization problem, if the option
/// [proto::SatParameters::enumerate_all_solutions] to find all
/// solutions was set, then this will be called on each new solution.
///
/// Please note that it does not work in parallel
/// (i. e. parameter [proto::SatParameters::num_search_workers] > 1).
///
/// ```
/// # use std::cell::RefCell;
/// # use std::rc::Rc;
/// # use cp_sat::builder::CpModelBuilder;
/// # use cp_sat::proto::{SatParameters, CpSolverResponse};
/// let mut model = CpModelBuilder::default();
/// // linear constraint will only allow a = 2, a = 3 and a = 4
/// let a = model.new_int_var([(2, 7)]);
/// model.add_linear_constraint([(3, a)], [(0, 13)]);
/// let mut params = SatParameters::default();
/// params.enumerate_all_solutions = Some(true);
///
/// let memory = Rc::new(RefCell::new(Vec::new()));
/// let memory2 = memory.clone();
/// let handler = move |response: CpSolverResponse| {
/// memory2.borrow_mut().push(response);
/// };
///
/// let _response = model.solve_with_parameters_and_handler(&params, handler);
/// assert_eq!(3, memory.borrow().len());
/// ```
pub fn solve_with_parameters_and_handler(
&self,
params: &proto::SatParameters,
handler: impl FnMut(proto::CpSolverResponse) + 'static,
) -> proto::CpSolverResponse {
ffi::solve_with_parameters_and_handler(self.proto(), params, Box::new(handler))
}
}

/// Boolean variable identifier.
Expand Down
60 changes: 60 additions & 0 deletions src/cp_sat_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,66 @@ cp_sat_wrapper_solve(
return out_buf;
}

/**
* Solution handler that is called on every encountered solution.
*
* Arguments:
* - serialized buffer of a CpSolverResponse
* - length of the buffer
* - additional data passed from the outside
*/
typedef void (*solution_handler)(unsigned char*, size_t, void*);

/**
* Similar to cp_sat_wrapper_solve_with_parameters, but with a callback function
* for all encountered solutions.
*
* - handler: called on every solution
* - handler_data: additional data that is provided to the callback
*/
extern "C" unsigned char*
cp_sat_wrapper_solve_with_parameters_and_handler(
unsigned char* model_buf,
size_t model_size,
unsigned char* params_buf,
size_t params_size,
solution_handler handler,
void* handler_data,
size_t* out_size)
{
sat::Model extra_model;
sat::CpModelProto model;
bool res = model.ParseFromArray(model_buf, model_size);
assert(res);

sat::SatParameters params;
res = params.ParseFromArray(params_buf, params_size);
assert(res);

extra_model.Add(sat::NewSatParameters(params));

// local function that serializes the CpSolverResponse for the provided solution handler
auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) {
// serialize CpSolverResponse
size_t response_size = curr_response.ByteSizeLong();
unsigned char* response_buf = (unsigned char*) malloc(response_size);
bool curr_res = curr_response.SerializeToArray(response_buf, response_size);
assert(curr_res);

handler(response_buf, response_size, handler_data);
};
extra_model.Add(sat::NewFeasibleSolutionObserver(wrapped_handler));

sat::CpSolverResponse response = sat::SolveCpModel(model, &extra_model);

*out_size = response.ByteSizeLong();
unsigned char* out_buf = (unsigned char*) malloc(*out_size);
res = response.SerializeToArray(out_buf, *out_size);
assert(res);

return out_buf;
}

extern "C" char*
cp_sat_wrapper_cp_model_stats(unsigned char* model_buf, size_t model_size) {
sat::CpModelProto model;
Expand Down
70 changes: 70 additions & 0 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::proto;
use libc::c_char;
use prost::Message;
use std::ffi::CStr;
use std::ffi::c_void;

extern "C" {
fn cp_sat_wrapper_solve(
Expand All @@ -16,6 +17,15 @@ extern "C" {
params_size: usize,
out_size: &mut usize,
) -> *mut u8;
fn cp_sat_wrapper_solve_with_parameters_and_handler(
model_buf: *const u8,
model_size: usize,
params_buf: *const u8,
params_size: usize,
handler_caller: extern "C" fn(*const u8, usize, *mut c_void),
handler: *mut c_void,
out_size: &mut usize,
) -> *mut u8;
fn cp_sat_wrapper_cp_model_stats(model_buf: *const u8, model_size: usize) -> *mut c_char;
fn cp_sat_wrapper_cp_solver_response_stats(
response_buf: *const u8,
Expand Down Expand Up @@ -72,6 +82,66 @@ pub fn solve_with_parameters(
response
}

/// User provided solution handler that is called with feasible solutions.
pub type SolutionHandler = Box<dyn FnMut(proto::CpSolverResponse)>;

/// Solves the given [CpModelProto][crate::proto::CpModelProto] with
/// the given parameters,
/// and calls the [SolutionHandler] on each improving feasible solution found
/// during the search. For a non-optimization problem, if the option
/// [proto::SatParameters.enumerate_all_solutions] to find all
/// solutions was set, then this will be called on each new solution.
///
/// Please note that it does not work in parallel
/// (i. e. parameter [proto::SatParameters::num_search_workers] > 1).
pub fn solve_with_parameters_and_handler(
model: &proto::CpModelProto,
params: &proto::SatParameters,
mut handler: SolutionHandler,
) -> proto::CpSolverResponse {
let mut model_buf = Vec::default();
model.encode(&mut model_buf).unwrap();
let mut params_buf = Vec::default();
params.encode(&mut params_buf).unwrap();

let mut out_size = 0;
let res = unsafe {
cp_sat_wrapper_solve_with_parameters_and_handler(
model_buf.as_ptr(),
model_buf.len(),
params_buf.as_ptr(),
params_buf.len(),
solution_handler_caller,
&mut handler as *mut _ as *mut c_void,
&mut out_size,
)
};
let out_slice = unsafe { std::slice::from_raw_parts(res, out_size) };
let response = proto::CpSolverResponse::decode(out_slice).unwrap();
unsafe { libc::free(res as _) };
response
}

/// Callback that is called from cpp code and transforms a buffered response to a
/// [proto::CpSolverResponse] that can be used by a [SolutionHandler].
///
/// # Arguments
/// - `response_buf` and `response_size`: buffer and size of a [proto::CpSolverResponse]
/// - `handler`: a user provided solution handler [SolutionHandler] that accepts a
/// [proto::CpSolverResponse]
extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) {
let response_slice = unsafe {
std::slice::from_raw_parts(response_buf, response_size)
};
let response = proto::CpSolverResponse::decode(response_slice).unwrap();
unsafe { libc::free(response_buf as _) };

unsafe {
let tmp = handler as *mut SolutionHandler;
(*tmp)(response);
}
}

/// Returns a string with some statistics on the given
/// [CpModelProto][crate::proto::CpModelProto].
pub fn cp_model_stats(model: &proto::CpModelProto) -> String {
Expand Down
63 changes: 63 additions & 0 deletions tests/solution_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::cell::RefCell;
use std::collections::HashSet;
use std::rc::Rc;
use cp_sat::builder::CpModelBuilder;
use cp_sat::proto::{SatParameters, CpSolverResponse};

/// In a non-optimization problem all feasible solutions should be found.
#[test]
fn enumeration_solution_handler() {
let mut model = CpModelBuilder::default();
// linear constraint will only allow a = 2, a = 3 and a = 4
let a = model.new_int_var([(2, 7)]);
model.add_linear_constraint([(3, a)], [(0, 13)]);
let mut params = SatParameters::default();
params.enumerate_all_solutions = Some(true);

let memory = Rc::new(RefCell::new(Vec::new()));
let memory2 = memory.clone();
let handler = move |response: CpSolverResponse| {
memory2.borrow_mut().push(response);
};

let _response = model.solve_with_parameters_and_handler(&params, handler);

assert_eq!(3, memory.borrow().len());

let expected = HashSet::from([2, 3, 4]);
let actual = memory
.borrow()
.iter()
.map(|response| a.solution_value(response))
.collect::<HashSet::<i64>>();

assert_eq!(expected, actual);
}

/// In an optimization problem at least one feasible solution should be found.
#[test]
fn optimization_solution_handler() {
let mut model = CpModelBuilder::default();
// linear constraint will only allow a = 2, a = 3 and a = 4
let a = model.new_int_var([(2, 7)]);
model.add_linear_constraint([(3, a)], [(0, 13)]);
model.minimize(a);
let mut params = SatParameters::default();
params.enumerate_all_solutions = Some(true);

let memory = Rc::new(RefCell::new(Vec::new()));
let memory2 = memory.clone();
let handler = move |response: CpSolverResponse| {
memory2.borrow_mut().push(response);
};

let response = model.solve_with_parameters_and_handler(&params, handler);

assert_eq!(2, a.solution_value(&response));

// At least one feasible solution is encountered.
// As we do not know how often the solution improves, or whether the first
// feasible solution is already the optimal one, we cannot expect more than one
// improvement.
assert!(memory.borrow().len() >= 1);
}