Skip to content

Commit

Permalink
updated tests, removed problem instance dependency from space module
Browse files Browse the repository at this point in the history
  • Loading branch information
alpavlenko committed Sep 27, 2023
1 parent 6dd4580 commit c2c293d
Show file tree
Hide file tree
Showing 28 changed files with 555 additions and 474 deletions.
2 changes: 1 addition & 1 deletion core/impl/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, space: Space, logger: Logger, problem: Problem,
def launch(self, *args, **kwargs) -> PointSet:
start_stamp = now()
with self.logger:
initial = self.space.get_initial(self.problem)
initial = self.space.get_initial()
self.logger.meta(initial, self.comparator)
# todo: search root estimation in cache
point, handles = self.estimate(initial).result(), []
Expand Down
6 changes: 3 additions & 3 deletions function/impl/function_gad.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def gad_supplements(args: WorkerArgs, problem: Problem,

def gad_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
space, budget, measure, problem, bytemask = payload
searchable, timestamp = space.unpack(problem, bytemask), now()
searchable, timestamp = space.unpack(bytemask), now()

# limit = measure.get_limit(budget)
times, times2, values, values2 = {}, {}, {}, {}
Expand Down Expand Up @@ -75,9 +75,9 @@ def get_worker_fn(self) -> WorkerCallable:
def calculate(self, searchable: Searchable, results: Results) -> Estimation:
times, values, statuses, stats = aggregate_results(results)
time_sum, value_sum = sum(times.values()), sum(values.values())
power, value = searchable.power(), value_sum if stats.count else float(
'inf')

power = searchable.power()
value = value_sum if stats.count else float('inf')
if stats.count > 0 and stats.count != power:
value = float(value_sum) / stats.count * power

Expand Down
2 changes: 1 addition & 1 deletion function/impl/function_ibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def ibs_supplements(args: WorkerArgs, problem: Problem,

def ibs_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
space, budget, measure, problem, bytemask = payload
backdoor, timestamp = space.unpack(problem, bytemask), now()
backdoor, timestamp = space.unpack(bytemask), now()

limit = measure.get_limit(budget)
times, times2, values, values2 = {}, {}, {}, {}
Expand Down
2 changes: 1 addition & 1 deletion function/impl/function_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def ips_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
space, budget, measure, problem, bytemask = payload
searchable, timestamp = space.unpack(problem, bytemask), now()
searchable, timestamp = space.unpack(bytemask), now()

times, times2, values, values2 = {}, {}, {}, {}
formula, statuses = problem.encoding.get_formula(), {}
Expand Down
12 changes: 9 additions & 3 deletions function/impl/function_rho.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def rho_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
space, budget, measure, problem, bytemask = payload
searchable, timestamp = space.unpack(problem, bytemask), now()
searchable, timestamp = space.unpack(bytemask), now()

times, times2, values, values2 = {}, {}, {}, {}
formula, statuses = problem.encoding.get_formula(), {}
Expand All @@ -35,9 +35,11 @@ def rho_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
class RhoFunction(GuessAndDetermine):
slug = 'function:rho'

def __init__(self, measure: Measure, penalty_power: float):
def __init__(self, measure: Measure, penalty_power: float,
only_solved: bool = False):
super().__init__(AutoBudget(), measure)
self.penalty_power = penalty_power
self.only_solved = only_solved

def get_worker_fn(self) -> WorkerCallable:
return rho_worker_fn
Expand All @@ -47,8 +49,12 @@ def calculate(self, searchable: Searchable, results: Results) -> Estimation:
time_sum, value_sum = sum(times.values()), sum(values.values())
power, value = searchable.power(), float('inf')

solved = statuses.get(Status.SOLVED, 0) + (
statuses.get(Status.RESOLVED, 0)
if not self.only_solved else 0
)
if stats.count > 0 and self.penalty_power > power:
rho_value = float(statuses.get(Status.RESOLVED, 0)) / stats.count
rho_value = float(solved) / stats.count
penalty_value = (1. - rho_value) * self.penalty_power
value = rho_value * power + penalty_value

Expand Down
2 changes: 1 addition & 1 deletion function/impl/function_rho_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def tau_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult:
space, budget, measure, problem, bytemask = payload
searchable, timestamp = space.unpack(problem, bytemask), now()
searchable, timestamp = space.unpack(bytemask), now()

limit = measure.get_limit(budget)
times, times2, values, values2 = {}, {}, {}, {}
Expand Down
7 changes: 5 additions & 2 deletions pysatmc/encoding/encoding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Any
from typing import Dict, List, Any, TypeVar


class Formula:
Expand All @@ -9,8 +9,11 @@ def __iter__(self) -> List[Any]:
raise NotImplementedError


TFormula = TypeVar('TFormula', bound='Formula')


class Encoding:
def get_formula(self, copy: bool = True) -> Formula:
def get_formula(self, copy: bool = True) -> TFormula:
raise NotImplementedError

def __config__(self) -> Dict[str, Any]:
Expand Down
22 changes: 17 additions & 5 deletions pysatmc/encoding/impl/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@ def __init__(
self,
from_file: str = None,
from_string: str = None,
extract_hard: bool = False,
from_clauses: Clauses = None,
comment_lead: List[str] = ('c',)
):
self.from_file = from_file
self.from_string = from_string
self.from_clauses = from_clauses
self.extract_hard = extract_hard
self.comment_lead = comment_lead

def get_formula(self, copy: bool = True) -> formula.CNF:
if self.from_file is not None:
if self.from_file not in cnf_data:
_formula = formula.CNF(
from_file=self.from_file,
comment_lead=self.comment_lead
)
if not self.extract_hard:
_formula = formula.CNF(
from_file=self.from_file,
comment_lead=self.comment_lead
)
else:
_formula = formula.CNF(
from_clauses=formula.WCNF(
from_file=self.from_file,
comment_lead=self.comment_lead
).hard
)
cnf_data[self.from_file] = _formula
return cnf_data[self.from_file].copy() if \
copy else cnf_data[self.from_file]
Expand All @@ -45,6 +55,7 @@ def __copy__(self):
return CNF(
from_file=self.from_file,
from_string=self.from_string,
extract_hard=self.extract_hard,
comment_lead=self.comment_lead,
)

Expand All @@ -53,6 +64,7 @@ def __config__(self) -> Dict[str, Any]:
'slug': self.slug,
'from_file': self.from_file,
'from_string': self.from_string,
'extract_hard': self.extract_hard,
'from_clauses': self.from_clauses,
'comment_lead': self.comment_lead
}
Expand All @@ -73,7 +85,7 @@ def __init__(
comment_lead=comment_lead
)

def get_formula(self) -> formula.CNFPlus:
def get_formula(self, copy: bool = True) -> formula.CNFPlus:
if self.from_file is not None:
if self.from_file not in cnf_data:
_formula = formula.CNFPlus(
Expand Down
16 changes: 13 additions & 3 deletions pysatmc/encoding/impl/wcnf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pysat import formula
from typing import Any, List, Dict

from .cnf import CNF
from ..encoding import Encoding

wcnf_data = {}
Expand Down Expand Up @@ -35,7 +36,15 @@ def get_formula(self, copy: bool = True) -> formula.WCNF:
comment_lead=self.comment_lead
)

def __copy__(self):
def from_hard(self) -> CNF:
return CNF(
from_file=self.from_file,
from_string=self.from_string,
extract_hard=True,
comment_lead=self.comment_lead,
)

def __copy__(self) -> 'WCNF':
return WCNF(
from_file=self.from_file,
from_string=self.from_string,
Expand All @@ -54,15 +63,16 @@ def __config__(self) -> Dict[str, Any]:
class WCNFPlus(WCNF):
slug = 'encoding:wcnf+'

def get_formula(self) -> formula.WCNFPlus:
def get_formula(self, copy: bool = True) -> formula.WCNFPlus:
if self.from_file is not None:
if self.from_file not in wcnf_data:
_formula = formula.WCNFPlus(
from_file=self.from_file,
comment_lead=self.comment_lead
)
wcnf_data[self.from_file] = _formula
return wcnf_data[self.from_file].copy()
return wcnf_data[self.from_file].copy() if \
copy else wcnf_data[self.from_file]

return formula.WCNFPlus(
from_string=self.from_string,
Expand Down
2 changes: 1 addition & 1 deletion requirements-mpi.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy~=1.25.2
numpy~=1.24.4
mpi4py~=3.1.4
python-sat~=0.1.8.dev9
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
numpy~=1.25.2
numpy~=1.24.4
python-sat~=0.1.8.dev9
11 changes: 5 additions & 6 deletions space/abc/space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Any, Optional

from pysatmc.problem import Problem
from typings.searchable import Vector, ByteVector, Searchable


Expand All @@ -10,16 +9,16 @@ class Space:
def __init__(self, by_vector: Optional[Vector] = None):
self.by_vector = by_vector

def get_initial(self, problem: Problem) -> Searchable:
def get_initial(self) -> Searchable:
raise NotImplementedError

def _get_searchable(self, problem: Problem) -> Searchable:
def _get_searchable(self) -> Searchable:
raise NotImplementedError

# noinspection PyProtectedMember
def unpack(self, problem: Problem, byte_vec: ByteVector) -> Searchable:
searchable = self._get_searchable(problem)
return searchable._set_vector(Searchable.unpack(byte_vec))
def unpack(self, byte_vec: ByteVector) -> Searchable:
vector = Searchable.unpack(byte_vec)
return self._get_searchable()._set_vector(vector)

def __config__(self) -> Dict[str, Any]:
raise NotImplementedError
Expand Down
15 changes: 9 additions & 6 deletions space/impl/backdoor_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..abc import Space
from ..model import Backdoor

from pysatmc.problem import Problem
from typings.searchable import Vector
from pysatmc.variables import Variables

Expand All @@ -22,8 +21,8 @@ def __init__(
self.variables = variables

# noinspection PyProtectedMember
def get_initial(self, problem: Problem) -> Backdoor:
backdoor = self._get_searchable(problem)
def get_initial(self) -> Backdoor:
backdoor = self._get_searchable()
if self.by_string is not None:
var_names = self.by_string.split()
backdoor._set_vector([
Expand All @@ -34,12 +33,16 @@ def get_initial(self, problem: Problem) -> Backdoor:
backdoor._set_vector(self.by_vector)
return backdoor

def _get_searchable(self, problem: Problem) -> Backdoor:
def _get_searchable(self) -> Backdoor:
return Backdoor(variables=self.variables)

def __config__(self) -> Dict[str, Any]:
# todo: add realisation
pass
return {
'slug': self.slug,
'by_string': self.by_string,
'by_vector': self.by_vector,
'variables': self.variables.__config__(),
}


__all__ = [
Expand Down
13 changes: 8 additions & 5 deletions space/impl/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ def __init__(
self.indexes = indexes

# noinspection PyProtectedMember
def get_initial(self, problem: Problem) -> Interval:
interval = self._get_searchable(problem)
def get_initial(self) -> Interval:
interval = self._get_searchable()
if self.by_vector is not None:
interval._set_vector(self.by_vector)
return interval

def _get_searchable(self, problem: Problem) -> Interval:
def _get_searchable(self) -> Interval:
return Interval(indexes=self.indexes)

def __config__(self) -> Dict[str, Any]:
# todo: add realisation
pass
return {
'slug': self.slug,
'indexes': self.indexes.__config__(),
'by_vector': self.by_vector,
}


__all__ = [
Expand Down
34 changes: 34 additions & 0 deletions tests/test_core/test_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from space.model import Backdoor
from pysatmc.variables import Range

from core.model.point import Point
from core.module.comparator import MinValueMaxSize


class TestComparator(unittest.TestCase):
def test_min_value_max_size(self):
comparator = MinValueMaxSize()
backdoor = Backdoor(Range(start=1, length=8))
self.assertGreater(
Point(backdoor.make_copy([]), comparator).set(value=1000),
Point(backdoor.make_copy([0, 0, 1]), comparator).set(value=900),
)
self.assertGreater(
Point(backdoor.make_copy([0, 0, 1]), comparator).set(value=1000),
Point(backdoor.make_copy([1, 1, 1]), comparator).set(value=1000),
)
self.assertEqual(
Point(backdoor.make_copy([0, 1, 1]), comparator).set(value=1000),
Point(backdoor.make_copy([1, 0, 1]), comparator).set(value=1000),
)
self.assertLess(
Point(backdoor.make_copy([0, 1, 1]), comparator).set(value=1000),
Point(backdoor.make_copy([0, 0, 1]), comparator).set(value=1000),
)
self.assertLess(
Point(backdoor.make_copy([0, 1, 1]), comparator).set(value=1000),
Point(backdoor.make_copy([1, 0, 1]), comparator).set(
value=float('inf')),
)
34 changes: 34 additions & 0 deletions tests/test_core/test_limitation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from core.module.limitation import WallTime, Iteration


class TestLimitation(unittest.TestCase):
def test_iteration(self):
limitation = Iteration(value=1000)
self.assertEqual(limitation.exhausted(), False)
self.assertEqual(limitation.left('time'), None)
self.assertEqual(limitation.left('iteration'), 1000)

limitation.increase('iteration', 990)
self.assertEqual(limitation.left('iteration'), 10)

limitation.set('iteration', 1234)
self.assertEqual(limitation.left('iteration'), 0)
self.assertEqual(limitation.get('iteration'), 1234)
self.assertEqual(limitation.exhausted(), True)

def test_wall_time(self):
limitation = WallTime(from_string='02:13:45')
self.assertEqual(limitation.exhausted(), False)
self.assertEqual(limitation.left('time'), 8025)
self.assertEqual(limitation.left('iteration'), None)

limitation.increase('time', 654)
self.assertEqual(limitation.get('time'), 654)
self.assertEqual(limitation.left('time'), 7371)

limitation.set('time', 10345)
self.assertEqual(limitation.left('time'), 0)
self.assertEqual(limitation.get('time'), 10345)
self.assertEqual(limitation.exhausted(), True)
File renamed without changes.
Loading

0 comments on commit c2c293d

Please sign in to comment.