Skip to content

Commit

Permalink
updated pysat solver wrapper, added solving algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
alpavlenko committed Dec 7, 2023
1 parent 6575d58 commit 68de1d6
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 107 deletions.
2 changes: 2 additions & 0 deletions core/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .combine import *
from .solving import *
from .optimize import *
from .combine_t import *
from .growing_t import *

cores = {
Solving.slug: Solving,
Combine.slug: Combine,
Optimize.slug: Optimize,
CombineT.slug: CombineT,
Expand Down
47 changes: 47 additions & 0 deletions core/impl/solving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict, Any
from time import time as now

from ..abc import Core

from space.model import Backdoor
from lib_satprob.solver import Report
from lib_satprob.derived import get_derived_by
from lib_satprob.variables import Supplements


class Solving(Core):
slug = 'core:solving'

def launch(self, *backdoors: Backdoor) -> Report:
stamp, formula, = now(), self.problem.encoding.get_formula()
assumptions_set, constraints_set, all_stats = set(), set(), {}

def add_supplements(_supplements: Supplements):
_assumptions, _constraints = _supplements
assumptions_set.update(set(_assumptions))
for _clause in map(tuple, _constraints):
constraints_set.add(_clause)

def get_report(_status, _stats, _model) -> Report:
_time = _stats.get('time', 0.) + now() - stamp
return Report(_status, {'time': _time}, _model)

with self.problem.solver.get_instance(formula) as solver:
for backdoor, easy, hard in [(bd, [], []) for bd in backdoors]:
for supplements in backdoor.enumerate():
status, stats, model, _ = solver.propagate(supplements)
(easy if status is False else hard).append(supplements)
if status is True: return get_report(status, {}, model)

if len(hard) == 0: return get_report(False, {}, None)
add_supplements(
hard[0] if len(hard) == 1 else get_derived_by(easy)
)

assumptions = list(assumptions_set)
constraints = [list(c) for c in constraints_set]
report = solver.solve((assumptions, constraints))
return get_report(report.status, report.stats, report.model)

def __config__(self) -> Dict[str, Any]:
return {}
21 changes: 14 additions & 7 deletions lib_satprob/derived.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

from lib_satprob.encoding import Clauses
from lib_satprob.variables import Supplements

try:
Expand All @@ -18,25 +19,31 @@ def to_dnf_clause(cube: Supplements) -> 'AndOp':
return And(*map(to_dnf_var, cube[0]))


def get_derived_by(easy: List[Supplements]) -> Supplements:
def _get_derived_by(easy: List[Supplements]) -> Clauses:
dnf = Or(*map(to_dnf_clause, easy))
(min_dnf,) = espresso_exprs(dnf)
min_cnf = (~min_dnf).to_cnf()

one_lit, constraints = [], []
lit_map, _, cnf = min_cnf.encode_cnf()

def map_lit(lit: int) -> int:
var = lit_map[abs(lit)].indices[0]
return var if lit > 0 else -var

clauses = [map(map_lit, cl) for cl in cnf]
for clause in map(lambda x: sorted(x, key=abs), clauses):
(one_lit if len(clause) == 1 else constraints).append(clause)
# noinspection PyTypeChecker
return [sorted(map(map_lit, cl), key=abs) for cl in cnf]


def get_derived_by(easy: List[Supplements]) -> Supplements:
one_lit, constraints = [], []
for clause in _get_derived_by(easy):
(one_lit if len(clause) == 1
else constraints).append(clause)

return [clause[0] for clause in one_lit], constraints
return [cl[0] for cl in one_lit], constraints


__all__ = [
'get_derived_by'
'get_derived_by',
'_get_derived_by'
]
4 changes: 3 additions & 1 deletion lib_satprob/encoding/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@
'WCNFPlus',
# types
'Clause',
'Clauses'
'Clauses',
# utility
'wcnf_to_cnf'
]
84 changes: 47 additions & 37 deletions lib_satprob/encoding/impl/pysat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pysat import formula
from typing import Any, List, Dict, Optional
from pysat import formula as fml
from typing import Any, List, Dict, Optional, Union

