From e82a73e2144fa2cec472039b368a2f78d85cd9a8 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Thu, 26 Sep 2024 18:54:34 +0400 Subject: [PATCH] final safe --- torchtitan/config_manager.py | 5 + torchtitan/datasets/hf_datasets.py | 8 +- torchtitan/utils/dataset_utils.py | 4 +- torchtitan/utils/safe.py | 465 ++++++++++++++++++++++++++ torchtitan/utils/text_format_utils.py | 8 +- train.py | 2 + train_configs/debug_model.toml | 1 + 7 files changed, 487 insertions(+), 6 deletions(-) create mode 100644 torchtitan/utils/safe.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 14ef3a4e..4aa564b6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,6 +270,11 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", + ) + self.parser.add_argument( + "--training.representation_type", + default="SMILES", + help="The representation type of the molecule for training the model.", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 6840f469..582503c7 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -87,6 +87,7 @@ def __init__( dataset_path: Optional[str], data_processing_style: str, tokenizer: Tokenizer, + representation_type: str = "SMILES", seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -135,6 +136,7 @@ def __init__( self.infinite = infinite self.rank = rank self.world_size = world_size + self.representation_type = representation_type # for non sync communication between ranks if not self.infinite and store: @@ -142,7 +144,6 @@ def __init__( else: self.store = None - # variables for checkpointing self._sample_idx = 0 self._all_tokens: List[int] = [] @@ -172,7 +173,7 @@ def __iter__(self): for sample_json in self._get_data_iter(): if self._some_rank_finished(): break - sample_text = self.data_processing_fn(sample_json, self.rng) + sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type) sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) self._all_tokens.extend(sample_tokens) self._sample_idx += 1 @@ -255,6 +256,7 @@ def build_hf_data_loader( seq_len: int, world_size, rank, + representation_type, infinite: bool = True, pin_memory: bool = False, num_workers: int = 2, @@ -268,7 +270,7 @@ def build_hf_data_loader( data_completion_store = None hf_ds = HuggingFaceDataset( - dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store + dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store ) return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers) diff --git a/torchtitan/utils/dataset_utils.py b/torchtitan/utils/dataset_utils.py index 40ef7aae..c397ee4c 100644 --- a/torchtitan/utils/dataset_utils.py +++ b/torchtitan/utils/dataset_utils.py @@ -30,12 +30,12 @@ def load_jsonl_line(jsonl_line): raise ValueError(f"Error decoding JSON: {e}") -def chemlactica_style_data_processing(sample_json, rng): +def chemlactica_style_data_processing(sample_json, rng, representation_type): try: sample_json = json.loads(sample_json["text"]) compound = delete_empty_tags(sample_json) sample_json = generate_formatted_string( - compound, rng + compound, rng, representation_type ) except Exception as e: print(e) diff --git a/torchtitan/utils/safe.py b/torchtitan/utils/safe.py new file mode 100644 index 00000000..6f51df09 --- /dev/null +++ b/torchtitan/utils/safe.py @@ -0,0 +1,465 @@ +import itertools +import re +from collections import Counter +from contextlib import suppress +from typing import Callable, List, Optional, Union + +import datamol as dm +import numpy as np +from rdkit import Chem +from rdkit.Chem import BRICS + +class SAFEDecodeError(Exception): + """Raised when a string cannot be decoded with the given encoding.""" + pass + +class SAFEEncodeError(Exception): + """Raised when a molecule cannot be encoded using SAFE.""" + pass + + +class SAFEFragmentationError(Exception): + """Raised when a the slicing algorithm return empty bonds.""" + pass + + +class SAFEConverter: + """Molecule line notation conversion from SMILES to SAFE + + A SAFE representation is a string based representation of a molecule decomposition into fragment components, + separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, + unless explicitely correct to add missing hydrogens. + + !!! note "Slicing algorithms" + + By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: + + * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) + * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) + * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) + * Any possible attachment points (`attach`) + + Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms + corresponding to the bonds to break. + + """ + + SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] + __SLICE_SMARTS = { + "hr": ["[*]!@-[*]"], # any non ring single bond + "recap": [ + "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", + "[$(C=!@O)]!@[$([O;+0])]", + "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", + "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", + "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", + "C=!@C", + "[N;+1;D4]!@[#6]", + "[$([n;+0])]-!@C", + "[$([O]=[C]-@[N;+0])]-!@[$([C])]", + "c-!@c", + "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", + ], + "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts + "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit + "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], + } + + def __init__( + self, + slicer: Optional[Union[str, List[str], Callable]] = "brics", + require_hs: Optional[bool] = None, + use_original_opener_for_attach: bool = True, + ignore_stereo: bool = False, + ): + """Constructor for the SAFE converter + + Args: + slicer: slicer algorithm to use for encoding. + Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) + or a custom callable that returns the bond ids that can be sliced. + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + `attach` slicer requires adding hydrogens. + use_original_opener_for_attach: whether to use the original branch opener digit when adding back + mapping number to attachment points, or use simple enumeration. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + + """ + self.slicer = slicer + if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: + self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) + if self.slicer != "brics" and isinstance(self.slicer, str): + self.slicer = [self.slicer] + if isinstance(self.slicer, (list, tuple)): + self.slicer = [dm.from_smarts(x) for x in self.slicer] + if any(x is None for x in self.slicer): + raise ValueError(f"Slicer: {slicer} cannot be valid") + self.require_hs = require_hs or (slicer == "attach") + self.use_original_opener_for_attach = use_original_opener_for_attach + self.ignore_stereo = ignore_stereo + + @staticmethod + def randomize(mol: dm.Mol, rng: Optional[int] = None): + """Randomize the position of the atoms in a mol. + + Args: + mol: molecules to randomize + rng: optional seed to use + """ + if isinstance(rng, int): + rng = np.random.default_rng(rng) + if mol.GetNumAtoms() == 0: + return mol + atom_indices = list(range(mol.GetNumAtoms())) + atom_indices = rng.permutation(atom_indices).tolist() + return Chem.RenumberAtoms(mol, atom_indices) + + @classmethod + def _find_branch_number(cls, inp: str): + """Find the branch number and ring closure in the SMILES representation using regexp + + Args: + inp: input smiles + """ + inp = re.sub(r"\[.*?\]", "", inp) # noqa + matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: + mol = Chem.FragmentOnBonds( + mol, + bonds, + dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], + ) + # here we need to be clever and disable rooted atom as the atom with mapping + + frags = list(Chem.GetMolFrags(mol, asMols=True)) + if randomize: + frags = rng.permutation(frags).tolist() + elif canonical: + frags = sorted( + frags, + key=lambda x: x.GetNumAtoms(), + reverse=True, + ) + + frags_str = [] + for frag in frags: + non_map_atom_idxs = [ + atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 + ] + frags_str.append( + Chem.MolToSmiles( + frag, + isomericSmiles=True, + canonical=True, # needs to always be true + rootedAtAtom=non_map_atom_idxs[0], + ) + ) + + scaffold_str = ".".join(frags_str) + # EN: fix for https://github.com/datamol-io/safe/issues/37 + # we were using the wrong branch number count which did not take into account + # possible change in digit utilization after bond slicing + scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers + + # don't capture atom mapping in the scaffold + attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) + if canonical: + attach_pos = sorted(attach_pos) + starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 + for attach in attach_pos: + val = str(starting_num) if starting_num < 10 else f"%{starting_num}" + # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" + attach_regexp = re.compile(r"(" + re.escape(attach) + r")") + scaffold_str = attach_regexp.sub(val, scaffold_str) + starting_num += 1 + + # now we need to remove all the parenthesis around digit only number + wrong_attach = re.compile(r"\(([\%\d]*)\)") + scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) + # furthermore, we autoapply rdkit-compatible digit standardization. + if rdkit_safe: + pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" + replacement = r"\g<1>\g<2>" + scaffold_str = re.sub(pattern, replacement, scaffold_str) + if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): + print( + "Warning: Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" + ) + return scaffold_str + + +def encode( + inp: Union[str, dm.Mol], + canonical: bool = True, + randomize: Optional[bool] = False, + seed: Optional[int] = None, + slicer: Optional[Union[List[str], str, Callable]] = None, + require_hs: Optional[bool] = None, + constraints: Optional[List[dm.Mol]] = None, + ignore_stereo: Optional[bool] = False, +): + """ + Convert input smiles to SAFE representation + + Args: + inp: input smiles + canonical: whether to return canonical SAFE string. Defaults to True + randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided + seed: optional seed to use when allowing randomization of the SAFE encoding. + slicer: slicer algorithm to use for encoding. Defaults to "brics". + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + constraints: List of molecules or pattern to preserve during the SAFE construction. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + """ + if slicer is None: + slicer = "brics" + with dm.without_rdkit_log(): + safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) + try: + encoded = safe_obj.encoder( + inp, + canonical=canonical, + randomize=randomize, + constraints=constraints, + seed=seed, + ) + except SAFEFragmentationError as e: + raise e + except Exception as e: + raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e + return encoded + + +def decode( + safe_str: str, + as_mol: bool = False, + canonical: bool = False, + fix: bool = True, + remove_added_hs: bool = True, + remove_dummies: bool = True, + ignore_errors: bool = False, +): + """Convert input SAFE representation to smiles + Args: + safe_str: input SAFE representation to decode as a valid molecule or smiles + as_mol: whether to return a molecule object or a smiles string + canonical: whether to return a canonical smiles or a randomized smiles + fix: whether to fix the SAFE representation to take into account non-connected attachment points + remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. + remove_dummies: whether to remove dummy atoms from the SAFE representation + ignore_errors: whether to ignore error and return None on decoding failure or raise an error + + """ + with dm.without_rdkit_log(): + safe_obj = SAFEConverter() + try: + decoded = safe_obj.decoder( + safe_str, + as_mol=as_mol, + canonical=canonical, + fix=fix, + remove_dummies=remove_dummies, + remove_added_hs=remove_added_hs, + ) + + except Exception as e: + if ignore_errors: + return None + raise SAFEDecodeError(f"Failed to decode {safe_str}") from e + return decoded + +def main(): + smiles = "O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1" + safe_string = encode(smiles) + print("SAFE representation:", safe_string) + print("SMILES representation:", decode(safe_string)) + +if __name__ == "main": + main() \ No newline at end of file diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 65a0c15e..9514ecf8 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/utils/text_format_utils.py # All rights reserved +from torchtitan.utils.safe import encode SPECIAL_TAGS = { "SMILES": {"start": "[START_SMILES]", "end": "[END_SMILES]"}, @@ -81,14 +82,19 @@ def delete_empty_tags(compound_json): return compound_json -def generate_formatted_string(compound_json, rng): +def generate_formatted_string(compound_json, rng, representation_type = "SMILES"): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") + + if representation_type == "SAFE": + value = encode(value) + if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) del compound_json[key] + keys = list(compound_json.keys()) rng.shuffle(keys) diff --git a/train.py b/train.py index b7ee0d23..797fca61 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,7 @@ def main(job_config: JobConfig): tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader + representation_type = job_config.training.representation_type data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -102,6 +103,7 @@ def main(job_config: JobConfig): job_config.training.seq_len, dp_degree, dp_rank, + representation_type, pin_memory = job_config.dataloader.pin_memory, num_workers = job_config.dataloader.num_workers, special_mode = job_config.dataloader.special_mode, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 2829d098..7aa7caa1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -44,6 +44,7 @@ tensor_parallel_degree = 1 compile = true dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) data_process_style="chemlactica_style" +representation_type="SAFE" [experimental] pipeline_parallel_degree = 1