-
Notifications
You must be signed in to change notification settings - Fork 2
/
smiles_augmenter.py
89 lines (75 loc) · 2.94 KB
/
smiles_augmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import logging
import random
from typing import Callable, List
from .miscellaneous import apply_to_any_smiles, apply_to_smiles_groups
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
class SmilesAugmenter:
"""
Class to augment any kind of SMILES string with the help of randomization
and shuffling.
"""
def __init__(
self,
augmentation_fn: Callable[[str], str],
augmentation_probability: float = 1.0,
shuffle: bool = True,
ignore_exceptions: bool = True,
):
"""
Args:
augmentation_fn: Function for augmenting the individual SMILES strings,
such as the functions provided in smiles_randomization.py.
augmentation_probability: Probability with which to augment individual
SMILES strings.
shuffle: Whether to shuffle the order of the compounds.
ignore_exceptions: Whether to ignore the error (and return the
original string) when an augmentation fails. If False, exceptions
will be propagated.
"""
self.augmentation_fn = augmentation_fn
self.augmentation_probability = augmentation_probability
self.shuffle = shuffle
self.ignore_exceptions = ignore_exceptions
def augment(self, smiles: str, number_augmentations: int) -> List[str]:
"""
Augment one SMILES string (of any kind).
Args:
smiles: SMILES string to augment.
number_augmentations: how many times to do the augmentation.
"""
# augmentation of the individual compound SMILES
augmented = [
apply_to_any_smiles(
smiles, self._augment_with_probability, force_multicomponent=True
)
for _ in range(number_augmentations)
]
# shuffle the order of the compounds
if self.shuffle:
augmented = [
apply_to_smiles_groups(s, SmilesAugmenter._shuffle) for s in augmented
]
return augmented
def _augment_with_probability(self, smiles: str) -> str:
"""Augmentat a SMILES, with the probability given by the member variable."""
# Note: no need to call random.uniform if the augmentation probability is 1.0.
if (
self.augmentation_probability == 1.0
or random.uniform(0, 1) <= self.augmentation_probability
):
try:
return self.augmentation_fn(smiles)
except Exception as e:
if self.ignore_exceptions:
logger.warning(f"Augmentation failed for {smiles}: {e}")
return smiles
else:
raise
# no augmentation
return smiles
@staticmethod
def _shuffle(smiles_list: List[str]) -> List[str]:
smiles_list = smiles_list.copy()
random.shuffle(smiles_list)
return smiles_list