from ..encoding import Encoding
from .._readers import PySatReader, CNFReader, \
Expand All @@ -9,6 +9,25 @@
Clauses = List[Clause]


def wcnf_to_cnf(
wcnf: fml.WCNF,
only_hard: bool = True
) -> Union[fml.CNF, fml.CNFPlus]:
if isinstance(wcnf, fml.WCNFPlus):
cnf = fml.CNFPlus()
cnf.atmosts = wcnf.atms
else:
cnf = fml.CNF()

cnf.nv = wcnf.nv
cnf.clauses = wcnf.hard
cnf.comments = wcnf.comments
if not only_hard:
cnf.clauses += wcnf.soft

return cnf


class PySatEnc(Encoding):
slug = 'encoding:pysat'

Expand Down Expand Up @@ -58,24 +77,19 @@ def _get_formula_flags(self) -> str:
return 'h' if self.extract_hard \
else super()._get_formula_flags()

def _get_formula(self) -> formula.CNF:
def _get_formula(self) -> fml.CNF:
if self.from_clauses:
return formula.CNF(
return fml.CNF(
from_clauses=self.from_clauses
)

_formula = super()._get_formula()
if isinstance(_formula, formula.WCNF):
cnf = formula.CNF()
cnf.nv = _formula.nv
cnf.clauses = _formula.hard
cnf.comments = _formula.comments
if not self.extract_hard:
cnf.clauses += _formula.soft

return cnf
elif isinstance(_formula, formula.CNF):
return _formula
formula = super()._get_formula()
if isinstance(formula, fml.WCNF):
return wcnf_to_cnf(
formula, self.extract_hard
)
elif isinstance(formula, fml.CNF):
return formula

def weighted(self) -> 'WCNF':
return WCNF().set_reader(self._reader)
Expand Down Expand Up @@ -104,20 +118,14 @@ def _get_formula_flags(self) -> str:
return 'h' if self.extract_hard \
else super()._get_formula_flags()

def _get_formula(self) -> formula.CNFPlus:
_formula = super()._get_formula()
if isinstance(_formula, formula.WCNFPlus):
cnf = formula.CNFPlus()
cnf.nv = _formula.nv
cnf.atmosts = _formula.atms
cnf.clauses = _formula.hard
cnf.comments = _formula.comments
if not self.extract_hard:
cnf.clauses += _formula.soft

return cnf
elif isinstance(_formula, formula.CNFPlus):
return _formula
def _get_formula(self) -> fml.CNFPlus:
formula = super()._get_formula()
if isinstance(formula, fml.WCNFPlus):
return wcnf_to_cnf(
formula, self.extract_hard
)
elif isinstance(formula, fml.CNFPlus):
return formula

def weighted(self) -> 'WCNFPlus':
return WCNFPlus().set_reader(self._reader)
Expand All @@ -139,11 +147,11 @@ def __init__(
from_file, from_string, comment_lead
) if from_file or from_string else None)

def _get_formula(self) -> formula.WCNF:
def _get_formula(self) -> fml.WCNF:
_formula = super()._get_formula()
if isinstance(_formula, formula.CNF):
if isinstance(_formula, fml.CNF):
return _formula.weighted()
elif isinstance(_formula, formula.WCNF):
elif isinstance(_formula, fml.WCNF):
return _formula

def from_hard(self) -> 'CNF':
Expand Down Expand Up @@ -171,11 +179,11 @@ def __init__(
from_file, from_string, comment_lead
) if from_file or from_string else None)

def _get_formula(self) -> formula.WCNFPlus:
def _get_formula(self) -> fml.WCNFPlus:
_formula = super()._get_formula()
if isinstance(_formula, formula.CNFPlus):
if isinstance(_formula, fml.CNFPlus):
return _formula.weighted()
elif isinstance(_formula, formula.WCNFPlus):
elif isinstance(_formula, fml.WCNFPlus):
return _formula

def from_hard(self) -> 'CNFPlus':
Expand All @@ -197,5 +205,7 @@ def __copy__(self) -> 'WCNFPlus':
'WCNFPlus',
# types
'Clause',
'Clauses'
'Clauses',
# utility
'wcnf_to_cnf'
]
Loading

0 comments on commit 68de1d6

Please sign in to comment.