Skip to content

Commit

Permalink
Merge branch 'main' into nd60/new-solver-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasdewally authored Feb 9, 2024
2 parents efb93c8 + 1278204 commit 874b6d1
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 112 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
coverage

conjure_oxide/tests/**/*.generated.*
conjure_oxide/tests/**/*.generated-parse.*
conjure_oxide/tests/**/*.generated-rewrite.*

*.profraw

## Rust
debug/
Expand Down
1 change: 1 addition & 0 deletions conjure_oxide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod utils;
pub use conjure_core::ast; // re-export core::ast as conjure_oxide::ast
pub use conjure_core::ast::Model; // rexport core::ast::Model as conjure_oxide::Model
pub use conjure_core::solvers::Solver;
pub use rules::eval_constant;

pub use error::Error;

Expand Down
16 changes: 9 additions & 7 deletions conjure_oxide/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use serde_json::Value;

use crate::ast::{DecisionVariable, Domain, Expression, Model, Name, Range};
use crate::ast::{Constant, DecisionVariable, Domain, Expression, Model, Name, Range};
use crate::error::{Error, Result};
use serde_json::Value as JsonValue;

Expand Down Expand Up @@ -246,12 +246,14 @@ fn parse_vec_op(

fn parse_constant(constant: &serde_json::Map<String, Value>) -> Option<Expression> {
match &constant["Constant"] {
Value::Object(int) if int.contains_key("ConstantInt") => Some(Expression::ConstantInt(
int["ConstantInt"].as_array()?[1]
.as_i64()?
.try_into()
.unwrap(),
)),
Value::Object(int) if int.contains_key("ConstantInt") => {
Some(Expression::Constant(Constant::Int(
int["ConstantInt"].as_array()?[1]
.as_i64()?
.try_into()
.unwrap(),
)))
}
otherwise => panic!("Unhandled parse_constant {:#?}", otherwise),
}
}
26 changes: 10 additions & 16 deletions conjure_oxide/src/rules/base.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
use conjure_core::ast::Expression;
use conjure_core::{ast::Expression as Expr, rule::RuleApplicationError};
use conjure_core::{ast::Constant as Const, ast::Expression as Expr, rule::RuleApplicationError};
use conjure_rules::register_rule;

/*****************************************************************************/
/* This file contains basic rules for simplifying expressions */
/*****************************************************************************/

// #[register_rule]
// fn identity(expr: &Expr) -> Result<Expr, RuleApplicationError> {
// Ok(expr.clone())
// }

/**
* Remove nothings from expressions:
* ```text
Expand Down Expand Up @@ -43,15 +37,15 @@ fn remove_nothings(expr: &Expr) -> Result<Expr, RuleApplicationError> {
}

match expr {
Expr::And(_) | Expr::Or(_) | Expression::Sum(_) => match expr.sub_expressions() {
Expr::And(_) | Expr::Or(_) | Expr::Sum(_) => match expr.sub_expressions() {
None => Err(RuleApplicationError::RuleNotApplicable),
Some(sub) => {
let new_sub = remove_nothings(sub)?;
let new_expr = expr.with_sub_expressions(new_sub);
Ok(new_expr)
}
},
Expression::SumEq(_, _) | Expression::SumLeq(_, _) | Expression::SumGeq(_, _) => {
Expr::SumEq(_, _) | Expr::SumLeq(_, _) | Expr::SumGeq(_, _) => {
match expr.sub_expressions() {
None => Err(RuleApplicationError::RuleNotApplicable),
Some(sub) => {
Expand Down Expand Up @@ -108,7 +102,7 @@ fn sum_constants(expr: &Expr) -> Result<Expr, RuleApplicationError> {
let mut changed = false;
for e in exprs {
match e {
Expr::ConstantInt(i) => {
Expr::Constant(Const::Int(i)) => {
sum += i;
changed = true;
}
Expand All @@ -118,7 +112,7 @@ fn sum_constants(expr: &Expr) -> Result<Expr, RuleApplicationError> {
if !changed {
return Err(RuleApplicationError::RuleNotApplicable);
}
new_exprs.push(Expr::ConstantInt(sum));
new_exprs.push(Expr::Constant(Const::Int(sum)));
Ok(Expr::Sum(new_exprs)) // Let other rules handle only one Expr being contained in the sum
}
_ => Err(RuleApplicationError::RuleNotApplicable),
Expand Down Expand Up @@ -308,10 +302,10 @@ fn remove_constants_from_or(expr: &Expr) -> Result<Expr, RuleApplicationError> {
let mut changed = false;
for e in exprs {
match e {
Expr::ConstantBool(val) => {
Expr::Constant(Const::Bool(val)) => {
if *val {
// If we find a true, the whole expression is true
return Ok(Expr::ConstantBool(true));
return Ok(Expr::Constant(Const::Bool(true)));
} else {
// If we find a false, we can ignore it
changed = true;
Expand Down Expand Up @@ -344,10 +338,10 @@ fn remove_constants_from_and(expr: &Expr) -> Result<Expr, RuleApplicationError>
let mut changed = false;
for e in exprs {
match e {
Expr::ConstantBool(val) => {
Expr::Constant(Const::Bool(val)) => {
if !*val {
// If we find a false, the whole expression is false
return Ok(Expr::ConstantBool(false));
return Ok(Expr::Constant(Const::Bool(false)));
} else {
// If we find a true, we can ignore it
changed = true;
Expand Down Expand Up @@ -376,7 +370,7 @@ fn remove_constants_from_and(expr: &Expr) -> Result<Expr, RuleApplicationError>
fn evaluate_constant_not(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Not(contents) => match contents.as_ref() {
Expr::ConstantBool(val) => Ok(Expr::ConstantBool(!val)),
Expr::Constant(Const::Bool(val)) => Ok(Expr::Constant(Const::Bool(!val))),
_ => Err(RuleApplicationError::RuleNotApplicable),
},
_ => Err(RuleApplicationError::RuleNotApplicable),
Expand Down
112 changes: 112 additions & 0 deletions conjure_oxide/src/rules/constant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use conjure_core::{ast::Constant as Const, ast::Expression as Expr, rule::RuleApplicationError};
use conjure_rules::register_rule;

#[register_rule]
fn apply_eval_constant(expr: &Expr) -> Result<Expr, RuleApplicationError> {
if expr.is_constant() {
return Err(RuleApplicationError::RuleNotApplicable);
}
let res = eval_constant(expr)
.map(Expr::Constant)
.ok_or(RuleApplicationError::RuleNotApplicable);
res
}

/// Simplify an expression to a constant if possible
/// Returns:
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
/// `Some(Const)` if the expression can be simplified to a constant
pub fn eval_constant(expr: &Expr) -> Option<Const> {
match expr {
Expr::Constant(c) => Some(c.clone()),
Expr::Reference(_) => None,

Expr::Eq(a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
.or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
.map(Const::Bool),
Expr::Neq(a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Const::Bool),
Expr::Lt(a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Const::Bool),
Expr::Gt(a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Const::Bool),
Expr::Leq(a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Const::Bool),
Expr::Geq(a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Const::Bool),

Expr::Not(expr) => un_op::<bool, bool>(|e| !e, expr).map(Const::Bool),

Expr::And(exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Const::Bool),
Expr::Or(exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Const::Bool),

Expr::Sum(exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Const::Int),

Expr::Ineq(a, b, c) => {
tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Const::Bool)
}

Expr::SumGeq(exprs, a) => {
flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Const::Bool)
}
Expr::SumLeq(exprs, a) => {
flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Const::Bool)
}
_ => {
println!("WARNING: Unimplemented constant eval: {:?}", expr);
None
}
}
}

fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
where
T: TryFrom<Const>,
{
let a = unwrap_expr::<T>(a)?;
Some(f(a))
}

fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
where
T: TryFrom<Const>,
{
let a = unwrap_expr::<T>(a)?;
let b = unwrap_expr::<T>(b)?;
Some(f(a, b))
}

fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
where
T: TryFrom<Const>,
{
let a = unwrap_expr::<T>(a)?;
let b = unwrap_expr::<T>(b)?;
let c = unwrap_expr::<T>(c)?;
Some(f(a, b, c))
}

fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &Vec<Expr>) -> Option<A>
where
T: TryFrom<Const>,
{
let a = a
.iter()
.map(unwrap_expr)
.into_iter()
.collect::<Option<Vec<T>>>()?;
Some(f(a))
}

fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &Vec<Expr>, b: &Expr) -> Option<A>
where
T: TryFrom<Const>,
{
let a = a
.iter()
.map(unwrap_expr)
.into_iter()
.collect::<Option<Vec<T>>>()?;
let b = unwrap_expr::<T>(b)?;
Some(f(a, b))
}

fn unwrap_expr<T: TryFrom<Const>>(expr: &Expr) -> Option<T> {
let c = eval_constant(expr)?;
TryInto::<T>::try_into(c).ok()
}
14 changes: 9 additions & 5 deletions conjure_oxide/src/rules/minion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use conjure_core::{ast::Expression as Expr, rule::RuleApplicationError};
use conjure_core::{ast::Constant as Const, ast::Expression as Expr, rule::RuleApplicationError};
use conjure_rules::register_rule;

/************************************************************************/
/* Rules for translating to Minion-supported constraints */
/************************************************************************/

fn is_nested_sum(exprs: &Vec<Expr>) -> bool {
for e in exprs {
if let Expr::Sum(_) = e {
Expand Down Expand Up @@ -116,7 +120,7 @@ fn lt_to_ineq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
Expr::Lt(a, b) => Ok(Expr::Ineq(
a.clone(),
b.clone(),
Box::new(Expr::ConstantInt(-1)),
Box::new(Expr::Constant(Const::Int(-1))),
)),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
Expand All @@ -135,7 +139,7 @@ fn gt_to_ineq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
Expr::Gt(a, b) => Ok(Expr::Ineq(
b.clone(),
a.clone(),
Box::new(Expr::ConstantInt(-1)),
Box::new(Expr::Constant(Const::Int(-1))),
)),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
Expand All @@ -154,7 +158,7 @@ fn geq_to_ineq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
Expr::Geq(a, b) => Ok(Expr::Ineq(
b.clone(),
a.clone(),
Box::new(Expr::ConstantInt(0)),
Box::new(Expr::Constant(Const::Int(0))),
)),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
Expand All @@ -173,7 +177,7 @@ fn leq_to_ineq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
Expr::Leq(a, b) => Ok(Expr::Ineq(
a.clone(),
b.clone(),
Box::new(Expr::ConstantInt(0)),
Box::new(Expr::Constant(Const::Int(0))),
)),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
Expand Down
3 changes: 3 additions & 0 deletions conjure_oxide/src/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod base;
mod cnf;
mod constant;
mod minion;

pub use constant::eval_constant;
9 changes: 5 additions & 4 deletions conjure_oxide/src/solvers/minion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use super::{FromConjureModel, SolverError};
use crate::Solver;

use crate::ast::{
DecisionVariable, Domain as ConjureDomain, Expression as ConjureExpression,
Model as ConjureModel, Name as ConjureName, Range as ConjureRange,
Constant as ConjureConstant, DecisionVariable, Domain as ConjureDomain,
Expression as ConjureExpression, Model as ConjureModel, Name as ConjureName,
Range as ConjureRange,
};
pub use minion_rs::ast::Model as MinionModel;
use minion_rs::ast::{
Expand Down Expand Up @@ -198,7 +199,7 @@ fn must_be_ref(e: ConjureExpression) -> Result<String, SolverError> {

fn must_be_const(e: ConjureExpression) -> Result<i32, SolverError> {
match e {
ConjureExpression::ConstantInt(n) => Ok(n),
ConjureExpression::Constant(ConjureConstant::Int(n)) => Ok(n),
x => Err(SolverError::InvalidInstance(
SOLVER,
format!("expected a constant, but got `{0:?}`", x),
Expand Down Expand Up @@ -238,7 +239,7 @@ mod tests {
let x = ConjureExpression::Reference(ConjureName::UserName("x".to_owned()));
let y = ConjureExpression::Reference(ConjureName::UserName("y".to_owned()));
let z = ConjureExpression::Reference(ConjureName::UserName("z".to_owned()));
let four = ConjureExpression::ConstantInt(4);
let four = ConjureExpression::Constant(ConjureConstant::Int(4));

let geq = ConjureExpression::SumGeq(
vec![x.to_owned(), y.to_owned(), z.to_owned()],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
]
},
{
"ConstantInt": 4
"Constant": {
"Int": 4
}
}
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
}
],
{
"ConstantInt": 4
"Constant": {
"Int": 4
}
}
]
},
Expand All @@ -45,7 +47,9 @@
}
],
{
"ConstantInt": 4
"Constant": {
"Int": 4
}
}
]
},
Expand All @@ -62,7 +66,9 @@
}
},
{
"ConstantInt": 0
"Constant": {
"Int": 0
}
}
]
}
Expand Down
Loading

0 comments on commit 874b6d1

Please sign in to comment.