Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sequences as a new Suslik data type #42

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
84 changes: 83 additions & 1 deletion src/main/scala/org/tygus/suslik/language/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ object Expressions {
(IntSetType, IntSetType) -> OpSetEq,
(IntervalType, IntervalType) -> OpIntervalEq,
(BoolType, BoolType) -> OpBoolEq,
(IntSequenceType, IntSequenceType) -> OpSequenceEq
)

override def default: BinOp = OpEq
Expand Down Expand Up @@ -132,6 +133,7 @@ object Expressions {
(IntType, IntType) -> OpPlus,
(IntSetType, IntSetType) -> OpUnion,
(IntervalType, IntervalType) -> OpIntervalUnion,
(IntSequenceType, IntSequenceType) -> OpSequenceAppend
)

override def default: BinOp = OpPlus
Expand All @@ -143,6 +145,7 @@ object Expressions {
override def opFromTypes: Map[(SSLType, SSLType), BinOp] = Map(
(IntType, IntType) -> OpMinus,
(IntSetType, IntSetType) -> OpDiff,
(IntSequenceType, IntType) -> OpSequenceRemove,
)

override def default: BinOp = OpMinus
Expand Down Expand Up @@ -298,7 +301,64 @@ object Expressions {
def lType: SSLType = IntervalType
def rType: SSLType = IntervalType
}
object OpSequenceEq extends RelOp with SymmetricOp {
def level: Int = 3
override def pp: String = "=="
def lType: SSLType = IntSequenceType
def rType: SSLType = IntSequenceType
}
object OpSequenceCons extends BinOp {
def level: Int = 4
override def pp: String = "::"
def lType: SSLType = IntType
def rType: SSLType = IntSequenceType
def resType: SSLType = IntSequenceType
}
object OpSequenceAppend extends BinOp with AssociativeOp {
def level: Int = 4
override def pp: String = "++"
def lType: SSLType = IntSequenceType
def rType: SSLType = IntSequenceType
def resType: SSLType = IntSequenceType
}
object OpSequenceRemove extends BinOp {
def level: Int = 4
override def pp: String = "--"
def lType: SSLType = IntSequenceType
def rType: SSLType = IntType
def resType: SSLType = IntSequenceType
}
object OpSequenceAt extends BinOp {
def level: Int = 4
override def pp: String = "@"
def lType: SSLType = IntSequenceType
def rType: SSLType = IntType
def resType: SSLType = IntSequenceType
}
object OpSequenceHead extends UnOp {
override def pp: String = "head"
override def inputType: SSLType = IntSequenceType
override def outputType: SSLType = IntType
}
object OpSequenceTail extends UnOp {
override def pp: String = "tail"
override def inputType: SSLType = IntSequenceType
override def outputType: SSLType = IntSequenceType
}

object OpSequenceLen extends UnOp {
override def pp: String = "len"
override def inputType: SSLType = IntSequenceType
override def outputType: SSLType = IntType
}

object OpSequenceIndexof extends BinOp {
def level: Int = 4
override def pp: String = "!!"
def lType: SSLType = IntSequenceType
def rType: SSLType = IntType
def resType: SSLType = IntType
}

sealed abstract class Expr extends PrettyPrinting with HasExpressions[Expr] with Ordered[Expr] {

Expand Down Expand Up @@ -326,6 +386,9 @@ object Expressions {
case s@SetLiteral(elems) =>
val acc1 = if (p(s)) acc + s.asInstanceOf[R] else acc
elems.foldLeft(acc1)((a,e) => collector(a)(e))
case s@SequenceLiteral(elems) =>
val acc1 = if (p(s)) acc + s.asInstanceOf[R] else acc
elems.foldLeft(acc1)((a, e) => collector(a)(e))
case i@IfThenElse(cond, l, r) =>
val acc1 = if (p(i)) acc + i.asInstanceOf[R] else acc
val acc2 = collector(acc1)(cond)
Expand Down Expand Up @@ -425,6 +488,13 @@ object Expressions {
case Some(g) => e.resolve(g, Some(IntType))
})
} else None
case SequenceLiteral(elems) =>
if (IntSequenceType.conformsTo(target)) {
elems.foldLeft[Option[Gamma]](Some(gamma))((go, e) => go match {
case None => None
case Some(g) => e.resolve(g, Some(IntType))
})
} else None
case IfThenElse(c, t, e) =>
for {
gamma1 <- c.resolve(gamma, Some(BoolType))
Expand All @@ -449,6 +519,7 @@ object Expressions {
case OverloadedBinaryExpr(_, l, r) => 1 + l.size + r.size
case UnaryExpr(_, arg) => 1 + arg.size
case SetLiteral(elems) => 1 + elems.map(_.size).sum
case SequenceLiteral(elems) => 1 + elems.map(_.size).sum
case IfThenElse(cond, l, r) => 1 + cond.size + l.size + r.size
case _ => 1
}
Expand Down Expand Up @@ -478,6 +549,7 @@ object Expressions {
case IfThenElse(c, t, e) =>IfThenElse(c.resolveOverloading(gamma),
t.resolveOverloading(gamma),
e.resolveOverloading(gamma))
case SequenceLiteral(elems) => SequenceLiteral(elems.map(_.resolveOverloading(gamma)))

}
}
Expand Down Expand Up @@ -592,7 +664,11 @@ object Expressions {
def subst(sigma: Subst): Expr = UnaryExpr(op, arg.subst(sigma))
override def substUnknown(sigma: UnknownSubst): Expr = UnaryExpr(op, arg.substUnknown(sigma))
override def level = 5
override def pp: String = s"${op.pp} ${arg.printInContext(this)}"
override def pp: String = op match {
case OpSequenceLen => s"|${arg.printInContext(this)}|"
case _ => s"${op.pp} ${arg.printInContext(this)}"
}

def getType(gamma: Gamma): Option[SSLType] = Some(op.outputType)
}

Expand All @@ -610,6 +686,12 @@ object Expressions {
def getType(gamma: Gamma): Option[SSLType] = left.getType(gamma)
}

case class SequenceLiteral(elems: List[Expr]) extends Expr {
override def pp: String = s"<<${elems.map(_.pp).mkString(",")}>>"
override def subst(sigma: Subst): SequenceLiteral = SequenceLiteral(elems.map(_.subst(sigma)))
def getType(gamma: Gamma): Option[SSLType] = Some(IntSequenceType)
}

/**
* Unknown predicate (to be replaced by a term)
* @param name Predicate name
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/org/tygus/suslik/language/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ case object CardType extends SSLType {
case _ => None
}
}

case object IntSequenceType extends SSLType {
override def pp: String = "intseq"
}
19 changes: 19 additions & 0 deletions src/main/scala/org/tygus/suslik/logic/PureLogicUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ trait PureLogicUtils {
case _:Var => e
case IfThenElse(e1,e2,e3) => IfThenElse(propagate_not(e1),propagate_not(e2), propagate_not(e3))
case SetLiteral(args) => SetLiteral(args.map(propagate_not))
case SequenceLiteral(args) => SequenceLiteral(args.map(propagate_not))
case e => throw SynthesisException(s"Not supported: ${e.pp} (${e.getClass.getName})")
}

Expand All @@ -67,6 +68,7 @@ trait PureLogicUtils {
case _:Var => e
case IfThenElse(e1,e2,e3) => IfThenElse(desugar(e1),desugar(e2), desugar(e3))
case SetLiteral(args) => SetLiteral(args.map(desugar))
case SequenceLiteral(args) => SequenceLiteral(args.map(desugar))
case e => throw SynthesisException(s"Not supported: ${e.pp} (${e.getClass.getName})")
}

Expand Down Expand Up @@ -120,6 +122,23 @@ trait PureLogicUtils {
// case BinaryExpr(OpBoolEq, v1@Var(n1), v2@Var(n2)) => // sort arguments lexicographically
// if (n1 <= n2) BinaryExpr(OpBoolEq, v1, v2) else BinaryExpr(OpBoolEq, v2, v1)
// case BinaryExpr(OpBoolEq, e, v@Var(_)) if !e.isInstanceOf[Var] => BinaryExpr(OpBoolEq, v, simplify(e))

// Sequence Operations

// Sequence Equality
case BinaryExpr(OpSequenceEq, Var(n1), Var(n2)) if n1 == n2 => // remove trivial equality
BoolConst(true)
case BinaryExpr(OpSequenceEq, v1@Var(n1), v2@Var(n2)) => // sort arguments lexicographically
if (n1 <= n2) BinaryExpr(OpSequenceEq, v1, v2) else BinaryExpr(OpSequenceEq, v2, v1)
case BinaryExpr(OpSequenceEq, e, v@Var(_)) if !e.isInstanceOf[Var] => BinaryExpr(OpSequenceEq, v, simplify(e))

// Sequence Append
case BinaryExpr(OpSequenceAppend, left, SequenceLiteral(s)) if s.isEmpty => simplify(left)
case BinaryExpr(OpSequenceAppend, SequenceLiteral(s), right) if s.isEmpty => simplify(right)

// Sequence Cons
//case BinaryExpr(OpSequenceCons, left, SequenceLiteral(s)) if s.isEmpty => SequenceLiteral([simplify(left)])


case BinaryExpr(OpPlus, left, IntConst(i)) if i.toInt == 0 => simplify(left)
case BinaryExpr(OpPlus, IntConst(i), right) if i.toInt == 0 => simplify(right)
Expand Down
137 changes: 130 additions & 7 deletions src/main/scala/org/tygus/suslik/logic/smt/SMTSolving.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ object SMTSolving extends Core

trait SetTerm
trait IntervalTerm
trait SequenceTerm

type SMTBoolTerm = TypedTerm[BoolTerm, Term]
type SMTIntTerm = TypedTerm[IntTerm, Term]
type SMTSetTerm = TypedTerm[SetTerm, Term]
type SMTIntervalTerm = TypedTerm[IntervalTerm, Term]
type SMTSequenceTerm = TypedTerm[SequenceTerm, Term]

def setSort: Sort = SortId(SymbolId(SSymbol("SetInt")))

Expand Down Expand Up @@ -107,6 +109,42 @@ object SMTSolving extends Core
"(define-fun iunion ((s1 Interval) (s2 Interval)) Interval (ite (iempty s1) s2 (iinsert (lower s1) (iinsert (upper s1) s2))))",
)

def sequenceSort: Sort = SortId(SymbolId(SSymbol("SequenceInt")))

def emptySequenceSymbol = SimpleQId(SymbolId(SSymbol("sempty")))

def sequenceConsSymbol = SimpleQId(SymbolId(SSymbol("scons")))

def sequenceAppendSymbol = SimpleQId(SymbolId(SSymbol("sappend")))

def sequenceRemoveSymbol = SimpleQId(SymbolId(SSymbol("sremove")))

def sequenceIndexofSymbol = SimpleQId(SymbolId(SSymbol("sindexof")))

def sequenceLenSymbol = SimpleQId(SymbolId(SSymbol("slen")))

def sequenceAtSymbol = SimpleQId(SymbolId(SSymbol("seat")))

def emptySequenceTerm: Term = QIdTerm(emptySequenceSymbol)

/*def sequencePrelude: List[String] = List(
"(define-sort SequenceInt () (List Int))",
"(define-fun sempty () SequenceInt (as nil SequenceInt))",
"(define-fun scons ((x Int) (xs SequenceInt)) SequenceInt (insert x xs))",
"(define-fun-rec sappend ((xs SequenceInt) (ys SequenceInt)) SequenceInt (match xs ((nil ys) ((insert x xsn) (insert x (sappend xsn ys))))))"
)*/

def sequencePrelude: List[String] = List(
"(define-sort SequenceInt () (Seq Int))",
"(define-fun sempty () SequenceInt (as seq.empty SequenceInt))",
"(define-fun scons ((x Int) (xs SequenceInt)) SequenceInt (seq.++ (seq.unit x) xs))",
"(define-fun sappend ((xs SequenceInt) (ys SequenceInt)) SequenceInt (seq.++ xs ys))",
"(define-fun sremove ((xs SequenceInt) (y Int)) SequenceInt (seq.replace xs (seq.unit y) sempty))",
"(define-fun sindexof ((xs SequenceInt) (y Int)) Int (seq.indexof xs (seq.unit y)))",
"(define-fun slen ((xs SequenceInt)) Int (seq.len xs))",
"(define-fun seat ((xs SequenceInt) (y Int)) SequenceInt (seq.at xs y))"
)

// Commands to be executed before solving starts
def prelude = if (defaultSolver == "CVC4") {
List(
Expand All @@ -127,7 +165,7 @@ object SMTSolving extends Core
"(assert (forall ((b1 Bool) (b2 Bool)) (= (andNot b1 b2) (and b1 (not b2)))))",
"(define-fun difference ((s1 SetInt) (s2 SetInt)) SetInt ((_ map andNot) s1 s2))",
"(declare-datatypes () ((Interval (interval (lower Int) (upper Int)))))"
) ++ intervalPrelude
) ++ sequencePrelude ++ intervalPrelude
} else if (defaultSolver == "Z3 <= 4.7.x") {
// In Z3 4.7.x and below, difference is built in and intersection is called intersect
List(
Expand All @@ -136,7 +174,7 @@ object SMTSolving extends Core
"(define-fun member ((x Int) (s SetInt)) Bool (select s x))",
"(define-fun insert ((x Int) (s SetInt)) SetInt (store s x true))",
"(declare-datatypes () ((Interval (interval (lower Int) (upper Int)))))"
) ++ intervalPrelude
) ++ sequencePrelude ++ intervalPrelude
} else throw SolverUnsupportedExpr(defaultSolver)

