Skip to content

Commit

Permalink
Merge pull request #118 from NREL/min_num_atoms
Browse files Browse the repository at this point in the history
Fix LRU cache for ray and update 'min_atoms' option to only build molecules with at least 'min_atoms'
  • Loading branch information
jlaw9 authored Apr 3, 2024
2 parents 9fa89d0 + 76c0688 commit cb36415
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 13 deletions.
4 changes: 2 additions & 2 deletions rlmolecule/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, List

import ray
from lru import LRU
from pylru import lrucache

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,7 +32,7 @@ class RayDictCache(DictCache):
@ray.remote
class RayLRUCache(DictCache):
def __init__(self, max_size: int = int(1e5)):
self._dict = LRU(max_size)
self._dict = lrucache(max_size)


@ray.remote
Expand Down
2 changes: 1 addition & 1 deletion rlmolecule/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
Args:
max_atoms (int, optional): Maximum number of heavy atoms. Defaults to 10.
min_atoms (int, optional): minimum number of heavy atoms. Defaults to 4.
min_atoms (int, optional): Minimum number of heavy atoms. Defaults to 4.
atom_additions (Optional[List], optional): potential atom types to consider.
Defaults to ('C', 'N', 'O')
stereoisomers (bool, optional): whether to consider stereoisomers different
Expand Down
23 changes: 17 additions & 6 deletions rlmolecule/molecule_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,24 @@ def _get_children(self) -> Sequence[V]:
return []

next_actions = [self.new(molecule) for molecule in self.builder(self.molecule)]
next_actions.extend(self._get_terminal_actions())

# Only add the terminal state as an action if there are enough atoms
terminal_actions = []
if self.num_atoms >= self.min_num_atoms:
terminal_actions = self._get_terminal_actions()

if self.data.prune_terminal_states:
next_actions = self._prune_next_actions(next_actions)
if self.data.prune_terminal_states:
terminal_actions = self._prune_next_actions(terminal_actions)

if len(next_actions) >= self.max_num_actions:
if len(next_actions) + len(terminal_actions) >= self.max_num_actions:
logger.info(
f"{self} has {len(next_actions) + 1} next actions when the "
f"{self} has {len(next_actions) + 1 + len(terminal_actions)} next actions when the "
f"maximum is {self.max_num_actions}"
)
next_actions = random.sample(next_actions, self.max_num_actions)
next_actions = random.sample(next_actions, self.max_num_actions - len(terminal_actions))

next_actions.extend(terminal_actions)

return next_actions

def _get_terminal_actions(self) -> Sequence[V]:
Expand All @@ -145,6 +151,7 @@ def _get_terminal_actions(self) -> Sequence[V]:
def _prune_next_actions(self, next_actions: Sequence[V]) -> Sequence[V]:
"""Use the ray actor handle in self.data (or a simple set) to find terminal
states that have already been evaluated and remove them from the search tree.
Useful to force the builder to continually make new molecules.
Args:
next_actions (Sequence[V]): A list of MoleculeStates to be pruned
Expand Down Expand Up @@ -242,6 +249,10 @@ def num_atoms(self) -> int:
def max_num_actions(self) -> int:
return self.data.max_num_actions

@property
def min_num_atoms(self) -> int:
return self.builder.min_atoms

def __repr__(self) -> str:
"""
delegates to the SMILES string
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include_package_data = true
install_requires =
graphenv
tensorflow
lru-dict
pylru
rdkit
nfp

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ray_init():

@pytest.fixture
def builder() -> MoleculeBuilder:
return MoleculeBuilder(max_atoms=5)
return MoleculeBuilder(max_atoms=5, min_atoms=1)


@pytest.fixture(scope="class")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_molecule_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_prune_terminal_ray(ray_init):
qed_root = QEDState(
rdkit.Chem.MolFromSmiles("C"),
MoleculeData(
MoleculeBuilder(max_atoms=5, cache=True),
MoleculeBuilder(max_atoms=5, min_atoms=1, cache=True),
max_num_actions=20,
prune_terminal_states=True,
),
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_csv_writer(ray_init, caplog):
with TemporaryDirectory() as tempdir:

data = MoleculeData(
MoleculeBuilder(max_atoms=5, cache=True),
MoleculeBuilder(max_atoms=5, min_atoms=1, cache=True),
max_num_actions=20,
prune_terminal_states=True,
log_reward_filepath=Path(tempdir, "test.csv"),
Expand Down

0 comments on commit cb36415

Please sign in to comment.