diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4aa564b6..14ef3a4e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,11 +270,6 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", - ) - self.parser.add_argument( - "--training.representation_type", - default="SMILES", - help="The representation type of the molecule for training the model.", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 1509695f..6840f469 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -87,7 +87,6 @@ def __init__( dataset_path: Optional[str], data_processing_style: str, tokenizer: Tokenizer, - representation_type: str = "SMILES", seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -134,7 +133,6 @@ def __init__( self._tokenizer = tokenizer self.seq_len = seq_len self.infinite = infinite - self.representation_type = representation_type self.rank = rank self.world_size = world_size @@ -144,6 +142,7 @@ def __init__( else: self.store = None + # variables for checkpointing self._sample_idx = 0 self._all_tokens: List[int] = [] @@ -256,7 +255,6 @@ def build_hf_data_loader( seq_len: int, world_size, rank, - representation_type, infinite: bool = True, pin_memory: bool = False, num_workers: int = 2, diff --git a/torchtitan/utils/dataset_utils.py b/torchtitan/utils/dataset_utils.py index c397ee4c..40ef7aae 100644 --- a/torchtitan/utils/dataset_utils.py +++ b/torchtitan/utils/dataset_utils.py @@ -30,12 +30,12 @@ def load_jsonl_line(jsonl_line): raise ValueError(f"Error decoding JSON: {e}") -def chemlactica_style_data_processing(sample_json, rng, representation_type): +def chemlactica_style_data_processing(sample_json, rng): try: sample_json = json.loads(sample_json["text"]) compound = delete_empty_tags(sample_json) sample_json = generate_formatted_string( - compound, rng, representation_type + compound, rng ) except Exception as e: print(e) diff --git a/torchtitan/utils/safe.py b/torchtitan/utils/safe.py deleted file mode 100644 index 6f51df09..00000000 --- a/torchtitan/utils/safe.py +++ /dev/null @@ -1,465 +0,0 @@ -import itertools -import re -from collections import Counter -from contextlib import suppress -from typing import Callable, List, Optional, Union - -import datamol as dm -import numpy as np -from rdkit import Chem -from rdkit.Chem import BRICS - -class SAFEDecodeError(Exception): - """Raised when a string cannot be decoded with the given encoding.""" - pass - -class SAFEEncodeError(Exception): - """Raised when a molecule cannot be encoded using SAFE.""" - pass - - -class SAFEFragmentationError(Exception): - """Raised when a the slicing algorithm return empty bonds.""" - pass - - -class SAFEConverter: - """Molecule line notation conversion from SMILES to SAFE - - A SAFE representation is a string based representation of a molecule decomposition into fragment components, - separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, - unless explicitely correct to add missing hydrogens. - - !!! note "Slicing algorithms" - - By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: - - * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) - * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) - * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) - * Any possible attachment points (`attach`) - - Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms - corresponding to the bonds to break. - - """ - - SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] - __SLICE_SMARTS = { - "hr": ["[*]!@-[*]"], # any non ring single bond - "recap": [ - "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", - "[$(C=!@O)]!@[$([O;+0])]", - "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", - "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", - "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", - "C=!@C", - "[N;+1;D4]!@[#6]", - "[$([n;+0])]-!@C", - "[$([O]=[C]-@[N;+0])]-!@[$([C])]", - "c-!@c", - "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", - ], - "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts - "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit - "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], - } - - def __init__( - self, - slicer: Optional[Union[str, List[str], Callable]] = "brics", - require_hs: Optional[bool] = None, - use_original_opener_for_attach: bool = True, - ignore_stereo: bool = False, - ): - """Constructor for the SAFE converter - - Args: - slicer: slicer algorithm to use for encoding. - Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) - or a custom callable that returns the bond ids that can be sliced. - require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. - `attach` slicer requires adding hydrogens. - use_original_opener_for_attach: whether to use the original branch opener digit when adding back - mapping number to attachment points, or use simple enumeration. - ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. - - """ - self.slicer = slicer - if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: - self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) - if self.slicer != "brics" and isinstance(self.slicer, str): - self.slicer = [self.slicer] - if isinstance(self.slicer, (list, tuple)): - self.slicer = [dm.from_smarts(x) for x in self.slicer] - if any(x is None for x in self.slicer): - raise ValueError(f"Slicer: {slicer} cannot be valid") - self.require_hs = require_hs or (slicer == "attach") - self.use_original_opener_for_attach = use_original_opener_for_attach - self.ignore_stereo = ignore_stereo - - @staticmethod - def randomize(mol: dm.Mol, rng: Optional[int] = None): - """Randomize the position of the atoms in a mol. - - Args: - mol: molecules to randomize - rng: optional seed to use - """ - if isinstance(rng, int): - rng = np.random.default_rng(rng) - if mol.GetNumAtoms() == 0: - return mol - atom_indices = list(range(mol.GetNumAtoms())) - atom_indices = rng.permutation(atom_indices).tolist() - return Chem.RenumberAtoms(mol, atom_indices) - - @classmethod - def _find_branch_number(cls, inp: str): - """Find the branch number and ring closure in the SMILES representation using regexp - - Args: - inp: input smiles - """ - inp = re.sub(r"\[.*?\]", "", inp) # noqa - matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: - mol = Chem.FragmentOnBonds( - mol, - bonds, - dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], - ) - # here we need to be clever and disable rooted atom as the atom with mapping - - frags = list(Chem.GetMolFrags(mol, asMols=True)) - if randomize: - frags = rng.permutation(frags).tolist() - elif canonical: - frags = sorted( - frags, - key=lambda x: x.GetNumAtoms(), - reverse=True, - ) - - frags_str = [] - for frag in frags: - non_map_atom_idxs = [ - atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 - ] - frags_str.append( - Chem.MolToSmiles( - frag, - isomericSmiles=True, - canonical=True, # needs to always be true - rootedAtAtom=non_map_atom_idxs[0], - ) - ) - - scaffold_str = ".".join(frags_str) - # EN: fix for https://github.com/datamol-io/safe/issues/37 - # we were using the wrong branch number count which did not take into account - # possible change in digit utilization after bond slicing - scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers - - # don't capture atom mapping in the scaffold - attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) - if canonical: - attach_pos = sorted(attach_pos) - starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 - for attach in attach_pos: - val = str(starting_num) if starting_num < 10 else f"%{starting_num}" - # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" - attach_regexp = re.compile(r"(" + re.escape(attach) + r")") - scaffold_str = attach_regexp.sub(val, scaffold_str) - starting_num += 1 - - # now we need to remove all the parenthesis around digit only number - wrong_attach = re.compile(r"\(([\%\d]*)\)") - scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) - # furthermore, we autoapply rdkit-compatible digit standardization. - if rdkit_safe: - pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" - replacement = r"\g<1>\g<2>" - scaffold_str = re.sub(pattern, replacement, scaffold_str) - if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): - print( - "Warning: Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" - ) - return scaffold_str - - -def encode( - inp: Union[str, dm.Mol], - canonical: bool = True, - randomize: Optional[bool] = False, - seed: Optional[int] = None, - slicer: Optional[Union[List[str], str, Callable]] = None, - require_hs: Optional[bool] = None, - constraints: Optional[List[dm.Mol]] = None, - ignore_stereo: Optional[bool] = False, -): - """ - Convert input smiles to SAFE representation - - Args: - inp: input smiles - canonical: whether to return canonical SAFE string. Defaults to True - randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided - seed: optional seed to use when allowing randomization of the SAFE encoding. - slicer: slicer algorithm to use for encoding. Defaults to "brics". - require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. - constraints: List of molecules or pattern to preserve during the SAFE construction. - ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. - """ - if slicer is None: - slicer = "brics" - with dm.without_rdkit_log(): - safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) - try: - encoded = safe_obj.encoder( - inp, - canonical=canonical, - randomize=randomize, - constraints=constraints, - seed=seed, - ) - except SAFEFragmentationError as e: - raise e - except Exception as e: - raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e - return encoded - - -def decode( - safe_str: str, - as_mol: bool = False, - canonical: bool = False, - fix: bool = True, - remove_added_hs: bool = True, - remove_dummies: bool = True, - ignore_errors: bool = False, -): - """Convert input SAFE representation to smiles - Args: - safe_str: input SAFE representation to decode as a valid molecule or smiles - as_mol: whether to return a molecule object or a smiles string - canonical: whether to return a canonical smiles or a randomized smiles - fix: whether to fix the SAFE representation to take into account non-connected attachment points - remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. - remove_dummies: whether to remove dummy atoms from the SAFE representation - ignore_errors: whether to ignore error and return None on decoding failure or raise an error - - """ - with dm.without_rdkit_log(): - safe_obj = SAFEConverter() - try: - decoded = safe_obj.decoder( - safe_str, - as_mol=as_mol, - canonical=canonical, - fix=fix, - remove_dummies=remove_dummies, - remove_added_hs=remove_added_hs, - ) - - except Exception as e: - if ignore_errors: - return None - raise SAFEDecodeError(f"Failed to decode {safe_str}") from e - return decoded - -def main(): - smiles = "O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1" - safe_string = encode(smiles) - print("SAFE representation:", safe_string) - print("SMILES representation:", decode(safe_string)) - -if __name__ == "main": - main() \ No newline at end of file diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 9514ecf8..65a0c15e 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -1,6 +1,5 @@ # Adapted from https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/utils/text_format_utils.py # All rights reserved -from torchtitan.utils.safe import encode SPECIAL_TAGS = { "SMILES": {"start": "[START_SMILES]", "end": "[END_SMILES]"}, @@ -82,19 +81,14 @@ def delete_empty_tags(compound_json): return compound_json -def generate_formatted_string(compound_json, rng, representation_type = "SMILES"): +def generate_formatted_string(compound_json, rng): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") - - if representation_type == "SAFE": - value = encode(value) - if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) del compound_json[key] - keys = list(compound_json.keys()) rng.shuffle(keys) diff --git a/train.py b/train.py index 797fca61..b7ee0d23 100644 --- a/train.py +++ b/train.py @@ -93,7 +93,6 @@ def main(job_config: JobConfig): tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader - representation_type = job_config.training.representation_type data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -103,7 +102,6 @@ def main(job_config: JobConfig): job_config.training.seq_len, dp_degree, dp_rank, - representation_type, pin_memory = job_config.dataloader.pin_memory, num_workers = job_config.dataloader.num_workers, special_mode = job_config.dataloader.special_mode, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 7aa7caa1..2829d098 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -44,7 +44,6 @@ tensor_parallel_degree = 1 compile = true dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) data_process_style="chemlactica_style" -representation_type="SAFE" [experimental] pipeline_parallel_degree = 1