diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index d40d5b0..a3253a6 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Union + import numpy as np from molexpress import types @@ -23,27 +25,32 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) for residue in residues: residue = chem_ops.get_molecule(residue) residue_graph = { - **self.node_encoder(residue), + **self.node_encoder(residue), **self.edge_encoder(residue) } residue_graphs.append(residue_graph) residue_sizes.append(residue.GetNumAtoms()) disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs) disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) - return disjoint_peptide_graph - + disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32") + return disjoint_peptide_graph + @staticmethod - def _collate_fn( - data: list[tuple[types.MolecularGraph, np.ndarray]], + def collate_fn( + data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]], ) -> tuple[types.MolecularGraph, np.ndarray]: - """TODO: Not sure where to implement this collate function. - Temporarily putting it here. - - Procedure: - Merges list of graphs into a single disjoint graph. """ + Merge list of graphs into a single disjoint graph. - disjoint_peptide_graphs, y = list(zip(*data)) + Data can be a list of MolecularGraphs or a list of tuples where the first element is a + MolecularGraph and the second element is a label. + + """ + if isinstance(data[0], tuple): + disjoint_peptide_graphs, y = list(zip(*data)) + else: + disjoint_peptide_graphs = data + y = None disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs( disjoint_peptide_graphs @@ -54,7 +61,11 @@ def _collate_fn( disjoint_peptide_batch_graph["residue_size"] = np.concatenate([ g["residue_size"] for g in disjoint_peptide_graphs ]).astype("int32") - return disjoint_peptide_batch_graph, np.stack(y) + + if y is None: + return disjoint_peptide_batch_graph + else: + return disjoint_peptide_batch_graph, np.stack(y) @staticmethod def _merge_molecular_graphs( diff --git a/molexpress/layers/residue_readout.py b/molexpress/layers/residue_readout.py index c0cad78..ddc4d80 100644 --- a/molexpress/layers/residue_readout.py +++ b/molexpress/layers/residue_readout.py @@ -22,8 +22,8 @@ def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: raise ValueError("Cannot perform readout: 'residue_size' not found.") def call(self, inputs: types.MolecularGraph) -> types.Array: - peptide_size = keras.ops.cast(inputs['peptide_size'], 'int32') - residue_size = keras.ops.cast(inputs['residue_size'], 'int32') + peptide_size = keras.ops.cast(inputs["peptide_size"], "int32") + residue_size = keras.ops.cast(inputs["residue_size"], "int32") n_residues = keras.ops.shape(residue_size)[0] segment_ids = keras.ops.repeat(range(n_residues), residue_size) residue_state = self._readout_fn( @@ -34,25 +34,21 @@ def call(self, inputs: types.MolecularGraph) -> types.Array: ) # Make shape known residue_state = keras.ops.reshape( - residue_state, - ( - keras.ops.shape(residue_size)[0], - keras.ops.shape(inputs['node_state'])[-1] - ) + residue_state, + (keras.ops.shape(residue_size)[0], keras.ops.shape(inputs["node_state"])[-1]), ) - + if keras.ops.shape(peptide_size)[0] == 1: # Single peptide in batch return residue_state[None] - + # Split and stack (with padding in the second dim) # Resulting shape: (n_peptides, n_residues, n_features) - residues = keras.ops.split(residue_state, peptide_size[:-1]) + residues = keras.ops.split(residue_state, keras.ops.cumsum(peptide_size)[:-1]) max_residue_size = keras.ops.max([len(r) for r in residues]) - return keras.ops.stack([ - keras.ops.pad(r, [(0, max_residue_size-keras.ops.shape(r)[0]), (0, 0)]) - for r in residues - ]) - - - + return keras.ops.stack( + [ + keras.ops.pad(r, [(0, max_residue_size - keras.ops.shape(r)[0]), (0, 0)]) + for r in residues + ] + )