private def checkSat(term: SMTBoolTerm): Boolean =
Expand Down Expand Up @@ -238,6 +276,66 @@ object SMTSolving extends Core
case _ => throw SMTUnsupportedExpr(e)
}

private def convertSequenceExpr(e: Expr): SMTSequenceTerm = e match {
case Var(name) => new VarTerm[SequenceTerm](name, sequenceSort)
case SequenceLiteral(elems) => {
val emptyTerm = new TypedTerm[SequenceTerm, Term](Set.empty, emptySequenceTerm)
makeSequenceCons(emptyTerm, elems)
}
case BinaryExpr(OpSequenceCons, left, right) => {
val l = convertIntExpr(left)
val r = convertSequenceExpr(right)

new TypedTerm[SequenceTerm, Term](l.typeDefs ++ r.typeDefs,
QIdAndTermsTerm(sequenceConsSymbol, List(l.termDef, r.termDef)))
}

case BinaryExpr(OpSequenceAppend, left, right) => {
val l = convertSequenceExpr(left)
val r = convertSequenceExpr(right)

new TypedTerm[SequenceTerm, Term](l.typeDefs ++ r.typeDefs,
QIdAndTermsTerm(sequenceAppendSymbol, List(l.termDef, r.termDef)))
}

case BinaryExpr(OpSequenceRemove, left, right) => {
var l = convertSequenceExpr(left)
var r = convertIntExpr(right)

new TypedTerm[SequenceTerm, Term](l.typeDefs ++ r.typeDefs,
QIdAndTermsTerm(sequenceRemoveSymbol, List(l.termDef, r.termDef)))
}
case BinaryExpr(OpSequenceAt, left, right) => {
val l = convertSequenceExpr(left)
var r = convertIntExpr(right)
new TypedTerm[SequenceTerm, Term](l.typeDefs ++ r.typeDefs, QIdAndTermsTerm(sequenceAtSymbol, List(l.termDef, r.termDef)))
}
case IfThenElse(cond, left, right) => {
val c = convertBoolExpr(cond)
val l = convertSequenceExpr(left)
val r = convertSequenceExpr(right)
c.ite(l, r)
}
case _ => throw SMTUnsupportedExpr(e)
}

