Skip to content

Commit

Permalink
Support user defined operators in SMT
Browse files Browse the repository at this point in the history
  • Loading branch information
pgebal committed Sep 29, 2023
1 parent cc7a14a commit 772ecc0
Show file tree
Hide file tree
Showing 14 changed files with 267 additions and 244 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Language Features:

Compiler Features:
* Parser: Remove the experimental error recovery mode (``--error-recovery`` / ``settings.parserErrorRecovery``).
* SMTChecker: Support user-defined operators.
* Yul Optimizer: If ``PUSH0`` is supported, favor zero literals over storing zero values in variables.
* Yul Optimizer: Run the ``Rematerializer`` and ``UnusedPruner`` steps at the end of the default clean-up sequence.

Expand Down
3 changes: 3 additions & 0 deletions libsolidity/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,7 @@ class UnaryOperation: public Expression
Token getOperator() const { return m_operator; }
bool isPrefixOperation() const { return m_isPrefix; }
Expression const& subExpression() const { return *m_subExpression; }
ASTPointer<Expression> const& argument() const { return m_subExpression; }

FunctionType const* userDefinedFunctionType() const;

Expand Down Expand Up @@ -2145,6 +2146,8 @@ class BinaryOperation: public Expression

Expression const& leftExpression() const { return *m_left; }
Expression const& rightExpression() const { return *m_right; }
ASTPointer<Expression> leftArgument() const { return m_left; }
ASTPointer<Expression> rightArgument() const { return m_right; }
Token getOperator() const { return m_operator; }

FunctionType const* userDefinedFunctionType() const;
Expand Down
62 changes: 51 additions & 11 deletions libsolidity/formal/BMC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,15 @@ void BMC::endVisit(UnaryOperation const& _op)
{
SMTEncoder::endVisit(_op);

// User-defined operators are essentially function calls.
if (auto funDef = *_op.annotation().userDefinedFunction)
{
std::vector<ASTPointer<Expression const>> arguments;
arguments.push_back(_op.argument());
inlineFunctionCall(funDef, _op, &_op, _op.userDefinedFunctionType(), arguments);
return;
}

if (
_op.annotation().type->category() == Type::Category::RationalNumber ||
_op.annotation().type->category() == Type::Category::FixedPoint
Expand All @@ -565,6 +574,19 @@ void BMC::endVisit(UnaryOperation const& _op)
);
}

void BMC::endVisit(BinaryOperation const& _op)
{
SMTEncoder::endVisit(_op);

if (auto funDef = *_op.annotation().userDefinedFunction)
{
std::vector<ASTPointer<Expression const>> arguments;
arguments.push_back(_op.leftArgument());
arguments.push_back(_op.rightArgument());
inlineFunctionCall(funDef, _op, &_op, _op.userDefinedFunctionType(), arguments);
}
}

void BMC::endVisit(FunctionCall const& _funCall)
{
auto functionCallKind = *_funCall.annotation().kind;
Expand Down Expand Up @@ -674,15 +696,21 @@ void BMC::visitAddMulMod(FunctionCall const& _funCall)
SMTEncoder::visitAddMulMod(_funCall);
}

