diff --git a/README.md b/README.md index 04a4100..58ed4e1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ conda create env -f environment.yml pip install clipzyme ``` +3. Download ESM-2 checkpoint `esm2_t33_650M_UR50D`. The `esm_dir` argument should point to this directory. # Screening with CLIPZyme ## Using CLIPZyme's screening set @@ -89,7 +90,7 @@ from clipzyme import ReactionDataset #------------------------- reaction_dataset = ReactionDataset( dataset_file_path = "files/new_data.csv", - esm_dir = "/path/to/esm2_t33_650M_UR50D.pt", + esm_dir = "/path/to/esm2_dir", protein_cache_dir = "/path/to/protein_cache", ) @@ -130,8 +131,9 @@ for batch in reaction_dataset: "save_predictions": [true], # whether to save the reaction-enzyme pair scores "use_as_protein_encoder": [true], # whether to use the model as a protein encoder only "use_as_reaction_encoder": [true], # whether to use the model as a reaction encoder only - "protein_cache_dir": ["/path/to/protein_cache"], # where to save the protein cache - "gpus": [8], # number of gpus to use + "esm_dir": ["/data/esm/checkpoints"], path to ESM-2 checkpoints + "gpus": [8], # number of gpus to use, + "protein_cache_dir": ["/path/to/protein_cache"], # where to save the protein cache [optional] ... } ``` @@ -169,8 +171,8 @@ We obtain the data from the following sources: Our processed data is available at [here](`https://doi.org/10.5281/zenodo.5555555`). It consists of the following files: - `enzymemap.json`: contains the EnzymeMap dataset. - `terpene_synthases.json`: contains the Terpene Synthases dataset. -- `enzymemap_screening.p`: contains the screening set. -- `sequenceid2sequence.p`: contains the mapping form sequence ID to amino acids. +- `clipzyme_screening_set.p`: contains the screening set as dict of UniProt IDs and precomputed protein embeddings. +- `uniprot2sequence.p`: contains the mapping form sequence ID to amino acids. ## Training and evaluation diff --git a/clipzyme/datasets/enzymemap.py b/clipzyme/datasets/enzymemap.py index 76ee97a..e81c7dd 100644 --- a/clipzyme/datasets/enzymemap.py +++ b/clipzyme/datasets/enzymemap.py @@ -20,9 +20,7 @@ from clipzyme.utils.registry import register_object from clipzyme.datasets.abstract import AbstractDataset -from clipzyme.utils.smiles import ( - generate_scaffold, -) + from clipzyme.utils.protein_utils import ( read_structure_file, filter_resolution, @@ -32,7 +30,6 @@ ) from clipzyme.utils.pyg import from_mapped_smiles from clipzyme.utils.wln_processing import get_bond_changes -from clipzyme.models.wln import WLDN_Cache ESM_MODEL2HIDDEN_DIM = { "esm2_t48_15B_UR50D": 5120, @@ -72,8 +69,6 @@ def __init__(self, args, split_group) -> None: self.batch_converter = alphabet.get_batch_converter() super(EnzymeMap, EnzymeMap).__init__(self, args, split_group) self.metadata_json = None # overwrite for memory - if args.load_wln_cache_in_dataset: - self.cache = WLDN_Cache(args.cache_path) def init_class(self, args: argparse.ArgumentParser, split_group: str) -> None: """Perform Class-Specific init methods @@ -88,31 +83,15 @@ def init_class(self, args: argparse.ArgumentParser, split_group: str) -> None: self.valid_ec2uniprot = defaultdict(set) - self.ec2uniprot = pickle.load( - open( - "/home/datasets/EnzymeMap/ec2uniprot.p", - "rb", - ) - ) - self.uniprot2sequence = pickle.load( - open( - "/home/datasets/EnzymeMap/uniprot2sequence.p", - "rb", - ) - ) + self.ec2uniprot = pickle.load(open("files/ec2uniprot.p", "rb")) + self.uniprot2sequence = pickle.load(open("files/uniprot2sequence.p", "rb")) self.uniprot2sequence_len = { k: 0 if v is None else len(v) for k, v in self.uniprot2sequence.items() } - self.uniprot2cluster = pickle.load( - open( - args.uniprot2cluster_path, - "rb", - ) - ) # products to remove based on smiles or pattern - remove_patterns_path = "/home/datasets/ECReact/patterns.txt" - remove_molecules_path = "/home/datasets/ECReact/molecules.txt" + remove_patterns_path = "files/ecreact/patterns.txt" + remove_molecules_path = "files/ecreact/molecules.txt" self.remove_patterns = [] self.remove_molecules = [] @@ -137,17 +116,6 @@ def init_class(self, args: argparse.ArgumentParser, split_group: str) -> None: ) ) - if ( - not hasattr(self.args, "use_all_sequences") - or not self.args.use_all_sequences - ): - self.uniprot2split = pickle.load( - open( - "/home/datasets/EnzymeMap/mmseq_splits_precomputed.p", - "rb", - ) - ) - def create_dataset( self, split_group: Literal["train", "dev", "test"] ) -> List[dict]: @@ -377,13 +345,6 @@ def __getitem__(self, index): for k, v in self.args.ec_levels.items(): item[f"ec{k}"] = v.get(".".join(split_ec[: int(k)]), -1) - if self.args.use_pesto_scores: - scores = self.get_pesto_scores(item["protein_id"]) - if (scores is None) or (scores.shape[0] != len(item["sequence"])): - # make all zeros of length sequence - scores = torch.zeros(len(item["sequence"])) - item["sequence_annotation"] = scores - return item except Exception as e: @@ -412,314 +373,37 @@ def assign_splits(self, metadata_json, split_probs, seed=0) -> None: # set seed np.random.seed(seed) - # assign groups - if self.args.split_type in [ - "mmseqs", - "sequence", - "ec", - "product", - "mmseqs_precomputed", - ]: - if ( - self.args.split_type == "mmseqs" - or self.args.split_type == "mmseqs_precomputed" - ): - samples = [ - self.uniprot2cluster[reaction["uniprot_id"]] - for reaction in metadata_json - ] - # samples = list(self.uniprot2cluster.values()) - - if self.args.split_type == "sequence": - # split based on uniprot_id - samples = [ - u - for reaction in metadata_json - for u in self.ec2uniprot.get(reaction["ec"], []) - ] - if "protein_id" in metadata_json[0]: - samples += [r["protein_id"] for r in metadata_json] - - elif self.args.split_type == "ec": - # split based on ec number - samples = [reaction["ec"] for reaction in metadata_json] - - # option to change level of ec categorization based on which to split - samples = [ - ".".join(e.split(".")[: self.args.ec_level + 1]) for e in samples - ] - - elif self.args.split_type == "product": - # split by reaction product (splits share no products) - samples = [".".join(s["products"]) for s in metadata_json] - - sample2count = Counter(samples) - samples = sorted(list(set(samples))) - np.random.shuffle(samples) - samples_cumsum = np.cumsum([sample2count[s] for s in samples]) - # Find the indices for each quantile - split_indices = [ - np.searchsorted( - samples_cumsum, np.round(q, 3) * samples_cumsum[-1], side="right" - ) - for q in np.cumsum(split_probs) - ] - split_indices[-1] = len(samples) - split_indices = np.concatenate([[0], split_indices]) - for i in range(len(split_indices) - 1): - self.to_split.update( - { - sample: ["train", "dev", "test"][i] - for sample in samples[split_indices[i] : split_indices[i + 1]] - } - ) - - elif self.args.split_type == "rule_id": - # rule id - rules = [reaction["rule_id"] for reaction in metadata_json] - rule2count = Counter(rules) - samples = sorted(list(set(rules))) - np.random.shuffle(samples) - samples_cumsum = np.cumsum([rule2count[s] for s in samples]) - # Find the indices for each quantile - split_indices = [ - np.searchsorted( - samples_cumsum, np.round(q, 3) * samples_cumsum[-1], side="right" - ) - for q in np.cumsum(split_probs) - ] - split_indices[-1] = len(samples) - split_indices = np.concatenate([[0], split_indices]) - for i in range(len(split_indices) - 1): - self.to_split.update( - { - sample: ["train", "dev", "test"][i] - for sample in samples[split_indices[i] : split_indices[i + 1]] - } - ) - - elif self.args.split_type == "ec_hold_out": - unique_products = set( - [ - ".".join(sample["products"]) - for sample in self.metadata_json - if sample["ec"].split(".")[0] != str(self.args.held_out_ec_num) - ] + # rule id + rules = [reaction["rule_id"] for reaction in metadata_json] + rule2count = Counter(rules) + samples = sorted(list(set(rules))) + np.random.shuffle(samples) + samples_cumsum = np.cumsum([rule2count[s] for s in samples]) + # Find the indices for each quantile + split_indices = [ + np.searchsorted( + samples_cumsum, np.round(q, 3) * samples_cumsum[-1], side="right" ) - # ! ENSURE REPRODUCIBLE SETS FOR SAME SEED - unique_products = sorted(list(unique_products)) - np.random.shuffle(unique_products) - - dev_probs = split_probs[1] / (split_probs[0] + split_probs[1]) - train_probs = split_probs[0] / (split_probs[0] + split_probs[1]) - if not self.args.split_multiproduct_samples: - products2split = { - p: np.random.choice(["train", "dev"], p=[train_probs, dev_probs]) - for p in unique_products + for q in np.cumsum(split_probs) + ] + split_indices[-1] = len(samples) + split_indices = np.concatenate([[0], split_indices]) + for i in range(len(split_indices) - 1): + self.to_split.update( + { + sample: ["train", "dev", "test"][i] + for sample in samples[split_indices[i] : split_indices[i + 1]] } - else: - products2split = {} - for p_list in unique_products: - for p in p_list.split("."): - products2split[p] = np.random.choice( - ["train", "dev"], p=[train_probs, dev_probs] - ) - - for sample in self.metadata_json: - ec = sample["ec"] - rkey = ( - "mapped_reactants" - if "mapped_reactants" in self.metadata_json[0] - else "reactants" - ) - pkey = ( - "mapped_products" - if "mapped_products" in self.metadata_json[0] - else "products" - ) - reactants = sorted([s for s in sample[rkey] if s != "[H+]"]) - products = sorted([s for s in sample[pkey] if s != "[H+]"]) - products = [p for p in products if p not in reactants] - - if self.args.topk_byproducts_to_remove is not None: - products = [p for p in products if p not in self.common_byproducts] - - reaction_string = "{}>>{}".format( - ".".join(reactants), ".".join(products) - ) - - if self.args.version == "1": - alluniprots = self.ec2uniprot.get(ec, []) - protein_refs = [] - elif self.args.version == "2": - protein_refs = eval(sample["protein_refs"]) - alluniprots = protein_refs - if (len(alluniprots) == 0) and self.args.sample_uniprot_per_ec: - alluniprots = self.ec2uniprot.get(ec, []) - - if ( - self.args.create_sample_per_sequence - or self.args.sample_uniprot_per_ec - ): - valid_uniprots = [] - for uniprot in alluniprots: - if self.args.split_multiproduct_samples: - for product_id, p in enumerate(products): - psample = copy.deepcopy(sample) - psample["products"] = [p] - # psample["sample_id"] += f"_{product_id}" - preaction_string = "{}>>{}".format( - ".".join(psample["reactants"]), p - ) - # uniprot = psample["uniprot_id"] - punique_sample_content = f"{preaction_string}{uniprot}{psample.get('organism', '')}" - phashed_sample_content = hashlib.sha256( - punique_sample_content.encode("utf-8") - ).hexdigest() - psample["hash_sample_id"] = phashed_sample_content - if str(self.args.held_out_ec_num) == ec: - self.to_split[psample["hash_sample_id"]] = "test" - else: - self.to_split[psample["hash_sample_id"]] = ( - products2split[p] - ) - - else: - unique_sample_content = f"{reaction_string}{uniprot}{sample.get('organism', '')}" - hashed_sample_content = hashlib.sha256( - unique_sample_content.encode("utf-8") - ).hexdigest() - sample["hash_sample_id"] = hashed_sample_content - if sample["ec"].split(".")[0] == str( - self.args.held_out_ec_num - ): - self.to_split[sample["hash_sample_id"]] = "test" - else: - self.to_split[sample["hash_sample_id"]] = ( - products2split[".".join(sample["products"])] - ) - - # random splitting - elif self.args.split_type == "random": - for sample in self.metadata_json: - reaction_string = ( - ".".join(sample["reactants"]) + ">>" + ".".join(sample["products"]) - ) - self.to_split.update( - { - reaction_string: np.random.choice( - ["train", "dev", "test"], p=split_probs - ) - } - ) - elif self.args.split_type == "scaffold": - # split based on scaffold - self.scaffold_split(metadata_json, split_probs, seed) - else: - raise ValueError("Split type not supported") - - def scaffold_split(self, meta: List[dict], split_probs: List[float], seed): - scaffold_to_indices = defaultdict(list) - for m_i, m in enumerate(meta): - scaffold = generate_scaffold(m["smiles"]) - scaffold_to_indices[scaffold].append(m_i) - - # Split - train_size, val_size, test_size = ( - split_probs[0] * len(meta), - split_probs[1] * len(meta), - split_probs[2] * len(meta), - ) - train, val, test = [], [], [] - train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0 - - # Seed randomness - random = Random(seed) - - if self.args.scaffold_balanced: # Put stuff that's bigger than half the val/test size into train, rest just order randomly - index_sets = list(scaffold_to_indices.values()) - big_index_sets = [] - small_index_sets = [] - for index_set in index_sets: - if len(index_set) > val_size / 2 or len(index_set) > test_size / 2: - big_index_sets.append(index_set) - else: - small_index_sets.append(index_set) - random.seed(seed) - random.shuffle(big_index_sets) - random.shuffle(small_index_sets) - index_sets = big_index_sets + small_index_sets - else: # Sort from largest to smallest scaffold sets - index_sets = sorted( - list(scaffold_to_indices.values()), - key=lambda index_set: len(index_set), - reverse=True, ) - for index_set in index_sets: - if len(train) + len(index_set) <= train_size: - train += index_set - train_scaffold_count += 1 - elif len(val) + len(index_set) <= val_size: - val += index_set - val_scaffold_count += 1 - else: - test += index_set - test_scaffold_count += 1 - - for idx_list, split in [(train, "train"), (val, "dev"), (test, "test")]: - for idx in idx_list: - meta[idx]["split"] = split - if ( - meta[idx]["smiles"] in self.to_split - and self.to_split[meta[idx]["smiles"]] != split - ): - raise Exception("Smile exists in to_split but with different split") - self.to_split[meta[idx]["smiles"]] = split - def get_split_group_dataset( self, processed_dataset, split_group: Literal["train", "dev", "test"] ) -> List[dict]: dataset = [] for sample in processed_dataset: # check right split - if self.args.split_type == "ec": - split_ec = sample[f"ec{self.args.ec_level + 1}"] - if self.to_split[split_ec] != split_group: - continue - - elif self.args.split_type == "rule_id": - if self.to_split[sample["rule_id"]] != split_group: - continue - - elif self.args.split_type == "mmseqs": - cluster = self.uniprot2cluster.get(sample["protein_id"], None) - if (cluster is None) or (self.to_split[cluster] != split_group): - continue - elif ( - self.args.split_type == "mmseqs_precomputed" - or self.args.split_type == "scaffold" - ): - if sample["split"] != split_group: - continue - elif self.args.split_type in ["product"]: - products = ".".join(sample["products"]) - if self.to_split[products] != split_group: - continue - - elif self.args.split_type == "sequence": - uniprot = sample["protein_id"] - if self.to_split[uniprot] != split_group: - continue - - elif self.args.split_type == "ec_hold_out": - sample_id = sample["hash_sample_id"] - if self.to_split[sample_id] != split_group: - continue - - elif sample["split"] is not None: - if sample["split"] != split_group: - continue + if self.to_split[sample["rule_id"]] != split_group: + continue dataset.append(sample) return dataset @@ -839,23 +523,11 @@ def add_args(parser) -> None: Args: parser (argparse.ArgumentParser): argument parser """ - super(EnzymeMap, EnzymeMap).add_args(parser) - parser.add_argument( - "--held_out_ec_num", - type=int, - default=None, - help="EC number to hold out", - ) - parser.add_argument( - "--uniprot2cluster_path", - type=str, - default="/home/datasets/EnzymeMap/mmseq_clusters_updated.p", - help="path to uniprot2cluster pickle", - ) + AbstractDataset.add_args(parser) parser.add_argument( "--esm_dir", type=str, - default="/home/snapshots/metabolomics/esm2/checkpoints/esm2_t33_650M_UR50D.pt", + default="/home/esm2/checkpoints/esm2_t33_650M_UR50D.pt", help="directory to load esm model from", ) parser.add_argument( @@ -949,18 +621,6 @@ def add_args(parser) -> None: default=None, help="remove common byproducts", ) - parser.add_argument( - "--use_pesto_scores", - action="store_true", - default=False, - help="use pesto scores", - ) - parser.add_argument( - "--pesto_scores_directory", - type=str, - default="/home/datasets/ECReact/pesto_ligands", - help="load pesto scores from directory predictions", - ) parser.add_argument( "--create_sample_per_sequence", action="store_true", @@ -985,12 +645,6 @@ def add_args(parser) -> None: default=-1, help="minimum threshold to use for filtering reactions based on quality score", ) - parser.add_argument( - "--load_wln_cache_in_dataset", - action="store_true", - default=False, - help="load cache for wln in getitem", - ) parser.add_argument( "--split_multiproduct_samples", action="store_true", @@ -1003,30 +657,12 @@ def add_args(parser) -> None: default=False, help="encode node and edge features of molecule as one-hot", ) - parser.add_argument( - "--scaffold_balanced", - action="store_true", - default=False, - help="balance the scaffold sets", - ) parser.add_argument( "--version", type=str, default="1", help="enzyme map version number", ) - parser.add_argument( - "--convert_graph_to_smiles", - action="store_true", - default=False, - help="for sequence based methods", - ) - parser.add_argument( - "--reaction_to_products_dir", - type=str, - default=None, - help="cache for post process step", - ) parser.add_argument( "--remove_duplicate_reactions", action="store_true", @@ -1063,21 +699,6 @@ def SUMMARY_STATEMENT(self) -> None: @register_object("enzymemap_reaction_graph", "dataset") class EnzymeMapGraph(EnzymeMap): def post_process(self, args): - def make_reaction_to_products(): - reaction_to_products = defaultdict(set) - if args.create_sample_per_sequence: - key = lambda sample: f"{sample['ec']}{'.'.join(sample['reactants'])}" - else: - key = lambda sample: ".".join(sample["reactants"]) - for sample in tqdm(self.dataset, desc="post-processing", ncols=100): - reaction_to_products[key(sample)].add( - ( - ".".join(sample["products"]), - stringify_sets(sorted(sample["bond_changes"])), - ) - ) - return reaction_to_products - # set ec levels to id for use in modeling ecs = set(d["ec"] for d in self.dataset) ecs = [e.split(".") for e in ecs] @@ -1423,13 +1044,6 @@ def __getitem__(self, index): if self.args.load_wln_cache_in_dataset: item["product_candidates"] = self.cache.get(rowid) - if self.args.use_pesto_scores: - scores = self.get_pesto_scores(item["protein_id"]) - if (scores is None) or (scores.shape[0] != len(item["sequence"])): - # make all zeros of length sequence - scores = torch.zeros(len(item["sequence"])) - item["sequence_annotation"] = scores - if self.args.use_protein_graphs: if self.args.cache_path: try: diff --git a/clipzyme/lightning/clipzyme.py b/clipzyme/lightning/clipzyme.py index c6be093..6cf6edd 100644 --- a/clipzyme/lightning/clipzyme.py +++ b/clipzyme/lightning/clipzyme.py @@ -277,7 +277,7 @@ def extract_protein_features( self, batch: dict = None, cif_path: Union[str, List[str]] = None, - esm_path: str = None, + esm_dir: str = None, ) -> torch.Tensor: """ Extract protein features from model. @@ -296,12 +296,15 @@ def extract_protein_features( if cif_path is not None: assert ( - esm_path is not None - ), "If manually extracting protein embedding, then `esm_path` must be provided" + esm_dir is not None + ), "If manually extracting protein embedding, then `esm_dir` must be provided" if isinstance(cif_path, str): cif_path = [cif_path] protein_graphs = [ - create_protein_graph(cif_path=cpath, esm_path=esm_path) + create_protein_graph( + cif_path=cpath, + esm_path=os.path.join(esm_dir, "esm2_t33_650M_UR50D.pt"), + ) for cpath in cif_path ] batch = default_collate([{"graph": g} for g in protein_graphs]) diff --git a/scripts/screen.py b/scripts/screen.py index e665d96..3ee744e 100644 --- a/scripts/screen.py +++ b/scripts/screen.py @@ -126,6 +126,11 @@ def cast_type(val): default=32, help="Batch size for training [default: 128]", ) + parser.add_argument( + "--precision", + default="bf16", + help="precision to use for eval", + ) parser.add_argument( "--num_workers", type=int, @@ -182,6 +187,11 @@ def cast_type(val): default="logs/test", help="Where to save the arguments of the run", ) + parser.add_argument( + "--experiment_name", + type=str, + help="defined either automatically by dispatcher.py or time in main.py. Keep without default", + ) args = parser.parse_args() eval(args)