Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some access levels in MainSolver and add some checks to public … #689

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/api/MainSolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace opensmt { bool stop; }

MainSolver::MainSolver(Logic& logic, SMTConfig& conf, std::string name)
MainSolver::MainSolver(Logic & logic, SMTConfig & conf, std::string name)
:
theory(createTheory(logic, conf)),
term_mapper(new TermMapper(logic)),
Expand Down Expand Up @@ -103,10 +103,6 @@ bool MainSolver::pop()
return false;
}

PartitionManager & MainSolver::getPartitionManager() { return pmanager; }

sstat MainSolver::getStatus() const { return status; }

void MainSolver::insertFormula(PTRef fla)
{
if (logic.getSortRef(fla) != logic.getSort_bool()) {
Expand Down Expand Up @@ -214,6 +210,7 @@ PTRef MainSolver::rewriteMaxArity(PTRef root)
}

std::unique_ptr<Model> MainSolver::getModel() {
if (!config.produce_models()) { throw OsmtApiException("Producing models is not enabled"); }
if (status != s_True) { throw OsmtApiException("Model cannot be created if solver is not in SAT state"); }

ModelBuilder modelBuilder {logic};
Expand Down Expand Up @@ -242,6 +239,7 @@ lbool MainSolver::getTermValue(PTRef tr) const {
}

std::unique_ptr<InterpolationContext> MainSolver::getInterpolationContext() {
if (!config.produce_inter()) { throw OsmtApiException("Producing interpolants is not enabled"); }
if (status != s_False) { throw OsmtApiException("Interpolation context cannot be created if solver is not in UNSAT state"); }
return std::make_unique<InterpolationContext>(
config, *theory, *term_mapper, getSMTSolver().getResolutionProof(), pmanager
Expand Down
178 changes: 91 additions & 87 deletions src/api/MainSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
class Logic;

class sstat {
char value;
public:
public:
explicit sstat(int v) : value(v) {}
bool operator == (sstat s) const { return value == s.value; }
bool operator != (sstat s) const { return value != s.value; }
Expand All @@ -39,6 +38,8 @@ class sstat {
}
char getValue() const { return value; }
friend sstat toSstat(int v);
private:
char value;
};

inline sstat toSstat(int v) {return sstat(v); }
Expand All @@ -50,31 +51,76 @@ const sstat s_Error = toSstat( 2);


class MainSolver {
protected: /** Helper classes to deal with assertion stack, preprocessing and substitutions **/
public:
MainSolver(Logic & logic, SMTConfig & conf, std::string name);

MainSolver(std::unique_ptr<Theory> th, std::unique_ptr<TermMapper> tm, std::unique_ptr<THandler> thd,
std::unique_ptr<SimpSMTSolver> ss, Logic & logic, SMTConfig & conf, std::string name);

virtual ~MainSolver() = default;
MainSolver (const MainSolver&) = delete;
MainSolver& operator = (const MainSolver&) = delete;
MainSolver (MainSolver&&) = default;
MainSolver& operator = (MainSolver&&) = delete;

SMTConfig & getConfig() const { return config; }
Logic & getLogic() const { return logic; }

SimpSMTSolver & getSMTSolver() { return *smt_solver; }
SimpSMTSolver const & getSMTSolver() const { return *smt_solver; }

THandler & getTHandler() { return *thandler; }
THandler const & getTHandler() const { return *thandler; }

void push();
bool pop();
void insertFormula(PTRef fla);

void initialize();

virtual sstat check(); // A wrapper for solve which simplifies the loaded formulas and initializes the solvers
sstat solve();
// Simplify frames (not yet simplified) until all are simplified or the instance is detected unsatisfiable.
sstat simplifyFormulas();

void printFramesAsQuery() const;
[[nodiscard]] sstat getStatus() const { return status; }

// Values
lbool getTermValue (PTRef tr) const;

// Returns model of the last query (must be in satisfiable state)
std::unique_ptr<Model> getModel();

// Prints proof of the last query (must be in unsatisfiable state)
void printResolutionProofSMT2() const;

// Returns interpolation context for the last query (must be in UNSAT state)
std::unique_ptr<InterpolationContext> getInterpolationContext();

void stop() { smt_solver->stop = true; }

static std::unique_ptr<Theory> createTheory(Logic & logic, SMTConfig & config);
protected:
using FrameId = uint32_t;

struct PushFrame {
private:
FrameId id;

public:
FrameId getId() const { return id; }
int size() const { return formulas.size(); }
void push(PTRef tr) { formulas.push(tr); }
PTRef operator[](int i) const { return formulas[i]; }
vec<PTRef> formulas;
bool unsat; // If true then the stack of frames with this frame at top is UNSAT
bool unsat{false}; // If true then the stack of frames with this frame at top is UNSAT

PushFrame(PushFrame const &) = delete;
PushFrame(PushFrame &&) = default;
explicit PushFrame(uint32_t id) : id(id), unsat(false) {}
explicit PushFrame(uint32_t id) : id(id) {}
private:
FrameId id;
};

class AssertionStack {
private:
std::vector<PushFrame> frames;
uint32_t frameId = 0;

public:
[[nodiscard]] PushFrame const & last() const {
assert(not frames.empty());
Expand Down Expand Up @@ -104,11 +150,12 @@ class MainSolver {
assert(frameCount() > 0);
last().push(fla);
}
private:
std::vector<PushFrame> frames;
uint32_t frameId = 0;
};

class Substitutions {
std::vector<Logic::SubstMap> perFrameSubst;

public:
void push() { perFrameSubst.emplace_back(); }
void pop() { perFrameSubst.pop_back(); }
Expand All @@ -127,15 +174,25 @@ class MainSolver {
}
return allSubst;
}
private:
std::vector<Logic::SubstMap> perFrameSubst;
};
/** Actual MainSolver members **/
protected:
AssertionStack frames;
Substitutions substitutions;
vec<PTRef> frameTerms;
std::size_t firstNotSimplifiedFrame = 0;
unsigned int insertedFormulasCount = 0;
sstat status = s_Undef; // The status of the last solver call

struct SubstitutionResult {
Logic::SubstMap usedSubstitution;
PTRef result {PTRef_Undef};
};

Theory & getTheory() { return *theory; }
Theory const & getTheory() const { return *theory; }
TermMapper & getTermMapper() const { return *term_mapper;}
PartitionManager & getPartitionManager() { return pmanager; }

[[nodiscard]] PTRef currentRootInstance() const;

void printFramesAsQuery(std::ostream & s) const;

static std::unique_ptr<SimpSMTSolver> createInnerSolver(SMTConfig& config, THandler& thandler);

PTRef newFrameTerm(FrameId frameId) {
assert(frameId != 0);
Expand Down Expand Up @@ -163,88 +220,35 @@ class MainSolver {

sstat giveToSolver(PTRef root, FrameId push_id);

struct SubstitutionResult {
Logic::SubstMap usedSubstitution;
PTRef result {PTRef_Undef};
};

PTRef applyLearntSubstitutions(PTRef fla);

PTRef substitutionPass(PTRef fla, PreprocessingContext const& context);

SubstitutionResult computeSubstitutions(PTRef fla);

public:
AssertionStack frames;

sstat status = s_Undef; // The status of the last solver call

private:

std::unique_ptr<Theory> theory;
std::unique_ptr<TermMapper> term_mapper;
std::unique_ptr<THandler> thandler;
std::unique_ptr<SimpSMTSolver> smt_solver;
Logic& logic;
Logic & logic;
PartitionManager pmanager;
SMTConfig& config;
SMTConfig & config;
Tseitin ts;

opensmt::OSMTTimeVal query_timer; // How much time we spend solving.
std::string solver_name; // Name for the solver
int check_called = 0; // A counter on how many times check was called.

sstat solve();

[[nodiscard]] PTRef currentRootInstance() const;

void printFramesAsQuery(std::ostream & s) const;

static std::unique_ptr<SimpSMTSolver> createInnerSolver(SMTConfig& config, THandler& thandler);

MainSolver(Logic& logic, SMTConfig& conf, std::string name);

MainSolver(std::unique_ptr<Theory> th, std::unique_ptr<TermMapper> tm, std::unique_ptr<THandler> thd,
std::unique_ptr<SimpSMTSolver> ss, Logic & logic, SMTConfig & conf, std::string name);

virtual ~MainSolver() = default;
MainSolver (const MainSolver&) = delete;
MainSolver& operator = (const MainSolver&) = delete;
MainSolver (MainSolver&&) = default;
MainSolver& operator = (MainSolver&&) = delete;

SMTConfig& getConfig() { return config; }
SimpSMTSolver & getSMTSolver() { return *smt_solver; }
SimpSMTSolver const & getSMTSolver() const { return *smt_solver; }

THandler &getTHandler() { return *thandler; }
Logic &getLogic() { return logic; }
Theory &getTheory() { return *theory; }
const Theory &getTheory() const { return *theory; }
PartitionManager & getPartitionManager();

void push();
bool pop();
void insertFormula(PTRef fla);

void initialize();

virtual sstat check(); // A wrapper for solve which simplifies the loaded formulas and initializes the solvers
// Simplify frames (not yet simplified) until all are simplified or the instance is detected unsatisfiable.
sstat simplifyFormulas();

void printFramesAsQuery() const;
[[nodiscard]] sstat getStatus() const;

// Values
lbool getTermValue (PTRef tr) const;

// Returns model of the last query (must be in satisfiable state)
std::unique_ptr<Model> getModel();

// Prints proof of the last query (must be in unsatisfiable state)
void printResolutionProofSMT2() const;

void stop() { smt_solver->stop = true; }

// Returns interpolation context for the last query (must be in UNSAT state)
std::unique_ptr<InterpolationContext> getInterpolationContext();

static std::unique_ptr<Theory> createTheory(Logic & logic, SMTConfig & config);
Substitutions substitutions;
vec<PTRef> frameTerms;
std::size_t firstNotSimplifiedFrame = 0;
unsigned int insertedFormulasCount = 0;
};

bool MainSolver::trackPartitions() const
Expand Down
15 changes: 9 additions & 6 deletions src/parallel/MainSplitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void MainSplitter::notifyResult(sstat const & result)
}

sstat MainSplitter::check() {
if (getChannel().isSolverInParallelMode() and not config.sat_solver_limit()) {
if (getChannel().isSolverInParallelMode() and not getConfig().sat_solver_limit()) {
//push frames size should match with length of the solver branch
if (frames.frameCount() !=
static_cast<std::size_t>(getSplitter().getSolverBranch().size() + 1))
Expand All @@ -31,7 +31,7 @@ sstat MainSplitter::check() {
}

sstat MainSplitter::solve_(vec<FrameId> const & enabledFrames) {
if (getChannel().isSolverInParallelMode() and not config.sat_solver_limit()) {
if (getChannel().isSolverInParallelMode() and not getConfig().sat_solver_limit()) {
vec<opensmt::pair<int, int>> const & solverBranch = getSplitter().getSolverBranch();
if (enabledFrames.size() > solverBranch.size() + 1) {
throw PTPLib::common::Exception(__FILE__, __LINE__,
Expand All @@ -58,7 +58,7 @@ sstat MainSplitter::solve() {
};

void MainSplitter::writeSplits(std::string const & baseName) const {
assert(config.sat_split_type() != spt_none);
assert(getConfig().sat_split_type() != spt_none);
auto const & splits = getSplitter().getSplits();

auto splitStrings = getPartitionClauses();
Expand Down Expand Up @@ -94,6 +94,7 @@ std::unique_ptr<SimpSMTSolver> MainSplitter::createInnerSolver(SMTConfig & confi
bool MainSplitter::verifyPartitions(vec<PTRef> const & partitions) const {
bool ok = true;
std::string error;
auto & logic = getLogic();
VerificationUtils verifier(logic);
for (int i = 0; i < partitions.size(); i++) {
for (int j = i + 1; j < partitions.size(); j++) {
Expand All @@ -108,7 +109,7 @@ bool MainSplitter::verifyPartitions(vec<PTRef> const & partitions) const {
for (PTRef tr : partitions) {
partitionCoverageQuery.push(logic.mkNot(tr));
}
if (partitions.size() == config.sat_split_num()) {
if (partitions.size() == getConfig().sat_split_num()) {
// The partitions need to cover the full search space, i.e., the conjunction of the negated partitions must be unsatisfiable
if (not verifier.impliesInternal(logic.mkAnd(partitionCoverageQuery), logic.getTerm_false())) {
error += "[Non-covering partitioning: " + logic.pp(logic.mkAnd(partitionCoverageQuery)) + " is satisfiable] ";
Expand Down Expand Up @@ -137,10 +138,11 @@ bool MainSplitter::verifyPartitions(vec<PTRef> const & partitions) const {
std::vector<std::string> MainSplitter::getPartitionClauses() const {
assert(not isSplitTypeNone());
auto const & splits = getSplitter().getSplits();
auto & logic = getLogic();
vec<PTRef> partitionsTr;
partitionsTr.capacity(splits.size());
for (auto const &split : splits) {
auto conj_vec = addToConjunction(split.splitToPtAsgns(*thandler));
auto conj_vec = addToConjunction(split.splitToPtAsgns(getTHandler()));
partitionsTr.push(logic.mkAnd(conj_vec));
}

Expand All @@ -157,6 +159,7 @@ std::vector<std::string> MainSplitter::getPartitionClauses() const {

vec<PTRef> MainSplitter::addToConjunction(std::vector<vec<PtAsgn>> const & in) const {
vec<PTRef> out;
auto & logic = getLogic();
for (const auto & constr : in) {
vec<PTRef> disj_vec;
for (const auto & pta : constr) {
Expand All @@ -165,4 +168,4 @@ vec<PTRef> MainSplitter::addToConjunction(std::vector<vec<PtAsgn>> const & in) c
out.push(logic.mkOr(std::move(disj_vec)));
}
return out;
}
}
12 changes: 6 additions & 6 deletions src/parallel/MainSplitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
class MainSplitter : public MainSolver {

private:
inline bool isSplitTypeScatter() const & { return dynamic_cast<Splitter&>(*smt_solver).isSplitTypeScatter(); }
inline bool isSplitTypeScatter() const { return dynamic_cast<Splitter const &>(getSMTSolver()).isSplitTypeScatter(); }

inline bool isSplitTypeNone() const & { return dynamic_cast<Splitter&>(*smt_solver).isSplitTypeNone(); }
inline bool isSplitTypeNone() const { return dynamic_cast<Splitter const &>(getSMTSolver()).isSplitTypeNone(); }

inline PTPLib::net::Channel<PTPLib::net::SMTS_Event, PTPLib::net::Lemma> & getChannel() const { return getSplitter().getChannel(); }

inline ScatterSplitter & getScatterSplitter() { return dynamic_cast<ScatterSplitter&>(getSMTSolver()); }
inline ScatterSplitter const & getScatterSplitter() const { return dynamic_cast<ScatterSplitter const &>(getSMTSolver()); }
inline ScatterSplitter & getScatterSplitter() { return dynamic_cast<ScatterSplitter &>(getSMTSolver()); }

inline Splitter & getSplitter() const { return dynamic_cast<Splitter&>(*smt_solver); }
inline Splitter const & getSplitter() const { return dynamic_cast<Splitter const &>(getSMTSolver()); }
inline Splitter & getSplitter() { return dynamic_cast<Splitter &>(getSMTSolver()); }

void notifyResult(sstat const & result);

Expand Down Expand Up @@ -58,8 +60,6 @@ class MainSplitter : public MainSolver {
void writeSplits(std::string const &) const;

static std::unique_ptr<SimpSMTSolver> createInnerSolver(SMTConfig &, THandler &, PTPLib::net::Channel<PTPLib::net::SMTS_Event, PTPLib::net::Lemma> &);

inline TermMapper& getTermMapper() const { return *term_mapper;}
};


Expand Down
Loading