From 33962a735f67074855c0ee371cbec81711ebae96 Mon Sep 17 00:00:00 2001 From: ConorFWild <41680328+ConorFWild@users.noreply.github.com> Date: Mon, 23 Oct 2023 12:08:23 +0100 Subject: [PATCH] New crystalform and assemblies documentation --- scripts/collate.py | 32 ++++++++++++++++++++++++++++++++ xchemalign/aligner.py | 36 ++++++++++++++++++------------------ xchemalign/collator.py | 4 ++-- 3 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 scripts/collate.py diff --git a/scripts/collate.py b/scripts/collate.py new file mode 100644 index 0000000..2e65058 --- /dev/null +++ b/scripts/collate.py @@ -0,0 +1,32 @@ +import argparse + +from xchemalign import utils +from xchemalign.collator import Collator + + +def main(): + parser = argparse.ArgumentParser(description="collator") + + parser.add_argument("-c", "--config-file", default="config.yaml", help="Configuration file") + parser.add_argument("-l", "--log-file", help="File to write logs to") + parser.add_argument("--log-level", type=int, default=0, help="Logging level") + parser.add_argument("-v", "--validate", action="store_true", help="Only perform validation") + + args = parser.parse_args() + logger = utils.Logger(logfile=args.log_file, level=args.log_level) + logger.info("collator: ", args) + + c = Collator(args.config_file, logger=logger) + + meta, num_errors, num_warnings = c.validate() + + if not args.validate: + if meta is None or num_errors: + print("There are errors, cannot continue") + exit(1) + else: + c.run(meta) + + +if __name__ == "__main__": + main() diff --git a/xchemalign/aligner.py b/xchemalign/aligner.py index 16ee32d..88cf722 100644 --- a/xchemalign/aligner.py +++ b/xchemalign/aligner.py @@ -190,10 +190,6 @@ def __init__(self, version_dir, metadata, xtalforms, assemblies, logger=None): self.xtalforms_file = Path(xtalforms) else: self.xtalforms_file = self.base_dir / Constants.XTALFORMS_FILENAME # e.g. path/to/xtalforms.yaml - if assemblies: - self.assemblies_file = Path(assemblies) - else: - self.assemblies_file = self.base_dir / Constants.ASSEMBLIES_FILENAME # e.g. path/to/assemblies.yaml if logger: self.logger = logger else: @@ -234,7 +230,7 @@ def run(self): def _write_output(self, collator_dict, aligner_dict): # keep a copy of the xtaforms and assemblies configs self._copy_file_to_version_dir(self.xtalforms_file) - self._copy_file_to_version_dir(self.assemblies_file) + # self._copy_file_to_version_dir(self.assemblies_file) collator_dict[Constants.META_XTALFORMS] = aligner_dict[Constants.META_XTALFORMS] collator_dict[Constants.META_CONFORMER_SITES] = aligner_dict[Constants.META_CONFORMER_SITES] @@ -313,7 +309,7 @@ def _perform_alignments(self, meta): # Load the fs model for the new output dir fs_model = dt.FSModel.from_dir(output_path) fs_model.xtalforms = self.xtalforms_file - fs_model.assemblies = self.assemblies_file + # fs_model.assemblies = self.assemblies_file if source_fs_model: fs_model.alignments = source_fs_model.alignments fs_model.reference_alignments = source_fs_model.reference_alignments @@ -331,16 +327,21 @@ def _perform_alignments(self, meta): datasets, reference_datasets, new_datasets = get_datasets_from_crystals(crystals, self.base_dir) # Get assemblies - if source_fs_model: - assemblies: dict[str, dt.Assembly] = _load_assemblies(source_fs_model.assemblies, self.assemblies_file) - else: - assemblies = _load_assemblies(fs_model.assemblies, self.assemblies_file) - - # Get xtalforms - if source_fs_model: - xtalforms: dict[str, dt.XtalForm] = _load_xtalforms(source_fs_model.xtalforms, self.xtalforms_file) - else: - xtalforms = _load_xtalforms(fs_model.xtalforms, self.xtalforms_file) + # if source_fs_model: + # assemblies: dict[str, dt.Assembly] = _load_assemblies(source_fs_model.assemblies, self.assemblies_file) + # else: + # assemblies = _load_assemblies(fs_model.assemblies, self.assemblies_file) + # + # # Get xtalforms + # if source_fs_model: + # xtalforms: dict[str, dt.XtalForm] = _load_xtalforms(source_fs_model.xtalforms, self.xtalforms_file) + # else: + # xtalforms = _load_xtalforms(fs_model.xtalforms, self.xtalforms_file) + xtalforms, assemblies = _load_xtalforms_and_assemblies( + source_fs_model.xtalforms, + source_fs_model.assemblies, + self.xtalforms_file + ) # Get the dataset assignments if source_fs_model: @@ -745,7 +746,6 @@ def main(): parser.add_argument("-d", "--version-dir", required=True, help="Path to version dir") parser.add_argument("-m", "--metadata_file", default=Constants.METADATA_XTAL_FILENAME, help="Metadata YAML file") parser.add_argument("-x", "--xtalforms", help="Crystal forms YAML file") - parser.add_argument("-a", "--assemblies", help="Assemblies YAML file") parser.add_argument("-l", "--log-file", help="File to write logs to") parser.add_argument("--log-level", type=int, default=0, help="Logging level") parser.add_argument("--validate", action="store_true", help="Only perform validation") @@ -755,7 +755,7 @@ def main(): logger = utils.Logger(logfile=args.log_file, level=args.log_level) - a = Aligner(args.version_dir, args.metadata_file, args.xtalforms, args.assemblies, logger=logger) + a = Aligner(args.version_dir, args.metadata_file, args.xtalforms, logger=logger) num_errors, num_warnings = a.validate() if not args.validate: diff --git a/xchemalign/collator.py b/xchemalign/collator.py index 951e6dd..9a63414 100644 --- a/xchemalign/collator.py +++ b/xchemalign/collator.py @@ -26,8 +26,8 @@ from rdkit import Chem -from . import utils, dbreader -from .utils import Constants +from xchemalign import utils, dbreader +from xchemalign.utils import Constants def generate_xtal_dir(input_path: Path, xtal_name: str):