Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite for set membership of the powerset of a record where the record has infinite co-domains. #2946

Merged
merged 1 commit into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .unreleased/features/rewrite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle expressions such as S \in SUBSET [ a : Int ] by rewriting the expression into \A r \in S: DOMAIN r = {"a"} /\ r.a \in Int
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker
}
apply(tla.and(domEq +: fieldsEq: _*).as(b))
}

// S ∈ SUBSET { ["a" ↦ x] : x ∈ T }
case memEx @ OperEx(TlaSetOper.in, setRec,
OperEx(TlaSetOper.powerset,
OperEx(TlaSetOper.map, OperEx(TlaFunOper.rec, fieldsAndValues @ _*), varsAndSets @ _*)))
if fieldsAndValues.length == varsAndSets.length =>
val (fields, values) = TlaOper.deinterleave(fieldsAndValues)
val (vars, sets) = TlaOper.deinterleave(varsAndSets)
assert(fields.length == vars.length)
if (values.zip(vars).exists(p => p._1 != p._2)) {
// The set has a more general form: { [f_1 |-> e_1, ..., f_k |-> e_k]: x_1 \in S_1, ..., x_k \in S_k }, where
// e_1, ..., e_k are expressions over x_1, ..., x_k.
// We do not know how to optimize it.
memEx
lemmy marked this conversation as resolved.
Show resolved Hide resolved
} else {
val strSetT = SetT1(StrT1)
val b = BoolT1

val domType = getElemType(setRec)
val r = tla.name(nameGen.newName()).as(domType)

val domEq = tla.eql(tla.dom(r).as(SetT1(domType)), tla.enumSet(fields: _*).as(strSetT)).as(b)

val fieldsEq = fields.zip(values.zip(sets)).map { case (key, (value, set)) =>
tla.in(tla.appFun(r, key).as(value.typeTag.asTlaType1()), set).as(b)
}
apply(tla.forall(r, setRec, tla.and(domEq +: fieldsEq: _*).as(b)).as(b))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,100 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach {

// An optimization for set membership over sets of records. Note that this is the standard form produced by Keramelizer.
test("""r \in { [a |-> x, b |-> y]: x \in S, y \in T } becomes DOMAIN r = { "a", "b" } /\ r.a \in S /\ r.b \in T""") {
// ... [a |-> x, b |-> y] ...
val recT = RecT1("a" -> IntT1, "b" -> IntT1)
val recSetT = SetT1(recT)
// ... x \in S, y \in T ...
val record =
enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT)
// ... S ...
val S = name("S").as(intSetT)
// ... T ...
val T = name("T").as(intSetT)
// { ... }
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), S, name("y").as(intT), T).as(recSetT)
// r ...
val r = name("r").as(recT)
// ... \in ...
val input = in(r, recordSet).as(boolT)

// ~~>

// DOMAIN r = { "a", "b" }
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT)
// r.a \in S
val memA = in(appFun(r, str("a")).as(intT), S).as(boolT)
// r.b \in T
val memB = in(appFun(r, str("b")).as(intT), T).as(boolT)
// ... /\ ... /\ ...
val expected = and(domEq, memA, memB).as(boolT)
val output = optimizer.apply(input)

assert(expected == output)
}

// An optimization for set membership of the powerset of a record where the record has infinite co-domains.
test("""S \in SUBSET [a : T] ~~> \A r \in S: DOMAIN r = { "a" } /\ r.a \in T""") {

// ... { [a |-> x] : x \in T } ...
val recT = RecT1("a" -> IntT1)
val record =
enumFun(str("a"), name("x").as(intT)).as(recT)
val T = name("T").as(intSetT)
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), T).as(recSetT)

// ... SUBSET ...
val powSetT = powSet(recordSet).as(recSetT)

// S ...
val s = name("S").as(recSetT)

// ... \in ...
val input = in(s, powSetT).as(boolT)
val output = optimizer.apply(input)

// ~~>

// DOMAIN r = { "a" }
val r = name("t_1").as(recT)
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a")).as(strSetT)).as(boolT)

// r.a \in T
val memA = in(appFun(r, str("a")).as(intT), T).as(boolT)

// ... /\ ...
val conjunct = and(domEq, memA).as(boolT)

// \A ...
val expected = forall(r, s, conjunct).as(boolT)

assert(expected == output)
}

test("""S \in SUBSET [a : T, b : U] ~~> \A r \in S: DOMAIN r = { "a", "b" } /\ r.a \in T /\ r.b \in U""") {
val recT = RecT1("a" -> IntT1, "b" -> IntT1)
val record =
enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT)
val T = name("T").as(intSetT)
val U = name("U").as(intSetT)
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), T, name("y").as(intT), U).as(recSetT)
val powSetT = powSet(recordSet).as(recSetT)
val s = name("S").as(recSetT)
val input = in(s, powSetT).as(boolT)
val output = optimizer.apply(input)

val r = name("t_1").as(recT)
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT)
val memA = in(appFun(r, str("a")).as(intT), T).as(boolT)
val memB = in(appFun(r, str("b")).as(intT), U).as(boolT)
val conjunct = and(domEq, memA, memB).as(boolT)
val expected = forall(r, s, conjunct).as(boolT)

assert(expected == output)
}

Expand Down
Loading