Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas: Add instructions for Swift examples #1315

Draft
wants to merge 14 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3754,6 +3754,14 @@ final case class LLVMStore[G](
)(val blame: Blame[VerificationFailure])(implicit val o: Origin)
extends LLVMStatement[G] with LLVMStoreImpl[G]

final case class LLVMMemset[G](
dest: Expr[G],
value: Expr[G],
len: Expr[G],
volatile: Expr[G],
)(val blame: Blame[VerificationFailure])(implicit val o: Origin)
extends LLVMStatement[G] with LLVMMemsetImpl[G]

final case class LLVMBranchUnreachable[G]()(
val blame: Blame[UnreachableReachedError]
)(implicit val o: Origin)
Expand All @@ -3767,6 +3775,14 @@ final case class LLVMGetElementPointer[G](
)(implicit val o: Origin)
extends LLVMExpr[G] with LLVMGetElementPointerImpl[G]

final case class LLVMExtractValue[G](
aggregateType: Type[G],
resultType: Type[G],
value: Expr[G],
indices: Seq[Int],
)(implicit val o: Origin)
extends LLVMExpr[G] with LLVMExtractValueImpl[G]

final case class LLVMSignExtend[G](
inputType: Type[G],
outputType: Type[G],
Expand Down Expand Up @@ -3795,6 +3811,33 @@ final case class LLVMFloatExtend[G](
)(implicit val o: Origin)
extends LLVMExpr[G] with LLVMFloatExtendImpl[G]

sealed trait LLVMArithOpWithOverflow[G]
extends LLVMStatement[G] with LLVMArithOpWithOverflowImpl[G]

final case class LLVMAddWithOverflow[G](
target: Expr[G],
left: Expr[G],
right: Expr[G],
signed: Boolean,
)(val blame: Blame[AssignFailed])(implicit val o: Origin)
extends LLVMArithOpWithOverflow[G] with LLVMAddWithOverflowImpl[G]

final case class LLVMSubWithOverflow[G](
target: Expr[G],
left: Expr[G],
right: Expr[G],
signed: Boolean,
)(val blame: Blame[AssignFailed])(implicit val o: Origin)
extends LLVMArithOpWithOverflow[G] with LLVMSubWithOverflowImpl[G]

final case class LLVMMultWithOverflow[G](
target: Expr[G],
left: Expr[G],
right: Expr[G],
signed: Boolean,
)(val blame: Blame[AssignFailed])(implicit val o: Origin)
extends LLVMArithOpWithOverflow[G] with LLVMMultWithOverflowImpl[G]

final class LLVMGlobalSpecification[G](val value: String)(
implicit val o: Origin
) extends GlobalDeclaration[G] with LLVMGlobalSpecificationImpl[G] {
Expand Down
13 changes: 13 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMAddWithOverflowImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package vct.col.ast.lang.llvm

import vct.col.ast.LLVMAddWithOverflow
import vct.col.ast.ops.LLVMAddWithOverflowOps
import vct.col.print._

trait LLVMAddWithOverflowImpl[G] extends LLVMAddWithOverflowOps[G] {
this: LLVMAddWithOverflow[G] =>
override def instRepr: Doc =
if (signed) { Text("sadd.with.overflow") }
else { Text("uadd.with.overflow") }

}
23 changes: 23 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMArithOpWithOverflowImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package vct.col.ast.lang.llvm

import vct.col.ast.{Expr, LLVMArithOpWithOverflow}
import vct.col.origin.{AssignFailed, Blame}
import vct.col.print._

trait LLVMArithOpWithOverflowImpl[G] {
this: LLVMArithOpWithOverflow[G] =>
def target: Expr[G]
def left: Expr[G]
def right: Expr[G]
def signed: Boolean
def blame: Blame[AssignFailed]

def instRepr: Doc

override def layout(implicit ctx: Ctx): Doc = {
Group(
target.show <+> Text("=") <+> instRepr <+> Text("(") <+> left.show <+>
Text(", ") <+> right.show <+> Text(")")
)
}
}
14 changes: 14 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMExtractValueImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package vct.col.ast.lang.llvm

import vct.col.ast.LLVMExtractValue
import vct.col.ast.ops.LLVMExtractValueOps
import vct.col.print._

trait LLVMExtractValueImpl[G] extends LLVMExtractValueOps[G] {
this: LLVMExtractValue[G] =>

override def layout(implicit ctx: Ctx): Doc =
Text("extractValue<") <+> aggregateType <> ">" <+> value.show <> "[" <+>
Doc.args(indices.map(i => Text(i.toString))) <+> "]"
override def t = resultType
}
14 changes: 14 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMMemsetImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package vct.col.ast.lang.llvm

import vct.col.ast.LLVMMemset
import vct.col.ast.ops.LLVMMemsetOps
import vct.col.print.{Text, _}

trait LLVMMemsetImpl[G] extends LLVMMemsetOps[G] {
this: LLVMMemset[G] =>

override def layout(implicit ctx: Ctx): Doc = {
Text("memset(") <+> dest.show <+> Text(", ") <+> value.show <+>
Text(", ") <+> len.show <+> Text(", ") <+> volatile.show <+> Text(")")
}
}
12 changes: 12 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMMultWithOverflowImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package vct.col.ast.lang.llvm

import vct.col.ast.ops.LLVMMultWithOverflowOps
import vct.col.ast.{LLVMMultWithOverflow, Type}
import vct.col.print._

trait LLVMMultWithOverflowImpl[G] extends LLVMMultWithOverflowOps[G] {
this: LLVMMultWithOverflow[G] =>
override def instRepr: Doc =
if (signed) { Text("smul.with.overflow") }
else { Text("umul.with.overflow") }
}
12 changes: 12 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMSubWithOverflowImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package vct.col.ast.lang.llvm

import vct.col.ast.LLVMSubWithOverflow
import vct.col.ast.ops.LLVMSubWithOverflowOps
import vct.col.print._

trait LLVMSubWithOverflowImpl[G] extends LLVMSubWithOverflowOps[G] {
this: LLVMSubWithOverflow[G] =>
override def instRepr: Doc =
if (signed) { Text("ssub.with.overflow") }
else { Text("usub.with.overflow") }
}
5 changes: 5 additions & 0 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,7 @@ abstract class CoercingRewriter[Pre <: Generation]()
case LLVMForall(_, _) => e
case LLVMSepForall(_, _) => e
case LLVMExists(_, _) => e
case LLVMExtractValue(_, _, _, _) => e
case PVLEndpointExpr(_, _) => e
case EndpointExpr(ref, expr) => e
case ChorExpr(expr) => ChorExpr(bool(expr))
Expand Down Expand Up @@ -2330,6 +2331,10 @@ abstract class CoercingRewriter[Pre <: Generation]()
LLVMStore(value, p, ordering)(store.blame)
case fracOf: LLVMFracOf[Pre] => fracOf
case unreachable: LLVMBranchUnreachable[Pre] => unreachable
case add: LLVMAddWithOverflow[Pre] => add
case sub: LLVMSubWithOverflow[Pre] => sub
case mult: LLVMMultWithOverflow[Pre] => mult
case memset: LLVMMemset[Pre] => memset
case ModelDo(model, perm, after, action, impl) =>
ModelDo(model, rat(perm), after, action, impl)
case n @ Notify(obj) => Notify(cls(obj))(n.blame)
Expand Down
9 changes: 8 additions & 1 deletion src/col/vct/col/util/StatementToExpression.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ case class StatementToExpression[Pre <: Generation, Post <: Generation](
}
}
}
case a @ Assert(e) =>
alt match {
case Some(exprAlt) =>
Some(Asserting[Post](rw.dispatch(e), exprAlt)(a.blame)(a.o))
case None =>
throw errorBuilder("Assert may not be the last statement")
}
case _ => None
}
}
Expand Down Expand Up @@ -88,7 +95,7 @@ case class StatementToExpression[Pre <: Generation, Post <: Generation](
else
0
)
case Assign(_, _) => None
case Assert(_) => Some(0)
case _ => None
}

