Skip to content

Commit

Permalink
✨ Simpler API to access run parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-flam committed Aug 6, 2020
1 parent 294531f commit 7a4f798
Show file tree
Hide file tree
Showing 18 changed files with 332 additions and 302 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ log.lsf.txt
log.txt
ssl/
commits/
user/

image.batches.yaml
iter.batches.yaml
Expand Down
4 changes: 2 additions & 2 deletions qaboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .check_for_updates import check_for_updates
check_for_updates()

from .config import on_windows, on_linux, on_lsf, on_vdi, is_ci
from .config import config, merge
from .config import on_windows, on_linux, on_lsf, on_vdi, is_ci, config
from .utils import merge
from .conventions import slugify
from .qa import qa
23 changes: 3 additions & 20 deletions qaboard/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import sys
import datetime
from getpass import getuser
from itertools import chain
from pathlib import Path, PurePosixPath
from typing import Dict, Any, Tuple, List, Optional

import yaml
import click

from .utils import getenvs
from .utils import merge, getenvs
from .git import git_head, git_show
from .conventions import slugify, get_commit_dirs, location_from_spec
from .iterators import flatten
Expand Down Expand Up @@ -68,22 +67,6 @@ def find_configs(path : Path) -> List[Tuple[Dict, Path]]:
dim=True, err=True)


def merge(src: Dict, dest: Dict) -> Dict:
"""Deep merge QA-Board configuration files"""
# https://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data
if src:
for key, value in src.items():
if isinstance(value, dict):
node = dest.setdefault(key, {})
merge(value, node)
elif value:
# "super" is a reserved keyword
if isinstance(value, list) and "super" in value:
value = list(chain.from_iterable([[e] if e != "super" else dest.get(key, []) for e in value]))
dest[key] = value
return dest


# take care not to mutate the root config, as its project.name is the git repo name
config : Dict[str, Any] = {}
for c in qatools_configs:
Expand Down Expand Up @@ -341,9 +324,9 @@ def mkdir(path: Path):
default_input_type = config_inputs_types.get('default', 'default')


def get_default_configuration(input_settings):
def get_default_configuration(input_settings) -> str:
from .conventions import serialize_config
default_configuration = input_settings.get('configurations', input_settings.get('configuration', []))
default_configuration = input_settings.get('configs', input_settings.get('configurations', input_settings.get('configuration', [])))
default_configuration = list(flatten(default_configuration))
return serialize_config(default_configuration)

Expand Down
12 changes: 7 additions & 5 deletions qaboard/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def iter_batch(batch: Dict, default_run_context: RunContext, qatools_config, def
return

run_context = copy(default_run_context)
run_context.configurations = batch.get('configurations', batch.get('configuration', run_context.configurations))
run_context.configurations = batch.get('configs', batch.get('configurations', batch.get('configuration', run_context.configurations)))
run_context.configurations = list(flatten(run_context.configurations))
if 'platform' in batch:
run_context.platform = batch['platform']
Expand All @@ -261,7 +261,7 @@ def iter_batch(batch: Dict, default_run_context: RunContext, qatools_config, def
from sklearn.model_selection import ParameterGrid
for matrix in ParameterGrid(batch['matrix']):
batch_ = copy(batch)
for key in ['matrix', 'configuration', 'configurations', 'platform']:
for key in ['matrix', 'configuration', 'configurations', 'configs', 'platform']:
if key in batch:
del batch_[key]
matrix_run_context = copy(run_context)
Expand All @@ -271,8 +271,10 @@ def iter_batch(batch: Dict, default_run_context: RunContext, qatools_config, def
matrix_run_context.configurations = matrix['configuration']
if 'configurations' in matrix:
matrix_run_context.configurations = matrix['configurations']
if 'configs' in matrix:
matrix_run_context.configurations = matrix['configs']
for param, value in matrix.items():
if param in ['configuration', 'configurations', 'platform']:
if param in ['configuration', 'configurations', 'configs', 'platform']:
continue
matrix_run_context.configurations = deep_interpolate(matrix_run_context.configurations, '${matrix.%s}' % param, value)
yield from iter_batch(batch_, matrix_run_context, qatools_config, default_inputs_settings, debug)
Expand Down Expand Up @@ -316,10 +318,10 @@ def iter_batch(batch: Dict, default_run_context: RunContext, qatools_config, def
for k in ['type', 'database', runner, 'platform', 'glob', 'globs', 'use_parent_folder']:
if k in location_configurations:
del location_configurations[k]
if location_configurations and 'configurations' not in location_configurations and 'configurations' not in location_configurations:
if location_configurations and 'configs' not in location_configurations and 'configurations' not in location_configurations and 'configurations' not in location_configurations:
location_run_context.configurations = [*location_run_context.configurations, location_configurations]
else:
patch_config = location_configurations.get('configurations', location_configurations.get('configuration', []))
patch_config = location_configurations.get('configs', location_configurations.get('configurations', location_configurations.get('configuration', [])))
location_run_context.configurations = [*location_run_context.configurations, *patch_config]
elif isinstance(location_configurations, list):
location_run_context.configurations = [*location_run_context.configurations, *list(flatten(location_configurations))]
Expand Down
Loading

0 comments on commit 7a4f798

Please sign in to comment.