void BMC::inlineFunctionCall(FunctionCall const& _funCall)
void BMC::inlineFunctionCall(
FunctionDefinition const* _funDef,
Expression const& _callStackExpr,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments
)
{
solAssert(shouldInlineFunctionCall(_funCall, currentScopeContract(), m_currentContract), "");
auto funDef = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
solAssert(funDef, "");
solAssert(_funDef, "");
solAssert(_funType, "");
solAssert(_calledExpr, "");

if (visitedFunction(funDef))
if (visitedFunction(_funDef))
{
auto const& returnParams = funDef->returnParameters();
auto const& returnParams = _funDef->returnParameters();
for (auto param: returnParams)
{
m_context.newValue(*param);
Expand All @@ -691,19 +719,31 @@ void BMC::inlineFunctionCall(FunctionCall const& _funCall)
}
else
{
initializeFunctionCallParameters(*funDef, symbolicArguments(_funCall, m_currentContract));
initializeFunctionCallParameters(*_funDef, symbolicArguments(_funDef, _calledExpr, _funType, _arguments));

// The reason why we need to pushCallStack here instead of visit(FunctionDefinition)
// is that there we don't have `_funCall`.
pushCallStack({funDef, &_funCall});
// is that there we don't have `_callStackExpr`.
pushCallStack({_funDef, &_callStackExpr});
pushPathCondition(currentPathConditions());
auto oldChecked = std::exchange(m_checked, true);
funDef->accept(*this);
_funDef->accept(*this);
m_checked = oldChecked;
popPathCondition();
}

createReturnedExpressions(_funCall, m_currentContract);
createReturnedExpressions(_funDef, _callStackExpr);
}

void BMC::inlineFunctionCall(FunctionCall const& _funCall)
{
solAssert(shouldInlineFunctionCall(_funCall, currentScopeContract(), m_currentContract), "");

auto funDef = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
Expression const* calledExpr = &_funCall.expression();
auto funType = dynamic_cast<FunctionType const*>(calledExpr->annotation().type);
std::vector<ASTPointer<Expression const>> arguments = _funCall.sortedArguments();

inlineFunctionCall(funDef, _funCall, calledExpr, funType, arguments);
}

void BMC::internalOrExternalFunctionCall(FunctionCall const& _funCall)
Expand Down
8 changes: 8 additions & 0 deletions libsolidity/formal/BMC.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class BMC: public SMTEncoder
bool visit(WhileStatement const& _node) override;
bool visit(ForStatement const& _node) override;
void endVisit(UnaryOperation const& _node) override;
void endVisit(BinaryOperation const& _node) override;
void endVisit(FunctionCall const& _node) override;
void endVisit(Return const& _node) override;
bool visit(TryStatement const& _node) override;
Expand All @@ -113,6 +114,13 @@ class BMC: public SMTEncoder
/// Visits the FunctionDefinition of the called function
/// if available and inlines the return value.
void inlineFunctionCall(FunctionCall const& _funCall);
void inlineFunctionCall(
FunctionDefinition const* _funDef,
Expression const& _callStackExpr,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments
);
/// Inlines if the function call is internal or external to `this`.
/// Erases knowledge about state variables if external.
void internalOrExternalFunctionCall(FunctionCall const& _funCall);
Expand Down
137 changes: 104 additions & 33 deletions libsolidity/formal/CHC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,35 @@ void CHC::endVisit(ForStatement const& _for)
m_scopes.pop_back();
}

void CHC::endVisit(UnaryOperation const& _op)
{
SMTEncoder::endVisit(_op);

if (auto funDef = *_op.annotation().userDefinedFunction)
{
std::vector<ASTPointer<Expression const>> arguments;
arguments.push_back(_op.argument());
internalFunctionCall(funDef, &_op, _op.userDefinedFunctionType(), arguments, state().thisAddress());

createReturnedExpressions(funDef, _op);
}
}

void CHC::endVisit(BinaryOperation const& _op)
{
SMTEncoder::endVisit(_op);

if (auto funDef = *_op.annotation().userDefinedFunction)
{
std::vector<ASTPointer<Expression const>> arguments;
arguments.push_back(_op.leftArgument());
arguments.push_back(_op.rightArgument());
internalFunctionCall(funDef, &_op, _op.userDefinedFunctionType(), arguments, state().thisAddress());

createReturnedExpressions(funDef, _op);
}
}

void CHC::endVisit(FunctionCall const& _funCall)
{
auto functionCallKind = *_funCall.annotation().kind;
Expand Down Expand Up @@ -593,8 +622,8 @@ void CHC::endVisit(FunctionCall const& _funCall)
break;
}


createReturnedExpressions(_funCall, m_currentContract);
auto funDef = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
createReturnedExpressions(funDef, _funCall);
}

void CHC::endVisit(Break const& _break)
Expand Down Expand Up @@ -820,20 +849,67 @@ void CHC::visitDeployment(FunctionCall const& _funCall)
defineExpr(_funCall, newAddr);
}

