Skip to content

Commit

Permalink
Implement IIFEs in the Rust backend (dafny-lang#4382)
Browse files Browse the repository at this point in the history
Implement IIFEs in the Rust backend

<small>By submitting this pull request, I confirm that my contribution
is made under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).</small>

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/dafny-lang/dafny/pull/4382).
* dafny-lang#4422
* __->__ dafny-lang#4382
  • Loading branch information
shadaj authored Aug 15, 2023
1 parent 5fcad67 commit 09f64e9
Show file tree
Hide file tree
Showing 12 changed files with 925 additions and 685 deletions.
7 changes: 4 additions & 3 deletions Source/DafnyCore/Compilers/CSharp/Compiler-Csharp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2869,13 +2869,14 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
return result;
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
wr.Write(source);
source(wr);
} else {
wr.Write($"{source}.{DestructorGetterName(dtor, ctor, formalNonGhostIndex)}");
source(wr);
wr.Write($".{DestructorGetterName(dtor, ctor, formalNonGhostIndex)}");
}
}

Expand Down
14 changes: 10 additions & 4 deletions Source/DafnyCore/Compilers/Cplusplus/Compiler-cpp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1951,9 +1951,11 @@ protected override void EmitConstructorCheck(string source, DatatypeCtor ctor, C
wr.Write("is_{1}({0})", source, DatatypeSubStructName(ctor));
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (ctor.EnclosingDatatype is TupleTypeDecl) {
wr.Write("({0}).template get<{1}>()", source, formalNonGhostIndex);
wr.Write("(");
source(wr);
wr.Write(").template get<{0}>()", formalNonGhostIndex);
} else {
var dtorName = FormalName(dtor, formalNonGhostIndex);
if (dtor.Type is UserDefinedType udt && udt.ResolvedClass == ctor.EnclosingDatatype) {
Expand All @@ -1962,9 +1964,13 @@ protected override void EmitDestructor(string source, Formal dtor, int formalNon
}

if (ctor.EnclosingDatatype.Ctors.Count > 1) {
wr.Write("(({0}).dtor_{1}())", source, dtorName);
wr.Write("((");
source(wr);
wr.Write(").dtor_{0}())", dtorName);
} else {
wr.Write("(({0}).{1})", source, dtorName);
wr.Write("((");
source(wr);
wr.Write(").{0})", dtorName);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion Source/DafnyCore/Compilers/Dafny/AST.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ module {:extern "DAST"} DAST {
Select(expr: Expression, field: string, onDatatype: bool) |
SelectFn(expr: Expression, field: string, onDatatype: bool, isStatic: bool, arity: nat) |
TupleSelect(expr: Expression, index: nat) |
Call(on: Expression, name: string, typeArgs: seq<Type>, args: seq<Expression>) |
Call(on: Expression, name: Ident, typeArgs: seq<Type>, args: seq<Expression>) |
Lambda(params: seq<Formal>, body: seq<Statement>) |
IIFE(name: Ident, typ: Type, value: Expression, iifeBody: Expression) |
Apply(expr: Expression, args: seq<Expression>) |
TypeTest(on: Expression, dType: seq<Ident>, variant: string) |
InitializationValue(typ: Type)
Expand Down
78 changes: 78 additions & 0 deletions Source/DafnyCore/Compilers/Dafny/ASTBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,12 @@ LambdaExprBuilder Lambda(List<DAST.Formal> formals) {
return ret;
}

IIFEExprBuilder IIFE(string name, DAST.Type tpe) {
var ret = new IIFEExprBuilder(name, tpe);
AddBuildable(ret);
return ret;
}

protected static void RecursivelyBuild(List<object> body, List<DAST.Expression> builtExprs) {
foreach (var maybeBuilt in body) {
if (maybeBuilt is DAST.Expression built) {
Expand Down Expand Up @@ -910,3 +916,75 @@ public DAST.Expression Build() {
);
}
}

class IIFEExprBuilder : ExprContainer, BuildableExpr {
readonly string name;
readonly DAST.Type tpe;

object body = null;
public object value = null;

public IIFEExprBuilder(string name, DAST.Type tpe) {
this.name = name;
this.tpe = tpe;
}

public IIFEExprRhs RhsBuilder() {
return new IIFEExprRhs(this);
}

public void AddExpr(DAST.Expression item) {
if (body != null) {
throw new InvalidOperationException();
} else {
body = item;
}
}

public void AddBuildable(BuildableExpr item) {
if (body != null) {
throw new InvalidOperationException();
} else {
body = item;
}
}

public DAST.Expression Build() {
var builtBody = new List<DAST.Expression>();
ExprContainer.RecursivelyBuild(new List<object> { body }, builtBody);

var builtValue = new List<DAST.Expression>();
ExprContainer.RecursivelyBuild(new List<object> { value }, builtValue);

return (DAST.Expression)DAST.Expression.create_IIFE(
Sequence<Rune>.UnicodeFromString(name),
tpe,
builtValue[0],
builtBody[0]
);
}
}

class IIFEExprRhs : ExprContainer {
readonly IIFEExprBuilder parent;

public IIFEExprRhs(IIFEExprBuilder parent) {
this.parent = parent;
}

public void AddExpr(DAST.Expression item) {
if (parent.value != null) {
throw new InvalidOperationException();
} else {
parent.value = item;
}
}

public void AddBuildable(BuildableExpr item) {
if (parent.value != null) {
throw new InvalidOperationException();
} else {
parent.value = item;
}
}
}
24 changes: 18 additions & 6 deletions Source/DafnyCore/Compilers/Dafny/Compiler-dafny.cs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ private DAST.Type GenType(Type typ) {
if (xType is BoolType) {
return (DAST.Type)DAST.Type.create_Primitive(DAST.Primitive.create_Bool());
} else if (xType is IntType) {
return (DAST.Type)DAST.Type.create_Passthrough(Sequence<Rune>.UnicodeFromString("i32"));
return (DAST.Type)DAST.Type.create_Passthrough(Sequence<Rune>.UnicodeFromString("i64"));
} else if (xType is RealType) {
return (DAST.Type)DAST.Type.create_Passthrough(Sequence<Rune>.UnicodeFromString("f32"));
} else if (xType.IsStringType) {
Expand Down Expand Up @@ -1064,6 +1064,9 @@ public override void EmitExpr(Expression expr, bool inLetExprBody, ConcreteSynta
var origBuilder = currentBuilder;
base.EmitExpr(expr, inLetExprBody, actualWr, wStmts);
currentBuilder = origBuilder;
} else if (expr is IdentifierExpr) {
// we don't need to create a copy of the identifier, that's language specific
base.EmitExpr(expr, false, actualWr, wStmts);
} else {
base.EmitExpr(expr, inLetExprBody, actualWr, wStmts);
}
Expand Down Expand Up @@ -1257,22 +1260,25 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
throw new NotImplementedException();
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor,
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor,
List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (wr is BuilderSyntaxTree<ExprContainer> builder) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
EmitIdentifier(source, wr);
source(wr);
} else {
var buf = new ExprBuffer(null);
source(new BuilderSyntaxTree<ExprContainer>(buf));
var sourceAST = buf.Finish();
if (ctor.EnclosingDatatype is TupleTypeDecl) {
builder.Builder.AddExpr((DAST.Expression)DAST.Expression.create_TupleSelect(
(DAST.Expression)DAST.Expression.create_Ident(Sequence<Rune>.UnicodeFromString(source)),
sourceAST,
int.Parse(dtor.NameForCompilation)
));
} else {
builder.Builder.AddExpr((DAST.Expression)DAST.Expression.create_Select(
(DAST.Expression)DAST.Expression.create_Ident(Sequence<Rune>.UnicodeFromString(source)),
sourceAST,
Sequence<Rune>.UnicodeFromString(dtor.CompileName),
true
));
Expand Down Expand Up @@ -1301,7 +1307,13 @@ protected override ConcreteSyntaxTree CreateLambda(List<Type> inTypes, IToken to

protected override void CreateIIFE(string bvName, Type bvType, IToken bvTok, Type bodyType, IToken bodyTok,
ConcreteSyntaxTree wr, ref ConcreteSyntaxTree wStmts, out ConcreteSyntaxTree wrRhs, out ConcreteSyntaxTree wrBody) {
throw new NotImplementedException();
if (wr is BuilderSyntaxTree<ExprContainer> builder) {
var iife = builder.Builder.IIFE(bvName, GenType(bvType));
wrRhs = new BuilderSyntaxTree<ExprContainer>(iife.RhsBuilder());
wrBody = new BuilderSyntaxTree<ExprContainer>(iife);
} else {
throw new InvalidOperationException("Invalid context for IIFE: " + wr.GetType());
}
}

protected override ConcreteSyntaxTree CreateIIFE0(Type resultType, IToken resultTok, ConcreteSyntaxTree wr,
Expand Down
11 changes: 7 additions & 4 deletions Source/DafnyCore/Compilers/GoLang/Compiler-go.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3076,18 +3076,21 @@ protected override void EmitConstructorCheck(string source, DatatypeCtor ctor, C
wr.Write("{0}.{1}()", source, FormatDatatypeConstructorCheckName(ctor.GetCompileName(Options)));
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
wr.Write(source);
source(wr);
} else if (ctor.EnclosingDatatype is TupleTypeDecl tupleTypeDecl) {
Contract.Assert(tupleTypeDecl.NonGhostDims != 1); // such a tuple is an erasable-wrapper type, handled above
wr.Write("(*({0}).IndexInt({1})).({2})", source, formalNonGhostIndex, TypeName(typeArgs[formalNonGhostIndex], wr, Token.NoToken));
wr.Write("(*(");
source(wr);
wr.Write(").IndexInt({0})).({1})", formalNonGhostIndex, TypeName(typeArgs[formalNonGhostIndex], wr, Token.NoToken));
} else {
var dtorName = DatatypeFieldName(dtor, formalNonGhostIndex);
wr = EmitCoercionIfNecessary(from: dtor.Type, to: bvType, tok: dtor.tok, wr: wr);
wr.Write("{0}.Get_().({1}).{2}", source, TypeName_Constructor(ctor, wr), dtorName);
source(wr);
wr.Write(".Get_().({0}).{1}", TypeName_Constructor(ctor, wr), dtorName);
}
}

Expand Down
8 changes: 5 additions & 3 deletions Source/DafnyCore/Compilers/Java/Compiler-java.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3293,11 +3293,11 @@ protected override IClassWriter CreateTrait(string name, bool isExtern, List<Typ
return new ClassWriter(this, instanceMemberWriter, ctorBodyWriter, staticMemberWriter);
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
wr.Write(source);
source(wr);
return;
}
string dtorName;
Expand All @@ -3308,7 +3308,9 @@ protected override void EmitDestructor(string source, Formal dtor, int formalNon
} else {
dtorName = FieldName(dtor, formalNonGhostIndex);
}
wr.Write("(({0}){1}{2}).{3}", DtCtorName(ctor, typeArgs, wr), source, ctor.EnclosingDatatype is CoDatatypeDecl ? ".Get()" : "", dtorName);
wr.Write("(({0})", DtCtorName(ctor, typeArgs, wr));
source(wr);
wr.Write("{0}).{1}", ctor.EnclosingDatatype is CoDatatypeDecl ? ".Get()" : "", dtorName);
}

private void CreateLambdaFunctionInterface(int i, ConcreteSyntaxTree outputWr) {
Expand Down
12 changes: 8 additions & 4 deletions Source/DafnyCore/Compilers/JavaScript/Compiler-js.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1902,17 +1902,21 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
return w;
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
wr.Write(source);
source(wr);
} else if (ctor.EnclosingDatatype is TupleTypeDecl tupleTypeDecl) {
Contract.Assert(tupleTypeDecl.NonGhostDims != 1); // such a tuple is an erasable-wrapper type, handled above
wr.Write("({0})[{1}]", source, formalNonGhostIndex);
wr.Write("(");
source(wr);
wr.Write(")[{0}]", formalNonGhostIndex);
} else {
var dtorName = FormalName(dtor, formalNonGhostIndex);
wr.Write("({0}){1}.{2}", source, ctor.EnclosingDatatype is CoDatatypeDecl ? "._D()" : "", dtorName);
wr.Write("(");
source(wr);
wr.Write("){0}.{1}", ctor.EnclosingDatatype is CoDatatypeDecl ? "._D()" : "", dtorName);
}
}

Expand Down
4 changes: 2 additions & 2 deletions Source/DafnyCore/Compilers/Python/Compiler-python.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1412,9 +1412,9 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
return EmitReturnExpr(wrBody);
}

protected override void EmitDestructor(string source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor,
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor,
List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
wr.Write(source);
source(wr);
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
Expand Down
35 changes: 26 additions & 9 deletions Source/DafnyCore/Compilers/Rust/Dafny-compiler-rust.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ module {:extern "DCOMP"} DCOMP {
s := s + "}\n";
s := s + "}\n";
s := s + "impl ::dafny_runtime::DafnyPrint for r#" + c.name + " {\n";
s := s + "fn fmt_print(&self, f: &mut ::std::fmt::Formatter, in_seq: bool) -> ::std::fmt::Result {\n";
s := s + "::dafny_runtime::DafnyPrint::fmt_print(&self.0, f, in_seq)\n";
s := s + "fn fmt_print(&self, __fmt_print_formatter: &mut ::std::fmt::Formatter, in_seq: bool) -> ::std::fmt::Result {\n";
s := s + "::dafny_runtime::DafnyPrint::fmt_print(&self.0, __fmt_print_formatter, in_seq)\n";
s := s + "}\n";
s := s + "}";

// inherit common traits
var ops := [("std::ops::Add", "add"), ("std::ops::Sub", "sub"), ("std::ops::Mul", "mul"), ("std::ops::Div", "div")];
var ops := [("::std::ops::Add", "add"), ("::std::ops::Sub", "sub"), ("::std::ops::Mul", "mul"), ("::std::ops::Div", "div")];
var i := 0;
while i < |ops| {
var (traitName, methodName) := ops[i];
Expand All @@ -130,6 +130,13 @@ module {:extern "DCOMP"} DCOMP {
s := s + "}\n";
i := i + 1;
}

s := s + "impl ::std::cmp::PartialOrd<r#" + c.name + "> for r#" + c.name;
s := s + " where " + underlyingType + ": ::std::cmp::PartialOrd<" + underlyingType + "> {\n";
s := s + "fn partial_cmp(&self, other: &r#" + c.name + ") -> ::std::option::Option<::std::cmp::Ordering> {\n";
s := s + "self.0.partial_cmp(&other.0)\n";
s := s + "}\n";
s := s + "}\n";
}

static method GenDatatype(c: Datatype) returns (s: string) {
Expand Down Expand Up @@ -245,32 +252,32 @@ module {:extern "DCOMP"} DCOMP {

var enumBody := "#[derive(PartialEq)]\npub enum r#" + c.name + typeParams + " {\n" + ctors + "\n}" + "\n" + "impl " + constrainedTypeParams + " r#" + c.name + typeParams + " {\n" + implBody + "\n}";

var printImpl := "impl " + constrainedTypeParams + " ::dafny_runtime::DafnyPrint for r#" + c.name + typeParams + " {\n" + "fn fmt_print(&self, f: &mut ::std::fmt::Formatter, _in_seq: bool) -> std::fmt::Result {\n" + "match self {\n";
var printImpl := "impl " + constrainedTypeParams + " ::dafny_runtime::DafnyPrint for r#" + c.name + typeParams + " {\n" + "fn fmt_print(&self, __fmt_print_formatter: &mut ::std::fmt::Formatter, _in_seq: bool) -> std::fmt::Result {\n" + "match self {\n";
i := 0;
while i < |c.ctors| {
var ctor := c.ctors[i];
var ctorMatch := "r#" + ctor.name + " { ";

var modulePrefix := if c.enclosingModule.id == "_module" then "" else c.enclosingModule.id + ".";
var printRhs := "write!(f, \"" + modulePrefix + c.name + "." + ctor.name + (if ctor.hasAnyArgs then "(\")?;" else "\")?;");
var printRhs := "write!(__fmt_print_formatter, \"" + modulePrefix + c.name + "." + ctor.name + (if ctor.hasAnyArgs then "(\")?;" else "\")?;");

var j := 0;
while j < |ctor.args| {
var formal := ctor.args[j];
ctorMatch := ctorMatch + formal.name + ", ";

if (j > 0) {
printRhs := printRhs + "\nwrite!(f, \", \")?;";
printRhs := printRhs + "\nwrite!(__fmt_print_formatter, \", \")?;";
}
printRhs := printRhs + "\n::dafny_runtime::DafnyPrint::fmt_print(" + formal.name + ", f, false)?;";
printRhs := printRhs + "\n::dafny_runtime::DafnyPrint::fmt_print(" + formal.name + ", __fmt_print_formatter, false)?;";

j := j + 1;
}

ctorMatch := ctorMatch + "}";

if (ctor.hasAnyArgs) {
printRhs := printRhs + "\nwrite!(f, \")\")?;";
printRhs := printRhs + "\nwrite!(__fmt_print_formatter, \")\")?;";
}

printRhs := printRhs + "\nOk(())";
Expand Down Expand Up @@ -1113,7 +1120,7 @@ module {:extern "DCOMP"} DCOMP {
}
}

s := enclosingString + "r#" + name + typeArgString + "(" + argString + ")";
s := enclosingString + "r#" + name.id + typeArgString + "(" + argString + ")";
isOwned := true;
}
case Lambda(params, body) => {
Expand Down Expand Up @@ -1154,6 +1161,16 @@ module {:extern "DCOMP"} DCOMP {
s := "::dafny_runtime::FunctionWrapper({\n" + allReadCloned + "Box::new(move |" + paramsString + "| {\n" + recursiveGen + "\n})})";
isOwned := true;
}
case IIFE(name, tpe, value, iifeBody) => {
var valueGen, valueOwned, recIdents := GenExpr(value, params, false);
readIdents := recIdents;
var valueTypeGen := GenType(tpe, false, true);
var bodyGen, bodyOwned, bodyIdents := GenExpr(iifeBody, params + if valueOwned then [] else [name.id], mustOwn);
readIdents := readIdents + bodyIdents;

s := "{\nlet r#" + name.id + ": " + (if valueOwned then "" else "&") + valueTypeGen + " = " + valueGen + ";\n" + bodyGen + "\n}";
isOwned := bodyOwned;
}
case Apply(func, args) => {
var funcString, _, recIdents := GenExpr(func, params, false);
readIdents := recIdents;
Expand Down
Loading

0 comments on commit 09f64e9

Please sign in to comment.