diff --git a/rlmolecule/actors.py b/rlmolecule/actors.py index 55cfd8c..abc1126 100644 --- a/rlmolecule/actors.py +++ b/rlmolecule/actors.py @@ -4,7 +4,7 @@ from typing import Any, List import ray -from lru import LRU +from pylru import lrucache logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ class RayDictCache(DictCache): @ray.remote class RayLRUCache(DictCache): def __init__(self, max_size: int = int(1e5)): - self._dict = LRU(max_size) + self._dict = lrucache(max_size) @ray.remote diff --git a/rlmolecule/builder.py b/rlmolecule/builder.py index 740b615..b9491cb 100644 --- a/rlmolecule/builder.py +++ b/rlmolecule/builder.py @@ -53,7 +53,7 @@ def __init__( Args: max_atoms (int, optional): Maximum number of heavy atoms. Defaults to 10. - min_atoms (int, optional): minimum number of heavy atoms. Defaults to 4. + min_atoms (int, optional): Minimum number of heavy atoms. Defaults to 4. atom_additions (Optional[List], optional): potential atom types to consider. Defaults to ('C', 'N', 'O') stereoisomers (bool, optional): whether to consider stereoisomers different diff --git a/rlmolecule/molecule_state.py b/rlmolecule/molecule_state.py index 6e80ebe..a4c1e46 100644 --- a/rlmolecule/molecule_state.py +++ b/rlmolecule/molecule_state.py @@ -118,18 +118,24 @@ def _get_children(self) -> Sequence[V]: return [] next_actions = [self.new(molecule) for molecule in self.builder(self.molecule)] - next_actions.extend(self._get_terminal_actions()) + + # Only add the terminal state as an action if there are enough atoms + terminal_actions = [] + if self.num_atoms >= self.min_num_atoms: + terminal_actions = self._get_terminal_actions() - if self.data.prune_terminal_states: - next_actions = self._prune_next_actions(next_actions) + if self.data.prune_terminal_states: + terminal_actions = self._prune_next_actions(terminal_actions) - if len(next_actions) >= self.max_num_actions: + if len(next_actions) + len(terminal_actions) >= self.max_num_actions: logger.info( - f"{self} has {len(next_actions) + 1} next actions when the " + f"{self} has {len(next_actions) + 1 + len(terminal_actions)} next actions when the " f"maximum is {self.max_num_actions}" ) - next_actions = random.sample(next_actions, self.max_num_actions) + next_actions = random.sample(next_actions, self.max_num_actions - len(terminal_actions)) + next_actions.extend(terminal_actions) + return next_actions def _get_terminal_actions(self) -> Sequence[V]: @@ -145,6 +151,7 @@ def _get_terminal_actions(self) -> Sequence[V]: def _prune_next_actions(self, next_actions: Sequence[V]) -> Sequence[V]: """Use the ray actor handle in self.data (or a simple set) to find terminal states that have already been evaluated and remove them from the search tree. + Useful to force the builder to continually make new molecules. Args: next_actions (Sequence[V]): A list of MoleculeStates to be pruned @@ -242,6 +249,10 @@ def num_atoms(self) -> int: def max_num_actions(self) -> int: return self.data.max_num_actions + @property + def min_num_atoms(self) -> int: + return self.builder.min_atoms + def __repr__(self) -> str: """ delegates to the SMILES string diff --git a/setup.cfg b/setup.cfg index 8969cb5..0893ed8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,7 +15,7 @@ include_package_data = true install_requires = graphenv tensorflow - lru-dict + pylru rdkit nfp diff --git a/tests/conftest.py b/tests/conftest.py index 0a6fd94..47ba9c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def ray_init(): @pytest.fixture def builder() -> MoleculeBuilder: - return MoleculeBuilder(max_atoms=5) + return MoleculeBuilder(max_atoms=5, min_atoms=1) @pytest.fixture(scope="class") diff --git a/tests/test_molecule_state.py b/tests/test_molecule_state.py index a50a036..d53f64f 100644 --- a/tests/test_molecule_state.py +++ b/tests/test_molecule_state.py @@ -61,7 +61,7 @@ def test_prune_terminal_ray(ray_init): qed_root = QEDState( rdkit.Chem.MolFromSmiles("C"), MoleculeData( - MoleculeBuilder(max_atoms=5, cache=True), + MoleculeBuilder(max_atoms=5, min_atoms=1, cache=True), max_num_actions=20, prune_terminal_states=True, ), @@ -95,7 +95,7 @@ def test_csv_writer(ray_init, caplog): with TemporaryDirectory() as tempdir: data = MoleculeData( - MoleculeBuilder(max_atoms=5, cache=True), + MoleculeBuilder(max_atoms=5, min_atoms=1, cache=True), max_num_actions=20, prune_terminal_states=True, log_reward_filepath=Path(tempdir, "test.csv"),