From 084783b3ad1dd8feb95396765ea1106b786bddf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ter=C3=A9zia=20Slanin=C3=A1kov=C3=A1?= <445526@mail.muni.cz> Date: Tue, 1 Oct 2024 12:48:02 +0200 Subject: [PATCH] fix model parsing, arg passing --- training/alphafind_training/create_buckets.py | 15 +++++++++++---- training/train_alphafind.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/training/alphafind_training/create_buckets.py b/training/alphafind_training/create_buckets.py index 1b7d5ab..28cea6f 100644 --- a/training/alphafind_training/create_buckets.py +++ b/training/alphafind_training/create_buckets.py @@ -41,16 +41,23 @@ def load_all_embeddings(path): def parse_model_params(model_path): LOG.info(f'Parsing out model params from model path: {model_path}') pattern = r'model-(\w+)--.*?n_classes-(\d+)(?:--.*?dimensionality-(\d+))?' + + if model_path is None: + model = 'MLP' + dimensionality = DEFAULT_DIMENSIONALITY + n_classes = 2 + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') + return model, dimensionality, n_classes + match = re.search(pattern, model_path, re.MULTILINE) - # new model format if match and len(match.groups()) == 3: - model = match.group(1) - n_classes = int(match.group(2)) - dimensionality = match.group(3) + model, n_classes, dimensionality = match.groups() dimensionality = int(dimensionality) if dimensionality is not None else DEFAULT_DIMENSIONALITY + n_classes = int(n_classes) else: LOG.info(f'Failed to parse out model params from model path: {model_path}') exit(1) + LOG.info(f'Parsed out model={model}, dimensionality={dimensionality}, n_classes={n_classes}') return model, dimensionality, n_classes diff --git a/training/train_alphafind.py b/training/train_alphafind.py index 90acc1a..16873ae 100644 --- a/training/train_alphafind.py +++ b/training/train_alphafind.py @@ -40,7 +40,7 @@ def train_alphafind(base_dir, data_dir, models_dir): # 5) Create bucket-data mapping to protein IDs create_mapping( - bucket_path=os.path.join(data_dir, "bucket-data"), output_path=os.path.join(data_dir, "bucket-mapping.pkl") + bucket_data_path=os.path.join(data_dir, "bucket-data"), output_path=os.path.join(data_dir, "bucket-mapping.pkl") )