void CHC::internalFunctionCall(
FunctionDefinition const* _funDef,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments,
smtutil::Expression _contractAddressValue
)
{
solAssert(m_currentContract, "");
solAssert(_calledExpr, "");
solAssert(_funType, "");

if (_funDef)
{
if (m_currentFunction && !m_currentFunction->isConstructor())
m_callGraph[m_currentFunction].insert(_funDef);
else
m_callGraph[m_currentContract].insert(_funDef);
}

m_context.addAssertion(predicate(_funDef, _calledExpr, _funType, _arguments, _contractAddressValue));

solAssert(m_errorDest, "");
connectBlocks(
m_currentBlock,
predicate(*m_errorDest),
errorFlag().currentValue() > 0 && currentPathConditions()
);
m_context.addAssertion(smtutil::Expression::implies(currentPathConditions(), errorFlag().currentValue() == 0));
m_context.addAssertion(errorFlag().increaseIndex() == 0);
}

void CHC::internalFunctionCall(FunctionCall const& _funCall)
{
solAssert(m_currentContract, "");

auto function = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
if (function)
auto funDef = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
if (funDef)
{
if (m_currentFunction && !m_currentFunction->isConstructor())
m_callGraph[m_currentFunction].insert(function);
m_callGraph[m_currentFunction].insert(funDef);
else
m_callGraph[m_currentContract].insert(function);
m_callGraph[m_currentContract].insert(funDef);
}

m_context.addAssertion(predicate(_funCall));
Expression const* calledExpr = &_funCall.expression();
auto funType = dynamic_cast<FunctionType const*>(calledExpr->annotation().type);
std::vector<ASTPointer<Expression const>> arguments = _funCall.sortedArguments();

auto contractAddressValue = [this](FunctionCall const& _f) {
auto [callExpr, callOptions] = functionCallExpression(_f);

FunctionType const& funType = dynamic_cast<FunctionType const&>(*callExpr->annotation().type);
if (funType.kind() == FunctionType::Kind::Internal)
return state().thisAddress();
if (MemberAccess const* callBase = dynamic_cast<MemberAccess const*>(callExpr))
return expr(callBase->expression());
solAssert(false, "Unreachable!");
};

m_context.addAssertion(predicate(funDef, calledExpr, funType, arguments, contractAddressValue(_funCall)));

solAssert(m_errorDest, "");
connectBlocks(
Expand Down Expand Up @@ -1028,7 +1104,7 @@ void CHC::externalFunctionCallToTrustedCode(FunctionCall const& _funCall)
state().readStateVars(*function->annotation().contract, contractAddressValue(_funCall));
}

smtutil::Expression pred = predicate(_funCall);
smtutil::Expression pred = predicate(function, callExpr, &funType, _funCall.sortedArguments(), calledAddress);

auto txConstraints = state().txTypeConstraints() && state().txFunctionConstraints(*function);
m_context.addAssertion(pred && txConstraints);
Expand Down Expand Up @@ -1733,40 +1809,35 @@ smtutil::Expression CHC::predicate(Predicate const& _block)
solAssert(false, "");
}

smtutil::Expression CHC::predicate(FunctionCall const& _funCall)
smtutil::Expression CHC::predicate(
FunctionDefinition const* _funDef,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments,
smtutil::Expression _contractAddressValue
)
{
FunctionType const& funType = dynamic_cast<FunctionType const&>(*_funCall.expression().annotation().type);
auto kind = funType.kind();
solAssert(_calledExpr, "");
solAssert(_funType, "");
auto kind = _funType->kind();
solAssert(kind == FunctionType::Kind::Internal || kind == FunctionType::Kind::External || kind == FunctionType::Kind::BareStaticCall, "");

solAssert(m_currentContract, "");
auto function = functionCallToDefinition(_funCall, currentScopeContract(), m_currentContract);
if (!function)
if (!_funDef)
return smtutil::Expression(true);

auto contractAddressValue = [this](FunctionCall const& _f) {
auto [callExpr, callOptions] = functionCallExpression(_f);

FunctionType const& funType = dynamic_cast<FunctionType const&>(*callExpr->annotation().type);
if (funType.kind() == FunctionType::Kind::Internal)
return state().thisAddress();
if (MemberAccess const* callBase = dynamic_cast<MemberAccess const*>(callExpr))
return expr(callBase->expression());
solAssert(false, "Unreachable!");
};
errorFlag().increaseIndex();
std::vector<smtutil::Expression> args{errorFlag().currentValue(), contractAddressValue(_funCall), state().abi(), state().crypto(), state().tx(), state().state()};

auto const* contract = function->annotation().contract;
std::vector<smtutil::Expression> args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().crypto(), state().tx(), state().state()};