Expand Down
4 changes: 4 additions & 0 deletions src/llvm/include/Transform/Instruction/BinaryOpTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ void transformBinaryOp(llvm::Instruction &llvmInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

void transformBitwiseXor(llvm::Instruction &llvmInstruction,
col::Assign &assignment,
pallas::FunctionCursor &funcCursor);

} // namespace llvm2col
#endif // PALLAS_BINARYOPTRANSFORM_H
61 changes: 61 additions & 0 deletions src/llvm/include/Transform/Instruction/IntrinsicsTransform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#ifndef PALLAS_INTRINSICSTRANSFORM_H
#define PALLAS_INTRINSICSTRANSFORM_H
#include "Passes/Function/FunctionBodyTransformer.h"

#include <llvm/IR/Instructions.h>

namespace llvm2col {
namespace col = vct::col::ast;

void transformIntrinsic(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

/**
* Transform call to the @llvm.trap()-intrinsic
*/
void transformTrap(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

/**
* Transform call to the @llvm.expect.XXX-intrinsic
*/
void transformExpect(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

/**
* Transform call to the @llvm.sadd.with.overflow or
* @llvm.uadd.with.overflow-intrinsic
*/
void transformAddWithOverflow(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor, bool sign);

/**
* Transform call to the @llvm.ssub.with.overflow or
* @llvm.usub.with.overflow-intrinsic
*/
void transformSubWithOverflow(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor, bool sign);

/**
* Transform call to the @llvm.smult.with.overflow or
* @llvm.umult.with.overflow-intrinsic
*/
void transformMultWithOverflow(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor, bool sign);

/**
* Transform call to the @llvm.memset.*-intrinsic
*/
void transformMemset(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

} // namespace llvm2col

#endif // PALLAS_INTRINSICSTRANSFORM_H
4 changes: 4 additions & 0 deletions src/llvm/include/Transform/Instruction/OtherOpTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ void transformICmp(llvm::ICmpInst &icmpInstruction,
void transformCmpExpr(llvm::CmpInst &cmpInstruction, auto &colCompareExpr,
pallas::FunctionCursor &funcCursor);

void transformExtractValueInst(llvm::ExtractValueInst &llvmInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);

void transformCallExpr(llvm::CallInst &callInstruction,
col::LlvmBasicBlock &colBlock,
pallas::FunctionCursor &funcCursor);
Expand Down
8 changes: 8 additions & 0 deletions src/llvm/include/Transform/LoopContractTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ bool addInvariantToContract(llvm::MDNode &invMD, llvm::Loop &llvmLoop,
llvm::MDNode &contractLoc,
pallas::FunctionCursor &functionCursor);

/**
* Build argument-expression that dereferences the value of the given alloca.
*/
void buildArgExprFromAlloca(col::LlvmFunctionInvocation &wrapperCall,
unsigned int argIdx, llvm::AllocaInst &llvmAlloca,
llvm::Function &llvmWFunc, llvm::MDNode &srcLoc,
pallas::FunctionCursor &functionCursor);

} // namespace llvm2col

#endif // PALLAS_LOOPCONTRACTTRANSFORM_H
15 changes: 15 additions & 0 deletions src/llvm/include/Transform/Transform.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef PALLAS_TRANSFORM_H
#define PALLAS_TRANSFORM_H

#include <llvm/IR/DerivedTypes.h>
#include <llvm/Support/Casting.h>

#include "Origin/OriginProvider.h"
#include "Passes/Function/FunctionBodyTransformer.h"

Expand Down Expand Up @@ -70,6 +73,18 @@ void transformBinExpr(llvm::Instruction &llvmInstruction,
llvm2col::transformAndSetExpr(funcCursor, llvmInstruction,
*llvmInstruction.getOperand(1), *rExpr);
}
template <class ColBinBitExpr>
void transformBitwiseBinExpr(llvm::Instruction &llvmInstruction,
ColBinBitExpr &colBinBitExpr,
pallas::FunctionCursor &funcCursor) {
llvm::IntegerType *ty =
llvm::cast<llvm::IntegerType>(llvmInstruction.getType());
transformBinExpr(llvmInstruction, colBinBitExpr, funcCursor);
colBinBitExpr.set_allocated_blame(new col::Blame());
colBinBitExpr.set_bits(ty->getBitWidth());
colBinBitExpr.set_signed_(true);
// TODO: Figure out what to put for signed
}

template <class IDNode> int64_t setColNodeId(IDNode &idNode) {
auto id = reinterpret_cast<int64_t>(idNode);
Expand Down
6 changes: 6 additions & 0 deletions src/llvm/include/Util/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ const std::string VC_PREFIX = "VC.";
const std::string METADATA_PURE_KEYWORD = VC_PREFIX + "pure";
const std::string METADATA_CONTRACT_KEYWORD = VC_PREFIX + "contract";
const std::string METADATA_GLOBAL_KEYWORD = VC_PREFIX + "global";

// Other constants
const std::string SWIFT_ENTRY_SECTION = "swift5_entry";
const std::string SWIFT_FATAL_ERROR =
"$ss10fatalError_4file4lines5NeverOSSyXK_s12StaticStringVSutFfA_SSycfu_";

} // namespace pallas::constants

#endif // PALLAS_CONSTANTS_H
11 changes: 11 additions & 0 deletions src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "Transform/BlockTransform.h"
#include "Transform/Transform.h"
#include "Util/BlockUtils.h"
#include "Util/Constants.h"
#include "Util/Exceptions.h"
#include <llvm/Support/raw_ostream.h>

Expand Down Expand Up @@ -203,6 +204,16 @@ FunctionBodyTransformerPass::run(Function &F, FunctionAnalysisManager &FAM) {
funcCursor.addVariableMapEntry(
A, FAM.getResult<FunctionDeclarer>(F).getFuncArgMapEntry(A));
}

// Skip the body of the fatalError-function from Swift
// (As it uses currently unsupported instructions). We generate an
// requires false; - contract instead.
if (F.getName().str() == constants::SWIFT_FATAL_ERROR) {
ErrorReporter::addWarning(SOURCE_LOC,
"Skipping body of swift fatalError", F);
return PreservedAnalyses::all();
}

// start recursive block code gen with basic block
llvm::BasicBlock &entryBlock = F.getEntryBlock();
llvm2col::transformLLVMBlock(entryBlock, funcCursor);
Expand Down
12 changes: 12 additions & 0 deletions src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ FunctionContractDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) {
.mutable_vcllvm_function_contract();
colContract->set_allocated_blame(new col::Blame());
colContract->set_name(F.getName());

// Add a "requires false"-contract to Swift's fatalError-function
// (Since we currently do not support all instructions that it uses).
if (F.getName().str() == constants::SWIFT_FATAL_ERROR) {
ErrorReporter::addWarning(SOURCE_LOC,
"Generating contract forswift fatalError", F);
colContract->set_value("requires false;");
colContract->set_allocated_origin(
llvm2col::generateFunctionContractOrigin(F, "requires false;"));
return PreservedAnalyses::all();
}

// check if contract keyword is present
if (!F.hasMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD)) {
// set contract to a tautology
Expand Down
Loading