private def makeSequenceCons(sequenceTerm: SMTSequenceTerm, elems: List[Expr]): SMTSequenceTerm = {
if (elems.isEmpty) {
sequenceTerm
} else {
val eTerms: List[SMTIntTerm] = elems.map(convertIntExpr)
if (defaultSolver == "CVC4") {
throw SolverUnsupportedExpr(defaultSolver)
} else if (defaultSolver == "Z3" || defaultSolver == "Z3 <= 4.7.x") {
def makeInsertOne(eTerm: SMTIntTerm, sequenceTerm: SMTSequenceTerm): SMTSequenceTerm =
new TypedTerm[SequenceTerm, Term](sequenceTerm.typeDefs ++ eTerm.typeDefs,
QIdAndTermsTerm(sequenceConsSymbol, List(eTerm.termDef, sequenceTerm.termDef)))

eTerms.foldRight(sequenceTerm)(makeInsertOne)
} else throw SolverUnsupportedExpr(defaultSolver)
}
}

private def convertBoolExpr(e: Expr): SMTBoolTerm = e match {
case Var(name) => Bools(name)
case BoolConst(true) => True()
Expand Down Expand Up @@ -314,6 +412,12 @@ object SMTSolving extends Core
val r = convertBoolExpr(right)
c.ite(l, r)
}
case BinaryExpr(OpSequenceEq, left, right) => {
val l = convertSequenceExpr(left)
val r = convertSequenceExpr(right)

l === r
}
case Unknown(_, _, _) => True() // Treat unknown predicates as true
case _ => throw SMTUnsupportedExpr(e)
}
Expand All @@ -330,13 +434,32 @@ object SMTSolving extends Core
val s = convertIntervalExpr(e)
new TypedTerm[IntTerm, Term](s.typeDefs, QIdAndTermsTerm(intervalUpperSymbol, List(s.termDef)))
}
case UnaryExpr(OpSequenceLen, e) => {
val s = convertSequenceExpr(e)
new TypedTerm[IntTerm, Term](s.typeDefs, QIdAndTermsTerm(sequenceLenSymbol, List(s.termDef)))
}
case BinaryExpr(op, left, right) => {
val l = convertIntExpr(left)
val r = convertIntExpr(right)
op match {
case OpPlus => l + r
case OpMinus => l - r
case OpMultiply => l * r
case OpPlus => {
val l = convertIntExpr(left)
val r = convertIntExpr(right)
l + r
}
case OpMinus => {
val l = convertIntExpr(left)
val r = convertIntExpr(right)
l - r
}
case OpMultiply => {
val l = convertIntExpr(left)
val r = convertIntExpr(right)
l * r
}
case OpSequenceIndexof => {
val l = convertSequenceExpr(left)
val r = convertIntExpr(right)
new TypedTerm[IntTerm, Term](l.typeDefs ++ r.typeDefs, QIdAndTermsTerm(sequenceIndexofSymbol, List(l.termDef, r.termDef)))
}
case _ => throw SMTUnsupportedExpr(e)
}
}
Expand Down
Loading