auto const* contract = _funDef->annotation().contract;
auto const& hierarchy = m_currentContract->annotation().linearizedBaseContracts;
solAssert(kind != FunctionType::Kind::Internal || function->isFree() || (contract && contract->isLibrary()) || util::contains(hierarchy, contract), "");
solAssert(kind != FunctionType::Kind::Internal || _funDef->isFree() || (contract && contract->isLibrary()) || util::contains(hierarchy, contract), "");

if (kind == FunctionType::Kind::Internal)
contract = m_currentContract;

args += currentStateVariables(*contract);
args += symbolicArguments(_funCall, contract);
if (!usesStaticCall(_funCall))
args += symbolicArguments(_funDef, _calledExpr, _funType, _arguments);
if (!((_funDef && (_funDef->stateMutability() == StateMutability::Pure || _funDef->stateMutability() == StateMutability::View)) || kind == FunctionType::Kind::BareStaticCall))
{
state().newState();
for (auto const& var: stateVariablesIncludingInheritedAndPrivate(*contract))
Expand All @@ -1775,7 +1846,7 @@ smtutil::Expression CHC::predicate(FunctionCall const& _funCall)
args += std::vector<smtutil::Expression>{state().state()};
args += currentStateVariables(*contract);

for (auto var: function->parameters() + function->returnParameters())
for (auto var: _funDef->parameters() + _funDef->returnParameters())
{
if (m_context.knownVariable(*var))
m_context.variable(*var)->increaseIndex();
Expand All @@ -1784,10 +1855,10 @@ smtutil::Expression CHC::predicate(FunctionCall const& _funCall)
args.push_back(currentValue(*var));
}

Predicate const& summary = *m_summaries.at(contract).at(function);
Predicate const& summary = *m_summaries.at(contract).at(_funDef);
auto from = smt::function(summary, contract, m_context);
Predicate const& callPredicate = *createSummaryBlock(
*function,
*_funDef,
*contract,
kind == FunctionType::Kind::Internal ? PredicateType::InternalCall : PredicateType::ExternalCallTrusted
);
Expand Down
17 changes: 16 additions & 1 deletion libsolidity/formal/CHC.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class CHC: public SMTEncoder
bool visit(ForStatement const&) override;
void endVisit(ForStatement const&) override;
void endVisit(FunctionCall const& _node) override;
void endVisit(BinaryOperation const& _op) override;
void endVisit(UnaryOperation const& _op) override;
void endVisit(Break const& _node) override;
void endVisit(Continue const& _node) override;
void endVisit(IndexRangeAccess const& _node) override;
Expand All @@ -127,6 +129,13 @@ class CHC: public SMTEncoder
void visitAddMulMod(FunctionCall const& _funCall) override;
void visitDeployment(FunctionCall const& _funCall);
void internalFunctionCall(FunctionCall const& _funCall);
void internalFunctionCall(
FunctionDefinition const* _funDef,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments,
smtutil::Expression _contractAddressValue
);
void externalFunctionCall(FunctionCall const& _funCall);
void externalFunctionCallToTrustedCode(FunctionCall const& _funCall);
void addNondetCalls(ContractDefinition const& _contract);
Expand Down Expand Up @@ -246,7 +255,13 @@ class CHC: public SMTEncoder
/// @returns a predicate application after checking the predicate's type.
smtutil::Expression predicate(Predicate const& _block);
/// @returns the summary predicate for the called function.
smtutil::Expression predicate(FunctionCall const& _funCall);
smtutil::Expression predicate(
FunctionDefinition const* _funDef,
Expression const* _calledExpr,
FunctionType const* _funType,
std::vector<ASTPointer<Expression const>> _arguments,
smtutil::Expression _contractAddressValue
);
/// @returns a predicate that defines a contract initializer for _contract in the context of _contractContext.
smtutil::Expression initializer(ContractDefinition const& _contract, ContractDefinition const& _contractContext);
/// @returns a predicate that defines a constructor summary.
Expand Down
Loading

0 comments on commit 772ecc0

Please sign in to comment.