Skip to content

Commit

Permalink
Merge pull request #442 from Alex-Fischman/sort-names
Browse files Browse the repository at this point in the history
Sort declaration cleanup
  • Loading branch information
Alex-Fischman authored Oct 16, 2024
2 parents 43de12f + 8ae1427 commit 5d637f2
Show file tree
Hide file tree
Showing 19 changed files with 217 additions and 284 deletions.
2 changes: 1 addition & 1 deletion src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl EGraph {
let ts = self.timestamp;
let out = &function.schema.output;
match function.decl.default.as_ref() {
None if out.name() == UNIT_SYM.into() => {
None if out.name() == UnitSort.name() => {
function.insert(values, Value::unit(), ts);
Value::unit()
}
Expand Down
4 changes: 2 additions & 2 deletions src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ pub enum GenericExpr<Head, Leaf> {
}

impl ResolvedExpr {
pub fn output_type(&self, type_info: &TypeInfo) -> ArcSort {
pub fn output_type(&self) -> ArcSort {
match self {
ResolvedExpr::Lit(_, lit) => type_info.infer_literal(lit),
ResolvedExpr::Lit(_, lit) => sort::literal_sort(lit),
ResolvedExpr::Var(_, resolved_var) => resolved_var.sort.clone(),
ResolvedExpr::Call(_, resolved_call, _) => resolved_call.output().clone(),
}
Expand Down
18 changes: 5 additions & 13 deletions src/ast/remove_globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
//! When a globally-bound primitive value is used in the actions of a rule,
//! we add a new variable to the query bound to the primitive value.
use crate::{
core::ResolvedCall, typechecking::FuncType, FreshGen, GenericAction, GenericActions,
GenericExpr, GenericFact, GenericNCommand, GenericRule, HashMap, ResolvedAction, ResolvedExpr,
ResolvedFact, ResolvedFunctionDecl, ResolvedNCommand, ResolvedVar, Schema, SymbolGen, TypeInfo,
};
use crate::*;
use crate::{core::ResolvedCall, typechecking::FuncType};

struct GlobalRemover<'a> {
fresh: &'a mut SymbolGen,
Expand Down Expand Up @@ -45,13 +42,12 @@ struct GlobalRemover<'a> {
/// ((Add fresh_var_for_x fresh_var_for_x)))
/// ```
pub(crate) fn remove_globals(
type_info: &TypeInfo,
prog: Vec<ResolvedNCommand>,
fresh: &mut SymbolGen,
) -> Vec<ResolvedNCommand> {
let mut remover = GlobalRemover { fresh };
prog.into_iter()
.flat_map(|cmd| remover.remove_globals_cmd(type_info, cmd))
.flat_map(|cmd| remover.remove_globals_cmd(cmd))
.collect()
}

Expand Down Expand Up @@ -91,15 +87,11 @@ fn remove_globals_action(action: ResolvedAction) -> ResolvedAction {
}

impl<'a> GlobalRemover<'a> {
fn remove_globals_cmd(
&mut self,
type_info: &TypeInfo,
cmd: ResolvedNCommand,
) -> Vec<ResolvedNCommand> {
fn remove_globals_cmd(&mut self, cmd: ResolvedNCommand) -> Vec<ResolvedNCommand> {
match cmd {
GenericNCommand::CoreAction(action) => match action {
GenericAction::Let(span, name, expr) => {
let ty = expr.output_type(type_info);
let ty = expr.output_type();

let func_decl = ResolvedFunctionDecl {
name: name.name,
Expand Down
17 changes: 7 additions & 10 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl Assignment<AtomTerm, ArcSort> {
.collect();
let types: Vec<_> = args
.iter()
.map(|arg| arg.output_type(typeinfo))
.map(|arg| arg.output_type())
.chain(once(
self.get(&AtomTerm::Var(DUMMY_SPAN.clone(), *corresponding_var))
.unwrap()
Expand Down Expand Up @@ -351,8 +351,8 @@ impl Assignment<AtomTerm, ArcSort> {
let rhs = self.annotate_expr(rhs, typeinfo);
let types: Vec<_> = children
.iter()
.map(|child| child.output_type(typeinfo))
.chain(once(rhs.output_type(typeinfo)))
.map(|child| child.output_type())
.chain(once(rhs.output_type()))
.collect();
let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
if !matches!(resolved_call, ResolvedCall::Func(_)) {
Expand All @@ -379,10 +379,7 @@ impl Assignment<AtomTerm, ArcSort> {
.iter()
.map(|child| self.annotate_expr(child, typeinfo))
.collect();
let types: Vec<_> = children
.iter()
.map(|child| child.output_type(typeinfo))
.collect();
let types: Vec<_> = children.iter().map(|child| child.output_type()).collect();
let resolved_call =
ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
.ok_or_else(|| TypeError::UnboundFunction(*head, span.clone()))?;
Expand Down Expand Up @@ -568,7 +565,7 @@ impl CoreAction {
get_literal_and_global_constraints(&[e.clone(), n.clone()], typeinfo)
.chain(once(Constraint::Assign(
n.clone(),
typeinfo.get_sort_nofail::<I64Sort>() as ArcSort,
std::sync::Arc::new(I64Sort) as ArcSort,
)))
.collect(),
)
Expand Down Expand Up @@ -684,8 +681,8 @@ fn get_literal_and_global_constraints<'a>(
AtomTerm::Var(_, _) => None,
// Literal to type constraint
AtomTerm::Literal(_, lit) => {
let typ = type_info.infer_literal(lit);
Some(Constraint::Assign(arg.clone(), typ.clone()))
let typ = crate::sort::literal_sort(lit);
Some(Constraint::Assign(arg.clone(), typ))
}
AtomTerm::Global(_, v) => {
if let Some(typ) = type_info.lookup_global(v) {
Expand Down
12 changes: 5 additions & 7 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
use std::hash::Hasher;
use std::ops::AddAssign;

use crate::HashMap;
use crate::{typechecking::FuncType, *};
use crate::{typechecking::FuncType, HashMap, *};
use typechecking::TypeError;

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -190,10 +189,10 @@ impl<Leaf: Clone> GenericAtomTerm<Leaf> {
}

impl ResolvedAtomTerm {
pub fn output(&self, typeinfo: &TypeInfo) -> ArcSort {
pub fn output(&self) -> ArcSort {
match self {
ResolvedAtomTerm::Var(_, v) => v.sort.clone(),
ResolvedAtomTerm::Literal(_, l) => typeinfo.infer_literal(l),
ResolvedAtomTerm::Literal(_, l) => literal_sort(l),
ResolvedAtomTerm::Global(_, v) => v.sort.clone(),
}
}
Expand Down Expand Up @@ -838,12 +837,11 @@ impl ResolvedRule {
fresh_gen: &mut SymbolGen,
) -> Result<ResolvedCoreRule, TypeError> {
let value_eq = &typeinfo.primitives.get(&Symbol::from("value-eq")).unwrap()[0];
let unit = typeinfo.get_sort_nofail::<UnitSort>();
self.to_canonicalized_core_rule_impl(typeinfo, fresh_gen, |at1, at2| {
ResolvedCall::Primitive(SpecializedPrimitive {
primitive: value_eq.clone(),
input: vec![at1.output(typeinfo), at2.output(typeinfo)],
output: unit.clone(),
input: vec![at1.output(), at2.output()],
output: Arc::new(UnitSort),
})
})
}
Expand Down
40 changes: 16 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use thiserror::Error;
use generic_symbolic_expressions::Sexp;

use ast::*;
pub use typechecking::{TypeInfo, UNIT_SYM};
pub use typechecking::TypeInfo;

use crate::core::{AtomTerm, ResolvedCall};
use actions::Program;
Expand Down Expand Up @@ -678,11 +678,11 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::F64(f) => f.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::String(s) => s.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Unit => ().store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Bool(b) => b.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Int(i) => i.store(&I64Sort).unwrap(),
Literal::F64(f) => f.store(&F64Sort).unwrap(),
Literal::String(s) => s.store(&StringSort).unwrap(),
Literal::Unit => ().store(&UnitSort).unwrap(),
Literal::Bool(b) => b.store(&BoolSort).unwrap(),
}
}

Expand Down Expand Up @@ -739,7 +739,7 @@ impl EGraph {
.get(&sym)
// function_to_dag should have checked this
.unwrap();
let out_is_unit = f.schema.output.name() == UNIT_SYM.into();
let out_is_unit = f.schema.output.name() == UnitSort.name();

let mut buf = String::new();
let s = &mut buf;
Expand Down Expand Up @@ -1300,7 +1300,7 @@ impl EGraph {
let mut termdag = TermDag::default();
for expr in exprs {
let value = self.eval_resolved_expr(&expr)?;
let expr_type = expr.output_type(&self.type_info);
let expr_type = expr.output_type();
let term = self.extract(value, &mut termdag, &expr_type).1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&term))
Expand Down Expand Up @@ -1367,7 +1367,7 @@ impl EGraph {
let mut exprs: Vec<Expr> = str_buf.iter().map(|&s| parse(s)).collect();

actions.push(
if function_type.is_datatype || function_type.output.name() == UNIT_SYM.into() {
if function_type.is_datatype || function_type.output.name() == UnitSort.name() {
Action::Expr(span.clone(), Expr::Call(span.clone(), func_name, exprs))
} else {
let out = exprs.pop().unwrap();
Expand Down Expand Up @@ -1412,7 +1412,7 @@ impl EGraph {
.type_info
.typecheck_program(&mut self.symbol_gen, &program)?;

let program = remove_globals(&self.type_info, program, &mut self.symbol_gen);
let program = remove_globals(program, &mut self.symbol_gen);

Ok(program)
}
Expand Down Expand Up @@ -1476,11 +1476,6 @@ impl EGraph {
self.type_info.get_sort_by(pred)
}

/// Returns a sort based on the type
pub fn get_sort<S: Sort + Send + Sync>(&self) -> Option<Arc<S>> {
self.type_info.get_sort_by(|_| true)
}

/// Add a user-defined sort
pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.type_info.add_arcsort(arcsort, DUMMY_SPAN.clone())
Expand Down Expand Up @@ -1601,21 +1596,18 @@ mod tests {
fn test_user_defined_primitive() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
"
(sort IntVec (Vec i64))
",
)
.parse_and_run_program(None, "(sort IntVec (Vec i64))")
.unwrap();
let i64_sort: Arc<I64Sort> = egraph.get_sort().unwrap();

let int_vec_sort: Arc<VecSort> = egraph
.get_sort_by(|s: &Arc<VecSort>| s.element_name() == i64_sort.name())
.get_sort_by(|s: &Arc<VecSort>| s.element_name() == I64Sort.name())
.unwrap();

egraph.add_primitive(InnerProduct {
ele: i64_sort,
ele: I64Sort.into(),
vec: int_vec_sort,
});

egraph
.parse_and_run_program(
None,
Expand Down
14 changes: 5 additions & 9 deletions src/sort/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ use crate::ast::Literal;
use super::*;

#[derive(Debug)]
pub struct BoolSort {
name: Symbol,
}
pub struct BoolSort;

impl BoolSort {
pub fn new(name: Symbol) -> Self {
Self { name }
}
lazy_static! {
static ref BOOL_SORT_NAME: Symbol = "bool".into();
}

impl Sort for BoolSort {
fn name(&self) -> Symbol {
self.name
*BOOL_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
Expand Down Expand Up @@ -44,7 +40,7 @@ impl IntoSort for bool {
type Sort = BoolSort;
fn store(self, sort: &Self::Sort) -> Option<Value> {
Some(Value {
tag: sort.name,
tag: sort.name(),
bits: self as u64,
})
}
Expand Down
14 changes: 5 additions & 9 deletions src/sort/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ use crate::ast::Literal;
use ordered_float::OrderedFloat;

#[derive(Debug)]
pub struct F64Sort {
name: Symbol,
}
pub struct F64Sort;

impl F64Sort {
pub fn new(name: Symbol) -> Self {
Self { name }
}
lazy_static! {
static ref F64_SORT_NAME: Symbol = "f64".into();
}

impl Sort for F64Sort {
fn name(&self) -> Symbol {
self.name
*F64_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
Expand Down Expand Up @@ -70,7 +66,7 @@ impl IntoSort for f64 {
type Sort = F64Sort;
fn store(self, sort: &Self::Sort) -> Option<Value> {
Some(Value {
tag: sort.name,
tag: sort.name(),
bits: self.to_bits(),
})
}
Expand Down
Loading

0 comments on commit 5d637f2

Please sign in to comment.