diff --git a/.ci/Jenkinsfile b/.ci/Jenkinsfile index 6c6bc86baa..74c6a2078f 100644 --- a/.ci/Jenkinsfile +++ b/.ci/Jenkinsfile @@ -76,6 +76,7 @@ pipeline { } stage('Build') { steps { + sh 'pip install --upgrade --user pip' sh 'pip install --user .[all]' // To be able to do ssh localhost sh 'ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa' diff --git a/.ci/test_plugin_testcase.py b/.ci/test_plugin_testcase.py index 838f03f2f4..0a8df61324 100644 --- a/.ci/test_plugin_testcase.py +++ b/.ci/test_plugin_testcase.py @@ -55,7 +55,7 @@ def get_computer(cls, temp_dir): from aiida import orm computer = orm.Computer( - name='localhost', + label='localhost', hostname='localhost', description='my computer', transport_type='local', @@ -80,7 +80,7 @@ def test_computer_loaded(self): work after resetting the DB. """ from aiida import orm - self.assertEqual(orm.Computer.objects.get(name='localhost').uuid, self.computer.uuid) + self.assertEqual(orm.Computer.objects.get(label='localhost').uuid, self.computer.uuid) def test_tear_down(self): """ diff --git a/.ci/workchains.py b/.ci/workchains.py index 5504813a10..e94f44669c 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name +"""Work chain implementations for testing purposes.""" from aiida.common import AttributeDict from aiida.engine import calcfunction, workfunction, WorkChain, ToContext, append_, while_, ExitCode from aiida.engine import BaseRestartWorkChain, process_handler, ProcessHandlerReport @@ -15,7 +16,6 @@ from aiida.orm import Int, List, Str from aiida.plugins import CalculationFactory - ArithmeticAddCalculation = CalculationFactory('arithmetic.add') @@ -54,15 +54,15 @@ def setup(self): def sanity_check_not_too_big(self, node): """My puny brain cannot deal with numbers that I cannot count on my hand.""" if node.is_finished_ok and node.outputs.sum > 10: - return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) + return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) # pylint: disable=no-member @process_handler(priority=460, enabled=False) - def disabled_handler(self, node): + def disabled_handler(self, node): # pylint: disable=unused-argument """By default this is not enabled and so should never be called, irrespective of exit codes of sub process.""" - return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) + return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) # pylint: disable=no-member @process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered')) - def a_magic_unicorn_appeared(self, node): + def a_magic_unicorn_appeared(self, node): # pylint: disable=no-self-argument,no-self-use """As we all know unicorns do not exist so we should never have to deal with it.""" raise RuntimeError('this handler should never even have been called') @@ -78,30 +78,24 @@ class NestedWorkChain(WorkChain): """ Nested workchain which creates a workflow where the nesting level is equal to its input. """ + @classmethod def define(cls, spec): super().define(spec) spec.input('inp', valid_type=Int) - spec.outline( - cls.do_submit, - cls.finalize - ) + spec.outline(cls.do_submit, cls.finalize) spec.output('output', valid_type=Int, required=True) def do_submit(self): if self.should_submit(): self.report('Submitting nested workchain.') - return ToContext( - workchain=append_(self.submit( - NestedWorkChain, - inp=self.inputs.inp - 1 - )) - ) + return ToContext(workchain=append_(self.submit(NestedWorkChain, inp=self.inputs.inp - 1))) def should_submit(self): return int(self.inputs.inp) > 0 def finalize(self): + """Attach the outputs.""" if self.should_submit(): self.report('Getting sub-workchain output.') sub_workchain = self.ctx.workchain[0] @@ -112,15 +106,13 @@ def finalize(self): class SerializeWorkChain(WorkChain): + """Work chain that serializes inputs.""" + @classmethod def define(cls, spec): super().define(spec) - spec.input( - 'test', - valid_type=Str, - serializer=lambda x: Str(ObjectLoader().identify_object(x)) - ) + spec.input('test', valid_type=Str, serializer=lambda x: Str(ObjectLoader().identify_object(x))) spec.outline(cls.echo) spec.outputs.dynamic = True @@ -130,6 +122,8 @@ def echo(self): class NestedInputNamespace(WorkChain): + """Work chain with nested namespace.""" + @classmethod def define(cls, spec): super().define(spec) @@ -143,6 +137,8 @@ def do_echo(self): class ListEcho(WorkChain): + """Work chain that simply echos a `List` input.""" + @classmethod def define(cls, spec): super().define(spec) @@ -157,6 +153,8 @@ def do_echo(self): class DynamicNonDbInput(WorkChain): + """Work chain with dynamic non_db inputs.""" + @classmethod def define(cls, spec): super().define(spec) @@ -172,6 +170,8 @@ def do_test(self): class DynamicDbInput(WorkChain): + """Work chain with dynamic input namespace.""" + @classmethod def define(cls, spec): super().define(spec) @@ -186,6 +186,8 @@ def do_test(self): class DynamicMixedInput(WorkChain): + """Work chain with dynamic mixed input.""" + @classmethod def define(cls, spec): super().define(spec) @@ -194,6 +196,7 @@ def define(cls, spec): spec.outline(cls.do_test) def do_test(self): + """Run the test.""" input_non_db = self.inputs.namespace.inputs['input_non_db'] input_db = self.inputs.namespace.inputs['input_db'] assert isinstance(input_non_db, int) @@ -206,6 +209,7 @@ class CalcFunctionRunnerWorkChain(WorkChain): """ WorkChain which calls an InlineCalculation in its step. """ + @classmethod def define(cls, spec): super().define(spec) @@ -223,6 +227,7 @@ class WorkFunctionRunnerWorkChain(WorkChain): """ WorkChain which calls a workfunction in its step """ + @classmethod def define(cls, spec): super().define(spec) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1fe8637a62..d7fa7b9339 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,9 +13,7 @@ jobs: keys: - cache-pip - - run: | - pip install numpy==1.17.4 - pip install --user .[docs,tests] + - run: pip install --user .[docs,tests] - save_cache: key: cache-pip diff --git a/.github/config/localhost-config.yaml b/.github/config/localhost-config.yaml index e6a7912c85..ef0ca22365 100644 --- a/.github/config/localhost-config.yaml +++ b/.github/config/localhost-config.yaml @@ -1,2 +1,3 @@ --- +use_login_shell: true safe_interval: 0 diff --git a/.github/config/profile.yaml b/.github/config/profile.yaml index e58ab2821d..009e3ed0ff 100644 --- a/.github/config/profile.yaml +++ b/.github/config/profile.yaml @@ -11,4 +11,10 @@ db_port: 5432 db_name: PLACEHOLDER_DATABASE_NAME db_username: postgres db_password: '' +broker_protocol: amqp +broker_username: guest +broker_password: guest +broker_host: 127.0.0.1 +broker_port: 5672 +broker_virtual_host: '' repository: PLACEHOLDER_REPOSITORY diff --git a/.github/workflows/benchmark-config.json b/.github/workflows/benchmark-config.json new file mode 100644 index 0000000000..cc698e011b --- /dev/null +++ b/.github/workflows/benchmark-config.json @@ -0,0 +1,39 @@ +{ + "suites": { + "pytest-benchmarks:ubuntu-18.04,django": { + "header": "Performance Benchmarks (Ubuntu-18.04, Django)", + "description": "Performance benchmark tests, generated using pytest-benchmark." + }, + "pytest-benchmarks:ubuntu-18.04,sqlalchemy": { + "header": "Performance Benchmarks (Ubuntu-18.04, SQLAlchemy)", + "description": "Performance benchmark tests, generated using pytest-benchmark." + } + }, + "groups": { + "node": { + "header": "Single Node", + "description": "Comparison of basic node interactions, such as storage and deletion from the database.", + "single_chart": true, + "xAxis": "id", + "backgroundFill": false, + "yAxisFormat": "logarithmic" + }, + "engine": { + "header": "Processes", + "description": "Comparison of Processes, executed via both local and daemon runners.", + "single_chart": true, + "xAxis": "id", + "backgroundFill": false, + "legendAlign": "start", + "yAxisFormat": "logarithmic" + }, + "import-export": { + "header": "Import-Export", + "description": "Comparison of import/export of provenance trees.", + "single_chart": true, + "xAxis": "id", + "backgroundFill": false, + "yAxisFormat": "logarithmic" + } + } +} diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 0000000000..450b33d1a8 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,156 @@ +name: Performance benchmarks + +on: + push: + branches: [develop] + paths-ignore: ['docs/**'] + pull_request: + branches: [develop] + +jobs: + + run-and-upload: + + if: ${{ github.event_name == 'push' }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-18.04] + postgres: [12.3] + rabbitmq: [3.8.3] + backend: ['django', 'sqlalchemy'] + + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + + services: + postgres: + image: "postgres:${{ matrix.postgres }}" + env: + POSTGRES_DB: test_${{ matrix.backend }} + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: "rabbitmq:${{ matrix.rabbitmq }}" + ports: + - 5672:5672 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install python dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/requirements-py-3.8.txt + pip install --no-deps -e . + reentry scan + pip freeze + + - name: Run benchmarks + env: + AIIDA_TEST_BACKEND: ${{ matrix.backend }} + run: pytest --benchmark-only --benchmark-json benchmark.json + + - name: Store benchmark result + uses: aiidateam/github-action-benchmark@v3 + with: + benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/${{ matrix.backend }}" + name: "pytest-benchmarks:${{ matrix.os }},${{ matrix.backend }}" + metadata: "postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }}" + output-file-path: benchmark.json + render-json-path: .github/workflows/benchmark-config.json + commit-msg-append: "[ci skip]" + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '200%' + comment-on-alert: true + fail-on-alert: false + alert-comment-cc-users: '@chrisjsewell,@giovannipizzi' + + run-on-comment: + + if: ${{ github.event_name == 'pull_request' }} + + strategy: + matrix: + os: [ubuntu-18.04] + postgres: [12.3] + rabbitmq: [3.8.3] + backend: ['django'] + + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + + services: + postgres: + image: "postgres:${{ matrix.postgres }}" + env: + POSTGRES_DB: test_${{ matrix.backend }} + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: "rabbitmq:${{ matrix.rabbitmq }}" + ports: + - 5672:5672 + + steps: + # v2 was checking out the wrong commit! https://github.com/actions/checkout/issues/299 + - uses: actions/checkout@v1 + + - name: get commit message + run: echo ::set-env name=commitmsg::$(git log --format=%B -n 1 "${{ github.event.after }}") + + - if: contains( env.commitmsg , '[run bench]' ) + name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - if: contains( env.commitmsg , '[run bench]' ) + name: Install python dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/requirements-py-3.8.txt + pip install --no-deps -e . + reentry scan + pip freeze + + - if: contains( env.commitmsg , '[run bench]' ) + name: Run benchmarks + env: + AIIDA_TEST_BACKEND: ${{ matrix.backend }} + run: pytest --benchmark-only --benchmark-json benchmark.json + + - if: contains( env.commitmsg , '[run bench]' ) + name: Compare benchmark results + uses: aiidateam/github-action-benchmark@v3 + with: + output-file-path: benchmark.json + name: "pytest-benchmarks:${{ matrix.os }},${{ matrix.backend }}" + benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/${{ matrix.backend }}" + metadata: "postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }}" + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: false + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '200%' + comment-always: true + fail-on-alert: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci-code.yml similarity index 79% rename from .github/workflows/ci.yml rename to .github/workflows/ci-code.yml index c835e4d406..bc047ab216 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci-code.yml @@ -1,38 +1,14 @@ name: continuous-integration -on: [push, pull_request] +on: + push: + branches-ignore: [gh-pages] + pull_request: + branches-ignore: [gh-pages] + paths-ignore: ['docs/**'] jobs: - pre-commit: - - runs-on: ubuntu-latest - timeout-minutes: 30 - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.7 - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - - name: Install system dependencies - run: | - sudo rm -f /etc/apt/sources.list.d/dotnetdev.list /etc/apt/sources.list.d/microsoft-prod.list - sudo apt update - sudo apt install libkrb5-dev ruby ruby-dev - - - name: Install python dependencies - run: | - pip install numpy==1.17.4 - pip install -e .[all] - pip freeze - - - name: Run pre-commit - run: - pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) - check-requirements: runs-on: ubuntu-latest @@ -42,12 +18,12 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.8 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.8 - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 toml + run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit - name: Check requirements files id: check_reqs @@ -99,7 +75,7 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} @@ -120,8 +96,8 @@ jobs: - name: Install aiida-core run: | - pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt - pip install --no-deps -e . + pip install --use-feature=2020-resolver -r requirements/requirements-py-${{ matrix.python-version }}.txt + pip install --use-feature=2020-resolver --no-deps -e . reentry scan pip freeze @@ -144,7 +120,7 @@ jobs: name: aiida-pytests-py3.5-${{ matrix.backend }} flags: ${{ matrix.backend }} file: ./coverage.xml - fail_ci_if_error: true + fail_ci_if_error: false # don't fail job, if coverage upload fails verdi: @@ -154,15 +130,13 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v1 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install python dependencies - run: | - pip install numpy==1.17.4 - pip install -e . + run: pip install -e . - name: Run verdi run: | diff --git a/.github/workflows/ci-style.yml b/.github/workflows/ci-style.yml new file mode 100644 index 0000000000..b8c2b75720 --- /dev/null +++ b/.github/workflows/ci-style.yml @@ -0,0 +1,37 @@ +name: continuous-integration + +on: + push: + branches-ignore: [gh-pages] + pull_request: + branches-ignore: [gh-pages] + +jobs: + + pre-commit: + + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install system dependencies + run: | + sudo rm -f /etc/apt/sources.list.d/dotnetdev.list /etc/apt/sources.list.d/microsoft-prod.list + sudo apt update + sudo apt install libkrb5-dev ruby ruby-dev + + - name: Install python dependencies + run: | + pip install -e .[all] + pip freeze + + - name: Run pre-commit + run: + pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 44381803fd..d55fcfa698 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -21,13 +21,12 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.7 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.7 - name: Install python dependencies run: | - pip install numpy==1.17.4 pip install transifex-client sphinx-intl pip install -e .[docs,tests] diff --git a/.github/workflows/setup.sh b/.github/workflows/setup.sh index 75cb35c612..6ff5c4c6e0 100755 --- a/.github/workflows/setup.sh +++ b/.github/workflows/setup.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -ev -ssh-keygen -q -t rsa -b 4096 -m PEM -N "" -f "${HOME}/.ssh/id_rsa" +ssh-keygen -q -t rsa -b 4096 -N "" -f "${HOME}/.ssh/id_rsa" ssh-keygen -y -f "${HOME}/.ssh/id_rsa" >> "${HOME}/.ssh/authorized_keys" ssh-keyscan -H localhost >> "${HOME}/.ssh/known_hosts" diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index 41902e5a11..b1d2603af7 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -9,6 +9,7 @@ on: - 'pyproject.toml' - 'util/dependency_management.py' - '.github/workflows/test-install.yml' + branches-ignore: [gh-pages] schedule: - cron: '30 02 * * *' # nightly build @@ -17,6 +18,7 @@ jobs: validate-dependency-specification: # Note: The specification is also validated by the pre-commit hook. + if: github.repository == 'aiidateam/aiida-core' runs-on: ubuntu-latest timeout-minutes: 5 @@ -24,40 +26,58 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.8 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.8 - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 toml + run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit - name: Validate run: python ./utils/dependency_management.py validate-all install-with-pip: + if: github.repository == 'aiidateam/aiida-core' runs-on: ubuntu-latest timeout-minutes: 5 + continue-on-error: ${{ contains(matrix.pip-feature-flag, '2020-resolver') }} + strategy: + fail-fast: false + matrix: + pip-feature-flag: [ '', '--use-feature=2020-resolver' ] + extras: [ '', '[atomic_tools,docs,notebook,rest,tests]' ] + steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v1 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Pip install + id: pip_install + continue-on-error: ${{ contains(matrix.pip-feature-flag, '2020-resolver') }} run: | - python -m pip install -e . + python -m pip --version + python -m pip install -e .${{ matrix.extras }} ${{ matrix.pip-feature-flag }} python -m pip freeze - name: Test importing aiida + if: steps.pip_install.outcome == 'success' run: python -c "import aiida" + - name: Warn about pip 2020 resolver issues. + if: steps.pip_install.outcome == 'failure' && contains(matrix.pip-feature-flag, '2020-resolver') + run: | + echo "::warning ::Encountered issues with the pip 2020-resolver." + install-with-conda: + if: github.repository == 'aiidateam/aiida-core' runs-on: ubuntu-latest name: install-with-conda @@ -70,7 +90,7 @@ jobs: uses: s-weigand/setup-conda@v1 with: update-conda: true - python-version: 3.7 + python-version: 3.8 - run: conda --version - run: python --version - run: which python @@ -121,7 +141,7 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} @@ -183,12 +203,12 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.8 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.8 - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 toml + run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit - name: Check consistency of requirements/ files id: check_reqs diff --git a/.github/workflows/verdi.sh b/.github/workflows/verdi.sh index f2871bae96..103c1f54e5 100755 --- a/.github/workflows/verdi.sh +++ b/.github/workflows/verdi.sh @@ -34,3 +34,6 @@ while true; do fi done + +$VERDI devel check-load-time +$VERDI devel check-undesired-imports diff --git a/.gitignore b/.gitignore index 82fa56577f..624d61e194 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ *.egg-info .eggs .vscode +.tox # files created by coverage .cache @@ -32,4 +33,4 @@ pip-wheel-metadata # Docs docs/build -docs/source/apidoc +docs/source/reference/apidoc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3e83f5dd8..7ceff0877c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,104 +8,16 @@ repos: - id: mixed-line-ending - id: trailing-whitespace -- repo: https://github.com/PyCQA/pylint - rev: pylint-2.5.2 - hooks: - - id: pylint - language: system - exclude: &exclude_files > - (?x)^( - .ci/workchains.py| - aiida/backends/djsite/queries.py| - aiida/backends/djsite/db/models.py| - aiida/backends/djsite/db/migrations/0001_initial.py| - aiida/backends/djsite/db/migrations/0002_db_state_change.py| - aiida/backends/djsite/db/migrations/0003_add_link_type.py| - aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py| - aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py| - aiida/backends/djsite/db/migrations/0006_delete_dbpath.py| - aiida/backends/djsite/db/migrations/0007_update_linktypes.py| - aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py| - aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py| - aiida/backends/djsite/db/migrations/0010_process_type.py| - aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py| - aiida/backends/djsite/db/migrations/0012_drop_dblock.py| - aiida/backends/djsite/db/migrations/0013_django_1_8.py| - aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py| - aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py| - aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py| - aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py| - aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py| - aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py| - aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py| - aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py| - aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py| - aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py| - aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py| - aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py| - aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py| - aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py| - aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py| - aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py| - aiida/backends/sqlalchemy/models/computer.py| - aiida/backends/sqlalchemy/models/settings.py| - aiida/backends/sqlalchemy/models/node.py| - aiida/backends/utils.py| - aiida/common/datastructures.py| - aiida/engine/daemon/execmanager.py| - aiida/engine/processes/calcjobs/tasks.py| - aiida/orm/querybuilder.py| - aiida/orm/nodes/data/array/bands.py| - aiida/orm/nodes/data/array/projection.py| - aiida/orm/nodes/data/array/xy.py| - aiida/orm/nodes/data/code.py| - aiida/orm/nodes/data/orbital.py| - aiida/orm/nodes/data/remote.py| - aiida/orm/nodes/data/structure.py| - aiida/orm/utils/remote.py| - aiida/parsers/plugins/arithmetic/add.py| - aiida/parsers/plugins/templatereplacer/doubler.py| - aiida/parsers/plugins/templatereplacer/__init__.py| - aiida/plugins/entry.py| - aiida/plugins/info.py| - aiida/plugins/registry.py| - aiida/tools/data/array/kpoints/legacy.py| - aiida/tools/data/array/kpoints/seekpath.py| - aiida/tools/data/__init__.py| - aiida/tools/dbexporters/__init__.py| - aiida/tools/dbimporters/baseclasses.py| - aiida/tools/dbimporters/__init__.py| - aiida/tools/dbimporters/plugins/cod.py| - aiida/tools/dbimporters/plugins/icsd.py| - aiida/tools/dbimporters/plugins/__init__.py| - aiida/tools/dbimporters/plugins/mpds.py| - aiida/tools/dbimporters/plugins/mpod.py| - aiida/tools/dbimporters/plugins/nninc.py| - aiida/tools/dbimporters/plugins/oqmd.py| - aiida/tools/dbimporters/plugins/pcod.py| - docs/.*| - examples/.*| - tests/engine/test_work_chain.py| - tests/schedulers/test_direct.py| - tests/schedulers/test_lsf.py| - tests/schedulers/test_pbspro.py| - tests/schedulers/test_sge.py| - tests/schedulers/test_torque.py| - tests/sphinxext/workchain_source/conf.py| - tests/sphinxext/workchain_source_broken/conf.py| - tests/transports/test_all_plugins.py| - tests/transports/test_local.py| - tests/transports/test_ssh.py| - tests/test_dataclasses.py| - )$ - - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.30.0 hooks: - id: yapf name: yapf types: [python] - exclude: *exclude_files + exclude: &exclude_files > + (?x)^( + docs/.*| + )$ args: ['-i'] - repo: https://github.com/pre-commit/mirrors-mypy @@ -120,7 +32,15 @@ repos: )$ - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + types: [python] + language: system + exclude: *exclude_files + - id: dm-generate-all name: Update all requirements files entry: python ./utils/dependency_management.py generate-all diff --git a/.pylintrc b/.pylintrc index 5816e54bd0..eca9d52a8b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -60,7 +60,8 @@ disable=bad-continuation, import-outside-toplevel, cyclic-import, duplicate-code, - too-few-public-methods + too-few-public-methods, + inconsistent-return-statements # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/.readthedocs.yml b/.readthedocs.yml index 1bf9dbfc90..ffc1ec2a59 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,3 +20,7 @@ python: sphinx: builder: html fail_on_warning: true + +search: + ranking: + reference/apidoc/*: -7 diff --git a/CHANGELOG.md b/CHANGELOG.md index a5db4880f4..f705c0a3ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,51 @@ # Changelog +## v1.4.0 + +### Improvements +- Add defaults for configure options of the `SshTransport` plugin [[#4223]](https://github.com/aiidateam/aiida-core/pull/4223) +- `verdi status`: distinguish database schema version incompatible [[#4319]](https://github.com/aiidateam/aiida-core/pull/4319) +- `SlurmScheduler`: implement `parse_output` to detect OOM and OOW [[#3931]](https://github.com/aiidateam/aiida-core/pull/3931) + +### Features +- Make the RabbitMQ connection parameters configurable [[#4341]](https://github.com/aiidateam/aiida-core/pull/4341) +- Add infrastructure to parse scheduler output for `CalcJobs` [[#3906]](https://github.com/aiidateam/aiida-core/pull/3906) +- Add support for "peer" authentication with PostgreSQL [[#4255]](https://github.com/aiidateam/aiida-core/pull/4255) +- Add the `--paused` flag to `verdi process list` [[#4213]](https://github.com/aiidateam/aiida-core/pull/4213) +- Make the loglevel of the daemonizer configurable [[#4276]](https://github.com/aiidateam/aiida-core/pull/4276) +- `Transport`: add option to not use a login shell for all commands [[#4271]](https://github.com/aiidateam/aiida-core/pull/4271) +- Implement `skip_orm` option for SqlAlchemy `Group.remove_nodes` [[#4214]](https://github.com/aiidateam/aiida-core/pull/4214) +- `Dict`: allow setting attributes through setitem and `AttributeManager` [[#4351]](https://github.com/aiidateam/aiida-core/pull/4351) +- `CalcJob`: allow nested target paths for `local_copy_list` [[#4373]](https://github.com/aiidateam/aiida-core/pull/4373) +- `verdi export migrate`: add `--in-place` flag to migrate archive in place [[#4220]](https://github.com/aiidateam/aiida-core/pull/4220) + +### Bug fixes +- `verdi`: make `--prepend-text` and `--append-text` options properly interactive [[#4318]](https://github.com/aiidateam/aiida-core/pull/4318) +- `verdi computer test`: fix failing result in harmless `stderr` responses [[#4316]](https://github.com/aiidateam/aiida-core/pull/4316) +- `QueryBuilder`: Accept empty string for `entity_type` in `append` method [[#4299]](https://github.com/aiidateam/aiida-core/pull/4299) +- `verdi status`: do not except when no profile is configured [[#4253]](https://github.com/aiidateam/aiida-core/pull/4253) +- `ArithmeticAddParser`: attach output before checking for negative value [[#4267]](https://github.com/aiidateam/aiida-core/pull/4267) +- `CalcJob`: fix bug in `retrieve_list` affecting entries without wildcards [[#4275]](https://github.com/aiidateam/aiida-core/pull/4275) +- `TemplateReplacerCalculation`: make `files` namespace dynamic [[#4348]](https://github.com/aiidateam/aiida-core/pull/4348) + +### Developers +- Rename folder `test.fixtures` to `test.static` [[#4219]](https://github.com/aiidateam/aiida-core/pull/4219) +- Remove all files from the pre-commit exclude list [[#4196]](https://github.com/aiidateam/aiida-core/pull/4196) +- ORM: move attributes/extras methods of frontend and backend nodes to mixins [[#4376]](https://github.com/aiidateam/aiida-core/pull/4376) + +### Dependencies +- Dependencies: update minimum requirement `paramiko~=2.7` [[#4222]](https://github.com/aiidateam/aiida-core/pull/4222) +- Depedencies: remove upper limit and allow `numpy~=1.17` [[#4378]](https://github.com/aiidateam/aiida-core/pull/4378) + +### Deprecations +- Deprecate getter and setter methods of `Computer` properties [[#4252]](https://github.com/aiidateam/aiida-core/pull/4252) +- Deprecate methods that refer to a computer's label as name [[#4309]](https://github.com/aiidateam/aiida-core/pull/4309) + +### Changes +- `BaseRestartWorkChain`: do not run `process_handler` when `exit_codes=[]` [[#4380]](https://github.com/aiidateam/aiida-core/pull/4380) +- `SlurmScheduler`: always raise for non-zero exit code [[#4332]](https://github.com/aiidateam/aiida-core/pull/4332) +- Remove superfluous `ERROR_NO_RETRIEVED_FOLDER` from `CalcJob` subclasses [[#3906]](https://github.com/aiidateam/aiida-core/pull/3906) + ## v1.3.1 diff --git a/Dockerfile b/Dockerfile index 7404a78bdc..849a1bd5ff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM aiidateam/aiida-prerequisites:latest +FROM aiidateam/aiida-prerequisites:0.2.1 USER root diff --git a/README.md b/README.md index 1012679fb1..f04634d30b 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,10 @@ If you are experiencing problems with your AiiDA installation, please refer to t ## How to cite -If you use AiiDA in your research, please consider citing the AiiDA paper: +If you use AiiDA in your research, please consider citing the following publications: -> Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari, -> and Boris Kozinsky, *AiiDA: automated interactive infrastructure and -> database for computational science*, Comp. Mat. Sci 111, 218-230 -> (2016); ; -> . + * **AiiDA >= 1.0**: S. P. Huber *et al.*, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) + * **AiiDA < 1.0**: Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari,and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Comp. Mat. Sci **111**, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) ## License diff --git a/aiida/__init__.py b/aiida/__init__.py index 16b527a91a..96ea2b0ce7 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -31,7 +31,7 @@ 'For further information please visit http://www.aiida.net/. All rights reserved.' ) __license__ = 'MIT license, see LICENSE.txt file.' -__version__ = '1.3.1' +__version__ = '1.4.0' __authors__ = 'The AiiDA team.' __paper__ = ( 'G. Pizzi, A. Cepellotti, R. Sabatini, N. Marzari, and B. Kozinsky,' diff --git a/aiida/backends/djsite/db/migrations/0001_initial.py b/aiida/backends/djsite/db/migrations/0001_initial.py index 15d19b72ac..0ea8397da0 100644 --- a/aiida/backends/djsite/db/migrations/0001_initial.py +++ b/aiida/backends/djsite/db/migrations/0001_initial.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import django.db.models.deletion import django.utils.timezone @@ -19,6 +20,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('auth', '0001_initial'), @@ -53,8 +55,8 @@ class Migration(migrations.Migration): 'is_active', models.BooleanField( default=True, - help_text= - 'Designates whether this user should be treated as active. Unselect this instead of deleting accounts.' + help_text='Designates whether this user should be treated as active. Unselect this instead of ' + 'deleting accounts.' ) ), ('date_joined', models.DateTimeField(default=django.utils.timezone.now)), @@ -65,8 +67,8 @@ class Migration(migrations.Migration): related_name='user_set', to='auth.Group', blank=True, - help_text= - 'The groups this user belongs to. A user will get all permissions granted to each of his/her group.', + help_text='The groups this user belongs to. A user will get all permissions granted to each of ' + 'his/her group.', verbose_name='groups' ) ), diff --git a/aiida/backends/djsite/db/migrations/0002_db_state_change.py b/aiida/backends/djsite/db/migrations/0002_db_state_change.py index cfc0fdcd2c..2ac6d980c4 100644 --- a/aiida/backends/djsite/db/migrations/0002_db_state_change.py +++ b/aiida/backends/djsite/db/migrations/0002_db_state_change.py @@ -7,24 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.2' DOWN_REVISION = '1.0.1' -def fix_calc_states(apps, schema_editor): +def fix_calc_states(apps, _): + """Fix calculation states.""" from aiida.orm.utils import load_node # These states should never exist in the database but we'll play it safe # and deal with them if they do DbCalcState = apps.get_model('db', 'DbCalcState') - for calc_state in DbCalcState.objects.filter( - state__in=['UNDETERMINED', 'NOTFOUND']): + for calc_state in DbCalcState.objects.filter(state__in=['UNDETERMINED', 'NOTFOUND']): old_state = calc_state.state calc_state.state = 'FAILED' calc_state.save() @@ -32,11 +32,13 @@ def fix_calc_states(apps, schema_editor): calc = load_node(pk=calc_state.dbnode.pk) calc.logger.warning( 'Job state {} found for calculation {} which should never be in ' - 'the database. Changed state to FAILED.'.format( - old_state, calc_state.dbnode.pk)) + 'the database. Changed state to FAILED.'.format(old_state, calc_state.dbnode.pk) + ) class Migration(migrations.Migration): + """Database migration.""" + dependencies = [ ('db', '0001_initial'), ] @@ -47,14 +49,15 @@ class Migration(migrations.Migration): name='state', # The UNDETERMINED and NOTFOUND 'states' were removed as these # don't make sense - field=models.CharField(db_index=True, max_length=25, - choices=[('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), - ('RETRIEVING', 'RETRIEVING'), ('WITHSCHEDULER', 'WITHSCHEDULER'), - ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), ('PARSING', 'PARSING'), - ('FAILED', 'FAILED'), ('FINISHED', 'FINISHED'), - ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), - ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), - ('PARSINGFAILED', 'PARSINGFAILED')]), + field=models.CharField( + db_index=True, + max_length=25, + choices=[('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), ('RETRIEVING', 'RETRIEVING'), + ('WITHSCHEDULER', 'WITHSCHEDULER'), ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), + ('PARSING', 'PARSING'), ('FAILED', 'FAILED'), + ('FINISHED', 'FINISHED'), ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), + ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), ('PARSINGFAILED', 'PARSINGFAILED')] + ), preserve_default=True, ), # Fix up any calculation states that had one of the removed states diff --git a/aiida/backends/djsite/db/migrations/0003_add_link_type.py b/aiida/backends/djsite/db/migrations/0003_add_link_type.py index 40117da428..24e32381b7 100644 --- a/aiida/backends/djsite/db/migrations/0003_add_link_type.py +++ b/aiida/backends/djsite/db/migrations/0003_add_link_type.py @@ -7,17 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import aiida.common.timezone from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.3' DOWN_REVISION = '1.0.2' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0002_db_state_change'), diff --git a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py b/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py index 5bf78b5bdf..cb53ff3d6e 100644 --- a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py +++ b/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py @@ -7,18 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.4' DOWN_REVISION = '1.0.3' class Migration(migrations.Migration): + """Database migration.""" + dependencies = [ ('db', '0003_add_link_type'), ] @@ -27,20 +29,20 @@ class Migration(migrations.Migration): # Create the index that speeds up the daemon queries # We use the RunSQL command because Django interface # doesn't seem to support partial indexes - migrations.RunSQL(""" + migrations.RunSQL( + """ CREATE INDEX tval_idx_for_daemon ON db_dbattribute (tval) WHERE ("db_dbattribute"."tval" - IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))"""), + IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))""" + ), # Create an index on UUIDs to speed up loading of nodes # using this field migrations.AlterField( model_name='dbnode', name='uuid', - field=models.CharField(max_length=36,db_index=True, - editable=False, - blank=True), + field=models.CharField(max_length=36, db_index=True, editable=False, blank=True), preserve_default=True, ), upgrade_schema_version(REVISION, DOWN_REVISION) diff --git a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py b/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py index 71a901f1a0..11c7e99953 100644 --- a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py +++ b/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py @@ -7,17 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import aiida.common.timezone from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.5' DOWN_REVISION = '1.0.4' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0004_add_daemon_and_uuid_indices'), diff --git a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py b/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py index 905c459960..134b52d8c7 100644 --- a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py +++ b/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.6' DOWN_REVISION = '1.0.5' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0005_add_cmtime_indices'), @@ -35,13 +36,12 @@ class Migration(migrations.Migration): model_name='dbnode', name='children', ), - migrations.DeleteModel( - name='DbPath', - ), - migrations.RunSQL(""" + migrations.DeleteModel(name='DbPath',), + migrations.RunSQL( + """ DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink; DROP FUNCTION IF EXISTS update_tc(); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py b/aiida/backends/djsite/db/migrations/0007_update_linktypes.py index c694cb793b..a966516b29 100644 --- a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py +++ b/aiida/backends/djsite/db/migrations/0007_update_linktypes.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.8' DOWN_REVISION = '1.0.7' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0006_delete_dbpath'), @@ -36,7 +37,8 @@ class Migration(migrations.Migration): # - joins a Data (or subclass) as output # - is marked as a returnlink. # 2) set for these links the type to 'createlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -46,7 +48,8 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'data.%' AND db_dblink_1.type = 'returnlink' ); - """), + """ + ), # Now I am updating the link-types that are null because of either an export and subsequent import # https://github.com/aiidateam/aiida-core/issues/685 # or because the link types don't exist because the links were added before the introduction of link types. @@ -55,10 +58,11 @@ class Migration(migrations.Migration): # The following sql statement: # 1) selects all links that # - joins Data (or subclass) or Code as input - # - joins Calculation (or subclass) as output. This includes WorkCalculation, InlineCalcuation, JobCalculations... + # - joins Calculation (or subclass) as output: includes WorkCalculation, InlineCalcuation, JobCalculations... # - has no type (null) # 2) set for these links the type to 'inputlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='inputlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -68,7 +72,8 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # # The following sql statement: # 1) selects all links that @@ -76,7 +81,8 @@ class Migration(migrations.Migration): # - joins Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'createlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='createlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -90,14 +96,16 @@ class Migration(migrations.Migration): ) AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # The following sql statement: # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked for. + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked # - join Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'returnlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='returnlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -107,15 +115,17 @@ class Migration(migrations.Migration): AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # Now I update links that are CALLS: # The following sql statement: # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked for. + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. # - has no type (null) # 2) set for these links the type to 'calllink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='calllink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -125,7 +135,7 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py b/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py index 52b3292951..be65bd0bc7 100644 --- a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py +++ b/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.8' DOWN_REVISION = '1.0.7' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0007_update_linktypes'), @@ -28,16 +29,20 @@ class Migration(migrations.Migration): # we move that value to the extra table # # First we copy the 'hidden' attributes from code.Code. nodes to the db_extra table - migrations.RunSQL(""" + migrations.RunSQL( + """ INSERT INTO db_dbextra (key, datatype, tval, fval, ival, bval, dval, dbnode_id) ( - SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id + SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, + db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id FROM db_dbattribute JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' ); - """), + """ + ), # Secondly, we delete the original entries from the DbAttribute table - migrations.RunSQL(""" + migrations.RunSQL( + """ DELETE FROM db_dbattribute WHERE id in ( SELECT db_dbattribute.id @@ -45,6 +50,7 @@ class Migration(migrations.Migration): JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' ); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py b/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py index 2ca434e3e4..1a9317d0b1 100644 --- a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py +++ b/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.9' DOWN_REVISION = '1.0.8' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0008_code_hidden_to_extra'), @@ -26,12 +27,14 @@ class Migration(migrations.Migration): # The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their # module path changes, which determines the plugin type string which is stored in the databse. # The type string now will have a type string prefix that is unique to each sub type. - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0010_process_type.py b/aiida/backends/djsite/db/migrations/0010_process_type.py index 11fb34bd32..d1c36dc526 100644 --- a/aiida/backends/djsite/db/migrations/0010_process_type.py +++ b/aiida/backends/djsite/db/migrations/0010_process_type.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.10' DOWN_REVISION = '1.0.9' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0009_base_data_plugin_type_string'), @@ -24,9 +25,7 @@ class Migration(migrations.Migration): operations = [ migrations.AddField( - model_name='dbnode', - name='process_type', - field=models.CharField(max_length=255, db_index=True, null=True) + model_name='dbnode', name='process_type', field=models.CharField(max_length=255, db_index=True, null=True) ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py b/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py index 35cdeb715c..d3fcb91e1b 100644 --- a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py +++ b/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py @@ -7,23 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.11' DOWN_REVISION = '1.0.10' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0010_process_type'), ] operations = [ - migrations.RunSQL(""" + migrations.RunSQL( + """ DROP TABLE IF EXISTS kombu_message; DROP TABLE IF EXISTS kombu_queue; DELETE FROM db_dbsetting WHERE key = 'daemon|user'; @@ -33,6 +35,7 @@ class Migration(migrations.Migration): DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py b/aiida/backends/djsite/db/migrations/0012_drop_dblock.py index b89787f3f1..0c37ec8fd7 100644 --- a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py +++ b/aiida/backends/djsite/db/migrations/0012_drop_dblock.py @@ -7,24 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.12' DOWN_REVISION = '1.0.11' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0011_delete_kombu_tables'), ] - operations = [ - migrations.DeleteModel( - name='DbLock', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] + operations = [migrations.DeleteModel(name='DbLock',), upgrade_schema_version(REVISION, DOWN_REVISION)] diff --git a/aiida/backends/djsite/db/migrations/0013_django_1_8.py b/aiida/backends/djsite/db/migrations/0013_django_1_8.py index 6448e3b924..17d5b3a196 100644 --- a/aiida/backends/djsite/db/migrations/0013_django_1_8.py +++ b/aiida/backends/djsite/db/migrations/0013_django_1_8.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.13' DOWN_REVISION = '1.0.12' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0012_drop_dblock'), diff --git a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py b/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py index 48794a87c5..8d125f2196 100644 --- a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py +++ b/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py @@ -18,7 +18,7 @@ DOWN_REVISION = '1.0.13' -def verify_node_uuid_uniqueness(apps, schema_editor): +def verify_node_uuid_uniqueness(_, __): """Check whether the database contains nodes with duplicate UUIDS. Note that we have to redefine this method from aiida.manage.database.integrity.verify_node_uuid_uniqueness @@ -31,7 +31,7 @@ def verify_node_uuid_uniqueness(apps, schema_editor): verify_uuid_uniqueness(table='db_dbnode') -def reverse_code(apps, schema_editor): +def reverse_code(_, __): pass diff --git a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py b/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py index 62e8e446b9..f3ac3ca9c3 100644 --- a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py +++ b/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py @@ -8,9 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,too-few-public-methods -""" -Invalidating node hash - User should rehash nodes for caching -""" +"""Invalidating node hash - User should rehash nodes for caching.""" # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-name-in-module,import-error diff --git a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py b/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py index 611a324dc1..d1fe5fe1b2 100644 --- a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py +++ b/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version @@ -16,6 +17,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0015_invalidating_node_hash'), @@ -26,6 +28,7 @@ class Migration(migrations.Migration): # To make everything fully consistent, its type string should therefore also start with `data.` migrations.RunSQL( sql="""UPDATE db_dbnode SET type = 'data.code.Code.' WHERE type = 'code.Code.';""", - reverse_sql="""UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';"""), + reverse_sql="""UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';""" + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py b/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py index 4f7bcd904e..eda7694481 100644 --- a/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py +++ b/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version @@ -16,6 +17,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0016_code_sub_class_of_data'), diff --git a/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py b/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py new file mode 100644 index 0000000000..8f6216ecb2 --- /dev/null +++ b/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Migration to add the `extras` JSONB column to the `DbGroup` model.""" +# pylint: disable=invalid-name +import django.contrib.postgres.fields.jsonb +from django.db import migrations +from aiida.backends.djsite.db.migrations import upgrade_schema_version + +REVISION = '1.0.45' +DOWN_REVISION = '1.0.44' + + +class Migration(migrations.Migration): + """Migrate to add the extras column to the dbgroup table.""" + dependencies = [ + ('db', '0044_dbgroup_type_string'), + ] + + operations = [ + migrations.AddField( + model_name='dbgroup', + name='extras', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict, null=False), + ), + upgrade_schema_version(REVISION, DOWN_REVISION), + ] diff --git a/aiida/backends/djsite/db/migrations/__init__.py b/aiida/backends/djsite/db/migrations/__init__.py index 922c0f22fb..b361f04156 100644 --- a/aiida/backends/djsite/db/migrations/__init__.py +++ b/aiida/backends/djsite/db/migrations/__init__.py @@ -21,7 +21,7 @@ class DeserializationException(AiidaException): pass -LATEST_MIGRATION = '0044_dbgroup_type_string' +LATEST_MIGRATION = '0045_dbgroup_extras' def _update_schema_version(version, apps, _): @@ -273,7 +273,7 @@ def _deserialize_attribute(mainitem, subitems, sep, original_class=None, origina try: return json.loads(mainitem['tval']) except ValueError: - raise DeserializationException('Error in the content of the json field') + raise DeserializationException('Error in the content of the json field') from ValueError else: raise DeserializationException("The type field '{}' is not recognized".format(mainitem['datatype'])) @@ -426,7 +426,8 @@ def get_value_for_node(self, dbnode, key): try: attr = cls.objects.get(dbnode=dbnode_node, key=key) except ObjectDoesNotExist: - raise AttributeError('{} with key {} for node {} not found in db'.format(cls.__name__, key, dbnode.pk)) + raise AttributeError('{} with key {} for node {} not found in db'.format(cls.__name__, key, dbnode.pk)) \ + from ObjectDoesNotExist return self.getvalue(attr) @@ -674,7 +675,7 @@ def set_value( 'another entry already exists and the creation would ' 'violate an uniqueness constraint.\nFurther details: ' '{}'.format(cls.__name__, exc) - ) + ) from exc raise @staticmethod @@ -806,7 +807,7 @@ def create_value(self, key, value, subspecifier_value=None, other_attribs=None): raise ValueError( 'Unable to store the value: it must be either a basic datatype, or json-serializable: {}'. format(value) - ) + ) from TypeError new_entry.datatype = 'json' new_entry.tval = jsondata diff --git a/aiida/backends/djsite/db/models.py b/aiida/backends/djsite/db/models.py index c5c81924ad..bb2b5b380b 100644 --- a/aiida/backends/djsite/db/models.py +++ b/aiida/backends/djsite/db/models.py @@ -193,9 +193,7 @@ def __str__(self): return "'{}'={}".format(self.key, self.getvalue()) @classmethod - def set_value( - cls, key, value, with_transaction=True, subspecifier_value=None, other_attribs=None, stop_if_existing=False - ): + def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): """Delete a setting value.""" other_attribs = other_attribs if other_attribs is not None else {} setting = DbSetting.objects.filter(key=key).first() @@ -221,7 +219,7 @@ def get_description(self): return self.description @classmethod - def del_value(cls, key, only_children=False, subspecifier_value=None): + def del_value(cls, key): """Set a setting value.""" setting = DbSetting.objects.filter(key=key).first() @@ -256,6 +254,8 @@ class DbGroup(m.Model): # On user deletion, remove his/her groups too (not the calcuations, only # the groups user = m.ForeignKey(DbUser, on_delete=m.CASCADE, related_name='dbgroups') + # JSON Extras + extras = JSONField(default=dict, null=False) class Meta: # pylint: disable=too-few-public-methods unique_together = (('label', 'type_string'),) @@ -300,7 +300,6 @@ class DbComputer(m.Model): name = m.CharField(max_length=255, unique=True, blank=False) hostname = m.CharField(max_length=255) description = m.TextField(blank=True) - # TODO: next three fields should not be blank... scheduler_type = m.CharField(max_length=255) transport_type = m.CharField(max_length=255) metadata = JSONField(default=dict) diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py index 51a796af3a..c80ea0389d 100644 --- a/aiida/backends/djsite/manager.py +++ b/aiida/backends/djsite/manager.py @@ -181,7 +181,7 @@ def set(self, key, value, description=None): :param description: optional setting description """ from aiida.backends.djsite.db.models import DbSetting - from aiida.orm.utils.node import validate_attribute_extra_key + from aiida.orm.implementation.utils import validate_attribute_extra_key self.validate_table_existence() validate_attribute_extra_key(key) @@ -205,4 +205,4 @@ def delete(self, key): try: DbSetting.del_value(key=key) except KeyError: - raise NotExistent('setting `{}` does not exist'.format(key)) + raise NotExistent('setting `{}` does not exist'.format(key)) from KeyError diff --git a/aiida/backends/djsite/queries.py b/aiida/backends/djsite/queries.py index ff8121c7ab..53ed500305 100644 --- a/aiida/backends/djsite/queries.py +++ b/aiida/backends/djsite/queries.py @@ -110,6 +110,7 @@ def query_group(q_object, args): def get_bands_and_parents_structure(self, args): """Returns bands and closest parent structure.""" + # pylint: disable=too-many-locals from django.db.models import Q from aiida.backends.djsite.db import models from aiida.common.utils import grouper diff --git a/aiida/backends/general/abstractqueries.py b/aiida/backends/general/abstractqueries.py index 1851eca2c6..6e3b56812e 100644 --- a/aiida/backends/general/abstractqueries.py +++ b/aiida/backends/general/abstractqueries.py @@ -17,7 +17,7 @@ class AbstractQueryManager(abc.ABC): def __init__(self, backend): """ :param backend: The AiiDA backend - :type backend: :class:`aiida.orm.implementation.sql.SqlBackend` + :type backend: :class:`aiida.orm.implementation.sql.backends.SqlBackend` """ self._backend = backend diff --git a/aiida/backends/sqlalchemy/manager.py b/aiida/backends/sqlalchemy/manager.py index d74da1affe..ab69457ee1 100644 --- a/aiida/backends/sqlalchemy/manager.py +++ b/aiida/backends/sqlalchemy/manager.py @@ -179,7 +179,7 @@ def get(self, key): try: setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() except NoResultFound: - raise NotExistent('setting `{}` does not exist'.format(key)) + raise NotExistent('setting `{}` does not exist'.format(key)) from NoResultFound return Setting(key, setting.getvalue(), setting.description, setting.time) @@ -191,7 +191,7 @@ def set(self, key, value, description=None): :param description: optional setting description """ from aiida.backends.sqlalchemy.models.settings import DbSetting - from aiida.orm.utils.node import validate_attribute_extra_key + from aiida.orm.implementation.utils import validate_attribute_extra_key self.validate_table_existence() validate_attribute_extra_key(key) @@ -215,4 +215,4 @@ def delete(self, key): setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() setting.delete() except NoResultFound: - raise NotExistent('setting `{}` does not exist'.format(key)) + raise NotExistent('setting `{}` does not exist'.format(key)) from NoResultFound diff --git a/aiida/backends/sqlalchemy/migrations/env.py b/aiida/backends/sqlalchemy/migrations/env.py index c18b73c2f6..e2d5f246c1 100644 --- a/aiida/backends/sqlalchemy/migrations/env.py +++ b/aiida/backends/sqlalchemy/migrations/env.py @@ -55,7 +55,7 @@ def run_migrations_online(): if connectable is None: from aiida.common.exceptions import ConfigurationError - raise ConfigurationError('An initialized connection is expected ' 'for the AiiDA online migrations.') + raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') with connectable.connect() as connection: context.configure( diff --git a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py b/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py index d38f1fe672..66d8f7e0a8 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py +++ b/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the `transport_params` from the `Computer` database model. Revision ID: 07fac78e6209 diff --git a/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py index 0f3392ce6f..d73fd01407 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py +++ b/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Correct the type string for the base data types Revision ID: 0aebbeab274d @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = '0aebbeab274d' down_revision = '7a6587e16f4c' @@ -26,29 +26,35 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their # module path changes, which determines the plugin type string which is stored in the databse. # The type string now will have a type string prefix that is unique to each sub type. - statement = text(""" + statement = text( + """ UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; - """) + """ + ) conn.execute(statement) def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() - statement = text(""" + statement = text( + """ UPDATE db_dbnode SET type = 'data.base.Bool.' WHERE type = 'data.bool.Bool.'; UPDATE db_dbnode SET type = 'data.base.Float.' WHERE type = 'data.float.Float.'; UPDATE db_dbnode SET type = 'data.base.Int.' WHERE type = 'data.int.Int.'; UPDATE db_dbnode SET type = 'data.base.Str.' WHERE type = 'data.str.Str.'; UPDATE db_dbnode SET type = 'data.base.List.' WHERE type = 'data.list.List.'; - """) + """ + ) conn.execute(statement) diff --git a/aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py b/aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py new file mode 100644 index 0000000000..9b823afcf9 --- /dev/null +++ b/aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=no-member,invalid-name +"""Migration to add the `extras` JSONB column to the `DbGroup` model. + +Revision ID: 0edcdd5a30f0 +Revises: bf591f31dd12 +Create Date: 2019-04-03 14:38:50.585639 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import text + +# revision identifiers, used by Alembic. +revision = '0edcdd5a30f0' +down_revision = 'bf591f31dd12' +branch_labels = None +depends_on = None + + +def upgrade(): + """Upgrade: Add the extras column to the 'db_dbgroup' table""" + op.add_column('db_dbgroup', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()))) + op.execute(text(""" + UPDATE db_dbgroup + SET extras='{}' + """)) + op.alter_column('db_dbgroup', 'extras', nullable=False) + + +def downgrade(): + """Downgrade: Drop the extras column from the 'db_dbgroup' table""" + op.drop_column('db_dbgroup', 'extras') diff --git a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index 835cdff566..70c331faa1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Move trajectory symbols from repository array to attribute Revision ID: 12536798d4d3 diff --git a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py b/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py index 191d0e6935..888bf556be 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py +++ b/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the DbCalcState table Revision ID: 162b99bca4a2 @@ -35,10 +36,10 @@ def downgrade(): sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint( - ['dbnode_id'], ['db_dbnode.id'], - name='db_dbcalcstate_dbnode_id_fkey', - ondelete='CASCADE', - initially='DEFERRED', - deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), - sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key')) + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcalcstate_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), + sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') + ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py b/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py index 7f42c2d91a..0e9587e5b3 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the columns `nodeversion` and `public` from the `DbNode` model. Revision ID: 1830c8430131 diff --git a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py b/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py index 16ca636185..af91d0e34c 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py +++ b/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for legacy `JobCalculations`. These old nodes have already been migrated to the correct `CalcJobNode` type in a previous migration, but they can diff --git a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py b/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py index 7cacdbb518..8d417a4ffc 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py +++ b/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Migrating 'hidden' properties from DbAttribute to DbExtra for code.Code. nodes Revision ID: 35d4ee9a1b0e @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = '35d4ee9a1b0e' down_revision = '89176227b25' @@ -26,14 +26,25 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # Set hidden=True in extras if the attributes contain hidden=True - statement = text("""UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(True)) WHERE type = 'code.Code.' AND attributes @> '{"hidden": true}'""") + statement = text( + """ + UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(True)) + WHERE type = 'code.Code.' AND attributes @> '{"hidden": true}' + """ + ) conn.execute(statement) # Set hidden=False in extras if the attributes contain hidden=False - statement = text("""UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(False)) WHERE type = 'code.Code.' AND attributes @> '{"hidden": false}'""") + statement = text( + """ + UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(False)) + WHERE type = 'code.Code.' AND attributes @> '{"hidden": false}' + """ + ) conn.execute(statement) # Delete the hidden key from the attributes @@ -42,14 +53,25 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() # Set hidden=True in attributes if the extras contain hidden=True - statement = text("""UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(True)) WHERE type = 'code.Code.' AND extras @> '{"hidden": true}'""") + statement = text( + """ + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(True)) + WHERE type = 'code.Code.' AND extras @> '{"hidden": true}' + """ + ) conn.execute(statement) # Set hidden=False in attributes if the extras contain hidden=False - statement = text("""UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(False)) WHERE type = 'code.Code.' AND extras @> '{"hidden": false}'""") + statement = text( + """ + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(False)) + WHERE type = 'code.Code.' AND extras @> '{"hidden": false}' + """ + ) conn.execute(statement) # Delete the hidden key from the extras diff --git a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py b/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py index 76d6e1569d..bc43767eb1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py +++ b/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Make all uuid columns unique Revision ID: 37f3d4882837 diff --git a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py b/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py index f42d245587..c710703708 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py +++ b/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Adding indexes and constraints to the dbnode-dbgroup relationship table Revision ID: 59edaf8a8b79 @@ -25,30 +26,28 @@ def upgrade(): + """Migrations for the upgrade.""" # Check if constraint uix_dbnode_id_dbgroup_id of migration 7a6587e16f4c # is there and if yes, drop it insp = Inspector.from_engine(op.get_bind()) for constr in insp.get_unique_constraints('db_dbgroup_dbnodes'): if constr['name'] == 'uix_dbnode_id_dbgroup_id': - op.drop_constraint('uix_dbnode_id_dbgroup_id', - 'db_dbgroup_dbnodes') + op.drop_constraint('uix_dbnode_id_dbgroup_id', 'db_dbgroup_dbnodes') - op.create_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes', - ['dbnode_id']) - op.create_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes', - ['dbgroup_id']) + op.create_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes', ['dbnode_id']) + op.create_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes', ['dbgroup_id']) op.create_unique_constraint( - 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', - ['dbgroup_id', 'dbnode_id']) + 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', ['dbgroup_id', 'dbnode_id'] + ) def downgrade(): + """Migrations for the downgrade.""" op.drop_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes') op.drop_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes') - op.drop_constraint('db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', - 'db_dbgroup_dbnodes') + op.drop_constraint('db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes') # Creating the constraint uix_dbnode_id_dbgroup_id that migration # 7a6587e16f4c would add op.create_unique_constraint( - 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', - ['dbgroup_id', 'dbnode_id']) + 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', ['dbgroup_id', 'dbnode_id'] + ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py b/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py index 17661d95e2..354f52b6d3 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py +++ b/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Adding indices on the `input_id`, `output_id` and `type` column of the `DbLink` table Revision ID: 5a49629f0d45 diff --git a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py b/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py index e988b836f0..4420d84cd6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py +++ b/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Final data migration for `Nodes` after `aiida.orm.nodes` reorganization was finalized to remove the `node.` prefix Revision ID: 61fc0913fae9 diff --git a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py b/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py index fb67adf396..81336fb16f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py +++ b/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add a unique constraint on the UUID column of the Node model Revision ID: 62fe0d36de90 @@ -31,26 +32,30 @@ def verify_node_uuid_uniqueness(): :raises: IntegrityError if database contains nodes with duplicate UUIDS. """ - from alembic import op from sqlalchemy.sql import text from aiida.common.exceptions import IntegrityError query = text( - 'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM db_dbnode) AS s WHERE c > 1') + 'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM db_dbnode) AS s WHERE c > 1' + ) conn = op.get_bind() duplicates = conn.execute(query).fetchall() if duplicates: table = 'db_dbnode' - command = '`verdi database integrity detect-duplicate-uuid {table}`'.format(table) - raise IntegrityError('Your table "{}" contains entries with duplicate UUIDS.\nRun {} ' - 'to return to a consistent state'.format(table, command)) + command = '`verdi database integrity detect-duplicate-uuid {table}`'.format(table=table) + raise IntegrityError( + 'Your table "{}" contains entries with duplicate UUIDS.\nRun {} ' + 'to return to a consistent state'.format(table, command) + ) def upgrade(): + """Migrations for the upgrade.""" verify_node_uuid_uniqueness() op.create_unique_constraint('db_dbnode_uuid_key', 'db_dbnode', ['uuid']) def downgrade(): + """Migrations for the downgrade.""" op.drop_constraint('db_dbnode_uuid_key', 'db_dbnode') diff --git a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py b/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py index f6e32c800e..86160b0e46 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py +++ b/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for `Data` nodes after it was moved in the `aiida.orm.node` module changing the type string. Revision ID: 6a5c2ea1439d diff --git a/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py b/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py index 79fd26ca75..b29a4b7514 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py +++ b/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add the process_type column to DbNode Revision ID: 6c629c886f84 @@ -17,7 +18,6 @@ from alembic import op import sqlalchemy as sa - # revision identifiers, used by Alembic. revision = '6c629c886f84' down_revision = '0aebbeab274d' @@ -26,11 +26,14 @@ def upgrade(): - op.add_column('db_dbnode', + """Migrations for the upgrade.""" + op.add_column( + 'db_dbnode', sa.Column('process_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), ) op.create_index('ix_db_dbnode_process_type', 'db_dbnode', ['process_type']) def downgrade(): + """Migrations for the downgrade.""" op.drop_column('db_dbnode', 'process_type') diff --git a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py index 98df687a2d..bd0ad4409f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py +++ b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Deleting dbpath table and triggers Revision ID: 70c7d732f1b2 @@ -16,7 +17,6 @@ """ from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql from sqlalchemy.orm.session import Session from aiida.backends.sqlalchemy.utils import install_tc @@ -28,6 +28,7 @@ def upgrade(): + """Migrations for the upgrade.""" op.drop_table('db_dbpath') conn = op.get_bind() conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink') @@ -35,17 +36,23 @@ def upgrade(): def downgrade(): - op.create_table('db_dbpath', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], name='db_dbpath_child_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], name='db_dbpath_parent_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') + """Migrations for the downgrade.""" + op.create_table( + 'db_dbpath', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], + name='db_dbpath_child_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], + name='db_dbpath_parent_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') ) # I get the session using the alembic connection # (Keep in mind that alembic uses the AiiDA SQLA diff --git a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py b/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py index cd7ff83347..f3f8087837 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py +++ b/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add indexes to dbworkflowdata table Revision ID: 89176227b25 @@ -24,10 +25,8 @@ def upgrade(): - op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', - ['aiida_obj_id']) - op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', - ['parent_id']) + op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', ['aiida_obj_id']) + op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', ['parent_id']) def downgrade(): diff --git a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py b/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py index 1093abf660..24cd6c8be9 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the DbLock model Revision ID: a514d673c163 @@ -18,7 +19,6 @@ from sqlalchemy.dialects import postgresql import sqlalchemy as sa - # revision identifiers, used by Alembic. revision = 'a514d673c163' down_revision = 'f9a69de76a9a' @@ -31,10 +31,10 @@ def upgrade(): def downgrade(): - op.create_table('db_dblock', - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') + op.create_table( + 'db_dblock', sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py b/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py index 7342df1e61..71203f6f28 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Correct the type string for the code class Revision ID: a603da2cc809 @@ -25,6 +26,7 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # The Code class used to be just a sub class of Node but was changed to act like a Data node. @@ -36,6 +38,7 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() statement = text(""" diff --git a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py b/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py index 7636241e41..440d41cf20 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Updating link types - This is a copy of the Django migration script Revision ID: a6048f0ffca8 @@ -25,6 +26,7 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # I am first migrating the wrongly declared returnlinks out of @@ -40,7 +42,8 @@ def upgrade(): # - joins a Data (or subclass) as output # - is marked as a returnlink. # 2) set for these links the type to 'createlink' - stmt1 = text(""" + stmt1 = text( + """ UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -50,7 +53,8 @@ def upgrade(): AND db_dbnode_2.type LIKE 'data.%' AND db_dblink_1.type = 'returnlink' ) - """) + """ + ) conn.execute(stmt1) # Now I am updating the link-types that are null because of either an export and subsequent import # https://github.com/aiidateam/aiida-core/issues/685 @@ -63,7 +67,8 @@ def upgrade(): # - joins Calculation (or subclass) as output. This includes WorkCalculation, InlineCalcuation, JobCalculations... # - has no type (null) # 2) set for these links the type to 'inputlink' - stmt2 = text(""" + stmt2 = text( + """ UPDATE db_dblink set type='inputlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -73,7 +78,8 @@ def upgrade(): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """) + """ + ) conn.execute(stmt2) # # The following sql statement: @@ -82,7 +88,8 @@ def upgrade(): # - joins Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'createlink' - stmt3 = text(""" + stmt3 = text( + """ UPDATE db_dblink set type='createlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -96,7 +103,8 @@ def upgrade(): ) AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt3) # The following sql statement: # 1) selects all links that @@ -104,7 +112,8 @@ def upgrade(): # - join Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'returnlink' - stmt4 = text(""" + stmt4 = text( + """ UPDATE db_dblink set type='returnlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -114,7 +123,8 @@ def upgrade(): AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt4) # Now I update links that are CALLS: # The following sql statement: @@ -123,7 +133,8 @@ def upgrade(): # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. # - has no type (null) # 2) set for these links the type to 'calllink' - stmt5 = text(""" + stmt5 = text( + """ UPDATE db_dblink set type='calllink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -133,9 +144,11 @@ def upgrade(): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt5) def downgrade(): + """Migrations for the downgrade.""" print('There is no downgrade for the link types') diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py index 626b561c12..6d71cd55f6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py +++ b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Migration after the `Group` class became pluginnable and so the group `type_string` changed. Revision ID: bf591f31dd12 diff --git a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py index cf5932e79e..682e4371ad 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Delete trajectory symbols array from the repository and the reference in the attributes Revision ID: ce56d84bcc35 @@ -14,8 +15,6 @@ Create Date: 2019-01-21 15:35:07.280805 """ -# pylint: disable=invalid-name - # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-member,no-name-in-module,import-error diff --git a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py b/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py index 071e548ce3..87a1aa8fc0 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py +++ b/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for after `ParameterData` was renamed to `Dict`. Revision ID: d254fdfed416 @@ -14,7 +15,6 @@ Create Date: 2019-02-25 19:29:11.753089 """ -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module from alembic import op from sqlalchemy.sql import text diff --git a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py b/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py index 143054fe7b..a154d0f019 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py +++ b/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module """Drop various columns from the `DbUser` model. These columns were part of the default Django user model @@ -16,7 +17,6 @@ Create Date: 2019-05-28 11:15:33.242602 """ -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module from alembic import op import sqlalchemy as sa diff --git a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py b/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py index c48ec01f46..ab4b00f560 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py +++ b/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Initial schema Revision ID: e15ef2630a1b @@ -28,233 +29,290 @@ def upgrade(): - op.create_table('db_dbuser', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('email', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('password', sa.VARCHAR(length=128), autoincrement=False, nullable=True), - sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False), - sa.Column('first_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('last_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('institution', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('is_staff', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbuser_pkey'), - postgresql_ignore_search_path=False + """Migrations for the upgrade.""" + op.create_table( + 'db_dbuser', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('email', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('password', sa.VARCHAR(length=128), autoincrement=False, nullable=True), + sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False), + sa.Column('first_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('last_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('institution', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('is_staff', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbuser_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbuser_email', 'db_dbuser', ['email'], unique=True) - op.create_table('db_dbworkflow', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('lastsyncedversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('report', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module_class', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_path', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_md5', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflow_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), - postgresql_ignore_search_path=False + op.create_table( + 'db_dbworkflow', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('lastsyncedversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('report', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('module', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('module_class', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('script_path', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('script_md5', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflow_user_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbworkflow_label', 'db_dbworkflow', ['label']) - op.create_table('db_dbworkflowstep', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('nextcall', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_parent_id_fkey'), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflowstep_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), - sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_key'), - postgresql_ignore_search_path=False + op.create_table( + 'db_dbworkflowstep', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('nextcall', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_parent_id_fkey'), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflowstep_user_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), + sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_key'), + postgresql_ignore_search_path=False ) - op.create_table('db_dbcomputer', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('hostname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('transport_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('scheduler_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('transport_params', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcomputer_pkey'), - sa.UniqueConstraint('name', name='db_dbcomputer_name_key') + op.create_table( + 'db_dbcomputer', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('hostname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('transport_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('scheduler_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('transport_params', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbcomputer_pkey'), + sa.UniqueConstraint('name', name='db_dbcomputer_name_key') ) - op.create_table('db_dbauthinfo', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('aiidauser_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('auth_params', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['aiidauser_id'], ['db_dbuser.id'], name='db_dbauthinfo_aiidauser_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], name='db_dbauthinfo_dbcomputer_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbauthinfo_pkey'), - sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id', name='db_dbauthinfo_aiidauser_id_dbcomputer_id_key') + op.create_table( + 'db_dbauthinfo', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('aiidauser_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('auth_params', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['aiidauser_id'], ['db_dbuser.id'], + name='db_dbauthinfo_aiidauser_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], + name='db_dbauthinfo_dbcomputer_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbauthinfo_pkey'), + sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id', name='db_dbauthinfo_aiidauser_id_dbcomputer_id_key') ) - op.create_table('db_dbgroup', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbgroup_user_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbgroup_pkey'), - sa.UniqueConstraint('name', 'type', name='db_dbgroup_name_type_key') + op.create_table( + 'db_dbgroup', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbgroup_user_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbgroup_pkey'), + sa.UniqueConstraint('name', 'type', name='db_dbgroup_name_type_key') ) op.create_index('ix_db_dbgroup_name', 'db_dbgroup', ['name']) op.create_index('ix_db_dbgroup_type', 'db_dbgroup', ['type']) - op.create_table('db_dbnode', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('public', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('attributes', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('extras', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], name='db_dbnode_dbcomputer_id_fkey', ondelete='RESTRICT', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbnode_user_id_fkey', ondelete='RESTRICT', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbnode_pkey'),postgresql_ignore_search_path=False + op.create_table( + 'db_dbnode', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('public', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('attributes', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('extras', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], + name='db_dbnode_dbcomputer_id_fkey', + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbnode_user_id_fkey', + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True), + sa.PrimaryKeyConstraint('id', name='db_dbnode_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbnode_label', 'db_dbnode', ['label']) op.create_index('ix_db_dbnode_type', 'db_dbnode', ['type']) - op.create_table('db_dbgroup_dbnodes', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbgroup_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], name='db_dbgroup_dbnodes_dbgroup_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbgroup_dbnodes_dbnode_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbgroup_dbnodes_pkey') + op.create_table( + 'db_dbgroup_dbnodes', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbgroup_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], + name='db_dbgroup_dbnodes_dbgroup_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbgroup_dbnodes_dbnode_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbgroup_dbnodes_pkey') ) - op.create_table('db_dblock', - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') + op.create_table( + 'db_dblock', sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') ) - op.create_table('db_dbworkflowdata', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('data_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('value_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('json_value', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('aiida_obj_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['aiida_obj_id'], ['db_dbnode.id'], name='db_dbworkflowdata_aiida_obj_id_fkey'), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowdata_parent_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), - sa.UniqueConstraint('parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_key') + op.create_table( + 'db_dbworkflowdata', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('data_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('value_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('json_value', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('aiida_obj_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['aiida_obj_id'], ['db_dbnode.id'], name='db_dbworkflowdata_aiida_obj_id_fkey'), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowdata_parent_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), + sa.UniqueConstraint('parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_key') ) - op.create_table('db_dblink', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('input_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('output_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], name='db_dblink_input_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['output_id'], ['db_dbnode.id'], name='db_dblink_output_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dblink_pkey'), + op.create_table( + 'db_dblink', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('input_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('output_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], + name='db_dblink_input_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['output_id'], ['db_dbnode.id'], + name='db_dblink_output_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.PrimaryKeyConstraint('id', name='db_dblink_pkey'), ) op.create_index('ix_db_dblink_label', 'db_dblink', ['label']) - op.create_table('db_dbworkflowstep_calculations', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbworkflowstep_calculations_dbnode_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], name='db_dbworkflowstep_calculations_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), - sa.UniqueConstraint('dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calculations_id_dbnode_id_key') + op.create_table( + 'db_dbworkflowstep_calculations', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbworkflowstep_calculations_dbnode_id_fkey'), + sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_calculations_dbworkflowstep_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), + sa.UniqueConstraint('dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calculations_id_dbnode_id_key') ) - op.create_table('db_dbpath', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], name='db_dbpath_child_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], name='db_dbpath_parent_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') + op.create_table( + 'db_dbpath', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], + name='db_dbpath_child_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], + name='db_dbpath_parent_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') ) - op.create_table('db_dbcalcstate', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbcalcstate_dbnode_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), - sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') + op.create_table( + 'db_dbcalcstate', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcalcstate_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), + sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') ) op.create_index('ix_db_dbcalcstate_state', 'db_dbcalcstate', ['state']) - op.create_table('db_dbsetting', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('val', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('description', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbsetting_pkey'), - sa.UniqueConstraint('key', name='db_dbsetting_key_key') + op.create_table( + 'db_dbsetting', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('val', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('description', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbsetting_pkey'), + sa.UniqueConstraint('key', name='db_dbsetting_key_key') ) op.create_index('ix_db_dbsetting_key', 'db_dbsetting', ['key']) - op.create_table('db_dbcomment', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('content', sa.TEXT(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbcomment_dbnode_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbcomment_user_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcomment_pkey') + op.create_table( + 'db_dbcomment', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('content', sa.TEXT(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcomment_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbcomment_user_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcomment_pkey') ) - op.create_table('db_dblog', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('loggername', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('levelname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('objname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('objpk', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('message', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dblog_pkey') + op.create_table( + 'db_dblog', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('loggername', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('levelname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('objname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('objpk', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('message', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dblog_pkey') ) op.create_index('ix_db_dblog_levelname', 'db_dblog', ['levelname']) op.create_index('ix_db_dblog_loggername', 'db_dblog', ['loggername']) op.create_index('ix_db_dblog_objname', 'db_dblog', ['objname']) op.create_index('ix_db_dblog_objpk', 'db_dblog', ['objpk']) - op.create_table('db_dbworkflowstep_sub_workflows', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbworkflow_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbworkflow_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_sub_workflows_dbworkflow_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], name='db_dbworkflowstep_sub_workflows_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), - sa.UniqueConstraint('dbworkflowstep_id', 'dbworkflow_id', name='db_dbworkflowstep_sub_workflows_id_dbworkflow__key') + op.create_table( + 'db_dbworkflowstep_sub_workflows', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbworkflow_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbworkflow_id'], ['db_dbworkflow.id'], + name='db_dbworkflowstep_sub_workflows_dbworkflow_id_fkey'), + sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_sub_workflows_dbworkflowstep_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), + sa.UniqueConstraint( + 'dbworkflowstep_id', 'dbworkflow_id', name='db_dbworkflowstep_sub_workflows_id_dbworkflow__key' + ) ) # I get the session using the alembic connection # (Keep in mind that alembic uses the AiiDA SQLA @@ -264,6 +322,7 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" op.drop_table('db_dbworkflowstep_calculations') op.drop_table('db_dbworkflowstep_sub_workflows') op.drop_table('db_dbworkflowdata') diff --git a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py index 3b0b4c32f8..3d34308429 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,no-member -# pylint: disable=no-name-in-module,import-error +# pylint: disable=invalid-name,no-member,no-name-in-module,import-error """This migration creates UUID column and populates it with distinct UUIDs This migration corresponds to the 0024_dblog_update Django migration. diff --git a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py b/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py index ff38db86f6..a6543778a4 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py +++ b/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Delete the kombu tables that were used by the old Celery based daemon and the obsolete related timestamps Revision ID: f9a69de76a9a @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = 'f9a69de76a9a' down_revision = '6c629c886f84' @@ -26,10 +26,12 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # Drop the kombu tables and delete the old timestamps and user related to the daemon in the DbSetting table - statement = text(""" + statement = text( + """ DROP TABLE IF EXISTS kombu_message; DROP TABLE IF EXISTS kombu_queue; DROP SEQUENCE IF EXISTS message_id_sequence; @@ -41,9 +43,11 @@ def upgrade(): DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; - """) + """ + ) conn.execute(statement) def downgrade(): + """Migrations for the downgrade.""" print('There is no downgrade for the deletion of the kombu tables and the daemon timestamps') diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index 8721455bfd..73a7cba6cf 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -44,7 +44,7 @@ class _SessionProperty: def __get__(self, obj, _type): if not aiida.backends.sqlalchemy.get_scoped_session(): - raise InvalidOperation('You need to call load_dbenv before ' 'accessing the session of SQLALchemy.') + raise InvalidOperation('You need to call load_dbenv before accessing the session of SQLALchemy.') return aiida.backends.sqlalchemy.get_scoped_session() diff --git a/aiida/backends/sqlalchemy/models/computer.py b/aiida/backends/sqlalchemy/models/computer.py index 5b4befde47..e53f052ebf 100644 --- a/aiida/backends/sqlalchemy/models/computer.py +++ b/aiida/backends/sqlalchemy/models/computer.py @@ -9,7 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage computers for the SQLA backend.""" - from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.schema import Column from sqlalchemy.types import Integer, String, Text @@ -34,7 +33,6 @@ class DbComputer(Base): def __init__(self, *args, **kwargs): """Provide _metadata and description attributes to the class.""" self._metadata = {} - # TODO SP: it's supposed to be nullable, but there is a NOT constraint inside the DB. self.description = '' # If someone passes metadata in **kwargs we change it to _metadata diff --git a/aiida/backends/sqlalchemy/models/group.py b/aiida/backends/sqlalchemy/models/group.py index 95126f1c4d..22d422968f 100644 --- a/aiida/backends/sqlalchemy/models/group.py +++ b/aiida/backends/sqlalchemy/models/group.py @@ -15,7 +15,7 @@ from sqlalchemy.schema import Column, Table, UniqueConstraint, Index from sqlalchemy.types import Integer, String, DateTime, Text -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import UUID, JSONB from aiida.common import timezone from aiida.common.utils import get_new_uuid @@ -47,6 +47,8 @@ class DbGroup(Base): time = Column(DateTime(timezone=True), default=timezone.now) description = Column(Text, nullable=True) + extras = Column(JSONB, default=dict, nullable=False) + user_id = Column(Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) user = relationship('DbUser', backref=backref('dbgroups', cascade='merge')) diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/backends/sqlalchemy/models/node.py index b1ccce8eba..48d0cae220 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/backends/sqlalchemy/models/node.py @@ -53,6 +53,7 @@ class DbNode(Base): Integer, ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='restrict'), nullable=False ) + # pylint: disable=fixme # TODO SP: The 'passive_deletes=all' argument here means that SQLAlchemy # won't take care of automatic deleting in the DbLink table. This still # isn't exactly the same behaviour than with Django. The solution to @@ -108,7 +109,7 @@ def outputs(self): @property def inputs(self): - return self.inputs_q.all() + return self.inputs_q.all() # pylint: disable=no-member def get_simple_name(self, invalid_result=None): """ diff --git a/aiida/backends/sqlalchemy/models/settings.py b/aiida/backends/sqlalchemy/models/settings.py index d5b328999e..eac5433a28 100644 --- a/aiida/backends/sqlalchemy/models/settings.py +++ b/aiida/backends/sqlalchemy/models/settings.py @@ -9,7 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage node settings for the SQLA backend.""" - from pytz import UTC from sqlalchemy import Column @@ -40,9 +39,7 @@ def __str__(self): return "'{}'={}".format(self.key, self.getvalue()) @classmethod - def set_value( - cls, key, value, with_transaction=True, subspecifier_value=None, other_attribs=None, stop_if_existing=False - ): + def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): """Set a setting value.""" other_attribs = other_attribs if other_attribs is not None else {} setting = sa.get_scoped_session().query(DbSetting).filter_by(key=key).first() diff --git a/aiida/backends/testimplbase.py b/aiida/backends/testimplbase.py index 32f22e0333..83603e9c42 100644 --- a/aiida/backends/testimplbase.py +++ b/aiida/backends/testimplbase.py @@ -69,10 +69,10 @@ def create_user(self): def create_computer(self): """This method creates and stores a computer.""" self.computer = orm.Computer( - name='localhost', + label='localhost', hostname='localhost', transport_type='local', - scheduler_type='pbspro', + scheduler_type='direct', workdir='/tmp/aiida', backend=self.backend ).store() diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index ebbe03912c..ebea81ab48 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -24,12 +24,16 @@ def create_sqlalchemy_engine(profile, **kwargs): from sqlalchemy import create_engine from aiida.common import json + # The hostname may be `None`, which is a valid value in the case of peer authentication for example. In this case + # it should be converted to an empty string, because otherwise the `None` will be converted to string literal "None" + hostname = profile.database_hostname or '' separator = ':' if profile.database_port else '' + engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( separator=separator, user=profile.database_username, password=profile.database_password, - hostname=profile.database_hostname, + hostname=hostname, port=profile.database_port, name=profile.database_name ) diff --git a/aiida/calculations/arithmetic/add.py b/aiida/calculations/arithmetic/add.py index 54f68d8200..6fccfe8d49 100644 --- a/aiida/calculations/arithmetic/add.py +++ b/aiida/calculations/arithmetic/add.py @@ -24,15 +24,15 @@ def define(cls, spec: CalcJobProcessSpec): :param spec: the calculation job process spec to define. """ super().define(spec) + spec.input('x', valid_type=(orm.Int, orm.Float), help='The left operand.') + spec.input('y', valid_type=(orm.Int, orm.Float), help='The right operand.') + spec.output('sum', valid_type=(orm.Int, orm.Float), help='The sum of the left and right operand.') + # set default options (optional) spec.inputs['metadata']['options']['parser_name'].default = 'arithmetic.add' spec.inputs['metadata']['options']['input_filename'].default = 'aiida.in' spec.inputs['metadata']['options']['output_filename'].default = 'aiida.out' spec.inputs['metadata']['options']['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1} - spec.input('x', valid_type=(orm.Int, orm.Float), help='The left operand.') - spec.input('y', valid_type=(orm.Int, orm.Float), help='The right operand.') - spec.output('sum', valid_type=(orm.Int, orm.Float), help='The sum of the left and right operand.') # start exit codes - marker for docs - spec.exit_code(300, 'ERROR_NO_RETRIEVED_FOLDER', message='The retrieved output node does not exist.') spec.exit_code(310, 'ERROR_READING_OUTPUT_FILE', message='The output file could not be read.') spec.exit_code(320, 'ERROR_INVALID_OUTPUT', message='The output file contains invalid output.') spec.exit_code(410, 'ERROR_NEGATIVE_NUMBER', message='The sum of the operands is a negative number.') diff --git a/aiida/calculations/templatereplacer.py b/aiida/calculations/templatereplacer.py index aa98da7321..f84e715b10 100644 --- a/aiida/calculations/templatereplacer.py +++ b/aiida/calculations/templatereplacer.py @@ -68,13 +68,11 @@ def define(cls, spec): help='A template for the input file.') spec.input('parameters', valid_type=orm.Dict, required=False, help='Parameters used to replace placeholders in the template.') - spec.input_namespace('files', valid_type=(orm.RemoteData, orm.SinglefileData), required=False) + spec.input_namespace('files', valid_type=(orm.RemoteData, orm.SinglefileData), required=False, dynamic=True) spec.output('output_parameters', valid_type=orm.Dict, required=True) spec.default_output_node = 'output_parameters' - spec.exit_code(100, 'ERROR_NO_RETRIEVED_FOLDER', - message='The retrieved folder data node could not be accessed.') spec.exit_code(101, 'ERROR_NO_TEMPORARY_RETRIEVED_FOLDER', message='The temporary retrieved folder data node could not be accessed.') spec.exit_code(105, 'ERROR_NO_OUTPUT_FILE_NAME_DEFINED', @@ -86,6 +84,7 @@ def define(cls, spec): spec.exit_code(120, 'ERROR_INVALID_OUTPUT', message='The output file contains invalid output.') + def prepare_for_submission(self, folder): """ This is the routine to be called when you want to create the input files and related stuff with a plugin. diff --git a/aiida/cmdline/commands/cmd_calcjob.py b/aiida/cmdline/commands/cmd_calcjob.py index 2ee966fc3f..7a59444558 100644 --- a/aiida/cmdline/commands/cmd_calcjob.py +++ b/aiida/cmdline/commands/cmd_calcjob.py @@ -271,4 +271,4 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force): clean_remote(transport, path) counter += 1 - echo.echo_success('{} remote folders cleaned on {}'.format(counter, computer.name)) + echo.echo_success('{} remote folders cleaned on {}'.format(counter, computer.label)) diff --git a/aiida/cmdline/commands/cmd_code.py b/aiida/cmdline/commands/cmd_code.py index 4a77b61462..80cfb943e4 100644 --- a/aiida/cmdline/commands/cmd_code.py +++ b/aiida/cmdline/commands/cmd_code.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi code` command.""" - from functools import partial + import click import tabulate @@ -18,7 +18,6 @@ from aiida.cmdline.params.options.commands import code as options_code from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv -from aiida.cmdline.utils.multi_line_input import ensure_scripts from aiida.common.exceptions import InputValidationError @@ -45,7 +44,7 @@ def get_default(key, ctx): def get_computer_name(ctx): - return getattr(ctx.code_builder, 'computer').name + return getattr(ctx.code_builder, 'computer').label def get_on_computer(ctx): @@ -69,8 +68,8 @@ def set_code_builder(ctx, param, value): @options_code.REMOTE_ABS_PATH() @options_code.FOLDER() @options_code.REL_PATH() -@options.PREPEND_TEXT() -@options.APPEND_TEXT() +@options_code.PREPEND_TEXT() +@options_code.APPEND_TEXT() @options.NON_INTERACTIVE() @options.CONFIG_FILE() @with_dbenv() @@ -79,15 +78,6 @@ def setup_code(non_interactive, **kwargs): from aiida.common.exceptions import ValidationError from aiida.orm.utils.builders.code import CodeBuilder - if not non_interactive: - try: - pre, post = ensure_scripts(kwargs.pop('prepend_text', ''), kwargs.pop('append_text', ''), kwargs) - except InputValidationError as exception: - raise click.BadParameter('invalid prepend and or append text: {}'.format(exception)) - - kwargs['prepend_text'] = pre - kwargs['append_text'] = post - if kwargs.pop('on_computer'): kwargs['code_type'] = CodeBuilder.CodeType.ON_COMPUTER else: @@ -119,8 +109,8 @@ def setup_code(non_interactive, **kwargs): @options_code.REMOTE_ABS_PATH(contextual_default=partial(get_default, 'remote_abs_path')) @options_code.FOLDER(contextual_default=partial(get_default, 'code_folder')) @options_code.REL_PATH(contextual_default=partial(get_default, 'code_rel_path')) -@options.PREPEND_TEXT(cls=options.ContextualDefaultOption, contextual_default=partial(get_default, 'prepend_text')) -@options.APPEND_TEXT(cls=options.ContextualDefaultOption, contextual_default=partial(get_default, 'append_text')) +@options_code.PREPEND_TEXT(contextual_default=partial(get_default, 'prepend_text')) +@options_code.APPEND_TEXT(contextual_default=partial(get_default, 'append_text')) @options.NON_INTERACTIVE() @click.option('--hide-original', is_flag=True, default=False, help='Hide the code being copied.') @click.pass_context @@ -130,15 +120,6 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): from aiida.common.exceptions import ValidationError from aiida.orm.utils.builders.code import CodeBuilder - if not non_interactive: - try: - pre, post = ensure_scripts(kwargs.pop('prepend_text', ''), kwargs.pop('append_text', ''), kwargs) - except InputValidationError as exception: - raise click.BadParameter('invalid prepend and or append text: {}'.format(exception)) - - kwargs['prepend_text'] = pre - kwargs['append_text'] = post - if kwargs.pop('on_computer'): kwargs['code_type'] = CodeBuilder.CodeType.ON_COMPUTER else: @@ -168,7 +149,36 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): @with_dbenv() def show(code, verbose): """Display detailed information for a code.""" - click.echo(tabulate.tabulate(code.get_full_text_info(verbose))) + from aiida.repository import FileType + + table = [] + table.append(['PK', code.pk]) + table.append(['UUID', code.uuid]) + table.append(['Label', code.label]) + table.append(['Description', code.description]) + table.append(['Default plugin', code.get_input_plugin_name()]) + + if code.is_local(): + table.append(['Type', 'local']) + table.append(['Exec name', code.get_execname()]) + table.append(['List of files/folders:', '']) + for obj in code.list_objects(): + if obj.file_type == FileType.DIRECTORY: + table.append(['directory', obj.name]) + else: + table.append(['file', obj.name]) + else: + table.append(['Type', 'remote']) + table.append(['Remote machine', code.get_remote_computer().label]) + table.append(['Remote absolute path', code.get_remote_exec_path()]) + + table.append(['Prepend text', code.get_prepend_text()]) + table.append(['Append text', code.get_append_text()]) + + if verbose: + table.append(['Calculations', len(code.get_outgoing().all())]) + + click.echo(tabulate.tabulate(table)) @verdi_code.command() @@ -225,7 +235,7 @@ def relabel(code, label): try: code.relabel(label) except InputValidationError as exception: - echo.echo_critical('invalid code name: {}'.format(exception)) + echo.echo_critical('invalid code label: {}'.format(exception)) else: echo.echo_success('Code<{}> relabeled from {} to {}'.format(code.pk, old_label, code.full_label)) @@ -249,7 +259,7 @@ def code_list(computer, input_plugin, all_entries, all_users, show_owner): qb_computer_filters = dict() if computer is not None: - qb_computer_filters['name'] = computer.name + qb_computer_filters['name'] = computer.label qb_code_filters = dict() if input_plugin is not None: diff --git a/aiida/cmdline/commands/cmd_completioncommand.py b/aiida/cmdline/commands/cmd_completioncommand.py index a9b3ba2817..dbf5e7b359 100644 --- a/aiida/cmdline/commands/cmd_completioncommand.py +++ b/aiida/cmdline/commands/cmd_completioncommand.py @@ -17,13 +17,11 @@ @verdi.command('completioncommand') def verdi_completioncommand(): - """ - Return the code to activate bash completion. - - :note: this command is mainly for back-compatibility. - You should rather use:; + """Return the code to activate bash completion. - eval "$(_VERDI_COMPLETE=source verdi)" + \b + This command is mainly for back-compatibility. + You should rather use: eval "$(_VERDI_COMPLETE=source verdi)" """ from click_completion import get_auto_shell, get_code click.echo(get_code(shell=get_auto_shell())) diff --git a/aiida/cmdline/commands/cmd_computer.py b/aiida/cmdline/commands/cmd_computer.py index b32b31dd93..b7a2e09c79 100644 --- a/aiida/cmdline/commands/cmd_computer.py +++ b/aiida/cmdline/commands/cmd_computer.py @@ -9,18 +9,17 @@ ########################################################################### # pylint: disable=invalid-name,too-many-statements,too-many-branches """`verdi computer` command.""" - from functools import partial import click +import tabulate from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import options, arguments from aiida.cmdline.params.options.commands import computer as options_computer from aiida.cmdline.utils import echo -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.cmdline.utils.multi_line_input import ensure_scripts -from aiida.common.exceptions import ValidationError, InputValidationError +from aiida.cmdline.utils.decorators import with_dbenv, deprecated_command +from aiida.common.exceptions import ValidationError from aiida.plugins.entry_point import get_entry_points from aiida.transports import cli as transport_cli @@ -76,35 +75,20 @@ def _computer_test_no_unexpected_output(transport, scheduler, authinfo): # pyli if retval != 0: return False, 'The command `echo -n` returned a non-zero return code ({})'.format(retval) - if stdout: - return False, u""" -There is some spurious output in the standard output, -that we report below between the === signs: -========================================================= + template = """ +We detected some spurious output in the {} when connecting to the computer, as shown between the bars +===================================================================================================== {} -========================================================= -Please check that you don't have code producing output in -your ~/.bash_profile (or ~/.bashrc). If you don't want to -remove the code, but just to disable it for non-interactive -shells, see comments in issue #1980 on GitHub: -https://github.com/aiidateam/aiida-core/issues/1890 -(and in the AiiDA documentation, linked from that issue) -""".format(stdout) +===================================================================================================== +Please check that you don't have code producing output in your ~/.bash_profile, ~/.bashrc or similar. +If you don't want to remove the code, but just to disable it for non-interactive shells, see comments +in this troubleshooting section of the online documentation: https://bit.ly/2FCRDc5 +""" + if stdout: + return False, template.format('stdout', stdout) if stderr: - return u""" -There is some spurious output in the stderr, -that we report below between the === signs: -========================================================= -{} -========================================================= -Please check that you don't have code producing output in -your ~/.bash_profile (or ~/.bashrc). If you don't want to -remove the code, but just to disable it for non-interactive -shells, see comments in issue #1980 on GitHub: -https://github.com/aiidateam/aiida-core/issues/1890 -(and in the AiiDA documentation, linked from that issue) -""" + return False, template.format('stderr', stderr) return True, None @@ -220,8 +204,8 @@ def set_computer_builder(ctx, param, value): @options_computer.WORKDIR() @options_computer.MPI_RUN_COMMAND() @options_computer.MPI_PROCS_PER_MACHINE() -@options.PREPEND_TEXT() -@options.APPEND_TEXT() +@options_computer.PREPEND_TEXT() +@options_computer.APPEND_TEXT() @options.NON_INTERACTIVE() @options.CONFIG_FILE() @click.pass_context @@ -237,15 +221,6 @@ def computer_setup(ctx, non_interactive, **kwargs): 'computer starting from the settings of {c}.'.format(c=kwargs['label']) ) - if not non_interactive: - try: - pre, post = ensure_scripts(kwargs.pop('prepend_text', ''), kwargs.pop('append_text', ''), kwargs) - except InputValidationError as exception: - raise click.BadParameter('invalid prepend and or append text: {}'.format(exception)) - - kwargs['prepend_text'] = pre - kwargs['append_text'] = post - kwargs['transport'] = kwargs['transport'].name kwargs['scheduler'] = kwargs['scheduler'].name @@ -260,10 +235,10 @@ def computer_setup(ctx, non_interactive, **kwargs): except ValidationError as err: echo.echo_critical('unable to store the computer: {}. Exiting...'.format(err)) else: - echo.echo_success('Computer<{}> {} created'.format(computer.pk, computer.name)) + echo.echo_success('Computer<{}> {} created'.format(computer.pk, computer.label)) echo.echo_info('Note: before the computer can be used, it has to be configured with the command:') - echo.echo_info(' verdi computer configure {} {}'.format(computer.get_transport_type(), computer.name)) + echo.echo_info(' verdi computer configure {} {}'.format(computer.transport_type, computer.label)) @verdi_computer.command('duplicate') @@ -277,12 +252,8 @@ def computer_setup(ctx, non_interactive, **kwargs): @options_computer.WORKDIR(contextual_default=partial(get_parameter_default, 'work_dir')) @options_computer.MPI_RUN_COMMAND(contextual_default=partial(get_parameter_default, 'mpirun_command')) @options_computer.MPI_PROCS_PER_MACHINE(contextual_default=partial(get_parameter_default, 'mpiprocs_per_machine')) -@options.PREPEND_TEXT( - cls=options.ContextualDefaultOption, contextual_default=partial(get_parameter_default, 'prepend_text') -) -@options.APPEND_TEXT( - cls=options.ContextualDefaultOption, contextual_default=partial(get_parameter_default, 'append_text') -) +@options_computer.PREPEND_TEXT(contextual_default=partial(get_parameter_default, 'prepend_text')) +@options_computer.APPEND_TEXT(contextual_default=partial(get_parameter_default, 'append_text')) @options.NON_INTERACTIVE() @click.pass_context @with_dbenv() @@ -294,15 +265,6 @@ def computer_duplicate(ctx, computer, non_interactive, **kwargs): if kwargs['label'] in get_computer_names(): echo.echo_critical('A computer called {} already exists'.format(kwargs['label'])) - if not non_interactive: - try: - pre, post = ensure_scripts(kwargs.pop('prepend_text', ''), kwargs.pop('append_text', ''), kwargs) - except InputValidationError as exception: - raise click.BadParameter('invalid prepend and or append text: {}'.format(exception)) - - kwargs['prepend_text'] = pre - kwargs['append_text'] = post - kwargs['transport'] = kwargs['transport'].name kwargs['scheduler'] = kwargs['scheduler'].name @@ -316,20 +278,20 @@ def computer_duplicate(ctx, computer, non_interactive, **kwargs): except (ComputerBuilder.ComputerValidationError, ValidationError) as e: echo.echo_critical('{}: {}'.format(type(e).__name__, e)) else: - echo.echo_success('stored computer {}<{}>'.format(computer.name, computer.pk)) + echo.echo_success('stored computer {}<{}>'.format(computer.label, computer.pk)) try: computer.store() except ValidationError as err: echo.echo_critical('unable to store the computer: {}. Exiting...'.format(err)) else: - echo.echo_success('Computer<{}> {} created'.format(computer.pk, computer.name)) + echo.echo_success('Computer<{}> {} created'.format(computer.pk, computer.label)) is_configured = computer.is_user_configured(orm.User.objects.get_default()) if not is_configured: echo.echo_info('Note: before the computer can be used, it has to be configured with the command:') - echo.echo_info(' verdi computer configure {} {}'.format(computer.get_transport_type(), computer.name)) + echo.echo_info(' verdi computer configure {} {}'.format(computer.transport_type, computer.label)) @verdi_computer.command('enable') @@ -344,15 +306,15 @@ def computer_enable(computer, user): authinfo = computer.get_authinfo(user) except NotExistent: echo.echo_critical( - "User with email '{}' is not configured for computer '{}' yet.".format(user.email, computer.name) + "User with email '{}' is not configured for computer '{}' yet.".format(user.email, computer.label) ) if not authinfo.enabled: authinfo.enabled = True - echo.echo_info("Computer '{}' enabled for user {}.".format(computer.name, user.get_full_name())) + echo.echo_info("Computer '{}' enabled for user {}.".format(computer.label, user.get_full_name())) else: echo.echo_info( - "Computer '{}' was already enabled for user {} {}.".format(computer.name, user.first_name, user.last_name) + "Computer '{}' was already enabled for user {} {}.".format(computer.label, user.first_name, user.last_name) ) @@ -370,15 +332,17 @@ def computer_disable(computer, user): authinfo = computer.get_authinfo(user) except NotExistent: echo.echo_critical( - "User with email '{}' is not configured for computer '{}' yet.".format(user.email, computer.name) + "User with email '{}' is not configured for computer '{}' yet.".format(user.email, computer.label) ) if authinfo.enabled: authinfo.enabled = False - echo.echo_info("Computer '{}' disabled for user {}.".format(computer.name, user.get_full_name())) + echo.echo_info("Computer '{}' disabled for user {}.".format(computer.label, user.get_full_name())) else: echo.echo_info( - "Computer '{}' was already disabled for user {} {}.".format(computer.name, user.first_name, user.last_name) + "Computer '{}' was already disabled for user {} {}.".format( + computer.label, user.first_name, user.last_name + ) ) @@ -400,7 +364,7 @@ def computer_list(all_entries, raw): if not computers: echo.echo_info("No computers configured yet. Use 'verdi computer setup'") - sort = lambda computer: computer.name + sort = lambda computer: computer.label highlight = lambda comp: comp.is_user_configured(user) and comp.is_user_enabled(user) hide = lambda comp: not (comp.is_user_configured(user) and comp.is_user_enabled(user)) and not all_entries echo.echo_formatted_list(computers, ['name'], sort=sort, highlight=highlight, hide=hide) @@ -411,36 +375,57 @@ def computer_list(all_entries, raw): @with_dbenv() def computer_show(computer): """Show detailed information for a computer.""" - echo.echo(computer.full_text_info) + table = [] + table.append(['Label', computer.label]) + table.append(['PK', computer.pk]) + table.append(['UUID', computer.uuid]) + table.append(['Description', computer.description]) + table.append(['Hostname', computer.hostname]) + table.append(['Transport type', computer.transport_type]) + table.append(['Scheduler type', computer.scheduler_type]) + table.append(['Work directory', computer.get_workdir()]) + table.append(['Shebang', computer.get_shebang()]) + table.append(['Mpirun command', ' '.join(computer.get_mpirun_command())]) + table.append(['Prepend text', computer.get_prepend_text()]) + table.append(['Append text', computer.get_append_text()]) + echo.echo(tabulate.tabulate(table)) @verdi_computer.command('rename') @arguments.COMPUTER() @arguments.LABEL('NEW_NAME') +@deprecated_command("This command has been deprecated. Please use 'verdi computer relabel' instead.") +@click.pass_context @with_dbenv() -def computer_rename(computer, new_name): +def computer_rename(ctx, computer, new_name): """Rename a computer.""" + ctx.invoke(computer_relabel, computer=computer, label=new_name) + + +@verdi_computer.command('relabel') +@arguments.COMPUTER() +@arguments.LABEL('LABEL') +@with_dbenv() +def computer_relabel(computer, label): + """Relabel a computer.""" from aiida.common.exceptions import UniquenessError - old_name = computer.get_name() + old_label = computer.label - if old_name == new_name: - echo.echo_critical('The old and new names are the same.') + if old_label == label: + echo.echo_critical('The old and new labels are the same.') try: - computer.set_name(new_name) + computer.label = label computer.store() except ValidationError as error: echo.echo_critical('Invalid input! {}'.format(error)) except UniquenessError as error: echo.echo_critical( - 'Uniqueness error encountered! Probably a ' - "computer with name '{}' already exists" - ''.format(new_name) + "Uniqueness error encountered! Probably a computer with label '{}' already exists: {}".format(label, error) ) - echo.echo_critical('(Message was: {})'.format(error)) - echo.echo_success("Computer '{}' renamed to '{}'".format(old_name, new_name)) + echo.echo_success("Computer '{}' relabeled to '{}'".format(old_label, label)) @verdi_computer.command('test') @@ -449,12 +434,7 @@ def computer_rename(computer, new_name): help='Test the connection for a given AiiDA user, specified by' 'their email address. If not specified, uses the current default user.', ) -@click.option( - '-t', - '--print-traceback', - is_flag=True, - help='Print the full traceback in case an exception is raised', -) +@options.PRINT_TRACEBACK() @arguments.COMPUTER() @with_dbenv() def computer_test(user, print_traceback, computer): @@ -472,15 +452,15 @@ def computer_test(user, print_traceback, computer): if user is None: user = orm.User.objects.get_default() - echo.echo_info('Testing computer<{}> for user<{}>...'.format(computer.name, user.email)) + echo.echo_info('Testing computer<{}> for user<{}>...'.format(computer.label, user.email)) try: authinfo = computer.get_authinfo(user) except NotExistent: - echo.echo_critical('Computer<{}> is not yet configured for user<{}>'.format(computer.name, user.email)) + echo.echo_critical('Computer<{}> is not yet configured for user<{}>'.format(computer.label, user.email)) if not authinfo.enabled: - echo.echo_warning('Computer<{}> is disabled for user<{}>'.format(computer.name, user.email)) + echo.echo_warning('Computer<{}> is disabled for user<{}>'.format(computer.label, user.email)) click.confirm('Do you really want to test it?', abort=True) scheduler = authinfo.computer.get_scheduler() @@ -568,14 +548,14 @@ def computer_delete(computer): from aiida.common.exceptions import InvalidOperation from aiida import orm - compname = computer.name + label = computer.label try: orm.Computer.objects.delete(computer.id) except InvalidOperation as error: echo.echo_critical(str(error)) - echo.echo_success("Computer '{}' deleted.".format(compname)) + echo.echo_success("Computer '{}' deleted.".format(label)) @verdi_computer.group('configure') @@ -594,12 +574,11 @@ def computer_configure(): @arguments.COMPUTER() def computer_config_show(computer, user, defaults, as_option_string): """Show the current configuration for a computer.""" - import tabulate from aiida.common.escaping import escape_for_bash transport_cls = computer.get_transport_class() option_list = [ - param for param in transport_cli.create_configure_cmd(computer.get_transport_type()).params + param for param in transport_cli.create_configure_cmd(computer.transport_type).params if isinstance(param, click.core.Option) ] option_list = [option for option in option_list if option.name in transport_cls.get_valid_auth_params()] diff --git a/aiida/cmdline/commands/cmd_data/cmd_remote.py b/aiida/cmdline/commands/cmd_data/cmd_remote.py index b3645ae693..6b992697a7 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_remote.py +++ b/aiida/cmdline/commands/cmd_data/cmd_remote.py @@ -85,6 +85,6 @@ def remote_cat(datum, path): def remote_show(datum): """Show information for a RemoteData object.""" click.echo('- Remote computer name:') - click.echo(' {}'.format(datum.get_computer_name())) + click.echo(' {}'.format(datum.computer.label)) click.echo('- Remote folder full path:') click.echo(' {}'.format(datum.get_remote_path())) diff --git a/aiida/cmdline/commands/cmd_data/cmd_show.py b/aiida/cmdline/commands/cmd_data/cmd_show.py index 8a443e2945..51cdbca9b1 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_show.py +++ b/aiida/cmdline/commands/cmd_data/cmd_show.py @@ -10,7 +10,6 @@ """ This allows to manage showfunctionality to all data types. """ - import click from aiida.cmdline.params.options.multivalue import MultipleValueOption @@ -209,7 +208,7 @@ def _show_xmgrace(exec_name, list_bands): import sys import subprocess import tempfile - from aiida.orm.nodes.data.array.bands import max_num_agr_colors + from aiida.orm.nodes.data.array.bands import MAX_NUM_AGR_COLORS list_files = [] current_band_number = 0 @@ -218,7 +217,7 @@ def _show_xmgrace(exec_name, list_bands): nbnds = bnds.get_bands().shape[1] # pylint: disable=protected-access text, _ = bnds._exportcontent( - 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % max_num_agr_colors) + 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) ) # write a tempfile tempf = tempfile.NamedTemporaryFile('w+b', suffix='.agr') diff --git a/aiida/cmdline/commands/cmd_data/cmd_structure.py b/aiida/cmdline/commands/cmd_data/cmd_structure.py index c0d205197e..a2e91949bd 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_structure.py +++ b/aiida/cmdline/commands/cmd_data/cmd_structure.py @@ -231,7 +231,7 @@ def import_ase(filename, dry_run): try: import ase.io except ImportError: - echo.echo_critical('You have not installed the package ase. \n' 'You can install it with: pip install ase') + echo.echo_critical('You have not installed the package ase. \nYou can install it with: pip install ase') try: asecell = ase.io.read(filename) diff --git a/aiida/cmdline/commands/cmd_devel.py b/aiida/cmdline/commands/cmd_devel.py index fb3287312f..40d9972ad2 100644 --- a/aiida/cmdline/commands/cmd_devel.py +++ b/aiida/cmdline/commands/cmd_devel.py @@ -48,6 +48,24 @@ def devel_check_load_time(): echo.echo_success('no issues detected') +@verdi_devel.command('check-undesired-imports') +def devel_check_undesired_imports(): + """Check that verdi does not import python modules it shouldn't. + + Note: The blacklist was taken from the list of packages in the 'atomic_tools' extra but can be extended. + """ + loaded_modules = 0 + + for modulename in ['seekpath', 'CifFile', 'ase', 'pymatgen', 'spglib', 'pymysql']: + if modulename in sys.modules: + echo.echo_warning('Detected loaded module "{}"'.format(modulename)) + loaded_modules += 1 + + if loaded_modules > 0: + echo.echo_critical('Detected {} unwanted modules'.format(loaded_modules)) + echo.echo_success('no issues detected') + + @verdi_devel.command('run_daemon') @decorators.with_dbenv() def devel_run_daemon(): diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py index 9baff6ad25..21596007c4 100644 --- a/aiida/cmdline/commands/cmd_export.py +++ b/aiida/cmdline/commands/cmd_export.py @@ -11,6 +11,7 @@ """`verdi export` command.""" import os +import tempfile import click import tabulate @@ -141,9 +142,10 @@ def create( @verdi_export.command('migrate') @arguments.INPUT_FILE() -@arguments.OUTPUT_FILE() +@arguments.OUTPUT_FILE(required=False) @options.ARCHIVE_FORMAT() @options.FORCE(help='overwrite output file if it already exists') +@click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') @options.SILENT() @click.option( '-v', @@ -151,9 +153,12 @@ def create( type=click.STRING, required=False, metavar='VERSION', - help='Specify an exact archive version to migrate to. By default the most recent version is taken.' + # Note: Adding aiida.tools.EXPORT_VERSION as a default value explicitly would result in a slow import of + # aiida.tools and, as a consequence, aiida.orm. As long as this is the case, better determine the latest export + # version inside the function when needed. + help='Archive format version to migrate to (defaults to latest version).', ) -def migrate(input_file, output_file, force, silent, archive_format, version): +def migrate(input_file, output_file, force, silent, in_place, archive_format, version): # pylint: disable=too-many-locals,too-many-statements,too-many-branches """Migrate an export archive to a more recent format version.""" import tarfile @@ -161,11 +166,21 @@ def migrate(input_file, output_file, force, silent, archive_format, version): from aiida.common import json from aiida.common.folders import SandboxFolder - from aiida.tools.importexport import EXPORT_VERSION, migration, extract_zip, extract_tar, ArchiveMigrationError + from aiida.tools.importexport import migration, extract_zip, extract_tar, ArchiveMigrationError, EXPORT_VERSION if version is None: version = EXPORT_VERSION + if in_place: + if output_file: + echo.echo_critical('output file specified together with --in-place flag') + tempdir = tempfile.TemporaryDirectory() + output_file = os.path.join(tempdir.name, 'archive.aiida') + elif not output_file: + echo.echo_critical( + 'no output file specified. Please add --in-place flag if you would like to migrate in place.' + ) + if os.path.exists(output_file) and not force: echo.echo_critical('the output file already exists') @@ -187,6 +202,10 @@ def migrate(input_file, output_file, force, silent, archive_format, version): echo.echo_critical('export archive does not contain the required file {}'.format(fhandle.filename)) old_version = migration.verify_metadata_version(metadata) + if version <= old_version: + echo.echo_success('nothing to be done - archive already at version {} >= {}'.format(old_version, version)) + return + try: new_version = migration.migrate_recursively(metadata, data, folder, version) except ArchiveMigrationError as exception: @@ -212,5 +231,9 @@ def migrate(input_file, output_file, force, silent, archive_format, version): with tarfile.open(output_file, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as archive: archive.add(folder.abspath, arcname='') + if in_place: + os.rename(output_file, input_file) + tempdir.cleanup() + if not silent: echo.echo_success('migrated the archive from version {} to {}'.format(old_version, new_version)) diff --git a/aiida/cmdline/commands/cmd_node.py b/aiida/cmdline/commands/cmd_node.py index 867660d60e..8c0e4b12ab 100644 --- a/aiida/cmdline/commands/cmd_node.py +++ b/aiida/cmdline/commands/cmd_node.py @@ -84,7 +84,7 @@ def repo_dump(node, output_directory): The output directory should not exist. If it does, the command will abort. """ - from aiida.orm.utils.repository import FileType + from aiida.repository import FileType output_directory = pathlib.Path(output_directory) @@ -98,18 +98,18 @@ def _copy_tree(key, output_dir): # pylint: disable=too-many-branches Recursively copy the content at the ``key`` path in the given node to the ``output_dir``. """ - for file in node.list_objects(key=key): + for file in node.list_objects(key): # Not using os.path.join here, because this is the "path" # in the AiiDA node, not an actual OS - level path. file_key = file.name if not key else key + '/' + file.name - if file.type == FileType.DIRECTORY: + if file.file_type == FileType.DIRECTORY: new_out_dir = output_dir / file.name assert not new_out_dir.exists() new_out_dir.mkdir() _copy_tree(key=file_key, output_dir=new_out_dir) else: - assert file.type == FileType.FILE + assert file.file_type == FileType.FILE out_file_path = output_dir / file.name assert not out_file_path.exists() with node.open(file_key, 'rb') as in_file: diff --git a/aiida/cmdline/commands/cmd_process.py b/aiida/cmdline/commands/cmd_process.py index d7dc5f0d00..16fd660f60 100644 --- a/aiida/cmdline/commands/cmd_process.py +++ b/aiida/cmdline/commands/cmd_process.py @@ -9,7 +9,6 @@ ########################################################################### # pylint: disable=too-many-arguments """`verdi process` command.""" - import click from kiwipy import communications @@ -37,6 +36,7 @@ def verdi_process(): @options.ALL(help='Show all entries, regardless of their process state.') @options.PROCESS_STATE() @options.PROCESS_LABEL() +@options.PAUSED() @options.EXIT_STATUS() @options.FAILED() @options.PAST_DAYS() @@ -44,8 +44,8 @@ def verdi_process(): @options.RAW() @decorators.with_dbenv() def process_list( - all_entries, group, process_state, process_label, exit_status, failed, past_days, limit, project, raw, order_by, - order_dir + all_entries, group, process_state, process_label, paused, exit_status, failed, past_days, limit, project, raw, + order_by, order_dir ): """Show a list of running or terminated processes. @@ -61,7 +61,7 @@ def process_list( relationships['with_node'] = group builder = CalculationQueryBuilder() - filters = builder.get_filters(all_entries, process_state, process_label, exit_status, failed) + filters = builder.get_filters(all_entries, process_state, process_label, paused, exit_status, failed) query_set = builder.get_query_set( relationships=relationships, filters=filters, order_by={order_by: order_dir}, past_days=past_days, limit=limit ) diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index fbfbf8b23c..3fb5159b13 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -33,11 +33,18 @@ @options_setup.SETUP_DATABASE_NAME() @options_setup.SETUP_DATABASE_USERNAME() @options_setup.SETUP_DATABASE_PASSWORD() +@options_setup.SETUP_BROKER_PROTOCOL() +@options_setup.SETUP_BROKER_USERNAME() +@options_setup.SETUP_BROKER_PASSWORD() +@options_setup.SETUP_BROKER_HOST() +@options_setup.SETUP_BROKER_PORT() +@options_setup.SETUP_BROKER_VIRTUAL_HOST() @options_setup.SETUP_REPOSITORY_URI() @options.CONFIG_FILE() def setup( non_interactive, profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, db_port, - db_name, db_username, db_password, repository + db_name, db_username, db_password, broker_protocol, broker_username, broker_password, broker_host, broker_port, + broker_virtual_host, repository ): """Setup a new profile.""" # pylint: disable=too-many-arguments,too-many-locals,unused-argument @@ -51,6 +58,12 @@ def setup( profile.database_hostname = db_host profile.database_username = db_username profile.database_password = db_password + profile.broker_protocol = broker_protocol + profile.broker_username = broker_username + profile.broker_password = broker_password + profile.broker_host = broker_host + profile.broker_port = broker_port + profile.broker_virtual_host = broker_virtual_host profile.repository_uri = 'file://' + repository config = get_config() @@ -113,12 +126,19 @@ def setup( @options_setup.QUICKSETUP_SUPERUSER_DATABASE_NAME() @options_setup.QUICKSETUP_SUPERUSER_DATABASE_USERNAME() @options_setup.QUICKSETUP_SUPERUSER_DATABASE_PASSWORD() +@options_setup.QUICKSETUP_BROKER_PROTOCOL() +@options_setup.QUICKSETUP_BROKER_USERNAME() +@options_setup.QUICKSETUP_BROKER_PASSWORD() +@options_setup.QUICKSETUP_BROKER_HOST() +@options_setup.QUICKSETUP_BROKER_PORT() +@options_setup.QUICKSETUP_BROKER_VIRTUAL_HOST() @options_setup.QUICKSETUP_REPOSITORY_URI() @options.CONFIG_FILE() @click.pass_context def quicksetup( ctx, non_interactive, profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, db_port, - db_name, db_username, db_password, su_db_name, su_db_username, su_db_password, repository + db_name, db_username, db_password, su_db_name, su_db_username, su_db_password, broker_protocol, broker_username, + broker_password, broker_host, broker_port, broker_virtual_host, repository ): """Setup a new profile in a fully automated fashion.""" # pylint: disable=too-many-arguments,too-many-locals @@ -166,6 +186,12 @@ def quicksetup( 'db_port': postgres.port_for_psycopg2, 'db_username': db_username, 'db_password': db_password, + 'broker_protocol': broker_protocol, + 'broker_username': broker_username, + 'broker_password': broker_password, + 'broker_host': broker_host, + 'broker_port': broker_port, + 'broker_virtual_host': broker_virtual_host, 'repository': repository, } ctx.invoke(setup, **setup_parameters) diff --git a/aiida/cmdline/commands/cmd_shell.py b/aiida/cmdline/commands/cmd_shell.py index 5927c7d749..bbbe5807dd 100644 --- a/aiida/cmdline/commands/cmd_shell.py +++ b/aiida/cmdline/commands/cmd_shell.py @@ -19,7 +19,7 @@ @verdi.command('shell') @decorators.with_dbenv() -@click.option('--plain', is_flag=True, help='Use a plain Python shell.)') +@click.option('--plain', is_flag=True, help='Use a plain Python shell.') @click.option( '--no-startup', is_flag=True, diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py index e1b9eabe4d..e021e85d4b 100644 --- a/aiida/cmdline/commands/cmd_status.py +++ b/aiida/cmdline/commands/cmd_status.py @@ -8,13 +8,16 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi status` command.""" +import enum import sys -import enum import click from aiida.cmdline.commands.cmd_verdi import verdi +from aiida.cmdline.params import options +from aiida.cmdline.utils import echo from aiida.common.log import override_log_level +from aiida.common.exceptions import IncompatibleDatabaseSchema from ..utils.echo import ExitCode @@ -47,13 +50,13 @@ class ServiceStatus(enum.IntEnum): @verdi.command('status') +@options.PRINT_TRACEBACK() @click.option('--no-rmq', is_flag=True, help='Do not check RabbitMQ status') -def verdi_status(no_rmq): +def verdi_status(print_traceback, no_rmq): """Print status of AiiDA services.""" - # pylint: disable=broad-except,too-many-statements + # pylint: disable=broad-except,too-many-statements,too-many-branches from aiida.cmdline.utils.daemon import get_daemon_status, delete_stale_pid_file from aiida.common.utils import Capturing - from aiida.manage.external.rmq import get_rmq_url from aiida.manage.manager import get_manager from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER @@ -64,19 +67,25 @@ def verdi_status(no_rmq): manager = get_manager() profile = manager.get_profile() + if profile is None: + print_status(ServiceStatus.WARNING, 'profile', 'no profile configured yet') + echo.echo_info('Configure a profile by running `verdi quicksetup` or `verdi setup`.') + return + try: profile = manager.get_profile() print_status(ServiceStatus.UP, 'profile', 'On profile {}'.format(profile.name)) except Exception as exc: - print_status(ServiceStatus.ERROR, 'profile', 'Unable to read AiiDA profile', exception=exc) + message = 'Unable to read AiiDA profile' + print_status(ServiceStatus.ERROR, 'profile', message, exception=exc, print_traceback=print_traceback) sys.exit(ExitCode.CRITICAL) # stop here - without a profile we cannot access anything # Getting the repository - repo_folder = 'undefined' try: repo_folder = profile.repository_path except Exception as exc: - print_status(ServiceStatus.ERROR, 'repository', 'Error with repo folder', exception=exc) + message = 'Error with repository folder' + print_status(ServiceStatus.ERROR, 'repository', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: print_status(ServiceStatus.UP, 'repository', repo_folder) @@ -87,8 +96,13 @@ def verdi_status(no_rmq): with override_log_level(): # temporarily suppress noisy logging backend = manager.get_backend() backend.cursor() - except Exception: - print_status(ServiceStatus.DOWN, 'postgres', 'Unable to connect as {}@{}:{}'.format(*database_data)) + except IncompatibleDatabaseSchema: + message = 'Database schema version is incompatible with the code: run `verdi database migrate`.' + print_status(ServiceStatus.DOWN, 'postgres', message) + exit_code = ExitCode.CRITICAL + except Exception as exc: + message = 'Unable to connect as {}@{}:{}'.format(*database_data) + print_status(ServiceStatus.DOWN, 'postgres', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: print_status(ServiceStatus.UP, 'postgres', 'Connected as {}@{}:{}'.format(*database_data)) @@ -101,10 +115,11 @@ def verdi_status(no_rmq): comm = manager.create_communicator(with_orm=False) comm.stop() except Exception as exc: - print_status(ServiceStatus.ERROR, 'rabbitmq', 'Unable to connect to rabbitmq', exception=exc) + message = 'Unable to connect to rabbitmq with URL: {}'.format(profile.get_rmq_url()) + print_status(ServiceStatus.ERROR, 'rabbitmq', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: - print_status(ServiceStatus.UP, 'rabbitmq', 'Connected to {}'.format(get_rmq_url())) + print_status(ServiceStatus.UP, 'rabbitmq', 'Connected as {}'.format(profile.get_rmq_url())) # Getting the daemon status try: @@ -117,17 +132,17 @@ def verdi_status(no_rmq): print_status(ServiceStatus.UP, 'daemon', daemon_status) else: print_status(ServiceStatus.WARNING, 'daemon', daemon_status) - exit_code = ExitCode.SUCCESS # A daemon that is not running is not a failure except Exception as exc: - print_status(ServiceStatus.ERROR, 'daemon', 'Error getting daemon status', exception=exc) + message = 'Error getting daemon status' + print_status(ServiceStatus.ERROR, 'daemon', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL # Note: click does not forward return values to the exit code, see https://github.com/pallets/click/issues/747 sys.exit(exit_code) -def print_status(status, service, msg='', exception=None): +def print_status(status, service, msg='', exception=None, print_traceback=False): """Print status message. Includes colored indicator. @@ -139,5 +154,10 @@ def print_status(status, service, msg='', exception=None): symbol = STATUS_SYMBOLS[status] click.secho(' {} '.format(symbol['string']), fg=symbol['color'], nl=False) click.secho('{:12s} {}'.format(service + ':', msg)) + if exception is not None: - click.echo(exception, err=True) + echo.echo_error('{}: {}'.format(type(exception).__name__, exception)) + + if print_traceback: + import traceback + traceback.print_exc() diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index 9e2f46d616..2fc187189d 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -8,12 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with pre-defined reusable commandline options that can be used as `click` decorators.""" - import click -# Note: importing from aiida.manage.postgres leads to circular imports from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA +from aiida.manage.external.rmq import BROKER_DEFAULTS from ...utils import defaults, echo from .. import types from .multivalue import MultipleValueOption @@ -30,7 +29,7 @@ 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG' + 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' ) TRAVERSAL_RULE_HELP_STRING = { @@ -205,7 +204,11 @@ def decorator(command): ) NON_INTERACTIVE = OverridableOption( - '-n', '--non-interactive', is_flag=True, is_eager=True, help='Non-interactive mode: never prompt for input.' + '-n', + '--non-interactive', + is_flag=True, + is_eager=True, + help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' ) DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') @@ -269,6 +272,55 @@ def decorator(command): DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') +BROKER_PROTOCOL = OverridableOption( + '--broker-protocol', + type=click.Choice(('amqp', 'amqps')), + default=BROKER_DEFAULTS.protocol, + show_default=True, + help='Protocol to use for the message broker.' +) + +BROKER_USERNAME = OverridableOption( + '--broker-username', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.username, + show_default=True, + help='Username to use for authentication with the message broker.' +) + +BROKER_PASSWORD = OverridableOption( + '--broker-password', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.password, + show_default=True, + help='Password to use for authentication with the message broker.', + hide_input=True, +) + +BROKER_HOST = OverridableOption( + '--broker-host', + type=types.HostnameType(), + default=BROKER_DEFAULTS.host, + show_default=True, + help='Hostname for the message broker.' +) + +BROKER_PORT = OverridableOption( + '--broker-port', + type=click.INT, + default=BROKER_DEFAULTS.port, + show_default=True, + help='Port for the message broker.', +) + +BROKER_VIRTUAL_HOST = OverridableOption( + '--broker-virtual-host', + type=types.HostnameType(), + default=BROKER_DEFAULTS.virtual_host, + show_default=True, + help='Name of the virtual host for the message broker. Forward slashes need to be encoded' +) + REPOSITORY_PATH = OverridableOption( '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' ) @@ -324,6 +376,8 @@ def decorator(command): help='Only include entries with this process state.' ) +PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') + PROCESS_LABEL = OverridableOption( '-L', '--process-label', @@ -428,11 +482,19 @@ def decorator(command): HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') TRANSPORT = OverridableOption( - '-T', '--transport', type=types.PluginParamType(group='transports'), required=True, help='Transport type.' + '-T', + '--transport', + type=types.PluginParamType(group='transports'), + required=True, + help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." ) SCHEDULER = OverridableOption( - '-S', '--scheduler', type=types.PluginParamType(group='schedulers'), required=True, help='Scheduler type.' + '-S', + '--scheduler', + type=types.PluginParamType(group='schedulers'), + required=True, + help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." ) USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') @@ -526,3 +588,10 @@ def decorator(command): DEBUG = OverridableOption( '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True ) + +PRINT_TRACEBACK = OverridableOption( + '-t', + '--print-traceback', + is_flag=True, + help='Print the full traceback in case an exception is raised.', +) diff --git a/aiida/cmdline/params/options/commands/code.py b/aiida/cmdline/params/options/commands/code.py index 00f1cace38..39de20ad4e 100644 --- a/aiida/cmdline/params/options/commands/code.py +++ b/aiida/cmdline/params/options/commands/code.py @@ -8,11 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Reusable command line interface options for Code commands.""" - import click from aiida.cmdline.params import options, types -from aiida.cmdline.params.options.interactive import InteractiveOption +from aiida.cmdline.params.options.interactive import InteractiveOption, TemplateInteractiveOption from aiida.cmdline.params.options.overridable import OverridableOption @@ -30,7 +29,8 @@ def is_not_on_computer(ctx): default=True, cls=InteractiveOption, prompt='Installed on target computer?', - help='Whether the code is installed on the target computer or should be copied each time from a local path.' + help='Whether the code is installed on the target computer, or should be copied to the target computer each time ' + 'from a local path.' ) REMOTE_ABS_PATH = OverridableOption( @@ -40,7 +40,7 @@ def is_not_on_computer(ctx): prompt_fn=is_on_computer, type=types.AbsolutePathParamType(dir_okay=False), cls=InteractiveOption, - help=('[if --on-computer]: the absolute path to the executable on the remote machine.') + help='[if --on-computer]: Absolute path to the executable on the target computer.' ) FOLDER = OverridableOption( @@ -50,7 +50,8 @@ def is_not_on_computer(ctx): prompt_fn=is_not_on_computer, type=click.Path(file_okay=False, exists=True, readable=True), cls=InteractiveOption, - help=('[if --store-in-db]: directory containing the executable and all other files necessary for running it.') + help='[if --store-in-db]: Absolute path to directory containing the executable and all other files necessary for ' + 'running it (to be copied to target computer).' ) REL_PATH = OverridableOption( @@ -60,19 +61,26 @@ def is_not_on_computer(ctx): prompt_fn=is_not_on_computer, type=click.Path(dir_okay=False), cls=InteractiveOption, - help=('[if --store-in-db]: relative path of the executable inside the code-folder.') + help='[if --store-in-db]: Relative path of the executable inside the code-folder.' ) -LABEL = options.LABEL.clone(prompt='Label', cls=InteractiveOption, help='A label to refer to this code.') +LABEL = options.LABEL.clone( + prompt='Label', + cls=InteractiveOption, + help="This label can be used to identify the code (using 'label@computerlabel'), as long as labels are unique per " + 'computer.' +) DESCRIPTION = options.DESCRIPTION.clone( - prompt='Description', cls=InteractiveOption, help='A human-readable description of this code.' + prompt='Description', + cls=InteractiveOption, + help='A human-readable description of this code, ideally including version and compilation environment.' ) INPUT_PLUGIN = options.INPUT_PLUGIN.clone( prompt='Default calculation input plugin', cls=InteractiveOption, - help='Default calculation plugin to use for this code.' + help="Entry point name of the default calculation plugin (as listed in 'verdi plugin list aiida.calculations')." ) COMPUTER = options.COMPUTER.clone( @@ -80,5 +88,31 @@ def is_not_on_computer(ctx): cls=InteractiveOption, required_fn=is_on_computer, prompt_fn=is_on_computer, - help='Name of the computer, on which the code resides.' + help='Name of the computer, on which the code is installed.' +) + +PREPEND_TEXT = OverridableOption( + '--prepend-text', + cls=TemplateInteractiveOption, + prompt='Prepend script', + type=click.STRING, + default='', + help='Bash commands that should be prepended to the executable call in all submit scripts for this code.', + extension='.bash', + header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' + 'submit scripts for this code, type that between the equal signs below and save the file.', + footer='All lines that start with `#=` will be ignored.' +) + +APPEND_TEXT = OverridableOption( + '--append-text', + cls=TemplateInteractiveOption, + prompt='Append script', + type=click.STRING, + default='', + help='Bash commands that should be appended to the executable call in all submit scripts for this code.', + extension='.bash', + header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' + 'submit scripts for this code, type that between the equal signs below and save the file.', + footer='All lines that start with `#=` will be ignored.' ) diff --git a/aiida/cmdline/params/options/commands/computer.py b/aiida/cmdline/params/options/commands/computer.py index 5a39fc99b1..0a98049b4f 100644 --- a/aiida/cmdline/params/options/commands/computer.py +++ b/aiida/cmdline/params/options/commands/computer.py @@ -8,11 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Reusable command line interface options for Computer commands.""" - import click from aiida.cmdline.params import options, types -from aiida.cmdline.params.options.interactive import InteractiveOption +from aiida.cmdline.params.options.interactive import InteractiveOption, TemplateInteractiveOption from aiida.cmdline.params.options.overridable import OverridableOption @@ -45,13 +44,19 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na return job_resource_cls.accepts_default_mpiprocs_per_machine() -LABEL = options.LABEL.clone(prompt='Computer label', cls=InteractiveOption, required=True) +LABEL = options.LABEL.clone( + prompt='Computer label', + cls=InteractiveOption, + required=True, + help='Unique, human-readable label for this computer.' +) HOSTNAME = options.HOSTNAME.clone( prompt='Hostname', cls=InteractiveOption, required=True, - help='The fully qualified hostname of the computer; use "localhost" for local transports.', + help='The fully qualified hostname of the computer (e.g. daint.cscs.ch). ' + 'Use "localhost" when setting up the computer that AiiDA is running on.', ) DESCRIPTION = options.DESCRIPTION.clone( @@ -67,7 +72,7 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na prompt='Shebang line (first line of each script, starting with #!)', default='#!/bin/bash', cls=InteractiveOption, - help='This line specifies the first line of the submission script for this computer.', + help='Specify the first line of the submission script for this computer (only the bash shell is supported).', type=types.ShebangParamType() ) @@ -78,9 +83,8 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na default='/scratch/{username}/aiida/', cls=InteractiveOption, help='The absolute path of the directory on the computer where AiiDA will ' - 'run the calculations (typically, the scratch of the computer). You ' - 'can use the {username} replacement, that will be replaced by your ' - 'username on the remote computer.' + 'run the calculations (often a "scratch" directory).' + 'The {username} string will be replaced by your username on the remote computer.' ) MPI_RUN_COMMAND = OverridableOption( @@ -89,10 +93,8 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na prompt='Mpirun command', default='mpirun -np {tot_num_mpiprocs}', cls=InteractiveOption, - help='The mpirun command needed on the cluster to run parallel MPI ' - 'programs. You can use the {tot_num_mpiprocs} replacement, that will be ' - 'replaced by the total number of cpus, or the other scheduler-dependent ' - 'replacement fields (see the scheduler docs for more information).', + help='The mpirun command needed on the cluster to run parallel MPI programs. The {tot_num_mpiprocs} string will be ' + 'replaced by the total number of cpus. See the scheduler docs for further scheduler-dependent template variables.', type=types.MpirunCommandParamType() ) @@ -103,7 +105,32 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na prompt_fn=should_call_default_mpiprocs_per_machine, required_fn=False, type=click.INT, - help='Enter here the default number of MPI processes per machine (node) that ' - 'should be used if nothing is otherwise specified. Pass the digit 0 ' - 'if you do not want to provide a default value.', + help='The default number of MPI processes that should be executed per machine (node), if not otherwise specified.' + 'Use 0 to specify no default value.', +) + +PREPEND_TEXT = OverridableOption( + '--prepend-text', + cls=TemplateInteractiveOption, + prompt='Prepend script', + type=click.STRING, + default='', + help='Bash commands that should be prepended to the executable call in all submit scripts for this computer.', + extension='.bash', + header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' + 'submit scripts for this computer, type that between the equal signs below and save the file.', + footer='All lines that start with `#=` will be ignored.' +) + +APPEND_TEXT = OverridableOption( + '--append-text', + cls=TemplateInteractiveOption, + prompt='Append script', + type=click.STRING, + default='', + help='Bash commands that should be appended to the executable call in all submit scripts for this computer.', + extension='.bash', + header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' + 'submit scripts for this computer, type that between the equal signs below and save the file.', + footer='All lines that start with `#=` will be ignored.' ) diff --git a/aiida/cmdline/params/options/commands/setup.py b/aiida/cmdline/params/options/commands/setup.py index 3fffab2102..fb57a0cb6c 100644 --- a/aiida/cmdline/params/options/commands/setup.py +++ b/aiida/cmdline/params/options/commands/setup.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Reusable command line interface options for the setup commands.""" - import functools import getpass import hashlib @@ -229,6 +228,18 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume default=DEFAULT_DBINFO['password'], ) +QUICKSETUP_BROKER_PROTOCOL = options.BROKER_PROTOCOL + +QUICKSETUP_BROKER_USERNAME = options.BROKER_USERNAME + +QUICKSETUP_BROKER_PASSWORD = options.BROKER_PASSWORD + +QUICKSETUP_BROKER_HOST = options.BROKER_HOST + +QUICKSETUP_BROKER_PORT = options.BROKER_PORT + +QUICKSETUP_BROKER_VIRTUAL_HOST = options.BROKER_VIRTUAL_HOST + QUICKSETUP_REPOSITORY_URI = options.REPOSITORY_PATH.clone( callback=get_quicksetup_repository_uri # Cannot use `default` because `ctx` is needed to determine the default ) @@ -278,6 +289,48 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume cls=options.interactive.InteractiveOption ) +SETUP_BROKER_PROTOCOL = QUICKSETUP_BROKER_PROTOCOL.clone( + prompt='Broker protocol', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_protocol', None)), + cls=options.interactive.InteractiveOption +) + +SETUP_BROKER_USERNAME = QUICKSETUP_BROKER_USERNAME.clone( + prompt='Broker username', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_username', None)), + cls=options.interactive.InteractiveOption +) + +SETUP_BROKER_PASSWORD = QUICKSETUP_BROKER_PASSWORD.clone( + prompt='Broker password', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_password', None)), + cls=options.interactive.InteractiveOption +) + +SETUP_BROKER_HOST = QUICKSETUP_BROKER_HOST.clone( + prompt='Broker host', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_host', None)), + cls=options.interactive.InteractiveOption +) + +SETUP_BROKER_PORT = QUICKSETUP_BROKER_PORT.clone( + prompt='Broker port', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_port', None)), + cls=options.interactive.InteractiveOption +) + +SETUP_BROKER_VIRTUAL_HOST = QUICKSETUP_BROKER_VIRTUAL_HOST.clone( + prompt='Broker virtual host name', + required=True, + contextual_default=functools.partial(get_profile_attribute_default, ('broker_virtual_host', None)), + cls=options.interactive.InteractiveOption +) + SETUP_REPOSITORY_URI = QUICKSETUP_REPOSITORY_URI.clone( prompt='Repository directory', callback=None, # Unset the `callback` to define the default, which is instead done by the `contextual_default` diff --git a/aiida/cmdline/params/options/interactive.py b/aiida/cmdline/params/options/interactive.py index 8006d1b6fb..38f50f5e27 100644 --- a/aiida/cmdline/params/options/interactive.py +++ b/aiida/cmdline/params/options/interactive.py @@ -12,7 +12,6 @@ :synopsis: Tools and an option class for interactive parameter entry with additional features such as help lookup. """ - import click from aiida.cmdline.utils import echo @@ -284,6 +283,34 @@ def prompt_callback(self, ctx, param, value): return self.after_callback(ctx, param, value) +class TemplateInteractiveOption(InteractiveOption): + """Sub class of ``InteractiveOption`` that uses template file for input instead of simple inline prompt. + + This is useful for options that need to be able to specify multiline string values. + """ + + def __init__(self, param_decls=None, **kwargs): + """Define the configuration for the multiline template in the keyword arguments. + + :param template: name of the template to use from the ``aiida.cmdline.templates`` directory. + Default is the 'multiline.tpl' template. + :param header: string to put in the header of the template. + :param footer: string to put in the footer of the template. + :param extension: file extension to give to the template file. + """ + self.template = kwargs.pop('template', 'multiline.tpl') + self.header = kwargs.pop('header', '') + self.footer = kwargs.pop('footer', '') + self.extension = kwargs.pop('extension', '') + super().__init__(param_decls=param_decls, **kwargs) + + def prompt_func(self, ctx): + """Replace the basic prompt with a method that opens a template file in an editor.""" + from aiida.cmdline.utils.multi_line_input import edit_multiline_template + kwargs = {'value': self._get_default(ctx) or '', 'header': self.header, 'footer': self.footer} + return edit_multiline_template(self.template, extension=self.extension, **kwargs) + + def opt_prompter(ctx, cmd, givenkwargs, oldvalues=None): """ Prompt interactively for the value of an option of the command with context ``ctx``. diff --git a/aiida/cmdline/templates/multiline.tpl b/aiida/cmdline/templates/multiline.tpl new file mode 100644 index 0000000000..beaee1ead3 --- /dev/null +++ b/aiida/cmdline/templates/multiline.tpl @@ -0,0 +1,9 @@ +{% if header %}#={{'='*72}}=# +#= {{header|wordwrap(71, wrapstring='\n#= ')}} +#={{'='*72}}=# +{% endif %} +{{value}} +{% if footer %}#={{'='*72}}=# +#= {{footer|wordwrap(71, wrapstring='\n#= ')}} +#={{'='*72}}=# +{% endif %} diff --git a/aiida/cmdline/templates/multival.tpl b/aiida/cmdline/templates/multival.tpl deleted file mode 100644 index eaeb4704dc..0000000000 --- a/aiida/cmdline/templates/multival.tpl +++ /dev/null @@ -1,10 +0,0 @@ -{% for title, defaults in docs %} -#={{'='*50}}=# -#= {{title|center(48)}} =# -#={{'='*50}}=# - -{{defaults}} - -{% endfor %} -#={{'='*50}}=# -{{'#= ' + helpmsg|wordwrap(50, wrapstring='\n#= ')}} diff --git a/aiida/cmdline/templates/prepost.bash.tpl b/aiida/cmdline/templates/prepost.bash.tpl deleted file mode 100644 index a0c174bad8..0000000000 --- a/aiida/cmdline/templates/prepost.bash.tpl +++ /dev/null @@ -1,19 +0,0 @@ -#={{'='*50}}=# -#= {{'Pre execution script'|center(48)}} =# -#={{'='*50}}=# - -{{default_pre}} - -{{separator}} -{{default_post}} - -#={{'='*50}}=# -{{('#= Lines starting with "#=" will be ignored! Pre and post execution scripts are executed on ' - 'the remote computer before and after execution of the code(s), respectively. AiiDA expects ' - 'valid bash code.')|wordwrap(50, wrapstring='\n#= ')}} -#= -#={{'='*50}}=# -#= {{'Summary of config so far'|center(48)}} =# -#={{'='*50}}=# -{% for k, v in summary.items() %}#= {{k.ljust(20)}}: {{v}} -{% endfor %} diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 531ed78eec..afaf288ffc 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -123,7 +123,7 @@ def get_node_summary(node): pass else: if computer is not None: - table.append(['computer', '[{}] {}'.format(node.computer.pk, node.computer.name)]) + table.append(['computer', '[{}] {}'.format(node.computer.pk, node.computer.label)]) return tabulate(table, headers=table_headers) @@ -455,7 +455,7 @@ def build_entries(ports): click.secho('{:>{width_name}d}: {}'.format(exit_code.status, message, width_name=max_width_name)) -def get_num_workers(): #pylint: disable=inconsistent-return-statements +def get_num_workers(): """ Get the number of active daemon workers from the circus client """ diff --git a/aiida/cmdline/utils/multi_line_input.py b/aiida/cmdline/utils/multi_line_input.py index 32632fa616..97c3e440de 100644 --- a/aiida/cmdline/utils/multi_line_input.py +++ b/aiida/cmdline/utils/multi_line_input.py @@ -7,77 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -utilities for getting multi line input from the commandline -""" -import click -from aiida.common.exceptions import InputValidationError - - -def ensure_scripts(pre, post, summary): - """ - A function to check if the prepend and append scripts were specified, and - if needed ask to edit them. +"""Utilities for getting multi line input from the commandline.""" +import re - :param pre: prepend-text - :param post: append-text - :param summary: summary for click template - :return: - """ - if not pre or not post: - return edit_pre_post(pre, post, summary) +import click - return pre, post +def edit_multiline_template(template_name, comment_marker='#=', extension=None, **kwargs): + """Open a template file for editing in a text editor. -def edit_pre_post(pre=None, post=None, summary=None): - """ - use click to call up an editor to write or edit pre / post - execution scripts for both codes and computers + :param template: name of the template to use from the ``aiida.cmdline.templates`` directory. + :param comment_marker: the set of symbols that mark a comment line that should be stripped from the final value + :param extension: the file extension to give to the rendered template file. + :param kwargs: keywords that will be passed to the template rendering engine. + :return: the final string value entered in the editor with all comment lines stripped. """ from aiida.cmdline.utils.templates import env - template = env.get_template('prepost.bash.tpl') - summary = summary or {} - summary = {k: v for k, v in summary.items() if v} - - # Define a separator that will be splitting pre- and post- execution - # parts of the submission script - separator = '#====================================================#\n' \ - '#= Post execution script =#\n' \ - '#= I am acting as a separator, do not modify me!!! =#\n' \ - '#====================================================#\n' + template = env.get_template(template_name) + rendered = template.render(**kwargs) + content = click.edit(rendered, extension=extension) - content = template.render(default_pre=pre or '', separator=separator, default_post=post or '', summary=summary) - mlinput = click.edit(content, extension='.bash') - if mlinput: - import re + if content: + # Remove all comments, which are all lines that start with the comment marker + value = re.sub(r'(^' + re.escape(comment_marker) + '.*$\n)+', '', content, flags=re.M).strip() - # Splitting the text in pre- and post- halfs - try: - pre, post = mlinput.split(separator) - except ValueError as err: - if str(err) == 'need more than 1 value to unpack': - raise InputValidationError( - 'Looks like you modified the ' - 'separator that should NOT be modified. Please be ' - 'careful!' - ) - elif str(err) == 'too many values to unpack': - raise InputValidationError( - 'Looks like you have more than one ' - 'separator, while only one is needed ' - '(and allowed). Please be careful!' - ) - else: - raise err - - # Removing all the comments starting from '#=' in both pre- and post- - # parts - pre = re.sub(r'(^#=.*$\n)+', '', pre, flags=re.M).strip() - post = re.sub(r'(^#=.*$\n)+', '', post, flags=re.M).strip() - else: - pre, post = (pre or '', post or '') - return pre, post + return value def edit_comment(old_cmt=''): @@ -89,7 +43,6 @@ def edit_comment(old_cmt=''): content = template.render(old_comment=old_cmt) mlinput = click.edit(content, extension='.txt') if mlinput: - import re regex = r'^(?!#=)(.*)$' cmt = '\n'.join(re.findall(regex, mlinput, flags=re.M)) cmt = cmt.strip('\n') diff --git a/aiida/cmdline/utils/query/calculation.py b/aiida/cmdline/utils/query/calculation.py index 6cdc7d8c57..b2026baebf 100644 --- a/aiida/cmdline/utils/query/calculation.py +++ b/aiida/cmdline/utils/query/calculation.py @@ -47,6 +47,7 @@ def get_filters( all_entries=False, process_state=None, process_label=None, + paused=False, exit_status=None, failed=False, node_types=None @@ -58,6 +59,7 @@ def get_filters( :param all_entries: boolean to negate filtering for process state :param process_state: filter for this process state attribute :param process_label: filter for this process label attribute + :param paused: boolean, if True, filter for processes that are paused :param exit_status: filter for this exit status :param failed: boolean to filter only failed processes :return: dictionary of filters suitable for a QueryBuilder.append() call @@ -68,6 +70,7 @@ def get_filters( exit_status_attribute = self.mapper.get_attribute('exit_status') process_label_attribute = self.mapper.get_attribute('process_label') process_state_attribute = self.mapper.get_attribute('process_state') + paused_attribute = self.mapper.get_attribute('paused') filters = {} @@ -85,6 +88,9 @@ def get_filters( else: filters[process_label_attribute] = process_label + if paused: + filters[paused_attribute] = True + if failed: filters[process_state_attribute] = {'==': ProcessState.FINISHED.value} filters[exit_status_attribute] = {'>': 0} diff --git a/aiida/cmdline/utils/query/mapping.py b/aiida/cmdline/utils/query/mapping.py index ffed7028ad..f469c75dfb 100644 --- a/aiida/cmdline/utils/query/mapping.py +++ b/aiida/cmdline/utils/query/mapping.py @@ -12,8 +12,7 @@ class ProjectionMapper: - """ - Class to map projection names from the CLI to entity labels, attributes and formatters. + """Class to map projection names from the CLI to entity labels, attributes and formatters. The command line interface will often have to display database entities and their attributes. The names of the attributes exposed on the CLI do not always match one-to-one with the attributes in the ORM and often @@ -28,7 +27,6 @@ class ProjectionMapper: _valid_projections = [] def __init__(self, projection_labels=None, projection_attributes=None, projection_formatters=None): - # pylint: disable=unused-variable,undefined-variable if not self._valid_projections: raise NotImplementedError('no valid projections were specified by the sub class') @@ -108,7 +106,6 @@ def __init__(self, projections, projection_labels=None, projection_attributes=No 'exit_status': exit_status_key, } - # pylint: disable=line-too-long default_formatters = { 'ctime': lambda value: formatting.format_relative_time(value['ctime']), diff --git a/aiida/cmdline/utils/repository.py b/aiida/cmdline/utils/repository.py index 0732f7f873..161f56845f 100644 --- a/aiida/cmdline/utils/repository.py +++ b/aiida/cmdline/utils/repository.py @@ -18,9 +18,9 @@ def list_repository_contents(node, path, color): :param path: directory path :raises FileNotFoundError: if the `path` does not exist in the repository of the given node """ - from aiida.orm.utils.repository import FileType + from aiida.repository import FileType for entry in node.list_objects(path): - bold = bool(entry.type == FileType.DIRECTORY) - fg = 'blue' if color and entry.type == FileType.DIRECTORY else None + bold = bool(entry.file_type == FileType.DIRECTORY) + fg = 'blue' if color and entry.file_type == FileType.DIRECTORY else None click.secho(entry.name, bold=bold, fg=fg) diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 61bd147e52..e10a7cca22 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to define commonly used data structures.""" - from enum import Enum, IntEnum from .extendeddicts import DefaultFieldsAttributeDict @@ -78,7 +77,7 @@ class CalcInfo(DefaultFieldsAttributeDict): """ _default_fields = ( - 'job_environment', # TODO UNDERSTAND THIS! + 'job_environment', 'email', 'email_on_started', 'email_on_terminated', diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 7b66b24ffc..b7d59eef30 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -58,7 +58,7 @@ using_sysrandom = False # pylint: disable=invalid-name -def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz' 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): +def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): """ Returns a securely generated random string. diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index 7cafc5f2ef..e239a51728 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -17,7 +17,7 @@ import socket import tempfile -from aiida.manage.configuration import get_config +from aiida.manage.configuration import get_config, get_config_option VERDI_BIN = shutil.which('verdi') # Recent versions of virtualenv create the environment variable VIRTUAL_ENV @@ -63,7 +63,6 @@ class DaemonClient: # pylint: disable=too-many-public-methods DAEMON_ERROR_TIMEOUT = 'daemon-error-timeout' _DAEMON_NAME = 'aiida-{name}' - _DEFAULT_LOGLEVEL = 'INFO' _ENDPOINT_PROTOCOL = ControllerProtocol.IPC def __init__(self, profile): @@ -103,7 +102,7 @@ def cmd_string(self): @property def loglevel(self): - return self._DEFAULT_LOGLEVEL + return get_config_option('logging.circus_loglevel') @property def virtualenv(self): diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 1fc1613a92..d823ea194f 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -14,9 +14,9 @@ plugin-specific operations. """ import os +import shutil from aiida.common import AIIDA_LOGGER, exceptions -from aiida.common.datastructures import CalcJobState from aiida.common.folders import SandboxFolder from aiida.common.links import LinkType from aiida.orm import FolderData, Node @@ -37,6 +37,7 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= :param calc_info: the calculation info datastructure returned by `CalcJob.presubmit` :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements from logging import LoggerAdapter from tempfile import NamedTemporaryFile from aiida.orm import load_node, Code, RemoteData @@ -59,21 +60,23 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= logger = LoggerAdapter(logger=execlogger, extra=logger_extra) if not dry_run and node.has_cached_links(): - raise ValueError('Cannot submit calculation {} because it has cached input links! If you just want to test the ' - 'submission, set `metadata.dry_run` to True in the inputs.'.format(node.pk)) + raise ValueError( + 'Cannot submit calculation {} because it has cached input links! If you just want to test the ' + 'submission, set `metadata.dry_run` to True in the inputs.'.format(node.pk) + ) # If we are performing a dry-run, the working directory should actually be a local folder that should already exist if dry_run: workdir = transport.getcwd() else: remote_user = transport.whoami() - # TODO Doc: {username} field - # TODO: if something is changed here, fix also 'verdi computer test' remote_working_directory = computer.get_workdir().format(username=remote_user) if not remote_working_directory.strip(): raise exceptions.ConfigurationError( "[submission of calculation {}] No remote_working_directory configured for computer '{}'".format( - node.pk, computer.name)) + node.pk, computer.label + ) + ) # If it already exists, no exception is raised try: @@ -81,7 +84,9 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= except IOError: logger.debug( '[submission of calculation {}] Unable to chdir in {}, trying to create it'.format( - node.pk, remote_working_directory)) + node.pk, remote_working_directory + ) + ) try: transport.makedirs(remote_working_directory) transport.chdir(remote_working_directory) @@ -89,8 +94,8 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= raise exceptions.ConfigurationError( '[submission of calculation {}] ' 'Unable to create the remote directory {} on ' - "computer '{}': {}".format( - node.pk, remote_working_directory, computer.name, exc)) + "computer '{}': {}".format(node.pk, remote_working_directory, computer.label, exc) + ) # Store remotely with sharding (here is where we choose # the folder structure of remote jobs; then I store this # in the calculation properties using _set_remote_dir @@ -112,8 +117,11 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= path_existing = os.path.join(transport.getcwd(), calc_info.uuid[4:]) path_lost_found = os.path.join(remote_working_directory, REMOTE_WORK_DIRECTORY_LOST_FOUND) path_target = os.path.join(path_lost_found, calc_info.uuid) - logger.warning('tried to create path {} but it already exists, moving the entire folder to {}'.format( - path_existing, path_target)) + logger.warning( + 'tried to create path {} but it already exists, moving the entire folder to {}'.format( + path_existing, path_target + ) + ) # Make sure the lost+found directory exists, then copy the existing folder there and delete the original transport.mkdir(path_lost_found, ignore_existing=True) @@ -136,27 +144,22 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= for code in input_codes: if code.is_local(): # Note: this will possibly overwrite files - for f in code.list_object_names(): + for filename in code.list_object_names(): # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` # Since the content of the node could potentially be binary, we read the raw bytes and pass them on with NamedTemporaryFile(mode='wb+') as handle: - handle.write(code.get_object_content(f, mode='rb')) + handle.write(code.get_object_content(filename, mode='rb')) handle.flush() - transport.put(handle.name, f) + transport.put(handle.name, filename) transport.chmod(code.get_local_executable(), 0o755) # rwxr-xr-x - # In a dry_run, the working directory is the raw input folder, which will already contain these resources - if not dry_run: - for filename in folder.get_content_list(): - logger.debug('[submission of calculation {}] copying file/folder {}...'.format(node.pk, filename)) - transport.put(folder.get_abs_path(filename), filename) - # local_copy_list is a list of tuples, each with (uuid, dest_rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() local_copy_list = calc_info.local_copy_list or [] remote_copy_list = calc_info.remote_copy_list or [] remote_symlink_list = calc_info.remote_symlink_list or [] + provenance_exclude_list = calc_info.provenance_exclude_list or [] for uuid, filename, target in local_copy_list: logger.debug('[submission of calculation {}] copying local file/folder to {}'.format(node.uuid, target)) @@ -168,10 +171,10 @@ def find_data_node(inputs, uuid): :param uuid: UUID of the node to find :return: instance of `Node` or `None` if not found """ - from collections import Mapping + from collections.abc import Mapping data_node = None - for link_label, input_node in inputs.items(): + for input_node in inputs.values(): if isinstance(input_node, Mapping): data_node = find_data_node(input_node, uuid) elif isinstance(input_node, Node) and input_node.uuid == uuid: @@ -189,59 +192,80 @@ def find_data_node(inputs, uuid): if data_node is None: logger.warning('failed to load Node<{}> specified in the `local_copy_list`'.format(uuid)) else: - # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in - # combination with the new `Transport.put_object_from_filelike` - # Since the content of the node could potentially be binary, we read the raw bytes and pass them on - with NamedTemporaryFile(mode='wb+') as handle: - handle.write(data_node.get_object_content(filename, mode='rb')) - handle.flush() - transport.put(handle.name, target) - - if dry_run: - if remote_copy_list: - with open(os.path.join(workdir, '_aiida_remote_copy_list.txt'), 'w') as handle: - for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: - handle.write('would have copied {} to {} in working directory on remote {}'.format( - remote_abs_path, dest_rel_path, computer.name)) + dirname = os.path.dirname(target) + if dirname: + os.makedirs(os.path.join(folder.abspath, dirname), exist_ok=True) + with folder.open(target, 'wb') as handle: + with data_node.open(filename, 'rb') as source: + shutil.copyfileobj(source, handle) + provenance_exclude_list.append(target) - if remote_symlink_list: - with open(os.path.join(workdir, '_aiida_remote_symlink_list.txt'), 'w') as handle: - for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: - handle.write('would have created symlinks from {} to {} in working directory on remote {}'.format( - remote_abs_path, dest_rel_path, computer.name)) - - else: + # In a dry_run, the working directory is the raw input folder, which will already contain these resources + if not dry_run: + for filename in folder.get_content_list(): + logger.debug('[submission of calculation {}] copying file/folder {}...'.format(node.pk, filename)) + transport.put(folder.get_abs_path(filename), filename) for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_copy_list: if remote_computer_uuid == computer.uuid: - logger.debug('[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( - node.pk, dest_rel_path, computer.name)) + logger.debug( + '[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( + node.pk, dest_rel_path, computer.label + ) + ) try: transport.copy(remote_abs_path, dest_rel_path) except (IOError, OSError): - logger.warning('[submission of calculation {}] Unable to copy remote resource from {} to {}! ' - 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path)) + logger.warning( + '[submission of calculation {}] Unable to copy remote resource from {} to {}! ' + 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path) + ) raise else: raise NotImplementedError( '[submission of calculation {}] Remote copy between two different machines is ' - 'not implemented yet'.format(node.pk)) + 'not implemented yet'.format(node.pk) + ) for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_symlink_list: if remote_computer_uuid == computer.uuid: - logger.debug('[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( - node.pk, dest_rel_path, computer.name)) + logger.debug( + '[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( + node.pk, dest_rel_path, computer.label + ) + ) try: transport.symlink(remote_abs_path, dest_rel_path) except (IOError, OSError): - logger.warning('[submission of calculation {}] Unable to create remote symlink from {} to {}! ' - 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path)) + logger.warning( + '[submission of calculation {}] Unable to create remote symlink from {} to {}! ' + 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path) + ) raise else: - raise IOError('It is not possible to create a symlink between two different machines for ' - 'calculation {}'.format(node.pk)) + raise IOError( + 'It is not possible to create a symlink between two different machines for ' + 'calculation {}'.format(node.pk) + ) + else: - provenance_exclude_list = calc_info.provenance_exclude_list or [] + if remote_copy_list: + with open(os.path.join(workdir, '_aiida_remote_copy_list.txt'), 'w') as handle: + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: + handle.write( + 'would have copied {} to {} in working directory on remote {}'.format( + remote_abs_path, dest_rel_path, computer.label + ) + ) + + if remote_symlink_list: + with open(os.path.join(workdir, '_aiida_remote_symlink_list.txt'), 'w') as handle: + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: + handle.write( + 'would have created symlinks from {} to {} in working directory on remote {}'.format( + remote_abs_path, dest_rel_path, computer.label + ) + ) # Loop recursively over content of the sandbox folder copying all that are not in `provenance_exclude_list`. Note # that directories are not created explicitly. The `node.put_object_from_filelike` call will create intermediate @@ -250,13 +274,13 @@ def find_data_node(inputs, uuid): # advantage of this explicit copying instead of deleting the files from `provenance_exclude_list` from the sandbox # first before moving the entire remaining content to the node's repository, is that in this way we are guaranteed # not to accidentally move files to the repository that should not go there at all cost. - for root, dirnames, filenames in os.walk(folder.abspath): + for root, _, filenames in os.walk(folder.abspath): for filename in filenames: filepath = os.path.join(root, filename) relpath = os.path.relpath(filepath, folder.abspath) if relpath not in provenance_exclude_list: with open(filepath, 'rb') as handle: - node.put_object_from_filelike(handle, relpath, 'wb', force=True) + node._repository.put_object_from_filelike(handle, relpath, 'wb', force=True) # pylint: disable=protected-access if not dry_run: # Make sure that attaching the `remote_folder` with a link is the last thing we do. This gives the biggest @@ -319,8 +343,9 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the retrieval. link_label = calculation.link_label_retrieved if calculation.get_outgoing(FolderData, link_label_filter=link_label).first(): - execlogger.warning('CalcJobNode<{}> already has a `{}` output folder: skipping retrieval'.format( - calculation.pk, link_label)) + execlogger.warning( + 'CalcJobNode<{}> already has a `{}` output folder: skipping retrieval'.format(calculation.pk, link_label) + ) return # Create the FolderData node into which to store the files that are to be retrieved @@ -351,14 +376,17 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): # Log the files that were retrieved in the temporary folder for filename in os.listdir(retrieved_temporary_folder): - execlogger.debug("[retrieval of calc {}] Retrieved temporary file or folder '{}'".format( - calculation.pk, filename), extra=logger_extra) + execlogger.debug( + "[retrieval of calc {}] Retrieved temporary file or folder '{}'".format(calculation.pk, filename), + extra=logger_extra + ) # Store everything execlogger.debug( '[retrieval of calc {}] ' 'Storing retrieved_files={}'.format(calculation.pk, retrieved_files.pk), - extra=logger_extra) + extra=logger_extra + ) retrieved_files.store() # Make sure that attaching the `retrieved` folder with a link is the last thing we do. This gives the biggest chance @@ -398,72 +426,15 @@ def kill_calculation(calculation, transport): return True -def parse_results(process, retrieved_temporary_folder=None): - """ - Parse the results for a given CalcJobNode (job) - - :returns: integer exit code, where 0 indicates success and non-zero failure - """ - from aiida.engine import ExitCode - - assert process.node.get_state() == CalcJobState.PARSING, \ - 'job should be in the PARSING state when calling this function yet it is {}'.format(process.node.get_state()) - - parser_class = process.node.get_parser_class() - exit_code = ExitCode() - logger_extra = get_dblogger_extra(process.node) - - if retrieved_temporary_folder: - files = [] - for root, directories, filenames in os.walk(retrieved_temporary_folder): - for directory in directories: - files.append('- [D] {}'.format(os.path.join(root, directory))) - for filename in filenames: - files.append('- [F] {}'.format(os.path.join(root, filename))) - - execlogger.debug('[parsing of calc {}] ' - 'Content of the retrieved_temporary_folder: \n' - '{}'.format(process.node.pk, '\n'.join(files)), extra=logger_extra) - else: - execlogger.debug('[parsing of calc {}] ' - 'No retrieved_temporary_folder.'.format(process.node.pk), extra=logger_extra) - - if parser_class is not None: - - parser = parser_class(process.node) - parse_kwargs = parser.get_outputs_for_parsing() - - if retrieved_temporary_folder: - parse_kwargs['retrieved_temporary_folder'] = retrieved_temporary_folder - - exit_code = parser.parse(**parse_kwargs) - - if exit_code is None: - exit_code = ExitCode(0) - - if not isinstance(exit_code, ExitCode): - raise ValueError('parse should return an `ExitCode` or None, and not {}'.format(type(exit_code))) - - if exit_code.status: - parser.logger.error('parser returned exit code<{}>: {}'.format(exit_code.status, exit_code.message)) - - for link_label, node in parser.outputs.items(): - try: - process.out(link_label, node) - except ValueError as exception: - parser.logger.error('invalid value {} specified with label {}: {}'.format(node, link_label, exception)) - exit_code = process.exit_codes.ERROR_INVALID_OUTPUT - break - - return exit_code - - def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_extra=None): + """Retrieve files specified through the singlefile list mechanism.""" singlefile_list = [] for (linkname, subclassname, filename) in retrieve_file_list: - execlogger.debug('[retrieval of calc {}] Trying ' - "to retrieve remote singlefile '{}'".format( - job.pk, filename), extra=logger_extra) + execlogger.debug( + '[retrieval of calc {}] Trying ' + "to retrieve remote singlefile '{}'".format(job.pk, filename), + extra=logger_extra + ) localfilename = os.path.join(folder.abspath, os.path.split(filename)[1]) transport.get(filename, localfilename, ignore_nonexisting=True) singlefile_list.append((linkname, subclassname, localfilename)) @@ -474,16 +445,16 @@ def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_ext # after retrieving from the cluster, I create the objects singlefiles = [] for (linkname, subclassname, filename) in singlefile_list: - SinglefileSubclass = DataFactory(subclassname) - singlefile = SinglefileSubclass(file=filename) + cls = DataFactory(subclassname) + singlefile = cls(file=filename) singlefile.add_incoming(job, link_type=LinkType.CREATE, link_label=linkname) singlefiles.append(singlefile) for fil in singlefiles: execlogger.debug( '[retrieval of calc {}] ' - 'Storing retrieved_singlefile={}'.format(job.pk, fil.pk), - extra=logger_extra) + 'Storing retrieved_singlefile={}'.format(job.pk, fil.pk), extra=logger_extra + ) fil.store() @@ -507,12 +478,12 @@ def retrieve_files_from_list(calculation, transport, folder, retrieve_list): treated as the work directory of the folder and the depth integer determines upto what level of the original remotepath nesting the files will be copied. - :param transport: the Transport instance - :param folder: an absolute path to a folder to copy files in - :param retrieve_list: the list of files to retrieve + :param transport: the Transport instance. + :param folder: an absolute path to a folder that contains the files to copy. + :param retrieve_list: the list of files to retrieve. """ for item in retrieve_list: - if isinstance(item, list): + if isinstance(item, (list, tuple)): tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): @@ -523,13 +494,11 @@ def retrieve_files_from_list(calculation, transport, folder, retrieve_list): local_names.append(os.path.sep.join([tmp_lname] + to_append)) else: remote_names = [tmp_rname] - to_append = remote_names.split(os.path.sep)[-depth:] if depth > 0 else [] + to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else [] local_names = [os.path.sep.join([tmp_lname] + to_append)] if depth > 1: # create directories in the folder, if needed for this_local_file in local_names: - new_folder = os.path.join( - folder, - os.path.split(this_local_file)[0]) + new_folder = os.path.join(folder, os.path.split(this_local_file)[0]) if not os.path.exists(new_folder): os.makedirs(new_folder) else: # it is a string @@ -542,5 +511,6 @@ def retrieve_files_from_list(calculation, transport, folder, retrieve_list): for rem, loc in zip(remote_names, local_names): transport.logger.debug( - "[retrieval of calc {}] Trying to retrieve remote item '{}'".format(calculation.pk, rem)) + "[retrieval of calc {}] Trying to retrieve remote item '{}'".format(calculation.pk, rem) + ) transport.get(rem, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 6f686355fd..39dd934868 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -19,6 +19,7 @@ from aiida.common.lang import override, classproperty from aiida.common.links import LinkType +from ..exit_code import ExitCode from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec from .tasks import Waiting, UPLOAD_COMMAND @@ -26,7 +27,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs, ctx): # pylint: disable=inconsistent-return-statements,too-many-return-statements +def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statements """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -81,14 +82,15 @@ def validate_calc_job(inputs, ctx): # pylint: disable=inconsistent-return-state except KeyError: return 'input `metadata.options.resources` is required but is not specified' + scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) + try: - scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) scheduler.validate_resources(**resources) - except (ValueError, TypeError) as exception: - return 'input `metadata.options.resources` is not valid for the {} scheduler: {}'.format(scheduler, exception) + except ValueError as exception: + return 'input `metadata.options.resources` is not valid for the `{}` scheduler: {}'.format(scheduler, exception) -def validate_parser(parser_name, _): # pylint: disable=inconsistent-return-statements +def validate_parser(parser_name, _): """Validate the parser. :return: string with error message in case the inputs are invalid @@ -191,6 +193,14 @@ def define(cls, spec: CalcJobProcessSpec): help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.') + # Errors caused or returned by the scheduler + spec.exit_code(100, 'ERROR_NO_RETRIEVED_FOLDER', + message='The process did not have the required `retrieved` output.') + spec.exit_code(110, 'ERROR_SCHEDULER_OUT_OF_MEMORY', + message='The job ran out of memory.') + spec.exit_code(120, 'ERROR_SCHEDULER_OUT_OF_WALLTIME', + message='The job ran out of walltime.') + @classproperty def spec_options(cls): # pylint: disable=no-self-argument """Return the metadata options port namespace of the process specification of this process. @@ -280,22 +290,118 @@ def parse(self, retrieved_temporary_folder=None): This is called once it's finished waiting for the calculation to be finished and the data has been retrieved. """ import shutil - from aiida.engine.daemon import execmanager try: - exit_code = execmanager.parse_results(self, retrieved_temporary_folder) + retrieved = self.node.outputs.retrieved + except exceptions.NotExistent: + return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER # pylint: disable=no-member + + # Call the scheduler output parser + exit_code_scheduler = self.parse_scheduler_output(retrieved) + + if exit_code_scheduler is not None and exit_code_scheduler.status > 0: + # If an exit code is returned by the scheduler output parser, we log it and set it on the node. This will + # allow the actual `Parser` implementation, if defined in the inputs, to inspect it and decide to keep it, + # or override it with a more specific exit code, if applicable. + args = (exit_code_scheduler.status, exit_code_scheduler.message) + self.logger.warning('scheduler parser returned exit code<{}>: {}'.format(*args)) + self.node.set_exit_status(exit_code_scheduler.status) + self.node.set_exit_message(exit_code_scheduler.message) + + # Call the retrieved output parser + try: + exit_code_retrieved = self.parse_retrieved_output(retrieved_temporary_folder) finally: - # Delete the temporary folder - try: - shutil.rmtree(retrieved_temporary_folder) - except OSError as exception: - if exception.errno != 2: - raise + shutil.rmtree(retrieved_temporary_folder, ignore_errors=True) + + if exit_code_retrieved is not None and exit_code_retrieved.status > 0: + args = (exit_code_retrieved.status, exit_code_retrieved.message) + self.logger.warning('output parser returned exit code<{}>: {}'.format(*args)) + + # The final exit code is that of the scheduler, unless the output parser returned one + if exit_code_retrieved is not None: + exit_code = exit_code_retrieved + else: + exit_code = exit_code_scheduler # Finally link up the outputs and we're done for entry in self.node.get_outgoing(): self.out(entry.link_label, entry.node) + return exit_code or ExitCode(0) + + def parse_scheduler_output(self, retrieved): + """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" + scheduler = self.node.computer.get_scheduler() + filename_stderr = self.node.get_option('scheduler_stderr') + filename_stdout = self.node.get_option('scheduler_stdout') + + detailed_job_info = self.node.get_detailed_job_info() + + if detailed_job_info is None: + self.logger.info('could not parse scheduler output: the `detailed_job_info` attribute is missing') + elif detailed_job_info.get('retval', 0) != 0: + self.logger.info('could not parse scheduler output: return value of `detailed_job_info` is non-zero') + detailed_job_info = None + + try: + scheduler_stderr = retrieved.get_object_content(filename_stderr) + except FileNotFoundError: + scheduler_stderr = None + self.logger.warning('could not parse scheduler output: the `{}` file is missing'.format(filename_stderr)) + + try: + scheduler_stdout = retrieved.get_object_content(filename_stdout) + except FileNotFoundError: + scheduler_stdout = None + self.logger.warning('could not parse scheduler output: the `{}` file is missing'.format(filename_stdout)) + + # Only attempt to call the scheduler parser if all three resources of information are available + if any(entry is None for entry in [detailed_job_info, scheduler_stderr, scheduler_stdout]): + return + + try: + exit_code = scheduler.parse_output(detailed_job_info, scheduler_stdout, scheduler_stderr) + except exceptions.FeatureNotAvailable: + self.logger.info('`{}` does not implement scheduler output parsing'.format(scheduler.__class__.__name__)) + return + except Exception as exception: # pylint: disable=broad-except + self.logger.error('the `parse_output` method of the scheduler excepted: {}'.format(exception)) + return + + if exit_code is not None and not isinstance(exit_code, ExitCode): + args = (scheduler.__class__.__name__, type(exit_code)) + raise ValueError('`{}.parse_output` returned neither an `ExitCode` nor None, but: {}'.format(*args)) + + return exit_code + + def parse_retrieved_output(self, retrieved_temporary_folder=None): + """Parse the retrieved data by calling the parser plugin if it was defined in the inputs.""" + parser_class = self.node.get_parser_class() + + if parser_class is None: + return + + parser = parser_class(self.node) + parse_kwargs = parser.get_outputs_for_parsing() + + if retrieved_temporary_folder: + parse_kwargs['retrieved_temporary_folder'] = retrieved_temporary_folder + + exit_code = parser.parse(**parse_kwargs) + + for link_label, node in parser.outputs.items(): + try: + self.out(link_label, node) + except ValueError as exception: + self.logger.error('invalid value {} specified with label {}: {}'.format(node, link_label, exception)) + exit_code = self.exit_codes.ERROR_INVALID_OUTPUT # pylint: disable=no-member + break + + if exit_code is not None and not isinstance(exit_code, ExitCode): + args = (parser_class.__name__, type(exit_code)) + raise ValueError('`{}.parse` returned neither an `ExitCode` nor None, but: {}'.format(*args)) + return exit_code def presubmit(self, folder): @@ -329,7 +435,7 @@ def presubmit(self, folder): for code in codes: if not code.can_run_on(computer): raise InputValidationError('The selected code {} for calculation {} cannot run on computer {}'.format( - code.pk, self.node.pk, computer.name)) + code.pk, self.node.pk, computer.label)) if code.is_local() and code.get_local_executable() in folder.get_content_list(): raise PluginInternalError('The plugin created a file {} that is also the executable name!'.format( diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 112e3a5058..c0958e02d5 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Transport tasks for calculation jobs.""" import functools import logging import tempfile @@ -87,7 +88,8 @@ def do_upload(): logger.info('scheduled request to upload CalcJob<{}>'.format(node.pk)) ignore_exceptions = (plumpy.CancelledError, PreSubmitException) result = yield exponential_backoff_retry( - do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions) + do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) except PreSubmitException: raise except plumpy.CancelledError: @@ -136,7 +138,8 @@ def do_submit(): try: logger.info('scheduled request to submit CalcJob<{}>'.format(node.pk)) result = yield exponential_backoff_retry( - do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: pass except Exception: @@ -195,7 +198,8 @@ def do_update(): try: logger.info('scheduled request to update CalcJob<{}>'.format(node.pk)) job_done = yield exponential_backoff_retry( - do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: raise except Exception: @@ -257,8 +261,9 @@ def do_retrieve(): try: logger.info('scheduled request to retrieve CalcJob<{}>'.format(node.pk)) - result = yield exponential_backoff_retry( - do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + yield exponential_backoff_retry( + do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: raise except Exception: @@ -333,7 +338,8 @@ def load_instance_state(self, saved_state, load_context): @coroutine def execute(self): - + """Override the execute coroutine of the base `Waiting` state.""" + # pylint: disable=too-many-branches node = self.process.node transport_queue = self.process.runner.transport command = self.data @@ -393,6 +399,7 @@ def execute(self): @coroutine def _launch_task(self, coro, *args, **kwargs): + """Launch a coroutine as a task, making sure to make it interruptable.""" task_fn = functools.partial(coro, *args, **kwargs) try: self._task = interruptable_task(task_fn) diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 07e535af3d..35b65022dd 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -20,7 +20,7 @@ __all__ = ('BaseRestartWorkChain',) -def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=inconsistent-return-statements,unused-argument +def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=unused-argument """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a @@ -170,7 +170,7 @@ def run_process(self): return ToContext(children=append_(node)) - def inspect_process(self): # pylint: disable=inconsistent-return-statements,too-many-branches + def inspect_process(self): # pylint: disable=too-many-branches """Analyse the results of the previous process and call the handlers when necessary. If the process is excepted or killed, the work chain will abort. Otherwise any attached handlers will be called @@ -260,7 +260,7 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too # Otherwise the process was successful and no handler returned anything so we consider the work done self.ctx.is_finished = True - def results(self): # pylint: disable=inconsistent-return-statements + def results(self): """Attach the outputs specified in the output specification from the last completed process.""" node = self.ctx.children[self.ctx.iteration - 1] diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index 45f2158e8b..94fd826e61 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -99,7 +99,7 @@ def wrapper(wrapped, instance, args, kwargs): # When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument node = args[0] - if exit_codes and node.exit_status not in [exit_code.status for exit_code in exit_codes]: + if exit_codes is not None and node.exit_status not in [exit_code.status for exit_code in exit_codes]: result = None else: result = wrapped(*args, **kwargs) diff --git a/aiida/manage/backup/backup_general.py b/aiida/manage/backup/backup_general.py index f75d7f41f1..1ec59796ee 100644 --- a/aiida/manage/backup/backup_general.py +++ b/aiida/manage/backup/backup_general.py @@ -15,7 +15,7 @@ from aiida.orm import Node from aiida.manage.backup.backup_base import AbstractBackup, BackupError from aiida.common.folders import RepositoryFolder -from aiida.orm.utils.repository import Repository +from aiida.orm.utils._repository import Repository class Backup(AbstractBackup): diff --git a/aiida/manage/configuration/migrations/migrations.py b/aiida/manage/configuration/migrations/migrations.py index 85560c6325..7d094e74d3 100644 --- a/aiida/manage/configuration/migrations/migrations.py +++ b/aiida/manage/configuration/migrations/migrations.py @@ -15,7 +15,7 @@ # If the configuration file format is changed, the current version number should be upped and a migration added. # When the configuration file format is changed in a backwards-incompatible way, the oldest compatible version should # be set to the new current version. -CURRENT_CONFIG_VERSION = 3 +CURRENT_CONFIG_VERSION = 4 OLDEST_COMPATIBLE_CONFIG_VERSION = 3 @@ -42,8 +42,7 @@ def apply(self, config): def _1_add_profile_uuid(config): - """ - This adds the required values for a new default profile + """Add the required values for a new default profile. * PROFILE_UUID @@ -58,10 +57,11 @@ def _1_add_profile_uuid(config): def _2_simplify_default_profiles(config): - """ - The concept of a different 'process' for a profile has been removed and as such the - default profiles key in the configuration no longer needs a value per process ('verdi', 'daemon'). - We remove the dictionary 'default_profiles' and replace it with a simple value 'default_profile'. + """Replace process specific default profiles with single default profile key. + + The concept of a different 'process' for a profile has been removed and as such the default profiles key in the + configuration no longer needs a value per process ('verdi', 'daemon'). We remove the dictionary 'default_profiles' + and replace it with a simple value 'default_profile'. """ from aiida.manage.configuration import PROFILE @@ -77,9 +77,31 @@ def _2_simplify_default_profiles(config): return config +def _3_add_message_broker(config): + """Add the configuration for the message broker, which was not configurable up to now.""" + from aiida.manage.external.rmq import BROKER_DEFAULTS + + defaults = [ + ('broker_protocol', BROKER_DEFAULTS.protocol), + ('broker_username', BROKER_DEFAULTS.username), + ('broker_password', BROKER_DEFAULTS.password), + ('broker_host', BROKER_DEFAULTS.host), + ('broker_port', BROKER_DEFAULTS.port), + ('broker_virtual_host', BROKER_DEFAULTS.virtual_host), + ] + + for profile in config.get('profiles', {}).values(): + for key, default in defaults: + if key not in profile: + profile[key] = default + + return config + + # Maps the initial config version to the ConfigMigration which updates it. _MIGRATION_LOOKUP = { 0: ConfigMigration(migrate_function=lambda x: x, version=1, version_oldest_compatible=0), 1: ConfigMigration(migrate_function=_1_add_profile_uuid, version=2, version_oldest_compatible=0), - 2: ConfigMigration(migrate_function=_2_simplify_default_profiles, version=3, version_oldest_compatible=3) + 2: ConfigMigration(migrate_function=_2_simplify_default_profiles, version=3, version_oldest_compatible=3), + 3: ConfigMigration(migrate_function=_3_add_message_broker, version=4, version_oldest_compatible=3) } diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py index ac5ec4e164..4cba1f0c7a 100644 --- a/aiida/manage/configuration/profile.py +++ b/aiida/manage/configuration/profile.py @@ -41,6 +41,13 @@ class Profile: # pylint: disable=too-many-public-methods KEY_DATABASE_HOSTNAME = 'AIIDADB_HOST' KEY_DATABASE_USERNAME = 'AIIDADB_USER' KEY_DATABASE_PASSWORD = 'AIIDADB_PASS' # noqa + KEY_BROKER_PROTOCOL = 'broker_protocol' + KEY_BROKER_USERNAME = 'broker_username' + KEY_BROKER_PASSWORD = 'broker_password' # noqa + KEY_BROKER_HOST = 'broker_host' + KEY_BROKER_PORT = 'broker_port' + KEY_BROKER_VIRTUAL_HOST = 'broker_virtual_host' + KEY_BROKER_PARAMETERS = 'broker_parameters' KEY_REPOSITORY_URI = 'AIIDADB_REPOSITORY_URI' # A mapping of valid attributes to the key under which they are stored in the configuration dictionary @@ -55,6 +62,13 @@ class Profile: # pylint: disable=too-many-public-methods KEY_DATABASE_HOSTNAME: 'database_hostname', KEY_DATABASE_USERNAME: 'database_username', KEY_DATABASE_PASSWORD: 'database_password', + KEY_BROKER_PROTOCOL: 'broker_protocol', + KEY_BROKER_USERNAME: 'broker_username', + KEY_BROKER_PASSWORD: 'broker_password', + KEY_BROKER_HOST: 'broker_host', + KEY_BROKER_PORT: 'broker_port', + KEY_BROKER_VIRTUAL_HOST: 'broker_virtual_host', + KEY_BROKER_PARAMETERS: 'broker_parameters', KEY_REPOSITORY_URI: 'repository_uri', } @@ -175,6 +189,62 @@ def database_password(self): def database_password(self, value): self._attributes[self.KEY_DATABASE_PASSWORD] = value + @property + def broker_protocol(self): + return self._attributes[self.KEY_BROKER_PROTOCOL] + + @broker_protocol.setter + def broker_protocol(self, value): + self._attributes[self.KEY_BROKER_PROTOCOL] = value + + @property + def broker_host(self): + return self._attributes[self.KEY_BROKER_HOST] + + @broker_host.setter + def broker_host(self, value): + self._attributes[self.KEY_BROKER_HOST] = value + + @property + def broker_port(self): + return self._attributes[self.KEY_BROKER_PORT] + + @broker_port.setter + def broker_port(self, value): + self._attributes[self.KEY_BROKER_PORT] = value + + @property + def broker_username(self): + return self._attributes[self.KEY_BROKER_USERNAME] + + @broker_username.setter + def broker_username(self, value): + self._attributes[self.KEY_BROKER_USERNAME] = value + + @property + def broker_password(self): + return self._attributes[self.KEY_BROKER_PASSWORD] + + @broker_password.setter + def broker_password(self, value): + self._attributes[self.KEY_BROKER_PASSWORD] = value + + @property + def broker_virtual_host(self): + return self._attributes[self.KEY_BROKER_VIRTUAL_HOST] + + @broker_virtual_host.setter + def broker_virtual_host(self, value): + self._attributes[self.KEY_BROKER_VIRTUAL_HOST] = value + + @property + def broker_parameters(self): + return self._attributes.get(self.KEY_BROKER_PARAMETERS, {}) + + @broker_parameters.setter + def broker_parameters(self, value): + self._attributes[self.KEY_BROKER_PARAMETERS] = value + @property def repository_uri(self): return self._attributes[self.KEY_REPOSITORY_URI] @@ -268,6 +338,18 @@ def _parse_repository_uri(self): return parts.scheme, os.path.expanduser(parts.path) + def get_rmq_url(self): + from aiida.manage.external.rmq import get_rmq_url + return get_rmq_url( + protocol=self.broker_protocol, + username=self.broker_username, + password=self.broker_password, + host=self.broker_host, + port=self.broker_port, + virtual_host=self.broker_virtual_host, + **self.broker_parameters + ) + def configure_repository(self): """Validates the configured repository and in the case of a file system repo makes sure the folder exists.""" import errno diff --git a/aiida/manage/database/integrity/duplicate_uuid.py b/aiida/manage/database/integrity/duplicate_uuid.py index 763f617f52..a17b2edde3 100644 --- a/aiida/manage/database/integrity/duplicate_uuid.py +++ b/aiida/manage/database/integrity/duplicate_uuid.py @@ -71,7 +71,7 @@ def deduplicate_uuids(table=None, dry_run=True): from collections import defaultdict from aiida.common.utils import get_new_uuid - from aiida.orm.utils.repository import Repository + from aiida.orm.utils._repository import Repository if table not in TABLES_UUID_DEDUPLICATION: raise ValueError('invalid table {}: choose from {}'.format(table, ', '.join(TABLES_UUID_DEDUPLICATION))) diff --git a/aiida/manage/external/rmq.py b/aiida/manage/external/rmq.py index 12512b0cec..51ffe1272c 100644 --- a/aiida/manage/external/rmq.py +++ b/aiida/manage/external/rmq.py @@ -13,10 +13,12 @@ import logging from tornado import gen -import plumpy from kiwipy import communications, Future +import plumpy + +from aiida.common.extendeddicts import AttributeDict -__all__ = ('RemoteException', 'CommunicationTimeout', 'DeliveryFailed', 'ProcessLauncher') +__all__ = ('RemoteException', 'CommunicationTimeout', 'DeliveryFailed', 'ProcessLauncher', 'BROKER_DEFAULTS') LOGGER = logging.getLogger(__name__) @@ -24,40 +26,79 @@ DeliveryFailed = plumpy.DeliveryFailed CommunicationTimeout = communications.TimeoutError # pylint: disable=invalid-name -# GP: Using here 127.0.0.1 instead of localhost because on some computers -# localhost resolves first to IPv6 with address ::1 and if RMQ is not -# running on IPv6 one gets an annoying warning. When moving this to -# a user-configurable variable, make sure users are aware of this and -# know how to avoid warnings. For more info see -# https://github.com/aiidateam/aiida-core/issues/1142 -_RMQ_URL = 'amqp://127.0.0.1' -_RMQ_HEARTBEAT_TIMEOUT = 600 # Maximum that can be set by client, with default RabbitMQ server configuration _LAUNCH_QUEUE = 'process.queue' _MESSAGE_EXCHANGE = 'messages' _TASK_EXCHANGE = 'tasks' - -def get_rmq_url(heartbeat_timeout=None): +BROKER_DEFAULTS = AttributeDict({ + 'protocol': 'amqp', + 'username': 'guest', + 'password': 'guest', + 'host': '127.0.0.1', + 'port': 5672, + 'virtual_host': '', + 'heartbeat': 600, +}) + +BROKER_VALID_PARAMETERS = [ + 'heartbeat', # heartbeat timeout in seconds + 'cafile', # string containing path to ca certificate file + 'capath', # string containing path to ca certificates + 'cadata', # base64 encoded ca certificate data + 'keyfile', # string containing path to key file + 'certfile', # string containing path to certificate file + 'no_verify_ssl', # boolean disables certificates validation +] + + +def get_rmq_url(protocol=None, username=None, password=None, host=None, port=None, virtual_host=None, **kwargs): + """Return the URL to connect to RabbitMQ. + + .. note:: + + The default of the ``host`` is set to ``127.0.0.1`` instead of ``localhost`` because on some computers localhost + resolves first to IPv6 with address ::1 and if RMQ is not running on IPv6 one gets an annoying warning. For more + info see: https://github.com/aiidateam/aiida-core/issues/1142 + + :param protocol: the protocol to use, `amqp` or `amqps`. + :param username: the username for authentication. + :param password: the password for authentication. + :param host: the hostname of the RabbitMQ server. + :param port: the port of the RabbitMQ server. + :param virtual_host: the virtual host to connect to. + :returns: the connection URL string. """ - Get the URL to connect to RabbitMQ + from urllib.parse import urlencode, urlunparse - :param heartbeat_timeout: the interval in seconds for the heartbeat timeout - :returns: the connection URL string - """ - url = _RMQ_URL + invalid = set(kwargs.keys()).difference(BROKER_VALID_PARAMETERS) + if invalid: + raise ValueError('invalid URL parameters specified in the keyword arguments: {}'.format(', '.join(invalid))) + + if 'heartbeat' not in kwargs: + kwargs['heartbeat'] = BROKER_DEFAULTS.heartbeat - if heartbeat_timeout is None: - heartbeat_timeout = _RMQ_HEARTBEAT_TIMEOUT + scheme = protocol or BROKER_DEFAULTS.protocol + netloc = '{username}:{password}@{host}:{port}'.format( + username=username or BROKER_DEFAULTS.username, + password=password or BROKER_DEFAULTS.password, + host=host or BROKER_DEFAULTS.host, + port=port or BROKER_DEFAULTS.port, + ) + path = virtual_host or BROKER_DEFAULTS.virtual_host + parameters = '' + query = urlencode(kwargs) + fragment = '' - if heartbeat_timeout is not None: - url += '?heartbeat={}'.format(heartbeat_timeout) + # The virtual host is optional but if it is specified it needs to start with a forward slash. If the virtual host + # itself contains forward slashes, they need to be encoded. + if path and not path.startswith('/'): + path = '/' + path - return url + return urlunparse((scheme, netloc, path, parameters, query, fragment)) def get_launch_queue_name(prefix=None): - """ - Return the launch queue name with an optional prefix + """Return the launch queue name with an optional prefix. :returns: launch queue name """ @@ -68,8 +109,7 @@ def get_launch_queue_name(prefix=None): def get_message_exchange_name(prefix): - """ - Return the message exchange name for a given prefix + """Return the message exchange name for a given prefix. :returns: message exchange name """ @@ -77,8 +117,7 @@ def get_message_exchange_name(prefix): def get_task_exchange_name(prefix): - """ - Return the task exchange name for a given prefix + """Return the task exchange name for a given prefix. :returns: task exchange name """ @@ -86,8 +125,9 @@ def get_task_exchange_name(prefix): def _store_inputs(inputs): - """ - Try to store the values in the input dictionary. For nested dictionaries, the values are stored by recursively. + """Try to store the values in the input dictionary. + + For nested dictionaries, the values are stored by recursively. """ for node in inputs.values(): try: diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 4f0cb904ca..8914555c15 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -182,7 +182,7 @@ def create_communicator(self, task_prefetch_count=None, with_orm=True): if task_prefetch_count is None: task_prefetch_count = self.get_config().get_option('daemon.worker_process_slots', profile.name) - url = rmq.get_rmq_url() + url = profile.get_rmq_url() prefix = profile.rmq_prefix # This needs to be here, because the verdi commands will call this function and when called in unit tests the diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index cc4e2e5fbf..c5486326a0 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -272,6 +272,12 @@ def profile_dictionary(self): 'database_name': self.profile_info.get('database_name'), 'database_username': self.profile_info.get('database_username'), 'database_password': self.profile_info.get('database_password'), + 'broker_protocol': self.profile_info.get('broker_protocol'), + 'broker_username': self.profile_info.get('broker_username'), + 'broker_password': self.profile_info.get('broker_password'), + 'broker_host': self.profile_info.get('broker_host'), + 'broker_port': self.profile_info.get('broker_port'), + 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), 'repository_uri': 'file://' + self.repo, } return dictionary diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 310b07e944..00a151e66a 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -94,15 +94,15 @@ def test_1(aiida_localhost): from aiida.orm import Computer from aiida.common.exceptions import NotExistent - name = 'localhost-test' + label = 'localhost-test' try: - computer = Computer.objects.get(name=name) + computer = Computer.objects.get(label=label) except NotExistent: computer = Computer( - name=name, + label=label, description='localhost computer set up by test manager', - hostname=name, + hostname=label, workdir=temp_dir, transport_type='local', scheduler_type='direct' diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index d1bf0e8dcb..d3bd980f66 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -50,9 +50,9 @@ def __init__(self, computer, user, backend=None): def __str__(self): if self.enabled: - return 'AuthInfo for {} on {}'.format(self.user.email, self.computer.name) + return 'AuthInfo for {} on {}'.format(self.user.email, self.computer.label) - return 'AuthInfo for {} on {} [DISABLED]'.format(self.user.email, self.computer.name) + return 'AuthInfo for {} on {} [DISABLED]'.format(self.user.email, self.computer.label) @property def enabled(self): @@ -138,7 +138,7 @@ def get_transport(self): :rtype: :class:`aiida.transports.Transport` """ computer = self.computer - transport_type = computer.get_transport_type() + transport_type = computer.transport_type try: transport_class = TransportFactory(transport_type) diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 2832733fa0..3662edf71a 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -8,15 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for Computer entities""" - import logging import os import warnings -from aiida import transports, schedulers from aiida.common import exceptions from aiida.common.warnings import AiidaDeprecationWarning from aiida.manage.manager import get_manager +from aiida.orm.implementation import Backend from aiida.plugins import SchedulerFactory, TransportFactory from . import entities @@ -51,33 +50,72 @@ class Computer(entities.Entity): class Collection(entities.Collection): """The collection of Computer entries.""" + def get(self, **filters): + """Get a single collection entry that matches the filter criteria. + + :param filters: the filters identifying the object to get + :type filters: dict + + :return: the entry + """ + if 'name' in filters: + warnings.warn('keyword `name` is deprecated, use `label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + + # This switch needs to be here until we fully remove `name` and replace it with `label` even on the backend + # entities and database models. + if 'label' in filters: + filters['name'] = filters.pop('label') + + return super().get(**filters) + def list_names(self): - """Return a list with all the names of the computers in the DB.""" + """Return a list with all the names of the computers in the DB. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use `list_labels` instead. + """ + return self._backend.computers.list_names() + + def list_labels(self): + """Return a list with all the labels of the computers in the DB.""" return self._backend.computers.list_names() def delete(self, id): # pylint: disable=redefined-builtin,invalid-name """Delete the computer with the given id""" return self._backend.computers.delete(id) - def __init__( - self, name, hostname, description='', transport_type='', scheduler_type='', workdir=None, backend=None - ): + def __init__( # pylint: disable=too-many-arguments + self, + label: str = None, + hostname: str = None, + description: str = '', + transport_type: str = '', + scheduler_type: str = '', + workdir: str = None, + backend: Backend = None, + name: str = None + ) -> 'Computer': """Construct a new computer - :type name: str - :type hostname: str - :type description: str - :type transport_type: str - :type scheduler_type: str - :type workdir: str - :type backend: :class:`aiida.orm.implementation.Backend` - - :rtype: :class:`aiida.orm.Computer` + .. deprecated:: 1.4.0 + The `name` keyword will be removed in `v2.0.0`, use `label` instead. """ - # pylint: disable=too-many-arguments + + # This needs to be here because `label` needed to get a default, since it was replacing `name` and during the + # deprecation period, it needs to be automatically set to whatever `name` is passed. As a knock-on effect, since + # a keyword argument cannot preceed a normal argument, `hostname` also needed to become a keyword argument, + # forcing us to set a default, which we set to `None`. We raise the same exception that Python would normally + # raise if a normally positional argument is not specified. + if hostname is None: + raise TypeError("missing 1 required positional argument: 'hostname'") + + if name is not None: + warnings.warn('keyword `name` is deprecated, use `label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + label = name + backend = backend or get_manager().get_backend() model = backend.computers.create( - name=name, + name=label, hostname=hostname, description=description, transport_type=transport_type, @@ -91,23 +129,27 @@ def __repr__(self): return '<{}: {}>'.format(self.__class__.__name__, str(self)) def __str__(self): - return '{} ({}), pk: {}'.format(self.name, self.hostname, self.pk) + return '{} ({}), pk: {}'.format(self.label, self.hostname, self.pk) @property def full_text_info(self): """ Return a (multiline) string with a human-readable detailed information on this computer. + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`. + :rtype: str """ + warnings.warn('this property is deprecated', AiidaDeprecationWarning) # pylint: disable=no-member ret_lines = [] - ret_lines.append('Computer name: {}'.format(self.name)) + ret_lines.append('Computer name: {}'.format(self.label)) ret_lines.append(' * PK: {}'.format(self.pk)) ret_lines.append(' * UUID: {}'.format(self.uuid)) ret_lines.append(' * Description: {}'.format(self.description)) ret_lines.append(' * Hostname: {}'.format(self.hostname)) - ret_lines.append(' * Transport type: {}'.format(self.get_transport_type())) - ret_lines.append(' * Scheduler type: {}'.format(self.get_scheduler_type())) + ret_lines.append(' * Transport type: {}'.format(self.transport_type)) + ret_lines.append(' * Scheduler type: {}'.format(self.scheduler_type)) ret_lines.append(' * Work directory: {}'.format(self.get_workdir())) ret_lines.append(' * Shebang: {}'.format(self.get_shebang())) ret_lines.append(' * mpirun command: {}'.format(' '.join(self.get_mpirun_command()))) @@ -137,8 +179,6 @@ def full_text_info(self): def logger(self): return self._logger - # region validation - @classmethod def _name_validator(cls, name): """ @@ -167,7 +207,8 @@ def _transport_type_validator(cls, transport_type): """ Validates the transport string. """ - if transport_type not in transports.Transport.get_valid_transports(): + from aiida.plugins.entry_point import get_entry_point_names + if transport_type not in get_entry_point_names('aiida.transports'): raise exceptions.ValidationError('The specified transport is not a valid one') @classmethod @@ -175,7 +216,8 @@ def _scheduler_type_validator(cls, scheduler_type): """ Validates the transport string. """ - if scheduler_type not in schedulers.Scheduler.get_valid_schedulers(): + from aiida.plugins.entry_point import get_entry_point_names + if scheduler_type not in get_entry_point_names('aiida.schedulers'): raise exceptions.ValidationError('The specified scheduler is not a valid one') @classmethod @@ -245,13 +287,13 @@ def validate(self): For the base class, this is always valid. Subclasses will reimplement this. In the subclass, always call the super().validate() method first! """ - if not self.get_name().strip(): + if not self.label.strip(): raise exceptions.ValidationError('No name specified') - self._hostname_validator(self.get_hostname()) - self._description_validator(self.get_description()) - self._transport_type_validator(self.get_transport_type()) - self._scheduler_type_validator(self.get_scheduler_type()) + self._hostname_validator(self.hostname) + self._description_validator(self.description) + self._transport_type_validator(self.transport_type) + self._scheduler_type_validator(self.scheduler_type) self._workdir_validator(self.get_workdir()) try: @@ -272,14 +314,10 @@ def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine): if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0: raise exceptions.ValidationError( - 'Invalid value for default_mpiprocs_per_machine, ' - 'must be a positive integer, or an empty ' - 'string if you do not want to provide a ' - 'default value.' + 'Invalid value for default_mpiprocs_per_machine, must be a positive integer, or an empty string if you ' + 'do not want to provide a default value.' ) - # endregion - def copy(self): """ Return a copy of the current object to work with, not stored yet. @@ -297,49 +335,100 @@ def store(self): return super().store() @property - def name(self): + def label(self) -> str: + """Return the computer label. + + :return: the label. + """ return self._backend_entity.name + @label.setter + def label(self, value: str): + """Set the computer label. + + :param value: the label to set. + """ + self._backend_entity.set_name(value) + @property - def label(self): + def description(self) -> str: + """Return the computer computer. + + :return: the description. """ - The computer label + return self._backend_entity.description + + @description.setter + def description(self, value: str): + """Set the computer description. + + :param value: the description to set. """ - return self.name + self._backend_entity.set_description(value) - @label.setter - def label(self, value): + @property + def hostname(self) -> str: + """Return the computer hostname. + + :return: the hostname. """ - Set the computer label (i.e., name) + return self._backend_entity.hostname + + @hostname.setter + def hostname(self, value: str): + """Set the computer hostname. + + :param value: the hostname to set. """ - self.set_name(value) + self._backend_entity.set_hostname(value) @property - def description(self): + def scheduler_type(self) -> str: + """Return the computer scheduler type. + + :return: the scheduler type. """ - Get a description of the computer + return self._backend_entity.get_scheduler_type() - :return: the description - :rtype: str + @scheduler_type.setter + def scheduler_type(self, value: str): + """Set the computer scheduler type. + + :param value: the scheduler type to set. """ - return self._backend_entity.description + self._backend_entity.set_scheduler_type(value) @property - def hostname(self): - return self._backend_entity.hostname + def transport_type(self) -> str: + """Return the computer transport type. - def get_metadata(self): - return self._backend_entity.get_metadata() + :return: the transport_type. + """ + return self._backend_entity.get_transport_type() - def set_metadata(self, metadata): + @transport_type.setter + def transport_type(self, value: str): + """Set the computer transport type. + + :param value: the transport_type to set. """ - Set the metadata. + self._backend_entity.set_transport_type(value) - .. note: You still need to call the .store() method to actually save - data to the database! (The store method can be called multiple - times, differently from AiiDA Node objects). + @property + def metadata(self) -> str: + """Return the computer metadata. + + :return: the metadata. + """ + return self._backend_entity.get_metadata() + + @metadata.setter + def metadata(self, value: str): + """Set the computer metadata. + + :param value: the metadata to set. """ - self._backend_entity.set_metadata(metadata) + self._backend_entity.set_metadata(value) def delete_property(self, name, raise_exception=True): """ @@ -351,10 +440,10 @@ def delete_property(self, name, raise_exception=True): :param raise_exception: if True raise if the property does not exist, otherwise return None :type raise_exception: bool """ - olddata = self.get_metadata() + olddata = self.metadata try: del olddata[name] - self.set_metadata(olddata) + self.metadata = olddata except KeyError: if raise_exception: raise AttributeError("'{}' property not found".format(name)) @@ -366,9 +455,9 @@ def set_property(self, name, value): :param name: the property name :param value: the new value """ - metadata = self.get_metadata() or {} + metadata = self.metadata or {} metadata[name] = value - self.set_metadata(metadata) + self.metadata = metadata def get_property(self, name, *args): """ @@ -383,7 +472,7 @@ def get_property(self, name, *args): """ if len(args) > 1: raise TypeError('get_property expected at most 2 arguments') - olddata = self.get_metadata() + olddata = self.metadata try: return olddata[name] except KeyError: @@ -462,31 +551,6 @@ def set_minimum_job_poll_interval(self, interval): """ self.set_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, interval) - def get_transport(self, user=None): - """ - Return a Transport class, configured with all correct parameters. - The Transport is closed (meaning that if you want to run any operation with - it, you have to open it first (i.e., e.g. for a SSH transport, you have - to open a connection). To do this you can call ``transports.open()``, or simply - run within a ``with`` statement:: - - transport = Computer.get_transport() - with transport: - print(transports.whoami()) - - :param user: if None, try to obtain a transport for the default user. - Otherwise, pass a valid User. - - :return: a (closed) Transport, already configured with the connection - parameters to the supercomputer, as configured with ``verdi computer configure`` - for the user specified as a parameter ``user``. - """ - from . import authinfos # pylint: disable=cyclic-import - - user = user or users.User.objects(self.backend).get_default() - authinfo = authinfos.AuthInfo.objects(self.backend).get(dbcomputer=self, aiidauser=user) - return authinfo.get_transport() - def get_workdir(self): """ Get the working directory for this computer @@ -509,47 +573,9 @@ def set_shebang(self, val): raise ValueError('{} is invalid. Input has to be a string'.format(val)) if not val.startswith('#!'): raise ValueError('{} is invalid. A shebang line has to start with #!'.format(val)) - metadata = self.get_metadata() + metadata = self.metadata metadata['shebang'] = val - self.set_metadata(metadata) - - def get_name(self): - return self._backend_entity.get_name() - - def set_name(self, val): - self._backend_entity.set_name(val) - - def get_hostname(self): - """ - Get this computer hostname - :rtype: str - """ - return self._backend_entity.get_hostname() - - def set_hostname(self, val): - """ - Set the hostname of this computer - :param val: The new hostname - :type val: str - """ - self._backend_entity.set_hostname(val) - - def get_description(self): - """ - Get the description for this computer - - :return: the description - :rtype: str - """ - - def set_description(self, val): - """ - Set the description for this computer - - :param val: the new description - :type val: str - """ - self._backend_entity.set_description(val) + self.metadata = metadata def get_authinfo(self, user): """ @@ -592,43 +618,33 @@ def is_user_enabled(self, user): authinfo = self.get_authinfo(user) return authinfo.enabled except exceptions.NotExistent: - # Return False if the user is not configured (in a sense, - # it is disabled for that user) + # Return False if the user is not configured (in a sense, it is disabled for that user) return False - def get_scheduler_type(self): - """ - Get the scheduler type for this computer - - :return: the scheduler type - :rtype: str - """ - return self._backend_entity.get_scheduler_type() - - def set_scheduler_type(self, scheduler_type): - """ - :param scheduler_type: the new scheduler type + def get_transport(self, user=None): """ - self._scheduler_type_validator(scheduler_type) - self._backend_entity.set_scheduler_type(scheduler_type) + Return a Transport class, configured with all correct parameters. + The Transport is closed (meaning that if you want to run any operation with + it, you have to open it first (i.e., e.g. for a SSH transport, you have + to open a connection). To do this you can call ``transports.open()``, or simply + run within a ``with`` statement:: - def get_transport_type(self): - """ - Get the current transport type for this computer + transport = Computer.get_transport() + with transport: + print(transports.whoami()) - :return: the transport type - :rtype: str - """ - return self._backend_entity.get_transport_type() + :param user: if None, try to obtain a transport for the default user. + Otherwise, pass a valid User. - def set_transport_type(self, transport_type): + :return: a (closed) Transport, already configured with the connection + parameters to the supercomputer, as configured with ``verdi computer configure`` + for the user specified as a parameter ``user``. """ - Set the transport type for this computer + from . import authinfos # pylint: disable=cyclic-import - :param transport_type: the new transport type - :type transport_type: str - """ - self._backend_entity.set_transport_type(transport_type) + user = user or users.User.objects(self.backend).get_default() + authinfo = authinfos.AuthInfo.objects(self.backend).get(dbcomputer=self, aiidauser=user) + return authinfo.get_transport() def get_transport_class(self): """ @@ -637,12 +653,10 @@ def get_transport_class(self): :return: the transport class """ try: - return TransportFactory(self.get_transport_type()) + return TransportFactory(self.transport_type) except exceptions.EntryPointError as exception: raise exceptions.ConfigurationError( - 'No transport found for {} [type {}], message: {}'.format( - self.name, self.get_transport_type(), exception - ) + 'No transport found for {} [type {}], message: {}'.format(self.label, self.transport_type, exception) ) def get_scheduler(self): @@ -653,14 +667,12 @@ def get_scheduler(self): :rtype: :class:`aiida.schedulers.Scheduler` """ try: - scheduler_class = SchedulerFactory(self.get_scheduler_type()) + scheduler_class = SchedulerFactory(self.scheduler_type) # I call the init without any parameter return scheduler_class() except exceptions.EntryPointError as exception: raise exceptions.ConfigurationError( - 'No scheduler found for {} [type {}], message: {}'.format( - self.name, self.get_scheduler_type(), exception - ) + 'No scheduler found for {} [type {}], message: {}'.format(self.label, self.scheduler_type, exception) ) def configure(self, user=None, **kwargs): @@ -720,6 +732,158 @@ def get_configuration(self, user=None): return config + @property + def name(self): + """Return the computer name. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `label` property instead. + """ + warnings.warn('this property is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.label + + def get_name(self): + """Return the computer name. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `label` property instead. + """ + warnings.warn('this property is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.label + + def set_name(self, val): + """Set the computer name. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `label` property instead. + """ + warnings.warn('this method is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self.label = val + + def get_hostname(self): + """Get this computer hostname + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `hostname` property instead. + + :rtype: str + """ + warnings.warn('this method is deprecated, use the `hostname` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.hostname + + def set_hostname(self, val): + """ + Set the hostname of this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `hostname` property instead. + + :param val: The new hostname + :type val: str + """ + warnings.warn('this method is deprecated, use the `hostname` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self.hostname = val + + def get_description(self): + """ + Get the description for this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `description` property instead. + + :return: the description + :rtype: str + """ + warnings.warn('this method is deprecated, use the `description` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.description + + def set_description(self, val): + """ + Set the description for this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `description` property instead. + + :param val: the new description + :type val: str + """ + warnings.warn('this method is deprecated, use the `description` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self.description = val + + def get_scheduler_type(self): + """ + Get the scheduler type for this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `scheduler_type` property instead. + + :return: the scheduler type + :rtype: str + """ + warnings.warn('this method is deprecated, use the `scheduler_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.scheduler_type + + def set_scheduler_type(self, scheduler_type): + """ + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `scheduler_type` property instead. + + :param scheduler_type: the new scheduler type + """ + warnings.warn('this method is deprecated, use the `scheduler_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self._scheduler_type_validator(scheduler_type) + self.scheduler_type = scheduler_type + + def get_transport_type(self): + """ + Get the current transport type for this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `transport_type` property instead. + + :return: the transport type + :rtype: str + """ + warnings.warn('this method is deprecated, use the `transport_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.transport_type + + def set_transport_type(self, transport_type): + """ + Set the transport type for this computer + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `transport_type` property instead. + + :param transport_type: the new transport type + :type transport_type: str + """ + warnings.warn('this method is deprecated, use the `transport_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self.transport_type = transport_type + + def get_metadata(self): + """ + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `metadata` property instead. + + """ + warnings.warn('this method is deprecated, use the `metadata` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.metadata + + def set_metadata(self, metadata): + """ + Set the metadata. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `metadata` property instead. + + .. note: You still need to call the .store() method to actually save + data to the database! (The store method can be called multiple + times, differently from AiiDA Node objects). + """ + warnings.warn('this method is deprecated, use the `metadata` property instead', AiidaDeprecationWarning) # pylint: disable=no-member + self.metadata = metadata + @staticmethod def get_schema(): """ diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index 509ddf441c..6cd0dbf3cd 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -8,19 +8,22 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for all common top level AiiDA entity classes and methods""" - import typing +import abc +import copy from plumpy.base.utils import super_check, call_with_super_check -from aiida.common import datastructures +from aiida.common import datastructures, exceptions from aiida.common.lang import classproperty, type_check from aiida.manage.manager import get_manager -__all__ = ('Entity', 'Collection') +__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin') EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +_NO_DEFAULT = tuple() + class Collection(typing.Generic[EntityType]): """Container class that represents the collection of objects of a particular type.""" @@ -207,9 +210,9 @@ def from_backend_entity(cls, backend_entity): :return: an AiiDA entity instance """ - from . import implementation + from .implementation.entities import BackendEntity - type_check(backend_entity, implementation.BackendEntity) + type_check(backend_entity, BackendEntity) entity = cls.__new__(cls) entity.init_from_backend(backend_entity) call_with_super_check(entity.initialize) @@ -218,7 +221,7 @@ def from_backend_entity(cls, backend_entity): def __init__(self, backend_entity): """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.BackendEntity` + :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity call_with_super_check(self.initialize) @@ -226,7 +229,7 @@ def __init__(self, backend_entity): def init_from_backend(self, backend_entity): """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.BackendEntity` + :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity @@ -298,3 +301,298 @@ def backend_entity(self): :return: the class model """ return self._backend_entity + + +class EntityAttributesMixin(abc.ABC): + """Mixin class that adds all methods for the attributes column to an entity.""" + + @property + def attributes(self): + """Return the complete attributes dictionary. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :return: the attributes as a dictionary + """ + attributes = self.backend_entity.attributes + + if self.is_stored: + attributes = copy.deepcopy(attributes) + + return attributes + + def get_attribute(self, key, default=_NO_DEFAULT): + """Return the value of an attribute. + + .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attribute will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. + + :param key: name of the attribute + :param default: return this value instead of raising if the attribute does not exist + :return: the value of the attribute + :raises AttributeError: if the attribute does not exist and no default is specified + """ + try: + attribute = self.backend_entity.get_attribute(key) + except AttributeError: + if default is _NO_DEFAULT: + raise + attribute = default + + if self.is_stored: + attribute = copy.deepcopy(attribute) + + return attribute + + def get_attribute_many(self, keys): + """Return the values of multiple attributes. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :param keys: a list of attribute names + :return: a list of attribute values + :raises AttributeError: if at least one attribute does not exist + """ + attributes = self.backend_entity.get_attribute_many(keys) + + if self.is_stored: + attributes = copy.deepcopy(attributes) + + return attributes + + def set_attribute(self, key, value): + """Set an attribute to the given value. + + :param key: name of the attribute + :param value: value of the attribute + :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.set_attribute(key, value) + + def set_attribute_many(self, attributes): + """Set multiple attributes. + + .. note:: This will override any existing attributes that are present in the new dictionary. + + :param attributes: a dictionary with the attributes to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.set_attribute_many(attributes) + + def reset_attributes(self, attributes): + """Reset the attributes. + + .. note:: This will completely clear any existing attributes and replace them with the new dictionary. + + :param attributes: a dictionary with the attributes to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.reset_attributes(attributes) + + def delete_attribute(self, key): + """Delete an attribute. + + :param key: name of the attribute + :raises AttributeError: if the attribute does not exist + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.delete_attribute(key) + + def delete_attribute_many(self, keys): + """Delete multiple attributes. + + :param keys: names of the attributes to delete + :raises AttributeError: if at least one of the attribute does not exist + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.delete_attribute_many(keys) + + def clear_attributes(self): + """Delete all attributes.""" + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + + self.backend_entity.clear_attributes() + + def attributes_items(self): + """Return an iterator over the attributes. + + :return: an iterator with attribute key value pairs + """ + return self.backend_entity.attributes_items() + + def attributes_keys(self): + """Return an iterator over the attribute keys. + + :return: an iterator with attribute keys + """ + return self.backend_entity.attributes_keys() + + +class EntityExtrasMixin(abc.ABC): + """Mixin class that adds all methods for the extras column to an entity.""" + + @property + def extras(self): + """Return the complete extras dictionary. + + .. warning:: While the entity is unstored, this will return references of the extras on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extras will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys + or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and + `get_extra_many` instead. + + :return: the extras as a dictionary + """ + extras = self.backend_entity.extras + + if self.is_stored: + extras = copy.deepcopy(extras) + + return extras + + def get_extra(self, key, default=_NO_DEFAULT): + """Return the value of an extra. + + .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extra will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. + + :param key: name of the extra + :param default: return this value instead of raising if the attribute does not exist + :return: the value of the extra + :raises AttributeError: if the extra does not exist and no default is specified + """ + try: + extra = self.backend_entity.get_extra(key) + except AttributeError: + if default is _NO_DEFAULT: + raise + extra = default + + if self.is_stored: + extra = copy.deepcopy(extra) + + return extra + + def get_extra_many(self, keys): + """Return the values of multiple extras. + + .. warning:: While the entity is unstored, this will return references of the extras on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extras will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys + or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and + `get_extra_many` instead. + + :param keys: a list of extra names + :return: a list of extra values + :raises AttributeError: if at least one extra does not exist + """ + extras = self.backend_entity.get_extra_many(keys) + + if self.is_stored: + extras = copy.deepcopy(extras) + + return extras + + def set_extra(self, key, value): + """Set an extra to the given value. + + :param key: name of the extra + :param value: value of the extra + :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods + """ + self.backend_entity.set_extra(key, value) + + def set_extra_many(self, extras): + """Set multiple extras. + + .. note:: This will override any existing extras that are present in the new dictionary. + + :param extras: a dictionary with the extras to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + """ + self.backend_entity.set_extra_many(extras) + + def reset_extras(self, extras): + """Reset the extras. + + .. note:: This will completely clear any existing extras and replace them with the new dictionary. + + :param extras: a dictionary with the extras to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + """ + self.backend_entity.reset_extras(extras) + + def delete_extra(self, key): + """Delete an extra. + + :param key: name of the extra + :raises AttributeError: if the extra does not exist + """ + self.backend_entity.delete_extra(key) + + def delete_extra_many(self, keys): + """Delete multiple extras. + + :param keys: names of the extras to delete + :raises AttributeError: if at least one of the extra does not exist + """ + self.backend_entity.delete_extra_many(keys) + + def clear_extras(self): + """Delete all extras.""" + self.backend_entity.clear_extras() + + def extras_items(self): + """Return an iterator over the extras. + + :return: an iterator with extra key value pairs + """ + return self.backend_entity.extras_items() + + def extras_keys(self): + """Return an iterator over the extra keys. + + :return: an iterator with extra keys + """ + return self.backend_entity.extras_keys() diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index 903bca9f7c..ea410de4ca 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" AiiDA Group entites""" +"""AiiDA Group entites""" from abc import ABCMeta from enum import Enum import warnings @@ -77,7 +77,7 @@ class GroupTypeString(Enum): USER = 'user' -class Group(entities.Entity, metaclass=GroupMeta): +class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): """An AiiDA ORM implementation of group of nodes.""" class Collection(entities.Collection): diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 67e6d210d5..8e2f177b1d 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -9,7 +9,6 @@ ########################################################################### """Module with the implementations of the various backend entities for various database backends.""" # pylint: disable=wildcard-import,undefined-variable - from .authinfos import * from .backends import * from .comments import * diff --git a/aiida/orm/implementation/authinfos.py b/aiida/orm/implementation/authinfos.py index 66b209eb89..a9bc86e0f6 100644 --- a/aiida/orm/implementation/authinfos.py +++ b/aiida/orm/implementation/authinfos.py @@ -11,12 +11,12 @@ import abc -from . import backends +from .entities import BackendEntity, BackendCollection __all__ = ('BackendAuthInfo', 'BackendAuthInfoCollection') -class BackendAuthInfo(backends.BackendEntity): +class BackendAuthInfo(BackendEntity): """Backend implementation for the `AuthInfo` ORM class.""" METADATA_WORKDIR = 'workdir' @@ -78,7 +78,7 @@ def set_metadata(self, metadata): """ -class BackendAuthInfoCollection(backends.BackendCollection[BackendAuthInfo]): +class BackendAuthInfoCollection(BackendCollection[BackendAuthInfo]): """The collection of backend `AuthInfo` entries.""" ENTITY_CLASS = BackendAuthInfo diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 6b1c4025be..f0dfd50fe2 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,11 +9,8 @@ ########################################################################### """Generic backend related objects""" import abc -import typing -__all__ = ('Backend', 'BackendEntity', 'BackendCollection', 'EntityType') - -EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +__all__ = ('Backend',) class Backend(abc.ABC): @@ -120,97 +117,3 @@ def get_session(self): :return: an instance of :class:`sqlalchemy.orm.session.Session` """ - - -class BackendEntity(abc.ABC): - """An first-class entity in the backend""" - - def __init__(self, backend): - self._backend = backend - self._dbmodel = None - - @property - def backend(self): - """Return the backend this entity belongs to - - :return: the backend instance - """ - return self._backend - - @property - def dbmodel(self): - return self._dbmodel - - @abc.abstractproperty - def id(self): # pylint: disable=invalid-name - """Return the id for this entity. - - This is unique only amongst entities of this type for a particular backend. - - :return: the entity id - """ - - @property - def pk(self): - """Return the id for this entity. - - This is unique only amongst entities of this type for a particular backend. - - :return: the entity id - """ - return self.id - - @abc.abstractmethod - def store(self): - """Store this entity in the backend. - - Whether it is possible to call store more than once is delegated to the object itself - """ - - @abc.abstractproperty - def is_stored(self): - """Return whether the entity is stored. - - :return: True if stored, False otherwise - :rtype: bool - """ - - -class BackendCollection(typing.Generic[EntityType]): - """Container class that represents a collection of entries of a particular backend entity.""" - - ENTITY_CLASS = None # type: EntityType - - def __init__(self, backend): - """ - :param backend: the backend this collection belongs to - :type backend: :class:`aiida.orm.implementation.Backend` - """ - assert issubclass(self.ENTITY_CLASS, BackendEntity), 'Must set the ENTRY_CLASS class variable to an entity type' - self._backend = backend - - def from_dbmodel(self, dbmodel): - """ - Create an entity from the backend dbmodel - - :param dbmodel: the dbmodel to create the entity from - :return: the entity instance - """ - return self.ENTITY_CLASS.from_dbmodel(dbmodel, self.backend) - - @property - def backend(self): - """ - Return the backend. - - :rtype: :class:`aiida.orm.implementation.Backend` - """ - return self._backend - - def create(self, **kwargs): - """ - Create new a entry and set the attributes to those specified in the keyword arguments - - :return: the newly created entry of type ENTITY_CLASS - """ - return self.ENTITY_CLASS(backend=self._backend, **kwargs) # pylint: disable=not-callable diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py index d33e296ad1..57f92111f6 100644 --- a/aiida/orm/implementation/comments.py +++ b/aiida/orm/implementation/comments.py @@ -11,12 +11,12 @@ import abc -from . import backends +from .entities import BackendEntity, BackendCollection __all__ = ('BackendComment', 'BackendCommentCollection') -class BackendComment(backends.BackendEntity): +class BackendComment(BackendEntity): """Base class for a node comment.""" @property @@ -56,13 +56,13 @@ def set_content(self, value): pass -class BackendCommentCollection(backends.BackendCollection[BackendComment]): +class BackendCommentCollection(BackendCollection[BackendComment]): """The collection of Comment entries.""" ENTITY_CLASS = BackendComment @abc.abstractmethod - def create(self, node, user, content=None, **kwargs): + def create(self, node, user, content=None, **kwargs): # pylint: disable=arguments-differ """ Create a Comment for a given node and user diff --git a/aiida/orm/implementation/computers.py b/aiida/orm/implementation/computers.py index b90cd41681..fe06565b74 100644 --- a/aiida/orm/implementation/computers.py +++ b/aiida/orm/implementation/computers.py @@ -12,12 +12,12 @@ import abc import logging -from . import backends +from .entities import BackendEntity, BackendCollection __all__ = ('BackendComputer', 'BackendComputerCollection') -class BackendComputer(backends.BackendEntity): +class BackendComputer(BackendEntity): """ Base class to map a node in the DB + its permanent repository counterpart. @@ -117,7 +117,7 @@ def set_transport_type(self, transport_type): pass -class BackendComputerCollection(backends.BackendCollection[BackendComputer]): +class BackendComputerCollection(BackendCollection[BackendComputer]): """The collection of Computer entries.""" ENTITY_CLASS = BackendComputer diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index fbf9bb0ad5..6de13e3f02 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -16,7 +16,7 @@ from aiida.backends.djsite.queries import DjangoQueryManager from aiida.backends.djsite.manager import DjangoBackendManager -from ..sql import SqlBackend +from ..sql.backends import SqlBackend from . import authinfos from . import comments from . import computers diff --git a/aiida/orm/implementation/django/entities.py b/aiida/orm/implementation/django/entities.py index 2bdfeff83d..4e17518edc 100644 --- a/aiida/orm/implementation/django/entities.py +++ b/aiida/orm/implementation/django/entities.py @@ -76,15 +76,6 @@ def dbmodel(self): def id(self): # pylint: disable=invalid-name return self._dbmodel.pk - @property - def pk(self): - """ - Get the principal key for this entry - - :return: the principal key - """ - return self._dbmodel.id - @property def is_stored(self): """ diff --git a/aiida/orm/implementation/django/nodes.py b/aiida/orm/implementation/django/nodes.py index fc393e2114..a37666cb70 100644 --- a/aiida/orm/implementation/django/nodes.py +++ b/aiida/orm/implementation/django/nodes.py @@ -17,7 +17,7 @@ from aiida.backends.djsite.db import models from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.orm.utils.node import clean_value +from aiida.orm.implementation.utils import clean_value from .. import BackendNode, BackendNodeCollection from . import entities @@ -146,299 +146,6 @@ def user(self, user): type_check(user, DjangoUser) self._dbmodel.user = user.dbmodel - @property - def attributes(self): - """Return the complete attributes dictionary. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - return self.dbmodel.attributes - - def get_attribute(self, key): - """Return the value of an attribute. - - .. warning:: While the node is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - return self._dbmodel.attributes[key] - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - - def get_attribute_many(self, keys): - """Return the values of multiple attributes. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - try: - return [self.get_attribute(key) for key in keys] - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - """ - if self.is_stored: - value = clean_value(value) - - self._dbmodel.attributes[key] = value - self._flush_if_stored({'attributes'}) - - def set_attribute_many(self, attributes): - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - if self.is_stored: - attributes = {key: clean_value(value) for key, value in attributes.items()} - - for key, value in attributes.items(): - # We need to use `self.dbmodel` without the underscore, because otherwise the second iteration will refetch - # what is in the database and we lose the initial changes. - self.dbmodel.attributes[key] = value - self._flush_if_stored({'attributes'}) - - def reset_attributes(self, attributes): - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - if self.is_stored: - attributes = clean_value(attributes) - - self.dbmodel.attributes = attributes - self._flush_if_stored({'attributes'}) - - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - self._dbmodel.attributes.pop(key) - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - else: - self._flush_if_stored({'attributes'}) - - def delete_attribute_many(self, keys): - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.attributes] - - if non_existing_keys: - raise AttributeError('attributes `{}` do not exist'.format(', '.join(non_existing_keys))) - - for key in keys: - self.dbmodel.attributes.pop(key) - - self._flush_if_stored({'attributes'}) - - def clear_attributes(self): - """Delete all attributes.""" - self._dbmodel.attributes = {} - self._flush_if_stored({'attributes'}) - - def attributes_items(self): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - for key, value in self._dbmodel.attributes.items(): - yield key, value - - def attributes_keys(self): - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - for key in self._dbmodel.attributes: - yield key - - @property - def extras(self): - """Return the complete extras dictionary. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :return: the extras as a dictionary - """ - return self.dbmodel.extras - - def get_extra(self, key): - """Return the value of an extra. - - .. warning:: While the node is unstored, this will return a reference of the extra on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extra - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. - - :param key: name of the extra - :return: the value of the extra - :raises AttributeError: if the extra does not exist - """ - try: - return self._dbmodel.extras[key] - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - - def get_extra_many(self, keys): - """Return the values of multiple extras. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :param keys: a list of extra names - :return: a list of extra values - :raises AttributeError: if at least one extra does not exist - """ - try: - return [self.get_extra(key) for key in keys] - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - - def set_extra(self, key, value): - """Set an extra to the given value. - - :param key: name of the extra - :param value: value of the extra - """ - if self.is_stored: - value = clean_value(value) - - self._dbmodel.extras[key] = value - self._flush_if_stored({'extras'}) - - def set_extra_many(self, extras): - """Set multiple extras. - - .. note:: This will override any existing extras that are present in the new dictionary. - - :param extras: a dictionary with the extras to set - """ - if self.is_stored: - extras = {key: clean_value(value) for key, value in extras.items()} - - for key, value in extras.items(): - self.dbmodel.extras[key] = value - - self._flush_if_stored({'extras'}) - - def reset_extras(self, extras): - """Reset the extras. - - .. note:: This will completely clear any existing extras and replace them with the new dictionary. - - :param extras: a dictionary with the extras to set - """ - if self.is_stored: - extras = clean_value(extras) - - self.dbmodel.extras = extras - self._flush_if_stored({'extras'}) - - def delete_extra(self, key): - """Delete an extra. - - :param key: name of the extra - :raises AttributeError: if the extra does not exist - """ - try: - self._dbmodel.extras.pop(key) - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - else: - self._flush_if_stored({'extras'}) - - def delete_extra_many(self, keys): - """Delete multiple extras. - - :param keys: names of the extras to delete - :raises AttributeError: if at least one of the extra does not exist - """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.extras] - - if non_existing_keys: - raise AttributeError('extras `{}` do not exist'.format(', '.join(non_existing_keys))) - - for key in keys: - self.dbmodel.extras.pop(key) - - self._flush_if_stored({'extras'}) - - def clear_extras(self): - """Delete all extras.""" - self._dbmodel.extras = {} - self._flush_if_stored({'extras'}) - - def extras_items(self): - """Return an iterator over the extras. - - :return: an iterator with extra key value pairs - """ - for key, value in self._dbmodel.extras.items(): - yield key, value - - def extras_keys(self): - """Return an iterator over the extra keys. - - :return: an iterator with extra keys - """ - for key in self._dbmodel.extras: - yield key - - def _flush_if_stored(self, fields=None): - if self._dbmodel.is_saved(): - self._dbmodel._flush(fields) # pylint: disable=protected-access - def add_incoming(self, source, link_type, link_label): """Add a link of the given type from a given node to ourself. @@ -475,7 +182,7 @@ def _add_link(self, source, link_type, link_label): transaction.savepoint_commit(savepoint_id) except IntegrityError as exception: transaction.savepoint_rollback(savepoint_id) - raise exceptions.UniquenessError('failed to create the link: {}'.format(exception)) + raise exceptions.UniquenessError('failed to create the link: {}'.format(exception)) from exception def clean_values(self): self._dbmodel.attributes = clean_value(self._dbmodel.attributes) @@ -523,7 +230,7 @@ def get(self, pk): try: return self.ENTITY_CLASS.from_dbmodel(models.DbNode.objects.get(pk=pk), self.backend) except ObjectDoesNotExist: - raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) + raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) from ObjectDoesNotExist def delete(self, pk): """Remove a Node entry from the collection with the given id @@ -533,4 +240,4 @@ def delete(self, pk): try: models.DbNode.objects.filter(pk=pk).delete() # pylint: disable=no-member except ObjectDoesNotExist: - raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) + raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) from ObjectDoesNotExist diff --git a/aiida/orm/implementation/django/users.py b/aiida/orm/implementation/django/users.py index db3f663a2d..7940728b04 100644 --- a/aiida/orm/implementation/django/users.py +++ b/aiida/orm/implementation/django/users.py @@ -70,13 +70,14 @@ class DjangoUserCollection(BackendUserCollection): ENTITY_CLASS = DjangoUser - def create(self, email, first_name='', last_name='', institution=''): + def create(self, email, first_name='', last_name='', institution=''): # pylint: disable=arguments-differ """ Create a user with the provided email address :return: A new user object :rtype: :class:`aiida.orm.implementation.django.users.DjangoUser` """ + # pylint: disable=abstract-class-instantiated return DjangoUser(self.backend, email, first_name, last_name, institution) def find(self, email=None, id=None): # pylint: disable=redefined-builtin, invalid-name diff --git a/aiida/orm/implementation/django/utils.py b/aiida/orm/implementation/django/utils.py index 9df7ad8071..58664c5ac9 100644 --- a/aiida/orm/implementation/django/utils.py +++ b/aiida/orm/implementation/django/utils.py @@ -114,9 +114,9 @@ def _is_model_field(self, name): def _flush(self, fields=None): """Flush the fields of the model to the database. - .. note:: If the wrapped model is not actually save in the database yet, this method is a no-op. + .. note:: If the wrapped model is not actually saved in the database yet, this method is a no-op. - :param fields: the model fields whose currently value to flush to the database + :param fields: the model fields whose current value to flush to the database """ if self.is_saved(): try: diff --git a/aiida/orm/implementation/entities.py b/aiida/orm/implementation/entities.py new file mode 100644 index 0000000000..a8cdfadafe --- /dev/null +++ b/aiida/orm/implementation/entities.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Classes and methods for backend non-specific entities""" +import abc +import typing + +from aiida.orm.implementation.utils import clean_value, validate_attribute_extra_key + +__all__ = ( + 'BackendEntity', 'BackendCollection', 'EntityType', 'BackendEntityAttributesMixin', 'BackendEntityExtrasMixin' +) + +EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name + + +class BackendEntity(abc.ABC): + """An first-class entity in the backend""" + + def __init__(self, backend): + self._backend = backend + self._dbmodel = None + + @property + def backend(self): + """Return the backend this entity belongs to + + :return: the backend instance + """ + return self._backend + + @property + def dbmodel(self): + return self._dbmodel + + @abc.abstractproperty + def id(self): # pylint: disable=invalid-name + """Return the id for this entity. + + This is unique only amongst entities of this type for a particular backend. + + :return: the entity id + """ + + @property + def pk(self): + """Return the id for this entity. + + This is unique only amongst entities of this type for a particular backend. + + :return: the entity id + """ + return self.id + + @abc.abstractmethod + def store(self): + """Store this entity in the backend. + + Whether it is possible to call store more than once is delegated to the object itself + """ + + @abc.abstractproperty + def is_stored(self): + """Return whether the entity is stored. + + :return: True if stored, False otherwise + :rtype: bool + """ + + def _flush_if_stored(self, fields): + if self._dbmodel.is_saved(): + self._dbmodel._flush(fields) # pylint: disable=protected-access + + +class BackendCollection(typing.Generic[EntityType]): + """Container class that represents a collection of entries of a particular backend entity.""" + + ENTITY_CLASS = None # type: EntityType + + def __init__(self, backend): + """ + :param backend: the backend this collection belongs to + :type backend: :class:`aiida.orm.implementation.Backend` + """ + assert issubclass(self.ENTITY_CLASS, BackendEntity), 'Must set the ENTRY_CLASS class variable to an entity type' + self._backend = backend + + def from_dbmodel(self, dbmodel): + """ + Create an entity from the backend dbmodel + + :param dbmodel: the dbmodel to create the entity from + :return: the entity instance + """ + return self.ENTITY_CLASS.from_dbmodel(dbmodel, self.backend) + + @property + def backend(self): + """ + Return the backend. + + :rtype: :class:`aiida.orm.implementation.Backend` + """ + return self._backend + + def create(self, **kwargs): + """ + Create new a entry and set the attributes to those specified in the keyword arguments + + :return: the newly created entry of type ENTITY_CLASS + """ + return self.ENTITY_CLASS(backend=self._backend, **kwargs) # pylint: disable=not-callable + + +class BackendEntityAttributesMixin(abc.ABC): + """Mixin class that adds all methods for the attributes column to a backend entity""" + + @property + def attributes(self): + """Return the complete attributes dictionary. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :return: the attributes as a dictionary + """ + return self._dbmodel.attributes + + def get_attribute(self, key): + """Return the value of an attribute. + + .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attribute will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. + + :param key: name of the attribute + :return: the value of the attribute + :raises AttributeError: if the attribute does not exist + """ + try: + return self._dbmodel.attributes[key] + except KeyError as exception: + raise AttributeError('attribute `{}` does not exist'.format(exception)) from exception + + def get_attribute_many(self, keys): + """Return the values of multiple attributes. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :param keys: a list of attribute names + :return: a list of attribute values + :raises AttributeError: if at least one attribute does not exist + """ + try: + return [self.get_attribute(key) for key in keys] + except KeyError as exception: + raise AttributeError('attribute `{}` does not exist'.format(exception)) from exception + + def set_attribute(self, key, value): + """Set an attribute to the given value. + + :param key: name of the attribute + :param value: value of the attribute + """ + validate_attribute_extra_key(key) + + if self.is_stored: + value = clean_value(value) + + self._dbmodel.attributes[key] = value + self._flush_if_stored({'attributes'}) + + def set_attribute_many(self, attributes): + """Set multiple attributes. + + .. note:: This will override any existing attributes that are present in the new dictionary. + + :param attributes: a dictionary with the attributes to set + """ + for key in attributes: + validate_attribute_extra_key(key) + + if self.is_stored: + attributes = {key: clean_value(value) for key, value in attributes.items()} + + for key, value in attributes.items(): + # We need to use `self.dbmodel` without the underscore, because otherwise the second iteration will refetch + # what is in the database and we lose the initial changes. + self.dbmodel.attributes[key] = value + self._flush_if_stored({'attributes'}) + + def reset_attributes(self, attributes): + """Reset the attributes. + + .. note:: This will completely clear any existing attributes and replace them with the new dictionary. + + :param attributes: a dictionary with the attributes to set + """ + for key in attributes: + validate_attribute_extra_key(key) + + if self.is_stored: + attributes = clean_value(attributes) + + self.dbmodel.attributes = attributes + self._flush_if_stored({'attributes'}) + + def delete_attribute(self, key): + """Delete an attribute. + + :param key: name of the attribute + :raises AttributeError: if the attribute does not exist + """ + try: + self._dbmodel.attributes.pop(key) + except KeyError as exception: + raise AttributeError('attribute `{}` does not exist'.format(exception)) from exception + else: + self._flush_if_stored({'attributes'}) + + def delete_attribute_many(self, keys): + """Delete multiple attributes. + + :param keys: names of the attributes to delete + :raises AttributeError: if at least one of the attribute does not exist + """ + non_existing_keys = [key for key in keys if key not in self._dbmodel.attributes] + + if non_existing_keys: + raise AttributeError('attributes `{}` do not exist'.format(', '.join(non_existing_keys))) + + for key in keys: + self.dbmodel.attributes.pop(key) + + self._flush_if_stored({'attributes'}) + + def clear_attributes(self): + """Delete all attributes.""" + self._dbmodel.attributes = {} + self._flush_if_stored({'attributes'}) + + def attributes_items(self): + """Return an iterator over the attributes. + + :return: an iterator with attribute key value pairs + """ + for key, value in self._dbmodel.attributes.items(): + yield key, value + + def attributes_keys(self): + """Return an iterator over the attribute keys. + + :return: an iterator with attribute keys + """ + for key in self._dbmodel.attributes.keys(): + yield key + + @abc.abstractproperty + def is_stored(self): + """Return whether the entity is stored. + + :return: True if stored, False otherwise + :rtype: bool + """ + + @abc.abstractmethod + def _flush_if_stored(self, fields): + """Flush the fields""" + + +class BackendEntityExtrasMixin(abc.ABC): + """Mixin class that adds all methods for the extras column to a backend entity""" + + @property + def extras(self): + """Return the complete extras dictionary. + + .. warning:: While the entity is unstored, this will return references of the extras on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extras will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys + or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and + `get_extra_many` instead. + + :return: the extras as a dictionary + """ + return self._dbmodel.extras + + def get_extra(self, key): + """Return the value of an extra. + + .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extra will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. + + :param key: name of the extra + :return: the value of the extra + :raises AttributeError: if the extra does not exist + """ + try: + return self._dbmodel.extras[key] + except KeyError as exception: + raise AttributeError('extra `{}` does not exist'.format(exception)) from exception + + def get_extra_many(self, keys): + """Return the values of multiple extras. + + .. warning:: While the entity is unstored, this will return references of the extras on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + extras will be a deep copy and mutations of the database extras will have to go through the appropriate set + methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys + or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and + `get_extra_many` instead. + + :param keys: a list of extra names + :return: a list of extra values + :raises AttributeError: if at least one extra does not exist + """ + return [self.get_extra(key) for key in keys] + + def set_extra(self, key, value): + """Set an extra to the given value. + + :param key: name of the extra + :param value: value of the extra + """ + validate_attribute_extra_key(key) + + if self.is_stored: + value = clean_value(value) + + self._dbmodel.extras[key] = value + self._flush_if_stored({'extras'}) + + def set_extra_many(self, extras): + """Set multiple extras. + + .. note:: This will override any existing extras that are present in the new dictionary. + + :param extras: a dictionary with the extras to set + """ + for key in extras: + validate_attribute_extra_key(key) + + if self.is_stored: + extras = {key: clean_value(value) for key, value in extras.items()} + + for key, value in extras.items(): + self.dbmodel.extras[key] = value + + self._flush_if_stored({'extras'}) + + def reset_extras(self, extras): + """Reset the extras. + + .. note:: This will completely clear any existing extras and replace them with the new dictionary. + + :param extras: a dictionary with the extras to set + """ + for key in extras: + validate_attribute_extra_key(key) + + if self.is_stored: + extras = clean_value(extras) + + self.dbmodel.extras = extras + self._flush_if_stored({'extras'}) + + def delete_extra(self, key): + """Delete an extra. + + :param key: name of the extra + :raises AttributeError: if the extra does not exist + """ + try: + self._dbmodel.extras.pop(key) + except KeyError as exception: + raise AttributeError('extra `{}` does not exist'.format(exception)) from exception + else: + self._flush_if_stored({'extras'}) + + def delete_extra_many(self, keys): + """Delete multiple extras. + + :param keys: names of the extras to delete + :raises AttributeError: if at least one of the extra does not exist + """ + non_existing_keys = [key for key in keys if key not in self._dbmodel.extras] + + if non_existing_keys: + raise AttributeError('extras `{}` do not exist'.format(', '.join(non_existing_keys))) + + for key in keys: + self.dbmodel.extras.pop(key) + + self._flush_if_stored({'extras'}) + + def clear_extras(self): + """Delete all extras.""" + self._dbmodel.extras = {} + self._flush_if_stored({'extras'}) + + def extras_items(self): + """Return an iterator over the extras. + + :return: an iterator with extra key value pairs + """ + for key, value in self._dbmodel.extras.items(): + yield key, value + + def extras_keys(self): + """Return an iterator over the extra keys. + + :return: an iterator with extra keys + """ + for key in self._dbmodel.extras.keys(): + yield key + + @abc.abstractproperty + def is_stored(self): + """Return whether the entity is stored. + + :return: True if stored, False otherwise + :rtype: bool + """ + + @abc.abstractmethod + def _flush_if_stored(self, fields): + """Flush the fields""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index f39314060f..b4e7fbf2b9 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -12,14 +12,14 @@ import abc from aiida.common import exceptions +from .entities import BackendEntity, BackendCollection, BackendEntityExtrasMixin -from . import backends from .nodes import BackendNode __all__ = ('BackendGroup', 'BackendGroupCollection') -class BackendGroup(backends.BackendEntity): +class BackendGroup(BackendEntity, BackendEntityExtrasMixin): """ An AiiDA ORM implementation of group of nodes. """ @@ -101,7 +101,7 @@ def get_or_create(cls, *args, **kwargs): :return: (group, created) where group is the group (new or existing, in any case already stored) and created is a boolean saying """ - res = cls.query(name=kwargs.get('name')) + res = cls.query(name=kwargs.get('name')) # pylint: disable=no-member if not res: return cls.create(*args, **kwargs), True @@ -193,7 +193,7 @@ def __str__(self): return '"{}" [user-defined], of user {}'.format(self.label, self.user.email) -class BackendGroupCollection(backends.BackendCollection[BackendGroup]): +class BackendGroupCollection(BackendCollection[BackendGroup]): """The collection of Group entries.""" ENTITY_CLASS = BackendGroup diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py index 5924d0d228..b59fa52313 100644 --- a/aiida/orm/implementation/logs.py +++ b/aiida/orm/implementation/logs.py @@ -8,15 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Backend group module""" - import abc -from . import backends +from .entities import BackendEntity, BackendCollection __all__ = ('BackendLog', 'BackendLogCollection') -class BackendLog(backends.BackendEntity): +class BackendLog(BackendEntity): """ Backend Log interface """ @@ -85,7 +84,7 @@ def metadata(self): """ -class BackendLogCollection(backends.BackendCollection[BackendLog]): +class BackendLogCollection(BackendCollection[BackendLog]): """The collection of Log entries.""" ENTITY_CLASS = BackendLog diff --git a/aiida/orm/implementation/nodes.py b/aiida/orm/implementation/nodes.py index 67f89fc392..0d7f8eb867 100644 --- a/aiida/orm/implementation/nodes.py +++ b/aiida/orm/implementation/nodes.py @@ -11,12 +11,12 @@ import abc -from . import backends +from .entities import BackendEntity, BackendCollection, BackendEntityAttributesMixin, BackendEntityExtrasMixin __all__ = ('BackendNode', 'BackendNodeCollection') -class BackendNode(backends.BackendEntity): +class BackendNode(BackendEntity, BackendEntityExtrasMixin, BackendEntityAttributesMixin, metaclass=abc.ABCMeta): """Wrapper around a `DbNode` instance to set and retrieve data independent of the database implementation.""" # pylint: disable=too-many-public-methods @@ -144,220 +144,6 @@ def mtime(self): """ return self._dbmodel.mtime - @abc.abstractproperty - def attributes(self): - """Return the complete attributes dictionary. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - - @abc.abstractmethod - def get_attribute(self, key): - """Return the value of an attribute. - - .. warning:: While the node is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist - """ - - @abc.abstractmethod - def get_attribute_many(self, keys): - """Return the values of multiple attributes. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - - @abc.abstractmethod - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - """ - - @abc.abstractmethod - def set_attribute_many(self, attributes): - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - - @abc.abstractmethod - def reset_attributes(self, attributes): - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - - @abc.abstractmethod - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - """ - - @abc.abstractmethod - def delete_attribute_many(self, keys): - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - """ - - @abc.abstractmethod - def clear_attributes(self): - """Delete all attributes.""" - - @abc.abstractmethod - def attributes_items(self): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - - @abc.abstractmethod - def attributes_keys(self): - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - - @abc.abstractproperty - def extras(self): - """Return the complete extras dictionary. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :return: the extras as a dictionary - """ - - @abc.abstractmethod - def get_extra(self, key): - """Return the value of an extra. - - .. warning:: While the node is unstored, this will return a reference of the extra on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extra - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. - - :param key: name of the extra - :return: the value of the extra - :raises AttributeError: if the extra does not exist - """ - - @abc.abstractmethod - def get_extra_many(self, keys): - """Return the values of multiple extras. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :param keys: a list of extra names - :return: a list of extra values - :raises AttributeError: if at least one extra does not exist - """ - - @abc.abstractmethod - def set_extra(self, key, value): - """Set an extra to the given value. - - :param key: name of the extra - :param value: value of the extra - """ - - @abc.abstractmethod - def set_extra_many(self, extras): - """Set multiple extras. - - .. note:: This will override any existing extras that are present in the new dictionary. - - :param extras: a dictionary with the extras to set - """ - - @abc.abstractmethod - def reset_extras(self, extras): - """Reset the extras. - - .. note:: This will completely clear any existing extras and replace them with the new dictionary. - - :param extras: a dictionary with the extras to set - """ - - @abc.abstractmethod - def delete_extra(self, key): - """Delete an extra. - - :param key: name of the extra - :raises AttributeError: if the extra does not exist - """ - - @abc.abstractmethod - def delete_extra_many(self, keys): - """Delete multiple extras. - - :param keys: names of the extras to delete - :raises AttributeError: if at least one of the extra does not exist - """ - - @abc.abstractmethod - def clear_extras(self): - """Delete all extras.""" - - @abc.abstractmethod - def extras_items(self): - """Return an iterator over the extras. - - :return: an iterator with extra key value pairs - """ - - @abc.abstractmethod - def extras_keys(self): - """Return an iterator over the extra keys. - - :return: an iterator with extra keys - """ - @abc.abstractmethod def add_incoming(self, source, link_type, link_label): """Add a link of the given type from a given node to ourself. @@ -371,7 +157,7 @@ def add_incoming(self, source, link_type, link_label): """ @abc.abstractmethod - def store(self, links=None, with_transaction=True, clean=True): + def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ """Store the node in the database. :param links: optional links to add before storing @@ -380,7 +166,7 @@ def store(self, links=None, with_transaction=True, clean=True): """ -class BackendNodeCollection(backends.BackendCollection[BackendNode]): +class BackendNodeCollection(BackendCollection[BackendNode]): """The collection of `BackendNode` entries.""" # pylint: disable=too-few-public-methods diff --git a/aiida/orm/implementation/sql/__init__.py b/aiida/orm/implementation/sql/__init__.py index 91890bd3d5..3cea3705ad 100644 --- a/aiida/orm/implementation/sql/__init__.py +++ b/aiida/orm/implementation/sql/__init__.py @@ -12,9 +12,3 @@ All SQL backends with an ORM should subclass from the classes in this module """ - -# pylint: disable=wildcard-import - -from .backends import * - -__all__ = (backends.__all__) diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index f0db5ed627..2bb21f22af 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Generic backend related objects""" - import abc import typing @@ -37,7 +36,7 @@ def get_backend_entity(self, model): :param model: the ORM model instance to promote to a backend instance :return: the backend entity corresponding to the given model - :rtype: :class:`aiida.orm.implementation.BackendEntity` + :rtype: :class:`aiida.orm.implementation.entities.BackendEntity` """ @abc.abstractmethod diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 224d4933df..fa4ba06941 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -14,7 +14,7 @@ from aiida.backends.sqlalchemy.queries import SqlaQueryManager from aiida.backends.sqlalchemy.manager import SqlaBackendManager -from ..sql import SqlBackend +from ..sql.backends import SqlBackend from . import authinfos from . import comments from . import computers diff --git a/aiida/orm/implementation/sqlalchemy/entities.py b/aiida/orm/implementation/sqlalchemy/entities.py index 5dcbc0139c..67e3043c63 100644 --- a/aiida/orm/implementation/sqlalchemy/entities.py +++ b/aiida/orm/implementation/sqlalchemy/entities.py @@ -84,15 +84,6 @@ def id(self): # pylint: disable=redefined-builtin, invalid-name """ return self._dbmodel.id - @property - def pk(self): - """ - Get the principal key for this entry - - :return: the principal key - """ - return self._dbmodel.id - @property def is_stored(self): """ diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 5302d4f0d6..0864c7d3b1 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -70,7 +70,8 @@ def label(self, label): try: self._dbmodel.save() except Exception: - raise UniquenessError('a group of the same type with the label {} already exists'.format(label)) + raise UniquenessError('a group of the same type with the label {} already exists'.format(label)) \ + from Exception @property def description(self): @@ -109,7 +110,7 @@ def __int__(self): if not self.is_stored: return None - return self._dbnode.id + return self._dbnode.id # pylint: disable=no-member @property def is_stored(self): @@ -228,37 +229,56 @@ def check_node(given_node): # Commit everything as up till now we've just flushed session.commit() - def remove_nodes(self, nodes): + def remove_nodes(self, nodes, **kwargs): """Remove a node or a set of nodes from the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group + :param kwargs: + skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL + DELETE statement to the group-node relationship table in order to improve speed. """ + from sqlalchemy import and_ + from aiida.backends.sqlalchemy import get_scoped_session + from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode super().remove_nodes(nodes) # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self._dbmodel.dbnodes + skip_orm = kwargs.get('skip_orm', False) - list_nodes = [] - - for node in nodes: + def check_node(node): if not isinstance(node, SqlaNode): raise TypeError('invalid type {}, has to be {}'.format(type(node), SqlaNode)) if node.id is None: raise ValueError('At least one of the provided nodes is unstored, stopping...') - # If we don't check first, SqlA might issue a DELETE statement for an unexisting key, resulting in an error - if node.dbmodel in dbnodes: - list_nodes.append(node.dbmodel) + list_nodes = [] - for node in list_nodes: - dbnodes.remove(node) + with utils.disable_expire_on_commit(get_scoped_session()) as session: + if not skip_orm: + for node in nodes: + check_node(node) + + # Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error + if node.dbmodel in dbnodes: + list_nodes.append(node.dbmodel) + + for node in list_nodes: + dbnodes.remove(node) + else: + table = Base.metadata.tables['db_dbgroup_dbnodes'] + for node in nodes: + check_node(node) + clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id) + statement = table.delete().where(clause) + session.execute(statement) - sa.get_scoped_session().commit() + session.commit() class SqlaGroupCollection(BackendGroupCollection): diff --git a/aiida/orm/implementation/sqlalchemy/nodes.py b/aiida/orm/implementation/sqlalchemy/nodes.py index 6f0372646e..8b857c746d 100644 --- a/aiida/orm/implementation/sqlalchemy/nodes.py +++ b/aiida/orm/implementation/sqlalchemy/nodes.py @@ -18,7 +18,7 @@ from aiida.backends.sqlalchemy.models import node as models from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.orm.utils.node import clean_value +from aiida.orm.implementation.utils import clean_value from .. import BackendNode, BackendNodeCollection from . import entities @@ -148,307 +148,6 @@ def user(self, user): type_check(user, SqlaUser) self._dbmodel.user = user.dbmodel - @property - def attributes(self): - """Return the complete attributes dictionary. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - return self._dbmodel.attributes - - def get_attribute(self, key): - """Return the value of an attribute. - - .. warning:: While the node is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - return self._dbmodel.attributes[key] - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - - def get_attribute_many(self, keys): - """Return the values of multiple attributes. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - try: - return [self.get_attribute(key) for key in keys] - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - """ - if self.is_stored: - value = clean_value(value) - - self._dbmodel.attributes[key] = value - self._flag_field('attributes') - self._flush_if_stored() - - def set_attribute_many(self, attributes): - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - if self.is_stored: - attributes = {key: clean_value(value) for key, value in attributes.items()} - - for key, value in attributes.items(): - self.dbmodel.attributes[key] = value - - self._flag_field('attributes') - self._flush_if_stored() - - def reset_attributes(self, attributes): - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - if self.is_stored: - attributes = clean_value(attributes) - - self.dbmodel.attributes = attributes - self._flag_field('attributes') - self._flush_if_stored() - - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - self._dbmodel.attributes.pop(key) - except KeyError as exception: - raise AttributeError('attribute `{}` does not exist'.format(exception)) - else: - self._flag_field('attributes') - self._flush_if_stored() - - def delete_attribute_many(self, keys): - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.attributes] - - if non_existing_keys: - raise AttributeError('attributes `{}` do not exist'.format(', '.join(non_existing_keys))) - - for key in keys: - self.dbmodel.attributes.pop(key) - - self._flag_field('attributes') - self._flush_if_stored() - - def clear_attributes(self): - """Delete all attributes.""" - self._dbmodel.attributes = {} - - def attributes_items(self): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - for key, value in self._dbmodel.attributes.items(): - yield key, value - - def attributes_keys(self): - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - for key in self._dbmodel.attributes.keys(): - yield key - - @property - def extras(self): - """Return the complete extras dictionary. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :return: the extras as a dictionary - """ - return self._dbmodel.extras - - def get_extra(self, key): - """Return the value of an extra. - - .. warning:: While the node is unstored, this will return a reference of the extra on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extra - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. - - :param key: name of the extra - :return: the value of the extra - :raises AttributeError: if the extra does not exist - """ - try: - return self._dbmodel.extras[key] - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - - def get_extra_many(self, keys): - """Return the values of multiple extras. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :param keys: a list of extra names - :return: a list of extra values - :raises AttributeError: if at least one extra does not exist - """ - try: - return [self.get_extra(key) for key in keys] - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - - def set_extra(self, key, value): - """Set an extra to the given value. - - :param key: name of the extra - :param value: value of the extra - """ - if self.is_stored: - value = clean_value(value) - - self._dbmodel.extras[key] = value - self._flag_field('extras') - self._flush_if_stored() - - def set_extra_many(self, extras): - """Set multiple extras. - - .. note:: This will override any existing extras that are present in the new dictionary. - - :param extras: a dictionary with the extras to set - """ - if self.is_stored: - extras = {key: clean_value(value) for key, value in extras.items()} - - for key, value in extras.items(): - self.dbmodel.extras[key] = value - - self._flag_field('extras') - self._flush_if_stored() - - def reset_extras(self, extras): - """Reset the extras. - - .. note:: This will completely clear any existing extras and replace them with the new dictionary. - - :param extras: a dictionary with the extras to set - """ - self.dbmodel.extras = extras - self._flag_field('extras') - self._flush_if_stored() - - def delete_extra(self, key): - """Delete an extra. - - :param key: name of the extra - :raises AttributeError: if the extra does not exist - """ - try: - self._dbmodel.extras.pop(key) - except KeyError as exception: - raise AttributeError('extra `{}` does not exist'.format(exception)) - else: - self._flag_field('extras') - self._flush_if_stored() - - def delete_extra_many(self, keys): - """Delete multiple extras. - - :param keys: names of the extras to delete - :raises AttributeError: if at least one of the extra does not exist - """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.extras] - - if non_existing_keys: - raise AttributeError('extras `{}` do not exist'.format(', '.join(non_existing_keys))) - - for key in keys: - self.dbmodel.extras.pop(key) - - self._flag_field('extras') - self._flush_if_stored() - - def clear_extras(self): - """Delete all extras.""" - self._dbmodel.extras = {} - - def extras_items(self): - """Return an iterator over the extras. - - :return: an iterator with extra key value pairs - """ - for key, value in self._dbmodel.extras.items(): - yield key, value - - def extras_keys(self): - """Return an iterator over the extra keys. - - :return: an iterator with extra keys - """ - for key in self._dbmodel.extras.keys(): - yield key - - def _flag_field(self, field): - from aiida.backends.sqlalchemy.utils import flag_modified - flag_modified(self._dbmodel, field) - - def _flush_if_stored(self): - if self._dbmodel.is_saved(): - self._dbmodel.save() - def add_incoming(self, source, link_type, link_label): """Add a link of the given type from a given node to ourself. @@ -487,7 +186,7 @@ def _add_link(self, source, link_type, link_label): link = DbLink(input_id=source.id, output_id=self.id, label=link_label, type=link_type.value) session.add(link) except SQLAlchemyError as exception: - raise exceptions.UniquenessError('failed to create the link: {}'.format(exception)) + raise exceptions.UniquenessError('failed to create the link: {}'.format(exception)) from exception def clean_values(self): self._dbmodel.attributes = clean_value(self._dbmodel.attributes) @@ -536,7 +235,7 @@ def get(self, pk): try: return self.ENTITY_CLASS.from_dbmodel(session.query(models.DbNode).filter_by(id=pk).one(), self.backend) except NoResultFound: - raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) + raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) from NoResultFound def delete(self, pk): """Remove a Node entry from the collection with the given id @@ -549,4 +248,4 @@ def delete(self, pk): session.query(models.DbNode).filter_by(id=pk).one().delete() session.commit() except NoResultFound: - raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) + raise exceptions.NotExistent("Node with pk '{}' not found".format(pk)) from NoResultFound diff --git a/aiida/orm/implementation/sqlalchemy/users.py b/aiida/orm/implementation/sqlalchemy/users.py index cf97968f39..55b4ed18ce 100644 --- a/aiida/orm/implementation/sqlalchemy/users.py +++ b/aiida/orm/implementation/sqlalchemy/users.py @@ -66,13 +66,14 @@ class SqlaUserCollection(BackendUserCollection): ENTITY_CLASS = SqlaUser - def create(self, email, first_name='', last_name='', institution=''): + def create(self, email, first_name='', last_name='', institution=''): # pylint: disable=arguments-differ """ Create a user with the provided email address :return: A new user object :rtype: :class:`aiida.orm.User` """ + # pylint: disable=abstract-class-instantiated return SqlaUser(self.backend, email, first_name, last_name, institution) def find(self, email=None, id=None): # pylint: disable=redefined-builtin,invalid-name diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/orm/implementation/sqlalchemy/utils.py index 1ca086aa25..d265517bf2 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/orm/implementation/sqlalchemy/utils.py @@ -120,9 +120,9 @@ def _is_model_field(self, field): def _flush(self, fields=()): """Flush the fields of the model to the database. - .. note:: If the wrapped model is not actually save in the database yet, this method is a no-op. + .. note:: If the wrapped model is not actually saved in the database yet, this method is a no-op. - :param fields: the model fields whose currently value to flush to the database + :param fields: the model fields whose current value to flush to the database """ if self.is_saved(): for field in fields: diff --git a/aiida/orm/implementation/users.py b/aiida/orm/implementation/users.py index a2ec8dd52b..7cae760bd5 100644 --- a/aiida/orm/implementation/users.py +++ b/aiida/orm/implementation/users.py @@ -10,12 +10,12 @@ """Backend user""" import abc -from . import backends +from .entities import BackendEntity, BackendCollection __all__ = ('BackendUser', 'BackendUserCollection') -class BackendUser(backends.BackendEntity): +class BackendUser(BackendEntity): """ This is the base class for User information in AiiDA. An implementing backend needs to provide a concrete version. @@ -106,7 +106,7 @@ def institution(self, val): """ -class BackendUserCollection(backends.BackendCollection[BackendUser]): +class BackendUserCollection(BackendCollection[BackendUser]): # pylint: disable=too-few-public-methods ENTITY_CLASS = BackendUser diff --git a/aiida/orm/implementation/utils.py b/aiida/orm/implementation/utils.py new file mode 100644 index 0000000000..3641eff954 --- /dev/null +++ b/aiida/orm/implementation/utils.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utility methods for backend non-specific implementations.""" +import math +import numbers + +from collections.abc import Iterable, Mapping + +from aiida.common import exceptions +from aiida.common.constants import AIIDA_FLOAT_PRECISION + +# This separator character is reserved to indicate nested fields in node attribute and extras dictionaries and +# therefore is not allowed in individual attribute or extra keys. +FIELD_SEPARATOR = '.' + +__all__ = ('validate_attribute_extra_key', 'clean_value') + + +def validate_attribute_extra_key(key): + """Validate the key for an entity attribute or extra. + + :raise aiida.common.ValidationError: if the key is not a string or contains reserved separator character + """ + if not key or not isinstance(key, str): + raise exceptions.ValidationError('key for attributes or extras should be a string') + + if FIELD_SEPARATOR in key: + raise exceptions.ValidationError( + 'key for attributes or extras cannot contain the character `{}`'.format(FIELD_SEPARATOR) + ) + + +def clean_value(value): + """ + Get value from input and (recursively) replace, if needed, all occurrences + of BaseType AiiDA data nodes with their value, and List with a standard list. + It also makes a deep copy of everything + The purpose of this function is to convert data to a type which can be serialized and deserialized + for storage in the DB without its value changing. + + Note however that there is no logic to avoid infinite loops when the + user passes some perverse recursive dictionary or list. + In any case, however, this would not be storable by AiiDA... + + :param value: A value to be set as an attribute or an extra + :return: a "cleaned" value, potentially identical to value, but with + values replaced where needed. + """ + # Must be imported in here to avoid recursive imports + from aiida.orm import BaseType + + def clean_builtin(val): + """ + A function to clean build-in python values (`BaseType`). + + It mainly checks that we don't store NaN or Inf. + """ + # This is a whitelist of all the things we understand currently + if val is None or isinstance(val, (bool, str)): + return val + + # This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized + # Note that `numbers.Integral` also match booleans but they are already returned above + if isinstance(val, numbers.Integral): + return int(val) + + if isinstance(val, numbers.Real) and (math.isnan(val) or math.isinf(val)): + # see https://www.postgresql.org/docs/current/static/datatype-json.html#JSON-TYPE-MAPPING-TABLE + raise exceptions.ValidationError('nan and inf/-inf can not be serialized to the database') + + # This is for float-like types, like ``numpy.float128`` that are not json-serializable + # Note that `numbers.Real` also match booleans but they are already returned above + if isinstance(val, numbers.Real): + string_representation = '{{:.{}g}}'.format(AIIDA_FLOAT_PRECISION).format(val) + new_val = float(string_representation) + if 'e' in string_representation and new_val.is_integer(): + # This is indeed often quite unexpected, because it is going to change the type of the data + # from float to int. But anyway clean_value is changing some types, and we are also bound to what + # our current backends do. + # Currently, in both Django and SQLA (with JSONB attributes), if we store 1.e1, ..., 1.e14, 1.e15, + # they will be stored as floats; instead 1.e16, 1.e17, ... will all be stored as integer anyway, + # even if we don't run this clean_value step. + # So, for consistency, it's better if we do the conversion ourselves here, and we do it for a bit + # smaller numbers than python+[SQL+JSONB] would do (the AiiDA float precision is here 14), so the + # results are consistent, and the hashing will work also after a round trip as expected. + return int(new_val) + return new_val + + # Anything else we do not understand and we refuse + raise exceptions.ValidationError('type `{}` is not supported as it is not json-serializable'.format(type(val))) + + if isinstance(value, BaseType): + return clean_builtin(value.value) + + if isinstance(value, Mapping): + # Check dictionary before iterables + return {k: clean_value(v) for k, v in value.items()} + + if (isinstance(value, Iterable) and not isinstance(value, str)): + # list, tuple, ... but not a string + # This should also properly take care of dealing with the + # basedatatypes.List object + return [clean_value(v) for v in value] + + # If I don't know what to do I just return the value + # itself - it's not super robust, but relies on duck typing + # (e.g. if there is something that behaves like an integer + # but is not an integer, I still accept it) + + return clean_builtin(value) diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 4c78eac4ea..1c3b64777d 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -7,11 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ This module defines the classes related to band structures or dispersions in a Brillouin zone, and how to operate on them. """ - from string import Template import numpy @@ -22,6 +22,7 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): + """Prepare the header.""" from aiida import get_file_header filetext = [] @@ -32,13 +33,10 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): filetext.append('\t{}\t{}'.format(*plot_info['y'].shape)) filetext.append('') filetext.append('\tlabel\tpoint') - for l in plot_info['raw_labels']: - filetext.append('\t{}\t{:.8f}'.format(l[1], l[0])) - - return '\n'.join('{} {}'.format(comment_char, l) for l in filetext) - + for label in plot_info['raw_labels']: + filetext.append('\t{}\t{:.8f}'.format(label[1], label[0])) -# TODO: set and get bands could have more functionalities: how do I know the number of bands for example? + return '\n'.join('{} {}'.format(comment_char, line) for line in filetext) def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): @@ -71,14 +69,15 @@ def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): equal to the lumo (e.g. in semi-metals). """ + # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements,no-else-return + def nint(num): """ Stable rounding function """ - if (num > 0): + if num > 0: return int(num + .5) - else: - return int(num - .5) + return int(num - .5) if fermi_energy and number_electrons: raise ValueError('Specify either the number of electrons or the Fermi energy, but not both') @@ -89,11 +88,9 @@ def nint(num): raise KeyError('Cannot do much of a band analysis without bands') if len(stored_bands.shape) == 3: - # I write the algorithm for the generic case of having both the - # spin up and spin down array - + # I write the algorithm for the generic case of having both the spin up and spin down array # put all spins on one band per kpoint - bands = numpy.concatenate([_ for _ in stored_bands], axis=1) + bands = numpy.concatenate(stored_bands, axis=1) else: bands = stored_bands @@ -114,7 +111,7 @@ def nint(num): # spin up and spin down array # put all spins on one band per kpoint - occupations = numpy.concatenate([_ for _ in stored_occupations], axis=1) + occupations = numpy.concatenate(stored_occupations, axis=1) else: occupations = stored_occupations @@ -124,23 +121,29 @@ def nint(num): # sort the bands by energy, and reorder the occupations accordingly # since after joining the two spins, I might have unsorted stuff bands, occupations = [ - numpy.array(y) for y in zip(*[ - list(zip(*j)) for j in - [sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) for i in zip(bands, occupations)] - ]) + numpy.array(y) for y in zip( + *[ + list(zip(*j)) for j in [ + sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) + for i in zip(bands, occupations) + ] + ] + ) ] number_electrons = int(round(sum([sum(i) for i in occupations]) / num_kpoints)) homo_indexes = [numpy.where(numpy.array([nint(_) for _ in x]) > 0)[0][-1] for x in occupations] if len(set(homo_indexes)) > 1: # there must be intersections of valence and conduction bands return False, None - else: - homo = [_[0][_[1]] for _ in zip(bands, homo_indexes)] - try: - lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] - except IndexError: - raise ValueError('To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons') + + homo = [_[0][_[1]] for _ in zip(bands, homo_indexes)] + try: + lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] + except IndexError: + raise ValueError( + 'To understand if it is a metal or insulator, ' + 'need more bands than n_band=number_electrons' + ) else: bands = numpy.sort(bands) @@ -155,8 +158,10 @@ def nint(num): # gather the energies of the lumo band, for every kpoint lumo = [i[number_electrons // number_electrons_per_band] for i in bands] # take the n+1th level except IndexError: - raise ValueError('To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons') + raise ValueError( + 'To understand if it is a metal or insulator, ' + 'need more bands than n_band=number_electrons' + ) if number_electrons % 2 == 1 and len(stored_bands.shape) == 2: # if #electrons is odd and we have a non spin polarized calculation @@ -167,10 +172,11 @@ def nint(num): gap = min(lumo) - max(homo) if gap == 0.: return False, 0. - elif gap < 0.: + + if gap < 0.: return False, None - else: - return True, gap + + return True, gap # analysis on the fermi energy else: @@ -188,23 +194,22 @@ def nint(num): raise ValueError("The Fermi energy is below all band energies, don't know what to do.") # one band is crossed by the fermi energy - if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): + if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): # pylint: disable=chained-comparison return False, None # case of semimetals, fermi energy at the crossing of two bands # this will only work if the dirac point is computed! - elif (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): + if (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): return False, 0. - # insulating case - else: - # take the max of the band maxima below the fermi energy - homo = max([i[0] for i in max_mins if i[0] < fermi_energy]) - # take the min of the band minima above the fermi energy - lumo = min([i[1] for i in max_mins if i[1] > fermi_energy]) - gap = lumo - homo - if gap <= 0.: - raise Exception('Something wrong has been implemented. Revise the code!') - return True, gap + + # insulating case, take the max of the band maxima below the fermi energy + homo = max([i[0] for i in max_mins if i[0] < fermi_energy]) + # take the min of the band minima above the fermi energy + lumo = min([i[1] for i in max_mins if i[1] > fermi_energy]) + gap = lumo - homo + if gap <= 0.: + raise Exception('Something wrong has been implemented. Revise the code!') + return True, gap class BandsData(KpointsData): @@ -212,15 +217,6 @@ class BandsData(KpointsData): Class to handle bands data """ - # Associate file extensions to default plotting formats - _custom_export_format_replacements = { - 'dat': 'dat_multicolumn', - 'png': 'mpl_png', - 'pdf': 'mpl_pdf', - 'py': 'mpl_singlefile', - 'gnu': 'gnuplot' - } - def set_kpointsdata(self, kpointsdata): """ Load the kpoints from a kpoint object. @@ -258,6 +254,7 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): Nkpoints x Nbands floats or Nspins x Nkpoints x Nbands; Nkpoints must correspond to the number of kpoints. """ + # pylint: disable=too-many-branches try: kpoints = self.get_kpoints() except AttributeError: @@ -266,9 +263,11 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): the_bands = numpy.array(bands) if len(the_bands.shape) not in [2, 3]: - raise ValueError('Bands must be an array of dimension 2' - '([N_kpoints, N_bands]) or of dimension 3 ' - ' ([N_arrays, N_kpoints, N_bands]), found instead {}'.format(len(the_bands.shape))) + raise ValueError( + 'Bands must be an array of dimension 2' + '([N_kpoints, N_bands]) or of dimension 3 ' + ' ([N_arrays, N_kpoints, N_bands]), found instead {}'.format(len(the_bands.shape)) + ) list_of_arrays_to_be_checked = [] @@ -280,8 +279,10 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): if occupations is not None: the_occupations = numpy.array(occupations) if the_occupations.shape != the_bands.shape: - raise ValueError('Shape of occupations {} different from shape' - 'shape of bands {}'.format(the_occupations.shape, the_bands.shape)) + raise ValueError( + 'Shape of occupations {} different from shape' + 'shape of bands {}'.format(the_occupations.shape, the_bands.shape) + ) if not the_bands.dtype.type == numpy.float64: list_of_arrays_to_be_checked.append([the_occupations, 'occupations']) @@ -306,8 +307,10 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): elif isinstance(labels, (tuple, list)) and all([isinstance(_, str) for _ in labels]): the_labels = [str(_) for _ in labels] else: - raise ValidationError('Band labels have an unrecognized type ({})' - 'but should be a string or a list of strings'.format(labels.__class__)) + raise ValidationError( + 'Band labels have an unrecognized type ({})' + 'but should be a string or a list of strings'.format(labels.__class__) + ) if len(the_bands.shape) == 2 and len(the_labels) != 1: raise ValidationError('More array labels than the number of arrays') @@ -405,8 +408,8 @@ def get_bands(self, also_occupations=False, also_labels=False): if len(to_return) == 1: return bands - else: - return to_return + + return to_return def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, get_segments=False, y_origin=0.): """ @@ -432,6 +435,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, depending on the type of spin; the length is always equalt to the total number of bands per kpoint). """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements # load the x and y's of the graph stored_bands = self.get_bands() if len(stored_bands.shape) == 2: @@ -439,7 +443,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, band_type_idx = numpy.array([0] * stored_bands.shape[1]) two_band_types = False elif len(stored_bands.shape) == 3: - bands = numpy.concatenate([_ for _ in stored_bands], axis=1) + bands = numpy.concatenate(stored_bands, axis=1) band_type_idx = numpy.array([0] * stored_bands.shape[2] + [1] * stored_bands.shape[2]) two_band_types = True else: @@ -468,8 +472,9 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, # as a result, where there are discontinuities in the path, # I have two consecutive points with the same x coordinate distances = [ - numpy.linalg.norm(kpoints[i] - kpoints[i - 1]) - if not (i in labels_indices and i - 1 in labels_indices) else 0. for i in range(1, len(kpoints)) + numpy.linalg.norm(kpoints[i] - + kpoints[i - 1]) if not (i in labels_indices and i - 1 in labels_indices) else 0. + for i in range(1, len(kpoints)) ] x = [float(sum(distances[:i])) for i in range(len(distances) + 1)] @@ -499,8 +504,8 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, if labels[0][0] != 0: labels.insert(0, (0, '')) # I add an empty label that points to the last band if the last label does not do it - if labels[-1][0] != len(bands)-1 : - labels.append((len(bands)-1, '')) + if labels[-1][0] != len(bands) - 1: + labels.append((len(bands) - 1, '')) for (position_from, label_from), (position_to, label_to) in zip(labels[:-1], labels[1:]): if position_to - position_from > 1: # Create a new path line only if there are at least two points, @@ -547,6 +552,7 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-locals import os dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat' @@ -561,7 +567,6 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N x = plot_info['x'] labels = plot_info['labels'] - num_labels = len(labels) num_bands = bands.shape[1] # axis limits @@ -573,14 +578,6 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N # first prepare the xy coordinates of the sets raw_data, _ = self._prepare_dat_blocks(plot_info) - ## Manually add the xy coordinates of the vertical lines - not needed! Use gridlines - #new_block = [] - #for l in labels: - # new_block.append("{}\t{}".format(l[0], y_min_lim)) - # new_block.append("{}\t{}".format(l[0], y_max_lim)) - # new_block.append("") - #raw_data += "\n".join(new_block) - batch = [] if comments: batch.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) @@ -598,9 +595,9 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N batch.append('xaxis tick spec type both') batch.append('xaxis tick spec {}'.format(len(labels))) # set the name of the special points - for i, l in enumerate(labels): - batch.append('xaxis tick major {}, {}'.format(i, l[0])) - batch.append('xaxis ticklabel {}, "{}"'.format(i, l[1])) + for index, label in enumerate(labels): + batch.append('xaxis tick major {}, {}'.format(index, label[0])) + batch.append('xaxis ticklabel {}, "{}"'.format(index, label[1])) batch.append('xaxis tick major color 7') batch.append('xaxis tick major grid on') @@ -614,20 +611,16 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N batch.append('xaxis label font 4') # set color and linewidths of bands - for i in range(num_bands): - batch.append('s{} line color 1'.format(i)) - batch.append('s{} linewidth 1'.format(i)) - - ## set color and linewidths of label lines - not needed! use gridlines - #for i in range(num_bands, num_bands + num_labels): - # batch.append("s{} hidden true".format(i)) + for index in range(num_bands): + batch.append('s{} line color 1'.format(index)) + batch.append('s{} linewidth 1'.format(index)) batch_data = '\n'.join(batch) + '\n' extra_files = {dat_filename: raw_data} return batch_data.encode('utf-8'), extra_files - def _prepare_dat_multicolumn(self, main_file_name='', comments=True): + def _prepare_dat_multicolumn(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Write an N x M matrix. First column is the distance between kpoints, The other columns are the bands. Header contains number of kpoints and @@ -651,7 +644,7 @@ def _prepare_dat_multicolumn(self, main_file_name='', comments=True): return ('\n'.join(return_text) + '\n').encode('utf-8'), {} - def _prepare_dat_blocks(self, main_file_name='', comments=True): + def _prepare_dat_blocks(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Format suitable for gnuplot using blocks. Columns with x and y (path and band energy). Several blocks, separated @@ -669,10 +662,8 @@ def _prepare_dat_blocks(self, main_file_name='', comments=True): if comments: return_text.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) - the_bands = numpy.transpose(bands) - - for b in the_bands: - for i in zip(x, b): + for band in numpy.transpose(bands): + for i in zip(x, band): line = ['{:.8f}'.format(i[0]), '{:.8f}'.format(i[1])] return_text.append('\t'.join(line)) return_text.append('') @@ -680,17 +671,19 @@ def _prepare_dat_blocks(self, main_file_name='', comments=True): return '\n'.join(return_text).encode('utf-8'), {} - def _matplotlib_get_dict(self, - main_file_name='', - comments=True, - title='', - legend=None, - legend2=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None, - **kwargs): + def _matplotlib_get_dict( + self, + main_file_name='', + comments=True, + title='', + legend=None, + legend2=None, + y_max_lim=None, + y_min_lim=None, + y_origin=0., + prettify_format=None, + **kwargs + ): # pylint: disable=unused-argument """ Prepare the data to send to the python-matplotlib plotting script. @@ -700,7 +693,7 @@ def _matplotlib_get_dict(self, :param setnumber_offset: an offset to be applied to all set numbers (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter max_num_agr_colors + and filling (should be less than the parameter MAX_NUM_AGR_COLORS defined below) :param title: the title :param legend: the legend (applied only to the first of the set) @@ -717,7 +710,7 @@ def _matplotlib_get_dict(self, :param kwargs: additional customization variables; only a subset is accepted, see internal variable 'valid_additional_keywords """ - #import math + # pylint: disable=too-many-arguments,too-many-locals # Only these keywords are accepted in kwargs, and then set into the json valid_additional_keywords = [ @@ -763,7 +756,8 @@ def _matplotlib_get_dict(self, prettify_format=prettify_format, join_symbol=join_symbol, get_segments=True, - y_origin=y_origin) + y_origin=y_origin + ) all_data = {} @@ -777,10 +771,8 @@ def _matplotlib_get_dict(self, tick_pos = [] tick_labels = [] - #all_data['bands'] = the_bands.tolist() all_data['paths'] = plot_info['paths'] all_data['band_type_idx'] = plot_info['band_type_idx'].tolist() - #all_data['x'] = x all_data['tick_pos'] = tick_pos all_data['tick_labels'] = tick_labels @@ -798,17 +790,15 @@ def _matplotlib_get_dict(self, y_min_lim = numpy.array(bands).min() x_min_lim = min(x) # this isn't a numpy array, but a list x_max_lim = max(x) - #ytick_spacing = 10 ** int(math.log10((y_max_lim - y_min_lim))) all_data['x_min_lim'] = x_min_lim all_data['x_max_lim'] = x_max_lim all_data['y_min_lim'] = y_min_lim all_data['y_max_lim'] = y_max_lim - #all_data['ytick_spacing'] = ytick_spacing - for k, v in kwargs.items(): - if k not in valid_additional_keywords: - raise TypeError("_matplotlib_get_dict() got an unexpected keyword argument '{}'".format(k)) - all_data[k] = v + for key, value in kwargs.items(): + if key not in valid_additional_keywords: + raise TypeError("_matplotlib_get_dict() got an unexpected keyword argument '{}'".format(key)) + all_data[key] = value return all_data @@ -823,16 +813,16 @@ def _prepare_mpl_singlefile(self, *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) - s_header = matplotlib_header_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = matplotlib_footer_template_show.substitute() + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer - return s.encode('utf-8'), {} + return string.encode('utf-8'), {} - def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -852,16 +842,16 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): ext_files = {json_fname: json.dumps(all_data, indent=2).encode('utf-8')} - s_header = matplotlib_header_template.substitute() - s_import = matplotlib_import_data_fromfile_template.substitute(json_fname=json_fname) + s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE.substitute(json_fname=json_fname) s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = matplotlib_footer_template_show.substitute() + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer - return s.encode('utf-8'), ext_files + return string.encode('utf-8'), ext_files - def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -879,8 +869,8 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) # Use the Agg backend - s_header = matplotlib_header_agg_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) # I get a temporary file name @@ -890,30 +880,28 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): escaped_fname = filename.replace('"', '\"') - s_footer = matplotlib_footer_template_exportfile.substitute(fname=escaped_fname, format='pdf') + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE.substitute(fname=escaped_fname, format='pdf') - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer # I don't exec it because I might mess up with the matplotlib backend etc. # I run instead in a different process, with the same executable # (so it should work properly with virtualenvs) - #exec s - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(s) - f.flush() - - subprocess.check_output([sys.executable, f.name]) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(string) + handle.flush() + subprocess.check_output([sys.executable, handle.name]) if not os.path.exists(filename): raise RuntimeError('Unable to generate the PDF...') - with open(filename, 'rb', encoding=None) as f: - imgdata = f.read() + with open(filename, 'rb', encoding=None) as handle: + imgdata = handle.read() os.remove(filename) return imgdata, {} - def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -930,8 +918,8 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) # Use the Agg backend - s_header = matplotlib_header_agg_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) # I get a temporary file name @@ -941,24 +929,23 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): escaped_fname = filename.replace('"', '\"') - s_footer = matplotlib_footer_template_exportfile_with_dpi.substitute(fname=escaped_fname, format='png', dpi=300) + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI.substitute(fname=escaped_fname, format='png', dpi=300) - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer # I don't exec it because I might mess up with the matplotlib backend etc. # I run instead in a different process, with the same executable # (so it should work properly with virtualenvs) - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(s) - f.flush() - - subprocess.check_output([sys.executable, f.name]) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(string) + handle.flush() + subprocess.check_output([sys.executable, handle.name]) if not os.path.exists(filename): raise RuntimeError('Unable to generate the PNG...') - with open(filename, 'rb', encoding=None) as f: - imgdata = f.read() + with open(filename, 'rb', encoding=None) as handle: + imgdata = handle.read() os.remove(filename) return imgdata, {} @@ -969,9 +956,9 @@ def _get_mpl_body_template(paths): :param paths: paths of k-points """ if len(paths) == 1: - s_body = matplotlib_body_template.substitute(plot_code=single_kp) + s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=SINGLE_KP) else: - s_body = matplotlib_body_template.substitute(plot_code=multi_kp) + s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=MULTI_KP) return s_body def show_mpl(self, **kwargs): @@ -984,14 +971,16 @@ def show_mpl(self, **kwargs): """ exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) # pylint: disable=exec-used - def _prepare_gnuplot(self, - main_file_name=None, - title='', - comments=True, - prettify_format=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0.): + def _prepare_gnuplot( + self, + main_file_name=None, + title='', + comments=True, + prettify_format=None, + y_max_lim=None, + y_min_lim=None, + y_origin=0. + ): """ Prepare an gnuplot script to plot the bands, with the .dat file returned as an independent file. @@ -1006,6 +995,7 @@ def _prepare_gnuplot(self, :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-arguments,too-many-locals import os main_file_name = main_file_name or 'band.dat' @@ -1016,14 +1006,11 @@ def _prepare_gnuplot(self, prettify_format = 'gnuplot_seekpath' plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin) + cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin + ) bands = plot_info['y'] x = plot_info['x'] - labels = plot_info['labels'] - - num_labels = len(labels) - num_bands = bands.shape[1] # axis limits if y_max_lim is None: @@ -1045,7 +1032,8 @@ def _prepare_gnuplot(self, script.append(prepare_header_comment(self.uuid, plot_info=plot_info, comment_char='# ')) script.append('') - script.append(u"""## Uncomment the next two lines to write directly to PDF + script.append( + """## Uncomment the next two lines to write directly to PDF ## Note: You need to have gnuplot installed with pdfcairo support! #set term pdfcairo #set output 'out.pdf' @@ -1060,18 +1048,15 @@ def _prepare_gnuplot(self, #set termopt font "CMU Sans Serif, 12" ## Classical Times New Roman #set termopt font "Times New Roman, 12" -""") +""" + ) # Actual logic script.append('set termopt enhanced') # Properly deals with e.g. subscripts script.append('set encoding utf8') # To deal with Greek letters script.append('set xtics ({})'.format(xtics_string)) - script.append('unset key') - - script.append('set yrange [{}:{}]'.format(y_min_lim, y_max_lim)) - script.append('set ylabel "{}"'.format('Dispersion ({})'.format(self.units))) if title: @@ -1084,25 +1069,31 @@ def _prepare_gnuplot(self, script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) else: script.append('set xrange [-1.0:1.0]') - script.append('plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) + script.append( + 'plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format( + os.path.basename(dat_filename).replace('"', '\"') + ) + ) script_data = '\n'.join(script) + '\n' extra_files = {dat_filename: raw_data} return script_data.encode('utf-8'), extra_files - def _prepare_agr(self, - main_file_name='', - comments=True, - setnumber_offset=0, - color_number=1, - color_number2=2, - legend='', - title='', - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None): + def _prepare_agr( + self, + main_file_name='', + comments=True, + setnumber_offset=0, + color_number=1, + color_number2=2, + legend='', + title='', + y_max_lim=None, + y_min_lim=None, + y_origin=0., + prettify_format=None + ): """ Prepare an xmgrace agr file. @@ -1112,11 +1103,11 @@ def _prepare_agr(self, :param setnumber_offset: an offset to be applied to all set numbers (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter max_num_agr_colors + and filling (should be less than the parameter MAX_NUM_AGR_COLORS defined below) :param color_number2: the color number for lines, symbols, error bars and filling for the second-type spins (should be less than the - parameter max_num_agr_colors defined below) + parameter MAX_NUM_AGR_COLORS defined below) :param legend: the legend (applied only to the first set) :param title: the title :param y_max_lim: the maximum on the y axis (if None, put the @@ -1130,19 +1121,21 @@ def _prepare_agr(self, :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unused-argument if prettify_format is None: # Default. Specified like this to allow caller functions to pass 'None' prettify_format = 'agr_seekpath' plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin) + cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin + ) import math # load the x and y of every set - if color_number > max_num_agr_colors: - raise ValueError('Color number is too high (should be less than {})'.format(max_num_agr_colors)) - if color_number2 > max_num_agr_colors: - raise ValueError('Color number 2 is too high (should be less than {})'.format(max_num_agr_colors)) + if color_number > MAX_NUM_AGR_COLORS: + raise ValueError('Color number is too high (should be less than {})'.format(MAX_NUM_AGR_COLORS)) + if color_number2 > MAX_NUM_AGR_COLORS: + raise ValueError('Color number 2 is too high (should be less than {})'.format(MAX_NUM_AGR_COLORS)) bands = plot_info['y'] x = plot_info['x'] @@ -1161,22 +1154,22 @@ def _prepare_agr(self, # prepare xticks labels sx1 = '' - for i, l in enumerate(labels): - sx1 += agr_single_xtick_template.substitute( + for i, label in enumerate(labels): + sx1 += AGR_SINGLE_XTICK_TEMPLATE.substitute( index=i, - coord=l[0], - name=l[1], + coord=label[0], + name=label[1], ) - xticks = agr_xticks_template.substitute( + xticks = AGR_XTICKS_TEMPLATE.substitute( num_labels=num_labels, single_xtick_templates=sx1, ) # build the arrays with the xy coordinates all_sets = [] - for b in the_bands: + for band in the_bands: this_set = '' - for i in zip(x, b): + for i in zip(x, band): line = '{:.8f}'.format(i[0]) + '\t' + '{:.8f}'.format(i[1]) + '\n' this_set += line all_sets.append(this_set) @@ -1188,15 +1181,16 @@ def _prepare_agr(self, else: linecolor = color_number2 width = str(2.0) - set_descriptions += agr_set_description_template.substitute( + set_descriptions += AGR_SET_DESCRIPTION_TEMPLATE.substitute( set_number=i + setnumber_offset, linewidth=width, color_number=linecolor, - legend=legend if i == 0 else '') + legend=legend if i == 0 else '' + ) units = self.units - graphs = agr_graph_template.substitute( + graphs = AGR_GRAPH_TEMPLATE.substitute( x_min_lim=x_min_lim, y_min_lim=y_min_lim, x_max_lim=x_max_lim, @@ -1209,19 +1203,21 @@ def _prepare_agr(self, ) sets = [] for i, this_set in enumerate(all_sets): - sets.append(agr_singleset_template.substitute(set_number=i + setnumber_offset, xydata=this_set)) + sets.append(AGR_SINGLESET_TEMPLATE.substitute(set_number=i + setnumber_offset, xydata=this_set)) the_sets = '&\n'.join(sets) - s = agr_template.substitute(graphs=graphs, sets=the_sets) + string = AGR_TEMPLATE.substitute(graphs=graphs, sets=the_sets) if comments: - s = prepare_header_comment(self.uuid, plot_info, comment_char='#') + '\n' + s + string = prepare_header_comment(self.uuid, plot_info, comment_char='#') + '\n' + string - return s.encode('utf-8'), {} + return string.encode('utf-8'), {} def _get_band_segments(self, cartesian): + """Return the band segments.""" plot_info = self._get_bandplot_data( - cartesian=cartesian, prettify_format=None, join_symbol=None, get_segments=True) + cartesian=cartesian, prettify_format=None, join_symbol=None, get_segments=True + ) out_dict = {'label': self.label} @@ -1230,7 +1226,7 @@ def _get_band_segments(self, cartesian): return out_dict - def _prepare_json(self, main_file_name='', comments=True): + def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Prepare a json file in a format compatible with the AiiDA band visualizer @@ -1249,9 +1245,10 @@ def _prepare_json(self, main_file_name='', comments=True): return json.dumps(json_dict).encode('utf-8'), {} -max_num_agr_colors = 15 +MAX_NUM_AGR_COLORS = 15 -agr_template = Template(""" +AGR_TEMPLATE = Template( + """ # Grace project file # @version 50122 @@ -1376,19 +1373,23 @@ def _prepare_json(self, main_file_name='', comments=True): @r4 line 0, 0, 0, 0 $graphs $sets - """) + """ +) -agr_xticks_template = Template(""" +AGR_XTICKS_TEMPLATE = Template(""" @ xaxis tick spec $num_labels $single_xtick_templates """) -agr_single_xtick_template = Template(""" +AGR_SINGLE_XTICK_TEMPLATE = Template( + """ @ xaxis tick major $index, $coord @ xaxis ticklabel $index, "$name" - """) + """ +) -agr_graph_template = Template(""" +AGR_GRAPH_TEMPLATE = Template( + """ @g0 on @g0 hidden false @g0 type XY @@ -1545,9 +1546,11 @@ def _prepare_json(self, main_file_name='', comments=True): @ frame background color 0 @ frame background pattern 0 $set_descriptions - """) + """ +) -agr_set_description_template = Template(""" +AGR_SET_DESCRIPTION_TEMPLATE = Template( + """ @ s$set_number hidden false @ s$set_number type xy @ s$set_number symbol 0 @@ -1597,9 +1600,10 @@ def _prepare_json(self, main_file_name='', comments=True): @ s$set_number errorbar riser clip length 0.100000 @ s$set_number comment "Cols 1:2" @ s$set_number legend "$legend" - """) + """ +) -agr_singleset_template = Template(""" +AGR_SINGLESET_TEMPLATE = Template(""" @target G0.S$set_number @type xy $xydata @@ -1608,7 +1612,8 @@ def _prepare_json(self, main_file_name='', comments=True): # text.latex.preview=True is needed to have a proper alignment of # tick marks with and without subscripts # see e.g. http://matplotlib.org/1.3.0/examples/pylab_examples/usetex_baseline_test.html -matplotlib_header_agg_template = Template('''# -*- coding: utf-8 -*- +MATPLOTLIB_HEADER_AGG_TEMPLATE = Template( + """# -*- coding: utf-8 -*- import matplotlib matplotlib.use('Agg') @@ -1630,12 +1635,14 @@ def _prepare_json(self, main_file_name='', comments=True): import json print_comment = False -''') +""" +) # text.latex.preview=True is needed to have a proper alignment of # tick marks with and without subscripts # see e.g. http://matplotlib.org/1.3.0/examples/pylab_examples/usetex_baseline_test.html -matplotlib_header_template = Template('''# -*- coding: utf-8 -*- +MATPLOTLIB_HEADER_TEMPLATE = Template( + """# -*- coding: utf-8 -*- from matplotlib import rc # Uncomment to change default font @@ -1654,16 +1661,19 @@ def _prepare_json(self, main_file_name='', comments=True): import json print_comment = False -''') +""" +) -matplotlib_import_data_inline_template = Template('''all_data_str = r"""$all_data_json""" +MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE = Template('''all_data_str = r"""$all_data_json""" ''') -matplotlib_import_data_fromfile_template = Template('''with open("$json_fname", encoding='utf8') as f: +MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE = Template( + """with open("$json_fname", encoding='utf8') as f: all_data_str = f.read() -''') +""" +) -multi_kp = ''' +MULTI_KP = """ for path in paths: if path['length'] <= 1: # Avoid printing empty lines @@ -1690,16 +1700,17 @@ def _prepare_json(self, main_file_name='', comments=True): p.plot(x, band, label=label, **further_plot_options ) -''' +""" -single_kp = ''' +SINGLE_KP = """ path = paths[0] values = path['values'] x = [path['x'] for _ in values] p.scatter(x, values, marker="_") -''' +""" -matplotlib_body_template = Template('''all_data = json.loads(all_data_str) +MATPLOTLIB_BODY_TEMPLATE = Template( + """all_data = json.loads(all_data_str) if not all_data.get('use_latex', False): rc('text', usetex=False) @@ -1779,13 +1790,11 @@ def _prepare_json(self, main_file_name='', comments=True): print(all_data['comment']) except KeyError: pass -''') +""" +) -matplotlib_footer_template_show = Template('''pl.show() -''') +MATPLOTLIB_FOOTER_TEMPLATE_SHOW = Template("""pl.show()""") -matplotlib_footer_template_exportfile = Template('''pl.savefig("$fname", format="$format") -''') +MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE = Template("""pl.savefig("$fname", format="$format")""") -matplotlib_footer_template_exportfile_with_dpi = Template('''pl.savefig("$fname", format="$format", dpi=$dpi) -''') +MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index e02eae0929..68e91c7b40 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -12,7 +12,6 @@ lists and meshes of k-points (i.e., points in the reciprocal space of a periodic crystal structure). """ - import numpy from .array import ArrayData diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 30c3a2aed9..9e81fc8c5d 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Data plugin to represet arrays of projected wavefunction components.""" import copy import numpy as np @@ -27,6 +27,7 @@ class ProjectionData(OrbitalData, ArrayData): s, n, and k. E.g. the elements are the projections described as < orbital | Bloch wavefunction (s,n,k) > """ + def _check_projections_bands(self, projection_array): """ Checks to make sure that a reference bandsdata is already set, and that @@ -46,9 +47,7 @@ def _check_projections_bands(self, projection_array): # The [0:2] is so that each array, and not collection of arrays # is used to make the comparison if np.shape(projection_array) != shape_bands: - raise AttributeError('These arrays are not the same shape as' - ' the bands') - return None + raise AttributeError('These arrays are not the same shape as' ' the bands') def set_reference_bandsdata(self, value): """ @@ -65,17 +64,19 @@ def set_reference_bandsdata(self, value): else: try: pk = int(value) - bands = load_node(pk=pk, type=BandsData) + bands = load_node(pk=pk) uuid = bands.uuid except ValueError: uuid = str(value) try: - bands = load_node(uuid=uuid, type=BandsData) + bands = load_node(uuid=uuid) uuid = bands.uuid - except : - raise exceptions.NotExistent('The value passed to ' - 'set_reference_bandsdata was not ' - 'associated to any bandsdata') + except Exception: # pylint: disable=bare-except + raise exceptions.NotExistent( + 'The value passed to ' + 'set_reference_bandsdata was not ' + 'associated to any bandsdata' + ) self.set_attribute('reference_bandsdata_uuid', uuid) @@ -94,12 +95,9 @@ def get_reference_bandsdata(self): except AttributeError: raise AttributeError('BandsData has not been set for this instance') try: - #bands = load_node(uuid=uuid, type=BandsData) - bands = load_node(uuid=uuid) #TODO switch to above once type - # has been implemented for load_node + bands = load_node(uuid=uuid) except exceptions.NotExistent: - raise exceptions.NotExistent('The bands referenced to this class have not been ' - 'found in this database.') + raise exceptions.NotExistent('The bands referenced to this class have not been found in this database.') return bands def _find_orbitals_and_indices(self, **kwargs): @@ -112,16 +110,12 @@ def _find_orbitals_and_indices(self, **kwargs): to the kwargs :return: all_orbitals, list of orbitals to which the indexes correspond """ - # index_and_orbitals = self._get_orbitals_and_index() - index_and_orbitals = [] selected_orbitals = self.get_orbitals(**kwargs) - selected_orb_dicts = [orb.get_orbital_dict() for orb - in selected_orbitals] + selected_orb_dicts = [orb.get_orbital_dict() for orb in selected_orbitals] all_orbitals = self.get_orbitals() all_orb_dicts = [orb.get_orbital_dict() for orb in all_orbitals] - retrieve_indices = [i for i in range(len(all_orb_dicts)) - if all_orb_dicts[i] in selected_orb_dicts] - return retrieve_indices, all_orbitals + retrieve_indices = [i for i in range(len(all_orb_dicts)) if all_orb_dicts[i] in selected_orb_dicts] + return retrieve_indices, all_orbitals def get_pdos(self, **kwargs): """ @@ -134,12 +128,10 @@ def get_pdos(self, **kwargs): """ retrieve_indices, all_orbitals = self._find_orbitals_and_indices(**kwargs) - out_list = [(all_orbitals[i], - self.get_array('pdos_{}'.format( - self._from_index_to_arrayname(i))), - self.get_array('energy_{}'.format( - self._from_index_to_arrayname(i))) ) - for i in retrieve_indices] + out_list = [( + all_orbitals[i], self.get_array('pdos_{}'.format(self._from_index_to_arrayname(i))), + self.get_array('energy_{}'.format(self._from_index_to_arrayname(i))) + ) for i in retrieve_indices] return out_list def get_projections(self, **kwargs): @@ -153,21 +145,26 @@ def get_projections(self, **kwargs): """ retrieve_indices, all_orbitals = self._find_orbitals_and_indices(**kwargs) - out_list = [(all_orbitals[i], - self.get_array('proj_{}'.format( - self._from_index_to_arrayname(i)))) + out_list = [(all_orbitals[i], self.get_array('proj_{}'.format(self._from_index_to_arrayname(i)))) for i in retrieve_indices] return out_list - def _from_index_to_arrayname(self, index): + @staticmethod + def _from_index_to_arrayname(index): """ Used internally to determine the array names. """ return 'array_{}'.format(index) - def set_projectiondata(self,list_of_orbitals, list_of_projections=None, - list_of_energy=None, list_of_pdos=None, - tags = None, bands_check=True): + def set_projectiondata( + self, + list_of_orbitals, + list_of_projections=None, + list_of_energy=None, + list_of_pdos=None, + tags=None, + bands_check=True + ): """ Stores the projwfc_array using the projwfc_label, after validating both. @@ -196,6 +193,9 @@ def set_projectiondata(self,list_of_orbitals, list_of_projections=None, been stored and therefore get_reference_bandsdata cannot be called """ + + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + def single_to_list(item): """ Checks if the item is a list or tuple, and converts it to a list @@ -207,8 +207,8 @@ def single_to_list(item): """ if isinstance(item, (list, tuple)): return item - else: - return [item] + + return [item] def array_list_checker(array_list, array_name, orb_length): """ @@ -217,12 +217,13 @@ def array_list_checker(array_list, array_name, orb_length): required_length, raises exception using array_name if there is a failure """ - if not all([isinstance(_,np.ndarray) for _ in array_list]): - raise exceptions.ValidationError('{} was not composed ' - 'entirely of ndarrays'.format(array_name)) + if not all([isinstance(_, np.ndarray) for _ in array_list]): + raise exceptions.ValidationError('{} was not composed entirely of ndarrays'.format(array_name)) if len(array_list) != orb_length: - raise exceptions.ValidationError('{} did not have the same length as the ' - 'list of orbitals'.format(array_name)) + raise exceptions.ValidationError( + '{} did not have the same length as the ' + 'list of orbitals'.format(array_name) + ) ############## list_of_orbitals = single_to_list(list_of_orbitals) @@ -232,22 +233,21 @@ def array_list_checker(array_list, array_name, orb_length): if not list_of_pdos and not list_of_projections: raise exceptions.ValidationError('Must set either pdos or projections') if bool(list_of_energy) != bool(list_of_pdos): - raise exceptions.ValidationError('list_of_pdos and list_of_energy must always ' - 'be set together') + raise exceptions.ValidationError('list_of_pdos and list_of_energy must always be set together') orb_length = len(list_of_orbitals) # verifies and sets the orbital dicts list_of_orbital_dicts = [] - for i in range(len(list_of_orbitals)): + for i, _ in enumerate(list_of_orbitals): this_orbital = list_of_orbitals[i] orbital_dict = this_orbital.get_orbital_dict() try: orbital_type = orbital_dict.pop('_orbital_type') except KeyError: - raise ValidationError('No _orbital_type key found in dictionary: {}'.format(orbital_dict)) - OrbitalClass = OrbitalFactory(orbital_type) - test_orbital = OrbitalClass(**orbital_dict) + raise exceptions.ValidationError('No _orbital_type key found in dictionary: {}'.format(orbital_dict)) + cls = OrbitalFactory(orbital_type) + test_orbital = cls(**orbital_dict) list_of_orbital_dicts.append(test_orbital.get_orbital_dict()) self.set_attribute('orbital_dicts', list_of_orbital_dicts) @@ -255,7 +255,7 @@ def array_list_checker(array_list, array_name, orb_length): if list_of_projections: list_of_projections = single_to_list(list_of_projections) array_list_checker(list_of_projections, 'projections', orb_length) - for i in range(len(list_of_projections)): + for i, _ in enumerate(list_of_projections): this_projection = list_of_projections[i] array_name = self._from_index_to_arrayname(i) if bands_check: @@ -268,7 +268,7 @@ def array_list_checker(array_list, array_name, orb_length): list_of_energy = single_to_list(list_of_energy) array_list_checker(list_of_pdos, 'pdos', orb_length) array_list_checker(list_of_energy, 'energy', orb_length) - for i in range(len(list_of_pdos)): + for i, _ in enumerate(list_of_pdos): this_pdos = list_of_pdos[i] this_energy = list_of_energy[i] array_name = self._from_index_to_arrayname(i) @@ -285,15 +285,17 @@ def array_list_checker(array_list, array_name, orb_length): except IndexError: return exceptions.ValidationError('tags must be a list') - if not all([isinstance(_,str) for _ in tags]): + if not all([isinstance(_, str) for _ in tags]): raise exceptions.ValidationError('Tags must set a list of strings') self.set_attribute('tags', tags) - def set_orbitals(self, **kwargs): + def set_orbitals(self, **kwargs): # pylint: disable=arguments-differ """ This method is inherited from OrbitalData, but is blocked here. If used will raise a NotImplementedError """ - raise NotImplementedError('You cannot set orbitals using this class!' - ' This class is for setting orbitals and ' - ' projections only!') + raise NotImplementedError( + 'You cannot set orbitals using this class!' + ' This class is for setting orbitals and ' + ' projections only!' + ) diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index a3d2674320..ecc0b3ee8f 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -12,8 +12,6 @@ collections of y-arrays bound to a single x-array, and the methods to operate on them. """ - - import numpy as np from aiida.common.exceptions import InputValidationError, NotExistent from .array import ArrayData @@ -30,8 +28,8 @@ def check_convert_single_to_tuple(item): """ if isinstance(item, (list, tuple)): return item - else: - return [item] + + return [item] class XyData(ArrayData): @@ -40,7 +38,9 @@ class XyData(ArrayData): each other. That is there is one array, the X array, and there are several Y arrays, which can be considered functions of X. """ - def _arrayandname_validator(self, array, name, units): + + @staticmethod + def _arrayandname_validator(array, name, units): """ Validates that the array is an numpy.ndarray and that the name is of type str. Raises InputValidationError if this not the case. @@ -86,8 +86,7 @@ def set_y(self, y_arrays, y_names, y_units): # checks that the input lengths match if len(y_arrays) != len(y_names): - raise InputValidationError('Length of arrays and names do not ' - 'match!') + raise InputValidationError('Length of arrays and names do not match!') if len(y_units) != len(y_names): raise InputValidationError('Length of units does not match!') @@ -100,9 +99,11 @@ def set_y(self, y_arrays, y_names, y_units): for num, (y_array, y_name, y_unit) in enumerate(zip(y_arrays, y_names, y_units)): self._arrayandname_validator(y_array, y_name, y_unit) if np.shape(y_array) != np.shape(x_array): - raise InputValidationError('y_array {} did not have the ' - 'same shape has the x_array!' - ''.format(y_name)) + raise InputValidationError( + 'y_array {} did not have the ' + 'same shape has the x_array!' + ''.format(y_name) + ) self.set_array('y_array_{}'.format(num), y_array) # if the y_arrays pass the initial validation, sets each @@ -147,6 +148,5 @@ def get_y(self): for i in range(len(y_names)): y_arrays += [self.get_array('y_array_{}'.format(i))] except (KeyError, AttributeError): - raise NotExistent('Could not retrieve array associated with y array' - ' {}'.format(y_names[i])) + raise NotExistent('Could not retrieve array associated with y array {}'.format(y_names[i])) return list(zip(y_names, y_arrays, y_units)) diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index 4766e15c91..d39ec9602f 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -7,9 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Data plugin represeting an executable code to be wrapped and called through a `CalcJob` plugin.""" import os +import warnings -from aiida.common.exceptions import ValidationError, EntryPointError, InputValidationError +from aiida.common import exceptions +from aiida.common.warnings import AiidaDeprecationWarning from .data import Data __all__ = ('Code',) @@ -32,6 +35,8 @@ class Code(Data): for the code to be run). """ + # pylint: disable=too-many-public-methods + def __init__(self, remote_computer_exec=None, local_executable=None, input_plugin_name=None, files=None, **kwargs): super().__init__(**kwargs) @@ -92,21 +97,20 @@ def set_files(self, files): def __str__(self): local_str = 'Local' if self.is_local() else 'Remote' - computer_str = self.get_computer_name() + computer_str = self.computer.label return "{} code '{}' on {}, pk: {}, uuid: {}".format(local_str, self.label, computer_str, self.pk, self.uuid) def get_computer_name(self): - """Get name of this code's computer.""" + """Get label of this code's computer. - if self.is_local(): - computer_str = 'repository' - else: - if self.computer is not None: - computer_str = self.computer.name - else: - computer_str = '[unknown]' + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `self.get_computer_label()` method instead. + """ + return self.get_computer_label() - return computer_str + def get_computer_label(self): + """Get label of this code's computer.""" + return 'repository' if self.is_local() else self.computer.label @property def full_label(self): @@ -114,7 +118,7 @@ def full_label(self): Returns label of the form @. """ - return '{}@{}'.format(self.label, self.get_computer_name()) + return '{}@{}'.format(self.label, self.get_computer_label()) @property def label(self): @@ -132,9 +136,9 @@ def label(self, value): """ if '@' in str(value): msg = "Code labels must not contain the '@' symbol" - raise InputValidationError(msg) + raise exceptions.InputValidationError(msg) - super(Code, self.__class__).label.fset(self, value) + super(Code, self.__class__).label.fset(self, value) # pylint: disable=no-member def relabel(self, new_label, raise_error=True): """Relabel this code. @@ -146,7 +150,8 @@ def relabel(self, new_label, raise_error=True): .. deprecated:: 1.2.0 Will remove raise_error in `v2.0.0`. Use `try/except` instead. """ - suffix = '@{}'.format(self.get_computer_name()) + # pylint: disable=unused-argument + suffix = '@{}'.format(self.computer.label) if new_label.endswith(suffix): new_label = new_label[:-len(suffix)] @@ -173,21 +178,21 @@ def get_code_helper(cls, label, machinename=None): from aiida.orm.querybuilder import QueryBuilder from aiida.orm.computers import Computer - qb = QueryBuilder() - qb.append(cls, filters={'label': {'==': label}}, project=['*'], tag='code') + query = QueryBuilder() + query.append(cls, filters={'label': label}, project='*', tag='code') if machinename: - qb.append(Computer, filters={'name': {'==': machinename}}, with_node='code') + query.append(Computer, filters={'name': machinename}, with_node='code') - if qb.count() == 0: + if query.count() == 0: raise NotExistent("'{}' is not a valid code name.".format(label)) - elif qb.count() > 1: - codes = qb.all(flat=True) + elif query.count() > 1: + codes = query.all(flat=True) retstr = ("There are multiple codes with label '{}', having IDs: ".format(label)) retstr += ', '.join(sorted([str(c.pk) for c in codes])) + '.\n' retstr += ('Relabel them (using their ID), or refer to them with their ID.') raise MultipleObjectsError(retstr) else: - return qb.first()[0] + return query.first()[0] @classmethod def get(cls, pk=None, label=None, machinename=None): @@ -200,11 +205,10 @@ def get(cls, pk=None, label=None, machinename=None): :param machinename: the machine name where code is setup :raise aiida.common.NotExistent: if no code identified by the given string is found - :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely - a code + :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code :raise aiida.common.InputValidationError: if neither a pk nor a label was passed in """ - from aiida.common.exceptions import (NotExistent, MultipleObjectsError, InputValidationError) + # pylint: disable=arguments-differ from aiida.orm.utils import load_code # first check if code pk is provided @@ -212,17 +216,17 @@ def get(cls, pk=None, label=None, machinename=None): code_int = int(pk) try: return load_code(pk=code_int) - except NotExistent: + except exceptions.NotExistent: raise ValueError('{} is not valid code pk'.format(pk)) - except MultipleObjectsError: - raise MultipleObjectsError("More than one code in the DB with pk='{}'!".format(pk)) + except exceptions.MultipleObjectsError: + raise exceptions.MultipleObjectsError("More than one code in the DB with pk='{}'!".format(pk)) # check if label (and machinename) is provided elif label is not None: return cls.get_code_helper(label, machinename) else: - raise InputValidationError('Pass either pk or code label (and machinename)') + raise exceptions.InputValidationError('Pass either pk or code label (and machinename)') @classmethod def get_from_string(cls, code_string): @@ -247,8 +251,8 @@ def get_from_string(cls, code_string): from aiida.common.exceptions import NotExistent, MultipleObjectsError, InputValidationError try: - label, sep, machinename = code_string.partition('@') - except AttributeError as exception: + label, _, machinename = code_string.partition('@') + except AttributeError: raise InputValidationError('the provided code_string is not of valid string type') try: @@ -270,35 +274,39 @@ def list_for_plugin(cls, plugin, labels=True): otherwise a list of integers with the code PKs. """ from aiida.orm.querybuilder import QueryBuilder - qb = QueryBuilder() - qb.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) - valid_codes = qb.all(flat=True) + query = QueryBuilder() + query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) + valid_codes = query.all(flat=True) if labels: return [c.label for c in valid_codes] - else: - return [c.pk for c in valid_codes] + + return [c.pk for c in valid_codes] def _validate(self): super()._validate() if self.is_local() is None: - raise ValidationError('You did not set whether the code is local or remote') + raise exceptions.ValidationError('You did not set whether the code is local or remote') if self.is_local(): if not self.get_local_executable(): - raise ValidationError('You have to set which file is the local executable ' - 'using the set_exec_filename() method') + raise exceptions.ValidationError( + 'You have to set which file is the local executable ' + 'using the set_exec_filename() method' + ) if self.get_local_executable() not in self.list_object_names(): - raise ValidationError("The local executable '{}' is not in the list of " - 'files of this code'.format(self.get_local_executable())) + raise exceptions.ValidationError( + "The local executable '{}' is not in the list of " + 'files of this code'.format(self.get_local_executable()) + ) else: if self.list_object_names(): - raise ValidationError('The code is remote but it has files inside') + raise exceptions.ValidationError('The code is remote but it has files inside') if not self.get_remote_computer(): - raise ValidationError('You did not specify a remote computer') + raise exceptions.ValidationError('You did not specify a remote computer') if not self.get_remote_exec_path(): - raise ValidationError('You did not specify a remote executable') + raise exceptions.ValidationError('You did not specify a remote executable') def set_prepend_text(self, code): """ @@ -367,9 +375,11 @@ def set_remote_computer_exec(self, remote_computer_exec): from aiida.common.lang import type_check if (not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2): - raise ValueError('remote_computer_exec must be a list or tuple ' - 'of length 2, with machine and executable ' - 'name') + raise ValueError( + 'remote_computer_exec must be a list or tuple ' + 'of length 2, with machine and executable ' + 'name' + ) computer, remote_exec_path = tuple(remote_computer_exec) @@ -455,8 +465,8 @@ def get_execname(self): """ if self.is_local(): return './{}'.format(self.get_local_executable()) - else: - return self.get_remote_exec_path() + + return self.get_remote_exec_path() def get_builder(self): """Create and return a new `ProcessBuilder` for the `CalcJob` class of the plugin configured for this code. @@ -478,8 +488,8 @@ def get_builder(self): try: process_class = CalculationFactory(plugin_name) - except EntryPointError: - raise EntryPointError('the calculation entry point `{}` could not be loaded'.format(plugin_name)) + except exceptions.EntryPointError: + raise exceptions.EntryPointError('the calculation entry point `{}` could not be loaded'.format(plugin_name)) builder = process_class.get_builder() builder.code = self @@ -489,9 +499,13 @@ def get_builder(self): def get_full_text_info(self, verbose=False): """Return a list of lists with a human-readable detailed information on this code. + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`. + :return: list of lists where each entry consists of two elements: a key and a value """ - from aiida.orm.utils.repository import FileType + warnings.warn('this property is deprecated', AiidaDeprecationWarning) # pylint: disable=no-member + from aiida.repository import FileType result = [] result.append(['PK', self.pk]) @@ -508,13 +522,13 @@ def get_full_text_info(self, verbose=False): result.append(['Exec name', self.get_execname()]) result.append(['List of files/folders:', '']) for obj in self.list_objects(): - if obj.type == FileType.DIRECTORY: + if obj.file_type == FileType.DIRECTORY: result.append(['directory', obj.name]) else: result.append(['file', obj.name]) else: result.append(['Type', 'remote']) - result.append(['Remote machine', self.get_remote_computer().name]) + result.append(['Remote machine', self.get_remote_computer().label]) result.append(['Remote absolute path', self.get_remote_exec_path()]) if self.get_prepend_text().strip(): @@ -532,14 +546,3 @@ def get_full_text_info(self, verbose=False): result.append(['Append text', 'No append text']) return result - - @classmethod - def setup(cls, **kwargs): - from aiida.cmdline.commands.code import CodeInputValidationClass - code = CodeInputValidationClass().set_and_validate_from_code(kwargs) - - try: - code.store() - except ValidationError as exc: - raise ValidationError('Unable to store the computer: {}.'.format(exc)) - return code diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 9006e606b1..07e12de69e 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -19,7 +19,32 @@ class Dict(Data): - """`Data` sub class to represent a dictionary.""" + """`Data` sub class to represent a dictionary. + + The dictionary contents of a `Dict` node are stored in the database as attributes. The dictionary + can be initialized through the `dict` argument in the constructor. After construction, values can + be retrieved and updated through the item getters and setters, respectively: + + node['key'] = 'value' + + Alternatively, the `dict` property returns an instance of the `AttributeManager` that can be used + to get and set values through attribute notation: + + node.dict.key = 'value' + + Note that trying to set dictionary values directly on the node, e.g. `node.key = value`, will not + work as intended. It will merely set the `key` attribute on the node instance, but will not be + stored in the database. As soon as the node goes out of scope, the value will be lost. + + It is also relevant to note here the difference in something being an "attribute of a node" (in + the sense that it is stored in the "attribute" column of the database when the node is stored) + and something being an "attribute of a python object" (in the sense of being able to modify and + access it as if it was a property of the variable, e.g. `node.key = value`). This is true of all + types of nodes, but it becomes more relevant for `Dict` nodes where one is constantly manipulating + these attributes. + + Finally, all dictionary mutations will be forbidden once the node is stored. + """ def __init__(self, **kwargs): """Store a dictionary as a `Node` instance. @@ -37,7 +62,10 @@ def __init__(self, **kwargs): self.set_dict(dictionary) def __getitem__(self, key): - return self.get_dict()[key] + return self.get_attribute(key) + + def __setitem__(self, key, value): + self.set_attribute(key, value) def set_dict(self, dictionary): """ Replace the current dictionary with another one. diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py index 195d0b331a..cc97cc2ae2 100644 --- a/aiida/orm/nodes/data/orbital.py +++ b/aiida/orm/nodes/data/orbital.py @@ -7,13 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Data plugin to model an atomic orbital.""" import copy -from .data import Data -from .structure import Site as site_class from aiida.common.exceptions import ValidationError, InputValidationError from aiida.plugins import OrbitalFactory +from .data import Data __all__ = ('OrbitalData',) @@ -51,8 +50,7 @@ def get_orbitals(self, **kwargs): filter_dict.update(kwargs) # prevents KeyError from occuring orbital_dicts = [x for x in orbital_dicts if all([y in x for y in filter_dict])] - orbital_dicts = [x for x in orbital_dicts if - all([x[y] == filter_dict[y] for y in filter_dict])] + orbital_dicts = [x for x in orbital_dicts if all([x[y] == filter_dict[y] for y in filter_dict])] list_of_outputs = [] for orbital_dict in orbital_dicts: @@ -61,8 +59,8 @@ def get_orbitals(self, **kwargs): except KeyError: raise ValidationError('No _orbital_type found in: {}'.format(orbital_dict)) - OrbitalClass = OrbitalFactory(orbital_type) - orbital = OrbitalClass(**orbital_dict) + cls = OrbitalFactory(orbital_type) + orbital = cls(**orbital_dict) list_of_outputs.append(orbital) return list_of_outputs @@ -86,6 +84,7 @@ def set_orbitals(self, orbitals): orbital_dicts.append(orbital_dict) self.set_attribute('orbital_dicts', orbital_dicts) + ########################################################################## # Here are some ideas for potential future convenience methods ######################################################################### diff --git a/aiida/orm/nodes/data/remote.py b/aiida/orm/nodes/data/remote.py index 03e6bf6453..ba8b8e52e0 100644 --- a/aiida/orm/nodes/data/remote.py +++ b/aiida/orm/nodes/data/remote.py @@ -7,10 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Data plugin that models a folder on a remote computer.""" import os -from .data import Data from aiida.orm import AuthInfo +from .data import Data __all__ = ('RemoteData',) @@ -28,7 +29,12 @@ def __init__(self, remote_path=None, **kwargs): self.set_remote_path(remote_path) def get_computer_name(self): - return self.computer.name + """Get label of this node's computer. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use the `self.computer.label` property instead. + """ + return self.computer.label # pylint: disable=no-member def get_remote_path(self): return self.get_attribute('remote_path') @@ -61,19 +67,20 @@ def getfile(self, relpath, destpath): :param destpath: The absolute path of where to store the file on the local machine. """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), relpath) - t.getfile(full_path, destpath) - except IOError as e: - if e.errno == 2: # file not existing - raise IOError('The required remote file {} on {} does not exist or has been deleted.'.format( - full_path, self.computer.name - )) - else: - raise + transport.getfile(full_path, destpath) + except IOError as exception: + if exception.errno == 2: # file does not exist + raise IOError( + 'The required remote file {} on {} does not exist or has been deleted.'.format( + full_path, + self.computer.label # pylint: disable=no-member + ) + ) + raise def listdir(self, relpath='.'): """ @@ -83,32 +90,31 @@ def listdir(self, relpath='.'): :return: a flat list of file/directory names (as strings). """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), relpath) - t.chdir(full_path) - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + transport.chdir(full_path) + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.label) # pylint: disable=no-member + ) + exc.errno = exception.errno raise exc else: raise try: - return t.listdir() - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + return transport.listdir() + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.label) # pylint: disable=no-member + ) + exc.errno = exception.errno raise exc else: raise @@ -121,32 +127,31 @@ def listdir_withattributes(self, path='.'): :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), path) - t.chdir(full_path) - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + transport.chdir(full_path) + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.label) # pylint: disable=no-member + ) + exc.errno = exception.errno raise exc else: raise try: - return t.listdir_withattributes() - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + return transport.listdir_withattributes() + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.label) # pylint: disable=no-member + ) + exc.errno = exception.errno raise exc else: raise diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index 9c7c8eb61e..ef6040c52c 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -56,17 +56,34 @@ def filename(self): """ return self.get_attribute('filename') - def open(self, key=None, mode='r'): + def open(self, path=None, mode='r', key=None): """Return an open file handle to the content of this data node. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. + + .. deprecated:: 1.4.0 + Starting from `v2.0.0` this will raise if not used in a context manager. + + :param path: the relative path of the object within the repository. :param key: optional key within the repository, by default is the `filename` set in the attributes :param mode: the mode with which to open the file handle (default: read mode) :return: a file handle """ - if key is None: - key = self.filename - - return self._repository.open(key, mode=mode) + from ..node import WarnWhenNotEntered + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key + + if path is None: + path = self.filename + + return WarnWhenNotEntered(self._repository.open(path, mode=mode), repr(self)) def get_content(self): """Return the content of the single file stored for this data node. diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index b4a3682c70..4f4e179571 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -7,29 +7,28 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ This module defines the classes for structures and all related functions to operate on them. """ - -import itertools import copy -from functools import reduce - +import functools +import itertools -from .data import Data from aiida.common.constants import elements from aiida.common.exceptions import UnsupportedSpeciesError +from .data import Data __all__ = ('StructureData', 'Kind', 'Site') # Threshold used to check if the mass of two different Site objects is the same. -_mass_threshold = 1.e-3 +_MASS_THRESHOLD = 1.e-3 # Threshold to check if the sum is one or not -_sum_threshold = 1.e-6 +_SUM_THRESHOLD = 1.e-6 # Threshold used to check if the cell volume is not zero. -_volume_threshold = 1.e-6 +_VOLUME_THRESHOLD = 1.e-6 _valid_symbols = tuple(i['symbol'] for i in elements.values()) _atomic_masses = {el['symbol']: el['mass'] for el in elements.values()} @@ -51,7 +50,7 @@ def _get_valid_cell(inputcell): except (IndexError, ValueError, TypeError): raise ValueError('Cell must be a list of three vectors, each defined as a list of three coordinates.') - if abs(calc_cell_volume(the_cell)) < _volume_threshold: + if abs(calc_cell_volume(the_cell)) < _VOLUME_THRESHOLD: raise ValueError('The cell volume is zero. Invalid cell.') return the_cell @@ -66,7 +65,7 @@ def get_valid_pbc(inputpbc): """ if isinstance(inputpbc, bool): the_pbc = (inputpbc, inputpbc, inputpbc) - elif (hasattr(inputpbc, '__iter__')): + elif hasattr(inputpbc, '__iter__'): # To manage numpy lists of bools, whose elements are of type numpy.bool_ # and for which isinstance(i,bool) return False... if hasattr(inputpbc, 'tolist'): @@ -93,7 +92,7 @@ def has_ase(): :return: True if the ase module can be imported, False otherwise. """ try: - import ase + import ase # pylint: disable=unused-import except ImportError: return False return True @@ -104,7 +103,7 @@ def has_pymatgen(): :return: True if the pymatgen module can be imported, False otherwise. """ try: - import pymatgen + import pymatgen # pylint: disable=unused-import except ImportError: return False return True @@ -125,7 +124,7 @@ def has_spglib(): :return: True if the spglib module can be imported, False otherwise. """ try: - import spglib + import spglib # pylint: disable=unused-import except ImportError: return False return True @@ -143,6 +142,7 @@ def calc_cell_volume(cell): :returns: the cell volume. """ + # pylint: disable=invalid-name # returns the volume of the primitive cell: |a1.(a2xa3)| a1 = cell[0] a2 = cell[1] @@ -243,8 +243,10 @@ def validate_symbols_tuple(symbols_tuple): else: valid = all(is_valid_symbol(sym) for sym in symbols_tuple) if not valid: - raise UnsupportedSpeciesError('At least one element of the symbol list {} has ' - 'not been recognized.'.format(symbols_tuple)) + raise UnsupportedSpeciesError( + 'At least one element of the symbol list {} has ' + 'not been recognized.'.format(symbols_tuple) + ) def is_ase_atoms(ase_atoms): @@ -256,10 +258,7 @@ def is_ase_atoms(ase_atoms): Requires the ability to import ase, by doing 'import ase'. """ - # TODO: Check if we want to try to import ase and do something - # reasonable depending on whether ase is there or not. import ase - return isinstance(ase_atoms, ase.Atoms) @@ -318,8 +317,9 @@ def get_formula_from_symbol_list(_list, separator=''): if isinstance(elem[1], str): list_str.append('{}{}'.format(elem[1], multiplicity_str)) elif elem[0] > 1: - list_str.append('({}){}'.format( - get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str)) + list_str.append( + '({}){}'.format(get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str) + ) else: list_str.append('{}{}'.format(get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str)) @@ -361,15 +361,15 @@ def group_together(_list, group_size, offset): the_list = copy.deepcopy(_list) the_list.reverse() grouped_list = [] - for i in range(offset): + for _ in range(offset): grouped_list.append([the_list.pop()]) while the_list: - l = [] - for i in range(group_size): + sub_list = [] + for _ in range(group_size): if the_list: - l.append(the_list.pop()) - grouped_list.append(l) + sub_list.append(the_list.pop()) + grouped_list.append(sub_list) return grouped_list @@ -402,10 +402,10 @@ def group_together_symbols(_list, group_size): the_symbol_list = copy.deepcopy(_list) has_grouped = False offset = 0 - while (not has_grouped) and (offset < group_size): + while not has_grouped and offset < group_size: grouped_list = group_together(the_symbol_list, group_size, offset) new_symbol_list = group_symbols(grouped_list) - if (len(new_symbol_list) < len(grouped_list)): + if len(new_symbol_list) < len(grouped_list): the_symbol_list = copy.deepcopy(new_symbol_list) the_symbol_list = cleanout_symbol_list(the_symbol_list) has_grouped = True @@ -423,10 +423,9 @@ def group_all_together_symbols(_list): """ has_finished = False group_size = 2 - n = len(_list) the_symbol_list = copy.deepcopy(_list) - while (not has_finished) and (group_size <= n // 2): + while not has_finished and group_size <= len(_list) // 2: # try to group as much as possible by groups of size group_size the_symbol_list, has_grouped = group_together_symbols(the_symbol_list, group_size) has_finished = has_grouped @@ -504,7 +503,7 @@ def get_formula(symbol_list, mode='hill', separator=''): # for hill and count cases, simply count the occurences of each # chemical symbol (with some re-ordering in hill) - elif mode in ['hill', 'hill_compact']: + if mode in ['hill', 'hill_compact']: symbol_set = set(symbol_list) first_symbols = [] if 'C' in symbol_set: @@ -527,11 +526,11 @@ def get_formula(symbol_list, mode='hill', separator=''): the_symbol_list = group_symbols(symbol_list) else: - raise ValueError('Mode should be hill, hill_compact, group, ' 'reduce, count or count_compact') + raise ValueError('Mode should be hill, hill_compact, group, reduce, count or count_compact') if mode in ['hill_compact', 'count_compact']: from math import gcd - the_gcd = reduce(gcd, [e[0] for e in the_symbol_list]) + the_gcd = functools.reduce(gcd, [e[0] for e in the_symbol_list]) the_symbol_list = [[e[0] // the_gcd, e[1]] for e in the_symbol_list] return get_formula_from_symbol_list(the_symbol_list, separator=separator) @@ -555,24 +554,24 @@ def get_symbols_string(symbols, weights): """ if len(symbols) == 1 and weights[0] == 1.: return symbols[0] - else: - pieces = [] - for s, w in zip(symbols, weights): - pieces.append('{}{:4.2f}'.format(s, w)) - if has_vacancies(weights): - pieces.append('X{:4.2f}'.format(1. - sum(weights))) - return '{{{}}}'.format(''.join(sorted(pieces))) + + pieces = [] + for symbol, weight in zip(symbols, weights): + pieces.append('{}{:4.2f}'.format(symbol, weight)) + if has_vacancies(weights): + pieces.append('X{:4.2f}'.format(1. - sum(weights))) + return '{{{}}}'.format(''.join(sorted(pieces))) def has_vacancies(weights): """ Returns True if the sum of the weights is less than one. - It uses the internal variable _sum_threshold as a threshold. + It uses the internal variable _SUM_THRESHOLD as a threshold. :param weights: the weights :return: a boolean """ w_sum = sum(weights) - return not (1. - w_sum < _sum_threshold) + return not 1. - w_sum < _SUM_THRESHOLD def symop_ortho_from_fract(cell): @@ -586,6 +585,7 @@ def symop_ortho_from_fract(cell): :param cell: array of cell parameters (three lengths and three angles) """ + # pylint: disable=invalid-name import math import numpy @@ -609,6 +609,7 @@ def symop_fract_from_ortho(cell): :param cell: array of cell parameters (three lengths and three angles) """ + # pylint: disable=invalid-name import math import numpy @@ -681,12 +682,12 @@ def atom_kinds_to_html(atom_kind): # Parse the formula (TODO can be made more robust though never fails if # it takes strings generated with kind.get_symbols_string()) import re - elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) + matched_elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) # Compose the html string html_formula_pieces = [] - for element in elements: + for element in matched_elements: # replace element X by 'vacancy' species = element[0] if element[0] != 'X' else 'vacancy' @@ -709,6 +710,9 @@ class StructureData(Data): boundary conditions (whether they are periodic or not) and other related useful information. """ + + # pylint: disable=too-many-public-methods + _set_incompatibilities = [('ase', 'cell'), ('ase', 'pbc'), ('ase', 'pymatgen'), ('ase', 'pymatgen_molecule'), ('ase', 'pymatgen_structure'), ('cell', 'pymatgen'), ('cell', 'pymatgen_molecule'), ('cell', 'pymatgen_structure'), ('pbc', 'pymatgen'), ('pbc', 'pymatgen_molecule'), @@ -716,9 +720,18 @@ class StructureData(Data): ('pymatgen', 'pymatgen_structure'), ('pymatgen_molecule', 'pymatgen_structure')] _dimensionality_label = {0: '', 1: 'length', 2: 'surface', 3: 'volume'} - - def __init__(self, cell=None, pbc=None, ase=None, pymatgen=None, pymatgen_structure=None, pymatgen_molecule=None, **kwargs): - + _internal_kind_tags = None + + def __init__( + self, + cell=None, + pbc=None, + ase=None, + pymatgen=None, + pymatgen_structure=None, + pymatgen_molecule=None, + **kwargs + ): # pylint: disable=too-many-arguments args = { 'cell': cell, 'pbc': pbc, @@ -779,8 +792,7 @@ def get_dimensionality(self): if dim == 0: pass elif dim == 1: - v = cell[pbc] - retdict['value'] = np.linalg.norm(v) + retdict['value'] = np.linalg.norm(cell[pbc]) elif dim == 2: vectors = cell[pbc] retdict['value'] = np.linalg.norm(np.cross(vectors[0], vectors[1])) @@ -830,8 +842,8 @@ def set_pymatgen_molecule(self, mol, margin=5): of earlier versions may cause errors). """ box = [ - max([x.coords.tolist()[0] for x in mol.sites]) - min([x.coords.tolist()[0] for x in mol.sites - ]) + 2 * margin, + max([x.coords.tolist()[0] for x in mol.sites]) - min([x.coords.tolist()[0] for x in mol.sites]) + + 2 * margin, max([x.coords.tolist()[1] for x in mol.sites]) - min([x.coords.tolist()[1] for x in mol.sites]) + 2 * margin, max([x.coords.tolist()[2] for x in mol.sites]) - min([x.coords.tolist()[2] for x in mol.sites]) + 2 * margin @@ -886,8 +898,8 @@ def build_kind_name(species_and_occu): kind_name += '2' return kind_name - else: - return None + + return None self.cell = struct.lattice.matrix.tolist() self.pbc = [True, True, True] @@ -910,7 +922,7 @@ def build_kind_name(species_and_occu): inputs = { 'symbols': [x.symbol for x in species_and_occu.keys()], - 'weights': [x for x in species_and_occu.values()], + 'weights': list(species_and_occu.values()), 'position': site.coords.tolist() } @@ -947,10 +959,12 @@ def _validate(self): from collections import Counter counts = Counter([k.name for k in kinds]) - for c in counts: - if counts[c] != 1: - raise ValidationError("Kind with name '{}' appears {} times " - 'instead of only one'.format(c, counts[c])) + for count in counts: + if counts[count] != 1: + raise ValidationError( + "Kind with name '{}' appears {} times " + 'instead of only one'.format(count, counts[count]) + ) try: # This will try to create the sites objects @@ -960,15 +974,19 @@ def _validate(self): for site in sites: if site.kind_name not in [k.name for k in kinds]: - raise ValidationError('A site has kind {}, but no specie with that name exists' - ''.format(site.kind_name)) + raise ValidationError( + 'A site has kind {}, but no specie with that name exists' + ''.format(site.kind_name) + ) kinds_without_sites = (set(k.name for k in kinds) - set(s.kind_name for s in sites)) if kinds_without_sites: - raise ValidationError('The following kinds are defined, but there ' - 'are no sites with that kind: {}'.format(list(kinds_without_sites))) + raise ValidationError( + 'The following kinds are defined, but there ' + 'are no sites with that kind: {}'.format(list(kinds_without_sites)) + ) - def _prepare_xsf(self, main_file_name=''): + def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format XSF (for XCrySDen). """ @@ -990,19 +1008,20 @@ def _prepare_xsf(self, main_file_name=''): return_string += '%18.10f %18.10f %18.10f\n' % tuple(site.position) return return_string.encode('utf-8'), {} - def _prepare_cif(self, main_file_name=''): + def _prepare_cif(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format CIF. """ from aiida.orm import CifData cif = CifData(ase=self.get_ase()) - return cif._prepare_cif() + return cif._prepare_cif() # pylint: disable=protected-access - def _prepare_chemdoodle(self, main_file_name=''): + def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format required by ChemDoodle. """ + # pylint: disable=too-many-locals,invalid-name import numpy as np from itertools import product @@ -1044,7 +1063,6 @@ def _prepare_chemdoodle(self, main_file_name=''): 'x': base_site['position'][0] + shift[0], 'y': base_site['position'][1] + shift[1], 'z': base_site['position'][2] + shift[2], - # 'atomic_elements_html': kind_string 'atomic_elements_html': atom_kinds_to_html(kind_string) }) @@ -1065,7 +1083,7 @@ def _prepare_chemdoodle(self, main_file_name=''): return json.dumps(return_dict).encode('utf-8'), {} - def _prepare_xyz(self, main_file_name=''): + def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format XYZ. """ @@ -1076,14 +1094,20 @@ def _prepare_xyz(self, main_file_name=''): cell = self.cell return_list = ['{}'.format(len(sites))] - return_list.append('Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( - cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], cell[2][2], - self.pbc[0], self.pbc[1], self.pbc[2])) + return_list.append( + 'Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( + cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], + cell[2][2], self.pbc[0], self.pbc[1], self.pbc[2] + ) + ) for site in sites: # I checked above that it is not an alloy, therefore I take the # first symbol - return_list.append('{:6s} {:18.10f} {:18.10f} {:18.10f}'.format( - self.get_kind(site.kind_name).symbols[0], site.position[0], site.position[1], site.position[2])) + return_list.append( + '{:6s} {:18.10f} {:18.10f} {:18.10f}'.format( + self.get_kind(site.kind_name).symbols[0], site.position[0], site.position[1], site.position[2] + ) + ) return_string = '\n'.join(return_list) return return_string.encode('utf-8'), {} @@ -1115,8 +1139,8 @@ def _adjust_default_cell(self, vacuum_factor=1.0, vacuum_addition=10.0, pbc=(Fal leading to an unphysical definition of the structure. This method will adjust the cell """ + # pylint: disable=invalid-name import numpy as np - from ase.visualize import view def get_extremas_from_positions(positions): """ @@ -1130,7 +1154,7 @@ def get_extremas_from_positions(positions): # Calculating the minimal cell: positions = np.array([site.position for site in self.sites]) - position_min, position_max = get_extremas_from_positions(positions) + position_min, _ = get_extremas_from_positions(positions) # Translate the structure to the origin, such that the minimal values in each dimension # amount to (0,0,0) @@ -1336,11 +1360,11 @@ def append_kind(self, kind): # If here, no exceptions have been raised, so I add the site. self.attributes.setdefault('kinds', []).append(new_kind.get_raw()) - # Note, this is a dict (with integer keys) so it allows for empty - # spots! - if not hasattr(self, '_internal_kind_tags'): + # Note, this is a dict (with integer keys) so it allows for empty spots! + if self._internal_kind_tags is None: self._internal_kind_tags = {} - self._internal_kind_tags[len(self.get_attribute('kinds')) - 1] = kind._internal_tag + + self._internal_kind_tags[len(self.get_attribute('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access def append_site(self, site): """ @@ -1357,9 +1381,11 @@ def append_site(self, site): new_site = Site(site=site) # So we make a copy - if site.kind_name not in [k.name for k in self.kinds]: - raise ValueError("No kind with name '{}', available kinds are: " - '{}'.format(site.kind_name, [k.name for k in self.kinds])) + if site.kind_name not in [kind.name for kind in self.kinds]: + raise ValueError( + "No kind with name '{}', available kinds are: " + '{}'.format(site.kind_name, [kind.name for kind in self.kinds]) + ) # If here, no exceptions have been raised, so I add the site. self.attributes.setdefault('sites', []).append(new_site.get_raw()) @@ -1396,12 +1422,15 @@ def append_atom(self, **kwargs): .. note :: checks of equality of species are done using the :py:meth:`~aiida.orm.nodes.data.structure.Kind.compare_with` method. """ + # pylint: disable=too-many-branches aseatom = kwargs.pop('ase', None) if aseatom is not None: if kwargs: - raise ValueError("If you pass 'ase' as a parameter to " - 'append_atom, you cannot pass any further' - 'parameter') + raise ValueError( + "If you pass 'ase' as a parameter to " + 'append_atom, you cannot pass any further' + 'parameter' + ) position = aseatom.position kind = Kind(ase=aseatom) else: @@ -1420,13 +1449,13 @@ def append_atom(self, **kwargs): exists_already = False for idx, existing_kind in enumerate(_kinds): try: - existing_kind._internal_tag = self._internal_kind_tags[idx] + existing_kind._internal_tag = self._internal_kind_tags[idx] # pylint: disable=protected-access except KeyError: # self._internal_kind_tags does not contain any info for # the kind in position idx: I don't have to add anything # then, and I continue pass - if (kind.compare_with(existing_kind)[0]): + if kind.compare_with(existing_kind)[0]: kind = existing_kind exists_already = True break @@ -1456,10 +1485,12 @@ def append_atom(self, **kwargs): if is_the_same: kind = old_kind else: - raise ValueError('You are explicitly setting the name ' - "of the kind to '{}', that already " - 'exists, but the two kinds are different!' - ' (first difference: {})'.format(kind.name, firstdiff)) + raise ValueError( + 'You are explicitly setting the name ' + "of the kind to '{}', that already " + 'exists, but the two kinds are different!' + ' (first difference: {})'.format(kind.name, firstdiff) + ) site = Site(kind_name=kind.name, position=position) self.append_site(site) @@ -1528,7 +1559,7 @@ def get_kind(self, kind_name): try: kinds_dict = self._kinds_cache except AttributeError: - self._kinds_cache = {_.name: _ for _ in self.kinds} + self._kinds_cache = {_.name: _ for _ in self.kinds} # pylint: disable=attribute-defined-outside-init kinds_dict = self._kinds_cache else: kinds_dict = {_.name: _ for _ in self.kinds} @@ -1562,9 +1593,11 @@ def cell(self): @cell.setter def cell(self, value): + """Set the cell.""" self.set_cell(value) def set_cell(self, value): + """Set the cell.""" from aiida.common.exceptions import ModificationNotAllowed if self.is_stored: @@ -1611,7 +1644,6 @@ def reset_sites_positions(self, new_positions, conserve_particle=True): raise ModificationNotAllowed() if not conserve_particle: - # TODO: raise NotImplementedError else: @@ -1653,9 +1685,11 @@ def pbc(self): @pbc.setter def pbc(self, value): + """Set the periodic boundary conditions.""" self.set_pbc(value) def set_pbc(self, value): + """Set the periodic boundary conditions.""" from aiida.common.exceptions import ModificationNotAllowed if self.is_stored: @@ -1765,7 +1799,7 @@ def _get_object_phonopyatoms(self): :return: a PhonopyAtoms object """ - from phonopy.structure.atoms import Atoms as PhonopyAtoms + from phonopy.structure.atoms import PhonopyAtoms # pylint: disable=import-error atoms = PhonopyAtoms(symbols=[_.kind_name for _ in self.sites]) # Phonopy internally uses scaled positions, so you must store cell first! @@ -1805,8 +1839,8 @@ def _get_object_pymatgen(self, **kwargs): """ if self.pbc == (True, True, True): return self._get_object_pymatgen_structure(**kwargs) - else: - return self._get_object_pymatgen_molecule(**kwargs) + + return self._get_object_pymatgen_molecule(**kwargs) def _get_object_pymatgen_structure(self, **kwargs): """ @@ -1844,24 +1878,25 @@ def _get_object_pymatgen_structure(self, **kwargs): # case when spins are defined -> no partial occupancy allowed from pymatgen import Specie oxidation_state = 0 # now I always set the oxidation_state to zero - for s in self.sites: - k = self.get_kind(s.kind_name) - if len(k.symbols) != 1 or (len(k.weights) != 1 or sum(k.weights) < 1.): + for site in self.sites: + kind = self.get_kind(site.kind_name) + if len(kind.symbols) != 1 or (len(kind.weights) != 1 or sum(kind.weights) < 1.): raise ValueError('Cannot set partial occupancies and spins at the same time') species.append( Specie( - k.symbols[0], + kind.symbols[0], oxidation_state, - properties={'spin': -1 if k.name.endswith('1') else 1 if k.name.endswith('2') else 0})) + properties={'spin': -1 if kind.name.endswith('1') else 1 if kind.name.endswith('2') else 0} + ) + ) else: # case when no spin are defined - for s in self.sites: - k = self.get_kind(s.kind_name) - species.append({s: w for s, w in zip(k.symbols, k.weights)}) + for site in self.sites: + kind = self.get_kind(site.kind_name) + species.append(dict(zip(kind.symbols, kind.weights))) if any([ - create_automatic_kind_name(self.get_kind(name).symbols, - self.get_kind(name).weights) != name - for name in self.get_site_kindnames() + create_automatic_kind_name(self.get_kind(name).symbols, + self.get_kind(name).weights) != name for name in self.get_site_kindnames() ]): # add "kind_name" as a properties to each site, whenever # the kind_name cannot be automatically obtained from the symbols @@ -1892,11 +1927,11 @@ def _get_object_pymatgen_molecule(self, **kwargs): raise ValueError('Unrecognized parameters passed to pymatgen converter: {}'.format(kwargs.keys())) species = [] - for s in self.sites: - k = self.get_kind(s.kind_name) - species.append({s: w for s, w in zip(k.symbols, k.weights)}) + for site in self.sites: + kind = self.get_kind(site.kind_name) + species.append(dict(zip(kind.symbols, kind.weights))) - positions = [list(x.position) for x in self.sites] + positions = [list(site.position) for site in self.sites] return Molecule(species, positions) @@ -1931,6 +1966,7 @@ def __init__(self, **kwargs): :param name: a string that uniquely identifies the kind, and that is used to identify the sites. """ + # pylint: disable=too-many-branches,too-many-statements # Internal variables self._mass = None self._symbols = None @@ -1976,9 +2012,11 @@ def __init__(self, **kwargs): self.name = oldkind.name self._internal_tag = oldkind._internal_tag except AttributeError: - raise ValueError('Error using the Kind object. Are you sure ' - 'it is a Kind object? [Introspection says it is ' - '{}]'.format(str(type(oldkind)))) + raise ValueError( + 'Error using the Kind object. Are you sure ' + 'it is a Kind object? [Introspection says it is ' + '{}]'.format(str(type(oldkind))) + ) elif 'ase' in kwargs: aseatom = kwargs['ase'] @@ -1994,9 +2032,11 @@ def __init__(self, **kwargs): else: self.reset_mass() except AttributeError: - raise ValueError('Error using the aseatom object. Are you sure ' - 'it is a ase.atom.Atom object? [Introspection says it is ' - '{}]'.format(str(type(aseatom)))) + raise ValueError( + 'Error using the aseatom object. Are you sure ' + 'it is a ase.atom.Atom object? [Introspection says it is ' + '{}]'.format(str(type(aseatom))) + ) if aseatom.tag != 0: self.set_automatic_kind_name(tag=aseatom.tag) self._internal_tag = aseatom.tag @@ -2004,9 +2044,11 @@ def __init__(self, **kwargs): self.set_automatic_kind_name() else: if 'symbols' not in kwargs: - raise ValueError("'symbols' need to be " - 'specified (at least) to create a Site object. Otherwise, ' - "pass a raw site using the 'raw' parameter.") + raise ValueError( + "'symbols' need to be " + 'specified (at least) to create a Site object. Otherwise, ' + "pass a raw site using the 'raw' parameter." + ) weights = kwargs.pop('weights', None) self.set_symbols_and_weights(kwargs.pop('symbols'), weights) try: @@ -2051,7 +2093,7 @@ def reset_mass(self): """ w_sum = sum(self._weights) - if abs(w_sum) < _sum_threshold: + if abs(w_sum) < _SUM_THRESHOLD: self._mass = None return @@ -2115,21 +2157,28 @@ def compare_with(self, other_kind): # Check list of symbols for i in range(len(self.symbols)): if self.symbols[i] != other_kind.symbols[i]: - return (False, 'Symbol at position {:d} are different ' - '({} vs. {})'.format(i + 1, self.symbols[i], other_kind.symbols[i])) + return ( + False, 'Symbol at position {:d} are different ' + '({} vs. {})'.format(i + 1, self.symbols[i], other_kind.symbols[i]) + ) # Check weights (assuming length of weights and of symbols have same # length, which should be always true for i in range(len(self.weights)): if self.weights[i] != other_kind.weights[i]: - return (False, 'Weight at position {:d} are different ' - '({} vs. {})'.format(i + 1, self.weights[i], other_kind.weights[i])) + return ( + False, 'Weight at position {:d} are different ' + '({} vs. {})'.format(i + 1, self.weights[i], other_kind.weights[i]) + ) # Check masses - if abs(self.mass - other_kind.mass) > _mass_threshold: + if abs(self.mass - other_kind.mass) > _MASS_THRESHOLD: return (False, 'Masses are different ({} vs. {})'.format(self.mass, other_kind.mass)) - if self._internal_tag != other_kind._internal_tag: - return (False, 'Internal tags are different ({} vs. {})' - ''.format(self._internal_tag, other_kind._internal_tag)) + if self._internal_tag != other_kind._internal_tag: # pylint: disable=protected-access + return ( + False, + 'Internal tags are different ({} vs. {})' + ''.format(self._internal_tag, other_kind._internal_tag) # pylint: disable=protected-access + ) # If we got here, the two Site objects are similar enough # to be considered of the same kind @@ -2169,9 +2218,11 @@ def weights(self, value): weights_tuple = _create_weights_tuple(value) if len(weights_tuple) != len(self._symbols): - raise ValueError('Cannot change the number of weights. Use the ' - 'set_symbols_and_weights function instead.') - validate_weights_tuple(weights_tuple, _sum_threshold) + raise ValueError( + 'Cannot change the number of weights. Use the ' + 'set_symbols_and_weights function instead.' + ) + validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) self._weights = weights_tuple @@ -2200,8 +2251,8 @@ def symbol(self): """ if len(self._symbols) == 1: return self._symbols[0] - else: - raise ValueError('This kind has more than one symbol (it is an alloy): {}'.format(self._symbols)) + + raise ValueError('This kind has more than one symbol (it is an alloy): {}'.format(self._symbols)) @property def symbols(self): @@ -2227,8 +2278,10 @@ def symbols(self, value): symbols_tuple = _create_symbols_tuple(value) if len(symbols_tuple) != len(self._weights): - raise ValueError('Cannot change the number of symbols. Use the ' - 'set_symbols_and_weights function instead.') + raise ValueError( + 'Cannot change the number of symbols. Use the ' + 'set_symbols_and_weights function instead.' + ) validate_symbols_tuple(symbols_tuple) self._symbols = symbols_tuple @@ -2244,7 +2297,7 @@ def set_symbols_and_weights(self, symbols, weights): if len(symbols_tuple) != len(weights_tuple): raise ValueError('The number of symbols and weights must coincide.') validate_symbols_tuple(symbols_tuple) - validate_weights_tuple(weights_tuple, _sum_threshold) + validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) self._symbols = symbols_tuple self._weights = weights_tuple @@ -2260,7 +2313,7 @@ def is_alloy(self): def has_vacancies(self): """Return whether the Kind contains vacancies, i.e. when the sum of the weights is less than one. - .. note:: the property uses the internal variable `_sum_threshold` as a threshold. + .. note:: the property uses the internal variable `_SUM_THRESHOLD` as a threshold. :return: boolean, True if the sum of the weights is less than one, False otherwise """ @@ -2345,6 +2398,7 @@ def get_ase(self, kinds): .. note:: If any site is an alloy or has vacancies, a ValueError is raised (from the site.get_ase() routine). """ + # pylint: disable=too-many-branches from collections import defaultdict import ase @@ -2371,7 +2425,7 @@ def get_ase(self, kinds): pass tag_list.append(k.symbols[0]) # I use a string as a placeholder - for i in range(len(tag_list)): + for i, _ in enumerate(tag_list): # If it is a string, it is the name of the element, # and I have to generate a new integer for this element # and replace tag_list[i] with this new integer @@ -2388,10 +2442,10 @@ def get_ase(self, kinds): tag_list[i] = new_tag found = False - for k, t in zip(kinds, tag_list): - if k.name == self.kind_name: - kind = k - tag = t + for kind_candidate, tag_candidate in zip(kinds, tag_list): + if kind_candidate.name == self.kind_name: + kind = kind_candidate + tag = tag_candidate found = True break if not found: @@ -2401,7 +2455,7 @@ def get_ase(self, kinds): raise ValueError('Cannot convert to ASE if the kind represents an alloy or it has vacancies.') aseatom = ase.Atom(position=self.position, symbol=str(kind.symbols[0]), mass=kind.mass) if tag is not None: - aseatom.tag = tag + aseatom.tag = tag # pylint: disable=assigning-non-slot return aseatom @property diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 5c4f11a4c7..b2fe9e4ae5 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -7,9 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,too-many-arguments """Package for node ORM classes.""" -import copy import importlib import warnings @@ -21,13 +20,13 @@ from aiida.common.warnings import AiidaDeprecationWarning from aiida.manage.manager import get_manager from aiida.orm.utils.links import LinkManager, LinkTriple -from aiida.orm.utils.repository import Repository -from aiida.orm.utils.node import AbstractNodeMeta, validate_attribute_extra_key +from aiida.orm.utils._repository import Repository +from aiida.orm.utils.node import AbstractNodeMeta from aiida.orm import autogroup from ..comments import Comment from ..computers import Computer -from ..entities import Entity +from ..entities import Entity, EntityExtrasMixin, EntityAttributesMixin from ..entities import Collection as EntityCollection from ..querybuilder import QueryBuilder from ..users import User @@ -37,7 +36,47 @@ _NO_DEFAULT = tuple() -class Node(Entity, metaclass=AbstractNodeMeta): +class WarnWhenNotEntered: + """Temporary wrapper to warn when `Node.open` is called outside of a context manager.""" + + def __init__(self, fileobj, name): + self._fileobj = fileobj + self._name = name + self._was_entered = False + + def _warn_if_not_entered(self, method): + """Fire a warning if the object wrapper has not yet been entered.""" + if not self._was_entered: + msg = '`{}` used without context manager for {}. This will raise starting from `aiida-core==2.0.0`'.format( + method, self._name + ) + warnings.warn(msg, AiidaDeprecationWarning) # pylint: disable=no-member + + def __enter__(self): + self._was_entered = True + return self._fileobj.__enter__() + + def __exit__(self, *args): + self._fileobj.__exit__(*args) + + def __getattr__(self, key): + if key == '_fileobj': + return self._fileobj + return getattr(self._fileobj, key) + + def __del__(self): + self._warn_if_not_entered('del') + + def read(self, *args, **kwargs): + self._warn_if_not_entered('read') + return self._fileobj.read(*args, **kwargs) + + def close(self, *args, **kwargs): + self._warn_if_not_entered('close') + return self._fileobj.close(*args, **kwargs) + + +class Node(Entity, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): """ Base class for all nodes in AiiDA. @@ -329,368 +368,190 @@ def mtime(self): """ return self.backend_entity.mtime - @property - def attributes(self): - """Return the complete attributes dictionary. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - attributes = self.backend_entity.attributes - - if self.is_stored: - attributes = copy.deepcopy(attributes) - - return attributes - - def get_attribute(self, key, default=_NO_DEFAULT): - """Return the value of an attribute. - - .. warning:: While the node is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :param default: return this value instead of raising if the attribute does not exist - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist and no default is specified - """ - try: - attribute = self.backend_entity.get_attribute(key) - except AttributeError: - if default is _NO_DEFAULT: - raise - attribute = default - - if self.is_stored: - attribute = copy.deepcopy(attribute) - - return attribute - - def get_attribute_many(self, keys): - """Return the values of multiple attributes. - - .. warning:: While the node is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - attributes = self.backend_entity.get_attribute_many(keys) - - if self.is_stored: - attributes = copy.deepcopy(attributes) - - return attributes - - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods - :raise aiida.common.ModificationNotAllowed: if the node is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - validate_attribute_extra_key(key) - self.backend_entity.set_attribute(key, value) - - def set_attribute_many(self, attributes): - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods - :raise aiida.common.ModificationNotAllowed: if the node is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - for key in attributes: - validate_attribute_extra_key(key) - - self.backend_entity.set_attribute_many(attributes) - - def reset_attributes(self, attributes): - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods - :raise aiida.common.ModificationNotAllowed: if the node is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - for key in attributes: - validate_attribute_extra_key(key) - - self.backend_entity.reset_attributes(attributes) - - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - :raise aiida.common.ModificationNotAllowed: if the node is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - self.backend_entity.delete_attribute(key) - - def delete_attribute_many(self, keys): - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - :raise aiida.common.ModificationNotAllowed: if the node is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - self.backend_entity.delete_attribute_many(keys) - - def clear_attributes(self): - """Delete all attributes.""" - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored node are immutable') - - self.backend_entity.clear_attributes() - - def attributes_items(self): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - return self.backend_entity.attributes_items() - - def attributes_keys(self): - """Return an iterator over the attribute keys. + def list_objects(self, path=None, key=None): + """Return a list of the objects contained in this repository, optionally in the given sub directory. - :return: an iterator with attribute keys - """ - return self.backend_entity.attributes_keys() + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - @property - def extras(self): - """Return the complete extras dictionary. - - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :return: the extras as a dictionary + :param path: the relative path of the object within the repository. + :param key: fully qualified identifier for the object within the repository + :return: a list of `File` named tuples representing the objects present in directory with the given path + :raises FileNotFoundError: if the `path` does not exist in the repository of this node """ - extras = self.backend_entity.extras - - if self.is_stored: - extras = copy.deepcopy(extras) + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - return extras + return self._repository.list_objects(path) - def get_extra(self, key, default=_NO_DEFAULT): - """Return the value of an extra. + def list_object_names(self, path=None, key=None): + """Return a list of the object names contained in this repository, optionally in the given sub directory. - .. warning:: While the node is unstored, this will return a reference of the extra on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extra - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - :param key: name of the extra - :param default: return this value instead of raising if the attribute does not exist - :return: the value of the extra - :raises AttributeError: if the extra does not exist and no default is specified + :param path: the relative path of the object within the repository. + :param key: fully qualified identifier for the object within the repository + :return: a list of `File` named tuples representing the objects present in directory with the given path """ - try: - extra = self.backend_entity.get_extra(key) - except AttributeError: - if default is _NO_DEFAULT: - raise - extra = default + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - if self.is_stored: - extra = copy.deepcopy(extra) + return self._repository.list_object_names(path) - return extra + def open(self, path=None, mode='r', key=None): + """Open a file handle to the object with the given path. - def get_extra_many(self, keys): - """Return the values of multiple extras. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - .. warning:: While the node is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the node is stored, the returned extras - will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. + .. deprecated:: 1.4.0 + Starting from `v2.0.0` this will raise if not used in a context manager. - :param keys: a list of extra names - :return: a list of extra values - :raises AttributeError: if at least one extra does not exist + :param path: the relative path of the object within the repository. + :param key: fully qualified identifier for the object within the repository + :param mode: the mode under which to open the handle """ - extras = self.backend_entity.get_extra_many(keys) - - if self.is_stored: - extras = copy.deepcopy(extras) + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - return extras + if path is None: + raise TypeError("open() missing 1 required positional argument: 'path'") - def set_extra(self, key, value): - """Set an extra to the given value. + if mode not in ['r', 'rb']: + warnings.warn("from v2.0 only the modes 'r' and 'rb' will be accepted", AiidaDeprecationWarning) # pylint: disable=no-member - :param key: name of the extra - :param value: value of the extra - :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods - """ - validate_attribute_extra_key(key) - self.backend_entity.set_extra(key, value) + return WarnWhenNotEntered(self._repository.open(path, mode), repr(self)) - def set_extra_many(self, extras): - """Set multiple extras. + def get_object(self, path=None, key=None): + """Return the object with the given path. - .. note:: This will override any existing extras that are present in the new dictionary. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - :param extras: a dictionary with the extras to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + :param path: the relative path of the object within the repository. + :param key: fully qualified identifier for the object within the repository + :return: a `File` named tuple """ - for key in extras: - validate_attribute_extra_key(key) - - self.backend_entity.set_extra_many(extras) + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - def reset_extras(self, extras): - """Reset the extras. + if path is None: + raise TypeError("get_object() missing 1 required positional argument: 'path'") - .. note:: This will completely clear any existing extras and replace them with the new dictionary. - - :param extras: a dictionary with the extras to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods - """ - for key in extras: - validate_attribute_extra_key(key) + return self._repository.get_object(path) - self.backend_entity.reset_extras(extras) + def get_object_content(self, path=None, mode='r', key=None): + """Return the content of a object with the given path. - def delete_extra(self, key): - """Delete an extra. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - :param key: name of the extra - :raises AttributeError: if the extra does not exist + :param path: the relative path of the object within the repository. + :param key: fully qualified identifier for the object within the repository """ - self.backend_entity.delete_extra(key) - - def delete_extra_many(self, keys): - """Delete multiple extras. + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - :param keys: names of the extras to delete - :raises AttributeError: if at least one of the extra does not exist - """ - self.backend_entity.delete_extra_many(keys) + if path is None: + raise TypeError("get_object_content() missing 1 required positional argument: 'path'") - def clear_extras(self): - """Delete all extras.""" - self.backend_entity.clear_extras() + if mode not in ['r', 'rb']: + warnings.warn("from v2.0 only the modes 'r' and 'rb' will be accepted", AiidaDeprecationWarning) # pylint: disable=no-member - def extras_items(self): - """Return an iterator over the extras. + return self._repository.get_object_content(path, mode) - :return: an iterator with extra key value pairs - """ - return self.backend_entity.extras_items() + def put_object_from_tree(self, filepath, path=None, contents_only=True, force=False, key=None): + """Store a new object under `path` with the contents of the directory located at `filepath` on this file system. - def extras_keys(self): - """Return an iterator over the extra keys. + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! - :return: an iterator with extra keys - """ - return self.backend_entity.extras_keys() + .. deprecated:: 1.4.0 + First positional argument `path` has been deprecated and renamed to `filepath`. - def list_objects(self, key=None): - """Return a list of the objects contained in this repository, optionally in the given sub directory. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key - :raises FileNotFoundError: if the `path` does not exist in the repository of this node - """ - return self._repository.list_objects(key) + .. deprecated:: 1.4.0 + Keyword `force` is deprecated and will be removed in `v2.0.0`. - def list_object_names(self, key=None): - """Return a list of the object names contained in this repository, optionally in the given sub directory. + .. deprecated:: 1.4.0 + Keyword `contents_only` is deprecated and will be removed in `v2.0.0`. + :param filepath: absolute path of directory whose contents to copy to the repository + :param path: the relative path of the object within the repository. :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key + :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ - return self._repository.list_object_names(key) + if force: + warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - def open(self, key, mode='r'): - """Open a file handle to an object stored under the given key. + if contents_only is False: + warnings.warn( + 'the `contents_only` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning + ) # pylint: disable=no-member - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - return self._repository.open(key, mode) - - def get_object(self, key): - """Return the object identified by key. + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - :param key: fully qualified identifier for the object within the repository - :return: a `File` named tuple representing the object located at key - """ - return self._repository.get_object(key) + self._repository.put_object_from_tree(filepath, path, contents_only, force) - def get_object_content(self, key, mode='r'): - """Return the content of a object identified by key. - - :param key: fully qualified identifier for the object within the repository - """ - return self._repository.get_object_content(key, mode) - - def put_object_from_tree(self, path, key=None, contents_only=True, force=False): - """Store a new object under `key` with the contents of the directory located at `path` on this file system. + def put_object_from_file(self, filepath, path=None, mode=None, encoding=None, force=False, key=None): + """Store a new object under `path` with contents of the file located at `filepath` on this file system. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. This check can be avoided by using the `force` flag, but this should be used with extreme caution! - :param path: absolute path of directory whose contents to copy to the repository - :param key: fully qualified identifier for the object within the repository - :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - self._repository.put_object_from_tree(path, key, contents_only, force) + .. deprecated:: 1.4.0 + First positional argument `path` has been deprecated and renamed to `filepath`. - def put_object_from_file(self, path, key, mode=None, encoding=None, force=False): - """Store a new object under `key` with contents of the file located at `path` on this file system. + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! + .. deprecated:: 1.4.0 + Keyword `force` is deprecated and will be removed in `v2.0.0`. - :param path: absolute path of file whose contents to copy to the repository + :param filepath: absolute path of file whose contents to copy to the repository + :param path: the relative path where to store the object in the repository. :param key: fully qualified identifier for the object within the repository :param mode: the file mode with which the object will be written Deprecated: will be removed in `v2.0.0` @@ -703,42 +564,100 @@ def put_object_from_file(self, path, key, mode=None, encoding=None, force=False) # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, # will continue working the same. + if force: + warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member + if mode is not None: - warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning) # pylint: disable=no-member + warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member if encoding is not None: warnings.warn( # pylint: disable=no-member 'the `encoding` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning ) - self._repository.put_object_from_file(path, key, mode, encoding, force) + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key - def put_object_from_filelike(self, handle, key, mode='w', encoding='utf8', force=False): - """Store a new object under `key` with contents of filelike object `handle`. + if path is None: + raise TypeError("put_object_from_file() missing 1 required positional argument: 'path'") + + self._repository.put_object_from_file(filepath, path, mode, encoding, force) + + def put_object_from_filelike(self, handle, path=None, mode='w', encoding='utf8', force=False, key=None): + """Store a new object under `path` with contents of filelike object `handle`. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. This check can be avoided by using the `force` flag, but this should be used with extreme caution! + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. + + .. deprecated:: 1.4.0 + Keyword `force` is deprecated and will be removed in `v2.0.0`. + :param handle: filelike object with the content to be stored + :param path: the relative path where to store the object in the repository. :param key: fully qualified identifier for the object within the repository :param mode: the file mode with which the object will be written :param encoding: the file encoding with which the object will be written :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ - self._repository.put_object_from_filelike(handle, key, mode, encoding, force) + if force: + warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - def delete_object(self, key, force=False): + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key + + if path is None: + raise TypeError("put_object_from_filelike() missing 1 required positional argument: 'path'") + + self._repository.put_object_from_filelike(handle, path, mode, encoding, force) + + def delete_object(self, path=None, force=False, key=None): """Delete the object from the repository. .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. This check can be avoided by using the `force` flag, but this should be used with extreme caution! + .. deprecated:: 1.4.0 + Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. + + .. deprecated:: 1.4.0 + Keyword `force` is deprecated and will be removed in `v2.0.0`. + :param key: fully qualified identifier for the object within the repository :param force: boolean, if True, will skip the mutability check :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` """ - self._repository.delete_object(key, force) + if force: + warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member + + if key is not None: + if path is not None: + raise ValueError('cannot specify both `path` and `key`.') + warnings.warn( + 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member + path = key + + if path is None: + raise TypeError("delete_object() missing 1 required positional argument: 'path'") + + self._repository.delete_object(path, force) def add_comment(self, content, user=None): """Add a new comment. diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index bd2be3dd50..5ac0c2f682 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ The QueryBuilder: A class that allows you to query the AiiDA database, independent from backend. Note that the backend implementation is enforced and handled with a composition model! @@ -18,19 +19,18 @@ An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance when instantiated by the user. """ -# Checking for correct input with the inspect module from inspect import isclass as inspect_isclass import copy import logging import warnings + from sqlalchemy import and_, or_, not_, func as sa_func, select, join from sqlalchemy.types import Integer from sqlalchemy.orm import aliased -from sqlalchemy.sql.expression import cast +from sqlalchemy.sql.expression import cast as type_cast from sqlalchemy.dialects.postgresql import array from aiida.common.exceptions import InputValidationError -# The way I get column as a an attribute to the orm class from aiida.common.links import LinkType from aiida.manage.manager import get_manager from aiida.common.exceptions import ConfigurationError @@ -58,18 +58,19 @@ GROUP_ENTITY_TYPE_PREFIX = 'group.' -def get_querybuilder_classifiers_from_cls(cls, qb): +def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invalid-name """ Return the correct classifiers for the QueryBuilder from an ORM class. :param cls: an AiiDA ORM class or backend ORM class. - :param qb: an instance of the appropriate QueryBuilder backend. + :param query: an instance of the appropriate QueryBuilder backend. :returns: the ORM class as well as a dictionary with additional classifier strings :rtype: cls, dict Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) """ + # pylint: disable=protected-access,too-many-branches,too-many-statements # Note: Unable to move this import to the top of the module for some reason from aiida.engine import Process from aiida.orm.utils.node import is_valid_node_type_string @@ -79,63 +80,63 @@ def get_querybuilder_classifiers_from_cls(cls, qb): classifiers['process_type_string'] = None # Nodes - if issubclass(cls, qb.Node): + if issubclass(cls, query.Node): # If a backend ORM node (i.e. DbNode) is passed. # Users shouldn't do that, by why not... - classifiers['ormclass_type_string'] = qb.AiidaNode._plugin_type_string + classifiers['ormclass_type_string'] = query.AiidaNode._plugin_type_string ormclass = cls - elif issubclass(cls, qb.AiidaNode): + elif issubclass(cls, query.AiidaNode): classifiers['ormclass_type_string'] = cls._plugin_type_string - ormclass = qb.Node + ormclass = query.Node # Groups: - elif issubclass(cls, qb.Group): + elif issubclass(cls, query.Group): classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string ormclass = cls elif issubclass(cls, groups.Group): classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string - ormclass = qb.Group + ormclass = query.Group # Computers: - elif issubclass(cls, qb.Computer): + elif issubclass(cls, query.Computer): classifiers['ormclass_type_string'] = 'computer' ormclass = cls elif issubclass(cls, computers.Computer): classifiers['ormclass_type_string'] = 'computer' - ormclass = qb.Computer + ormclass = query.Computer # Users - elif issubclass(cls, qb.User): + elif issubclass(cls, query.User): classifiers['ormclass_type_string'] = 'user' ormclass = cls elif issubclass(cls, users.User): classifiers['ormclass_type_string'] = 'user' - ormclass = qb.User + ormclass = query.User # AuthInfo - elif issubclass(cls, qb.AuthInfo): + elif issubclass(cls, query.AuthInfo): classifiers['ormclass_type_string'] = 'authinfo' ormclass = cls elif issubclass(cls, authinfos.AuthInfo): classifiers['ormclass_type_string'] = 'authinfo' - ormclass = qb.AuthInfo + ormclass = query.AuthInfo # Comment - elif issubclass(cls, qb.Comment): + elif issubclass(cls, query.Comment): classifiers['ormclass_type_string'] = 'comment' ormclass = cls elif issubclass(cls, comments.Comment): classifiers['ormclass_type_string'] = 'comment' - ormclass = qb.Comment + ormclass = query.Comment # Log - elif issubclass(cls, qb.Log): + elif issubclass(cls, query.Log): classifiers['ormclass_type_string'] = 'log' ormclass = cls elif issubclass(cls, logs.Log): classifiers['ormclass_type_string'] = 'log' - ormclass = qb.Log + ormclass = query.Log # Process # This is a special case, since Process is not an ORM class. @@ -143,23 +144,23 @@ def get_querybuilder_classifiers_from_cls(cls, qb): elif issubclass(cls, Process): classifiers['ormclass_type_string'] = cls._node_class._plugin_type_string classifiers['process_type_string'] = cls.build_process_type() - ormclass = qb.Node + ormclass = query.Node else: raise InputValidationError('I do not know what to do with {}'.format(cls)) - if ormclass == qb.Node: + if ormclass == query.Node: is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) return ormclass, classifiers -def get_querybuilder_classifiers_from_type(ormclass_type_string, qb): +def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pylint: disable=invalid-name """ Return the correct classifiers for the QueryBuilder from an ORM type string. :param ormclass_type_string: type string for ORM class - :param qb: an instance of the appropriate QueryBuilder backend. + :param query: an instance of the appropriate QueryBuilder backend. :returns: the ORM class as well as a dictionary with additional classifier strings :rtype: cls, dict @@ -174,18 +175,18 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, qb): if classifiers['ormclass_type_string'].startswith(GROUP_ENTITY_TYPE_PREFIX): classifiers['ormclass_type_string'] = 'group.core' - ormclass = qb.Group + ormclass = query.Group elif classifiers['ormclass_type_string'] == 'computer': - ormclass = qb.Computer + ormclass = query.Computer elif classifiers['ormclass_type_string'] == 'user': - ormclass = qb.User + ormclass = query.User else: # At this point, we assume it is a node. The only valid type string then is a string # that matches exactly the _plugin_type_string of a node class classifiers['ormclass_type_string'] = ormclass_type_string # no lowercase - ormclass = qb.Node + ormclass = query.Node - if ormclass == qb.Node: + if ormclass == query.Node: is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) return ormclass, classifiers @@ -233,7 +234,6 @@ def get_process_type_filter(classifiers, subclassing): from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaEntryPointWarning from aiida.engine.processes.process import get_query_string_from_process_type_string - import warnings value = classifiers['process_type_string'] @@ -246,10 +246,16 @@ def get_process_type_filter(classifiers, subclassing): # Note: the process_type_string stored in the database does *not* end in a dot. # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', # we need to search separately for equality and 'like'. - filters = {'or': [ - {'==': value}, - {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, - ]} + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } elif value.startswith('aiida.engine'): # For core process types, a filter is not is needed since each process type has a corresponding # ormclass type that already specifies everything. @@ -259,14 +265,21 @@ def get_process_type_filter(classifiers, subclassing): # Note: Improve this when issue #2475 is addressed filters = {'like': '%'} else: - warnings.warn("Process type '{}' does not correspond to a registered entry. " - 'This risks queries to fail once the location of the process class changes. ' - "Add an entry point for '{}' to remove this warning.".format(value, value), - AiidaEntryPointWarning) - filters = {'or': [ - {'==': value}, - {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, - ]} + warnings.warn( + "Process type '{value}' does not correspond to a registered entry. " + 'This risks queries to fail once the location of the process class changes. ' + "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning + ) + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } return filters @@ -314,6 +327,8 @@ class QueryBuilder: """ + # pylint: disable=too-many-instance-attributes,too-many-public-methods + # This tag defines how edges are tagged (labeled) by the QueryBuilder default # namely tag of first entity + _EDGE_TAG_DELIM + tag of second entity _EDGE_TAG_DELIM = '--' @@ -362,12 +377,18 @@ def __init__(self, backend=None, **kwargs): # A dictionary tag:alias of ormclass # redundant but makes life easier self.tag_to_alias_map = {} + self.tag_to_projected_property_dict = {} # A dictionary tag: filter specification for this alias self._filters = {} # A dictionary tag: projections for this alias self._projections = {} + self.nr_of_projections = 0 + + self._attrkeys_as_in_sql_result = None + + self._query = None # A dictionary for classes passed to the tag given to them # Everything is specified with unique tags, which are strings. @@ -385,20 +406,10 @@ def __init__(self, backend=None, **kwargs): # is used twice. In that case, the user has to provide a tag! self._cls_to_tag_map = {} - # Hashing the the internal queryhelp allows me to avoid to build a query again, if i have used - # it already. - # Example: - - ## User is building a query: - # qb = QueryBuilder().append(.....) - ## User asks for the first results: - # qb.first() - ## User asks for all results, of the same query: - # qb.all() - # In above example, I can reuse the query, and to track whether somethis was changed - # I record a hash: + # Hashing the the internal queryhelp allows me to avoid to build a query again self._hash = None - ## The hash being None implies that the query will be build (Check the code in .get_query + + # The hash being None implies that the query will be build (Check the code in .get_query # The user can inject a query, this keyword stores whether this was done. # Check QueryBuilder.inject_query self._injected = False @@ -414,7 +425,6 @@ def __init__(self, backend=None, **kwargs): for path_spec in path: if isinstance(path_spec, dict): self.append(**path_spec) - # ~ except TypeError as e: elif isinstance(path_spec, str): # Maybe it is just a string, # I assume user means the type @@ -454,10 +464,12 @@ def __init__(self, backend=None, **kwargs): # If kwargs is not empty, there is a problem: if kwargs: valid_keys = ('path', 'filters', 'project', 'limit', 'offset', 'order_by') - raise InputValidationError('Received additional keywords: {}' - '\nwhich I cannot process' - '\nValid keywords are: {}' - ''.format(list(kwargs.keys()), valid_keys)) + raise InputValidationError( + 'Received additional keywords: {}' + '\nwhich I cannot process' + '\nValid keywords are: {}' + ''.format(list(kwargs.keys()), valid_keys) + ) def __str__(self): """ @@ -504,9 +516,9 @@ def _get_ormclass(self, cls, ormclass_type_string): ormclass = None classifiers = [] - for i, c in enumerate(input_info): - new_ormclass, new_classifiers = func(c, self._impl) - if i: + for index, classifier in enumerate(input_info): + new_ormclass, new_classifiers = func(classifier, self._impl) + if index: # This is not my first iteration! # I check consistency with what was specified before if new_ormclass != ormclass: @@ -556,8 +568,8 @@ def get_tag_from_type(classifiers): """ if isinstance(classifiers, list): return '-'.join([t['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' for t in classifiers]) - else: - return classifiers['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' + + return classifiers['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' basetag = get_tag_from_type(classifiers) tags_used = self.tag_to_alias_map.keys() @@ -568,18 +580,20 @@ def get_tag_from_type(classifiers): raise RuntimeError('Cannot find a tag after 100 tries') - def append(self, - cls=None, - entity_type=None, - tag=None, - filters=None, - project=None, - subclassing=True, - edge_tag=None, - edge_filters=None, - edge_project=None, - outerjoin=False, - **kwargs): + def append( + self, + cls=None, + entity_type=None, + tag=None, + filters=None, + project=None, + subclassing=True, + edge_tag=None, + edge_filters=None, + edge_project=None, + outerjoin=False, + **kwargs + ): """ Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. @@ -638,31 +652,35 @@ def append(self, :return: self :rtype: :class:`aiida.orm.QueryBuilder` """ + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## # This function can be called by users, so I am checking the # input now. # First of all, let's make sure the specified # the class or the type (not both) - if cls and entity_type: - raise InputValidationError('You cannot specify both a class ({}) and a entity_type ({})'.format(cls, entity_type)) - if not (cls or entity_type): + if cls is not None and entity_type is not None: + raise InputValidationError( + 'You cannot specify both a class ({}) and a entity_type ({})'.format(cls, entity_type) + ) + + if cls is None and entity_type is None: raise InputValidationError('You need to specify at least a class or a entity_type') # Let's check if it is a valid class or type if cls: if isinstance(cls, (tuple, list, set)): - for c in cls: - if not inspect_isclass(c): - raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(c)) + for sub_cls in cls: + if not inspect_isclass(sub_cls): + raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(sub_cls)) else: if not inspect_isclass(cls): raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(cls)) - elif entity_type: + elif entity_type is not None: if isinstance(entity_type, (tuple, list, set)): - for t in entity_type: - if not isinstance(t, str): - raise InputValidationError('{} was passed as entity_type, but is not a string'.format(t)) + for sub_type in entity_type: + if not isinstance(sub_type, str): + raise InputValidationError('{} was passed as entity_type, but is not a string'.format(sub_type)) else: if not isinstance(entity_type, str): raise InputValidationError('{} was passed as entity_type, but is not a string'.format(entity_type)) @@ -673,10 +691,11 @@ def append(self, # Let's get a tag if tag: if self._EDGE_TAG_DELIM in tag: - raise InputValidationError('tag cannot contain {}\n' - 'since this is used as a delimiter for links' - ''.format(self._EDGE_TAG_DELIM)) - tag = tag + raise InputValidationError( + 'tag cannot contain {}\n' + 'since this is used as a delimiter for links' + ''.format(self._EDGE_TAG_DELIM) + ) if tag in self.tag_to_alias_map.keys(): raise InputValidationError('This tag ({}) is already in use'.format(tag)) else: @@ -687,7 +706,6 @@ def append(self, # Now, several things can go wrong along the way, so I need to split into # atomic blocks that I can reverse if something goes wrong. # TAG MAPPING ################################# - # TODO check with duplicate classes # Let's fill the cls_to_tag_map so that one can specify # this vertice in a joining specification later @@ -716,10 +734,10 @@ def append(self, # ALIASING ############################## try: self.tag_to_alias_map[tag] = aliased(ormclass) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append, cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -742,10 +760,10 @@ def append(self, # if the user specified a filter, add it: if filters is not None: self.add_filter(tag, filters) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part filters), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag) @@ -757,18 +775,19 @@ def append(self, self._projections[tag] = [] if project is not None: self.add_projection(tag, project) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part projections), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) self._filters.pop(tag) self._projections.pop(tag) - raise e + raise exception # JOINING ##################################### + # pylint: disable=too-many-nested-blocks try: # Get the functions that are implemented: spec_to_function_map = [] @@ -784,12 +803,16 @@ def append(self, '{} is not a valid keyword ' 'for joining specification\n' 'Valid keywords are: ' - '{}'.format(key, - spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'])) + '{}'.format( + key, spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'] + ) + ) elif joining_keyword: - raise InputValidationError('You already specified joining specification {}\n' - 'But you now also want to specify {}' - ''.format(joining_keyword, key)) + raise InputValidationError( + 'You already specified joining specification {}\n' + 'But you now also want to specify {}' + ''.format(joining_keyword, key) + ) else: joining_keyword = key if joining_keyword == 'direction': @@ -804,9 +827,11 @@ def append(self, raise InputValidationError('direction=0 is not valid') joining_value = self._path[-abs(val)]['tag'] except IndexError as exc: - raise InputValidationError('You have specified a non-existent entity with\n' - 'direction={}\n' - '{}\n'.format(joining_value, exc)) + raise InputValidationError( + 'You have specified a non-existent entity with\n' + 'direction={}\n' + '{}\n'.format(joining_value, exc) + ) else: joining_value = self._get_tag_from_specification(val) # the default is that this vertice is 'with_incoming' as the previous one @@ -814,18 +839,17 @@ def append(self, joining_keyword = 'with_incoming' joining_value = self._path[-1]['tag'] - - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part joining), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) self._filters.pop(tag) self._projections.pop(tag) # There's not more to clean up here! - raise e + raise exception # EDGES ################################# if len(self._path) > 0: @@ -856,7 +880,7 @@ def append(self, self._projections[edge_tag] = [] if edge_project is not None: self.add_projection(edge_tag, edge_project) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part joining), cleaning up') @@ -872,7 +896,7 @@ def append(self, self._filters.pop(edge_tag, None) self._projections.pop(edge_tag, None) # There's not more to clean up here! - raise e + raise exception # EXTENDING THE PATH ################################# # Note: 'type' being a list is a relict of an earlier implementation @@ -889,7 +913,9 @@ def append(self, joining_keyword=joining_keyword, joining_value=joining_value, outerjoin=outerjoin, - edge_tag=edge_tag)) + edge_tag=edge_tag + ) + ) return self @@ -927,7 +953,7 @@ def order_by(self, order_by): qb.append(Node, tag='node') qb.order_by({'node':[{'id':'desc'}]}) """ - + # pylint: disable=too-many-nested-blocks,too-many-branches self._order_by = [] allowed_keys = ('cast', 'order') possible_orders = ('asc', 'desc') @@ -937,10 +963,12 @@ def order_by(self, order_by): for order_spec in order_by: if not isinstance(order_spec, dict): - raise InputValidationError('Invalid input for order_by statement: {}\n' - 'I am expecting a dictionary ORMClass,' - '[columns to sort]' - ''.format(order_spec)) + raise InputValidationError( + 'Invalid input for order_by statement: {}\n' + 'I am expecting a dictionary ORMClass,' + '[columns to sort]' + ''.format(order_spec) + ) _order_spec = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): @@ -953,9 +981,11 @@ def order_by(self, order_by): elif isinstance(item_to_order_by, dict): pass else: - raise InputValidationError('Cannot deal with input to order_by {}\n' - 'of type{}' - '\n'.format(item_to_order_by, type(item_to_order_by))) + raise InputValidationError( + 'Cannot deal with input to order_by {}\n' + 'of type{}' + '\n'.format(item_to_order_by, type(item_to_order_by)) + ) for entityname, orderspec in item_to_order_by.items(): # if somebody specifies eg {'node':{'id':'asc'}} # tranform to {'node':{'id':{'order':'asc'}}} @@ -965,21 +995,27 @@ def order_by(self, order_by): elif isinstance(orderspec, dict): this_order_spec = orderspec else: - raise InputValidationError('I was expecting a string or a dictionary\n' - 'You provided {} {}\n' - ''.format(type(orderspec), orderspec)) - for key in this_order_spec.keys(): + raise InputValidationError( + 'I was expecting a string or a dictionary\n' + 'You provided {} {}\n' + ''.format(type(orderspec), orderspec) + ) + for key in this_order_spec: if key not in allowed_keys: - raise InputValidationError('The allowed keys for an order specification\n' - 'are {}\n' - '{} is not valid\n' - ''.format(', '.join(allowed_keys), key)) + raise InputValidationError( + 'The allowed keys for an order specification\n' + 'are {}\n' + '{} is not valid\n' + ''.format(', '.join(allowed_keys), key) + ) this_order_spec['order'] = this_order_spec.get('order', 'asc') if this_order_spec['order'] not in possible_orders: - raise InputValidationError('You gave {} as an order parameters,\n' - 'but it is not a valid order parameter\n' - 'Valid orders are: {}\n' - ''.format(this_order_spec['order'], possible_orders)) + raise InputValidationError( + 'You gave {} as an order parameters,\n' + 'but it is not a valid order parameter\n' + 'Valid orders are: {}\n' + ''.format(this_order_spec['order'], possible_orders) + ) item_to_order_by[entityname] = this_order_spec _order_spec[tag].append(item_to_order_by) @@ -1009,7 +1045,9 @@ def add_filter(self, tagspec, filter_spec): tag = self._get_tag_from_specification(tagspec) self._filters[tag].update(filters) - def _process_filters(self, filters): + @staticmethod + def _process_filters(filters): + """Process filters.""" if not isinstance(filters, dict): raise InputValidationError('Filters have to be passed as dictionaries') @@ -1036,8 +1074,8 @@ def _add_node_type_filter(self, tagspec, classifiers, subclassing): if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers entity_type_filter = {'or': []} - for c in classifiers: - entity_type_filter['or'].append(get_node_type_filter(c, subclassing)) + for classifier in classifiers: + entity_type_filter['or'].append(get_node_type_filter(classifier, subclassing)) else: entity_type_filter = get_node_type_filter(classifiers, subclassing) @@ -1056,9 +1094,9 @@ def _add_process_type_filter(self, tagspec, classifiers, subclassing): if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers process_type_filter = {'or': []} - for c in classifiers: - if c['process_type_string'] is not None: - process_type_filter['or'].append(get_process_type_filter(c, subclassing)) + for classifier in classifiers: + if classifier['process_type_string'] is not None: + process_type_filter['or'].append(get_process_type_filter(classifier, subclassing)) if len(process_type_filter['or']) > 0: self.add_filter(tagspec, {'process_type': process_type_filter}) @@ -1147,12 +1185,14 @@ def add_projection(self, tag_spec, projection_spec): _thisprojection = {projection: {}} else: raise InputValidationError('Cannot deal with projection specification {}\n'.format(projection)) - for p, spec in _thisprojection.items(): + for spec in _thisprojection.values(): if not isinstance(spec, dict): - raise InputValidationError('\nThe value of a key-value pair in a projection\n' - 'has to be a dictionary\n' - 'You gave: {}\n' - ''.format(spec)) + raise InputValidationError( + '\nThe value of a key-value pair in a projection\n' + 'has to be a dictionary\n' + 'You gave: {}\n' + ''.format(spec) + ) for key, val in spec.items(): if key not in self._VALID_PROJECTION_KEYS: @@ -1165,6 +1205,7 @@ def add_projection(self, tag_spec, projection_spec): self._projections[tag] = _projections def _get_projectable_entity(self, alias, column_name, attrpath, **entityspec): + """Return projectable entity for a given alias and column name.""" if attrpath or column_name in ('attributes', 'extras'): entity = self._impl.get_projectable_attribute(alias, column_name, attrpath, **entityspec) else: @@ -1186,10 +1227,12 @@ def _add_to_projections(self, alias, projectable_entity_name, cast=None, func=No if column_name == '*': if func is not None: - raise InputValidationError('Very sorry, but functions on the aliased class\n' - "(You specified '*')\n" - 'will not work!\n' - "I suggest you apply functions on a column, e.g. ('id')\n") + raise InputValidationError( + 'Very sorry, but functions on the aliased class\n' + "(You specified '*')\n" + 'will not work!\n' + "I suggest you apply functions on a column, e.g. ('id')\n" + ) self._query = self._query.add_entity(alias) else: entity_to_project = self._get_projectable_entity(alias, column_name, attr_key, cast=cast) @@ -1205,11 +1248,8 @@ def _add_to_projections(self, alias, projectable_entity_name, cast=None, func=No raise InputValidationError('\nInvalid function specification {}'.format(func)) self._query = self._query.add_columns(entity_to_project) - def get_table_columns(self, table_alias): - raise NotImplementedError - def _build_projections(self, tag, items_to_project=None): - + """Build the projections for a given tag.""" if items_to_project is None: items_to_project = self._projections.get(tag, []) @@ -1248,16 +1288,19 @@ def _get_tag_from_specification(self, specification): if specification in self.tag_to_alias_map.keys(): tag = specification else: - raise InputValidationError('tag {} is not among my known tags\n' - 'My tags are: {}'.format(specification, self.tag_to_alias_map.keys())) + raise InputValidationError( + 'tag {} is not among my known tags\n' + 'My tags are: {}'.format(specification, self.tag_to_alias_map.keys()) + ) else: if specification in self._cls_to_tag_map.keys(): tag = self._cls_to_tag_map[specification] else: - raise InputValidationError('You specified as a class for which I have to find a tag\n' - 'The classes that I can do this for are:{}\n' - 'The tags I have are: {}'.format(specification, self._cls_to_tag_map.keys(), - self.tag_to_alias_map.keys())) + raise InputValidationError( + 'You specified as a class for which I have to find a tag\n' + 'The classes that I can do this for are:{}\n' + 'The tags I have are: {}'.format(self._cls_to_tag_map.keys(), self.tag_to_alias_map.keys()) + ) return tag def set_debug(self, debug): @@ -1337,7 +1380,7 @@ def _build_filters(self, alias, filter_spec): raise if not isinstance(filter_operation_dict, dict): filter_operation_dict = {'==': filter_operation_dict} - [ + for operator, value in filter_operation_dict.items(): expressions.append( self._impl.get_filter_expr( operator, @@ -1346,8 +1389,9 @@ def _build_filters(self, alias, filter_spec): is_attribute=is_attribute, column=column, column_name=column_name, - alias=alias)) for operator, value in filter_operation_dict.items() - ] + alias=alias + ) + ) return and_(*expressions) @staticmethod @@ -1365,22 +1409,25 @@ def _check_dbentities(entities_cls_joined, entities_cls_to_join, relationship): The relationship between the two entities to make the Exception comprehensible """ + # pylint: disable=protected-access for entity, cls in (entities_cls_joined, entities_cls_to_join): if not issubclass(entity._sa_class_manager.class_, cls): - raise InputValidationError("You are attempting to join {} as '{}' of {}\n" - 'This failed because you passed:\n' - ' - {} as entity joined (expected {})\n' - ' - {} as entity to join (expected {})\n' - '\n'.format( - entities_cls_joined[0].__name__, - relationship, - entities_cls_to_join[0].__name__, - entities_cls_joined[0]._sa_class_manager.class_.__name__, - entities_cls_joined[1].__name__, - entities_cls_to_join[0]._sa_class_manager.class_.__name__, - entities_cls_to_join[1].__name__, - )) + raise InputValidationError( + "You are attempting to join {} as '{}' of {}\n" + 'This failed because you passed:\n' + ' - {} as entity joined (expected {})\n' + ' - {} as entity to join (expected {})\n' + '\n'.format( + entities_cls_joined[0].__name__, + relationship, + entities_cls_to_join[0].__name__, + entities_cls_joined[0]._sa_class_manager.class_.__name__, + entities_cls_joined[1].__name__, + entities_cls_to_join[0]._sa_class_manager.class_.__name__, + entities_cls_to_join[1].__name__, + ) + ) def _join_outputs(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1394,9 +1441,12 @@ def _join_outputs(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_incoming') aliased_edge = aliased(self._impl.Link) - self._query = self._query.join( - aliased_edge, aliased_edge.input_id == joined_entity.id, isouter=isouterjoin).join( - entity_to_join, aliased_edge.output_id == entity_to_join.id, isouter=isouterjoin) + self._query = self._query.join(aliased_edge, aliased_edge.input_id == joined_entity.id, + isouter=isouterjoin).join( + entity_to_join, + aliased_edge.output_id == entity_to_join.id, + isouter=isouterjoin + ) return aliased_edge def _join_inputs(self, joined_entity, entity_to_join, isouterjoin): @@ -1415,8 +1465,7 @@ def _join_inputs(self, joined_entity, entity_to_join, isouterjoin): self._query = self._query.join( aliased_edge, aliased_edge.output_id == joined_entity.id, - ).join( - entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) + ).join(entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) return aliased_edge def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): @@ -1426,8 +1475,7 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin :TODO: Pass an option to also show the path, if this is wanted. """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), - 'with_ancestors') + self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') link1 = aliased(self._impl.Link) link2 = aliased(self._impl.Link) @@ -1437,7 +1485,7 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin selection_walk_list = [ link1.input_id.label('ancestor_id'), link1.output_id.label('descendant_id'), - cast(0, Integer).label('depth'), + type_cast(0, Integer).label('depth'), ] if expand_path: selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) @@ -1446,13 +1494,15 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin and_( in_recursive_filters, # I apply filters for speed here link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links - )).cte(recursive=True) + ) + ).cte(recursive=True) aliased_walk = aliased(walk) selection_union_list = [ aliased_walk.c.ancestor_id.label('ancestor_id'), - link2.output_id.label('descendant_id'), (aliased_walk.c.depth + cast(1, Integer)).label('current_depth') + link2.output_id.label('descendant_id'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth') ] if expand_path: selection_union_list.append((aliased_walk.c.path + array((link2.output_id,))).label('path')) @@ -1464,13 +1514,17 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin aliased_walk, link2, link2.input_id == aliased_walk.c.descendant_id, - )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))))) # .alias() + ) + ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ) + ) # .alias() self._query = self._query.join(descendants_recursive, descendants_recursive.c.ancestor_id == joined_entity.id).join( entity_to_join, descendants_recursive.c.descendant_id == entity_to_join.id, - isouter=isouterjoin) + isouter=isouterjoin + ) return descendants_recursive.c def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): @@ -1480,8 +1534,7 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, :TODO: Pass an option to also show the path, if this is wanted. """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), - 'with_ancestors') + self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') link1 = aliased(self._impl.Link) link2 = aliased(self._impl.Link) @@ -1491,21 +1544,21 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, selection_walk_list = [ link1.input_id.label('ancestor_id'), link1.output_id.label('descendant_id'), - cast(0, Integer).label('depth'), + type_cast(0, Integer).label('depth'), ] if expand_path: selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( - and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, - LinkType.INPUT_CALC.value)))).cte(recursive=True) + and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ).cte(recursive=True) aliased_walk = aliased(walk) selection_union_list = [ link2.input_id.label('ancestor_id'), aliased_walk.c.descendant_id.label('descendant_id'), - (aliased_walk.c.depth + cast(1, Integer)).label('current_depth'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth'), ] if expand_path: selection_union_list.append((aliased_walk.c.path + array((link2.input_id,))).label('path')) @@ -1517,15 +1570,18 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, aliased_walk, link2, link2.output_id == aliased_walk.c.ancestor_id, - )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ) + ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) # I can't follow RETURN or CALL links - )) + ) + ) self._query = self._query.join(ancestors_recursive, ancestors_recursive.c.descendant_id == joined_entity.id).join( entity_to_join, ancestors_recursive.c.ancestor_id == entity_to_join.id, - isouter=isouterjoin) + isouter=isouterjoin + ) return ancestors_recursive.c def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): @@ -1544,7 +1600,8 @@ def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Group), (entity_to_join, self._impl.Node), 'with_group') aliased_group_nodes = aliased(self._impl.table_groups_nodes) self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin) + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin + ) return aliased_group_nodes def _join_groups(self, joined_entity, entity_to_join, isouterjoin): @@ -1560,7 +1617,8 @@ def _join_groups(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Group), 'with_node') aliased_group_nodes = aliased(self._impl.table_groups_nodes) self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin) + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin + ) return aliased_group_nodes def _join_creator_of(self, joined_entity, entity_to_join, isouterjoin): @@ -1587,7 +1645,8 @@ def _join_to_computer_used(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Computer), (entity_to_join, self._impl.Node), 'with_computer') self._query = self._query.join( - entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin + ) def _join_computer(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1596,7 +1655,8 @@ def _join_computer(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Computer), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin + ) def _join_group_user(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1621,7 +1681,8 @@ def _join_node_comment(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Comment), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin + ) def _join_comment_node(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1630,7 +1691,8 @@ def _join_comment_node(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.Node), 'with_comment') self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin + ) def _join_node_log(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1639,7 +1701,8 @@ def _join_node_log(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Log), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin + ) def _join_log_node(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1648,7 +1711,8 @@ def _join_log_node(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Log), (entity_to_join, self._impl.Node), 'with_log') self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin + ) def _join_user_comment(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1669,8 +1733,8 @@ def _join_comment_user(self, joined_entity, entity_to_join, isouterjoin): def _get_function_map(self): """ Map relationship type keywords to functions - The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been passed to - the qb.append functon, and the second defines the relationship with respect to a given tag. + The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been + passed to the qb.append functon, and the second defines the relationship with respect to a given tag. """ mapping = { 'node': { @@ -1722,6 +1786,7 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, :param joining_keyword: the relation on which to join :param joining_value: the tag of the nodes to be joined """ + # pylint: disable=unused-argument # Set the calling entity - to allow for the correct join relation to be set entity_type = self._path[index]['entity_type'] @@ -1743,8 +1808,11 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, try: func = self._get_function_map()[calling_entity][joining_keyword] except KeyError: - raise InputValidationError("'{}' is not a valid joining keyword for a '{}' type entity".format( - joining_keyword, calling_entity)) + raise InputValidationError( + "'{}' is not a valid joining keyword for a '{}' type entity".format( + joining_keyword, calling_entity + ) + ) if isinstance(joining_value, int): returnval = (self._aliased_path[joining_value], func) @@ -1752,10 +1820,10 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, try: returnval = self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func except KeyError: - raise InputValidationError('Key {} is unknown to the types I know about:\n' - '{}'.format( - self._get_tag_from_specification(joining_value), - self.tag_to_alias_map.keys())) + raise InputValidationError( + 'Key {} is unknown to the types I know about:\n' + '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) + ) return returnval def get_json_compatible_queryhelp(self): @@ -1820,9 +1888,11 @@ def _build_order(self, alias, entitytag, entityspec): column_name = entitytag.split('.')[0] attrpath = entitytag.split('.')[1:] if attrpath and 'cast' not in entityspec.keys(): - raise InputValidationError('In order to project ({}), I have to cast the the values,\n' - 'but you have not specified the datatype to cast to\n' - "You can do this with keyword 'cast'".format(entitytag)) + raise InputValidationError( + 'In order to project ({}), I have to cast the the values,\n' + 'but you have not specified the datatype to cast to\n' + "You can do this with keyword 'cast'".format(entitytag) + ) entity = self._get_projectable_entity(alias, column_name, attrpath, **entityspec) order = entityspec.get('order', 'asc') @@ -1834,17 +1904,10 @@ def _build(self): """ build the query and return a sqlalchemy.Query instance """ - - # self.tags_location_dict is a dictionary that - # maps the tag to its index in the list - # this is basically the mapping between the count - # of nodes traversed - # and the tag used for that node - self.tags_location_dict = {path['tag']: index for index, path in enumerate(self._path)} + # pylint: disable=too-many-branches # Starting the query by receiving a session - # Every subclass needs to have _get_session and give me the - # right session + # Every subclass needs to have _get_session and give me the right session firstalias = self.tag_to_alias_map[self._path[0]['tag']] self._query = self._impl.get_session().query(firstalias) @@ -1869,7 +1932,8 @@ def _build(self): expand_path = ((self._filters[edge_tag].get('path', None) is not None) or any(['path' in d.keys() for d in self._projections[edge_tag]])) aliased_edge = connection_func( - toconnectwith, alias, isouterjoin=isouterjoin, filter_dict=filter_dict, expand_path=expand_path) + toconnectwith, alias, isouterjoin=isouterjoin, filter_dict=filter_dict, expand_path=expand_path + ) else: aliased_edge = connection_func(toconnectwith, alias, isouterjoin=isouterjoin) if aliased_edge is not None: @@ -1881,9 +1945,10 @@ def _build(self): try: alias = self.tag_to_alias_map[tag] except KeyError: - # TODO Check KeyError before? - raise InputValidationError('You looked for tag {} among the alias list\n' - 'The tags I know are:\n{}'.format(tag, self.tag_to_alias_map.keys())) + raise InputValidationError( + 'You looked for tag {} among the alias list\n' + 'The tags I know are:\n{}'.format(tag, self.tag_to_alias_map.keys()) + ) self._query = self._query.filter(self._build_filters(alias, filter_specs)) ######################### PROJECTIONS ########################## @@ -1891,8 +1956,6 @@ def _build(self): # path was not meant to be projected # attribute of Query instance storing entities to project: - # Will be later set to this list: - entities = [] # Mapping between entities and the tag used/ given by user: self.tag_to_projected_property_dict = {} @@ -1919,17 +1982,19 @@ def _build(self): edge_tag = vertex.get('edge_tag', None) if self._debug: print('DEBUG: Checking projections for edges:') - print(' This is edge {} from {}, {} of {}'.format(edge_tag, vertex.get('tag'), - vertex.get('joining_keyword'), - vertex.get('joining_value'))) + print( + ' This is edge {} from {}, {} of {}'.format( + edge_tag, vertex.get('tag'), vertex.get('joining_keyword'), vertex.get('joining_value') + ) + ) if edge_tag is not None: self._build_projections(edge_tag) # ORDER ################################ for order_spec in self._order_by: - for tag, entities in order_spec.items(): + for tag, entity_list in order_spec.items(): alias = self.tag_to_alias_map[tag] - for entitydict in entities: + for entitydict in entity_list: for entitytag, entityspec in entitydict.items(): self._build_order(alias, entitytag, entityspec) @@ -1943,7 +2008,7 @@ def _build(self): ################ LAST BUT NOT LEAST ############################ # pop the entity that I added to start the query - self._query._entities.pop(0) + self._query._entities.pop(0) # pylint: disable=protected-access # Dirty solution coming up: # Sqlalchemy is by default de-duplicating results if possible. @@ -1953,7 +2018,7 @@ def _build(self): # We also addressed this with sqlachemy: # https://github.com/sqlalchemy/sqlalchemy/issues/4395#event-2002418814 # where the following solution was sanctioned: - self._query._has_mapper_entities = False + self._query._has_mapper_entities = False # pylint: disable=protected-access # We should monitor SQLAlchemy, for when a solution is officially supported by the API! # Make a list that helps the projection postprocessing @@ -2115,10 +2180,7 @@ def first(self): if len(result) != len(self._attrkeys_as_in_sql_result): raise Exception('length of query result does not match the number of specified projections') - return [ - self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) - for colindex, rowitem in enumerate(result) - ] + return [self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result)] def one(self): """ diff --git a/aiida/orm/utils/_repository.py b/aiida/orm/utils/_repository.py new file mode 100644 index 0000000000..aca52a0c08 --- /dev/null +++ b/aiida/orm/utils/_repository.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Class that represents the repository of a `Node` instance. + +.. deprecated:: 1.4.0 + This module has been deprecated and will be removed in `v2.0.0`. + +""" +import os +import warnings + +from aiida.common import exceptions +from aiida.common.folders import RepositoryFolder, SandboxFolder +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.repository import File, FileType + + +class Repository: + """Class that represents the repository of a `Node` instance. + + .. deprecated:: 1.4.0 + This class has been deprecated and will be removed in `v2.0.0`. + """ + + # Name to be used for the Repository section + _section_name = 'node' + + def __init__(self, uuid, is_stored, base_path=None): + self._is_stored = is_stored + self._base_path = base_path + self._temp_folder = None + self._repo_folder = RepositoryFolder(section=self._section_name, uuid=uuid) + + def __del__(self): + """Clean the sandboxfolder if it was instantiated.""" + if getattr(self, '_temp_folder', None) is not None: + self._temp_folder.erase() + + def validate_mutability(self): + """Raise if the repository is immutable. + + :raises aiida.common.ModificationNotAllowed: if repository is marked as immutable because the corresponding node + is stored + """ + if self._is_stored: + raise exceptions.ModificationNotAllowed('cannot modify the repository after the node has been stored') + + @staticmethod + def validate_object_key(key): + """Validate the key of an object. + + :param key: an object key in the repository + :raises ValueError: if the key is not a valid object key + """ + if key and os.path.isabs(key): + raise ValueError('the key must be a relative path') + + def list_objects(self, key=None): + """Return a list of the objects contained in this repository, optionally in the given sub directory. + + :param key: fully qualified identifier for the object within the repository + :return: a list of `File` named tuples representing the objects present in directory with the given key + """ + folder = self._get_base_folder() + + if key: + folder = folder.get_subfolder(key) + + objects = [] + + for filename in folder.get_content_list(): + if os.path.isdir(os.path.join(folder.abspath, filename)): + objects.append(File(filename, FileType.DIRECTORY)) + else: + objects.append(File(filename, FileType.FILE)) + + return sorted(objects, key=lambda x: x.name) + + def list_object_names(self, key=None): + """Return a list of the object names contained in this repository, optionally in the given sub directory. + + :param key: fully qualified identifier for the object within the repository + :return: a list of `File` named tuples representing the objects present in directory with the given key + """ + return [entry.name for entry in self.list_objects(key)] + + def open(self, key, mode='r'): + """Open a file handle to an object stored under the given key. + + :param key: fully qualified identifier for the object within the repository + :param mode: the mode under which to open the handle + """ + return open(self._get_base_folder().get_abs_path(key), mode=mode) + + def get_object(self, key): + """Return the object identified by key. + + :param key: fully qualified identifier for the object within the repository + :return: a `File` named tuple representing the object located at key + :raises IOError: if no object with the given key exists + """ + self.validate_object_key(key) + + try: + directory, filename = key.rsplit(os.sep, 1) + except ValueError: + directory, filename = None, key + + folder = self._get_base_folder() + + if directory: + folder = folder.get_subfolder(directory) + + filepath = os.path.join(folder.abspath, filename) + + if os.path.isdir(filepath): + return File(filename, FileType.DIRECTORY) + + if os.path.isfile(filepath): + return File(filename, FileType.FILE) + + raise IOError('object {} does not exist'.format(key)) + + def get_object_content(self, key, mode='r'): + """Return the content of a object identified by key. + + :param key: fully qualified identifier for the object within the repository + :param mode: the mode under which to open the handle + """ + with self.open(key, mode=mode) as handle: + return handle.read() + + def put_object_from_tree(self, path, key=None, contents_only=True, force=False): + """Store a new object under `key` with the contents of the directory located at `path` on this file system. + + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! + + :param path: absolute path of directory whose contents to copy to the repository + :param key: fully qualified identifier for the object within the repository + :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` + """ + if not force: + self.validate_mutability() + + self.validate_object_key(key) + + if not os.path.isabs(path): + raise ValueError('the `path` must be an absolute path') + + folder = self._get_base_folder() + + if key: + folder = folder.get_subfolder(key, create=True) + + if contents_only: + for entry in os.listdir(path): + folder.insert_path(os.path.join(path, entry)) + else: + folder.insert_path(path) + + def put_object_from_file(self, path, key, mode=None, encoding=None, force=False): + """Store a new object under `key` with contents of the file located at `path` on this file system. + + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! + + :param path: absolute path of file whose contents to copy to the repository + :param key: fully qualified identifier for the object within the repository + :param mode: the file mode with which the object will be written + Deprecated: will be removed in `v2.0.0` + :param encoding: the file encoding with which the object will be written + Deprecated: will be removed in `v2.0.0` + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` + """ + # pylint: disable=unused-argument,no-member + # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in + # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did + # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, + # will continue working the same. + if mode is not None: + warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning) + + if encoding is not None: + warnings.warn( + 'the `encoding` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning + ) + + if not force: + self.validate_mutability() + + self.validate_object_key(key) + + with open(path, mode='rb') as handle: + self.put_object_from_filelike(handle, key, mode='wb', encoding=None) + + def put_object_from_filelike(self, handle, key, mode='w', encoding='utf8', force=False): + """Store a new object under `key` with contents of filelike object `handle`. + + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! + + :param handle: filelike object with the content to be stored + :param key: fully qualified identifier for the object within the repository + :param mode: the file mode with which the object will be written + :param encoding: the file encoding with which the object will be written + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` + """ + if not force: + self.validate_mutability() + + self.validate_object_key(key) + + folder = self._get_base_folder() + + while os.sep in key: + basepath, key = key.split(os.sep, 1) + folder = folder.get_subfolder(basepath, create=True) + + folder.create_file_from_filelike(handle, key, mode=mode, encoding=encoding) + + def delete_object(self, key, force=False): + """Delete the object from the repository. + + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! + + :param key: fully qualified identifier for the object within the repository + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` + """ + if not force: + self.validate_mutability() + + self.validate_object_key(key) + + self._get_base_folder().remove_path(key) + + def erase(self, force=False): + """Delete the repository folder. + + .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. + This check can be avoided by using the `force` flag, but this should be used with extreme caution! + + :param force: boolean, if True, will skip the mutability check + :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` + """ + if not force: + self.validate_mutability() + + self._get_base_folder().erase() + + def store(self): + """Store the contents of the sandbox folder into the repository folder.""" + if self._is_stored: + raise exceptions.ModificationNotAllowed('repository is already stored') + + self._repo_folder.replace_with_folder(self._get_temp_folder().abspath, move=True, overwrite=True) + self._is_stored = True + + def restore(self): + """Move the contents from the repository folder back into the sandbox folder.""" + if not self._is_stored: + raise exceptions.ModificationNotAllowed('repository is not yet stored') + + self._temp_folder.replace_with_folder(self._repo_folder.abspath, move=True, overwrite=True) + self._is_stored = False + + def _get_base_folder(self): + """Return the base sub folder in the repository. + + :return: a Folder object. + """ + if self._is_stored: + folder = self._repo_folder + else: + folder = self._get_temp_folder() + + if self._base_path is not None: + folder = folder.get_subfolder(self._base_path, reset_limit=True) + folder.create() + + return folder + + def _get_temp_folder(self): + """Return the temporary sandbox folder. + + :return: a SandboxFolder object mapping the node in the repository. + """ + if self._temp_folder is None: + self._temp_folder = SandboxFolder() + + return self._temp_folder diff --git a/aiida/orm/utils/builders/computer.py b/aiida/orm/utils/builders/computer.py index 0de26f6ff3..a9acbafee5 100644 --- a/aiida/orm/utils/builders/computer.py +++ b/aiida/orm/utils/builders/computer.py @@ -36,9 +36,9 @@ def get_computer_spec(computer): spec = {} spec['label'] = computer.label spec['description'] = computer.description - spec['hostname'] = computer.get_hostname() - spec['scheduler'] = computer.get_scheduler_type() - spec['transport'] = computer.get_transport_type() + spec['hostname'] = computer.hostname + spec['scheduler'] = computer.scheduler_type + spec['transport'] = computer.transport_type spec['prepend_text'] = computer.get_prepend_text() spec['append_text'] = computer.get_append_text() spec['work_dir'] = computer.get_workdir() @@ -70,11 +70,11 @@ def new(self): passed_keys = set(self._computer_spec.keys()) used = set() - computer = Computer(name=self._get_and_count('label', used), hostname=self._get_and_count('hostname', used)) + computer = Computer(label=self._get_and_count('label', used), hostname=self._get_and_count('hostname', used)) - computer.set_description(self._get_and_count('description', used)) - computer.set_scheduler_type(self._get_and_count('scheduler', used)) - computer.set_transport_type(self._get_and_count('transport', used)) + computer.description = self._get_and_count('description', used) + computer.scheduler_type = self._get_and_count('scheduler', used) + computer.transport_type = self._get_and_count('transport', used) computer.set_prepend_text(self._get_and_count('prepend_text', used)) computer.set_append_text(self._get_and_count('append_text', used)) computer.set_workdir(self._get_and_count('work_dir', used)) diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index d00b69994d..db504ab6bc 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -128,7 +128,10 @@ def __init__(self, node): :param node: the node object. """ # Possibly add checks here - self._node = node + # We cannot set `self._node` because it would go through the __setattr__ method + # which uses said _node by calling `self._node.set_attribute(name, value)`. + # Instead, we need to manually set it through the `self.__dict__` property. + self.__dict__['_node'] = node def __dir__(self): """ @@ -160,6 +163,9 @@ def __getattr__(self, name): """ return self._node.get_attribute(name) + def __setattr__(self, name, value): + self._node.set_attribute(name, value) + def __getitem__(self, name): """ Interface to get to dictionary values as a dictionary. diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index 0432964467..4103698035 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -10,22 +10,17 @@ """Utilities to operate on `Node` classes.""" from abc import ABCMeta import logging -import math -import numbers + import warnings -from collections.abc import Iterable, Mapping from aiida.common import exceptions from aiida.common.utils import strip_prefix -from aiida.common.constants import AIIDA_FLOAT_PRECISION - -# This separator character is reserved to indicate nested fields in node attribute and extras dictionaries and -# therefore is not allowed in individual attribute or extra keys. -FIELD_SEPARATOR = '.' __all__ = ( - 'load_node_class', 'get_type_string_from_class', 'get_query_type_from_type_string', 'AbstractNodeMeta', - 'validate_attribute_extra_key', 'clean_value' + 'load_node_class', + 'get_type_string_from_class', + 'get_query_type_from_type_string', + 'AbstractNodeMeta', ) @@ -51,7 +46,7 @@ def load_node_class(type_string): try: base_path = type_string.rsplit('.', 2)[0] except ValueError: - raise exceptions.EntryPointError + raise exceptions.EntryPointError from ValueError # This exception needs to be there to make migrations work that rely on the old type string starting with `node.` # Since now the type strings no longer have that prefix, we simply strip it and continue with the normal logic. @@ -160,100 +155,6 @@ def get_query_type_from_type_string(type_string): return type_string -def validate_attribute_extra_key(key): - """Validate the key for a node attribute or extra. - - :raise aiida.common.ValidationError: if the key is not a string or contains reserved separator character - """ - if not key or not isinstance(key, str): - raise exceptions.ValidationError('key for attributes or extras should be a string') - - if FIELD_SEPARATOR in key: - raise exceptions.ValidationError( - 'key for attributes or extras cannot contain the character `{}`'.format(FIELD_SEPARATOR) - ) - - -def clean_value(value): - """ - Get value from input and (recursively) replace, if needed, all occurrences - of BaseType AiiDA data nodes with their value, and List with a standard list. - It also makes a deep copy of everything - The purpose of this function is to convert data to a type which can be serialized and deserialized - for storage in the DB without its value changing. - - Note however that there is no logic to avoid infinite loops when the - user passes some perverse recursive dictionary or list. - In any case, however, this would not be storable by AiiDA... - - :param value: A value to be set as an attribute or an extra - :return: a "cleaned" value, potentially identical to value, but with - values replaced where needed. - """ - # Must be imported in here to avoid recursive imports - from aiida.orm import BaseType - - def clean_builtin(val): - """ - A function to clean build-in python values (`BaseType`). - - It mainly checks that we don't store NaN or Inf. - """ - # This is a whitelist of all the things we understand currently - if val is None or isinstance(val, (bool, str)): - return val - - # This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized - # Note that `numbers.Integral` also match booleans but they are already returned above - if isinstance(val, numbers.Integral): - return int(val) - - if isinstance(val, numbers.Real) and (math.isnan(val) or math.isinf(val)): - # see https://www.postgresql.org/docs/current/static/datatype-json.html#JSON-TYPE-MAPPING-TABLE - raise exceptions.ValidationError('nan and inf/-inf can not be serialized to the database') - - # This is for float-like types, like ``numpy.float128`` that are not json-serializable - # Note that `numbers.Real` also match booleans but they are already returned above - if isinstance(val, numbers.Real): - string_representation = '{{:.{}g}}'.format(AIIDA_FLOAT_PRECISION).format(val) - new_val = float(string_representation) - if 'e' in string_representation and new_val.is_integer(): - # This is indeed often quite unexpected, because it is going to change the type of the data - # from float to int. But anyway clean_value is changing some types, and we are also bound to what - # our current backends do. - # Currently, in both Django and SQLA (with JSONB attributes), if we store 1.e1, ..., 1.e14, 1.e15, - # they will be stored as floats; instead 1.e16, 1.e17, ... will all be stored as integer anyway, - # even if we don't run this clean_value step. - # So, for consistency, it's better if we do the conversion ourselves here, and we do it for a bit - # smaller numbers than python+[SQL+JSONB] would do (the AiiDA float precision is here 14), so the - # results are consistent, and the hashing will work also after a round trip as expected. - return int(new_val) - return new_val - - # Anything else we do not understand and we refuse - raise exceptions.ValidationError('type `{}` is not supported as it is not json-serializable'.format(type(val))) - - if isinstance(value, BaseType): - return clean_builtin(value.value) - - if isinstance(value, Mapping): - # Check dictionary before iterables - return {k: clean_value(v) for k, v in value.items()} - - if (isinstance(value, Iterable) and not isinstance(value, str)): - # list, tuple, ... but not a string - # This should also properly take care of dealing with the - # basedatatypes.List object - return [clean_value(v) for v in value] - - # If I don't know what to do I just return the value - # itself - it's not super robust, but relies on duck typing - # (e.g. if there is something that behaves like an integer - # but is not an integer, I still accept it) - - return clean_builtin(value) - - class AbstractNodeMeta(ABCMeta): # pylint: disable=too-few-public-methods """Some python black magic to set correctly the logger also in subclasses.""" diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py index e2e0249c58..49e5c3b44f 100644 --- a/aiida/orm/utils/remote.py +++ b/aiida/orm/utils/remote.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Utilities for operations on files on remote computers.""" import os @@ -73,17 +74,17 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer if pks: filters_calc['id'] = {'in': pks} - qb = orm.QueryBuilder() - qb.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) - qb.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) - qb.append(orm.User, with_node='calc', filters={'email': user.email}) + query = orm.QueryBuilder() + query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) + query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) + query.append(orm.User, with_node='calc', filters={'email': user.email}) - if qb.count() == 0: + if query.count() == 0: return None path_mapping = {} - for path, computer in qb.all(): + for path, computer in query.all(): if path is not None: path_mapping.setdefault(computer.uuid, []).append(path) diff --git a/aiida/orm/utils/repository.py b/aiida/orm/utils/repository.py index a8aa5155b2..0b15b17af5 100644 --- a/aiida/orm/utils/repository.py +++ b/aiida/orm/utils/repository.py @@ -7,301 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Class that represents the repository of a `Node` instance.""" - -import collections -import enum -import os +# pylint: disable=unused-import +"""Module shadowing original in order to print deprecation warning only when external code uses it.""" +import warnings from aiida.common import exceptions from aiida.common.folders import RepositoryFolder, SandboxFolder +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.repository import File, FileType +from ._repository import Repository as _Repository +warnings.warn( + 'this module is deprecated and will be removed in `v2.0.0`. ' + '`File` and `FileType` should be imported from `aiida.repository`.', AiidaDeprecationWarning +) -class FileType(enum.Enum): - - DIRECTORY = 0 - FILE = 1 - - -File = collections.namedtuple('File', ['name', 'type']) - - -class Repository: - """Class that represents the repository of a `Node` instance.""" - - # Name to be used for the Repository section - _section_name = 'node' - - def __init__(self, uuid, is_stored, base_path=None): - self._is_stored = is_stored - self._base_path = base_path - self._temp_folder = None - self._repo_folder = RepositoryFolder(section=self._section_name, uuid=uuid) - - def __del__(self): - """Clean the sandboxfolder if it was instantiated.""" - if getattr(self, '_temp_folder', None) is not None: - self._temp_folder.erase() - - def validate_mutability(self): - """Raise if the repository is immutable. - - :raises aiida.common.ModificationNotAllowed: if repository is marked as immutable because the corresponding node - is stored - """ - if self._is_stored: - raise exceptions.ModificationNotAllowed('cannot modify the repository after the node has been stored') - - @staticmethod - def validate_object_key(key): - """Validate the key of an object. - - :param key: an object key in the repository - :raises ValueError: if the key is not a valid object key - """ - if key and os.path.isabs(key): - raise ValueError('the key must be a relative path') - - def list_objects(self, key=None): - """Return a list of the objects contained in this repository, optionally in the given sub directory. - - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key - """ - folder = self._get_base_folder() - - if key: - folder = folder.get_subfolder(key) - - objects = [] - - for filename in folder.get_content_list(): - if os.path.isdir(os.path.join(folder.abspath, filename)): - objects.append(File(filename, FileType.DIRECTORY)) - else: - objects.append(File(filename, FileType.FILE)) - - return sorted(objects, key=lambda x: x.name) - - def list_object_names(self, key=None): - """Return a list of the object names contained in this repository, optionally in the given sub directory. - - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key - """ - return [entry.name for entry in self.list_objects(key)] - - def open(self, key, mode='r'): - """Open a file handle to an object stored under the given key. - - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - return open(self._get_base_folder().get_abs_path(key), mode=mode) - - def get_object(self, key): - """Return the object identified by key. - - :param key: fully qualified identifier for the object within the repository - :return: a `File` named tuple representing the object located at key - :raises IOError: if no object with the given key exists - """ - self.validate_object_key(key) - - try: - directory, filename = key.rsplit(os.sep, 1) - except ValueError: - directory, filename = None, key - - folder = self._get_base_folder() - - if directory: - folder = folder.get_subfolder(directory) - - filepath = os.path.join(folder.abspath, filename) - - if os.path.isdir(filepath): - return File(filename, FileType.DIRECTORY) - - if os.path.isfile(filepath): - return File(filename, FileType.FILE) - - raise IOError('object {} does not exist'.format(key)) - - def get_object_content(self, key, mode='r'): - """Return the content of a object identified by key. - - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - with self.open(key, mode=mode) as handle: - return handle.read() - - def put_object_from_tree(self, path, key=None, contents_only=True, force=False): - """Store a new object under `key` with the contents of the directory located at `path` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param path: absolute path of directory whose contents to copy to the repository - :param key: fully qualified identifier for the object within the repository - :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - if not os.path.isabs(path): - raise ValueError('the `path` must be an absolute path') - - folder = self._get_base_folder() - - if key: - folder = folder.get_subfolder(key, create=True) - - if contents_only: - for entry in os.listdir(path): - folder.insert_path(os.path.join(path, entry)) - else: - folder.insert_path(path) - - def put_object_from_file(self, path, key, mode=None, encoding=None, force=False): - """Store a new object under `key` with contents of the file located at `path` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param path: absolute path of file whose contents to copy to the repository - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param encoding: the file encoding with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - # pylint: disable=unused-argument,no-member - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - - # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in - # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did - # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, - # will continue working the same. - if mode is not None: - warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning) - - if encoding is not None: - warnings.warn( - 'the `encoding` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning - ) - - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - with open(path, mode='rb') as handle: - self.put_object_from_filelike(handle, key, mode='wb', encoding=None) - - def put_object_from_filelike(self, handle, key, mode='w', encoding='utf8', force=False): - """Store a new object under `key` with contents of filelike object `handle`. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param handle: filelike object with the content to be stored - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - :param encoding: the file encoding with which the object will be written - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - folder = self._get_base_folder() - - while os.sep in key: - basepath, key = key.split(os.sep, 1) - folder = folder.get_subfolder(basepath, create=True) - - folder.create_file_from_filelike(handle, key, mode=mode, encoding=encoding) - - def delete_object(self, key, force=False): - """Delete the object from the repository. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param key: fully qualified identifier for the object within the repository - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - self._get_base_folder().remove_path(key) - - def erase(self, force=False): - """Delete the repository folder. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self._get_base_folder().erase() - - def store(self): - """Store the contents of the sandbox folder into the repository folder.""" - if self._is_stored: - raise exceptions.ModificationNotAllowed('repository is already stored') - - self._repo_folder.replace_with_folder(self._get_temp_folder().abspath, move=True, overwrite=True) - self._is_stored = True - - def restore(self): - """Move the contents from the repository folder back into the sandbox folder.""" - if not self._is_stored: - raise exceptions.ModificationNotAllowed('repository is not yet stored') - - self._temp_folder.replace_with_folder(self._repo_folder.abspath, move=True, overwrite=True) - self._is_stored = False - - def _get_base_folder(self): - """Return the base sub folder in the repository. - - :return: a Folder object. - """ - if self._is_stored: - folder = self._repo_folder - else: - folder = self._get_temp_folder() - - if self._base_path is not None: - folder = folder.get_subfolder(self._base_path, reset_limit=True) - folder.create() - - return folder - - def _get_temp_folder(self): - """Return the temporary sandbox folder. - :return: a SandboxFolder object mapping the node in the repository. - """ - if self._temp_folder is None: - self._temp_folder = SandboxFolder() +class Repository(_Repository): + """Class shadowing original class in order to print deprecation warning when external code uses it.""" - return self._temp_folder + def __init__(self, *args, **kwargs): + warnings.warn('This class has been deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member""" + super().__init__(*args, **kwargs) diff --git a/aiida/parsers/plugins/arithmetic/add.py b/aiida/parsers/plugins/arithmetic/add.py index 67266a83b4..f043f3fc63 100644 --- a/aiida/parsers/plugins/arithmetic/add.py +++ b/aiida/parsers/plugins/arithmetic/add.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements # Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part # of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency. """Parser for an `ArithmeticAddCalculation` job.""" -from aiida.orm import Int from aiida.parsers.parser import Parser @@ -20,13 +18,10 @@ class ArithmeticAddParser(Parser): def parse(self, **kwargs): """Parse the contents of the output files stored in the `retrieved` output node.""" - try: - output_folder = self.retrieved - except AttributeError: - return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER + from aiida.orm import Int try: - with output_folder.open(self.node.get_option('output_filename'), 'r') as handle: + with self.retrieved.open(self.node.get_option('output_filename'), 'r') as handle: result = int(handle.read()) except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE @@ -37,3 +32,18 @@ def parse(self, **kwargs): if result < 0: return self.exit_codes.ERROR_NEGATIVE_NUMBER + + +class SimpleArithmeticAddParser(Parser): + """Simple parser for an `ArithmeticAddCalculation` job (for demonstration purposes only).""" + + def parse(self, **kwargs): + """Parse the contents of the output files stored in the `retrieved` output node.""" + from aiida.orm import Int + + output_folder = self.retrieved + + with output_folder.open(self.node.get_option('output_filename'), 'r') as handle: + result = int(handle.read()) + + self.out('sum', Int(result)) diff --git a/aiida/parsers/plugins/templatereplacer/doubler.py b/aiida/parsers/plugins/templatereplacer/doubler.py index 93700a3c2c..9b66f22d51 100644 --- a/aiida/parsers/plugins/templatereplacer/doubler.py +++ b/aiida/parsers/plugins/templatereplacer/doubler.py @@ -7,28 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Parser for the `TemplatereplacerCalculation` calculation job doubling a number.""" import os -from aiida.common import exceptions from aiida.orm import Dict from aiida.parsers.parser import Parser -from aiida.plugins import CalculationFactory - -TemplatereplacerCalculation = CalculationFactory('templatereplacer') class TemplatereplacerDoublerParser(Parser): + """Parser for the `TemplatereplacerCalculation` calculation job doubling a number.""" def parse(self, **kwargs): """Parse the contents of the output files retrieved in the `FolderData`.""" + output_folder = self.retrieved template = self.node.inputs.template.get_dict() - try: - output_folder = self.retrieved - except exceptions.NotExistent: - return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER - try: output_file = template['output_file_name'] except KeyError: @@ -57,8 +50,11 @@ def parse(self, **kwargs): file_path = os.path.join(retrieved_temporary_folder, retrieved_file) if not os.path.isfile(file_path): - self.logger.error('the file {} was not found in the temporary retrieved folder {}'.format( - retrieved_file, retrieved_temporary_folder)) + self.logger.error( + 'the file {} was not found in the temporary retrieved folder {}'.format( + retrieved_file, retrieved_temporary_folder + ) + ) return self.exit_codes.ERROR_READING_TEMPORARY_RETRIEVED_FILE with open(file_path, 'r', encoding='utf8') as handle: @@ -73,8 +69,7 @@ def parse(self, **kwargs): @staticmethod def parse_stdout(filelike): - """ - Parse the sum from the output of the ArithmeticAddcalculation written to standard out + """Parse the sum from the output of the ArithmeticAddcalculation written to standard out. :param filelike: filelike object containing the output :returns: the sum diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 1675ac6cb6..39633995d4 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,inconsistent-return-statements,cyclic-import +# pylint: disable=invalid-name,cyclic-import """Definition of factories to load classes from the various plugin groups.""" from inspect import isclass diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py new file mode 100644 index 0000000000..1ccf31a99e --- /dev/null +++ b/aiida/repository/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module with resources dealing with the file repository.""" +# pylint: disable=undefined-variable +from .common import * + +__all__ = (common.__all__) diff --git a/aiida/repository/common.py b/aiida/repository/common.py new file mode 100644 index 0000000000..f9dee05b0c --- /dev/null +++ b/aiida/repository/common.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=redefined-builtin +"""Module with resources common to the repository.""" +import enum +import warnings + +from aiida.common.warnings import AiidaDeprecationWarning + +__all__ = ('File', 'FileType') + + +class FileType(enum.Enum): + """Enumeration to represent the type of a file object.""" + + DIRECTORY = 0 + FILE = 1 + + +class File: + """Data class representing a file object.""" + + def __init__(self, name: str = '', file_type: FileType = FileType.DIRECTORY, type=None): + """ + + .. deprecated:: 1.4.0 + The argument `type` has been deprecated and will be removed in `v2.0.0`, use `file_type` instead. + """ + if type is not None: + warnings.warn( + 'argument `type` is deprecated and will be removed in `v2.0.0`. Use `file_type` instead.', + AiidaDeprecationWarning + ) # pylint: disable=no-member""" + file_type = type + + if not isinstance(name, str): + raise TypeError('name should be a string.') + + if not isinstance(file_type, FileType): + raise TypeError('file_type should be an instance of `FileType`.') + + self._name = name + self._file_type = file_type + + @property + def name(self) -> str: + """Return the name of the file object.""" + return self._name + + @property + def type(self) -> FileType: + """Return the file type of the file object. + + .. deprecated:: 1.4.0 + Will be removed in `v2.0.0`, use `file_type` instead. + """ + warnings.warn('property is deprecated, use `file_type` instead', AiidaDeprecationWarning) # pylint: disable=no-member""" + return self.file_type + + @property + def file_type(self) -> FileType: + """Return the file type of the file object.""" + return self._file_type + + def __iter__(self): + """Iterate over the properties.""" + warnings.warn( + '`File` has changed from named tuple into class and from `v2.0.0` will no longer be iterable', + AiidaDeprecationWarning + ) + yield self.name + yield self.file_type + + def __eq__(self, other): + return self.file_type == other.file_type and self.name == other.name diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index 8cc4df95d8..ba6f91f157 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -8,7 +8,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements """ It defines the method with all required parameters to run restapi locally. """ diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index 58b14b3384..1a77aa8762 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -472,7 +472,7 @@ def get_repo_list(node, filename=''): raise RestInputValidationError('{} is not a directory in this repository'.format(filename)) response = [] for fobj in flist: - response.append({'name': fobj.name, 'type': fobj.type.name}) + response.append({'name': fobj.name, 'type': fobj.file_type.name}) return response @staticmethod diff --git a/aiida/schedulers/datastructures.py b/aiida/schedulers/datastructures.py index 631db31164..757dd816ac 100644 --- a/aiida/schedulers/datastructures.py +++ b/aiida/schedulers/datastructures.py @@ -112,12 +112,14 @@ def is_greater_equal_one(parameter): # Validate that all fields are valid integers if they are specified, otherwise initialize them to `None` for parameter in list(cls._default_fields) + ['tot_num_mpiprocs']: - try: - setattr(resources, parameter, int(kwargs.pop(parameter))) - except KeyError: + value = kwargs.pop(parameter, None) + if value is None: setattr(resources, parameter, None) - except ValueError: - raise ValueError('`{}` must be an integer when specified'.format(parameter)) + else: + try: + setattr(resources, parameter, int(value)) + except ValueError: + raise ValueError('`{}` must be an integer when specified'.format(parameter)) if kwargs: raise ValueError('these parameters were not recognized: {}'.format(', '.join(list(kwargs.keys())))) @@ -198,7 +200,7 @@ def validate_resources(cls, **kwargs): try: resources.tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) - except (KeyError, ValueError): + except (KeyError, TypeError, ValueError): raise ValueError('`tot_num_mpiprocs` must be specified and must be an integer') if resources.tot_num_mpiprocs < 1: diff --git a/aiida/schedulers/plugins/sge.py b/aiida/schedulers/plugins/sge.py index 6bf679b735..ade7595c12 100644 --- a/aiida/schedulers/plugins/sge.py +++ b/aiida/schedulers/plugins/sge.py @@ -322,14 +322,14 @@ def _parse_joblist_output(self, retval, stdout, stderr): try: xmldata = xml.dom.minidom.parseString(stdout) except xml.parsers.expat.ExpatError: - self.logger.error('in sge._parse_joblist_output: xml parsing of stdout failed:' '{}'.format(stdout)) - raise SchedulerParsingError('Error during joblist retrieval,' 'xml parsing of stdout failed') + self.logger.error('in sge._parse_joblist_output: xml parsing of stdout failed: {}'.format(stdout)) + raise SchedulerParsingError('Error during joblist retrieval, xml parsing of stdout failed') else: self.logger.error( 'Error in sge._parse_joblist_output: retval={}; ' 'stdout={}; stderr={}'.format(retval, stdout, stderr) ) - raise SchedulerError('Error during joblist retrieval,' 'no stdout produced') + raise SchedulerError('Error during joblist retrieval, no stdout produced') try: first_child = xmldata.firstChild @@ -381,12 +381,12 @@ def _parse_joblist_output(self, retval, stdout, stderr): self.logger.error('Error in sge._parse_joblist_output:' 'no job id is given, stdout={}' \ .format(stdout)) - raise SchedulerError('Error in sge._parse_joblist_output:' 'no job id is given') + raise SchedulerError('Error in sge._parse_joblist_output: no job id is given') except IndexError: self.logger.error("No 'job_number' given for job index {} in " 'job list, stdout={}'.format(jobs.index(job) \ , stdout)) - raise IndexError('Error in sge._parse_joblist_output:' 'no job id is given') + raise IndexError('Error in sge._parse_joblist_output: no job id is given') try: job_element = job.getElementsByTagName('state').pop(0) diff --git a/aiida/schedulers/plugins/slurm.py b/aiida/schedulers/plugins/slurm.py index 70645b0bbc..f420ddba54 100644 --- a/aiida/schedulers/plugins/slurm.py +++ b/aiida/schedulers/plugins/slurm.py @@ -14,6 +14,7 @@ import re from aiida.common.escaping import escape_for_bash +from aiida.common.lang import type_check from aiida.schedulers import Scheduler, SchedulerError from aiida.schedulers.datastructures import (JobInfo, JobState, NodeNumberJobResource) @@ -214,6 +215,22 @@ def _get_joblist_command(self, jobs=None, user=None): if not isinstance(jobs, (tuple, list)): raise TypeError("If provided, the 'jobs' variable must be a string or a list of strings") joblist = jobs + + # Trick: When asking for a single job, append the same job once more. + # This helps provide a reliable way of knowing whether the squeue command failed (if its exit code is + # non-zero, _parse_joblist_output assumes that an error has occurred and raises an exception). + # When asking for a single job, squeue also returns a non-zero exit code if the corresponding job is + # no longer in the queue (stderr: "slurm_load_jobs error: Invalid job id specified"), which typically + # happens once in the life time of an AiiDA job, + # However, when providing two or more jobids via `squeue --jobs=123,234`, squeue stops caring whether + # the jobs are still in the queue and returns exit code zero irrespectively (allowing AiiDA to rely on the + # exit code for detection of real issues). + # Duplicating job ids has no other effect on the output. + # Verified on slurm versions 17.11.2, 19.05.3-2 and 20.02.2. + # See also https://github.com/aiidateam/aiida-core/issues/4326 + if len(joblist) == 1: + joblist += [joblist[0]] + command.append('--jobs={}'.format(','.join(joblist))) comm = ' '.join(command) @@ -482,21 +499,19 @@ def _parse_joblist_output(self, retval, stdout, stderr): # pylint: disable=too-many-branches,too-many-statements num_fields = len(self.fields) - # I don't raise because if I pass a list of jobs, - # I get a non-zero status - # if one of the job is not in the list anymore - # retval should be zero - # if retval != 0: - # self.logger.warning("Error in _parse_joblist_output: retval={}; " - # "stdout={}; stderr={}".format(retval, stdout, stderr)) - - # issue a warning if there is any stderr output and - # there is no line containing "Invalid job id specified", that happens - # when I ask for specific calculations, and they are all finished - if stderr.strip() and 'Invalid job id specified' not in stderr: - self.logger.warning("Warning in _parse_joblist_output, non-empty stderr='{}'".format(stderr.strip())) - if retval != 0: - raise SchedulerError('Error during squeue parsing (_parse_joblist_output function)') + # See discussion in _get_joblist_command on how we ensure that AiiDA can expect exit code 0 here. + if retval != 0: + raise SchedulerError( + """squeue returned exit code {} (_parse_joblist_output function) +stdout='{}' +stderr='{}'""".format(retval, stdout.strip(), stderr.strip()) + ) + if stderr.strip(): + self.logger.warning( + "squeue returned exit code 0 (_parse_joblist_output function) but non-empty stderr='{}'".format( + stderr.strip() + ) + ) # will contain raw data parsed from output: only lines with the # separator, and already split in fields @@ -738,3 +753,54 @@ def _parse_kill_output(self, retval, stdout, stderr): ) return True + + def parse_output(self, detailed_job_info, stdout, stderr): # pylint: disable=inconsistent-return-statements + """Parse the output of the scheduler. + + :param detailed_job_info: dictionary with the output returned by the `Scheduler.get_detailed_job_info` command. + This should contain the keys `retval`, `stdout` and `stderr` corresponding to the return value, stdout and + stderr returned by the accounting command executed for a specific job id. + :param stdout: string with the output written by the scheduler to stdout + :param stderr: string with the output written by the scheduler to stderr + :return: None or an instance of `aiida.engine.processes.exit_code.ExitCode` + :raises TypeError or ValueError: if the passed arguments have incorrect type or value + """ + from aiida.engine import CalcJob + + type_check(detailed_job_info, dict) + + try: + detailed_stdout = detailed_job_info['stdout'] + except KeyError: + raise ValueError('the `detailed_job_info` does not contain the required key `stdout`.') + + type_check(detailed_stdout, str) + + # The format of the detailed job info should be a multiline string, where the first line is the header, with + # the labels of the projected attributes. The following line should be the values of those attributes for the + # entire job. Any additional lines correspond to those values for any additional tasks that were run. + lines = detailed_stdout.splitlines() + + try: + master = lines[1] + except IndexError: + raise ValueError('the `detailed_job_info.stdout` contained less than two lines.') + + attributes = master.split('|') + + # Pop the last element if it is empty. This happens if the `master` string just finishes with a pipe + if not attributes[-1]: + attributes.pop() + + if len(self._detailed_job_info_fields) != len(attributes): + raise ValueError( + 'second line in `detailed_job_info.stdout` differs in length with schedulers `_detailed_job_info_fields' + ) + + data = dict(zip(self._detailed_job_info_fields, attributes)) + + if data['State'] == 'OUT_OF_MEMORY': + return CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_MEMORY # pylint: disable=no-member + + if data['State'] == 'TIMEOUT': + return CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME # pylint: disable=no-member diff --git a/aiida/schedulers/scheduler.py b/aiida/schedulers/scheduler.py index a22ffeea04..193b75bb61 100644 --- a/aiida/schedulers/scheduler.py +++ b/aiida/schedulers/scheduler.py @@ -41,6 +41,9 @@ class Scheduler(metaclass=abc.ABCMeta): # The class to be used for the job resource. _job_resource_class = None + def __str__(self): + return self.__class__.__name__ + @classmethod def preprocess_resources(cls, resources, default_mpiprocs_per_machine=None): """Pre process the resources. @@ -411,3 +414,15 @@ def _parse_kill_output(self, retval, stdout, stderr): :return: True if everything seems ok, False otherwise. """ + + def parse_output(self, detailed_job_info, stdout, stderr): + """Parse the output of the scheduler. + + :param detailed_job_info: dictionary with the output returned by the `Scheduler.get_detailed_job_info` command. + This should contain the keys `retval`, `stdout` and `stderr` corresponding to the return value, stdout and + stderr returned by the accounting command executed for a specific job id. + :param stdout: string with the output written by the scheduler to stdout + :param stderr: string with the output written by the scheduler to stderr + :return: None or an instance of `aiida.engine.processes.exit_code.ExitCode` + """ + raise exceptions.FeatureNotAvailable('output parsing is not available for `{}`'.format(self.__class__.__name__)) diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index eb9e2bef44..21330687a4 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -11,10 +11,7 @@ Various utilities to deal with KpointsData instances or create new ones (e.g. band paths, kpoints from a parsed input text file, ...) """ - from aiida.orm import KpointsData, Dict -from aiida.tools.data.array.kpoints import legacy -from aiida.tools.data.array.kpoints import seekpath __all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') @@ -49,15 +46,6 @@ def get_kpoints_path(structure, method='seekpath', **kwargs): if method not in _GET_KPOINTS_PATH_METHODS.keys(): raise ValueError("the method '{}' is not implemented".format(method)) - if method == 'seekpath': - try: - seekpath.check_seekpath_is_installed() - except ImportError: - raise ValueError( - "selected method is 'seekpath' but the package is not installed\n" - "Either install it or pass method='legacy' as input to the function call" - ) - method = _GET_KPOINTS_PATH_METHODS[method] return method(structure, **kwargs) @@ -94,15 +82,6 @@ def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): raise ValueError("the method '{}' is not implemented".format(method)) - if method == 'seekpath': - try: - seekpath.check_seekpath_is_installed() - except ImportError: - raise ValueError( - "selected method is 'seekpath' but the package is not installed\n" - "Either install it or pass method='legacy' as input to the function call" - ) - method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] return method(structure, **kwargs) @@ -131,6 +110,8 @@ def _seekpath_get_kpoints_path(structure, **kwargs): :param symprec: the symmetry precision used internally by SPGLIB :param angle_tolerance: the angle_tolerance used internally by SPGLIB """ + from aiida.tools.data.array.kpoints import seekpath + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] @@ -169,6 +150,8 @@ def _seekpath_get_explicit_kpoints_path(structure, **kwargs): :param symprec: the symmetry precision used internally by SPGLIB :param angle_tolerance: the angle_tolerance used internally by SPGLIB """ + from aiida.tools.data.array.kpoints import seekpath + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] @@ -189,6 +172,8 @@ def _legacy_get_kpoints_path(structure, **kwargs): :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info """ + from aiida.tools.data.array.kpoints import legacy + args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] args_unknown = set(kwargs).difference(args_recognized) @@ -219,6 +204,8 @@ def _legacy_get_explicit_kpoints_path(structure, **kwargs): :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info """ + from aiida.tools.data.array.kpoints import legacy + args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] args_unknown = set(kwargs).difference(args_recognized) diff --git a/aiida/tools/data/array/kpoints/legacy.py b/aiida/tools/data/array/kpoints/legacy.py index bc8260ff9e..d217e83282 100644 --- a/aiida/tools/data/array/kpoints/legacy.py +++ b/aiida/tools/data/array/kpoints/legacy.py @@ -7,10 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Tool to automatically determine k-points for a given structure using legacy custom implementation.""" +# pylint: disable=too-many-lines,fixme,invalid-name,too-many-arguments,too-many-locals,eval-used import numpy - _default_epsilon_length = 1e-5 _default_epsilon_angle = 1e-5 @@ -36,6 +36,7 @@ def change_reference(reciprocal_cell, kpoints, to_cartesian=True): # hence, first transpose kpoints, then multiply, finally transpose it back return numpy.transpose(numpy.dot(matrix, numpy.transpose(kpoints))) + def analyze_cell(cell=None, pbc=None): """ A function executed by the __init__ or by set_cell. @@ -50,11 +51,7 @@ def analyze_cell(cell=None, pbc=None): dimension = sum(pbc) if cell is None: - return { - 'reciprocal_cell': None, - 'dimension': dimension, - 'pbc': pbc - } + return {'reciprocal_cell': None, 'dimension': dimension, 'pbc': pbc} the_cell = numpy.array(cell) reciprocal_cell = 2. * numpy.pi * numpy.linalg.inv(the_cell).transpose() @@ -71,7 +68,6 @@ def analyze_cell(cell=None, pbc=None): cosbeta = numpy.dot(a3, a1) / c / a cosgamma = numpy.dot(a1, a2) / a / b - result = { 'a1': a1, 'a2': a2, @@ -93,9 +89,15 @@ def analyze_cell(cell=None, pbc=None): return result -def get_explicit_kpoints_path(value=None, cell=None, pbc=None, kpoint_distance=None, cartesian=False, - epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def get_explicit_kpoints_path( + value=None, + cell=None, + pbc=None, + kpoint_distance=None, + cartesian=False, + epsilon_length=_default_epsilon_length, + epsilon_angle=_default_epsilon_angle +): """ Set a path of kpoints in the Brillouin zone. @@ -131,11 +133,8 @@ def get_explicit_kpoints_path(value=None, cell=None, pbc=None, kpoint_distance=N :returns: point_coordinates, path, bravais_info, explicit_kpoints, labels """ - bravais_info = find_bravais_info( - cell=cell, pbc=pbc, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle - ) + # pylint: disable=too-many-branches,too-many-statements + bravais_info = find_bravais_info(cell=cell, pbc=pbc, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle) analysis = analyze_cell(cell, pbc) dimension = analysis['dimension'] @@ -166,18 +165,14 @@ def _is_path_2(path): if not are_three: return False - are_good = all([all([isinstance(b[0], str), - isinstance(b[1], str), - isinstance(b[2], int)]) - for b in path]) + are_good = all([all([isinstance(b[0], str), isinstance(b[1], str), isinstance(b[2], int)]) for b in path]) if not are_good: return False # check that at least two points per segment (beginning and end) points_num = [int(i[2]) for i in path] if any([i < 2 for i in points_num]): - raise ValueError('Must set at least two points per path ' - 'segment') + raise ValueError('Must set at least two points per path segment') except IndexError: return False @@ -218,8 +213,7 @@ def _is_path_4(path): # check that at least two points per segment (beginning and end) points_num = [int(i[4]) for i in path] if any([i < 2 for i in points_num]): - raise ValueError('Must set at least two points per path ' - 'segment') + raise ValueError('Must set at least two points per path segment') for i in path: coord1 = [float(j) for j in i[1]] coord2 = [float(j) for j in i[3]] @@ -232,9 +226,9 @@ def _is_path_4(path): def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): # NOTE: this way of creating intervals ensures equispaced objects # in crystal coordinates of b1,b2,b3 - distances = [numpy.linalg.norm(numpy.array(point_coordinates[i[0]]) - - numpy.array(point_coordinates[i[1]]) - ) for i in path] + distances = [ + numpy.linalg.norm(numpy.array(point_coordinates[i[0]]) - numpy.array(point_coordinates[i[1]])) for i in path + ] if kpoint_distance is None: # Use max_points_per_interval as the default guess for automatically @@ -244,8 +238,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): try: points_per_piece = [max(2, max_point_per_interval * i // max_interval) for i in distances] except ValueError: - raise ValueError('The beginning and end of each segment in the ' - 'path should be different.') + raise ValueError('The beginning and end of each segment in the path should be different.') else: points_per_piece = [max(2, int(distance // kpoint_distance)) for distance in distances] @@ -253,8 +246,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): if cartesian: if cell is None: - raise ValueError('To use cartesian coordinates, a cell must ' - 'be provided') + raise ValueError('To use cartesian coordinates, a cell must be provided') if kpoint_distance is not None: if kpoint_distance <= 0.: @@ -262,38 +254,32 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): if value is None: if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') point_coordinates, path, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = _num_points_from_coordinates(path, point_coordinates, kpoint_distance) elif _is_path_1(value): # in the form [('X','M'),(...),...] if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') path = value point_coordinates, _, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = _num_points_from_coordinates(path, point_coordinates, kpoint_distance) elif _is_path_2(value): # [('G','M',30), (...), ...] if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') path = [(i[0], i[1]) for i in value] point_coordinates, _, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = [i[2] for i in value] elif _is_path_3(value): @@ -307,10 +293,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[0]] = change_reference( - reciprocal_cell, - numpy.array([piece[1]]), - to_cartesian=False)[0] + point_coordinates[ + piece[0]] = change_reference(reciprocal_cell, numpy.array([piece[1]]), to_cartesian=False)[0] else: point_coordinates[piece[0]] = piece[1] if piece[2] in point_coordinates: @@ -318,10 +302,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[2]] = change_reference( - reciprocal_cell, - numpy.array([piece[3]]), - to_cartesian=False)[0] + point_coordinates[ + piece[2]] = change_reference(reciprocal_cell, numpy.array([piece[3]]), to_cartesian=False)[0] else: point_coordinates[piece[2]] = piece[3] @@ -338,10 +320,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[0]] = change_reference( - reciprocal_cell, - numpy.array([piece[1]]), - to_cartesian=False)[0] + point_coordinates[ + piece[0]] = change_reference(reciprocal_cell, numpy.array([piece[1]]), to_cartesian=False)[0] else: point_coordinates[piece[0]] = piece[1] if piece[2] in point_coordinates: @@ -349,10 +329,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[2]] = change_reference( - reciprocal_cell, - numpy.array([piece[3]]), - to_cartesian=False)[0] + point_coordinates[ + piece[2]] = change_reference(reciprocal_cell, numpy.array([piece[3]]), to_cartesian=False)[0] else: point_coordinates[piece[2]] = piece[3] @@ -364,25 +342,29 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): explicit_kpoints = [tuple(point_coordinates[path[0][0]])] labels = [(0, path[0][0])] + assert all([_.is_integer() for _ in num_points if isinstance(_, (float, numpy.float64))] + ), 'Could not determine number of points as a whole number. num_points={}'.format(num_points) + num_points = [int(_) for _ in num_points] + for count_piece, i in enumerate(path): ini_label = i[0] end_label = i[1] ini_coord = point_coordinates[ini_label] end_coord = point_coordinates[end_label] - path_piece = list(zip(numpy.linspace(ini_coord[0], end_coord[0], - num_points[count_piece]), - numpy.linspace(ini_coord[1], end_coord[1], - num_points[count_piece]), - numpy.linspace(ini_coord[2], end_coord[2], - num_points[count_piece]), - )) + path_piece = list( + zip( + numpy.linspace(ini_coord[0], end_coord[0], num_points[count_piece]), + numpy.linspace(ini_coord[1], end_coord[1], num_points[count_piece]), + numpy.linspace(ini_coord[2], end_coord[2], num_points[count_piece]), + ) + ) for count, j in enumerate(path_piece): if all(numpy.array(explicit_kpoints[-1]) == j): continue # avoid duplcates - else: - explicit_kpoints.append(j) + + explicit_kpoints.append(j) # add labels for the first and last point if count == 0: @@ -396,8 +378,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): return point_coordinates, path, bravais_info, explicit_kpoints, labels -def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, epsilon_angle=_default_epsilon_angle): """ Finds the Bravais lattice of the cell passed in input to the Kpoint class :note: We assume that the cell given by the cell property is the @@ -425,6 +406,7 @@ def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, the variation of the Bravais lattice) and extra (a dictionary with extra parameters used by the get_kpoints_path method) """ + # pylint: disable=too-many-branches,too-many-statements if cell is None: return None @@ -467,10 +449,8 @@ def a_are_equals(a, b): # 3D case -> 14 possible Bravais lattices # # =========================================# - comparison_length = [l_are_equals(a, b), l_are_equals(b, c), - l_are_equals(c, a)] - comparison_angles = [a_are_equals(cosa, cosb), a_are_equals(cosb, cosc), - a_are_equals(cosc, cosa)] + comparison_length = [l_are_equals(a, b), l_are_equals(b, c), l_are_equals(c, a)] + comparison_angles = [a_are_equals(cosa, cosb), a_are_equals(cosb, cosc), a_are_equals(cosc, cosa)] if comparison_length.count(True) == 3: @@ -479,68 +459,70 @@ def a_are_equals(a, b): orci_b = numpy.linalg.norm(a1 + a3) orci_c = numpy.linalg.norm(a1 + a2) orci_the_a, orci_the_b, orci_the_c = sorted([orci_a, orci_b, orci_c]) - bco1 = - (-orci_the_a ** 2 + orci_the_b ** 2 + orci_the_c ** 2) / (4. * a ** 2) - bco2 = - (orci_the_a ** 2 - orci_the_b ** 2 + orci_the_c ** 2) / (4. * a ** 2) - bco3 = - (orci_the_a ** 2 + orci_the_b ** 2 - orci_the_c ** 2) / (4. * a ** 2) + bco1 = -(-orci_the_a**2 + orci_the_b**2 + orci_the_c**2) / (4. * a**2) + bco2 = -(orci_the_a**2 - orci_the_b**2 + orci_the_c**2) / (4. * a**2) + bco3 = -(orci_the_a**2 + orci_the_b**2 - orci_the_c**2) / (4. * a**2) # ======================# # simple cubic lattice # # ======================# if comparison_angles.count(True) == 3 and a_are_equals(cosa, _90): - bravais_info = {'short_name': 'cub', - 'extended_name': 'cubic', - 'index': 1, - 'permutation': [0, 1, 2] - } + bravais_info = {'short_name': 'cub', 'extended_name': 'cubic', 'index': 1, 'permutation': [0, 1, 2]} # =====================# # face centered cubic # # =====================# elif comparison_angles.count(True) == 3 and a_are_equals(cosa, _60): - bravais_info = {'short_name': 'fcc', - 'extended_name': 'face centered cubic', - 'index': 2, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'fcc', + 'extended_name': 'face centered cubic', + 'index': 2, + 'permutation': [0, 1, 2] + } # =====================# # body centered cubic # # =====================# elif comparison_angles.count(True) == 3 and a_are_equals(cosa, -1. / 3.): - bravais_info = {'short_name': 'bcc', - 'extended_name': 'body centered cubic', - 'index': 3, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'bcc', + 'extended_name': 'body centered cubic', + 'index': 3, + 'permutation': [0, 1, 2] + } # ==============# # rhombohedral # # ==============# elif comparison_angles.count(True) == 3: # logical order is important, this check must come after the cubic cases - bravais_info = {'short_name': 'rhl', - 'extended_name': 'rhombohedral', - 'index': 11, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'rhl', + 'extended_name': 'rhombohedral', + 'index': 11, + 'permutation': [0, 1, 2] + } if cosa > 0.: bravais_info['variation'] = 'rhl1' eta = (1. + 4. * cosa) / (2. + 4. * cosa) - bravais_info['extra'] = {'eta': eta, - 'nu': 0.75 - eta / 2., - } + bravais_info['extra'] = { + 'eta': eta, + 'nu': 0.75 - eta / 2., + } else: bravais_info['variation'] = 'rhl2' eta = 1. / (2. * (1. - cosa) / (1. + cosa)) - bravais_info['extra'] = {'eta': eta, - 'nu': 0.75 - eta / 2., - } + bravais_info['extra'] = { + 'eta': eta, + 'nu': 0.75 - eta / 2., + } # ==========================# # body centered tetragonal # # ==========================# elif comparison_angles.count(True) == 1: # two angles are the same - bravais_info = {'short_name': 'bct', - 'extended_name': 'body centered tetragonal', - 'index': 5, - } + bravais_info = { + 'short_name': 'bct', + 'extended_name': 'body centered tetragonal', + 'index': 5, + } if comparison_angles.index(True) == 0: # alfa=beta ref_ang = cosa bravais_info['permutation'] = [0, 1, 2] @@ -552,31 +534,39 @@ def a_are_equals(a, b): bravais_info['permutation'] = [1, 2, 0] if ref_ang >= 0.: - raise ValueError('Problems on the definition of ' - 'body centered tetragonal lattices') - the_c = numpy.sqrt(-4. * ref_ang * (a ** 2)) - the_a = numpy.sqrt(2. * a ** 2 - (the_c ** 2) / 2.) + raise ValueError('Problems on the definition of body centered tetragonal lattices') + the_c = numpy.sqrt(-4. * ref_ang * (a**2)) + the_a = numpy.sqrt(2. * a**2 - (the_c**2) / 2.) if the_c < the_a: bravais_info['variation'] = 'bct1' - bravais_info['extra'] = {'eta': (1. + (the_c / the_a) ** 2) / 4.} + bravais_info['extra'] = {'eta': (1. + (the_c / the_a)**2) / 4.} else: bravais_info['variation'] = 'bct2' - bravais_info['extra'] = {'eta': (1. + (the_a / the_c) ** 2) / 4., - 'csi': ((the_a / the_c) ** 2) / 2., - } + bravais_info['extra'] = { + 'eta': (1. + (the_a / the_c)**2) / 4., + 'csi': ((the_a / the_c)**2) / 2., + } # ============================# # body centered orthorhombic # # ============================# - elif (any([a_are_equals(cosa, bco1), a_are_equals(cosb, bco1), a_are_equals(cosc, bco1)]) and - any([a_are_equals(cosa, bco2), a_are_equals(cosb, bco2), a_are_equals(cosc, bco2)]) and - any([a_are_equals(cosa, bco3), a_are_equals(cosb, bco3), a_are_equals(cosc, bco3)]) - ): - bravais_info = {'short_name': 'orci', - 'extended_name': 'body centered orthorhombic', - 'index': 8, - } + elif ( + any([a_are_equals(cosa, bco1), + a_are_equals(cosb, bco1), + a_are_equals(cosc, bco1)]) and + any([a_are_equals(cosa, bco2), + a_are_equals(cosb, bco2), + a_are_equals(cosc, bco2)]) and + any([a_are_equals(cosa, bco3), + a_are_equals(cosb, bco3), + a_are_equals(cosc, bco3)]) + ): + bravais_info = { + 'short_name': 'orci', + 'extended_name': 'body centered orthorhombic', + 'index': 8, + } if a_are_equals(cosa, bco1) and a_are_equals(cosc, bco3): bravais_info['permutation'] = [0, 1, 2] if a_are_equals(cosa, bco1) and a_are_equals(cosc, bco2): @@ -590,58 +580,64 @@ def a_are_equals(a, b): if a_are_equals(cosa, bco3) and a_are_equals(cosc, bco1): bravais_info['permutation'] = [2, 1, 0] - bravais_info['extra'] = {'csi': (1. + (orci_the_a / orci_the_c) ** 2) / 4., - 'eta': (1. + (orci_the_b / orci_the_c) ** 2) / 4., - 'dlt': (orci_the_b ** 2 - orci_the_a ** 2) / (4. * orci_the_c ** 2), - 'mu': (orci_the_a ** 2 + orci_the_b ** 2) / (4. * orci_the_c ** 2), - } + bravais_info['extra'] = { + 'csi': (1. + (orci_the_a / orci_the_c)**2) / 4., + 'eta': (1. + (orci_the_b / orci_the_c)**2) / 4., + 'dlt': (orci_the_b**2 - orci_the_a**2) / (4. * orci_the_c**2), + 'mu': (orci_the_a**2 + orci_the_b**2) / (4. * orci_the_c**2), + } # if it doesn't fall in the above, is triclinic else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # the check for triclinic variations is at the end of the method - - elif comparison_length.count(True) == 1: + # ============# # tetragonal # # ============# if comparison_angles.count(True) == 3 and a_are_equals(cosa, _90): - bravais_info = {'short_name': 'tet', - 'extended_name': 'tetragonal', - 'index': 4, - } - if comparison_length[0] == True: + bravais_info = { + 'short_name': 'tet', + 'extended_name': 'tetragonal', + 'index': 4, + } + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] # ====================================# # c-centered orthorombic + hexagonal # # ====================================# # alpha/=beta=gamma=pi/2 - elif (comparison_angles.count(True) == 1 and - any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), a_are_equals(cosc, _90)]) - ): + elif ( + comparison_angles.count(True) == 1 and + any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), + a_are_equals(cosc, _90)]) + ): if any([a_are_equals(cosa, _120), a_are_equals(cosb, _120), a_are_equals(cosc, _120)]): - bravais_info = {'short_name': 'hex', - 'extended_name': 'hexagonal', - 'index': 10, - } + bravais_info = { + 'short_name': 'hex', + 'extended_name': 'hexagonal', + 'index': 10, + } else: - bravais_info = {'short_name': 'orcc', - 'extended_name': 'c-centered orthorhombic', - 'index': 9, - } - if comparison_length[0] == True: + bravais_info = { + 'short_name': 'orcc', + 'extended_name': 'c-centered orthorhombic', + 'index': 9, + } + if comparison_length[0]: the_a1 = a1 the_a2 = a2 - elif comparison_length[1] == True: + elif comparison_length[1]: the_a1 = a2 the_a2 = a3 else: # comparison_length[2]==True: @@ -649,38 +645,40 @@ def a_are_equals(a, b): the_a2 = a1 the_a = numpy.linalg.norm(the_a1 + the_a2) the_b = numpy.linalg.norm(the_a1 - the_a2) - bravais_info['extra'] = {'csi': (1. + (the_a / the_b) ** 2) / 4., - } + bravais_info['extra'] = { + 'csi': (1. + (the_a / the_b)**2) / 4., + } # TODO : re-check this case, permutations look weird - if comparison_length[0] == True: + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] # =======================# # c-centered monoclinic # # =======================# elif comparison_angles.count(True) == 1: - bravais_info = {'short_name': 'mclc', - 'extended_name': 'c-centered monoclinic', - 'index': 13, - } + bravais_info = { + 'short_name': 'mclc', + 'extended_name': 'c-centered monoclinic', + 'index': 13, + } # TODO : re-check this case, permutations look weird - if comparison_length[0] == True: + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] the_ka = cosa the_a1 = a1 the_a2 = a2 the_c = c - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] the_ka = cosb the_a1 = a2 the_a2 = a3 the_c = a - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] the_ka = cosc the_a1 = a3 @@ -693,94 +691,99 @@ def a_are_equals(a, b): if a_are_equals(the_ka, _90): # order matters: has to be before the check on mclc1 bravais_info['variation'] = 'mclc2' - csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) - psi = 0.75 - the_a ** 2 / (4. * the_b * (1. - the_cosa ** 2)) - bravais_info['extra'] = {'csi': csi, - 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, - 'psi': psi, - 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, - } + csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) + psi = 0.75 - the_a**2 / (4. * the_b * (1. - the_cosa**2)) + bravais_info['extra'] = { + 'csi': csi, + 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, + 'psi': psi, + 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, + } elif the_ka < 0.: bravais_info['variation'] = 'mclc1' - csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) - psi = 0.75 - the_a ** 2 / (4. * the_b * (1. - the_cosa ** 2)) - bravais_info['extra'] = {'csi': csi, - 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, - 'psi': psi, - 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, - } + csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) + psi = 0.75 - the_a**2 / (4. * the_b * (1. - the_cosa**2)) + bravais_info['extra'] = { + 'csi': csi, + 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, + 'psi': psi, + 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, + } else: # if the_ka>0.: - x = the_b * the_cosa / the_c + the_b ** 2 * (1. - the_cosa ** 2) / the_a ** 2 + x = the_b * the_cosa / the_c + the_b**2 * (1. - the_cosa**2) / the_a**2 if a_are_equals(x, 1.): bravais_info['variation'] = 'mclc4' # order matters here too - mu = (1. + (the_b / the_a) ** 2) / 4. - dlt = the_b * the_c * the_cosa / (2. * the_a ** 2) - csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) + mu = (1. + (the_b / the_a)**2) / 4. + dlt = the_b * the_c * the_cosa / (2. * the_a**2) + csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) eta = 0.5 + 2. * csi * the_c * the_cosa / the_b phi = 1. + eta - 2. * mu psi = eta - 2. * dlt - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'phi': phi, - 'psi': psi, - } + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'phi': phi, + 'psi': psi, + } elif x < 1.: bravais_info['variation'] = 'mclc3' - mu = (1. + (the_b / the_a) ** 2) / 4. - dlt = the_b * the_c * the_cosa / (2. * the_a ** 2) - csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) + mu = (1. + (the_b / the_a)**2) / 4. + dlt = the_b * the_c * the_cosa / (2. * the_a**2) + csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) eta = 0.5 + 2. * csi * the_c * the_cosa / the_b phi = 1. + eta - 2. * mu psi = eta - 2. * dlt - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'phi': phi, - 'psi': psi, - } + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'phi': phi, + 'psi': psi, + } elif x > 1.: bravais_info['variation'] = 'mclc5' - csi = ((the_b / the_a) ** 2 + (1. - the_b * the_cosa / the_c) / (1. - the_cosa ** 2)) / 4. + csi = ((the_b / the_a)**2 + (1. - the_b * the_cosa / the_c) / (1. - the_cosa**2)) / 4. eta = 0.5 + 2. * csi * the_c * the_cosa / the_b - mu = eta / 2. + the_b ** 2 / 4. / the_a ** 2 - the_b * the_c * the_cosa / 2. / the_a ** 2 + mu = eta / 2. + the_b**2 / 4. / the_a**2 - the_b * the_c * the_cosa / 2. / the_a**2 nu = 2. * mu - csi - omg = (4. * nu - 1. - the_b ** 2 * (1. - the_cosa ** 2) / the_a ** 2) * the_c / ( - 2. * the_b * the_cosa) + omg = (4. * nu - 1. - the_b**2 * + (1. - the_cosa**2) / the_a**2) * the_c / (2. * the_b * the_cosa) dlt = csi * the_c * the_cosa / the_b + omg / 2. - 0.25 - rho = 1. - csi * the_a ** 2 / the_b ** 2 - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'rho': rho, - } + rho = 1. - csi * the_a**2 / the_b**2 + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'rho': rho, + } # if it doesn't fall in the above, is triclinic else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # the check for triclinic variations is at the end of the method - - else: # if comparison_length.count(True)==0: - fco1 = c ** 2 / numpy.sqrt((a ** 2 + c ** 2) * (b ** 2 + c ** 2)) - fco2 = a ** 2 / numpy.sqrt((a ** 2 + b ** 2) * (a ** 2 + c ** 2)) - fco3 = b ** 2 / numpy.sqrt((a ** 2 + b ** 2) * (b ** 2 + c ** 2)) + fco1 = c**2 / numpy.sqrt((a**2 + c**2) * (b**2 + c**2)) + fco2 = a**2 / numpy.sqrt((a**2 + b**2) * (a**2 + c**2)) + fco3 = b**2 / numpy.sqrt((a**2 + b**2) * (b**2 + c**2)) # ==============# # orthorhombic # # ==============# if comparison_angles.count(True) == 3: - bravais_info = {'short_name': 'orc', - 'extended_name': 'orthorhombic', - 'index': 6, - } + bravais_info = { + 'short_name': 'orc', + 'extended_name': 'orthorhombic', + 'index': 6, + } lens = [a, b, c] ind_a = lens.index(min(lens)) ind_c = lens.index(max(lens)) @@ -799,12 +802,16 @@ def a_are_equals(a, b): # ============# # monoclinic # # ============# - elif (comparison_angles.count(True) == 1 and - any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), a_are_equals(cosc, _90)])): - bravais_info = {'short_name': 'mcl', - 'extended_name': 'monoclinic', - 'index': 12, - } + elif ( + comparison_angles.count(True) == 1 and + any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), + a_are_equals(cosc, _90)]) + ): + bravais_info = { + 'short_name': 'mcl', + 'extended_name': 'monoclinic', + 'index': 12, + } lens = [a, b, c] # find the angle different from 90 # then order (if possible) a 0.: bravais_info['variation'] = 'orcf1' - bravais_info['extra'] = {'csi': (1. + (the_a / the_b) ** 2 - (the_a / the_c) ** 2) / 4., - 'eta': (1. + (the_a / the_b) ** 2 + (the_a / the_c) ** 2) / 4., - } + bravais_info['extra'] = { + 'csi': (1. + (the_a / the_b)**2 - (the_a / the_c)**2) / 4., + 'eta': (1. + (the_a / the_b)**2 + (the_a / the_c)**2) / 4., + } # orcf2 else: bravais_info['variation'] = 'orcf2' - bravais_info['extra'] = {'eta': (1. + (the_a / the_b) ** 2 - (the_a / the_c) ** 2) / 4., - 'dlt': (1. + (the_b / the_a) ** 2 + (the_b / the_c) ** 2) / 4., - 'phi': (1. + (the_c / the_b) ** 2 - (the_c / the_a) ** 2) / 4., - } + bravais_info['extra'] = { + 'eta': (1. + (the_a / the_b)**2 - (the_a / the_c)**2) / 4., + 'dlt': (1. + (the_b / the_a)**2 + (the_b / the_c)**2) / 4., + 'phi': (1. + (the_c / the_b)**2 - (the_c / the_a)**2) / 4., + } else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # ===========# # triclinic # # ===========# @@ -1015,19 +1035,21 @@ def a_are_equals(a, b): # square lattice # # ================# if comparison_angle_90 and comparison_length: - bravais_info = {'short_name': 'sq', - 'extended_name': 'square', - 'index': 1, - } + bravais_info = { + 'short_name': 'sq', + 'extended_name': 'square', + 'index': 1, + } # =========================# # (primitive) rectangular # # =========================# elif comparison_angle_90: - bravais_info = {'short_name': 'rec', - 'extended_name': 'rectangular', - 'index': 2, - } + bravais_info = { + 'short_name': 'rec', + 'extended_name': 'rectangular', + 'index': 2, + } # set the order such that first_vector < second_vector in norm if lens[0] > lens[1]: in_plane_indexes.reverse() @@ -1037,30 +1059,31 @@ def a_are_equals(a, b): # ===========# # this has to be put before the centered-rectangular case elif (l_are_equals(lens[0], lens[1]) and a_are_equals(cosphi, _120)): - bravais_info = {'short_name': 'hex', - 'extended_name': 'hexagonal', - 'index': 4, - } + bravais_info = { + 'short_name': 'hex', + 'extended_name': 'hexagonal', + 'index': 4, + } # ======================# # centered rectangular # # ======================# - elif (comparison_length and - l_are_equals(numpy.dot(vectors[0] + vectors[1], - vectors[0] - vectors[1]), 0.)): - bravais_info = {'short_name': 'recc', - 'extended_name': 'centered rectangular', - 'index': 3, - } + elif (comparison_length and l_are_equals(numpy.dot(vectors[0] + vectors[1], vectors[0] - vectors[1]), 0.)): + bravais_info = { + 'short_name': 'recc', + 'extended_name': 'centered rectangular', + 'index': 3, + } # =========# # oblique # # =========# else: - bravais_info = {'short_name': 'obl', - 'extended_name': 'oblique', - 'index': 5, - } + bravais_info = { + 'short_name': 'obl', + 'extended_name': 'oblique', + 'index': 5, + } # set the order such that first_vector < second_vector in norm if lens[0] > lens[1]: in_plane_indexes.reverse() @@ -1098,10 +1121,9 @@ def a_are_equals(a, b): return bravais_info - -def get_kpoints_path(cell, pbc=None, cartesian=False, - epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def get_kpoints_path( + cell, pbc=None, cartesian=False, epsilon_length=_default_epsilon_length, epsilon_angle=_default_epsilon_angle +): """ Get the special point and path of a given structure. @@ -1136,12 +1158,9 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, :note: We assume that the cell given by the cell property is the primitive unit cell """ + # pylint: disable=too-many-branches,too-many-statements # recognize which bravais lattice we are dealing with - bravais_info = find_bravais_info( - cell=cell, pbc=pbc, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle - ) + bravais_info = find_bravais_info(cell=cell, pbc=pbc, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle) analysis = analyze_cell(cell, pbc) dimension = analysis['dimension'] @@ -1153,98 +1172,108 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, # 3D case: 14 Bravais lattices # simple cubic if bravais_info['index'] == 1: - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0.5, 0.], - 'R': [0.5, 0.5, 0.5], - 'X': [0., 0.5, 0.], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'R'), - ('R', 'X'), - ('M', 'R'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0.5, 0.], + 'R': [0.5, 0.5, 0.5], + 'X': [0., 0.5, 0.], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'R'), + ('R', 'X'), + ('M', 'R'), + ] # face centered cubic elif bravais_info['index'] == 2: - special_points = {'G': [0., 0., 0.], - 'K': [3. / 8., 3. / 8., 0.75], - 'L': [0.5, 0.5, 0.5], - 'U': [5. / 8., 0.25, 5. / 8.], - 'W': [0.5, 0.25, 0.75], - 'X': [0.5, 0., 0.5], - } - path = [('G', 'X'), - ('X', 'W'), - ('W', 'K'), - ('K', 'G'), - ('G', 'L'), - ('L', 'U'), - ('U', 'W'), - ('W', 'L'), - ('L', 'K'), - ('U', 'X'), - ] + special_points = { + 'G': [0., 0., 0.], + 'K': [3. / 8., 3. / 8., 0.75], + 'L': [0.5, 0.5, 0.5], + 'U': [5. / 8., 0.25, 5. / 8.], + 'W': [0.5, 0.25, 0.75], + 'X': [0.5, 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'W'), + ('W', 'K'), + ('K', 'G'), + ('G', 'L'), + ('L', 'U'), + ('U', 'W'), + ('W', 'L'), + ('L', 'K'), + ('U', 'X'), + ] # body centered cubic elif bravais_info['index'] == 3: - special_points = {'G': [0., 0., 0.], - 'H': [0.5, -0.5, 0.5], - 'P': [0.25, 0.25, 0.25], - 'N': [0., 0., 0.5], - } - path = [('G', 'H'), - ('H', 'N'), - ('N', 'G'), - ('G', 'P'), - ('P', 'H'), - ('P', 'N'), - ] + special_points = { + 'G': [0., 0., 0.], + 'H': [0.5, -0.5, 0.5], + 'P': [0.25, 0.25, 0.25], + 'N': [0., 0., 0.5], + } + path = [ + ('G', 'H'), + ('H', 'N'), + ('N', 'G'), + ('G', 'P'), + ('P', 'H'), + ('P', 'N'), + ] # Tetragonal elif bravais_info['index'] == 4: - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5, 0.5], - 'M': [0.5, 0.5, 0.], - 'R': [0., 0.5, 0.5], - 'X': [0., 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'Z'), - ('Z', 'R'), - ('R', 'A'), - ('A', 'Z'), - ('X', 'R'), - ('M', 'A'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5, 0.5], + 'M': [0.5, 0.5, 0.], + 'R': [0., 0.5, 0.5], + 'X': [0., 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'Z'), + ('Z', 'R'), + ('R', 'A'), + ('A', 'Z'), + ('X', 'R'), + ('M', 'A'), + ] # body centered tetragonal elif bravais_info['index'] == 5: if bravais_info['variation'] == 'bct1': # Body centered tetragonal bct1 eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'M': [-0.5, 0.5, 0.5], - 'N': [0., 0.5, 0.], - 'P': [0.25, 0.25, 0.25], - 'X': [0., 0., 0.5], - 'Z': [eta, eta, -eta], - 'Z1': [-eta, 1. - eta, eta], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'Z'), - ('Z', 'P'), - ('P', 'N'), - ('N', 'Z1'), - ('Z1', 'M'), - ('X', 'P'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [-0.5, 0.5, 0.5], + 'N': [0., 0.5, 0.], + 'P': [0.25, 0.25, 0.25], + 'X': [0., 0., 0.5], + 'Z': [eta, eta, -eta], + 'Z1': [-eta, 1. - eta, eta], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'Z'), + ('Z', 'P'), + ('P', 'N'), + ('N', 'Z1'), + ('Z1', 'M'), + ('X', 'P'), + ] else: # bct2 # Body centered tetragonal bct2 @@ -1261,127 +1290,136 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y1': [0.5, 0.5, -csi], 'Z': [0.5, 0.5, -0.5], } - path = [('G', 'X'), - ('X', 'Y'), - ('Y', 'S'), - ('S', 'G'), - ('G', 'Z'), - ('Z', 'S1'), - ('S1', 'N'), - ('N', 'P'), - ('P', 'Y1'), - ('Y1', 'Z'), - ('X', 'P'), - ] + path = [ + ('G', 'X'), + ('X', 'Y'), + ('Y', 'S'), + ('S', 'G'), + ('G', 'Z'), + ('Z', 'S1'), + ('S1', 'N'), + ('N', 'P'), + ('P', 'Y1'), + ('Y1', 'Z'), + ('X', 'P'), + ] # orthorhombic elif bravais_info['index'] == 6: - special_points = {'G': [0., 0., 0.], - 'R': [0.5, 0.5, 0.5], - 'S': [0.5, 0.5, 0.], - 'T': [0., 0.5, 0.5], - 'U': [0.5, 0., 0.5], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'Y'), - ('Y', 'G'), - ('G', 'Z'), - ('Z', 'U'), - ('U', 'R'), - ('R', 'T'), - ('T', 'Z'), - ('Y', 'T'), - ('U', 'X'), - ('S', 'R'), - ] + special_points = { + 'G': [0., 0., 0.], + 'R': [0.5, 0.5, 0.5], + 'S': [0.5, 0.5, 0.], + 'T': [0., 0.5, 0.5], + 'U': [0.5, 0., 0.5], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'Y'), + ('Y', 'G'), + ('G', 'Z'), + ('Z', 'U'), + ('U', 'R'), + ('R', 'T'), + ('T', 'Z'), + ('Y', 'T'), + ('U', 'X'), + ('S', 'R'), + ] # face centered orthorhombic elif bravais_info['index'] == 7: if bravais_info['variation'] == 'orcf1': csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5 + csi, csi], - 'A1': [0.5, 0.5 - csi, 1. - csi], - 'L': [0.5, 0.5, 0.5], - 'T': [1., 0.5, 0.5], - 'X': [0., eta, eta], - 'X1': [1., 1. - eta, 1. - eta], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'T'), - ('T', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('X', 'A1'), - ('A1', 'Y'), - ('T', 'X1'), - ('X', 'A'), - ('A', 'Z'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5 + csi, csi], + 'A1': [0.5, 0.5 - csi, 1. - csi], + 'L': [0.5, 0.5, 0.5], + 'T': [1., 0.5, 0.5], + 'X': [0., eta, eta], + 'X1': [1., 1. - eta, 1. - eta], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'T'), + ('T', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('X', 'A1'), + ('A1', 'Y'), + ('T', 'X1'), + ('X', 'A'), + ('A', 'Z'), + ('L', 'G'), + ] elif bravais_info['variation'] == 'orcf2': eta = bravais_info['extra']['eta'] dlt = bravais_info['extra']['dlt'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'C': [0.5, 0.5 - eta, 1. - eta], - 'C1': [0.5, 0.5 + eta, eta], - 'D': [0.5 - dlt, 0.5, 1. - dlt], - 'D1': [0.5 + dlt, 0.5, dlt], - 'L': [0.5, 0.5, 0.5], - 'H': [1. - phi, 0.5 - phi, 0.5], - 'H1': [phi, 0.5 + phi, 0.5], - 'X': [0., 0.5, 0.5], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'C'), - ('C', 'D'), - ('D', 'X'), - ('X', 'G'), - ('G', 'Z'), - ('Z', 'D1'), - ('D1', 'H'), - ('H', 'C'), - ('C1', 'Z'), - ('X', 'H1'), - ('H', 'Y'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'C': [0.5, 0.5 - eta, 1. - eta], + 'C1': [0.5, 0.5 + eta, eta], + 'D': [0.5 - dlt, 0.5, 1. - dlt], + 'D1': [0.5 + dlt, 0.5, dlt], + 'L': [0.5, 0.5, 0.5], + 'H': [1. - phi, 0.5 - phi, 0.5], + 'H1': [phi, 0.5 + phi, 0.5], + 'X': [0., 0.5, 0.5], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'C'), + ('C', 'D'), + ('D', 'X'), + ('X', 'G'), + ('G', 'Z'), + ('Z', 'D1'), + ('D1', 'H'), + ('H', 'C'), + ('C1', 'Z'), + ('X', 'H1'), + ('H', 'Y'), + ('L', 'G'), + ] else: csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5 + csi, csi], - 'A1': [0.5, 0.5 - csi, 1. - csi], - 'L': [0.5, 0.5, 0.5], - 'T': [1., 0.5, 0.5], - 'X': [0., eta, eta], - 'X1': [1., 1. - eta, 1. - eta], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'T'), - ('T', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('X', 'A1'), - ('A1', 'Y'), - ('X', 'A'), - ('A', 'Z'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5 + csi, csi], + 'A1': [0.5, 0.5 - csi, 1. - csi], + 'L': [0.5, 0.5, 0.5], + 'T': [1., 0.5, 0.5], + 'X': [0., eta, eta], + 'X1': [1., 1. - eta, 1. - eta], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'T'), + ('T', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('X', 'A1'), + ('A1', 'Y'), + ('X', 'A'), + ('A', 'Z'), + ('L', 'G'), + ] # Body centered orthorhombic elif bravais_info['index'] == 8: @@ -1389,168 +1427,180 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, dlt = bravais_info['extra']['dlt'] eta = bravais_info['extra']['eta'] mu = bravais_info['extra']['mu'] - special_points = {'G': [0., 0., 0.], - 'L': [-mu, mu, 0.5 - dlt], - 'L1': [mu, -mu, 0.5 + dlt], - 'L2': [0.5 - dlt, 0.5 + dlt, -mu], - 'R': [0., 0.5, 0.], - 'S': [0.5, 0., 0.], - 'T': [0., 0., 0.5], - 'W': [0.25, 0.25, 0.25], - 'X': [-csi, csi, csi], - 'X1': [csi, 1. - csi, -csi], - 'Y': [eta, -eta, eta], - 'Y1': [1. - eta, eta, -eta], - 'Z': [0.5, 0.5, -0.5], - } - path = [('G', 'X'), - ('X', 'L'), - ('L', 'T'), - ('T', 'W'), - ('W', 'R'), - ('R', 'X1'), - ('X1', 'Z'), - ('Z', 'G'), - ('G', 'Y'), - ('Y', 'S'), - ('S', 'W'), - ('L1', 'Y'), - ('Y1', 'Z'), - ] + special_points = { + 'G': [0., 0., 0.], + 'L': [-mu, mu, 0.5 - dlt], + 'L1': [mu, -mu, 0.5 + dlt], + 'L2': [0.5 - dlt, 0.5 + dlt, -mu], + 'R': [0., 0.5, 0.], + 'S': [0.5, 0., 0.], + 'T': [0., 0., 0.5], + 'W': [0.25, 0.25, 0.25], + 'X': [-csi, csi, csi], + 'X1': [csi, 1. - csi, -csi], + 'Y': [eta, -eta, eta], + 'Y1': [1. - eta, eta, -eta], + 'Z': [0.5, 0.5, -0.5], + } + path = [ + ('G', 'X'), + ('X', 'L'), + ('L', 'T'), + ('T', 'W'), + ('W', 'R'), + ('R', 'X1'), + ('X1', 'Z'), + ('Z', 'G'), + ('G', 'Y'), + ('Y', 'S'), + ('S', 'W'), + ('L1', 'Y'), + ('Y1', 'Z'), + ] # C-centered orthorhombic elif bravais_info['index'] == 9: csi = bravais_info['extra']['csi'] - special_points = {'G': [0., 0., 0.], - 'A': [csi, csi, 0.5], - 'A1': [-csi, 1. - csi, 0.5], - 'R': [0., 0.5, 0.5], - 'S': [0., 0.5, 0.], - 'T': [-0.5, 0.5, 0.5], - 'X': [csi, csi, 0.], - 'X1': [-csi, 1. - csi, 0.], - 'Y': [-0.5, 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'R'), - ('R', 'A'), - ('A', 'Z'), - ('Z', 'G'), - ('G', 'Y'), - ('Y', 'X1'), - ('X1', 'A1'), - ('A1', 'T'), - ('T', 'Y'), - ('Z', 'T'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [csi, csi, 0.5], + 'A1': [-csi, 1. - csi, 0.5], + 'R': [0., 0.5, 0.5], + 'S': [0., 0.5, 0.], + 'T': [-0.5, 0.5, 0.5], + 'X': [csi, csi, 0.], + 'X1': [-csi, 1. - csi, 0.], + 'Y': [-0.5, 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'R'), + ('R', 'A'), + ('A', 'Z'), + ('Z', 'G'), + ('G', 'Y'), + ('Y', 'X1'), + ('X1', 'A1'), + ('A1', 'T'), + ('T', 'Y'), + ('Z', 'T'), + ] # Hexagonal elif bravais_info['index'] == 10: - special_points = {'G': [0., 0., 0.], - 'A': [0., 0., 0.5], - 'H': [1. / 3., 1. / 3., 0.5], - 'K': [1. / 3., 1. / 3., 0.], - 'L': [0.5, 0., 0.5], - 'M': [0.5, 0., 0.], - } - path = [('G', 'M'), - ('M', 'K'), - ('K', 'G'), - ('G', 'A'), - ('A', 'L'), - ('L', 'H'), - ('H', 'A'), - ('L', 'M'), - ('K', 'H'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0., 0., 0.5], + 'H': [1. / 3., 1. / 3., 0.5], + 'K': [1. / 3., 1. / 3., 0.], + 'L': [0.5, 0., 0.5], + 'M': [0.5, 0., 0.], + } + path = [ + ('G', 'M'), + ('M', 'K'), + ('K', 'G'), + ('G', 'A'), + ('A', 'L'), + ('L', 'H'), + ('H', 'A'), + ('L', 'M'), + ('K', 'H'), + ] # rhombohedral elif bravais_info['index'] == 11: if bravais_info['variation'] == 'rhl1': eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'B': [eta, 0.5, 1. - eta], - 'B1': [0.5, 1. - eta, eta - 1.], - 'F': [0.5, 0.5, 0.], - 'L': [0.5, 0., 0.], - 'L1': [0., 0., -0.5], - 'P': [eta, nu, nu], - 'P1': [1. - nu, 1. - nu, 1. - eta], - 'P2': [nu, nu, eta - 1.], - 'Q': [1. - nu, nu, 0.], - 'X': [nu, 0., -nu], - 'Z': [0.5, 0.5, 0.5], - } - path = [('G', 'L'), - ('L', 'B1'), - ('B', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('Q', 'F'), - ('F', 'P1'), - ('P1', 'Z'), - ('L', 'P'), - ] + special_points = { + 'G': [0., 0., 0.], + 'B': [eta, 0.5, 1. - eta], + 'B1': [0.5, 1. - eta, eta - 1.], + 'F': [0.5, 0.5, 0.], + 'L': [0.5, 0., 0.], + 'L1': [0., 0., -0.5], + 'P': [eta, nu, nu], + 'P1': [1. - nu, 1. - nu, 1. - eta], + 'P2': [nu, nu, eta - 1.], + 'Q': [1. - nu, nu, 0.], + 'X': [nu, 0., -nu], + 'Z': [0.5, 0.5, 0.5], + } + path = [ + ('G', 'L'), + ('L', 'B1'), + ('B', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('Q', 'F'), + ('F', 'P1'), + ('P1', 'Z'), + ('L', 'P'), + ] else: # Rhombohedral rhl2 eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'F': [0.5, -0.5, 0.], - 'L': [0.5, 0., 0.], - 'P': [1. - nu, -nu, 1. - nu], - 'P1': [nu, nu - 1., nu - 1.], - 'Q': [eta, eta, eta], - 'Q1': [1. - eta, -eta, -eta], - 'Z': [0.5, -0.5, 0.5], - } - path = [('G', 'P'), - ('P', 'Z'), - ('Z', 'Q'), - ('Q', 'G'), - ('G', 'F'), - ('F', 'P1'), - ('P1', 'Q1'), - ('Q1', 'L'), - ('L', 'Z'), - ] + special_points = { + 'G': [0., 0., 0.], + 'F': [0.5, -0.5, 0.], + 'L': [0.5, 0., 0.], + 'P': [1. - nu, -nu, 1. - nu], + 'P1': [nu, nu - 1., nu - 1.], + 'Q': [eta, eta, eta], + 'Q1': [1. - eta, -eta, -eta], + 'Z': [0.5, -0.5, 0.5], + } + path = [ + ('G', 'P'), + ('P', 'Z'), + ('Z', 'Q'), + ('Q', 'G'), + ('G', 'F'), + ('F', 'P1'), + ('P1', 'Q1'), + ('Q1', 'L'), + ('L', 'Z'), + ] # monoclinic elif bravais_info['index'] == 12: eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5, 0.], - 'C': [0., 0.5, 0.5], - 'D': [0.5, 0., 0.5], - 'D1': [0.5, 0., -0.5], - 'E': [0.5, 0.5, 0.5], - 'H': [0., eta, 1. - nu], - 'H1': [0., 1. - eta, nu], - 'H2': [0., eta, -nu], - 'M': [0.5, eta, 1. - nu], - 'M1': [0.5, 1. - eta, nu], - 'M2': [0.5, eta, -nu], - 'X': [0., 0.5, 0.], - 'Y': [0., 0., 0.5], - 'Y1': [0., 0., -0.5], - 'Z': [0.5, 0., 0.], - } - path = [('G', 'Y'), - ('Y', 'H'), - ('H', 'C'), - ('C', 'E'), - ('E', 'M1'), - ('M1', 'A'), - ('A', 'X'), - ('X', 'H1'), - ('M', 'D'), - ('D', 'Z'), - ('Y', 'D'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5, 0.], + 'C': [0., 0.5, 0.5], + 'D': [0.5, 0., 0.5], + 'D1': [0.5, 0., -0.5], + 'E': [0.5, 0.5, 0.5], + 'H': [0., eta, 1. - nu], + 'H1': [0., 1. - eta, nu], + 'H2': [0., eta, -nu], + 'M': [0.5, eta, 1. - nu], + 'M1': [0.5, 1. - eta, nu], + 'M2': [0.5, eta, -nu], + 'X': [0., 0.5, 0.], + 'Y': [0., 0., 0.5], + 'Y1': [0., 0., -0.5], + 'Z': [0.5, 0., 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'H'), + ('H', 'C'), + ('C', 'E'), + ('E', 'M1'), + ('M1', 'A'), + ('A', 'X'), + ('X', 'H1'), + ('M', 'D'), + ('D', 'Z'), + ('Y', 'D'), + ] elif bravais_info['index'] == 13: if bravais_info['variation'] == 'mclc1': @@ -1558,68 +1608,72 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, eta = bravais_info['extra']['eta'] psi = bravais_info['extra']['psi'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'F': [1. - csi, 1. - csi, 1. - eta], - 'F1': [csi, csi, eta], - 'F2': [csi, -csi, 1. - eta], - 'F3': [1. - csi, -csi, 1. - eta], - 'I': [phi, 1. - phi, 0.5], - 'I1': [1. - phi, phi - 1., 0.5], - 'L': [0.5, 0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'X': [1. - psi, psi - 1., 0.], - 'X1': [psi, 1. - psi, 0.], - 'X2': [psi - 1., -psi, 0.], - 'Y': [0.5, 0.5, 0.], - 'Y1': [-0.5, -0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'F1'), - ('Y', 'X1'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'F': [1. - csi, 1. - csi, 1. - eta], + 'F1': [csi, csi, eta], + 'F2': [csi, -csi, 1. - eta], + 'F3': [1. - csi, -csi, 1. - eta], + 'I': [phi, 1. - phi, 0.5], + 'I1': [1. - phi, phi - 1., 0.5], + 'L': [0.5, 0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'X': [1. - psi, psi - 1., 0.], + 'X1': [psi, 1. - psi, 0.], + 'X2': [psi - 1., -psi, 0.], + 'Y': [0.5, 0.5, 0.], + 'Y1': [-0.5, -0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'F1'), + ('Y', 'X1'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] elif bravais_info['variation'] == 'mclc2': csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] psi = bravais_info['extra']['psi'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'F': [1. - csi, 1. - csi, 1. - eta], - 'F1': [csi, csi, eta], - 'F2': [csi, -csi, 1. - eta], - 'F3': [1. - csi, -csi, 1. - eta], - 'I': [phi, 1. - phi, 0.5], - 'I1': [1. - phi, phi - 1., 0.5], - 'L': [0.5, 0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'X': [1. - psi, psi - 1., 0.], - 'X1': [psi, 1. - psi, 0.], - 'X2': [psi - 1., -psi, 0.], - 'Y': [0.5, 0.5, 0.], - 'Y1': [-0.5, -0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'F1'), - ('N', 'G'), - ('G', 'M'), - ] + special_points = { + 'G': [0., 0., 0.], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'F': [1. - csi, 1. - csi, 1. - eta], + 'F1': [csi, csi, eta], + 'F2': [csi, -csi, 1. - eta], + 'F3': [1. - csi, -csi, 1. - eta], + 'I': [phi, 1. - phi, 0.5], + 'I1': [1. - phi, phi - 1., 0.5], + 'L': [0.5, 0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'X': [1. - psi, psi - 1., 0.], + 'X1': [psi, 1. - psi, 0.], + 'X2': [psi - 1., -psi, 0.], + 'Y': [0.5, 0.5, 0.], + 'Y1': [-0.5, -0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'F1'), + ('N', 'G'), + ('G', 'M'), + ] elif bravais_info['variation'] == 'mclc3': mu = bravais_info['extra']['mu'] @@ -1647,18 +1701,19 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y3': [mu, mu - 1., dlt], 'Z': [0., 0., 0.5], } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'H'), - ('H', 'Z'), - ('Z', 'I'), - ('I', 'F1'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'F'), - ('G', 'N'), - ('M', 'G'), - ] + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'H'), + ('H', 'Z'), + ('Z', 'I'), + ('I', 'F1'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'F'), + ('G', 'N'), + ('M', 'G'), + ] elif bravais_info['variation'] == 'mclc4': mu = bravais_info['extra']['mu'] @@ -1667,35 +1722,37 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, eta = bravais_info['extra']['eta'] phi = bravais_info['extra']['phi'] psi = bravais_info['extra']['psi'] - special_points = {'G': [0., 0., 0.], - 'F': [1. - phi, 1 - phi, 1. - psi], - 'F1': [phi, phi - 1., psi], - 'F2': [1. - phi, -phi, 1. - psi], - 'H': [csi, csi, eta], - 'H1': [1. - csi, -csi, 1. - eta], - 'H2': [-csi, -csi, 1. - eta], - 'I': [0.5, -0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'X': [0.5, -0.5, 0.], - 'Y': [mu, mu, dlt], - 'Y1': [1. - mu, -mu, -dlt], - 'Y2': [-mu, -mu, -dlt], - 'Y3': [mu, mu - 1., dlt], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'H'), - ('H', 'Z'), - ('Z', 'I'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'F': [1. - phi, 1 - phi, 1. - psi], + 'F1': [phi, phi - 1., psi], + 'F2': [1. - phi, -phi, 1. - psi], + 'H': [csi, csi, eta], + 'H1': [1. - csi, -csi, 1. - eta], + 'H2': [-csi, -csi, 1. - eta], + 'I': [0.5, -0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'X': [0.5, -0.5, 0.], + 'Y': [mu, mu, dlt], + 'Y1': [1. - mu, -mu, -dlt], + 'Y2': [-mu, -mu, -dlt], + 'Y3': [mu, mu - 1., dlt], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'H'), + ('H', 'Z'), + ('Z', 'I'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] else: csi = bravais_info['extra']['csi'] @@ -1726,85 +1783,94 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y3': [mu, mu - 1., dlt], 'Z': [0., 0., 0.5], } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'H'), - ('H', 'F1'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'H'), + ('H', 'F1'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] # triclinic elif bravais_info['index'] == 14: if bravais_info['variation'] == 'tri1a' or bravais_info['variation'] == 'tri2a': - special_points = {'G': [0.0, 0.0, 0.0], - 'L': [0.5, 0.5, 0.0], - 'M': [0.0, 0.5, 0.5], - 'N': [0.5, 0.0, 0.5], - 'R': [0.5, 0.5, 0.5], - 'X': [0.5, 0.0, 0.0], - 'Y': [0.0, 0.5, 0.0], - 'Z': [0.0, 0.0, 0.5], - } - path = [('X', 'G'), - ('G', 'Y'), - ('L', 'G'), - ('G', 'Z'), - ('N', 'G'), - ('G', 'M'), - ('R', 'G'), - ] + special_points = { + 'G': [0.0, 0.0, 0.0], + 'L': [0.5, 0.5, 0.0], + 'M': [0.0, 0.5, 0.5], + 'N': [0.5, 0.0, 0.5], + 'R': [0.5, 0.5, 0.5], + 'X': [0.5, 0.0, 0.0], + 'Y': [0.0, 0.5, 0.0], + 'Z': [0.0, 0.0, 0.5], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('L', 'G'), + ('G', 'Z'), + ('N', 'G'), + ('G', 'M'), + ('R', 'G'), + ] else: - special_points = {'G': [0.0, 0.0, 0.0], - 'L': [0.5, -0.5, 0.0], - 'M': [0.0, 0.0, 0.5], - 'N': [-0.5, -0.5, 0.5], - 'R': [0.0, -0.5, 0.5], - 'X': [0.0, -0.5, 0.0], - 'Y': [0.5, 0.0, 0.0], - 'Z': [-0.5, 0.0, 0.5], - } - path = [('X', 'G'), - ('G', 'Y'), - ('L', 'G'), - ('G', 'Z'), - ('N', 'G'), - ('G', 'M'), - ('R', 'G'), - ] + special_points = { + 'G': [0.0, 0.0, 0.0], + 'L': [0.5, -0.5, 0.0], + 'M': [0.0, 0.0, 0.5], + 'N': [-0.5, -0.5, 0.5], + 'R': [0.0, -0.5, 0.5], + 'X': [0.0, -0.5, 0.0], + 'Y': [0.5, 0.0, 0.0], + 'Z': [-0.5, 0.0, 0.5], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('L', 'G'), + ('G', 'Z'), + ('N', 'G'), + ('G', 'M'), + ('R', 'G'), + ] elif dimension == 2: # 2D case: 5 Bravais lattices if bravais_info['index'] == 1: # square - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0.5, 0.], - 'X': [0.5, 0., 0.], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0.5, 0.], + 'X': [0.5, 0., 0.], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ] elif bravais_info['index'] == 2: # (primitive) rectangular - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'S': [0.5, 0.5, 0.], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'Y'), - ('Y', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'S': [0.5, 0.5, 0.], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'Y'), + ('Y', 'G'), + ] elif bravais_info['index'] == 3: # centered rectangular (rhombic) @@ -1815,57 +1881,67 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, # coordinates (primitive reciprocal cell) as for the rest. # Ramirez & Bohn gave them initially in (s1=b1+b2, s2=-b1+b2) # coordinates, i.e. using the conventional reciprocal cell. - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0.5, 0.], - 'Y1': [0.25, 0.75, 0.], - 'Y': [-0.25, 0.25, 0.], # typo in p. 404 of Ramirez & Bohm (should be Y=(0,1/4)) - 'C': [0., 0.5, 0.], - } - path = [('Y1', 'X'), - ('X', 'G'), - ('G', 'Y'), - ('Y', 'C'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0.5, 0.], + 'Y1': [0.25, 0.75, 0.], + 'Y': [-0.25, 0.25, 0.], # typo in p. 404 of Ramirez & Bohm (should be Y=(0,1/4)) + 'C': [0., 0.5, 0.], + } + path = [ + ('Y1', 'X'), + ('X', 'G'), + ('G', 'Y'), + ('Y', 'C'), + ] elif bravais_info['index'] == 4: # hexagonal - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0., 0.], - 'K': [1. / 3., 1. / 3., 0.], - } - path = [('G', 'M'), - ('M', 'K'), - ('K', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0., 0.], + 'K': [1. / 3., 1. / 3., 0.], + } + path = [ + ('G', 'M'), + ('M', 'K'), + ('K', 'G'), + ] elif bravais_info['index'] == 5: # oblique # NOTE: only end-points are high-symmetry points (not the path # in-between) - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'A': [0.5, 0.5, 0.], - } - path = [('X', 'G'), - ('G', 'Y'), - ('A', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'A': [0.5, 0.5, 0.], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('A', 'G'), + ] elif dimension == 1: # 1D case: 1 Bravais lattice - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - } - path = [('G', 'X'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + } + path = [ + ('G', 'X'), + ] elif dimension == 0: # 0D case: 1 Bravais lattice, only Gamma point, no path - special_points = {'G': [0., 0., 0.], - } - path = [('G', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + } + path = [ + ('G', 'G'), + ] permutation = bravais_info['permutation'] @@ -1873,23 +1949,19 @@ def permute(x, permutation): # return new_x such that new_x[i]=x[permutation[i]] return [x[int(p)] for p in permutation] - def invpermute(permutation): - # return the inverse of permutation - return [permutation.index(i) for i in range(3)] - the_special_points = {} - for k in special_points.keys(): + for key in special_points: # NOTE: this originally returned the inverse of the permutation, but was later changed to permutation - the_special_points[k] = permute(special_points[k], permutation) + the_special_points[key] = permute(special_points[key], permutation) # output crystal or cartesian if cartesian: the_abs_special_points = {} - for k in the_special_points.keys(): - the_abs_special_points[k] = change_reference( - reciprocal_cell, numpy.array(the_special_points[k]), to_cartesian=True + for key in the_special_points: + the_abs_special_points[key] = change_reference( + reciprocal_cell, numpy.array(the_special_points[key]), to_cartesian=True ) return the_abs_special_points, path, bravais_info - else: - return the_special_points, path, bravais_info + + return the_special_points, path, bravais_info diff --git a/aiida/tools/data/array/kpoints/seekpath.py b/aiida/tools/data/array/kpoints/seekpath.py index 176852fef0..9aa9dc36bb 100644 --- a/aiida/tools/data/array/kpoints/seekpath.py +++ b/aiida/tools/data/array/kpoints/seekpath.py @@ -7,22 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tool to automatically determine k-points for a given structure using SeeK-path.""" +import seekpath from aiida.orm import KpointsData, Dict -__all__ = ('check_seekpath_is_installed', 'get_explicit_kpoints_path', 'get_kpoints_path') - - -def check_seekpath_is_installed(): - """ - Tries to import the Seekpath module. Raise ImportError if it cannot be imported - - :raises: ImportError - """ - try: - import seekpath - except ImportError: - raise ImportError("Seekpath is not installed, please install with 'pip install seekpath'") +__all__ = ('get_explicit_kpoints_path', 'get_kpoints_path') def get_explicit_kpoints_path(structure, parameters): @@ -65,11 +55,9 @@ def get_explicit_kpoints_path(structure, parameters): - ``conv_structure``: A StructureData with the primitive structure """ + # pylint: disable=too-many-locals from aiida.tools.data.structure import spglib_tuple_to_structure, structure_to_spglib_tuple - check_seekpath_is_installed() - import seekpath - structure_tuple, kind_info, kinds = structure_to_spglib_tuple(structure) result = {} @@ -92,7 +80,6 @@ def get_explicit_kpoints_path(structure, parameters): # Remove reciprocal_primitive_lattice, recalculated by kpoints class rawdict.pop('reciprocal_primitive_lattice') kpoints_abs = rawdict.pop('explicit_kpoints_abs') - kpoints_rel = rawdict.pop('explicit_kpoints_rel') kpoints_labels = rawdict.pop('explicit_kpoints_labels') # set_kpoints expects labels like [[0,'X'],[34,'L'],...], so generate it here skipping empty labels @@ -146,9 +133,6 @@ def get_kpoints_path(structure, parameters): """ from aiida.tools.data.structure import spglib_tuple_to_structure, structure_to_spglib_tuple - check_seekpath_is_installed() - import seekpath - structure_tuple, kind_info, kinds = structure_to_spglib_tuple(structure) result = {} diff --git a/aiida/tools/data/cif.py b/aiida/tools/data/cif.py index 813ffa7fcc..a2eb83c69f 100644 --- a/aiida/tools/data/cif.py +++ b/aiida/tools/data/cif.py @@ -12,7 +12,7 @@ from aiida.engine import calcfunction from aiida.orm import CifData -from aiida.orm.utils.node import clean_value +from aiida.orm.implementation.utils import clean_value class InvalidOccupationsError(Exception): @@ -141,12 +141,12 @@ def _get_aiida_structure_pymatgen_inline(cif, **kwargs): structures = parser.get_structures(**parameters) except ValueError: # If it still fails, the occupancies were not the reason for failure - raise ValueError('pymatgen failed to provide a structure from the cif file') + raise ValueError('pymatgen failed to provide a structure from the cif file') from ValueError else: # If it now succeeds, non-unity occupancies were the culprit raise InvalidOccupationsError( 'detected atomic sites with an occupation number larger than the occupation tolerance' - ) + ) from ValueError return {'structure': StructureData(pymatgen_structure=structures[0])} @@ -203,8 +203,7 @@ def refine_inline(node): # Summary formula has to be calculated from non-reduced set of atoms. cif.values[name]['_chemical_formula_sum'] = \ - StructureData(ase=original_atoms).get_formula(mode='hill', - separator=' ') + StructureData(ase=original_atoms).get_formula(mode='hill', separator=' ') # If the number of reduced atoms multiplies the number of non-reduced # atoms, the new Z value can be calculated. diff --git a/aiida/tools/data/structure/__init__.py b/aiida/tools/data/structure/__init__.py index e75e893825..d308d04373 100644 --- a/aiida/tools/data/structure/__init__.py +++ b/aiida/tools/data/structure/__init__.py @@ -165,7 +165,7 @@ def spglib_tuple_to_structure(structure_tuple, kind_info=None, kinds=None): # p structure.append_kind(k) abs_pos = np.dot(rel_pos, cell) if len(abs_pos) != len(site_kinds): - raise ValueError('The length of the positions array is different from the ' 'length of the element numbers') + raise ValueError('The length of the positions array is different from the length of the element numbers') for kind, pos in zip(site_kinds, abs_pos): structure.append_site(Site(kind_name=kind.name, position=pos)) diff --git a/aiida/tools/dbimporters/__init__.py b/aiida/tools/dbimporters/__init__.py index 7ce146e69d..7cef45c164 100644 --- a/aiida/tools/dbimporters/__init__.py +++ b/aiida/tools/dbimporters/__init__.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module for plugins to import data from external databases into an AiiDA database.""" from .baseclasses import DbImporter __all__ = ('DbImporter',) diff --git a/aiida/tools/dbimporters/baseclasses.py b/aiida/tools/dbimporters/baseclasses.py index 28f6a52d17..64db13dac1 100644 --- a/aiida/tools/dbimporters/baseclasses.py +++ b/aiida/tools/dbimporters/baseclasses.py @@ -7,17 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Base class implementation for an external database importer.""" +import io class DbImporter: - """ - Base class for database importers. - """ + """Base class implementation for an external database importer.""" def query(self, **kwargs): - """ - Method to query the database. + """Method to query the database. :param id: database-specific entry identificator :param element: element name from periodic table of elements @@ -86,14 +84,14 @@ class DbSearchResults: ``__getitem__``. """ + _return_class = None + def __init__(self, results): self._results = results self._entries = {} class DbSearchResultsIterator: - """ - Iterator for search results - """ + """Iterator for search results.""" def __init__(self, results, increment=1): self._results = results @@ -101,12 +99,13 @@ def __init__(self, results, increment=1): self._increment = increment def __next__(self): + """Return the next entry in the iterator.""" pos = self._position - if pos >= 0 and pos < len(self._results): + if pos >= 0 and pos < len(self._results): # pylint: disable=chained-comparison self._position = self._position + self._increment return self._results[pos] - else: - raise StopIteration() + + raise StopIteration() def __iter__(self): """ @@ -141,7 +140,7 @@ def next(self): """ raise NotImplementedError('not implemented in base class') - def at(self, position): + def at(self, position): # pylint: disable=invalid-name """ Returns ``position``-th result as :py:class:`aiida.tools.dbimporters.baseclasses.DbEntry`. @@ -155,7 +154,7 @@ def at(self, position): if position not in self._entries: source_dict = self._get_source_dict(self._results[position]) url = self._get_url(self._results[position]) - self._entries[position] = self._return_class(url, **source_dict) + self._entries[position] = self._return_class(url, **source_dict) # pylint: disable=not-callable return self._entries[position] def _get_source_dict(self, result_dict): @@ -182,8 +181,7 @@ class DbEntry: """ _license = None - def __init__(self, db_name=None, db_uri=None, id=None, - version=None, extras={}, uri=None): + def __init__(self, db_name=None, db_uri=None, id=None, version=None, extras=None, uri=None): # pylint: disable=too-many-arguments,redefined-builtin """ Sets the basic parameters for the database entry: @@ -200,7 +198,7 @@ def __init__(self, db_name=None, db_uri=None, id=None, 'db_uri': db_uri, 'id': id, 'version': version, - 'extras': extras, + 'extras': extras or {}, 'uri': uri, 'source_md5': None, 'license': self._license, @@ -208,11 +206,13 @@ def __init__(self, db_name=None, db_uri=None, id=None, self._contents = None def __repr__(self): - return '{}({})'.format(self.__class__.__name__, - ','.join(['{}={}'.format(k, '"{}"'.format(self.source[k]) - if issubclass(self.source[k].__class__, str) - else self.source[k]) - for k in sorted(self.source.keys())])) + return '{}({})'.format( + self.__class__.__name__, ','.join([ + '{}={}'.format( + k, '"{}"'.format(self.source[k]) if issubclass(self.source[k].__class__, str) else self.source[k] + ) for k in sorted(self.source.keys()) + ]) + ) @property def contents(self): @@ -272,7 +272,7 @@ def get_ase_structure(self): :py:class:`aiida.orm.nodes.data.cif.CifData`. """ from aiida.orm import CifData - return CifData.read_cif(StringIO(self.cif)) + return CifData.read_cif(io.StringIO(self.cif)) def get_cif_node(self, store=False, parse_policy='lazy'): """ @@ -286,10 +286,10 @@ def get_cif_node(self, store=False, parse_policy='lazy'): cifnode = None - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(self.cif) - f.flush() - cifnode = CifData(file=f.name, source=self.source, parse_policy=parse_policy) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(self.cif) + handle.flush() + cifnode = CifData(file=handle.name, source=self.source, parse_policy=parse_policy) # Maintaining backwards-compatibility. Parameter 'store' should # be removed in the future, as the new node can be stored later. @@ -333,10 +333,10 @@ def get_upf_node(self, store=False): # Prefixing with an ID in order to start file name with the name # of the described element. - with tempfile.NamedTemporaryFile(mode='w+', prefix=self.source['id']) as f: - f.write(self.contents) - f.flush() - upfnode = UpfData(file=f.name, source=self.source) + with tempfile.NamedTemporaryFile(mode='w+', prefix=self.source['id']) as handle: + handle.write(self.contents) + handle.flush() + upfnode = UpfData(file=handle.name, source=self.source) # Maintaining backwards-compatibility. Parameter 'store' should # be removed in the future, as the new node can be stored later. diff --git a/aiida/tools/dbimporters/plugins/__init__.py b/aiida/tools/dbimporters/plugins/__init__.py index 2776a55f97..a98fa2e761 100644 --- a/aiida/tools/dbimporters/plugins/__init__.py +++ b/aiida/tools/dbimporters/plugins/__init__.py @@ -7,3 +7,4 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Module for plugins to import data from external databases into an AiiDA database.""" diff --git a/aiida/tools/dbimporters/plugins/cod.py b/aiida/tools/dbimporters/plugins/cod.py index f549bfb54a..12f35e0c42 100644 --- a/aiida/tools/dbimporters/plugins/cod.py +++ b/aiida/tools/dbimporters/plugins/cod.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the COD database.""" +from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) class CodDbImporter(DbImporter): @@ -23,10 +21,9 @@ def _int_clause(self, key, alias, values): """ Returns SQL query predicate for querying integer fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") return key + ' IN (' + ', '.join(str(int(i)) for i in values) + ')' def _str_exact_clause(self, key, alias, values): @@ -34,13 +31,12 @@ def _str_exact_clause(self, key, alias, values): Returns SQL query predicate for querying string fields. """ clause_parts = [] - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") - if isinstance(e, int): - e = str(e) - clause_parts.append("'" + e + "'") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") + if isinstance(value, int): + value = str(value) + clause_parts.append("'" + value + "'") return key + ' IN (' + ', '.join(clause_parts) + ')' def _str_exact_or_none_clause(self, key, alias, values): @@ -50,64 +46,58 @@ def _str_exact_or_none_clause(self, key, alias, values): """ if None in values: values_now = [] - for e in values: - if e is not None: - values_now.append(e) - if len(values_now): + for value in values: + if value is not None: + values_now.append(value) + if values_now: clause = self._str_exact_clause(key, alias, values_now) return '{} OR {} IS NULL'.format(clause, key) - else: - return '{} IS NULL'.format(key) - else: - return self._str_exact_clause(key, alias, values) + + return '{} IS NULL'.format(key) + + return self._str_exact_clause(key, alias, values) def _formula_clause(self, key, alias, values): """ Returns SQL query predicate for querying formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - return self._str_exact_clause(key, \ - alias, \ - ['- {} -'.format(f) for f in values]) + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only strings are accepted") + return self._str_exact_clause(key, alias, ['- {} -'.format(f) for f in values]) def _str_fuzzy_clause(self, key, alias, values): """ Returns SQL query predicate for fuzzy querying of string fields. """ clause_parts = [] - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") - if isinstance(e, int): - e = str(e) - clause_parts.append(key + " LIKE '%" + e + "%'") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") + if isinstance(value, int): + value = str(value) + clause_parts.append(key + " LIKE '%" + value + "%'") return ' OR '.join(clause_parts) - def _composition_clause(self, key, alias, values): + def _composition_clause(self, _, alias, values): """ Returns SQL query predicate for querying elements in formula fields. """ clause_parts = [] - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - clause_parts.append("formula REGEXP ' " + e + "[0-9 ]'") + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only strings are accepted") + clause_parts.append("formula REGEXP ' " + value + "[0-9 ]'") return ' AND '.join(clause_parts) def _double_clause(self, key, alias, values, precision): """ Returns SQL query predicate for querying double-valued fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, float): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and floats are accepted") - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d-precision, d+precision) for d in values) + for value in values: + if not isinstance(value, int) and not isinstance(value, float): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and floats are accepted") + return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) length_precision = 0.001 angle_precision = 0.001 @@ -145,46 +135,43 @@ def _pressure_clause(self, key, alias, values): """ return self._double_clause(key, alias, values, self.pressure_precision) - _keywords = {'id': ['file', _int_clause], - 'element': ['element', _composition_clause], - 'number_of_elements': ['nel', _int_clause], - 'mineral_name': ['mineral', _str_fuzzy_clause], - 'chemical_name': ['chemname', _str_fuzzy_clause], - 'formula': ['formula', _formula_clause], - 'volume': ['vol', _volume_clause], - 'spacegroup': ['sg', _str_exact_clause], - 'spacegroup_hall': ['sgHall', _str_exact_clause], - 'a': ['a', _length_clause], - 'b': ['b', _length_clause], - 'c': ['c', _length_clause], - 'alpha': ['alpha', _angle_clause], - 'beta': ['beta', _angle_clause], - 'gamma': ['gamma', _angle_clause], - 'z': ['Z', _int_clause], - 'measurement_temp': ['celltemp', _temperature_clause], - 'diffraction_temp': ['diffrtemp', _temperature_clause], - 'measurement_pressure': - ['cellpressure', _pressure_clause], - 'diffraction_pressure': - ['diffrpressure', _pressure_clause], - 'authors': ['authors', _str_fuzzy_clause], - 'journal': ['journal', _str_fuzzy_clause], - 'title': ['title', _str_fuzzy_clause], - 'year': ['year', _int_clause], - 'journal_volume': ['volume', _int_clause], - 'journal_issue': ['issue', _str_exact_clause], - 'first_page': ['firstpage', _str_exact_clause], - 'last_page': ['lastpage', _str_exact_clause], - 'doi': ['doi', _str_exact_clause], - 'determination_method': ['method', _str_exact_or_none_clause]} + _keywords = { + 'id': ['file', _int_clause], + 'element': ['element', _composition_clause], + 'number_of_elements': ['nel', _int_clause], + 'mineral_name': ['mineral', _str_fuzzy_clause], + 'chemical_name': ['chemname', _str_fuzzy_clause], + 'formula': ['formula', _formula_clause], + 'volume': ['vol', _volume_clause], + 'spacegroup': ['sg', _str_exact_clause], + 'spacegroup_hall': ['sgHall', _str_exact_clause], + 'a': ['a', _length_clause], + 'b': ['b', _length_clause], + 'c': ['c', _length_clause], + 'alpha': ['alpha', _angle_clause], + 'beta': ['beta', _angle_clause], + 'gamma': ['gamma', _angle_clause], + 'z': ['Z', _int_clause], + 'measurement_temp': ['celltemp', _temperature_clause], + 'diffraction_temp': ['diffrtemp', _temperature_clause], + 'measurement_pressure': ['cellpressure', _pressure_clause], + 'diffraction_pressure': ['diffrpressure', _pressure_clause], + 'authors': ['authors', _str_fuzzy_clause], + 'journal': ['journal', _str_fuzzy_clause], + 'title': ['title', _str_fuzzy_clause], + 'year': ['year', _int_clause], + 'journal_volume': ['volume', _int_clause], + 'journal_issue': ['issue', _str_exact_clause], + 'first_page': ['firstpage', _str_exact_clause], + 'last_page': ['lastpage', _str_exact_clause], + 'doi': ['doi', _str_exact_clause], + 'determination_method': ['method', _str_exact_or_none_clause] + } def __init__(self, **kwargs): self._db = None self._cursor = None - self._db_parameters = {'host': 'www.crystallography.net', - 'user': 'cod_reader', - 'passwd': '', - 'db': 'cod'} + self._db_parameters = {'host': 'www.crystallography.net', 'user': 'cod_reader', 'passwd': '', 'db': 'cod'} self.setup_db(**kwargs) def query_sql(self, **kwargs): @@ -200,19 +187,12 @@ def query_sql(self, **kwargs): values = kwargs.pop(key) if not isinstance(values, list): values = [values] - sql_parts.append( \ - '(' + self._keywords[key][1](self, \ - self._keywords[key][0], \ - key, \ - values) + \ - ')') - if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "search keyword(s) '" + \ - "', '".join(kwargs.keys()) + "' " + \ - 'is(are) not implemented for COD') - return 'SELECT file, svnrevision FROM data WHERE ' + \ - ' AND '.join(sql_parts) + sql_parts.append('(' + self._keywords[key][1](self, self._keywords[key][0], key, values) + ')') + + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) + + return 'SELECT file, svnrevision FROM data WHERE ' + ' AND '.join(sql_parts) def query(self, **kwargs): """ @@ -229,8 +209,7 @@ def query(self, **kwargs): self._cursor.execute(query_statement) self._db.commit() for row in self._cursor.fetchall(): - results.append({'id': str(row[0]), - 'svnrevision': str(row[1])}) + results.append({'id': str(row[0]), 'svnrevision': str(row[1])}) finally: self._disconnect_db() @@ -240,15 +219,14 @@ def setup_db(self, **kwargs): """ Changes the database connection details. """ - for key in self._db_parameters.keys(): + for key in self._db_parameters: if key in kwargs.keys(): self._db_parameters[key] = kwargs.pop(key) if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: '" + \ - "', '".join(self._db_parameters.keys()) + "'") + raise NotImplementedError( + "unknown database connection parameter(s): '" + "', '".join(kwargs.keys()) + + "', available parameters: '" + "', '".join(self._db_parameters.keys()) + "'" + ) def get_supported_keywords(self): """ @@ -267,10 +245,12 @@ def _connect_db(self): except ImportError: import pymysql as MySQLdb - self._db = MySQLdb.connect(host=self._db_parameters['host'], - user=self._db_parameters['user'], - passwd=self._db_parameters['passwd'], - db=self._db_parameters['db']) + self._db = MySQLdb.connect( + host=self._db_parameters['host'], + user=self._db_parameters['user'], + passwd=self._db_parameters['passwd'], + db=self._db_parameters['db'] + ) self._cursor = self._db.cursor() def _disconnect_db(self): @@ -280,7 +260,7 @@ def _disconnect_db(self): self._db.close() -class CodSearchResults(DbSearchResults): +class CodSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on COD. """ @@ -316,22 +296,22 @@ def _get_url(self, result_dict): if 'svnrevision' in result_dict and \ result_dict['svnrevision'] is not None: return '{}@{}'.format(url, result_dict['svnrevision']) - else: - return url + + return url -class CodEntry(CifEntry): +class CodEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from COD. """ _license = 'CC0' - def __init__(self, uri, db_name='Crystallography Open Database', - db_uri='http://www.crystallography.net/cod', **kwargs): + def __init__( + self, uri, db_name='Crystallography Open Database', db_uri='http://www.crystallography.net/cod', **kwargs + ): """ Creates an instance of :py:class:`aiida.tools.dbimporters.plugins.cod.CodEntry`, related to the supplied URI. """ - super().__init__(db_name=db_name, db_uri=db_uri, - uri=uri, **kwargs) + super().__init__(db_name=db_name, db_uri=db_uri, uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/icsd.py b/aiida/tools/dbimporters/plugins/icsd.py index 6367a522f5..2584a41081 100644 --- a/aiida/tools/dbimporters/plugins/icsd.py +++ b/aiida/tools/dbimporters/plugins/icsd.py @@ -7,29 +7,23 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the CISD database.""" +import io - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry class IcsdImporterExp(Exception): - pass + """Base class for ICSD exceptions.""" class CifFileErrorExp(IcsdImporterExp): - """ - Raised when the author loop is missing in a CIF file. - """ - pass + """Raised when the author loop is missing in a CIF file.""" class NoResultsWebExp(IcsdImporterExp): - """ - Raised when a webpage query returns no results. - """ - pass + """Raised when a webpage query returns no results.""" class IcsdDbImporter(DbImporter): @@ -82,15 +76,16 @@ class IcsdDbImporter(DbImporter): def __init__(self, **kwargs): - self.db_parameters = {'server': '', - 'urladd': 'index.php?', - 'querydb': True, - 'dl_db': 'icsd', - 'host': '', - 'user': 'dba', - 'passwd': 'sql', - 'db': 'icsd', - 'port': '3306', + self.db_parameters = { + 'server': '', + 'urladd': 'index.php?', + 'querydb': True, + 'dl_db': 'icsd', + 'host': '', + 'user': 'dba', + 'passwd': 'sql', + 'db': 'icsd', + 'port': '3306', } self.setup_db(**kwargs) @@ -103,69 +98,61 @@ def _int_clause(self, key, alias, values): :param values: Corresponding values from query :return: SQL query predicate """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return '{} IN ({})'.format(key, ', '.join(str(int(i)) for i in values)) def _str_exact_clause(self, key, alias, values): """ Return SQL query predicate for querying string fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return '{} IN ({})'.format(key, ', '.join("'{}'".format(f) for f in values)) def _formula_clause(self, key, alias, values): """ Return SQL query predicate for querying formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - return self._str_exact_clause(key, \ - alias, \ - [str(f) for f in values]) + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') + return self._str_exact_clause(key, alias, [str(f) for f in values]) def _str_fuzzy_clause(self, key, alias, values): """ Return SQL query predicate for fuzzy querying of string fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return ' OR '.join("{} LIKE '%{}%'".format(key, s) for s in values) - def _composition_clause(self, key, alias, values): + def _composition_clause(self, key, alias, values): # pylint: disable=unused-argument """ Return SQL query predicate for querying elements in formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') # SUM_FORM in the ICSD always stores a numeral after the element name, # STRUCT_FORM does not, so it's better to use SUM_FORM for the composition query. # The element-numeral pair can be in the beginning of the formula expression (therefore no space before), # or at the end of the formula expression (no space after). # Be aware that one needs to check that space/beginning of line before and ideally also space/end of line # after, because I found that capitalization of the element name is not enforced in these queries. - return ' AND '.join(r'SUM_FORM REGEXP \'(^|\ ){}[0-9\.]+($|\ )\''.format(e) for e in values) + return ' AND '.join(r'SUM_FORM REGEXP \'(^|\ ){}[0-9\.]+($|\ )\''.format(value) for value in values) def _double_clause(self, key, alias, values, precision): """ Return SQL query predicate for querying double-valued fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, float): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and floats are accepted") - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d-precision, d+precision) for d in values) + for value in values: + if not isinstance(value, int) and not isinstance(value, float): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and floats are accepted') + return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) def _crystal_system_clause(self, key, alias, values): """ @@ -181,117 +168,127 @@ def _crystal_system_clause(self, key, alias, values): 'triclinic': 'TC' } # from icsd accepted crystal systems - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') return key + ' IN (' + ', '.join("'" + valid_systems[f.lower()] + "'" for f in values) + ')' def _length_clause(self, key, alias, values): """ Return SQL query predicate for querying lattice vector lengths. """ - return self.double_clause(key, alias, values, self.length_precision) + return self._double_clause(key, alias, values, self.length_precision) def _density_clause(self, key, alias, values): """ Return SQL query predicate for querying density. """ - return self.double_clause(key, alias, values, self.density_precision) + return self._double_clause(key, alias, values, self.density_precision) def _angle_clause(self, key, alias, values): """ Return SQL query predicate for querying lattice angles. """ - return self.double_clause(key, alias, values, self.angle_precision) + return self._double_clause(key, alias, values, self.angle_precision) def _volume_clause(self, key, alias, values): """ Return SQL query predicate for querying unit cell volume. """ - return self.double_clause(key, alias, values, self.volume_precision) + return self._double_clause(key, alias, values, self.volume_precision) def _temperature_clause(self, key, alias, values): """ Return SQL query predicate for querying temperature. """ - return self.double_clause(key, alias, values, self.temperature_precision) + return self._double_clause(key, alias, values, self.temperature_precision) def _pressure_clause(self, key, alias, values): """ Return SQL query predicate for querying pressure. """ - return self.double_clause(key, alias, values, self.pressure_precision) + return self._double_clause(key, alias, values, self.pressure_precision) # for the web query - def _parse_all(k, v): + @staticmethod + def _parse_all(key, values): # pylint: disable=unused-argument """ Convert numbers, strings, lists into strings. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if type(v) is list: - retval = ' '.join(v) - elif type(v) is int: - retval = str(v) - elif type(v) is str: - retval = v + if isinstance(values, list): + retval = ' '.join(values) + elif isinstance(values, int): + retval = str(values) + elif isinstance(values, str): + retval = values return retval - def _parse_number(k, v): + @staticmethod + def _parse_number(key, values): # pylint: disable=unused-argument """ Convert int into string. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if type(v) is int: - retval = str(v) - elif type(v) is str: - retval = v + if isinstance(values, int): + retval = str(values) + elif isinstance(values, str): + retval = values return retval - def _parse_mineral(k, v): + @staticmethod + def _parse_mineral(key, values): """ Convert mineral_name and chemical_name into right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if k == 'mineral_name': - retval = 'M=' + v - elif k == 'chemical_name': - retval = 'C=' + v + if key == 'mineral_name': + retval = 'M=' + values + elif key == 'chemical_name': + retval = 'C=' + values return retval - def _parse_volume(k, v): + @staticmethod + def _parse_volume(key, values): # pylint: disable=too-many-return-statements """ Convert volume, cell parameter and angle queries into right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if k == 'volume': - return 'v=' + v - elif k == 'a': - return 'a=' + v - elif k == 'b': - return 'b=' + v - elif k == 'c': - return 'c=' + v - elif k == 'alpha': - return 'al=' + v - elif k == 'beta': - return 'be=' + v - elif k == 'gamma': - return 'ga=' + v - - def _parse_system(k, v): + if key == 'volume': + return 'v=' + values + + if key == 'a': + return 'a=' + values + + if key == 'b': + return 'b=' + values + + if key == 'c': + return 'c=' + values + + if key == 'alpha': + return 'al=' + values + + if key == 'beta': + return 'be=' + values + + if key == 'gamma': + return 'ga=' + values + + @staticmethod + def _parse_system(key, values): # pylint: disable=unused-argument """ Return crystal system in the right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ valid_systems = { @@ -304,54 +301,56 @@ def _parse_system(k, v): 'triclinic': 'TC' } - return valid_systems[v.lower()] + return valid_systems[values.lower()] # mysql database - query parameter (alias) : [mysql keyword (key), function to call] - keywords_db = {'id': ['COLL_CODE', _int_clause], - 'element': ['SUM_FORM;', _composition_clause], - 'number_of_elements': ['EL_COUNT', _int_clause], - 'chemical_name': ['CHEM_NAME', _str_fuzzy_clause], - 'formula': ['SUM_FORM', _formula_clause], - 'volume': ['C_VOL', _volume_clause], - 'spacegroup': ['SGR', _str_exact_clause], - 'a': ['A_LEN', _length_clause], - 'b': ['B_LEN', _length_clause], - 'c': ['C_LEN', _length_clause], - 'alpha': ['ALPHA', _angle_clause], - 'beta': ['BETA', _angle_clause], - 'gamma': ['GAMMA', _angle_clause], - 'density': ['DENSITY_CALC', _density_clause], - 'wyckoff': ['WYCK', _str_exact_clause], - 'molar_mass': ['MOL_MASS', _density_clause], - 'pdf_num': ['PDF_NUM', _str_exact_clause], - 'z': ['Z', _int_clause], - 'measurement_temp': ['TEMPERATURE', _temperature_clause], - 'authors': ['AUTHORS_TEXT', _str_fuzzy_clause], - 'journal': ['journal', _str_fuzzy_clause], - 'title': ['AU_TITLE', _str_fuzzy_clause], - 'year': ['MPY', _int_clause], - 'crystal_system': ['CRYST_SYS_CODE', _crystal_system_clause], + keywords_db = { + 'id': ['COLL_CODE', _int_clause], + 'element': ['SUM_FORM;', _composition_clause], + 'number_of_elements': ['EL_COUNT', _int_clause], + 'chemical_name': ['CHEM_NAME', _str_fuzzy_clause], + 'formula': ['SUM_FORM', _formula_clause], + 'volume': ['C_VOL', _volume_clause], + 'spacegroup': ['SGR', _str_exact_clause], + 'a': ['A_LEN', _length_clause], + 'b': ['B_LEN', _length_clause], + 'c': ['C_LEN', _length_clause], + 'alpha': ['ALPHA', _angle_clause], + 'beta': ['BETA', _angle_clause], + 'gamma': ['GAMMA', _angle_clause], + 'density': ['DENSITY_CALC', _density_clause], + 'wyckoff': ['WYCK', _str_exact_clause], + 'molar_mass': ['MOL_MASS', _density_clause], + 'pdf_num': ['PDF_NUM', _str_exact_clause], + 'z': ['Z', _int_clause], + 'measurement_temp': ['TEMPERATURE', _temperature_clause], + 'authors': ['AUTHORS_TEXT', _str_fuzzy_clause], + 'journal': ['journal', _str_fuzzy_clause], + 'title': ['AU_TITLE', _str_fuzzy_clause], + 'year': ['MPY', _int_clause], + 'crystal_system': ['CRYST_SYS_CODE', _crystal_system_clause], } # keywords accepted for the web page query - keywords = {'id': ('authors', _parse_all), - 'authors': ('authors', _parse_all), - 'element': ('elements', _parse_all), - 'number_of_elements': ('elementc', _parse_all), - 'mineral_name': ('mineral', _parse_mineral), - 'chemical_name': ('mineral', _parse_mineral), - 'formula': ('formula', _parse_all), - 'volume': ('volume', _parse_volume), - 'a': ('volume', _parse_volume), - 'b': ('volume', _parse_volume), - 'c': ('volume', _parse_volume), - 'alpha': ('volume', _parse_volume), - 'beta': ('volume', _parse_volume), - 'gamma': ('volume', _parse_volume), - 'spacegroup': ('spaceg', _parse_all), - 'journal': ('journal', _parse_all), - 'title': ('title', _parse_all), - 'year': ('year', _parse_all), - 'crystal_system': ('system', _parse_system), + keywords = { + 'id': ('authors', _parse_all), + 'authors': ('authors', _parse_all), + 'element': ('elements', _parse_all), + 'number_of_elements': ('elementc', _parse_all), + 'mineral_name': ('mineral', _parse_mineral), + 'chemical_name': ('mineral', _parse_mineral), + 'formula': ('formula', _parse_all), + 'volume': ('volume', _parse_volume), + 'a': ('volume', _parse_volume), + 'b': ('volume', _parse_volume), + 'c': ('volume', _parse_volume), + 'alpha': ('volume', _parse_volume), + 'beta': ('volume', _parse_volume), + 'gamma': ('volume', _parse_volume), + 'spacegroup': ('spaceg', _parse_all), + 'journal': ('journal', _parse_all), + 'title': ('title', _parse_all), + 'year': ('year', _parse_all), + 'crystal_system': ('system', _parse_system), } def query(self, **kwargs): @@ -361,11 +360,10 @@ def query(self, **kwargs): :param kwargs: A list of ''keyword = [values]'' pairs. """ - if self.db_parameters['querydb']: return self._query_sql_db(**kwargs) - else: - return self._queryweb(**kwargs) + + return self._queryweb(**kwargs) def _query_sql_db(self, **kwargs): """ @@ -377,13 +375,11 @@ def _query_sql_db(self, **kwargs): sql_where_query = [] # second part of sql query - for k, v in kwargs.items(): - if not isinstance(v, list): - v = [v] - sql_where_query.append('({})'.format(self.keywords_db[k][1](self, - self.keywords_db[k][0], - k, v))) - if 'crystal_system' in kwargs.keys(): # to query another table than the main one, add LEFT JOIN in front of WHERE + for key, value in kwargs.items(): + if not isinstance(value, list): + value = [value] + sql_where_query.append('({})'.format(self.keywords_db[key][1](self, self.keywords_db[key][0], key, value))) + if 'crystal_system' in kwargs: # to query another table than the main one, add LEFT JOIN in front of WHERE sql_query = 'LEFT JOIN space_group ON space_group.sgr=icsd.sgr LEFT '\ 'JOIN space_group_number ON '\ 'space_group_number.sgr_num=space_group.sgr_num '\ @@ -395,7 +391,6 @@ def _query_sql_db(self, **kwargs): return IcsdSearchResults(query=sql_query, db_parameters=self.db_parameters) - def _queryweb(self, **kwargs): """ Perform a query on the Icsd web database using ``keyword = value`` pairs, @@ -405,7 +400,7 @@ def _queryweb(self, **kwargs): :return: IcsdSearchResults """ from urllib.parse import urlencode - self.actual_args = { + self.actual_args = { # pylint: disable=attribute-defined-outside-init 'action': 'Search', 'nb_rows': '100', # max is 100 'order_by': 'yearDesc', @@ -414,10 +409,10 @@ def _queryweb(self, **kwargs): 'mineral': '' } - for k, v in kwargs.items(): + for key, value in kwargs.items(): try: - realname = self.keywords[k][0] - newv = self.keywords[k][1](k, v) + realname = self.keywords[key][0] + newv = self.keywords[key][1](key, value) # Because different keys correspond to the same search field. if realname in ['authors', 'volume', 'mineral']: self.actual_args[realname] = self.actual_args[realname] + newv + ' ' @@ -439,8 +434,8 @@ def setup_db(self, **kwargs): :param kwargs: db_parameters for the mysql database connection (host, user, passwd, db, port) """ - for key in self.db_parameters.keys(): - if key in kwargs.keys(): + for key in self.db_parameters: + if key in kwargs: self.db_parameters[key] = kwargs[key] def get_supported_keywords(self): @@ -449,11 +444,11 @@ def get_supported_keywords(self): """ if self.db_parameters['querydb']: return self.keywords_db.keys() - else: - return self.keywords.keys() + + return self.keywords.keys() -class IcsdSearchResults(DbSearchResults): +class IcsdSearchResults(DbSearchResults): # pylint: disable=abstract-method,too-many-instance-attributes """ Result manager for the query performed on ICSD. @@ -465,8 +460,8 @@ class IcsdSearchResults(DbSearchResults): db_name = 'Icsd' def __init__(self, query, db_parameters): - - self.db = None + # pylint: disable=super-init-not-called + self.db = None # pylint: disable=invalid-name self.cursor = None self.db_parameters = db_parameters self.query = query @@ -498,9 +493,9 @@ def next(self): if self.number_of_results > self.position: self.position = self.position + 1 return self.at(self.position - 1) - else: - self.position = 0 - raise StopIteration() + + self.position = 0 + raise StopIteration() def at(self, position): """ @@ -515,20 +510,23 @@ def at(self, position): if position not in self.entries: if self.db_parameters['querydb']: - self.entries[position] = IcsdEntry(self.db_parameters['server'] + - self.db_parameters['dl_db'] + self.cif_url.format( - self._results[position]), - db_name=self.db_name, id=self.cif_numbers[position], - version = self.db_version, - extras={'idnum': self._results[position]}) + self.entries[position] = IcsdEntry( + self.db_parameters['server'] + self.db_parameters['dl_db'] + + self.cif_url.format(self._results[position]), + db_name=self.db_name, + id=self.cif_numbers[position], + version=self.db_version, + extras={'idnum': self._results[position]} + ) else: - self.entries[position] = IcsdEntry(self.db_parameters['server'] + - self.db_parameters['dl_db'] + self.cif_url.format( - self._results[position]), - db_name=self.db_name, extras={'idnum': self._results[position]}) + self.entries[position] = IcsdEntry( + self.db_parameters['server'] + self.db_parameters['dl_db'] + + self.cif_url.format(self._results[position]), + db_name=self.db_name, + extras={'idnum': self._results[position]} + ) return self.entries[position] - def query_db_version(self): """ Query the version of the icsd database (last row of RELEASE_TAGS). @@ -554,8 +552,7 @@ def query_db_version(self): raise IcsdImporterExp('Database version not found') else: - raise NotImplementedError('Cannot query the database version with ' - 'a web query.') + raise NotImplementedError('Cannot query the database version with ' 'a web query.') def query_page(self): """ @@ -568,10 +565,9 @@ def query_page(self): if self.db_parameters['querydb']: self._connect_db() - query_statement = '{}{}{} LIMIT {}, 100'.format(self.sql_select_query, - self.sql_from_query, - self.query, - (self.page-1)*100) + query_statement = '{}{}{} LIMIT {}, 100'.format( + self.sql_select_query, self.sql_from_query, self.query, (self.page - 1) * 100 + ) self.cursor.execute(query_statement) self.db.commit() @@ -585,23 +581,21 @@ def query_page(self): self._disconnect_db() - else: - from bs4 import BeautifulSoup + from bs4 import BeautifulSoup # pylint: disable=import-error from urllib.request import urlopen import re - self.html = urlopen(self.db_parameters['server'] + - self.db_parameters['db'] + '/' + - self.query.format(str(self.page))).read() + self.html = urlopen( + self.db_parameters['server'] + self.db_parameters['db'] + '/' + self.query.format(str(self.page)) + ).read() self.soup = BeautifulSoup(self.html) try: if self.number_of_results is None: - self.number_of_results = int(re.findall(r'\d+', - str(self.soup.find_all('i')[-1]))[0]) + self.number_of_results = int(re.findall(r'\d+', str(self.soup.find_all('i')[-1]))[0]) except IndexError: raise NoResultsWebExp @@ -617,11 +611,12 @@ def _connect_db(self): except ImportError: import pymysql as MySQLdb - self.db = MySQLdb.connect(host=self.db_parameters['host'], - user=self.db_parameters['user'], - passwd=self.db_parameters['passwd'], - db=self.db_parameters['db'], - port=int(self.db_parameters['port']) + self.db = MySQLdb.connect( + host=self.db_parameters['host'], + user=self.db_parameters['user'], + passwd=self.db_parameters['passwd'], + db=self.db_parameters['db'], + port=int(self.db_parameters['port']) ) self.cursor = self.db.cursor() @@ -632,7 +627,7 @@ def _disconnect_db(self): self.db.close() -class IcsdEntry(CifEntry): +class IcsdEntry(CifEntry): # pylint: disable=abstract-method """ Represent an entry from Icsd. @@ -651,12 +646,14 @@ def __init__(self, uri, **kwargs): """ super().__init__(**kwargs) self.source = { - 'db_name': kwargs.get('db_name','Icsd'), + 'db_name': kwargs.get('db_name', 'Icsd'), 'db_uri': None, 'id': kwargs.get('id', None), 'version': kwargs.get('version', None), 'uri': uri, - 'extras': {'idnum': kwargs.get('extras', {}).get('idnum', None)}, + 'extras': { + 'idnum': kwargs.get('extras', {}).get('idnum', None) + }, 'license': self._license, } @@ -668,10 +665,11 @@ def contents(self): PyCifRW library (and most other sensible applications), expects UTF-8. Therefore, we decode the original CIF data to unicode and encode it in the UTF-8 format """ + import urllib.request + if self._contents is None: from hashlib import md5 - - self._contents = urlopen(self.source['uri']).read() + self._contents = urllib.request.urlopen(self.source['uri']).read() self._contents = self._contents.decode('iso-8859-1').encode('utf8') self.source['source_md5'] = md5(self._contents).hexdigest() @@ -682,9 +680,8 @@ def get_ase_structure(self): :return: ASE structure corresponding to the cif file. """ from aiida.orm import CifData - cif = correct_cif(self.cif) - return CifData.read_cif(StringIO(cif)) + return CifData.read_cif(io.StringIO(cif)) def correct_cif(cif): @@ -709,12 +706,12 @@ def correct_cif(cif): inc = 1 while True: words = lines[author_index + inc].split() - #in case loop is finished -> return cif lines. - #use regular expressions ? + # in case loop is finished -> return cif lines. + # use regular expressions ? if len(words) == 0 or words[0] == 'loop_' or words[0][0] == '_': return '\n'.join(lines) - elif ((words[0][0] == "'" and words[-1][-1] == "'") - or (words[0][0] == '"' and words[-1][-1] == '"')): + + if ((words[0][0] == "'" and words[-1][-1] == "'") or (words[0][0] == '"' and words[-1][-1] == '"')): # if quotes are already there, check next line inc = inc + 1 else: diff --git a/aiida/tools/dbimporters/plugins/materialsproject.py b/aiida/tools/dbimporters/plugins/materialsproject.py index 3da7bf9039..0f09d3b324 100644 --- a/aiida/tools/dbimporters/plugins/materialsproject.py +++ b/aiida/tools/dbimporters/plugins/materialsproject.py @@ -7,13 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Module that contains the class definitions necessary to offer support for -queries to Materials Project.""" - -import os +""""Implementation of `DbImporter` for the Materials Project database.""" import datetime +import os import requests + from pymatgen import MPRester + from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults diff --git a/aiida/tools/dbimporters/plugins/mpds.py b/aiida/tools/dbimporters/plugins/mpds.py index 00984b92b1..a5be35e9c4 100644 --- a/aiida/tools/dbimporters/plugins/mpds.py +++ b/aiida/tools/dbimporters/plugins/mpds.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +""""Implementation of `DbImporter` for the MPDS database.""" import copy import enum import os @@ -62,7 +62,7 @@ def __init__(self, url=None, api_key=None): self.setup_db(url=url, api_key=api_key) self._structures = StructuresCollection(self) - def setup_db(self, url=None, api_key=None, collection=None): + def setup_db(self, url=None, api_key=None, collection=None): # pylint: disable=arguments-differ """ Setup the required parameters for HTTP requests to the REST API @@ -112,7 +112,7 @@ def structures(self): return self._structures @property - def get_supported_keywords(self): + def get_supported_keywords(self): # pylint: disable=invalid-overridden-method """ Returns the list of all supported query keywords @@ -127,7 +127,7 @@ def url(self): """ return self._url - def query(self, query, collection=None): + def query(self, query, collection=None): # pylint: disable=arguments-differ """ Query the database with a given dictionary of query parameters for a given collection @@ -176,6 +176,8 @@ def find(self, query, fmt=DEFAULT_API_FORMAT): :param query: a dictionary with the query parameters """ + # pylint: disable=too-many-branches + if not isinstance(query, dict): raise TypeError('The query argument should be a dictionary') @@ -234,7 +236,8 @@ def get(self, fmt=DEFAULT_API_FORMAT, **kwargs): kwargs['fmt'] = fmt.value return requests.get(url=self.url, params=kwargs, headers={'Key': self.api_key}) - def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): + @staticmethod + def get_response_content(response, fmt=DEFAULT_API_FORMAT): """ Analyze the response of an HTTP GET request, verify that the response code is OK and return the json loaded response text @@ -254,10 +257,11 @@ def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): raise ValueError('Got error response: {}'.format(error)) return content - else: - return response.text - def get_id_from_cif(self, cif): + return response.text + + @staticmethod + def get_id_from_cif(cif): """ Extract the entry id from the string formatted cif response of the MPDS API @@ -275,6 +279,7 @@ def get_id_from_cif(self, cif): class StructuresCollection: + """Collection of structures.""" def __init__(self, engine): self._engine = engine @@ -305,11 +310,11 @@ class MpdsEntry(DbEntry): Represents an MPDS database entry """ - def __init__(self, url, **kwargs): + def __init__(self, _, **kwargs): """ Set the class license from the source dictionary """ - license = kwargs.pop('license', None) + license = kwargs.pop('license', None) # pylint: disable=redefined-builtin if license is not None: self._license = license @@ -317,7 +322,7 @@ def __init__(self, url, **kwargs): super().__init__(**kwargs) -class MpdsCifEntry(CifEntry, MpdsEntry): +class MpdsCifEntry(CifEntry, MpdsEntry): # pylint: disable=abstract-method """ An extension of the MpdsEntry class with the CifEntry class, which will treat the contents property through the URI as a cif file @@ -337,10 +342,8 @@ def __init__(self, url, **kwargs): self.cif = cif -class MpdsSearchResults(DbSearchResults): - """ - A collection of MpdsEntry query result entries - """ +class MpdsSearchResults(DbSearchResults): # pylint: disable=abstract-method + """Collection of MpdsEntry query result entries.""" _db_name = 'Materials Platform for Data Science' _db_uri = 'https://mpds.io/' diff --git a/aiida/tools/dbimporters/plugins/mpod.py b/aiida/tools/dbimporters/plugins/mpod.py index 88f67abcda..139dcf963e 100644 --- a/aiida/tools/dbimporters/plugins/mpod.py +++ b/aiida/tools/dbimporters/plugins/mpod.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the MPOD database.""" +from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) class MpodDbImporter(DbImporter): @@ -24,15 +22,16 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str) and not isinstance(values, int): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings and integers are accepted") + raise ValueError("incorrect value for keyword '" + alias + "' -- only strings and integers are accepted") return '{}={}'.format(key, values) - _keywords = {'phase_name': ['phase_name', _str_clause], - 'formula': ['formula', _str_clause], - 'element': ['element', None], - 'cod_id': ['cod_code', _str_clause], - 'authors': ['publ_author', _str_clause]} + _keywords = { + 'phase_name': ['phase_name', _str_clause], + 'formula': ['formula', _str_clause], + 'element': ['element', None], + 'cod_id': ['cod_code', _str_clause], + 'authors': ['publ_author', _str_clause] + } def __init__(self, **kwargs): self._query_url = 'http://mpod.cimav.edu.mx/data/search/' @@ -46,8 +45,7 @@ def query_get(self, **kwargs): :return: a list containing strings for HTTP GET statement. """ if 'formula' in kwargs.keys() and 'element' in kwargs.keys(): - raise ValueError('can not query both formula and elements ' - 'in MPOD') + raise ValueError('can not query both formula and elements ' 'in MPOD') elements = [] if 'element' in kwargs.keys(): @@ -56,25 +54,18 @@ def query_get(self, **kwargs): elements = [elements] get_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) - get_parts.append( - self._keywords[key][1](self, - self._keywords[key][0], - key, - values)) + get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) - if kwargs.keys(): - raise NotImplementedError("search keyword(s) '" - "', '".join(kwargs.keys()) + "' " - 'is(are) not implemented for MPOD') + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) queries = [] - for e in elements: - queries.append(self._query_url + '?' + - '&'.join(get_parts + - [self._str_clause('formula', 'element', e)])) + for element in elements: + clauses = [self._str_clause('formula', 'element', element)] + queries.append(self._query_url + '?' + '&'.join(get_parts + clauses)) if not queries: queries.append(self._query_url + '?' + '&'.join(get_parts)) @@ -103,18 +94,15 @@ def query(self, **kwargs): return MpodSearchResults([{'id': x} for x in results]) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ if query_url: self._query_url = query_url - if kwargs.keys(): - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: 'query_url'") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) def get_supported_keywords(self): """ @@ -125,7 +113,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class MpodSearchResults(DbSearchResults): +class MpodSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on MPOD. """ @@ -156,7 +144,7 @@ def _get_url(self, result_dict): return self._base_url + result_dict['id'] + '.mpod' -class MpodEntry(CifEntry): +class MpodEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from MPOD. """ @@ -167,7 +155,6 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.mpod.MpodEntry`, related to the supplied URI. """ - super().__init__(db_name='Material Properties Open Database', - db_uri='http://mpod.cimav.edu.mx', - uri=uri, - **kwargs) + super().__init__( + db_name='Material Properties Open Database', db_uri='http://mpod.cimav.edu.mx', uri=uri, **kwargs + ) diff --git a/aiida/tools/dbimporters/plugins/nninc.py b/aiida/tools/dbimporters/plugins/nninc.py index 0e7f13bcc3..ce2e724f47 100644 --- a/aiida/tools/dbimporters/plugins/nninc.py +++ b/aiida/tools/dbimporters/plugins/nninc.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - UpfEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the NNIN/C database.""" +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, UpfEntry class NnincDbImporter(DbImporter): @@ -24,14 +22,18 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str): - raise ValueError("incorrect value for keyword '{}' -- only " - 'strings and integers are accepted'.format(alias)) + raise ValueError( + "incorrect value for keyword '{}' -- only " + 'strings and integers are accepted'.format(alias) + ) return '{}={}'.format(key, values) - _keywords = {'xc_approximation': ['frmxcprox', _str_clause], - 'xc_type': ['frmxctype', _str_clause], - 'pseudopotential_class': ['frmspclass', _str_clause], - 'element': ['element', None]} + _keywords = { + 'xc_approximation': ['frmxcprox', _str_clause], + 'xc_type': ['frmxctype', _str_clause], + 'pseudopotential_class': ['frmspclass', _str_clause], + 'element': ['element', None] + } def __init__(self, **kwargs): self._query_url = 'http://nninc.cnf.cornell.edu/dd_search.php' @@ -45,20 +47,14 @@ def query_get(self, **kwargs): :return: a string with HTTP GET statement. """ get_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) if self._keywords[key][1] is not None: - get_parts.append( - self._keywords[key][1](self, - self._keywords[key][0], - key, - values)) + get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) - if kwargs.keys(): - raise NotImplementedError("search keyword(s) '" - "', '".join(kwargs.keys()) + \ - "' is(are) not implemented for NNIN/C") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) return self._query_url + '?' + '&'.join(get_parts) @@ -91,7 +87,7 @@ def query(self, **kwargs): return NnincSearchResults([{'id': x} for x in results]) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ @@ -113,7 +109,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class NnincSearchResults(DbSearchResults): +class NnincSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on NNIN/C Pseudopotential Virtual Vault. @@ -156,7 +152,6 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.nninc.NnincEntry`, related to the supplied URI. """ - super().__init__(db_name='NNIN/C Pseudopotential Virtual Vault', - db_uri='http://nninc.cnf.cornell.edu', - uri=uri, - **kwargs) + super().__init__( + db_name='NNIN/C Pseudopotential Virtual Vault', db_uri='http://nninc.cnf.cornell.edu', uri=uri, **kwargs + ) diff --git a/aiida/tools/dbimporters/plugins/oqmd.py b/aiida/tools/dbimporters/plugins/oqmd.py index b2249ade74..5af78fe127 100644 --- a/aiida/tools/dbimporters/plugins/oqmd.py +++ b/aiida/tools/dbimporters/plugins/oqmd.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the OQMD database.""" +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry class OqmdDbImporter(DbImporter): @@ -24,8 +22,7 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str) and not isinstance(values, int): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings and integers are accepted") + raise ValueError("incorrect value for keyword '" + alias + "' -- only strings and integers are accepted") return '{}={}'.format(key, values) _keywords = {'element': ['element', None]} @@ -65,27 +62,22 @@ def query(self, **kwargs): results = [] for entry in entries: - response = urlopen('{}{}'.format(self._query_url, - entry)).read() - structures = re.findall(r'/materials/export/conventional/cif/(\d+)', - response) + response = urlopen('{}{}'.format(self._query_url, entry)).read() + structures = re.findall(r'/materials/export/conventional/cif/(\d+)', response) for struct in structures: results.append({'id': struct}) return OqmdSearchResults(results) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ if query_url: self._query_url = query_url - if kwargs.keys(): - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: 'query_url'") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) def get_supported_keywords(self): """ @@ -96,7 +88,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class OqmdSearchResults(DbSearchResults): +class OqmdSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on OQMD. """ @@ -127,7 +119,7 @@ def _get_url(self, result_dict): return self._base_url + result_dict['id'] -class OqmdEntry(CifEntry): +class OqmdEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from OQMD. """ @@ -138,7 +130,4 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.oqmd.OqmdEntry`, related to the supplied URI. """ - super().__init__(db_name='Open Quantum Materials Database', - db_uri='http://oqmd.org', - uri=uri, - **kwargs) + super().__init__(db_name='Open Quantum Materials Database', db_uri='http://oqmd.org', uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/pcod.py b/aiida/tools/dbimporters/plugins/pcod.py index 7a7595c66c..4550ab5634 100644 --- a/aiida/tools/dbimporters/plugins/pcod.py +++ b/aiida/tools/dbimporters/plugins/pcod.py @@ -7,10 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - -from aiida.tools.dbimporters.plugins.cod import (CodDbImporter, - CodSearchResults, CodEntry) - +""""Implementation of `DbImporter` for the PCOD database.""" +from aiida.tools.dbimporters.plugins.cod import CodDbImporter, CodSearchResults, CodEntry class PcodDbImporter(CodDbImporter): @@ -18,50 +16,25 @@ class PcodDbImporter(CodDbImporter): Database importer for Predicted Crystallography Open Database. """ - def _int_clause(self, *args, **kwargs): - return super()._int_clause(*args, **kwargs) - - def _composition_clause(self, *args, **kwargs): - return super()._composition_clause(*args, **kwargs) - - def _formula_clause(self, *args, **kwargs): - return super()._formula_clause(*args, **kwargs) - - def _volume_clause(self, *args, **kwargs): - return super()._volume_clause(*args, **kwargs) - - def _str_exact_clause(self, *args, **kwargs): - return super()._str_exact_clause(*args, **kwargs) - - def _length_clause(self, *args, **kwargs): - return super()._length_clause(*args, **kwargs) - - def _angle_clause(self, *args, **kwargs): - return super()._angle_clause(*args, **kwargs) - - def _str_fuzzy_clause(self, *args, **kwargs): - return super()._str_fuzzy_clause(*args, **kwargs) - - _keywords = {'id': ['file', _int_clause], - 'element': ['element', _composition_clause], - 'number_of_elements': ['nel', _int_clause], - 'formula': ['formula', _formula_clause], - 'volume': ['vol', _volume_clause], - 'spacegroup': ['sg', _str_exact_clause], - 'a': ['a', _length_clause], - 'b': ['b', _length_clause], - 'c': ['c', _length_clause], - 'alpha': ['alpha', _angle_clause], - 'beta': ['beta', _angle_clause], - 'gamma': ['gamma', _angle_clause], - 'text': ['text', _str_fuzzy_clause]} + _keywords = { + 'id': ['file', CodDbImporter._int_clause], + 'element': ['element', CodDbImporter._composition_clause], + 'number_of_elements': ['nel', CodDbImporter._int_clause], + 'formula': ['formula', CodDbImporter._formula_clause], + 'volume': ['vol', CodDbImporter._volume_clause], + 'spacegroup': ['sg', CodDbImporter._str_exact_clause], + 'a': ['a', CodDbImporter._length_clause], + 'b': ['b', CodDbImporter._length_clause], + 'c': ['c', CodDbImporter._length_clause], + 'alpha': ['alpha', CodDbImporter._angle_clause], + 'beta': ['beta', CodDbImporter._angle_clause], + 'gamma': ['gamma', CodDbImporter._angle_clause], + 'text': ['text', CodDbImporter._str_fuzzy_clause] + } def __init__(self, **kwargs): super().__init__(**kwargs) - self._db_parameters = {'host': 'www.crystallography.net', - 'user': 'pcod_reader', - 'passwd': '', - 'db': 'pcod'} + self._db_parameters = {'host': 'www.crystallography.net', 'user': 'pcod_reader', 'passwd': '', 'db': 'pcod'} self.setup_db(**kwargs) def query_sql(self, **kwargs): @@ -72,24 +45,16 @@ def query_sql(self, **kwargs): :return: string containing a SQL statement. """ sql_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) if not isinstance(values, list): values = [values] - sql_parts.append( \ - '(' + self._keywords[key][1](self, \ - self._keywords[key][0], \ - key, \ - values) + \ - ')') - if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "search keyword(s) '" + \ - "', '".join(kwargs.keys()) + "' " + \ - 'is(are) not implemented for PCOD') - return 'SELECT file FROM data WHERE ' + \ - ' AND '.join(sql_parts) + sql_parts.append('(' + self._keywords[key][1](self, self._keywords[key][0], key, values) + ')') + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) + + return 'SELECT file FROM data WHERE ' + ' AND '.join(sql_parts) def query(self, **kwargs): """ @@ -113,7 +78,7 @@ def query(self, **kwargs): return PcodSearchResults(results) -class PcodSearchResults(CodSearchResults): +class PcodSearchResults(CodSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on PCOD. """ @@ -129,27 +94,25 @@ def _get_url(self, result_dict): :param result_dict: dictionary, describing an entry in the results. """ - return self._base_url + \ - result_dict['id'][0] + '/' + \ - result_dict['id'][0:3] + '/' + \ - result_dict['id'] + '.cif' + return self._base_url + result_dict['id'][0] + '/' + result_dict['id'][0:3] + '/' + result_dict['id'] + '.cif' -class PcodEntry(CodEntry): +class PcodEntry(CodEntry): # pylint: disable=abstract-method """ Represents an entry from PCOD. """ _license = 'CC0' - def __init__(self, uri, - db_name='Predicted Crystallography Open Database', - db_uri='http://www.crystallography.net/pcod', **kwargs): + def __init__( + self, + uri, + db_name='Predicted Crystallography Open Database', + db_uri='http://www.crystallography.net/pcod', + **kwargs + ): """ Creates an instance of :py:class:`aiida.tools.dbimporters.plugins.pcod.PcodEntry`, related to the supplied URI. """ - super().__init__(db_name=db_name, - db_uri=db_uri, - uri=uri, - **kwargs) + super().__init__(db_name=db_name, db_uri=db_uri, uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/tcod.py b/aiida/tools/dbimporters/plugins/tcod.py index 70cf74e37e..7abdbd1275 100644 --- a/aiida/tools/dbimporters/plugins/tcod.py +++ b/aiida/tools/dbimporters/plugins/tcod.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Importer implementation for the TCOD.""" +""""Implementation of `DbImporter` for the TCOD database.""" from aiida.tools.dbimporters.plugins.cod import (CodDbImporter, CodSearchResults, CodEntry) diff --git a/aiida/tools/importexport/common/utils.py b/aiida/tools/importexport/common/utils.py index 1c5214f3e5..0aef11888a 100644 --- a/aiida/tools/importexport/common/utils.py +++ b/aiida/tools/importexport/common/utils.py @@ -8,8 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Utility functions for import/export of AiiDA entities """ -# pylint: disable=inconsistent-return-statements,too-many-branches,too-many-return-statements -# pylint: disable=too-many-nested-blocks,too-many-locals +# pylint: disable=too-many-branches,too-many-return-statements,too-many-nested-blocks,too-many-locals from html.parser import HTMLParser import urllib.request import urllib.parse diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index ad2a43371f..d6646024c6 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -20,7 +20,7 @@ from aiida.common.folders import RepositoryFolder, SandboxFolder, Folder from aiida.common.lang import type_check from aiida.common.log import override_log_formatter, LOG_LEVEL_REPORT -from aiida.orm.utils.repository import Repository +from aiida.orm.utils._repository import Repository from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar from aiida.tools.importexport.common.config import EXPORT_VERSION, NODES_EXPORT_SUBFOLDER diff --git a/aiida/tools/importexport/dbimport/backends/django/__init__.py b/aiida/tools/importexport/dbimport/backends/django/__init__.py index 877b5f7119..7d1854bb4a 100644 --- a/aiida/tools/importexport/dbimport/backends/django/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/django/__init__.py @@ -23,7 +23,7 @@ from aiida.common.log import override_log_formatter from aiida.common.utils import grouper, get_object_from_string from aiida.manage.configuration import get_config_option -from aiida.orm.utils.repository import Repository +from aiida.orm.utils._repository import Repository from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar diff --git a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py index 134e4a4a60..ecf1f3429e 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py @@ -24,7 +24,7 @@ from aiida.common.utils import get_object_from_string from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.orm.utils.links import link_triple_exists, validate_link -from aiida.orm.utils.repository import Repository +from aiida.orm.utils._repository import Repository from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip diff --git a/aiida/tools/importexport/dbimport/utils.py b/aiida/tools/importexport/dbimport/utils.py index faaac81e1c..d25aab0cc0 100644 --- a/aiida/tools/importexport/dbimport/utils.py +++ b/aiida/tools/importexport/dbimport/utils.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Utility functions for import of AiiDA entities """ -# pylint: disable=inconsistent-return-statements,too-many-branches +# pylint: disable=too-many-branches import os import click diff --git a/aiida/tools/importexport/migration/__init__.py b/aiida/tools/importexport/migration/__init__.py index 402147ff7b..fa23050277 100644 --- a/aiida/tools/importexport/migration/__init__.py +++ b/aiida/tools/importexport/migration/__init__.py @@ -52,8 +52,8 @@ def migrate_recursively(metadata, data, folder, version=EXPORT_VERSION): try: if old_version == version: - raise ArchiveMigrationError('Your export file is already at the version {}'.format(version)) - elif old_version > version: + return old_version + if old_version > version: raise ArchiveMigrationError('Backward migrations are not supported') elif old_version in MIGRATE_FUNCTIONS: MIGRATE_FUNCTIONS[old_version](metadata, data, folder) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 70b6df7231..362d069e22 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -221,11 +221,11 @@ def default_node_sublabels(node): elif class_node_type == 'data.bool.Bool.': sublabel = '{}'.format(node.get_attribute('value', '')) elif class_node_type == 'data.code.Code.': - sublabel = '{}@{}'.format(os.path.basename(node.get_execname()), node.get_computer_name()) + sublabel = '{}@{}'.format(os.path.basename(node.get_execname()), node.computer.label) elif class_node_type == 'data.singlefile.SinglefileData.': sublabel = node.filename elif class_node_type == 'data.remote.RemoteData.': - sublabel = '@{}'.format(node.get_computer_name()) + sublabel = '@{}'.format(node.computer.label) elif class_node_type == 'data.structure.StructureData.': sublabel = node.get_formula() elif class_node_type == 'data.cif.CifData.': diff --git a/aiida/transports/cli.py b/aiida/transports/cli.py index f6d5b22e04..eba3ac0c76 100644 --- a/aiida/transports/cli.py +++ b/aiida/transports/cli.py @@ -26,10 +26,10 @@ # pylint: disable=unused-argument def match_comp_transport(ctx, param, computer, transport_type): """Check the computer argument against the transport type.""" - if computer.get_transport_type() != transport_type: + if computer.transport_type != transport_type: echo.echo_critical( 'Computer {} has transport of type "{}", not {}!'.format( - computer.name, computer.get_transport_type(), transport_type + computer.label, computer.transport_type, transport_type ) ) return computer @@ -42,12 +42,12 @@ def configure_computer_main(computer, user, **kwargs): user = user or orm.User.objects.get_default() - echo.echo_info('Configuring computer {} for user {}.'.format(computer.name, user.email)) + echo.echo_info('Configuring computer {} for user {}.'.format(computer.label, user.email)) if user.email != get_manager().get_profile().default_user: echo.echo_info('Configuring different user, defaults may not be appropriate.') computer.configure(user=user, **kwargs) - echo.echo_success('{} successfully configured for {}'.format(computer.name, user.email)) + echo.echo_success('{} successfully configured for {}'.format(computer.label, user.email)) def common_params(command_func): @@ -71,25 +71,33 @@ def transport_option_default(name, computer): return default -def interactive_default(transport_type, key, also_noninteractive=False): - """Create a contextual_default value callback for an auth_param key.""" +def interactive_default(key, also_non_interactive=False): + """Create a contextual_default value callback for an auth_param key. + + :param key: the name of the option. + :param also_non_interactive: indicates whether this option should provide a default also in non-interactive mode. If + False, the option will raise `MissingParameter` if no explicit value is specified when the command is called in + non-interactive mode. + """ @with_dbenv() def get_default(ctx): """Determine the default value from the context.""" from aiida import orm + if not also_non_interactive and ctx.params['non_interactive']: + raise click.MissingParameter() + user = ctx.params['user'] or orm.User.objects.get_default() computer = ctx.params['computer'] + try: authinfo = orm.AuthInfo.objects.get(dbcomputer_id=computer.id, aiidauser_id=user.id) except NotExistent: authinfo = orm.AuthInfo(computer=computer, user=user) - non_interactive = ctx.params['non_interactive'] - old_authparams = authinfo.get_auth_params() - if not also_noninteractive and non_interactive: - raise click.MissingParameter() - suggestion = old_authparams.get(key) + + auth_params = authinfo.get_auth_params() + suggestion = auth_params.get(key) suggestion = suggestion or transport_option_default(key, computer) return suggestion @@ -99,25 +107,25 @@ def get_default(ctx): def create_option(name, spec): """Create a click option from a name and partial specs as used in transport auth_options.""" from copy import deepcopy + spec = deepcopy(spec) name_dashed = name.replace('_', '-') option_name = '--{}'.format(name_dashed) existing_option = spec.pop('option', None) + if spec.pop('switch', False): option_name = '--{name}/--no-{name}'.format(name=name_dashed) - kwargs = {} - if 'default' in spec: - kwargs['show_default'] = True - else: - kwargs['contextual_default'] = interactive_default( - 'ssh', name, also_noninteractive=spec.pop('non_interactive_default', False) - ) + kwargs = {'cls': InteractiveOption, 'show_default': True} + + non_interactive_default = spec.pop('non_interactive_default', False) + kwargs['contextual_default'] = interactive_default(name, also_non_interactive=non_interactive_default) - kwargs['cls'] = InteractiveOption kwargs.update(spec) + if existing_option: return existing_option(**kwargs) + return click.option(option_name, **kwargs) diff --git a/aiida/transports/plugins/local.py b/aiida/transports/plugins/local.py index efc3628fdb..1a563e1f21 100644 --- a/aiida/transports/plugins/local.py +++ b/aiida/transports/plugins/local.py @@ -42,7 +42,6 @@ class LocalTransport(Transport): ``unset PYTHONPATH`` if you plan on running calculations that use Python. """ - # There are no valid parameters for the local transport _valid_auth_options = [] # There is no real limit on how fast you can safely connect to a localhost, unlike often the case with SSH transport @@ -744,7 +743,9 @@ def _exec_command_internal(self, command, **kwargs): # pylint: disable=unused-a # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. - command = 'bash -l -c ' + escape_for_bash(command) + bash_commmand = self._bash_command_str + '-c ' + + command = bash_commmand + escape_for_bash(command) proc = subprocess.Popen( command, @@ -803,11 +804,8 @@ def gotocomputer_command(self, remotedir): :param str remotedir: the full path of the remote directory """ - script = ' ; '.join([ - 'if [ -d {escaped_remotedir} ]', 'then cd {escaped_remotedir}', 'bash', "else echo ' ** The directory'", - "echo ' ** {remotedir}'", "echo ' ** seems to have been deleted, I logout...'", 'fi' - ]).format(escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir) - cmd = 'bash -c "{}"'.format(script) + connect_string = self._gotocomputer_string(remotedir) + cmd = 'bash -c {}'.format(connect_string) return cmd def rename(self, oldpath, newpath): diff --git a/aiida/transports/plugins/ssh.py b/aiida/transports/plugins/ssh.py index 0ca09c3b47..0ffd7b06e1 100644 --- a/aiida/transports/plugins/ssh.py +++ b/aiida/transports/plugins/ssh.py @@ -68,21 +68,27 @@ class SshTransport(Transport): # pylint: disable=too-many-public-methods # I disable 'password' and 'pkey' to avoid these data to get logged in the # aiida log file. _valid_connect_options = [ - ('username', { - 'prompt': 'User name', - 'help': 'user name for the computer', - 'non_interactive_default': True - }), - ('port', { - 'option': options.PORT, - 'prompt': 'port Nr', - 'non_interactive_default': True - }), + ( + 'username', { + 'prompt': 'User name', + 'help': 'Login user name on the remote machine.', + 'non_interactive_default': True + } + ), + ( + 'port', + { + 'option': options.PORT, + 'prompt': 'Port number', + 'non_interactive_default': True, + }, + ), ( 'look_for_keys', { + 'default': True, 'switch': True, 'prompt': 'Look for keys', - 'help': 'switch automatic key file discovery on / off', + 'help': 'Automatically look for private keys in the ~/.ssh folder.', 'non_interactive_default': True } ), @@ -90,7 +96,7 @@ class SshTransport(Transport): # pylint: disable=too-many-public-methods 'key_filename', { 'type': AbsolutePathOrEmptyParamType(dir_okay=False, exists=True), 'prompt': 'SSH key file', - 'help': 'Manually pass a key file if default path is not set in ssh config', + 'help': 'Absolute path to your private SSH key. Leave empty to use the path set in the SSH config.', 'non_interactive_default': True } ), @@ -98,62 +104,70 @@ class SshTransport(Transport): # pylint: disable=too-many-public-methods 'timeout', { 'type': int, 'prompt': 'Connection timeout in s', - 'help': 'time in seconds to wait for connection before giving up', + 'help': 'Time in seconds to wait for connection before giving up. Leave empty to use default value.', 'non_interactive_default': True } ), ( 'allow_agent', { + 'default': False, 'switch': True, 'prompt': 'Allow ssh agent', - 'help': 'switch to allow or disallow ssh agent', + 'help': 'Switch to allow or disallow using an SSH agent.', 'non_interactive_default': True } ), ( 'proxy_command', { 'prompt': 'SSH proxy command', - 'help': 'SSH proxy command', + 'help': 'SSH proxy command for tunneling through a proxy server.' + ' Leave empty to parse the proxy command from the SSH config file.', 'non_interactive_default': True } ), # Managed 'manually' in connect ( 'compress', { + 'default': True, 'switch': True, 'prompt': 'Compress file transfers', - 'help': 'switch file transfer compression on / off', + 'help': 'Turn file transfer compression on or off.', 'non_interactive_default': True } ), ( 'gss_auth', { + 'default': False, 'type': bool, 'prompt': 'GSS auth', - 'help': 'GSS auth for kerberos', + 'help': 'Enable when using GSS kerberos token to connect.', 'non_interactive_default': True } ), ( 'gss_kex', { + 'default': False, 'type': bool, 'prompt': 'GSS kex', - 'help': 'GSS kex for kerberos', + 'help': 'GSS kex for kerberos, if not configured in SSH config file.', 'non_interactive_default': True } ), ( 'gss_deleg_creds', { + 'default': False, 'type': bool, 'prompt': 'GSS deleg_creds', - 'help': 'GSS deleg_creds for kerberos', + 'help': 'GSS deleg_creds for kerberos, if not configured in SSH config file.', + 'non_interactive_default': True + } + ), + ( + 'gss_host', { + 'prompt': 'GSS host', + 'help': 'GSS host for kerberos, if not configured in SSH config file.', 'non_interactive_default': True } ), - ('gss_host', { - 'prompt': 'GSS host', - 'help': 'GSS host for kerberos', - 'non_interactive_default': True - }), # for Kerberos support through python-gssapi ] @@ -172,22 +186,29 @@ class SshTransport(Transport): # pylint: disable=too-many-public-methods _valid_auth_options = _valid_connect_options + [ ( 'load_system_host_keys', { + 'default': True, 'switch': True, 'prompt': 'Load system host keys', - 'help': 'switch loading system host keys on / off', + 'help': 'Load system host keys from default SSH location.', 'non_interactive_default': True } ), ( 'key_policy', { + 'default': 'RejectPolicy', 'type': click.Choice(['RejectPolicy', 'WarningPolicy', 'AutoAddPolicy']), 'prompt': 'Key policy', - 'help': 'SSH key policy', + 'help': 'SSH key policy if host is not known.', 'non_interactive_default': True } ) ] + # Max size of log message to print in _exec_command_internal. + # Unlimited by default, but can be cropped by a subclass + # if too large commands are sent, clogging the outputs or logs + _MAX_EXEC_COMMAND_LOG_SIZE = None + @classmethod def _get_username_suggestion_string(cls, computer): """ @@ -421,14 +442,29 @@ def open(self): ) raise - # Open also a SFTPClient - self._sftp = self._client.open_sftp() - # Set the current directory to a explicit path, and not to None - self._sftp.chdir(self._sftp.normalize('.')) + # Open also a File transport client. SFTP by default, pure SSH in ssh_only + self.open_file_transport() + + return self + + def open_file_transport(self): + """ + Open the SFTP channel, and handle error by directing customer to try another transport + """ + from aiida.common.exceptions import InvalidOperation + from paramiko.ssh_exception import SSHException + try: + self._sftp = self._client.open_sftp() + except SSHException: + raise InvalidOperation( + 'Error in ssh transport plugin. This may be due to the remote computer not supporting SFTP. ' + 'Try setting it up with the aiida.transports:ssh_only transport from the aiida-sshonly plugin instead.' + ) self._is_open = True - return self + # Set the current directory to a explicit path, and not to None + self._sftp.chdir(self._sftp.normalize('.')) def close(self): """ @@ -505,7 +541,7 @@ def chdir(self, path): # Note: I don't store the result of the function; if I have no # read permissions, this will raise an exception. try: - self.sftp.stat('.') + self.stat('.') except IOError as exc: if 'Permission denied' in str(exc): self.chdir(old_path) @@ -517,6 +553,35 @@ def normalize(self, path='.'): """ return self.sftp.normalize(path) + def stat(self, path): + """ + Retrieve information about a file on the remote system. The return + value is an object whose attributes correspond to the attributes of + Python's ``stat`` structure as returned by ``os.stat``, except that it + contains fewer fields. + The fields supported are: ``st_mode``, ``st_size``, ``st_uid``, + ``st_gid``, ``st_atime``, and ``st_mtime``. + + :param str path: the filename to stat + + :return: a `paramiko.sftp_attr.SFTPAttributes` object containing + attributes about the given file. + """ + return self.sftp.stat(path) + + def lstat(self, path): + """ + Retrieve information about a file on the remote system, without + following symbolic links (shortcuts). This otherwise behaves exactly + the same as `stat`. + + :param str path: the filename to stat + + :return: a `paramiko.sftp_attr.SFTPAttributes` object containing + attributes about the given file. + """ + return self.sftp.lstat(path) + def getcwd(self): """ Return the current working directory for this SFTP session, as @@ -647,7 +712,7 @@ def isdir(self, path): if not path: return False try: - return S_ISDIR(self.sftp.stat(path).st_mode) + return S_ISDIR(self.stat(path).st_mode) except IOError as exc: if getattr(exc, 'errno', None) == 2: # errno=2 means path does not exist: I return False @@ -826,7 +891,7 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr this_basename = os.path.relpath(path=this_source[0], start=localpath) try: - self.sftp.stat(os.path.join(remotepath, this_basename)) + self.stat(os.path.join(remotepath, this_basename)) except IOError as exc: import errno if exc.errno == errno.ENOENT: # Missing file @@ -994,7 +1059,7 @@ def get_attribute(self, path): """ from aiida.transports.util import FileAttribute - paramiko_attr = self.sftp.lstat(path) + paramiko_attr = self.lstat(path) aiida_attr = FileAttribute() # map the paramiko class into the aiida one # note that paramiko object contains more informations than the aiida @@ -1167,11 +1232,11 @@ def isfile(self, path): try: self.logger.debug( "stat for path '{}' ('{}'): {} [{}]".format( - path, self.sftp.normalize(path), self.sftp.stat(path), - self.sftp.stat(path).st_mode + path, self.normalize(path), self.stat(path), + self.stat(path).st_mode ) ) - return S_ISREG(self.sftp.stat(path).st_mode) + return S_ISREG(self.stat(path).st_mode) except IOError as exc: if getattr(exc, 'errno', None) == 2: # errno=2 means path does not exist: I return False @@ -1212,11 +1277,13 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): # else: command_to_execute = command - self.logger.debug('Command to be executed: {}'.format(command_to_execute)) + self.logger.debug('Command to be executed: {}'.format(command_to_execute[:self._MAX_EXEC_COMMAND_LOG_SIZE])) # Note: The default shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. - channel.exec_command('bash -l -c ' + escape_for_bash(command_to_execute)) + bash_commmand = self._bash_command_str + '-c ' + + channel.exec_command(bash_commmand + escape_for_bash(command_to_execute)) stdin = channel.makefile('wb', bufsize) stdout = channel.makefile('rb', bufsize) @@ -1282,21 +1349,23 @@ def gotocomputer_command(self, remotedir): further_params.append('-i {}'.format(escape_for_bash(self._connect_args['key_filename']))) further_params_str = ' '.join(further_params) - # I use triple strings because I both have single and double quotes, but I still want everything in - # a single line - connect_string = ( - """ssh -t {machine} {further_params} "if [ -d {escaped_remotedir} ] ;""" - """ then cd {escaped_remotedir} ; bash -l ; else echo ' ** The directory' ; """ - """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( - further_params=further_params_str, - machine=self._machine, - escaped_remotedir="'{}'".format(remotedir), - remotedir=remotedir - ) + + connect_string = self._gotocomputer_string(remotedir) + cmd = 'ssh -t {machine} {further_params} {connect_string}'.format( + further_params=further_params_str, + machine=self._machine, + connect_string=connect_string, ) + return cmd - # print connect_string - return connect_string + def _symlink(self, source, dest): + """ + Wrap SFTP symlink call without breaking API + + :param source: source of link + :param dest: link to create + """ + self.sftp.symlink(source, dest) def symlink(self, remotesource, remotedestination): """ @@ -1319,9 +1388,9 @@ def symlink(self, remotesource, remotedestination): for this_source in self.glob(source): # create the name of the link: take the last part of the path this_dest = os.path.join(remotedestination, os.path.split(this_source)[-1]) - self.sftp.symlink(this_source, this_dest) + self._symlink(this_source, this_dest) else: - self.sftp.symlink(source, dest) + self._symlink(source, dest) def path_exists(self, path): """ @@ -1329,7 +1398,7 @@ def path_exists(self, path): """ import errno try: - self.sftp.stat(path) + self.stat(path) except IOError as exc: if exc.errno == errno.ENOENT: return False diff --git a/aiida/transports/transport.py b/aiida/transports/transport.py index 87130aedd3..b98c11b480 100644 --- a/aiida/transports/transport.py +++ b/aiida/transports/transport.py @@ -47,21 +47,50 @@ class Transport(abc.ABC): _valid_auth_params = None _MAGIC_CHECK = re.compile('[*?[]') _valid_auth_options = [] - _common_auth_options = [( - 'safe_interval', { - 'type': float, - 'prompt': 'Connection cooldown time (s)', - 'help': 'Minimum time interval in seconds between consecutive connection openings', - 'callback': validate_positive_number - } - )] + _common_auth_options = [ + ( + 'use_login_shell', { + 'default': + True, + 'switch': + True, + 'prompt': + 'Use login shell when executing command', + 'help': + ' Not using a login shell can help suppress potential' + ' spurious text output that can prevent AiiDA from parsing the output of commands,' + ' but may result in some startup files (.profile) not being sourced.', + 'non_interactive_default': + True + } + ), + ( + 'safe_interval', { + 'type': float, + 'prompt': 'Connection cooldown time (s)', + 'help': 'Minimum time interval in seconds between opening new connections.', + 'callback': validate_positive_number + } + ), + ] def __init__(self, *args, **kwargs): # pylint: disable=unused-argument """ __init__ method of the Transport base class. + + :param safe_interval: (optional, default self._DEFAULT_SAFE_OPEN_INTERVAL) + Minimum time interval in seconds between opening new connections. + :param use_login_shell: (optional, default True) + if False, do not use a login shell when executing command """ from aiida.common import AIIDA_LOGGER self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL) + self._use_login_shell = kwargs.pop('use_login_shell', True) + if self._use_login_shell: + self._bash_command_str = 'bash -l ' + else: + self._bash_command_str = 'bash ' + self._logger = AIIDA_LOGGER.getChild('transport').getChild(self.__class__.__name__) self._logger_extra = None self._is_open = False @@ -159,11 +188,17 @@ def get_short_doc(cls): @classmethod def get_valid_transports(cls): + """Return the list of registered transport entry points. + + .. deprecated:: 1.4.0 + + Will be removed in `2.0.0`, use `aiida.plugins.entry_point.get_entry_point_names` instead """ - :return: a list of existing plugin names - """ + import warnings + from aiida.common.warnings import AiidaDeprecationWarning from aiida.plugins.entry_point import get_entry_point_names - + message = 'method is deprecated, use `aiida.plugins.entry_point.get_entry_point_names` instead' + warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member return get_entry_point_names('aiida.transports') @classmethod @@ -193,6 +228,13 @@ def _get_safe_interval_suggestion_string(cls, computer): # pylint: disable=unus """ return cls._DEFAULT_SAFE_OPEN_INTERVAL + @classmethod + def _get_use_login_shell_suggestion_string(cls, computer): # pylint: disable=unused-argument + """ + Return a suggestion for the specific field. + """ + return 'True' + @property def logger(self): """ @@ -756,6 +798,18 @@ def glob0(self, dirname, basename): def has_magic(self, string): return self._MAGIC_CHECK.search(string) is not None + def _gotocomputer_string(self, remotedir): + """command executed when goto computer.""" + connect_string = ( + """ "if [ -d {escaped_remotedir} ] ;""" + """ then cd {escaped_remotedir} ; {bash_command} ; else echo ' ** The directory' ; """ + """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( + bash_command=self._bash_command_str, escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir + ) + ) + + return connect_string + class TransportInternalError(InternalError): """ diff --git a/aiida/workflows/arithmetic/multiply_add.py b/aiida/workflows/arithmetic/multiply_add.py index 125df06698..d63ab0329e 100644 --- a/aiida/workflows/arithmetic/multiply_add.py +++ b/aiida/workflows/arithmetic/multiply_add.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements,no-member +# pylint: disable=no-member # start-marker for docs """Implementation of the MultiplyAddWorkChain for testing and demonstration purposes.""" from aiida.orm import Code, Int diff --git a/docs/source/developer_guide/core/caching.rst b/docs/source/developer_guide/core/caching.rst index 9853889b8b..3885222ae6 100644 --- a/docs/source/developer_guide/core/caching.rst +++ b/docs/source/developer_guide/core/caching.rst @@ -1,8 +1,8 @@ Caching: implementation details +++++++++++++++++++++++++++++++ -This section covers some details of the caching mechanism which are not discussed in the :ref:`user guide `. -If you are developing plugins and want to modify the caching behavior of your classes, we recommend you read :ref:`this section ` first. +This section covers some details of the caching mechanism which are not discussed in the :ref:`user guide `. +If you are developing plugins and want to modify the caching behavior of your classes, we recommend you read :ref:`this section ` first. .. _devel_controlling_hashing: diff --git a/docs/source/developer_guide/core/internals.rst b/docs/source/developer_guide/core/internals.rst index 2c97e5a450..ce32443512 100644 --- a/docs/source/developer_guide/core/internals.rst +++ b/docs/source/developer_guide/core/internals.rst @@ -17,18 +17,18 @@ This means that as soon as a node is stored, any attempt to alter its attributes Certain subclasses of nodes need to adapt this behavior however, as for example in the case of the :py:class:`~aiida.orm.nodes.process.process.ProcessNode` class (see `calculation updatable attributes`_), but since the immutability of stored nodes is a core concept of AiiDA, this behavior is nonetheless enforced on the node level. This guarantees that any subclasses of the Node class will respect this behavior unless it is explicitly overriden. -Node methods +Entity methods ****************** -- :py:meth:`~aiida.orm.utils.node.clean_value` takes a value and returns an object which can be serialized for storage in the database. +- :py:meth:`~aiida.orm.implementation.utils.clean_value` takes a value and returns an object which can be serialized for storage in the database. Such an object must be able to be subsequently deserialized without changing value. If a simple datatype is passed (integer, float, etc.), a check is performed to see if it has a value of ``nan`` or ``inf``, as these cannot be stored. Otherwise, if a list, tuple, dictionary, etc., is passed, this check is performed for each value it contains. This is done recursively, automatically handling the case of nested objects. It is important to note that iterable type objects are converted to lists during this process, and mappings are converted to normal dictionaries. For efficiency reasons, the cleaning of attribute values is delayed to the last moment possible. - This means that for an unstored node, new attributes are not cleaned but simply set in the cache of the underlying database model. - When the node is then stored, all attributes are cleaned in one fell swoop and if successful the values are flushed to the database. - Once a node is stored, there no longer is such a cache and so the attribute values are cleaned straight away for each call. + This means that for an unstored entity, new attributes are not cleaned but simply set in the cache of the underlying database model. + When the entity is then stored, all attributes are cleaned in one fell swoop and if successful the values are flushed to the database. + Once an entity is stored, there no longer is such a cache and so the attribute values are cleaned straight away for each call. The same mechanism holds for the cleaning of the values of extras. diff --git a/docs/source/get_started/codes.rst b/docs/source/get_started/codes.rst deleted file mode 100644 index 7e54477773..0000000000 --- a/docs/source/get_started/codes.rst +++ /dev/null @@ -1,181 +0,0 @@ -.. _setup_code: - -************ -Setup a code -************ - -Once you have at least one computer configured, you can configure the codes. -In AiiDA, for full reproducibility of each calculation, we store each code in -the database, and attach to each calculation a given code. This has the further -advantage to make very easy to query for all calculations that were run with -a given code (for instance because I am looking for phonon calculations, or -because I discovered that a specific version had a bug and I want to rerun -the calculations). - -In AiiDA, we distinguish two types of codes: **remote** codes and **local** codes, -where the distinction between the two is described here below. - -Remote codes ------------- -With remote codes we denote codes that are installed/compiled -on the remote computer. Indeed, this is very often the case for codes installed -in supercomputers for high-performance computing applications, because the -code is typically installed and optimized on the supercomputer. - -In AiiDA, a remote code is identified by two mandatory pieces of information: - -* A computer on which the code is (that must be a previously configured computer); -* The absolute path of the code executable on the remote computer. - -Local codes ------------ -With local codes we denote codes for which the code is not -already present on the remote machine, and must be copied for every submission. -This is the case if you have for instance a small, machine-independent Python -script that you did not copy previously in all your clusters. - -In AiiDA, a local code can be set up by specifying: - -* A folder, containing all files to be copied over at every submission -* The name of executable file among the files inside the folder specified above - -Setting up a code ------------------ - -The:: - - verdi code - -command allows to manage codes in AiiDA. - -To setup a new code, you execute:: - - verdi code setup - -and you will be guided through a process to setup your code. - - -.. tip:: The setup will ask you a few pieces of information. At every prompt, you can - type the ``?`` character and press ```` to get a more detailed - explanation of what is being asked. - -You will be asked for: - -* **Label**: A label to refer to this code. Note: this label is not enforced - to be unique. However, if you try to keep it unique, at least within - the same computer, you can use it later - to refer and use to your code. Otherwise, you need to remember its ``ID`` or ``UUID``. - -* **Description**: A human-readable description of this code (for instance "Quantum - Espresso v.5.0.2 with 5.0.3 patches, pw.x code, compiled with openmpi"). - -* **Default calculation input plugin**: A string that identifies the default input plugin to - be used to generate new calculations to use with this code. - This string has to be a valid string recognized by the ``CalculationFactory`` - function. To get the list of all available Calculation plugin strings, - use the ``verdi plugin list aiida.calculations`` command. - -* **Installed on target computer**: either True (for local codes) or False (for remote - codes). For the meaning of the distinction, see above. Depending - on your choice, you will be asked for: - - * REMOTE CODES: - - * **Remote computer name**: The computer name on which the code resides, - as configured and stored in the AiiDA database. - - * **Remote absolute path**: The (full) absolute path of the code executable - on the remote machine, *including the name of the executable*. - - * LOCAL CODES: - - * **Local directory containing the code**: The absolute path where the executable and all other - files needed to run the code are stored; these will be copied over to - the remote computers for every submitted calculation. - * **Relative path of executable inside code folder**: The relative path of the executable - file inside the folder entered in the previous step. - - -At the end of these steps, you will be prompted to edit a script, -and you will have the opportunity to include ``bash`` commands that will -be executed *before* running the submission script (after the -'pre execution script' lines) and *after* running the submission script -(after the 'Post execution script' separator). -This is intended for code-dependent settings, for instance to load modules or set variables -that are needed by the code. For example:: - - module load intelmpi - - -At the end, you will get a confirmation command, and also the ID of the code in the -database (the ``pk``, i.e. the principal key, and the ``uuid``). - -In a manner analogous to a computer setup, it is also possible to provide some (or all) the information -described above via a configuration file using :: - - verdi code setup --config code.yml - -where ``code.yml`` is a configuration file in the -`YAML format `_. - -This file contains the information in a series of key:value pairs: - -.. code-block:: yaml - - --- - label: "qe-6.3-pw" - description: "quantum_espresso v6.3" - input_plugin: "quantumespresso.pw" - on_computer: true - remote_abs_path: "/path/to/code/pw.x" - computer: "localhost" - prepend_text: | - module load module1 - module load module2 - append_text: " " - -.. tip:: The keys mirror the available options of the command, which you can print using: :: - - verdi code setup --help - - Note the syntax differences: remove the ``--`` prefix - and replace ``-`` within the keys by the underscore ``_``. - - -.. note:: Codes are a subclass of the :py:class:`Node ` class, - and as such you can attach any set of attributes to the code. These can - be extremely useful for querying: for instance, you can attach the version - of the code as an attribute, or the code family (for instance: "pw.x code of - Quantum Espresso") to later query for all runs done with a pw.x code and - version more recent than 5.0.0, for instance. However, in the - present AiiDA version you cannot add attributes from the command line using - ``verdi``, but you have to do it using Python code. - -.. note:: You can change the label of a code by using the following command:: - - verdi code relabel "ID" "new-label" - - (Without the quotation marks!) "ID" can either be the numeric ID (PK) of - the code (preferentially), or possibly its label (or ``label@computername``), - if this string uniquely identifies a code. - - You can also list all available codes (and their relative IDs) with:: - - verdi code list - - which also accepts some flags to filter only codes on a - given computer, only codes using a specific plugin, etc.; use the ``-h`` - command line option to see the documentation of all possible options. - - You can then get the information of a specific code with:: - - verdi code show "ID" - - Finally, to delete a code use:: - - verdi code delete "ID" - - (only if it wasn't used by any calculation, otherwise an exception - is raised). - -And now, you are ready to launch your calculations! diff --git a/docs/source/get_started/computers.rst b/docs/source/get_started/computers.rst deleted file mode 100644 index 748adbcb72..0000000000 --- a/docs/source/get_started/computers.rst +++ /dev/null @@ -1,399 +0,0 @@ -.. _setup_computer: - -**************** -Setup a computer -**************** - -A computer in AiiDA denotes any computational resource (with a batch job scheduler) on which you will run your calculations. -Computers typically are clusters or supercomputers. - -Remote computer requirements -============================ - -Requirements for a computer are: - -* It must run a Unix-like operating system -* It must have ``bash`` installed -* It should have a batch scheduler installed (see :doc:`here <../scheduler/index>` - for a list of supported batch schedulers) -* It must be accessible from the machine that runs AiiDA using one of the - available transports (see below). - -.. note:: - AiiDA will use ``bash`` on the remote computer, regardless of the default shell. - Please ensure that your remote ``bash`` configuration does not load a different shell. - -The first step is to choose the transport to connect to the computer. Typically, -you will want to use the SSH transport, apart from a few special cases where -SSH connection is not possible (e.g., because you cannot setup a password-less -connection to the computer). In this case, you can install AiiDA directly on -the remote cluster, and use the ``local`` transport (in this way, commands to -submit the jobs are simply executed on the AiiDA machine, and files are simply -copied on the disk instead of opening an SFTP connection). - -If you plan to use the ``local`` transport, you can skip to the next section. - -If you plan to use the ``SSH`` transport, you have to configure a password-less -login from your user to the cluster. To do so type first (only if you do not -already have some keys in your local ``~/.ssh`` directory - i.e. files like -``id_rsa.pub``):: - - ssh-keygen -t rsa -b 4096 -m PEM - -.. note:: The ``-m PEM`` flag is necessary in newer versions of OpenSSL that switched to a different key format by default. - As of 2019-08, the paramiko library used by AiiDA `only supports the PEM format `_. - -Then copy your keys to the remote computer (in ~/.ssh/authorized_keys) with:: - - ssh-copy-id YOURUSERNAME@YOURCLUSTERADDRESS - -replacing ``YOURUSERNAME`` and ``YOURCLUSTERADDRESS`` by respectively your username -and cluster address. Finally add the following lines to ~/.ssh/config (leaving an empty -line before and after):: - - Host YOURCLUSTERADDRESS - User YOURUSERNAME - IdentityFile YOURRSAKEY - -replacing ``YOURRSAKEY`` by the path to the rsa private key you want to use -(it should look like ``~/.ssh/id_rsa``). - -.. note:: In principle you don't have to put the ``IdentityFile`` line if you have - only one rsa key in your ``~/.ssh`` folder. - -Before proceeding to setup the computer, be sure that you are able to -connect to your cluster using:: - - ssh YOURCLUSTERADDRESS - -without the need to type a password. Moreover, make also sure you can connect -via ``sftp`` (needed to copy files). The following command:: - - sftp YOURCLUSTERADDRESS - -should show you a prompt without errors (possibly with a message saying -``Connected to YOURCLUSTERADDRESS``). - -.. note:: If the ``ssh`` command works, but the ``sftp`` command does not - (e.g. it just prints ``Connection closed``), a possible reason can be - that there is a line in your ``~/.bashrc`` (on the cluster) that either produces text output - or an error. Remove/comment it until no output or error is produced: this - should make ``sftp`` work again. - -Finally, try also:: - - ssh YOURCLUSTERADDRESS QUEUE_VISUALIZATION_COMMAND - -replacing ``QUEUE_VISUALIZATION_COMMAND`` by the scheduler command that prints on screen the -status of the queue on the cluster (i.e. ``qstat`` for PBSpro scheduler, ``squeue`` for SLURM, etc.). -It should print a snapshot of the queue status, without any errors. - -.. note:: If there are errors with the previous command, then - edit your ~/.bashrc file in the remote computer and add a line at the beginning - that adds the path to the scheduler commands, typically (here for - PBSpro):: - - export PATH=$PATH:/opt/pbs/default/bin - - Or, alternatively, find the path to the executables (like using ``which qsub``). - -.. note:: If you need your remote .bashrc to be sourced before you execute the code - (for instance to change the PATH), make sure the .bashrc file **does not** contain - lines like:: - - [ -z "$PS1" ] && return - - or:: - - case $- in - *i*) ;; - *) return;; - esac - - in the beginning (these would prevent the bashrc to be executed when you ssh - to the remote computer). You can check that e.g. the PATH variable is correctly - set upon ssh, by typing (in your local computer):: - - ssh YOURCLUSTERADDRESS 'echo $PATH' - - -.. note:: If you need to ssh to a computer *A* first, from which you can then - connect to computer *B* you wanted to connect to, you can use the - ``proxy_command`` feature of ssh, that we also support in - AiiDA. For more information, see :ref:`ssh_proxycommand`. - - -.. _computer_setup: - -Computer setup and configuration -================================ -The configuration of computers happens in two steps. - -.. note:: The commands use some ``readline`` extensions to provide default - answers, that require an advanced terminal. Therefore, run the commands from - a standard terminal, and not from embedded terminals as the ones included in - text editors, unless you know what you are doing. For instance, the - terminal embedded in ``emacs`` is known to give problems. - -1. **Setup of the computer**, using the:: - - verdi computer setup - - command. This command allows to create a new computer instance in the DB. - - .. tip:: The code will ask you a few pieces of information. At every prompt, you can - type the ``?`` character and press ```` to get a more detailed - explanation of what is being asked. - - .. tip:: You can press ``+C`` at any moment to abort the setup process. - Nothing will be stored in the DB. - - Here is a list of what is asked, together with an explanation. - - * **Computer label**: the (user-friendly) label of the new computer instance - which is about to be created in the DB (the label is used for instance when - you have to pick a computer to launch a calculation on it). Labels must - be unique. This command should be thought as a AiiDA-wise configuration of - computer, independent of the AiiDA user that will actually use it. - - * **Fully-qualified hostname**: the fully-qualified hostname of the computer - to which you want to connect (i.e., with all the dots: ``bellatrix.epfl.ch``, - and not just ``bellatrix``). Type ``localhost`` for the local transport. - - * **Description**: A human-readable description of this computer; this is - useful if you have a lot of computers and you want to add some text to - distinguish them (e.g.: "cluster of computers at EPFL, installed in 2012, - 2 GB of RAM per CPU") - - * **Enabled**: either True or False; if False, the computer is disabled - and calculations associated with it will not be submitted. This allows to - disable temporarily a computer if it is giving problems or it is down for - maintenance, without the need to delete it from the DB. - - * **Transport plugin**: The type of the transport to be used. A list of valid - transport types can be obtained typing ``?`` - - * **Scheduler plugin**: The name of the plugin to be used to manage the - job scheduler on the computer. A list of valid - scheduler plugins can be obtained typing ``?``. See - :doc:`here <../scheduler/index>` for a documentation of scheduler plugins - in AiiDA. - - * **shebang line** This is the first line in the beginning of the submission script. - The default is ``#!/bin/bash``. You can change this in order, for example, to add options, - such as the ``-l`` flag. Note that AiiDA only supports bash at this point! - - * **Work directory on the computer**: The absolute path of the directory on the - remote computer where AiiDA will run the calculations - (often, it is the scratch of the computer). You can (should) use the - ``{username}`` replacement, that will be replaced by your username on the - remote computer automatically: this allows the same computer to be used - by different users, without the need to setup a different computer for - each one. Example:: - - /scratch/{username}/aiida_work/ - - * **Mpirun command**: The ``mpirun`` command needed on the cluster to run parallel MPI - programs. You can (should) use the ``{tot_num_mpiprocs}`` replacement, - that will be replaced by the total number of cpus, or the other - scheduler-dependent fields (see the :doc:`scheduler docs <../scheduler/index>` - for more information). Some examples:: - - mpirun -np {tot_num_mpiprocs} - aprun -n {tot_num_mpiprocs} - poe - - * **Default number of CPUs per machine**: The number of MPI processes per machine that - should be executed if it is not otherwise specified. Use ``0`` to specify no default value. - - At the end, the command will open your default editor on a file containing a summary - of the configuration up to this point, and the possibility to add ``bash`` - commands that will be executed either *before* the actual execution of the job - (under 'pre-execution script') or *after* the script submission (under 'Post execution script'). - These additional lines need may set up the environment on the computer, - for example loading modules or exporting environment variables, for example:: - - export NEWVAR=1 - source some/file - - .. note:: Don't specify settings here that are specific to a code, calculation or scheduler -- - you can set further pre-execution commands at the ``Code`` and ``CalcJob`` level. - - When you are done editing, save and quit (e.g. ``:wq`` in ``vim``). - The computer has now been created in the database but you still need to *configure* access to it - using your credentials. - - In order to avoid having to retype the setup information the next time round, it is also possible provide some (or all) of the information - described above via a configuration file using:: - - verdi computer setup --config computer.yml - - where ``computer.yml`` is a configuration file in the - `YAML format `_. - This file contains the information in a series of key:value pairs: - - .. code-block:: yaml - - --- - label: "localhost" - hostname: "localhost" - transport: local - scheduler: "direct" - work_dir: "/home/max/.aiida_run" - mpirun_command: "mpirun -np {tot_num_mpiprocs}" - mpiprocs_per_machine: "2" - prepend_text: | - module load mymodule - export NEWVAR=1 - - .. tip:: The list of the keys that can be used is available from the options flags of the command: :: - - verdi computer setup --help - - Note the syntax differences: remove the ``--`` prefix - and replace ``-`` within the keys by the underscore ``_``. - - - -2. **Configuration of the computer**, using the:: - - verdi computer configure TRANSPORTTYPE COMPUTERNAME - - command, with the appropriate transport type (``ssh`` or ``local``) and computer label. - - The configuration allows to access more detailed configurations, that are - often user-dependent and depend on the specific transport. - - The command will try to provide automatically default answers, - that can be selected by pressing enter. - - For ``local`` transport, the only information required is the minimum - time interval between conections to the computer. - - For ``ssh`` transport, the following will be asked: - - * **User name**: your username on the remote machine - * **port Nr**: the port to connect to (the default SSH port is 22) - * **Look_for_keys**: automatically look for the private key in ``~/.ssh``. - Default: False. - * **SSH key file**: the absolute path to your private SSH key. You can leave - it empty to use the default SSH key, if you set ``look_for_keys`` to True. - * **Connection timeout**: A timeout in seconds if there is no response (e.g., the - machine is down. You can leave it empty to use the default value.) - * **Allow_ssh agent**: If True, it will try to use an SSH agent. - * **SSH proxy_command**: Leave empty if you do not need a proxy command (i.e., - if you can directly connect to the machine). If you instead need to connect - to an intermediate computer first, you need to provide here the - command for the proxy: see documentation :ref:`here ` - for how to use this option, and in particular the notes - :ref:`here ` for the format of this field. - * **Compress file transfer**: True to compress the traffic (recommended) - * **GSS auth**: yes when using Kerberos token to connect - * **GSS kex**: yes when using Kerberos token to connect, in some cases - (depending on your ``.ssh/config`` file) - * **GSS deleg_creds**: yes when using Kerberos token to connect, in - some cases (depending on your ``.ssh/config`` file) - * **GSS host**: hostname when using Kerberos token to connect (defaults - to the remote computer hostname) - * **Load system host keys**: True to load the known hosts keys from the - default SSH location (recommended) - * **key policy**: What is the policy in case the host is not known. - It is a string among the following: - - * ``RejectPolicy`` (default, recommended): reject the connection if the - host is not known. - * ``WarningPolicy`` (*not* recommended): issue a warning if the - host is not known. - * ``AutoAddPolicy`` (*not* recommended): automatically add the host key - at the first connection to the host. - * **Connection cooldown time (s)**: The minimum time interval between consecutive - connection openings to the remote machine. - -After setup and configuration have been completed, your computer is ready to go! - -.. note:: If the cluster you are using requires authentication through a Kerberos - token (that you need to obtain before using ssh), you typically need to install - ``libffi`` (``sudo apt-get install libffi-dev`` under Ubuntu), and make sure you install - the ``ssh_kerberos`` :ref:`optional dependencies` during the installation process of AiiDA. - Then, if your ``.ssh/config`` file is configured properly (in particular includes - all the necessary ``GSSAPI`` options), ``verdi computer configure`` will - contain already the correct suggestions for all the gss options needed to support Kerberos. - -.. note:: To check if you set up the computer correctly, - execute:: - - verdi computer test COMPUTERNAME - - that will run a few tests (file copy, file retrieval, check of the jobs in - the scheduler queue) to verify that everything works as expected. - -.. note:: If you are not sure if your computer is already set up, use the command:: - - verdi computer list - - to get a list of existing computers, and:: - - verdi computer show COMPUTERNAME - - to get detailed information on the specific computer named ``COMPUTERNAME``. - You have also the:: - - verdi computer rename OLDCOMPUTERNAME NEWCOMPUTERNAME - - and:: - - verdi computer delete COMPUTERNAME - - commands, to rename a computer or remove it from the database. - -.. note:: You can delete computers **only if** no entry in the database is linked to - them (as for instance Calculations, or RemoteData objects). Otherwise, you - will get an error message. - -.. note:: It is possible to **disable** a computer. - - Doing so will prevent AiiDA - from connecting to the given computer to check the state of calculations or - to submit new calculations. This is particularly useful if, for instance, - the computer is under maintenance but you still want to use AiiDA with - other computers, or submit the calculations in the AiiDA database anyway. - - The relevant commands are:: - - verdi computer enable COMPUTERNAME - verdi computer disable COMPUTERNAME - - Note that the above commands will disable the computer for all AiiDA users. - - -On not bombarding the remote computer with requests ---------------------------------------------------- - -Some machine (particularly at supercomputing centres) may not tolerate opening -connections and executing scheduler commands with a high frequency. To limit this -AiiDA currently has two settings: - - * The transport safe open interval, and, - * the minimum job poll interval - -Neither of these can ever be violated. AiiDA will not try to update the jobs list -on a remote machine until the job poll interval has elapsed since the last update -(the first update will be immediate) at which point it will request a transport. -Because of this the maximum possible time before a job update could be the sum of -the two intervals, however this is unlikely to happen in practice. - -The transport open interval is currently hardcoded by the transport plugin; -typically for SSH it's longer than for local transport. - -The job poll interval can be set programmatically on the corresponding ``Computer`` -object in verdi shell:: - - load_computer('localhost').set_minimum_job_poll_interval(30.0) - - -would set the transport interval on a computer called 'localhost' to 30 seconds. - -.. note:: All of these intervals apply *per worker*, meaning that a daemon with - multiple workers will not necessarily, overall, respect these limits. - For the time being there is no way around this and if these limits must be - respected then do not run with more than one worker. diff --git a/docs/source/howto/codes.rst b/docs/source/howto/codes.rst deleted file mode 100644 index 713b1af9a0..0000000000 --- a/docs/source/howto/codes.rst +++ /dev/null @@ -1,355 +0,0 @@ -.. _how-to:codes: - -************************* -How to run external codes -************************* - -To run an external code with AiiDA, you will need to use an appropriate :ref:`calculation plugin `. -This plugin must contain the instructions necessary for the engine to be able to: - -1. Prepare the required input files inside of the folder in which the code will be executed -2. Run the code with the correct set of command line parameters - -The following subsections will not only take you through the process of :ref:`creating the calculation plugin` and then using these to actually :ref:`run the code`. -It will also show examples on how to implement tools that are commonly coupled with the running of a calculation, such as :ref:`the parsing of outputs`. - -.. todo:: - - Add to preceding sentence: :ref:`the communication with external machines` and the interaction with its :ref:`scheduling software`. - -Some general guidelines to keep in mind are: - - * | **Check existing resources.** - | Before starting to write a plugin, check on the `aiida plugin registry `_ whether a plugin for your code is already available. - If it is, there is maybe no need to write your own, and you can skip straight ahead to :ref:`running the code`. - * | **Start simple.** - | Make use of existing classes like :py:class:`~aiida.orm.nodes.data.dict.Dict`, :py:class:`~aiida.orm.nodes.data.singlefile.SinglefileData`, ... - Write only what is necessary to pass information from and to AiiDA. - * | **Don't break data provenance.** - | Store *at least* what is needed for full reproducibility. - * | **Expose the full functionality.** - | Standardization is good but don't artificially limit the power of a code you are wrapping - or your users will get frustrated. - If the code can do it, there should be *some* way to do it with your plugin. - * | **Don't rely on AiiDA internals.** - Functionality at deeper nesting levels is not considered part of the public API and may change between minor AiiDA releases, breaking your plugin. - * | **Parse what you want to query for.** - | Make a list of which information to: - - #. parse into the database for querying (:py:class:`~aiida.orm.nodes.data.dict.Dict`, ...) - #. store in the file repository for safe-keeping (:py:class:`~aiida.orm.nodes.data.singlefile.SinglefileData`, ...) - #. leave on the computer where the calculation ran (:py:class:`~aiida.orm.nodes.data.remote.RemoteData`, ...) - -To demonstrate how to create a plugin for an external code, we will use the trivial example of using the `bash` shell (``/bin/bash``) to sum two numbers by running the command: ``echo $(( numx + numy ))``. -Here, the `bash` binary will be effectively acting as our |Code| executable, the input (``aiida.in``) will then be a file containing the command with the numbers provided by the user replaced, and the output (``aiida.out``) will be caught through the standard output. -The final recipe to run this code will then be: - -.. code-block:: bash - - /bin/bash < aiida.in > aiida.out - -.. _how-to:codes:interfacing: - -Interfacing external codes -========================== - -To provide AiiDA with the set of instructions, required to run a code, one should subclass the |CalcJob| class and implement the following two key methods: - - #. :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.define` - #. :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.prepare_for_submission` - -We will now show how each of these can be implemented. - -Defining the specifications ---------------------------- - -The |define| method is where one specifies the different inputs that the caller of the |CalcJob| will have to provide in order to run the code, as well as the outputs that will be produced (exit codes will be :ref:`discussed later`). -This is done through an instance of :py:class:`~aiida.engine.processes.process_spec.CalcJobProcessSpec`, which, as can be seen in the snippet below, is passed as the |spec| argument to the |define| method. -For the code that adds up two numbers, we will need to define those numbers as inputs (let's call them ``x`` and ``y`` to label them) and the result as an output (``sum``). -The snippet below shows one potential implementation, as it is included in ``aiida-core``: - -.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py - :language: python - :pyobject: ArithmeticAddCalculation.define - :dedent: 4 - -The first line of the |define| implementation calls the method of the parent class |CalcJob|. -This step is crucial as it will define `inputs` and `outputs` that are common to all |CalcJob|'s and failing to do so will leave the implementation broken. -After the super call, we modify the default values for some of these inputs that are defined by the base class. -Inputs that have already been defined can be accessed from the |spec| through the :py:attr:`~plumpy.process_spec.ProcessSpec.inputs` attribute, which behaves like a normal dictionary. - -After modifying the existing inputs, we define the inputs that are specific to this code. -For this purpose we use the :py:meth:`~plumpy.process_spec.ProcessSpec.input` method, which does not modify the existing `inputs`, accessed through :py:attr:`~plumpy.process_spec.ProcessSpec.inputs`, but defines new ones that will be specific to this implementation. -You can also see that the definitions do not involve the assignment of a value, but only the passing of parameters to the method: a label to identify it, their valid types (in this case nodes of type |Int|) and a description. -Finally, note that there is no return statement: this method does not need to return anything, since all modifications are made directly into the received |spec| object. -You can check the Topics section about :ref:`defining processes ` if you want more information about setting up your `inputs` and `outputs` (covering validation, dynamic number of inputs, etc.). - -Preparing for submission ------------------------- - -The :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.prepare_for_submission` method is used for two purposes. -Firstly, it should create the input files, based on the input nodes passed to the calculation, in the format that the external code will expect. -Secondly, the method should create and return a :py:class:`~aiida.common.datastructures.CalcInfo` instance that contains various instructions for the engine on how the code should be run. -An example implementation, as shipped with ``aiida-core`` can be seen in the following snippet: - -.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py - :language: python - :pyobject: ArithmeticAddCalculation.prepare_for_submission - :dedent: 4 - -Note that, unlike the |define| method, this one is implemented from scratch and so there is no super call. -The external code that we are running with this |CalcJob| is ``bash`` and so to sum the input numbers ``x`` and ``y``, we should write a bash input file that performs the summation, for example ``echo $((x + y))``, where one of course has to replace ``x`` and ``y`` with the actual numbers. -You can see how the snippet uses the ``folder`` argument, which is a |Folder| instance that represents a temporary folder on disk, to write the input file with the bash summation. -It uses Python's string interpolation to replace the ``x`` and ``y`` placeholders with the actual values that were passed as input, ``self.inputs.x`` and ``self.inputs.y``, respectively. - -.. note:: - - When the |prepare_for_submission| is called, the inputs that have been passed will have been validated against the specification defined in the |define| method and they can be accessed through the :py:attr:`~plumpy.processes.Process.inputs` attribute. - This means that if a particular input is required according to the spec, you can safely assume that it will have been set and you do not need to check explicitly for its existence. - -All the files that are copied into the sandbox ``folder`` will be automatically copied by the engine to the scratch directory where the code will be run. -In this case we only create one input file, but you can create as many as you need, including subfolders if required. - -.. note:: - - The input files written to the ``folder`` sandbox, will also be permanently stored in the file repository of the calculation node for the purpose of additional provenance guarantees. - See the section on :ref:`excluding files from provenance` to learn how to prevent certain input files from being stored explicitly. - -After having written the necessary input files, one should create the |CodeInfo| object, which can be used to instruct the engine on how to run the code. -We assign the ``code_uuid`` attribute to the ``uuid`` of the ``Code`` node that was passed as an input, which can be retrieved through ``self.inputs.code``. -This is necessary such that the engine can retrieve the required information from the |Code| node, such as the full path of the executable. -Note that we didn't explicitly define this |Code| input in the |define| method, but this is one of the inputs defined in the base |CalcJob| class: - -.. code-block:: python - - spec.input('code', valid_type=orm.Code, help='The `Code` to use for this job.') - -After defining the UUID of the code node that the engine should use, we define the filenames where the stdin and stdout file descriptors should be redirected to. -These values are taken from the inputs, which are part of the ``metadata.options`` namespace, for some of whose inputs we overrode the default values in the specification definition in the previous section. -Note that instead of retrieving them through ``self.inputs.metadata['options']['input_filename']``, one can use the shortcut ``self.options.input_filename`` as we do here. -Based on this definition of the |CodeInfo|, the engine will create a run script that looks like the following: - -.. code-block:: bash - - #!/bin/bash - - '[executable path in code node]' < '[input_filename]' > '[output_filename]' - -The |CodeInfo| should be attached to the ``codes_info`` attribute of a |CalcInfo| object. -A calculation can potentially run more than one code, so the |CodeInfo| object should be assigned as a list. -Finally, we define the ``retrieve_list`` attribute, which is a list of filenames that the engine should retrieve from the running directory once the calculation job has finished. -The engine will store these files in a :py:class:`~aiida.orm.nodes.data.folder.FolderData` node that will be attached as an output node to the calculation with the label ``retrieved``. -There are :ref:`other file lists available` that allow you to easily customize how to move files to and from the remote working directory in order to prevent the creation of unnecessary copies. - -This was a minimal example of how to implement the |CalcJob| class to interface AiiDA with an external code. -For more detailed information and advanced functionality on the |CalcJob| class, refer to the Topics section on :ref:`defining calculations `. - -.. _how-to:codes:parsing: - -Parsing the outputs -=================== - -The parsing of the output files generated by a |CalcJob| is optional and can be used to store (part of) their information as AiiDA nodes, which makes the data queryable and therefore easier to access and analyze. -To enable |CalcJob| output file parsing, one should subclass the |Parser| class and implement the :py:meth:`~aiida.parsers.parser.Parser.parse` method. -The following is an example implementation, as shipped with ``aiida-core``, to parse the outputs of the :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation` discussed in the previous section: - -.. literalinclude:: ../../../aiida/parsers/plugins/arithmetic/add.py - :language: python - :pyobject: ArithmeticAddParser - -The output files generated by the completed calculation can be accessed from the ``retrieved`` output folder, which can be accessed through the :py:attr:`~aiida.parsers.parser.Parser.retrieved` property. -It is an instance of :py:class:`~aiida.orm.nodes.data.folder.FolderData` and so provides, among other things, the :py:meth:`~aiida.orm.nodes.node.Node.open` method to open any file it contains. -In this example implementation, we use it to open the output file, whose filename we get through the :py:meth:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.get_option` method of the corresponding calculation node, which we obtain through the :py:attr:`~aiida.parsers.parser.Parser.node` property of the ``Parser``. -We read the content of the file and cast it to an integer, which should contain the sum that was produced by the ``bash`` code. -We catch any exceptions that might be thrown, for example when the file cannot be read, or if its content cannot be interpreted as an integer, and return an exit code. -This method of dealing with potential errors of external codes is discussed in the section on :ref:`handling parsing errors`. - -To attach the parsed sum as an output, use the :py:meth:`~aiida.parsers.parser.Parser.out` method. -The first argument is the name of the output, which will be used as the label for the link that connects the calculation and data node, and the second is the node that should be recorded as an output. -Note that the type of the output should match the type that is specified by the process specification of the corresponding |CalcJob|. -If any of the registered outputs do not match the specification, the calculation will be marked as failed. - -To trigger the parsing using a |Parser| after a |CalcJob| has finished (such as the one described in the :ref:`previous section `) it should be defined in the ``metadata.options.parser_name`` input. -If a particular parser should always be used by default for a given |CalcJob|, it can be defined as the default in the |define| method, for example: - -.. code-block:: python - - @classmethod - def define(cls, spec): - ... - spec.inputs['metadata']['options']['parser_name'].default = 'arithmetic.add' - -The default can be overridden through the inputs when launching the calculation job. -Note, that one should not pass the |Parser| class itself, but rather the corresponding entry point name under which it is registered as a plugin. -In other words, in order to use a |Parser| you will need to register it as explained in the how-to section on :ref:`registering plugins `. - - -.. _how-to:codes:parsing:errors: - -Handling parsing errors ------------------------ - -So far we have not spent too much attention on dealing with potential errors that might arise when running external codes. -However, for many codes, there are lots of ways in which it can fail to execute nominally and produced the correct output. -A |Parser| is the solution to detect these errors and report them to the caller through :ref:`exit codes`. -These exit codes can be defined through the |spec| of the |CalcJob| that is used for that code, just as the inputs and output are defined. -For example, the :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation` introduced in :ref:`"Interfacing external codes"`, defines the following exit codes: - -.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py - :language: python - :start-after: start exit codes - :end-before: end exit codes - :dedent: 8 - -Each ``exit_code`` defines an exit status (a positive integer), a label that can be used to reference the code in the |parse| method (through the ``self.exit_codes`` property, as seen below), and a message that provides a more detailed description of the problem. -To use these in the |parse| method, you just need to return the corresponding exit code which instructs the engine to store it on the node of the calculation that is being parsed. -The snippet of the previous section on :ref:`parsing the outputs` already showed two problems that are detected and are communicated by returning the corresponding the exit code: - -.. literalinclude:: ../../../aiida/parsers/plugins/arithmetic/add.py - :language: python - :lines: 28-34 - :dedent: 8 - -If the ``read()`` call fails to read the output file, for example because the calculation failed to run entirely and did not write anything, it will raise an ``OSError``, which the parser catches and returns the ``ERROR_READING_OUTPUT_FILE`` exit code. -Alternatively, if the file *could* be read, but it's content cannot be interpreted as an integer, the parser returns ``ERROR_INVALID_OUTPUT``. -The Topics section on :ref:`defining processes ` provides additional information on how to use exit codes. - -.. todo:: - - .. _how-to:codes:computers: - - title: Configuring remote computers - - `#4123`_ - -.. _how-to:codes:run: - -Running external codes -====================== - -To run an external code with AiiDA, you will need to use an appropriate :ref:`calculation plugin ` that knows how to transform the input nodes into the input files that the code expects, copy everything in the code's machine, run the calculation and retrieve the results. -You can check the `plugin registry `_ to see if a plugin already exists for the code that you would like to run. -If that is not the case, you can :ref:`develop your own `. -After you have installed the plugin, you can start running the code through AiiDA. -To check which calculation plugins you have currently installed, run: - -.. code-block:: bash - - $ verdi plugin list aiida.calculations - -As an example, we will show how to use the ``arithmetic.add`` plugin, which is a pre-installed plugin that uses the `bash shell `_ to sum two integers. -You can access it with the ``CalculationFactory``: - -.. code-block:: python - - from aiida.plugins import CalculationFactory - calculation_class = CalculationFactory('arithmetic.add') - -Next, we provide the inputs for the code when running the calculation. -Use ``verdi plugin`` to determine what inputs a specific plugin expects: - -.. code-block:: bash - - $ verdi plugin list aiida.calculations arithmetic.add - ... - Inputs: - code: required Code The `Code` to use for this job. - x: required Int, Float The left operand. - y: required Int, Float The right operand. - ... - -You will see that 3 inputs nodes are required: two containing the values to add up (``x``, ``y``) and one containing information about the specific code to execute (``code``). -If you already have these nodes in your database, you can get them by :ref:`querying for them ` or using ``orm.load_node()``. -Otherwise, you will need to create them as shown below (note that you `will` need to already have the ``localhost`` computer configured, as explained in the :ref:`previous how-to`): - -.. code-block:: python - - from aiida import orm - bash_binary = orm.Code(remote_computer_exec=[localhost, '/bin/bash']) - number_x = orm.Int(17) - number_y = orm.Int(11) - -To provide these as inputs to the calculations, we will now use the ``builder`` object that we can get from the class: - -.. code-block:: python - - calculation_builder = calculation_class.get_builder() - calculation_builder.code = bash_binary - calculation_builder.x = number_x - calculation_builder.y = number_y - -Now everything is in place and ready to perform the calculation, which can be done in two different ways. -The first one is blocking and will return a dictionary containing all the output nodes (keyed after their label, so in this case these should be: "remote_folder", "retrieved" and "sum") that you can safely inspect and work with: - -.. code-block:: python - - from aiida.engine import run - output_dict = run(calculation_builder) - sum_result = output_dict['sum'] - -The second one is non blocking, as you will be submitting it to the daemon and control is immediately returned to the interpreter. -The return value in this case is the calculation node that is stored in the database. - -.. code-block:: python - - from aiida.engine import submit - calculation = submit(calculation_builder) - -Note that, although you have access to the node, the underlying calculation `process` is not guaranteed to have finished when you get back control in the interpreter. -You can use the verdi command line interface to :ref:`monitor` these processes: - -.. code-block:: bash - - $ verdi process list - -Performing a dry-run --------------------- - -Additionally, you might want to check and verify your inputs before actually running or submitting a calculation. -You can do so by specifying to use a ``dry_run``, which will create all the input files in a local directory (``submit_test/[date]-0000[x]``) so you can inspect them before actually launching the calculation: - -.. code-block:: python - - calculation_builder.metadata.dry_run = True - calculation_builder.metadata.store_provenance = False - run(calculation_builder) - -.. todo:: - - .. _how-to:codes:caching: - - title: Using caching to save computational resources - - `#3988`_ - - - .. _how-to:codes:scheduler: - - title: Adding support for a custom scheduler - - `#3989`_ - - - .. _how-to:codes:transport: - - title: Adding support for a custom transport - - `#3990`_ - - -.. |Int| replace:: :py:class:`~aiida.orm.nodes.data.int.Int` -.. |Code| replace:: :py:class:`~aiida.orm.nodes.data.Code` -.. |Parser| replace:: :py:class:`~aiida.parsers.parser.Parser` -.. |parse| replace:: :py:class:`~aiida.parsers.parser.Parser.parse` -.. |folder| replace:: :py:class:`~aiida.common.folders.Folder` -.. |folder.open| replace:: :py:class:`~aiida.common.folders.Folder.open` -.. |CalcJob| replace:: :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` -.. |CalcInfo| replace:: :py:class:`~aiida.common.CalcInfo` -.. |CodeInfo| replace:: :py:class:`~aiida.common.CodeInfo` -.. |spec| replace:: ``spec`` -.. |define| replace:: :py:class:`~aiida.engine.processes.calcjobs.CalcJob.define` -.. |prepare_for_submission| replace:: :py:class:`~aiida.engine.processes.calcjobs.CalcJob.prepare_for_submission` - -.. _#3988: https://github.com/aiidateam/aiida-core/issues/3988 -.. _#3989: https://github.com/aiidateam/aiida-core/issues/3989 -.. _#3990: https://github.com/aiidateam/aiida-core/issues/3990 -.. _#4123: https://github.com/aiidateam/aiida-core/issues/4123 diff --git a/docs/source/howto/data.rst b/docs/source/howto/data.rst index acf6eac3ca..a7e3ecbfaf 100644 --- a/docs/source/howto/data.rst +++ b/docs/source/howto/data.rst @@ -246,7 +246,7 @@ However, storing large amounts of data within the database comes at the cost of Therefore, big data (think large files), whose content does not necessarily need to be queried for, is better stored in the file repository. A data type may safely use both the database and file repository in parallel for individual properties. Properties stored in the database are stored as *attributes* of the node. -The node class has various methods to set these attributes, such as :py:meth:`~aiida.orm.nodes.node.Node.set_attribute` and :py:meth:`~aiida.orm.nodes.node.Node.set_attribute_many`. +The node class has various methods to set these attributes, such as :py:meth:`~aiida.orm.entities.EntityAttributesMixin.set_attribute` and :py:meth:`~aiida.orm.entities.EntityAttributesMixin.set_attribute_many`. .. _how-to:data:find: @@ -979,7 +979,7 @@ A subset of data in AiiDA is mutable also after storing a node, and is used as a This data can be safely deleted at any time. This includes, notably: -* *Node extras*: These can be deleted using :py:meth:`~aiida.orm.nodes.node.Node.delete_extra` and :py:meth:`~aiida.orm.nodes.node.Node.delete_extra_many`. +* *Node extras*: These can be deleted using :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra` and :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra_many` methods. * *Node comments*: These can be removed using :py:meth:`~aiida.orm.nodes.node.Node.remove_comment`. * *Groups*: These can be deleted using :py:meth:`Group.objects.delete() `. This command will only delete the group, not the nodes contained in the group. diff --git a/docs/source/working_with_aiida/include/images/caching.png b/docs/source/howto/include/images/caching.png similarity index 100% rename from docs/source/working_with_aiida/include/images/caching.png rename to docs/source/howto/include/images/caching.png diff --git a/docs/source/working_with_aiida/include/images/caching.svg b/docs/source/howto/include/images/caching.svg similarity index 100% rename from docs/source/working_with_aiida/include/images/caching.svg rename to docs/source/howto/include/images/caching.svg diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index 9cca705cd1..98b861a52e 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -5,7 +5,9 @@ How-To Guides .. toctree:: :maxdepth: 1 - codes + run_codes + ssh + plugin_codes workflows data visualising_graphs/visualising_graphs diff --git a/docs/source/howto/installation.rst b/docs/source/howto/installation.rst index f50cdf98b5..27cff5ed7f 100644 --- a/docs/source/howto/installation.rst +++ b/docs/source/howto/installation.rst @@ -201,8 +201,8 @@ For example, the directory structure in your home folder ``~/`` might look like . ├── .aiida └── project_a -    ├── .aiida -    └── subfolder + ├── .aiida + └── subfolder If you leave the ``AIIDA_PATH`` variable unset, the default location ``~/.aiida`` will be used. However, if you set: @@ -343,7 +343,7 @@ Updating from 0.x.* to 1.* -------------------------- - `Additional instructions on how to migrate from 0.12.x versions `_. - `Additional instructions on how to migrate from versions 0.4 -- 0.11 `_. -- For a list of breaking changes between the 0.x and the 1.x series of AiiDA, check `this page `_. +- For a list of breaking changes between the 0.x and the 1.x series of AiiDA, `see here `_. .. _how-to:installation:backup: @@ -474,144 +474,6 @@ In order to restore a backup, you will need to: After supplying your database password, the database should be restored. Note that, if you installed the database on Ubuntu as a system service, you need to type ``sudo su - postgres`` to become the ``postgres`` UNIX user. -.. _how-to:installation:running-on-supercomputers: - -Running on supercomputers -========================= - -.. _how-to:installation:running-on-supercomputers:ssh-agent: - -Using passphrase-protected SSH keys via a ssh-agent ---------------------------------------------------- - -In order to connect to a remote computer using the ``SSH`` transport, AiiDA needs a password-less login: for this reason, it is necessary to configure an authentication key pair. - -Using a passphrase to encrypt the private key is not mandatory, however it is highly recommended. -In some cases it is indispensable because it is requested by the computer center managing the remote cluster. -To this purpose, the use of a tool like ``ssh-agent`` becomes essential, so that the private-key passphrase only needs to be supplied once (note that the key needs to be provided again after a reboot of your AiiDA machine). - -Starting the ssh-agent -^^^^^^^^^^^^^^^^^^^^^^ - -In the majority of modern Linux systems for desktops/laptops, the ``ssh-agent`` automatically starts during login. -In some cases (e.g. virtual machines, or old distributions) it is needed to start it manually instead. -If you are unsure, just run the command ``ssh-add``: if it displays the error ``Could not open a connection to your authentication agent``, then you need to start the agent manually as described below. - -.. dropdown:: Start the ``ssh-agent`` manually (and reuse it across shells) - - If you have no ``ssh-agent`` running, you can start a new one with the command: - - .. code:: bash - - eval `ssh-agent` - - However, this command will start a new agent that will be visible **only in your current shell**. - - In order to use the same agent instance in every future opened shell, and most importantly to make this accessible to the AiiDA daemon, you need to make sure that the environment variables of ``ssh-agent`` are reused by *all* shells. - - To make the ssh-agent persistent, downlod the script :download:`load-singlesshagent.sh ` and put it in a directory dedicated to the storage of your scripts (in our example will be ``~/bin``). - - .. note:: - - You need to use this script only if a "global" ssh-agent is not available by default on your computer. - A global agent is available, for instance, on recent versions of Mac OS X and of Ubuntu Linux. - - Then edit the file ``~/.bashrc`` and add the following lines: - - .. code:: bash - - if [ -f ~/bin/load-singlesshagent.sh ]; then - . ~/bin/load-singlesshagent.sh - fi - - To check that it works, perform the following steps: - - * Open a new shell, so that the ``~/.bashrc`` file is sourced. - * Run the command ``ssh-add`` as described in the following section. - * Logout from the current shell. - * Open a new shell. - * Check that you are able to connect to the remote computer without typing the passphrase. - -Adding the passphrase of your key(s) to the agent -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To provide the passphrase of your private key to the the agent use the command: - -.. code:: bash - - ssh-add - -If you changed the default position or the default name of the private key, or you want to provide the passphrase only for a specific key, you need specify the path to the SSH key file as a parameter to ``ssh-add``. - -The private key and the relative passphrase are now recorded in an instance of the agent. - -.. note:: - - The passphase is stored in the agent only until the next reboot. - If you shut down or restart the AiiDA machine, before starting the AiiDA deamon remember to run the ``ssh-add`` command again. - -Configure AiiDA -^^^^^^^^^^^^^^^ - -In order to use the agent in AiiDA, you need to first make sure that you can connect to the computer via SSH without explicitly specifying a passphrase. -Make sure that this is the case also in newly opened bash shells. - -Then, when configuring the corresponding AiiDA computer (via ``verdi computer configure``), make sure to specify ``true`` to the question ``Allow ssh agent``. -If you already configured the computer and just want to adapt the computer configuration, just rerun - -.. code:: bash - - verdi computer configure ssh COMPUTERNAME - -After the configuration, you should verify that AiiDA can connect to the computer with: - -.. code:: bash - - verdi computer test COMPUTERNAME - - -.. _how-to:installation:running-on-supercomputers:avoiding-overloads: - -Avoiding overloads ------------------- - -If you submit to a supercomputer shared by many users (e.g., in a supercomputer center), be careful not to overload the supercomputer with too many jobs: - - * limit the number of jobs in the queue (the exact number depends on the supercomputer: discuss this with your supercomputer administrators, and you can redirect them to :ref:`this page` that may contain useful information for them). - While in the future `this might be dealt with by AiiDA automatically `_, - you are responsible for this at the moment. - This can be achieved for instance by submitting only a maximum number of workflows to AiiDA, and submitting new ones only when the previous ones complete. - - * Tune the parameters that AiiDA uses to avoid overloading the supercomputer with connections or batch requests. - For SSH transports, the default is 30 seconds, which means that when each worker opens a SSH connection to a computer, it will reuse it as long as there are tasks to execute and then close it. - Opening a new connection will not happen before 30 seconds has passed from the opening of the previous one. - - We stress that this is *per daemon worker*, so that if you have 10 workers, your supercomputer will on average see 10 connections every 30 seconds. - Therefore, if you are using many workers and you mostly have long-running jobs, you can set a longer time (e.g., 120 seconds) by reconfiguring the computer with ``verdi computer configure ssh `` and changing the value - of the *Connection cooldown time* or, alternatively, by running: - - .. code-block:: bash - - verdi computer configure ssh --non-interactive --safe-interval - - * In addition to the connection cooldown time described above, AiiDA also limits the frequency for retrieving the job queue from the scheduler (``squeue``, ``qstat``, ...), as this can also impact the performance of the scheduler. - For a given computer, you can increase how many seconds must pass between requests. - First load the computer in a shell with ``computer = load_computer()``. - You can check the current value in seconds (by default, 10) with ``computer.get_minimum_job_poll_interval()``. - You can then set it to a higher value using: - - .. code-block:: python - - computer.set_minimum_job_poll_interval() - -.. _how-to:installation:running-on-supercomputers:for_cluster_admins: - -Optimising the SLURM scheduler configuration --------------------------------------------- - -If too many jobs are submitted at the same time to the queue, SLURM might have trouble in dealing with new submissions. -If you are a cluster administrator, you might be interested in `some tips available in the AiiDA wiki `_, suggested by sysadmins at the Swiss Supercomputer Centre `CSCS `_ (or you can redirect your admin to this page if your cluster is experiencing slowness related to a large number of submitted jobs). - .. _how-to:installation:multi-user: Managing multiple users @@ -625,3 +487,4 @@ Data can be shared between instances using :ref:`AiiDA's export and import funct Sharing (subsets of) the AiiDA graph can be done as often as needed. .. _#4122: https://github.com/aiidateam/aiida-core/issues/4122 +.. |Computer| replace:: :py:class:`~aiida.orm.Computer` diff --git a/docs/source/howto/plugin_codes.rst b/docs/source/howto/plugin_codes.rst new file mode 100644 index 0000000000..98a2293984 --- /dev/null +++ b/docs/source/howto/plugin_codes.rst @@ -0,0 +1,466 @@ +.. _how-to:plugin-codes: + +****************************************** +How to write a plugin for an external code +****************************************** + +.. tip:: + + Before starting to write a new plugin, check the `aiida plugin registry `_. + If a plugin for your code is already available, you can skip straight to :ref:`how-to:run-codes`. + +To run an external code with AiiDA, you need a corresponding *calculation* plugin, which tells AiiDA how to: + +1. Prepare the required input files. +2. Run the code with the correct command line parameters. + +Finally, you will probably want a *parser* plugin, which tells AiiDA how to: + +3. Parse the output of the code. + +This how-to takes you through the process of :ref:`creating a calculation plugin` for a simple executable that sums two numbers, using it to :ref:`run the code`, and :ref:`writing a parser ` for its outputs. + +In the following, as an example, our |Code| will be the `bash` executable, and our "input file" will be a `bash` script ``aiida.in`` that sums two numbers and prints the result: + +.. code-block:: bash + + echo $(( numx + numy )) + +We will run this as: + +.. code-block:: bash + + /bin/bash < aiida.in > aiida.out + +thus writing the sum of the two numbers ``numx`` and ``numy`` (provided by the user) to the output file ``aiida.out``. + + +.. todo:: + + Add to preceding sentence: :ref:`the communication with external machines` and the interaction with its :ref:`scheduling software`. + +.. _how-to:plugin-codes:interfacing: + + +Interfacing external codes +========================== + +Start by creating a file ``calculations.py`` and subclass the |CalcJob| class: + +.. code-block:: python + + from aiida import orm + from aiida.common.datastructures import CalcInfo, CodeInfo + from aiida.common.folders import Folder + from aiida.engine import CalcJob, CalcJobProcessSpec + + + class ArithmeticAddCalculation(CalcJob): + """`CalcJob` implementation to add two numbers using bash for testing and demonstration purposes.""" + + +In the following, we will tell AiiDA how to run our code by implementing two key methods: + + #. :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.define` + #. :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.prepare_for_submission` + +Defining the spec +----------------- + +The |define| method tells AiiDA which inputs the |CalcJob| expects and which outputs it produces (exit codes will be :ref:`discussed later`). +This is done through an instance of the :py:class:`~aiida.engine.processes.process_spec.CalcJobProcessSpec` class, which is passed as the |spec| argument to the |define| method. +For example: + +.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py + :language: python + :pyobject: ArithmeticAddCalculation.define + +The first line of the method calls the |define| method of the |CalcJob| parent class. +This necessary step defines the `inputs` and `outputs` that are common to all |CalcJob|'s. + +Next, we use the :py:meth:`~plumpy.process_spec.ProcessSpec.input` method in order to define our two input numbers ``x`` and ``y`` (we support integers and floating point numbers), and we use :py:meth:`~plumpy.process_spec.ProcessSpec.output` to define the only output of the calculation with the label ``sum``. +AiiDA will attach the outputs defined here to a (successfully) finished calculation using the link label provided. + +.. note:: + This holds for *required* outputs (the default behaviour). + Use ``required=False`` in order to mark an output as optional. + +.. tip:: + + For the input parameters and input files of more complex simulation codes, consider using :py:class:`~aiida.orm.nodes.data.dict.Dict` (python dictionary) and :py:class:`~aiida.orm.nodes.data.singlefile.SinglefileData` (file wrapper) input nodes. + +Finally, we set a couple of default ``options``, such as the name of the parser (which we will implement later), the name of input and output files, and the computational resources to use for such a calculation. +These ``options`` have already been defined on the |spec| by the ``super().define(spec)`` call, and they can be accessed through the :py:attr:`~plumpy.process_spec.ProcessSpec.inputs` attribute, which behaves like a dictionary. + +.. note:: + + One more important input required by any |CalcJob| is which external executable to use. + External executables are represented by |Code| instances that contain information about the computer they reside on, their path in the file system and more. + + They are passed to a |CalcJob| via the ``code`` input, which is defined in the |CalcJob| base class, so you don't have to: + + .. code-block:: python + + spec.input('code', valid_type=orm.Code, help='The `Code` to use for this job.') + + + +There is no ``return`` statement in ``define``: the ``define`` method directly modifies the |spec| object it receives. +For more details on setting up your `inputs` and `outputs` (covering validation, dynamic number of inputs, etc.) see the :ref:`Defining Processes ` topic. + +Preparing for submission +------------------------ + +The :py:meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.prepare_for_submission` method has two jobs: +Creating the input files in the format the external code expects and returning a :py:class:`~aiida.common.datastructures.CalcInfo` object that contains instructions for the AiiDA engine on how the code should be run. +For example: + +.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py + :language: python + :pyobject: ArithmeticAddCalculation.prepare_for_submission + +.. note:: Unlike the |define| method, the ``prepare_for_submission`` method is implemented from scratch and so there is no super call. + +The first step is writing the simple bash script mentioned in the beginning: summing the numbers ``x`` and ``y``, using Python's string interpolation to replace the ``x`` and ``y`` placeholders with the actual values ``self.inputs.x`` and ``self.inputs.y`` that were provided as inputs by the caller. + +All inputs provided to the calculation are validated against the ``spec`` *before* |prepare_for_submission| is called. +Therefore, when accessing the :py:attr:`~plumpy.processes.Process.inputs` attribute, you can safely assume that all required inputs have been set and that all inputs have a valid type. + +The ``folder`` argument (a |Folder| instance) allows us to write the input file to a sandbox folder, whose contents will be transferred to the compute resource where the actual calculation takes place. +In this example, we only create a single input file, but you can create as many as you need, including subfolders if required. + +.. note:: + + By default, the contents of the sandbox ``folder`` are also stored permanently in the file repository of the calculation node for additional provenance guarantees. + There are cases (e.g. license issues, file size) where you may want to change this behavior and :ref:`exclude files from being stored`. + +After having written the necessary input files, we let AiiDA know how to run the code via the |CodeInfo| object. + +First, we forward the ``uuid`` of the |Code| instance passed by the user via the generic ``code`` input mentioned previously (in this example, the ``code`` will represent a ``bash`` executable). + +Second, let's recall how we want our executable to be run: + +.. code-block:: bash + + #!/bin/bash + + '[executable path in code node]' < '[input_filename]' > '[output_filename]' + +We want to pass our input file to the executable via standard input, and record standard output of the executable in the output file -- this is done using the ``stdin_name`` and ``stdout_name`` attributes of the |CodeInfo|. + +.. tip:: + + Many executables don't read from standard input but instead require the path to an input file to be passed via command line parameters (potentially including further configuration options). + In that case, use the |CodeInfo| ``cmdline_params`` attribute: + + .. code-block:: python + + codeinfo.cmdline_params = ['--input', self.inputs.input_filename] + +.. tip:: + + ``self.options.input_filename`` is just a shorthand for ``self.inputs.metadata['options']['input_filename']``. + +Finally, we pass the |CodeInfo| to a |CalcInfo| object (one calculation job can involve more than one executable, so ``codes_info`` is a list). +We define the ``retrieve_list`` of filenames that the engine should retrieve from the directory where the job ran after it has finished. +The engine will store these files in a |FolderData| node that will be attached as an output node to the calculation with the label ``retrieved``. +There are :ref:`other file lists available` that allow you to easily customize how to move files to and from the remote working directory in order to prevent the creation of unnecessary copies. + +This was an example of how to implement the |CalcJob| class to interface AiiDA with an external code. +For more details on the |CalcJob| class, refer to the Topics section on :ref:`defining calculations `. + +.. _how-to:plugin-codes:parsing: + +Parsing the outputs +=================== + +Parsing the output files produced by a code into AiiDA nodes is optional, but it can make your data queryable and therefore easier to access and analyze. + +To create a parser plugin, subclass the |Parser| class (for example in a file called ``parsers.py``) and implement its :py:meth:`~aiida.parsers.parser.Parser.parse` method. +The following is an example of a simple implementation: + +.. literalinclude:: ../../../aiida/parsers/plugins/arithmetic/add.py + :language: python + :pyobject: SimpleArithmeticAddParser + +Before the ``parse()`` method is called, two important attributes are set on the |Parser| instance: + + 1. ``self.retrieved``: An instance of |FolderData|, which points to the folder containing all output files that the |CalcJob| instructed to retrieve, and provides the means to :py:meth:`~aiida.orm.nodes.node.Node.open` any file it contains. + + 2. ``self.node``: The :py:class:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode` representing the finished calculation, which, among other things, provides access to all of its inputs (``self.node.inputs``). + +The :py:meth:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.get_option` convenience method is used to get the filename of the output file. +Its content is cast to an integer, since the output file should contain the sum produced by the ``aiida.in`` bash script. + +Finally, the :py:meth:`~aiida.parsers.parser.Parser.out` method is used to link the parsed sum as an output of the calculation. +The first argument is the name of the output, which will be used as the label for the link that connects the calculation and data node, and the second is the node that should be recorded as an output. +Note that the type of the output should match the type that is specified by the process specification of the corresponding |CalcJob|. +If any of the registered outputs do not match the specification, the calculation will be marked as failed. + +In order to request automatic parsing of a |CalcJob| (once it has finished), users can set the ``metadata.options.parser_name`` input when launching the job. +If a particular parser should be used by default, the |CalcJob| ``define`` method can set a default value for the parser name as was done in the :ref:`previous section `: + +.. code-block:: python + + @classmethod + def define(cls, spec): + ... + spec.inputs['metadata']['options']['parser_name'].default = 'arithmetic.add' + +Note, that the default is not set to the |Parser| class itself, but the *entry point string* under which the parser class is registered. +How to register a parser class through an entry point is explained in the how-to section on :ref:`registering plugins `. + + +.. _how-to:plugin-codes:parsing:errors: + +Handling parsing errors +----------------------- + +So far, we have not spent much attention on dealing with potential errors that can arise when running external codes. +However, there are lots of ways in which codes can fail to execute nominally. +A |Parser| can play an important role in detecting and communicating such errors, where :ref:`workflows ` can then decide how to proceed, e.g., by modifying input parameters and resubmitting the calculation. + +Parsers communicate errors through :ref:`exit codes`, which are defined in the |spec| of the |CalcJob| they parse. +The :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation` example, defines the following exit codes: + +.. literalinclude:: ../../../aiida/calculations/arithmetic/add.py + :language: python + :start-after: start exit codes + :end-before: end exit codes + :dedent: 8 + +Each ``exit_code`` defines: + + * an exit status (a positive integer), + * a label that can be used to reference the code in the |parse| method (through the ``self.exit_codes`` property, as shown below), and + * a message that provides a more detailed description of the problem. + +In order to inform AiiDA about a failed calculation, simply return from the ``parse`` method the exit code that corresponds to the detected issue. +Here is a more complete version of the example |Parser| presented in the previous section: + +.. literalinclude:: ../../../aiida/parsers/plugins/arithmetic/add.py + :language: python + :pyobject: ArithmeticAddParser + +It checks: + + 1. Whether a retrieved folder is present. + 2. Whether the output file can be read (whether ``open()`` or ``read()`` will throw an ``OSError``). + 3. Whether the output file contains an integer. + 4. Whether the sum is negative. + +AiiDA stores the exit code returned by the |parse| method on the calculation node that is being parsed, from where it can then be inspected further down the line. +The Topics section on :ref:`defining processes ` provides more details on exit codes. +Note that scheduler plugins can also implement parsing of the output generated by the job scheduler and in the case of problems can set an exit code. +The Topics section on :ref:`scheduler exit codes ` explains how they can be inspected inside an output parser and how they can optionally be overridden. + + +.. todo:: + + .. _how-to:plugin-codes:computers: + + title: Configuring remote computers + + `#4123`_ + +.. _how-to:plugin-codes:entry-points: + +Registering entry points +======================== + +:ref:`Entry points ` are the preferred method of registering new calculation, parser and other plugins with AiiDA. + +With your ``calculations.py`` and ``parsers.py`` files at hand, let's register entry points for the plugins they contain: + + * Move your two scripts into a subfolder ``aiida_add``: + + .. code-block:: console + + mkdir aiida_add + mv calculations.py parsers.py aiida_add/ + + You have just created an ``aiida_add`` Python *package*! + + * Write a minimalistic ``setup.py`` script for your new package: + + .. code-block:: python + + from setuptools import setup + + setup( + name='aiida-add', + packages=['aiida_add'], + entry_points={ + 'aiida.calculations': ["add = aiida_add.calculations:ArithmeticAddCalculation"], + 'aiida.parsers': ["add = aiida_add.parsers:ArithmeticAddParser"], + } + ) + + .. note:: + Strictly speaking, ``aiida-add`` is the name of the *distribution*, while ``aiida_add`` is the name of the *package*. + The aiida-core documentation uses the term *package* a bit more loosely. + + + * Install your new ``aiida-add`` plugin package: + + .. code-block:: console + + pip install -e . + reentry scan + + +After this, you should see your plugins listed: + + .. code-block:: console + + $ verdi plugin list aiida.calculations + $ verdi plugin list aiida.calculations add + $ verdi plugin list aiida.parsers + + + +.. _how-to:plugin-codes:run: + +Running a calculation +===================== + +With the entry points set up, you are ready to launch your first calculation with the new plugin: + + + * If you haven't already done so, :ref:`set up your computer`. + In the following we assume it to be the localhost: + + .. code-block:: console + + $ verdi computer setup -L localhost -H localhost -T local -S direct -w `echo $PWD/work` -n + $ verdi computer configure local localhost --safe-interval 5 -n + + * Write a ``launch.py`` script: + + .. code-block:: python + + from aiida import orm, engine + from aiida.common.exceptions import NotExistent + + # Setting up inputs + computer = orm.load_computer('localhost') + try: + code = load_code('add@localhost') + except NotExistent: + # Setting up code via python API (or use "verdi code setup") + code = orm.Code(label='add', remote_computer_exec=[computer, '/bin/bash'], input_plugin_name='add') + + builder = code.get_builder() + builder.x = Int(4) + builder.y = Int(5) + builder.metadata.options.withmpi = False + builder.metadata.options.resources = { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1, + } + + # Running the calculation & parsing results + output_dict, node = engine.run_get_node(builder) + print("Parsing completed. Result: {}".format(output_dict['sum'].value)) + + .. note:: + + ``output_dict`` is a dictionary containing all the output nodes keyed after their label. + In this case: "remote_folder", "retrieved" and "sum". + + + * Launch the calculation: + + .. code-block:: console + + $ verdi run launch.py + + + If everything goes well, this should print the results of your calculation, something like: + + .. code-block:: console + + $ verdi run launch.py + Parsing completed. Result: 9 + +.. tip:: + + If you encountered a parsing error, it can be helpful to make a :ref:`topics:calculations:usage:calcjobs:dry_run`, which allows you to inspect the input folder generated by AiiDA before any calculation is launched. + + +Finally instead of running your calculation in the current shell, you can submit your calculation to the AiiDA daemon: + + * (Re)start the daemon to update its Python environment: + + .. code-block:: console + + $ verdi daemon restart --reset + + * Update your launch script to use: + + .. code-block:: python + + # Submitting the calculation + node = engine.submit(builder) + print("Submitted calculation {}".format(node)) + + .. note:: + + ``node`` is the |CalcJobNode| representing the state of the underlying calculation process (which may not be finished yet). + + + * Launch the calculation: + + .. code-block:: console + + $ verdi run launch.py + + This should print the UUID and the PK of the submitted calculation. + +You can use the verdi command line interface to :ref:`monitor` this processes: + +.. code-block:: bash + + $ verdi process list + + +This marks the end of this how-to. + +The |CalcJob| and |Parser| plugins are still rather basic and the ``aiida-add`` plugin package is missing a number of useful features, such as package metadata, documentation, tests, CI, etc. +Continue with :ref:`how-to:plugins` in order to learn how to quickly create a feature-rich new plugin package from scratch. + + +.. todo:: + + .. _how-to:plugin-codes:scheduler: + + title: Adding support for a custom scheduler + + `#3989`_ + + + .. _how-to:plugin-codes:transport: + + title: Adding support for a custom transport + + `#3990`_ + + +.. |Int| replace:: :py:class:`~aiida.orm.nodes.data.int.Int` +.. |Code| replace:: :py:class:`~aiida.orm.nodes.data.Code` +.. |Parser| replace:: :py:class:`~aiida.parsers.parser.Parser` +.. |parse| replace:: :py:class:`~aiida.parsers.parser.Parser.parse` +.. |folder| replace:: :py:class:`~aiida.common.folders.Folder` +.. |folder.open| replace:: :py:class:`~aiida.common.folders.Folder.open` +.. |CalcJob| replace:: :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` +.. |CalcJobNode| replace:: :py:class:`~aiida.orm.CalcJobNode` +.. |CalcInfo| replace:: :py:class:`~aiida.common.CalcInfo` +.. |CodeInfo| replace:: :py:class:`~aiida.common.CodeInfo` +.. |FolderData| replace:: :py:class:`~aiida.orm.nodes.data.folder.FolderData` +.. |spec| replace:: ``spec`` +.. |define| replace:: :py:class:`~aiida.engine.processes.calcjobs.CalcJob.define` +.. |prepare_for_submission| replace:: :py:class:`~aiida.engine.processes.calcjobs.CalcJob.prepare_for_submission` + +.. _#3989: https://github.com/aiidateam/aiida-core/issues/3989 +.. _#3990: https://github.com/aiidateam/aiida-core/issues/3990 +.. _#4123: https://github.com/aiidateam/aiida-core/issues/4123 diff --git a/docs/source/howto/plugins.rst b/docs/source/howto/plugins.rst index 7261e0a768..4dfab8d03d 100644 --- a/docs/source/howto/plugins.rst +++ b/docs/source/howto/plugins.rst @@ -5,11 +5,11 @@ How to package plugins ********************** This section focuses on how to *package* AiiDA extensions (plugins) so that they can be tested, published and eventually reused by others. -For guides on writing specific extensions, see :ref:`how-to:codes:interfacing` and :ref:`how-to:data:plugin`. +For guides on writing specific extensions, see :ref:`how-to:plugin-codes:interfacing` and :ref:`how-to:data:plugin`. .. todo:: - For guides on writing specific extensions, see :ref:`how-to:codes:interfacing`, -ref-'how-to:codes:scheduler', -ref-'how-to:codes:transport' or :ref:`how-to:data:plugin`. + For guides on writing specific extensions, see :ref:`how-to:plugin-codes:interfacing`, :ref:'how-to:plugin-codes:scheduler', :ref:'how-to:plugin-codes:transport' or :ref:`how-to:data:plugin`. .. _how-to:plugins:bundle: diff --git a/docs/source/howto/run_codes.rst b/docs/source/howto/run_codes.rst new file mode 100644 index 0000000000..984ccfd3eb --- /dev/null +++ b/docs/source/howto/run_codes.rst @@ -0,0 +1,535 @@ +.. _how-to:run-codes: + +************************* +How to run external codes +************************* + +This how-to walks you through the steps of setting up a (possibly remote) compute resource, setting up a code on that computer and submitting a calculation through AiiDA (similar to the :ref:`introductory tutorial `, but in more detail). + +To run an external code with AiiDA, you need an appropriate :ref:`calculation plugin `. +In the following, we assume that a plugin for your code is already available from the `aiida plugin registry `_ and installed on your machine, e.g. using ``pip install aiida-quantumespresso``. +If a plugin for your code is not yet available, see :ref:`how-to:plugin-codes`. + +Throughout the process you will be prompted for information on the computer and code. +In these prompts: + + * Type ``?`` followed by ```` to get help on what is being asked at any prompt. + * Press ``+C`` at any moment to abort the setup process. + Your AiiDA database will remain unmodified. + +.. note:: + + The ``verdi`` commands use ``readline`` extensions to provide default answers, which require an advanced terminal. + Use a standard terminal -- terminals embedded in some text editors (such as ``emacs``) have been known to cause problems. + +.. _how-to:run-codes:computer: + +How to set up a computer +======================== + +A |Computer| in AiiDA denotes a computational resource on which you will run your calculations. +It can either be: + + 1. the machine where AiiDA is installed or + 2. any machine that is accessible via `SSH `_ from the machine where AiiDA is installed (possibly :ref:`via a proxy server`). + +The second option allows managing multiple remote compute resources (including HPC clusters and cloud services) from the same AiiDA installation and moving computational jobs between them. + +.. note:: + + The second option requires access through an SSH keypair. + If your compute resource demands two-factor authentication, you may need to install AiiDA directly on the compute resource instead. + + +Computer requirements +--------------------- + +Each computer must satisfy the following requirements: + +* It runs a Unix-like operating system (Linux distros and MacOS should work fine) +* It has ``bash`` installed +* (optional) It has batch scheduler installed (see the :ref:`list of supported schedulers `) + +If you are configuring a remote computer, start by :ref:`configuring password-less SSH access ` to it. + +.. note:: + + AiiDA will use ``bash`` on the remote computer, regardless of the default shell. + Please ensure that your remote ``bash`` configuration does not load a different shell. + + +.. _how-to:run-codes:computer:setup: + +Computer setup +-------------- + +The configuration of computers happens in two steps: setting up the public metadata associated with the |Computer| in AiiDA provenance graphs, and configuring private connection details. + +Start by creating a new computer instance in the database: + +.. code-block:: console + + $ verdi computer setup + +At the end, the command will open your default editor on a file containing a summary of the configuration up to this point. +You can add ``bash`` commands that will be executed + + * *before* the actual execution of the job (under 'Pre-execution script'), and + * *after* the script submission (under 'Post execution script'). + +Use these additional lines to perform any further set up of the environment on the computer, for example loading modules or exporting environment variables: + +.. code-block:: bash + + export NEWVAR=1 + source some/file + +.. note:: + + Don't specify settings here that are specific to a code or calculation: you can set further pre-execution commands at the ``Code`` and even ``CalcJob`` level. + +When you are done editing, save and quit. +The computer has now been created in the database but you still need to *configure* access to it using your credentials. + +.. tip:: + In order to avoid having to retype the setup information the next time around, you can provide some (or all) of the information via a configuration file: + + .. code-block:: console + + $ verdi computer setup --config computer.yml + + where ``computer.yml`` is a configuration file in the `YAML format `__. + This file contains the information in a series of key-value pairs: + + .. code-block:: yaml + + --- + label: "localhost" + hostname: "localhost" + transport: local + scheduler: "direct" + work_dir: "/home/max/.aiida_run" + mpirun_command: "mpirun -np {tot_num_mpiprocs}" + mpiprocs_per_machine: "2" + prepend_text: | + module load mymodule + export NEWVAR=1 + + The list of the keys for the ``yaml`` file is given by the options of the ``computer setup`` command: + + .. code-block:: console + + $ verdi computer setup --help + + Note: remove the ``--`` prefix and replace ``-`` within the keys with an underscore ``_``. + +.. _how-to:run-codes:computer:configuration: + +Computer connection configuration +--------------------------------- + +The second step configures private connection details using: + +.. code-block:: console + + $ verdi computer configure TRANSPORTTYPE COMPUTERLABEL + +Replace ``COMPUTERLABEL`` with the computer label chosen during the setup and replace ``TRANSPORTTYPE`` with the name of chosen transport type, i.e., ``local`` for the localhost computer and ``ssh`` for any remote computer. + +After the setup and configuration have been completed, let's check that everything is working properly: + +.. code-block:: console + + $ verdi computer test COMPUTERNAME + +This command will perform various tests to make sure that AiiDA can connect to the computer, create new files in the scratch directory, retrieve files and query the job scheduler. + +.. _how-to:run-codes:computer:connection: + +Mitigating connection overloads +---------------------------------- + +Some compute resources, particularly large supercomputing centres, may not tolerate submitting too many jobs at once, executing scheduler commands too frequently or opening too many SSH connections. + + * Limit the number of jobs in the queue. + + Set a limit for the maximum number of workflows to submit, and only submit new ones once previous workflows start to complete. + The supported number of jobs depends on the supercomputer configuration which may be documented as part of the center's user documentation. + The supercomputer administrators may also find the information found on `this page `_ useful. + + * Increase the time interval between polling the job queue. + + The time interval (in seconds) can be set through the Python API by loading the corresponding |Computer| node, e.g. in the ``verdi shell``: + + .. code-block:: python + + load_computer('fidis').set_minimum_job_poll_interval(30.0) + + * Increase the connection cooldown time. + + This is the minimum time (in seconds) to wait between opening a new connection. + Modify it for an existing computer using: + + .. code-block:: bash + + verdi computer configure ssh --non-interactive --safe-interval + +.. important:: + + The two intervals apply *per daemon worker*, i.e. doubling the number of workers may end up putting twice the load on the remote computer. + +Managing your computers +----------------------- + +Fully configured computers can be listed with: + +.. code-block:: console + + $ verdi computer list + +To get detailed information on the specific computer named ``COMPUTERLABEL``: + +.. code-block:: console + + $ verdi computer show COMPUTERLABEL + +To rename a computer or remove it from the database: + +.. code-block:: console + + $ verdi computer rename OLDCOMPUTERLABEL NEWCOMPUTERLABEL + $ verdi computer delete COMPUTERLABEL + +.. note:: + + Before deleting a |Computer|, you will need to delete *all* nodes linked to it (e.g. any ``CalcJob`` and ``RemoteData`` nodes). + Otherwise, AiiDA will prevent you from doing so in order to preserve provenance. + +If a remote machine is under maintenance (or no longer operational), you may want to **disable** the corresponding |Computer|. +Doing so will prevent AiiDA from connecting to the given computer to check the state of calculations or to submit new calculations. + +.. code-block:: console + + $ verdi computer disable COMPUTERLABEL + $ verdi computer enable COMPUTERLABEL + +.. _how-to:run-codes:code: + +How to setup a code +=================== + +Once your computer is configured, you can set up codes on it. + +AiiDA stores a set of metadata for each code, which is attached automatically to each calculation using it. +Besides being important for reproducibility, this also makes it easy to query for all calculations that were run with a given code (for instance, if a specific version is found to contain a bug). + +.. _how-to:run-codes:code:setup: + +Setting up a code +----------------- + +The ``verdi code`` CLI is the access point for managing codes in AiiDA. +To setup a new code, execute: + +.. code-block:: console + + $ verdi code setup + +and you will be guided through a process to setup your code. + +.. admonition:: On remote and local codes + :class: tip title-icon-lightbulb + + In most cases, it is advisable to install the executables to be used by AiiDA on the target machine *before* submitting calculations using them in order to take advantage of the compilers and libraries present on the target machine. + This setup is referred to as *remote* codes (``Installed on target computer?: True``). + + Occasionally, you may need to run small, reasonably machine-independent scripts (e.g. Python or bash), and copying them manually to a number of different target computers can be tedious. + For this use case, AiiDA provides *local* codes (``Installed on target computer?: False``). + Local codes are stored in the AiiDA file repository and copied to the target computer for every execution. + + Do *not* use local codes as a way of encapsulating the environment of complex executables. + Containers are a much better solution to this problem, and we are working on adding native support for containers in AiiDA. + + +At the end of these steps, you will be prompted to edit a script, where you can include ``bash`` commands that will be executed + + * *before* running the submission script (after the 'Pre execution script' lines), and + * *after* running the submission script (after the 'Post execution script' separator). + +Use this for instance to load modules or set variables that are needed by the code, such as: + +.. code-block:: bash + + module load intelmpi + +At the end, you receive a confirmation, with the *PK* and the *UUID* of your new code. + +.. admonition:: Using configuration files + :class: tip title-icon-lightbulb + + Analogous to a :ref:`computer setup `, some (or all) the information described above can be provided via a configuration file: + + .. code-block:: console + + $ verdi code setup --config code.yml + + where ``code.yml`` is a configuration file in the `YAML format `_. + + This file contains the information in a series of key:value pairs: + + .. code-block:: yaml + + --- + label: "qe-6.3-pw" + description: "quantum_espresso v6.3" + input_plugin: "quantumespresso.pw" + on_computer: true + remote_abs_path: "/path/to/code/pw.x" + computer: "localhost" + prepend_text: | + module load module1 + module load module2 + append_text: " " + + The list of the keys for the ``yaml`` file is given by the available options of the ``code setup`` command: + + .. code-block:: console + + $ verdi code setup --help + + Note: remove the ``--`` prefix and replace ``-`` within the keys with an underscore ``_``. + +Managing codes +-------------- + +You can change the label of a code by using the following command: + +.. code-block:: console + + $ verdi code relabel "new-label" + +where can be the numeric *PK*, the *UUID* or the label of the code (either ``label`` or ``label@computername``) if the label is unique. + +You can also list all available codes and their identifiers with: + +.. code-block:: console + + $ verdi code list + +which also accepts flags to filter only codes on a given computer, or only codes using a specific plugin, etc. (use the ``-h`` option). + +You can get the information of a specific code with: + +.. code-block:: console + + $ verdi code show + +Finally, to delete a code use: + +.. code-block:: console + + $ verdi code delete + +(only if it wasn't used by any calculation, otherwise an exception is raised). + +.. note:: + + Codes are a subclass of :py:class:`Node ` and, as such, you can attach ``extras`` to a code, for example: + + .. code-block:: python + + load_code('').set_extra('version', '6.1') + load_code('').set_extra('family', 'cp2k') + + These can be useful for querying, for instance in order to find all runs done with the CP2K code of version 6.1 or later. + +.. _how-to:run-codes:submit: + +How to submit a calculation +=========================== + +After :ref:`setting up your computer ` and :ref:`setting up your code `, you are ready to launch your calculations! + + * Make sure the daemon is running: + + .. code-block:: bash + + verdi daemon status + + * Figure out which inputs your |CalcJob| plugin needs, e.g. using: + + .. code-block:: bash + + verdi plugin list aiida.calculations arithmetic.add + + * Write a ``submit.py`` script: + + .. code-block:: python + + from aiida.engine import submit + + code = load_code('add@localhost') + builder = code.get_builder() + builder.x = Int(4) + builder.y = Int(5) + builder.metadata.options.withmpi = False + builder.metadata.options.resources = { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1, + + } + builder.metadata.description = "My first calculation." + + print(submit(builder)) + + Of course, the code label and builder inputs need to be adapted to your code and calculation. + + * Submit your calculation to the AiiDA daemon: + + .. code-block:: bash + + verdi run submit.py + +After this, use ``verdi process list`` to monitor the status of the calculations. + +See :ref:`topics:processes:usage:launching` and :ref:`topics:processes:usage:monitoring` for more details. + + + +.. _how-to:run-codes:caching: + +How to save computational resources using caching +================================================= + +There are numerous reasons why you might need to re-run calculations you have already run before. +Maybe you run a great number of complex workflows in high-throughput that each may repeat the same calculation, or you may have to restart an entire workflow that failed somewhere half-way through. +Since AiiDA stores the full provenance of each calculation, it can detect whether a calculation has been run before and, instead of running it again, simply reuse its outputs, thereby saving valuable computational resources. +This is what we mean by **caching** in AiiDA. + +.. _how-to:run-codes:caching:enable: + +How to enable caching +--------------------- + +Caching is **not enabled by default**. +The reason is that it is designed to work in an unobtrusive way and simply save time and valuable computational resources. +However, this design is a double-egded sword, in that a user that might not be aware of this functionality, can be caught off guard by the results of their calculations. + +.. important:: + + The caching mechanism comes with some limitations and caveats that are important to understand. + Refer to the :ref:`topics:provenance:caching:limitations` section for more details. + +In order to enable caching for your profile (here called ``aiida_profile``), place the following ``cache_config.yml`` file in your ``.aiida`` configuration folder: + +.. code-block:: yaml + + aiida_profile: + default: True + +From this point onwards, when you launch a new calculation, AiiDA will compare its hash (depending both on the type of calculation and its inputs, see :ref:`topics:provenance:caching:hashing`) against other calculations already present in your database. +If another calculation with the same hash is found, AiiDA will reuse its results without repeating the actual calculation. + +In order to ensure that the provenance graph with and without caching is the same, AiiDA creates both a new calculation node and a copy of the output data nodes as shown in :numref:`fig_caching`. + +.. _fig_caching: +.. figure:: include/images/caching.png + :align: center + :height: 350px + + When reusing the results of a calculation **C** for a new calculation **C'**, AiiDA simply makes a copy of the result nodes and links them up as usual. + +.. note:: + + AiiDA uses the *hashes* of the input nodes **D1** and **D2** when searching the calculation cache. + That is to say, if the input of **C'** were new nodes **D1'** and **D2'** with the same content (hash) as **D1**, **D2**, the cache would trigger as well. + +.. _how-to:run-codes:caching:configure: + +How to configure caching +------------------------ + +The caching mechanism can be configured on a process class level, meaning the rules will automatically be applied to all instances of the given class, or on a per-instance level, meaning it can be controlled for individual process instances when they are launch. + +Class level +........... + +Besides an on/off switch per profile, the ``.aiida/cache_config.yml`` provides control over caching at the level of specific calculations using their corresponding entry point strings (see the output of ``verdi plugin list aiida.calculations``): + +.. code-block:: yaml + + aiida_profile: + default: False + enabled: + - aiida.calculations:quantumespresso.pw + disabled: + - aiida.calculations:templatereplacer + +In this example, where ``aiida_profile`` is the name of the profile, caching is disabled by default, but explicitly enabled for calculaions of the ``PwCalculation`` class, identified by its corresponding ``aiida.calculations:quantumespresso.pw`` entry point string. +It also shows how to disable caching for particular calculations (which has no effect here due to the profile-wide default). + +For calculations which do not have an entry point, you need to specify the fully qualified Python name instead. +For example, the ``seekpath_structure_analysis`` calcfunction defined in ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis`` is labeled as ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis.seekpath_structure_analysis``. +From an existing :class:`~aiida.orm.nodes.process.calculation.CalculationNode`, you can get the identifier string through the ``process_type`` attribute. + +The caching configuration also accepts ``*`` wildcards. +For example, the following configuration enables caching for all calculation entry points defined by ``aiida-quantumespresso``, and the ``seekpath_structure_analysis`` calcfunction. +Note that the ``*.seekpath_structure_analysis`` entry needs to be quoted, because it starts with ``*`` which is a special character in YAML. + +.. code-block:: yaml + + aiida_profile: + default: False + enabled: + - aiida.calculations:quantumespresso.* + - '*.seekpath_structure_analysis' + +Any entry with a wildcard is overridden by a more specific entry. +The following configuration enables caching for all ``aiida.calculation`` entry points, except those of ``aiida-quantumespresso``: + +.. code-block:: yaml + + aiida_profile: + default: False + enabled: + - aiida.calculations:* + disabled: + - aiida.calculations:quantumespresso.* + + +Instance level +.............. + +Caching can be enabled or disabled on a case-by-case basis by using the :class:`~aiida.manage.caching.enable_caching` or :class:`~aiida.manage.caching.disable_caching` context manager, respectively, regardless of the profile settings: + +.. code-block:: python + + from aiida.engine import run + from aiida.manage.caching import enable_caching + with enable_caching(identifier='aiida.calculations:templatereplacer'): + run(...) + +.. warning:: + + This affects only the current Python interpreter and won't change the behavior of the daemon workers. + This means that this technique is only useful when using :py:class:`~aiida.engine.run`, and **not** with :py:class:`~aiida.engine.submit`. + +If you suspect a node is being reused in error (e.g. during development), you can also manually *prevent* a specific node from being reused: + +#. Load one of the nodes you suspect to be a clone. + Check that :meth:`~aiida.orm.nodes.Node.get_cache_source` returns a UUID. + If it returns `None`, the node was not cloned. + +#. Clear the hashes of all nodes that are considered identical to this node: + + .. code:: python + + for node in node.get_all_same_nodes(): + node.clear_hash() + +#. Run your calculation again. + The node in question should no longer be reused. + + +.. |Computer| replace:: :py:class:`~aiida.orm.Computer` +.. |CalcJob| replace:: :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` diff --git a/docs/source/howto/ssh.rst b/docs/source/howto/ssh.rst new file mode 100644 index 0000000000..bf2f4cc65e --- /dev/null +++ b/docs/source/howto/ssh.rst @@ -0,0 +1,273 @@ +.. _how-to:ssh: + +**************************** +How to setup SSH connections +**************************** + +AiiDA communicates with remote computers via the SSH protocol. +There are two ways of setting up an SSH connection for AiiDA: + + 1. Using a passwordless SSH key (easier, less safe) + 2. Using a password-protected SSH key through ``ssh-agent`` (one more step, safer) + +.. _how-to:ssh:passwordless: + +Using a passwordless SSH key +============================ + + +There are numerous tutorials on the web, see e.g. `here `_. +Very briefly, first create a new private/public keypair (``aiida``/``aiida.pub``), leaving passphrase emtpy: + +.. code-block:: console + + $ ssh-keygen -t rsa -b 4096 -f ~/.ssh/aiida + +Copy the public key to the remote machine, normally this will add the public key to the rmote machine's ``~/.ssh/authorized_keys``: + +.. code-block:: console + + $ ssh-copy-id -i ~/.ssh/aiida YOURUSERNAME@YOURCLUSTERADDRESS + +Add the following lines to your ``~/.ssh/config`` file (or create it, if it does not exist): + +.. code-block:: bash + + Host YOURCLUSTERADDRESS + User YOURUSERNAME + IdentityFile ~/.ssh/aiida + +.. note:: + + If your cluster needs you to connect to another computer *PROXY* first, you can use the ``proxy_command`` feature of ssh, see :ref:`how-to:ssh:proxy`. + +You should now be able to access the remote computer (without the need to type a password) *via*: + +.. code-block:: console + + $ ssh YOURCLUSTERADDRESS + # this connection is used to copy files + $ sftp YOURCLUSTERADDRESS + +.. admonition:: Connection closed failures + :class: attention title-icon-troubleshoot + + + If the ``ssh`` command works, but the ``sftp`` command prints ``Connection closed``, there may be a line in the ``~/.bashrc`` file **on the cluster** that either produces text output or an error. + Remove/comment lines from this file until no output or error is produced: this should make ``sftp`` work again. + +Finally, if you are planning to use a batch scheduler on the remote computer, try also: + +.. code-block:: console + + $ ssh YOURCLUSTERADDRESS QUEUE_VISUALIZATION_COMMAND + +replacing ``QUEUE_VISUALIZATION_COMMAND`` by ``squeue`` (SLURM), ``qstat`` (PBSpro) or the equivalent command of your scheduler and check that it prints a list of the job queue without errors. + +.. admonition:: Scheduler errors? + :class: attention title-icon-troubleshoot + + If the previous command errors with ``command not found``, while the same ``QUEUE_VISUALIZATION_COMMAND`` works fine after you've logged in via SSH, it may be that a guard in the ``.bashrc`` file on the cluster prevents necessary modules from being loaded. + + Look for lines like: + + .. code-block:: bash + + [ -z "$PS1" ] && return + + or: + + .. code-block:: bash + + case $- in + *i*) ;; + *) return;; + esac + + which will prevent any instructions that follow from being executed. + + You can either move relevant instructions before these lines or delete the guards entirely. + If you are wondering whether the ``PATH`` environment variable is set correctly, you can check its value using: + + .. code-block:: bash + + $ ssh YOURCLUSTERADDRESS 'echo $PATH' + +.. _how-to:ssh:passphrase: + +Using passphrase-protected keys *via* an ssh-agent +================================================== + + +Tools like ``ssh-agent`` (available on most Linux distros and MacOS) allow you to enter the passphrase of a protected key *once* and provide access to the decrypted key for as long as the agent is running. +This allows you to use a passphrase-protected key (required by some HPC centres), while making the decrypted key available to AiiDA for automatic SSH operations. + +Creating the key +^^^^^^^^^^^^^^^^ + +Start by following the instructions above for :ref:`how-to:ssh:passwordless`, the only difference being that you enter a passphrase when creating the key (and when logging in to the remote computer). + +Adding the key to the agent +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Now provide the passphrase for your private key to the agent: + +.. code:: bash + + ssh-add ~/.ssh/aiida + +The private key and the relative passphrase are now recorded in an instance of the agent. + +.. note:: + + The passphase is stored in the agent only until the next reboot. + If you shut down or restart the AiiDA machine, before starting the AiiDA deamon remember to run the ``ssh-add`` command again. + +Starting the ssh-agent +^^^^^^^^^^^^^^^^^^^^^^ + +On most modern Linux installations, the ``ssh-agent`` starts automatically at login (e.g. Ubuntu 16.04 and later or MacOS 10.5 and later). +If you received an error ``Could not open a connection to your authentication agent``, you will need to start the agent manually instead. + +Check whether you can start an ``ssh-agent`` **in your current shell**: + +.. code:: bash + + eval `ssh-agent` + +In order to reuse the same agent instance everywhere (including the AiiDA daemon), the environment variables of ``ssh-agent`` need to be reused by *all* shells. +Download the script :download:`load-singlesshagent.sh ` and place it e.g. in ``~/bin``. +Then add the following lines to your ``~/.bashrc`` file: + +.. code:: bash + + if [ -f ~/bin/load-singlesshagent.sh ]; then + . ~/bin/load-singlesshagent.sh + fi + +To check that it works: + +* Open a new shell (``~/.bashrc`` file is sourced). +* Run ``ssh-add``. +* Close the shell. +* Open a new shell and try logging in to the remote computer. + +Try logging in to the remote computer; it should no longer require a passphrase. + +The key and its corresponding passphrase are now stored by the agent until it is stopped. +After a reboot, remember to run ``ssh-add ~/.ssh/aiida`` again before starting the AiiDA daemon. + +Integrating the ssh-agent with keychain on OSX +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +On OSX Sierra and later, the native ``ssh-add`` client allows passphrases to be stored persistently in the `OSX keychain `__. +Store the passphrase in the keychain using the OSX-specific ``-k`` argument: + +.. code:: bash + + ssh-add -k ~/.ssh/aiida + +To instruct ssh to look in the OSX keychain for key passphrases, add the following lines to ``~/.ssh/config``: + +.. code:: bash + + Host * + UseKeychain yes + +AiiDA configuration +^^^^^^^^^^^^^^^^^^^ + +When :ref:`configuring the computer in AiiDA `, simply make sure that ``Allow ssh agent`` is set to ``true`` (default). + +.. _how-to:ssh:proxy: + +Connecting to a remote computer *via* a proxy server +==================================================== + +Some compute clusters require you to connect to an intermediate server *PROXY*, from which you can then connect to the cluster *TARGET* on which you run your calculations. +This section explains how to use the ``proxy_command`` feature of ``ssh`` in order to make this jump automatically. + +.. tip:: + + This method can also be used to automatically tunnel into virtual private networks, if you have an account on a proxy/jumphost server with access to the network. + +Requirements +^^^^^^^^^^^^ + +The ``netcat`` tool needs to be present on the *PROXY* server (executable may be named ``netcat`` or ``nc``). +``netcat`` simply takes the standard input and redirects it to a given TCP port. + +.. dropdown:: Installing netcat + + If neither ``netcat`` or ``nc`` are available, you will need to install it on your own. + You can download a `netcat distribution `_, unzip the downloaded package, ``cd`` into the folder and execute something like: + + .. code-block:: console + + $ ./configure --prefix=. + $ make + $ make install + + This usually creates a subfolder ``bin``, containing the ``netcat`` and ``nc`` executables. + Write down the full path to ``nc`` which we will need later. + + + +SSH configuration +^^^^^^^^^^^^^^^^^ + +Edit the ``~/.ssh/config`` file on the computer on which you installed AiiDA (or create it if missing) and add the following lines:: + + Host SHORTNAME_TARGET + Hostname FULLHOSTNAME_TARGET + User USER_TARGET + IdentityFile ~/.ssh/aiida + ProxyCommand ssh USER_PROXY@FULLHOSTNAME_PROXY ABSPATH_NETCAT %h %p + +replacing the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers, and replacing ``ABSPATH_NETCAT`` with the result of ``which netcat`` (or ``which nc``). + +.. note:: + + If desired/necessary for your netcat implementation, hide warnings and errors that may occur during the proxying/tunneling by redirecting stdout and stderr, e.g. by appending ``2> /dev/null`` to the ``ProxyCommand``. + + +This should allow you to directly connect to the *TARGET* server using + +.. code-block:: console + + $ ssh SHORTNAME_TARGET + +For a *passwordless* connection, you need to follow the instructions :ref:`how-to:ssh:passwordless` *twice*: once for the connection from your computer to the *PROXY* server, and once for the connection from the *PROXY* server to the *TARGET* server. + + +.. warning:: + + There are occasionally ``netcat`` implementations, which keep running after you close your SSH connection, resulting in a growing number of open SSH connections between the *PROXY* server and the *TARGET* server. + If you suspect an issue, it may be worth connecting to the *PROXY* server and checking how many ``netcat`` processes are running, e.g. via: + + .. code-block:: console + + $ ps -aux | grep netcat + +AiiDA configuration +^^^^^^^^^^^^^^^^^^^ + +When :ref:`configuring the computer in AiiDA `, AiiDA will automatically parse the required information from your ``~/.ssh/config`` file. + +.. dropdown:: Specifying the proxy_command manually + + If, for any reason, you need to specify the ``proxy_command`` option of ``verdi computer configure ssh`` manually, please note the following: + + 1. Don't use placeholders ``%h`` and ``%p`` (AiiDA replaces them only when parsing from the ``~/.ssh/config`` file) but provide the actual hostname and port. + 2. Don't include stdout/stderr redirection (AiiDA strips it automatically, but only when parsing from the ``~/.ssh/config`` file). + + +Using kerberos tokens +===================== + +If the remote machine requires authentication through a Kerberos token (that you need to obtain before using ssh), you typically need to + + * install ``libffi`` (``sudo apt-get install libffi-dev`` under Ubuntu) + * install the ``ssh_kerberos`` extra during the installation of aiida-core (see :ref:`intro:install:aiida-core`). + +If you provide all necessary ``GSSAPI`` options in your ``~/.ssh/config`` file, ``verdi computer configure`` should already pick up the appropriate values for all the gss-related options. diff --git a/docs/source/index.rst b/docs/source/index.rst index 55c6ea8b0a..bfc29c1d21 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -110,13 +110,11 @@ How to cite If you use AiiDA for your research, please cite the following work: -.. highlights:: **AiiDA 1.0:** Sebastiaan. P. Huber, Spyros Zoupanos, Martin Uhrin, Leopold Talirz, Leonid Kahle, Rico Häuselmann, Dominik Gresch, Tiziano Müller, Aliaksandr V. Yakutovich, Casper W. Andersen, Francisco F. Ramirez, Carl S. Adorf, Fernando Gargiulo, Snehal Kumbhar, Elsa Passaro, Conrad Johnston, Andrius Merkys, Andrea Cepellotti, Nicolas Mounet, Nicola Marzari, Boris Kozinsky, Giovanni Pizzi, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, `arXiv:2003.12476 (2020) `_; - `http://www.aiida.net `_. +.. highlights:: **AiiDA >= 1.0:** Sebastiaan. P. Huber, Spyros Zoupanos, Martin Uhrin, Leopold Talirz, Leonid Kahle, Rico Häuselmann, Dominik Gresch, Tiziano Müller, Aliaksandr V. Yakutovich, Casper W. Andersen, Francisco F. Ramirez, Carl S. Adorf, Fernando Gargiulo, Snehal Kumbhar, Elsa Passaro, Conrad Johnston, Andrius Merkys, Andrea Cepellotti, Nicolas Mounet, Nicola Marzari, Boris Kozinsky, Giovanni Pizzi, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) -.. highlights:: **AiiDA 0.x:** Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari, +.. highlights:: **AiiDA < 1.0:** Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari, and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database - for computational science*, Comp. Mat. Sci 111, 218-230 (2016); - https://doi.org/10.1016/j.commatsci.2015.09.013; http://www.aiida.net. + for computational science*, Comp. Mat. Sci 111, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) **************** diff --git a/docs/source/intro/get_started.rst b/docs/source/intro/get_started.rst index 134879daef..0bf71e3386 100644 --- a/docs/source/intro/get_started.rst +++ b/docs/source/intro/get_started.rst @@ -182,7 +182,7 @@ Finally, to check that all services are running as expected use: ✓ profile: On profile me ✓ repository: /home/ubuntu/.aiida/repository/me ✓ postgres: Connected as aiida_qs_ubuntu_c6a4f69d255fbe9cdb7385dcdcf3c050@localhost:5432 - ✓ rabbitmq: Connected to amqp://127.0.0.1?heartbeat=600 + ✓ rabbitmq: Connected as amqp://127.0.0.1?heartbeat=600 ✓ daemon: Daemon is running as PID 16430 since 2020-04-29 12:17:31 Awesome! You now have a fully operational installation from which to take the next steps! diff --git a/docs/source/intro/installation.rst b/docs/source/intro/installation.rst index 039f206f81..f45b6f43b3 100644 --- a/docs/source/intro/installation.rst +++ b/docs/source/intro/installation.rst @@ -18,6 +18,13 @@ AiiDA is designed to run on `Unix `_ operati * `postgresql`_ (Database software, version 9.4 or higher) * `RabbitMQ`_ (A message broker necessary for AiiDA to communicate between processes) +.. admonition:: PostgreSQL and RabbitMQ as services + :class: tip title-icon-tip + + PostgreSQL and RabbitMQ can also be configured to run as services that run on other machines. + When setting up a profile, after having installed AiiDA, you can specify how to connect to these machines. + With this setup, it is not necesseray to install PostgreSQL nor RabbitMQ on the machine where AiiDA is installed. + Depending on your set up, there are a few optional dependencies: * `git`_ (Version control system used for AiiDA development) @@ -361,21 +368,24 @@ Using virtual environments ========================== AiiDA depends on a number of third party python packages, and usually on specific versions of those packages. -In order not to interfere with third party packages needed by other software on your system, we **strongly** recommend isolating AiiDA in a virtual python environment. +In order to not interfere with third party packages needed by other software on your system, we **strongly** recommend isolating AiiDA in a virtual Python environment, for example, by means of one of the methods described below. .. admonition:: Additional Information :class: seealso title-icon-read-more A very good tutorial on Python environments is provided by `realpython.com `__. -`venv `__ is module included directly with python for creating virtual environments. +venv +---- + +The `venv `__ module for creating virtual environments ships directly with Python. To create a virtual environment, in a given directory, run: .. code-block:: console $ python3 -m venv /path/to/new/virtual/environment/aiida -The command to activate the environment is shell specific (see `the documentation `__. +The command to activate the environment is shell specific (see `the documentation `__). With bash the following command is used: .. code-block:: console @@ -391,19 +401,22 @@ To leave or deactivate the environment, simply run: .. admonition:: Update install tools :class: tip title-icon-tip - You may need to install ``pip`` and ``setuptools`` in your virtual environment in case the system or user version of these tools is old + You may need to update ``pip`` and ``setuptools`` in your virtual environment, in case the system or user version of these tools is old. .. code-block:: console (aiida) $ pip install -U setuptools pip -If you have `Conda`_ installed then you can directly create a new environment with ``aiida-core`` and (optionally) Postgres and RabbitMQ installed. +Conda +----- + +If you have `Conda`_ installed then you can directly create a new environment with ``aiida-core`` and (optionally) the Postgres and RabbitMQ services installed. .. code-block:: console $ conda create -n aiida -c conda-forge python=3.7 aiida-core aiida-core.services pip - $ conda activate - $ conda deactivate aiida + $ conda activate aiida + $ conda deactivate .. _intro:install:aiida-core: @@ -462,7 +475,7 @@ In order to install any of these package groups, simply append them as a comma s .. code-block:: console - $ pip install -e aiida-core[atomic_tools,docs] + $ pip install -e "aiida-core[atomic_tools,docs]" .. admonition:: Kerberos on Ubuntu :class: note title-icon-troubleshoot @@ -501,7 +514,8 @@ Most users should use the interactive quicksetup: which leads through the installation process and takes care of creating the corresponding AiiDA database. -For maximum control and customizability, one can use ``verdi setup`` and set up the database manually as explained below. +For maximum customizability, one can use ``verdi setup``, that provides fine-grained control in configuring how AiiDA should connect to the required services. +This is useful, for example, if PostgreSQL and or RabbitMQ are not installed and configured with default settings, or are run on a different machine from AiiDA itself. .. admonition:: Don't forget to backup your data! :class: tip title-icon-tip @@ -650,6 +664,33 @@ During the ``verdi setup`` phase, use ``!`` to leave host empty and specify your $ AiiDA Database user: $ AiiDA Database password: "" + +RabbitMQ configuration +...................... + +In most normal setups, RabbitMQ will be installed and run as a service on the same machine that hosts AiiDA itself. +In that case, using the default configuration proposed during a profile setup will work just fine. +However, when the installation of RabbitMQ is not standard, for example it runs on a different port, or even runs on a completely different machine, all relevant connection details can be configured with ``verdi setup``. + +The following parameters can be configured: + ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Parameter | Option | Default | Explanation | ++==============+===========================+===============+=========================================================================================================================+ +| Protocol | ``--broker-protocol`` | ``amqp`` | The protocol to use, can be either ``amqp`` or ``amqps`` for SSL enabled connections. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Username | ``--broker-username`` | ``guest`` | The username with which to connect. The ``guest`` account is available and usable with a default RabbitMQ installation. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Password | ``--broker-password`` | ``guest`` | The password with which to connect. The ``guest`` account is available and usable with a default RabbitMQ installation. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Host | ``--broker-host`` | ``127.0.0.1`` | The hostname of the RabbitMQ server. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Port | ``--broker-port`` | ``5672`` | The port to which the server listens. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ +| Virtual host | ``--broker-virtual-host`` | ``''`` | Optional virtual host. If defined, needs to start with a forward slash. | ++--------------+---------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------+ + + verdi setup ........... @@ -727,8 +768,8 @@ Use the ``verdi status`` command to check that all services are up and running: ✓ profile: On profile quicksetup ✓ repository: /repo/aiida_dev/quicksetup - ✓ postgres: Connected to aiida@localhost:5432 - ✓ rabbitmq: Connected to amqp://127.0.0.1?heartbeat=600 + ✓ postgres: Connected as aiida@localhost:5432 + ✓ rabbitmq: Connected as amqp://127.0.0.1?heartbeat=600 ✓ daemon: Daemon is running as PID 2809 since 2019-03-15 16:27:52 In the example output, all service have a green check mark and so should be running as expected. @@ -857,7 +898,7 @@ The profile is created under the ``aiida`` username, so to execute commands use: ✓ profile: On profile default ✓ repository: /home/aiida/.aiida/repository/default ✓ postgres: Connected as aiida_qs_aiida_477d3dfc78a2042156110cb00ae3618f@localhost:5432 - ✓ rabbitmq: Connected to amqp://127.0.0.1?heartbeat=600 + ✓ rabbitmq: Connected as amqp://127.0.0.1?heartbeat=600 ✓ daemon: Daemon is running as PID 1795 since 2020-05-20 02:54:00 Or to enter into the container interactively: diff --git a/docs/source/intro/troubleshooting.rst b/docs/source/intro/troubleshooting.rst index ecc93ad576..c6a23d5960 100644 --- a/docs/source/intro/troubleshooting.rst +++ b/docs/source/intro/troubleshooting.rst @@ -12,8 +12,8 @@ If you experience any problems, first check that all services are up and running ✓ profile: On profile django ✓ repository: /repo/aiida_dev/django - ✓ postgres: Connected to aiida@localhost:5432 - ✓ rabbitmq: Connected to amqp://127.0.0.1?heartbeat=600 + ✓ postgres: Connected as aiida@localhost:5432 + ✓ rabbitmq: Connected as amqp://127.0.0.1?heartbeat=600 ✓ daemon: Daemon is running as PID 2809 since 2019-03-15 16:27:52 In the example output, all service have a green check mark and so should be running as expected. @@ -269,6 +269,13 @@ To test if a the computer does not produce spurious output, run (after configuri which checks and, in case of problems, suggests how to solve the problem. +.. note:: + + If the methods explained above do not work, you can configure AiiDA to not use a login shell when connecting to your computer, which may prevent the spurious output from being printed: + During ``verdi computer configure``, set ``-no-use-login-shell`` or when asked to use a login shell, set it to ``False``. + Note, however, that this may result in a slightly different environment, since `certain startup files are only sourced for login shells `_. + + .. _StackExchange thread: https://apple.stackexchange.com/questions/51036/what-is-the-difference-between-bash-profile-and-bashrc diff --git a/docs/source/intro/tutorial.rst b/docs/source/intro/tutorial.rst index 7e620f743c..ffda0d8c00 100644 --- a/docs/source/intro/tutorial.rst +++ b/docs/source/intro/tutorial.rst @@ -228,6 +228,8 @@ It should look something like the graph shown in :numref:`fig_calcfun_graph`. .. note:: Remember that the PK of the ``CalcJob`` can be different for your database. +.. _tutorial:basic:calcjob: + CalcJobs ======== @@ -235,7 +237,12 @@ When running calculations that require an external code or run on a remote machi For this purpose, AiiDA provides the ``CalcJob`` process class. To run a ``CalcJob``, you need to set up two things: a ``code`` that is going to implement the desired calculation and a ``computer`` for the calculation to run on. -If you're running this tutorial in the Quantum Mobile VM or on Binder, these have been pre-configured for you. If you're running on your own machine, you can follow the instructions in the panel below: + +If you're running this tutorial in the Quantum Mobile VM or on Binder, these have been pre-configured for you. If you're running on your own machine, you can follow the instructions in the panel below. + +.. seealso:: + + More details for how to :ref:`run external codes `. .. dropdown:: Install localhost computer and code @@ -615,23 +622,23 @@ We have also compiled useful how-to guides that are especially relevant for the Working with external codes Existing calculation plugins, for interfacing with external codes, are available on the `aiida plugin registry `_. - If none meet your needs, then the :ref:`external codes how-to ` can show you how to create your own calculation plugin. + If none meet your needs, then the :ref:`external codes how-to ` can show you how to create your own calculation plugin. Tuning performance To optimise the performance of AiiDA for running many concurrent computations see the :ref:`tuning performance how-to `. Saving computational resources - AiiDA can cache and reuse the outputs of identical computations, as described in the :ref:`caching how-to `. + AiiDA can cache and reuse the outputs of identical computations, as described in the :ref:`caching how-to `. .. dropdown:: Run computations on High Performance Computers Connecting to supercomputers - To setup up a computer which can communicate with a HPC over SSH, see the :ref:`running on supercomputers how-to `, or add a :ref:`custom transport `. + To setup up a computer which can communicate with a high-performance computer over SSH, see the :ref:`how-to for running external codes `, or add a :ref:`custom transport `. AiiDA has pre-written scheduler plugins to work with LSF, PBSPro, SGE, Slurm and Torque. Working with external codes Existing calculation plugins, for interfacing with external codes, are available on the `aiida plugin registry `_. - If none meet your needs, then the :ref:`external codes how-to ` can show you how to create your own calculation plugin. + If none meet your needs, then the :ref:`external codes how-to ` can show you how to create your own calculation plugin. Exploring your data Once you have run multiple computations, the :ref:`find and query data how-to ` can show you how to efficiently explore your data. The data lineage can also be visualised as a :ref:`provenance graph `. @@ -646,4 +653,4 @@ We have also compiled useful how-to guides that are especially relevant for the .. todo:: - Add to "Connecting to supercomputers": , or you can add a :ref:`custom scheduler `. + Add to "Connecting to supercomputers": , or you can add a :ref:`custom scheduler `. diff --git a/docs/source/reference/command_line.rst b/docs/source/reference/command_line.rst index a47b4bcece..d8da7d4c8a 100644 --- a/docs/source/reference/command_line.rst +++ b/docs/source/reference/command_line.rst @@ -15,7 +15,7 @@ Below is a list with all available subcommands. ``verdi calcjob`` ----------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -39,7 +39,7 @@ Below is a list with all available subcommands. ``verdi code`` -------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -64,7 +64,7 @@ Below is a list with all available subcommands. ``verdi comment`` ----------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -85,16 +85,14 @@ Below is a list with all available subcommands. ``verdi completioncommand`` --------------------------- -:: +.. code:: console Usage: [OPTIONS] Return the code to activate bash completion. - :note: this command is mainly for back-compatibility. You should - rather use:; - - eval "$(_VERDI_COMPLETE=source verdi)" + This command is mainly for back-compatibility. + You should rather use: eval "$(_VERDI_COMPLETE=source verdi)" Options: --help Show this message and exit. @@ -105,7 +103,7 @@ Below is a list with all available subcommands. ``verdi computer`` ------------------ -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -121,6 +119,7 @@ Below is a list with all available subcommands. duplicate Duplicate a computer allowing to change some parameters. enable Enable the computer for the given user. list List all available computers. + relabel Relabel a computer. rename Rename a computer. setup Create a new computer. show Show detailed information for a computer. @@ -132,7 +131,7 @@ Below is a list with all available subcommands. ``verdi config`` ---------------- -:: +.. code:: console Usage: [OPTIONS] OPTION_NAME OPTION_VALUE @@ -149,7 +148,7 @@ Below is a list with all available subcommands. ``verdi daemon`` ---------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -173,7 +172,7 @@ Below is a list with all available subcommands. ``verdi data`` -------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -188,7 +187,7 @@ Below is a list with all available subcommands. ``verdi database`` ------------------ -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -207,7 +206,7 @@ Below is a list with all available subcommands. ``verdi devel`` --------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -217,11 +216,12 @@ Below is a list with all available subcommands. --help Show this message and exit. Commands: - check-load-time Check for common indicators that slowdown `verdi`. - configure-backup Configure backup of the repository folder. - run_daemon Run a daemon instance in the current interpreter. - tests Run the unittest suite or parts of it. - validate-plugins Validate all plugins by checking they can be loaded. + check-load-time Check for common indicators that slowdown `verdi`. + check-undesired-imports Check that verdi does not import python modules it shouldn't. + configure-backup Configure backup of the repository folder. + run_daemon Run a daemon instance in the current interpreter. + tests Run the unittest suite or parts of it. + validate-plugins Validate all plugins by checking they can be loaded. .. _reference:command-line:verdi-export: @@ -229,7 +229,7 @@ Below is a list with all available subcommands. ``verdi export`` ---------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -249,7 +249,7 @@ Below is a list with all available subcommands. ``verdi graph`` --------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -267,7 +267,7 @@ Below is a list with all available subcommands. ``verdi group`` --------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -294,7 +294,7 @@ Below is a list with all available subcommands. ``verdi help`` -------------- -:: +.. code:: console Usage: [OPTIONS] [COMMAND] @@ -309,56 +309,50 @@ Below is a list with all available subcommands. ``verdi import`` ---------------- -:: +.. code:: console Usage: [OPTIONS] [--] [ARCHIVES]... Import data from an AiiDA archive file. - The archive can be specified by its relative or absolute file path, or its - HTTP URL. + The archive can be specified by its relative or absolute file path, or its HTTP URL. Options: - -w, --webpages TEXT... Discover all URL targets pointing to files - with the .aiida extension for these HTTP - addresses. Automatically discovered archive - URLs will be downloadeded and added to - ARCHIVES for importing + -w, --webpages TEXT... Discover all URL targets pointing to files with the + .aiida extension for these HTTP addresses. Automatically + discovered archive URLs will be downloadeded and added + to ARCHIVES for importing - -G, --group GROUP Specify group to which all the import nodes - will be added. If such a group does not - exist, it will be created automatically. + -G, --group GROUP Specify group to which all the import nodes will be + added. If such a group does not exist, it will be + created automatically. -e, --extras-mode-existing [keep_existing|update_existing|mirror|none|ask] - Specify which extras from the export archive - should be imported for nodes that are - already contained in the database: ask: - import all extras and prompt what to do for - existing extras. keep_existing: import all - extras and keep original value of existing - extras. update_existing: import all extras - and overwrite value of existing extras. - mirror: import all extras and remove any - existing extras that are not present in the - archive. none: do not import any extras. + Specify which extras from the export archive should be + imported for nodes that are already contained in the + database: ask: import all extras and prompt what to do + for existing extras. keep_existing: import all extras + and keep original value of existing extras. + update_existing: import all extras and overwrite value + of existing extras. mirror: import all extras and remove + any existing extras that are not present in the archive. + none: do not import any extras. -n, --extras-mode-new [import|none] - Specify whether to import extras of new - nodes: import: import extras. none: do not - import extras. + Specify whether to import extras of new nodes: import: + import extras. none: do not import extras. --comment-mode [newest|overwrite] - Specify the way to import Comments with - identical UUIDs: newest: Only the newest - Comments (based on mtime) - (default).overwrite: Replace existing - Comments with those from the import file. + Specify the way to import Comments with identical UUIDs: + newest: Only the newest Comments (based on mtime) + (default).overwrite: Replace existing Comments with + those from the import file. - --migration / --no-migration Force migration of export file archives, if - needed. [default: True] + --migration / --no-migration Force migration of export file archives, if needed. + [default: True] - -n, --non-interactive Non-interactive mode: never prompt for - input. + -n, --non-interactive In non-interactive mode, the CLI never prompts but + simply uses default values for options that define one. --help Show this message and exit. @@ -368,7 +362,7 @@ Below is a list with all available subcommands. ``verdi node`` -------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -396,7 +390,7 @@ Below is a list with all available subcommands. ``verdi plugin`` ---------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -414,7 +408,7 @@ Below is a list with all available subcommands. ``verdi process`` ----------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -440,7 +434,7 @@ Below is a list with all available subcommands. ``verdi profile`` ----------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... @@ -461,21 +455,20 @@ Below is a list with all available subcommands. ``verdi quicksetup`` -------------------- -:: +.. code:: console Usage: [OPTIONS] Setup a new profile in a fully automated fashion. Options: - -n, --non-interactive Non-interactive mode: never prompt for - input. + -n, --non-interactive In non-interactive mode, the CLI never prompts but + simply uses default values for options that define one. --profile PROFILE The name of the new profile. [required] - --email EMAIL Email address associated with the data you - generate. The email address is exported - along with the data, when sharing it. - [required] + --email EMAIL Email address associated with the data you generate. The + email address is exported along with the data, when + sharing it. [required] --first-name NONEMPTYSTRING First name of the user. [required] --last-name NONEMPTYSTRING Last name of the user. [required] @@ -491,16 +484,28 @@ Below is a list with all available subcommands. --db-name NONEMPTYSTRING Name of the database to create. --db-username NONEMPTYSTRING Name of the database user to create. --db-password TEXT Password of the database user. - --su-db-name TEXT Name of the template database to connect to - as the database superuser. + --su-db-name TEXT Name of the template database to connect to as the + database superuser. --su-db-username TEXT User name of the database super user. - --su-db-password TEXT Password to connect as the database - superuser. + --su-db-password TEXT Password to connect as the database superuser. + --broker-protocol [amqp|amqps] Protocol to use for the message broker. [default: amqp] + --broker-username NONEMPTYSTRING + Username to use for authentication with the message + broker. [default: guest] + + --broker-password NONEMPTYSTRING + Password to use for authentication with the message + broker. [default: guest] + + --broker-host HOSTNAME Hostname for the message broker. [default: 127.0.0.1] + --broker-port INTEGER Port for the message broker. [default: 5672] + --broker-virtual-host HOSTNAME Name of the virtual host for the message broker. Forward + slashes need to be encoded [default: ] --repository DIRECTORY Absolute path to the file repository. - --config FILEORURL Load option values from configuration file - in yaml format (local path or URL). + --config FILEORURL Load option values from configuration file in yaml + format (local path or URL). --help Show this message and exit. @@ -510,18 +515,18 @@ Below is a list with all available subcommands. ``verdi rehash`` ---------------- -:: +.. code:: console Usage: [OPTIONS] [NODES]... Recompute the hash for nodes in the database. - The set of nodes that will be rehashed can be filtered by their identifier - and/or based on their class. + The set of nodes that will be rehashed can be filtered by their identifier and/or + based on their class. Options: - -e, --entry-point PLUGIN Only include nodes that are class or sub class of - the class identified by this entry point. + -e, --entry-point PLUGIN Only include nodes that are class or sub class of the class + identified by this entry point. -f, --force Do not ask for confirmation. --help Show this message and exit. @@ -532,7 +537,7 @@ Below is a list with all available subcommands. ``verdi restapi`` ----------------- -:: +.. code:: console Usage: [OPTIONS] @@ -546,8 +551,8 @@ Below is a list with all available subcommands. -H, --hostname HOSTNAME Hostname. -P, --port INTEGER Port number. -c, --config-dir PATH Path to the configuration directory - --wsgi-profile Whether to enable WSGI profiler middleware for - finding bottlenecks + --wsgi-profile Whether to enable WSGI profiler middleware for finding + bottlenecks --hookup / --no-hookup Hookup app to flask server --help Show this message and exit. @@ -558,7 +563,7 @@ Below is a list with all available subcommands. ``verdi run`` ------------- -:: +.. code:: console Usage: [OPTIONS] [--] SCRIPTNAME [VARARGS]... @@ -567,19 +572,19 @@ Below is a list with all available subcommands. Options: --auto-group Enables the autogrouping -l, --auto-group-label-prefix TEXT - Specify the prefix of the label of the auto - group (numbers might be automatically - appended to generate unique names per run). + Specify the prefix of the label of the auto group + (numbers might be automatically appended to generate + unique names per run). - -n, --group-name TEXT Specify the name of the auto group - [DEPRECATED, USE --auto-group-label-prefix - instead]. This also enables auto-grouping. + -n, --group-name TEXT Specify the name of the auto group [DEPRECATED, USE + --auto-group-label-prefix instead]. This also enables + auto-grouping. - -e, --exclude TEXT Exclude these classes from auto grouping - (use full entrypoint strings). + -e, --exclude TEXT Exclude these classes from auto grouping (use full + entrypoint strings). - -i, --include TEXT Include these classes from auto grouping - (use full entrypoint strings or "all"). + -i, --include TEXT Include these classes from auto grouping (use full + entrypoint strings or "all"). --help Show this message and exit. @@ -589,21 +594,20 @@ Below is a list with all available subcommands. ``verdi setup`` --------------- -:: +.. code:: console Usage: [OPTIONS] Setup a new profile. Options: - -n, --non-interactive Non-interactive mode: never prompt for - input. + -n, --non-interactive In non-interactive mode, the CLI never prompts but + simply uses default values for options that define one. --profile PROFILE The name of the new profile. [required] - --email EMAIL Email address associated with the data you - generate. The email address is exported - along with the data, when sharing it. - [required] + --email EMAIL Email address associated with the data you generate. The + email address is exported along with the data, when + sharing it. [required] --first-name NONEMPTYSTRING First name of the user. [required] --last-name NONEMPTYSTRING Last name of the user. [required] @@ -617,13 +621,29 @@ Below is a list with all available subcommands. --db-port INTEGER Database server port. --db-name NONEMPTYSTRING Name of the database to create. [required] - --db-username NONEMPTYSTRING Name of the database user to create. - [required] - + --db-username NONEMPTYSTRING Name of the database user to create. [required] --db-password TEXT Password of the database user. [required] + --broker-protocol [amqp|amqps] Protocol to use for the message broker. [default: amqp; + required] + + --broker-username NONEMPTYSTRING + Username to use for authentication with the message + broker. [default: guest; required] + + --broker-password NONEMPTYSTRING + Password to use for authentication with the message + broker. [default: guest; required] + + --broker-host HOSTNAME Hostname for the message broker. [default: 127.0.0.1; + required] + + --broker-port INTEGER Port for the message broker. [default: 5672; required] + --broker-virtual-host HOSTNAME Name of the virtual host for the message broker. Forward + slashes need to be encoded [default: ; required] + --repository DIRECTORY Absolute path to the file repository. - --config FILEORURL Load option values from configuration file - in yaml format (local path or URL). + --config FILEORURL Load option values from configuration file in yaml + format (local path or URL). --help Show this message and exit. @@ -633,22 +653,19 @@ Below is a list with all available subcommands. ``verdi shell`` --------------- -:: +.. code:: console Usage: [OPTIONS] Start a python shell with preloaded AiiDA environment. Options: - --plain Use a plain Python shell.) - --no-startup When using plain Python, ignore the - PYTHONSTARTUP environment variable and - ~/.pythonrc.py script. + --plain Use a plain Python shell. + --no-startup When using plain Python, ignore the PYTHONSTARTUP + environment variable and ~/.pythonrc.py script. -i, --interface [ipython|bpython] - Specify an interactive interpreter - interface. - + Specify an interactive interpreter interface. --help Show this message and exit. @@ -657,15 +674,16 @@ Below is a list with all available subcommands. ``verdi status`` ---------------- -:: +.. code:: console Usage: [OPTIONS] Print status of AiiDA services. Options: - --no-rmq Do not check RabbitMQ status - --help Show this message and exit. + -t, --print-traceback Print the full traceback in case an exception is raised. + --no-rmq Do not check RabbitMQ status + --help Show this message and exit. .. _reference:command-line:verdi-user: @@ -673,7 +691,7 @@ Below is a list with all available subcommands. ``verdi user`` -------------- -:: +.. code:: console Usage: [OPTIONS] COMMAND [ARGS]... diff --git a/docs/source/scheduler/index.rst b/docs/source/scheduler/index.rst deleted file mode 100644 index 1b68d7c468..0000000000 --- a/docs/source/scheduler/index.rst +++ /dev/null @@ -1,195 +0,0 @@ -.. _my-reference-to-scheduler: - -Supported schedulers -++++++++++++++++++++ - -The list below describes the supported *schedulers*, i.e. the batch job schedulers that manage the job queues and execution on any given computer. - -PBSPro ------- -The `PBSPro`_ scheduler is supported (and it has been tested with version 12.1). - -All the main features are supported with this scheduler. - -The :ref:`JobResource ` class to be used when setting the job resources is the :ref:`NodeNumberJobResource`. - -.. _PBSPro: http://www.pbsworks.com/Product.aspx?id=1 - -SLURM ------ - -The `SLURM`_ scheduler is supported (and it has been tested with version 2.5.4). - -All the main features are supported with this scheduler. - -The :ref:`JobResource ` class to be used when setting the job resources is the :ref:`NodeNumberJobResource`. - -.. _SLURM: https://slurm.schedmd.com/ - -SGE ---- - -The `SGE`_ scheduler (Sun Grid Engine, now called Oracle Grid Engine) -is supported (and it has been tested with version GE 6.2u3), -together with some of the main variants/forks. - -All the main features are supported with this scheduler. - -The :ref:`JobResource ` class to be used when setting the job resources is the :ref:`ParEnvJobResource`. - -.. _SGE: https://en.wikipedia.org/wiki/Oracle_Grid_Engine - -LSF ---- - -The IBM `LSF`_ scheduler is supported and has been tested with version 9.1.3 -on the CERN `lxplus` cluster. - -.. _LSF: https://www-01.ibm.com/support/knowledgecenter/SSETD4_9.1.3/lsf_welcome.html - -Torque ------- - -`Torque`_ (based on OpenPBS) is supported (and it has been tested with Torque v.2.4.16 from Ubuntu). - -All the main features are supported with this scheduler. - -The :ref:`JobResource ` class to be used when setting the job resources is the :ref:`NodeNumberJobResource`. - -.. _Torque: http://www.adaptivecomputing.com/products/open-source/torque/ - - - -Direct execution (bypassing schedulers) ---------------------------------------- - -The direct scheduler, to be used mainly for debugging, is an implementation of a scheduler plugin that does not require a real scheduler installed, but instead directly executes a command, puts it in the background, and checks for its process ID (PID) to discover if the execution is completed. - -.. warning:: - The direct execution mode is very fragile. Currently, it spawns a separate Bash shell to execute a job and track each shell by process ID (PID). This poses following problems: - - * PID numeration is reset during reboots; - * PID numeration is different from machine to machine, thus direct execution is *not* possible in multi-machine clusters, redirecting each SSH login to a different node in round-robin fashion; - * there is no real queueing, hence, all calculation started will be run in parallel. - -.. warning:: - Direct execution bypasses schedulers, so it should be used with care in order not to disturb the functioning of machines. - -All the main features are supported with this scheduler. - -The :ref:`JobResource ` class to be used when setting the job resources is the :ref:`NodeNumberJobResource` - - -.. _job_resources: - -Job resources -+++++++++++++ - -When asking a scheduler to allocate some nodes/machines for a given job, we have to specify some job resources, such as the number of required nodes or the numbers of MPI processes per node. - -Unfortunately, the way of specifying this information is different on different clusters. In AiiDA, this is implemented in different subclasses of the :py:class:`aiida.schedulers.datastructures.JobResource` class. The subclass that should be used is given by the scheduler, as described in the previous section. - -The interfaces of these subclasses are not all exactly the same. Instead, specifying the resources is similar to writing a scheduler script. All classes define at least one method, :meth:`get_tot_num_mpiprocs `, that returns the total number of MPI processes requested. - -In the following, the different :class:`JobResource ` subclasses are described: - -.. contents :: - :local: - -.. note:: - you can manually load a `specific` :class:`JobResource ` subclass by directly importing it, e..g. - :: - - from aiida.schedulers.datastructures import NodeNumberJobResource - - However, in general, you will pass the fields to set directly in the ``metadata.options`` input dictionary of the :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob`. - For instance:: - - from aiida.orm import load_code - - # This example assumes that the computer is configured to use a scheduler with job resources of type :py:class:`~aiida.schedulers.datastructures.NodeNumberJobResource` - inputs = { - 'code': load_code('somecode@localhost'), # The configured code to be used, which also defines the computer - 'metadata': { - 'options': { - 'resources', {'num_machines': 4, 'num_mpiprocs_per_machine': 16} - } - } - } - - -.. _NodeNumberJobResource: - -NodeNumberJobResource (PBS-like) --------------------------------- -This is the way of specifying the job resources in PBS and SLURM. The class is :py:class:`~aiida.schedulers.datastructures.NodeNumberJobResource`. - -Once an instance of the class is obtained, you have the following fields that you can set: - -* ``res.num_machines``: specify the number of machines (also called nodes) on which the code should run -* ``res.num_mpiprocs_per_machine``: number of MPI processes to use on each machine -* ``res.tot_num_mpiprocs``: the total number of MPI processes that this job is requesting -* ``res.num_cores_per_machine``: specify the number of cores to use on each machine -* ``res.num_cores_per_mpiproc``: specify the number of cores to run each MPI process - -Note that you need to specify only two among the first three fields above, but they have to be defined upon construction, for instance:: - - res = NodeNumberJobResource(num_machines=4, num_mpiprocs_per_machine=16) - -asks the scheduler to allocate 4 machines, with 16 MPI processes on each machine. This will automatically ask for a total of ``4*16=64`` total number of MPI processes. - -.. note:: - If you specify res.num_machines, res.num_mpiprocs_per_machine, and res.tot_num_mpiprocs fields (not recommended), make sure that they satisfy:: - - res.num_machines * res.num_mpiprocs_per_machine = res.tot_num_mpiprocs - - Moreover, if you specify ``res.tot_num_mpiprocs``, make sure that this is a multiple of ``res.num_machines`` and/or ``res.num_mpiprocs_per_machine``. - -.. note:: - When creating a new computer, you will be asked for a ``default_mpiprocs_per_machine``. If you specify it, then you can avoid to specify ``num_mpiprocs_per_machine`` when creating the resources for that computer, and the default number will be used. - - Of course, all the requirements between ``num_machines``, ``num_mpiprocs_per_machine`` and ``tot_num_mpiprocs`` still apply. - - Moreover, you can explicitly specify ``num_mpiprocs_per_machine`` if you want to use a value different from the default one. - - -The num_cores_per_machine and num_cores_per_mpiproc fields are optional. If you specify num_mpiprocs_per_machine and num_cores_per_machine fields, make sure that:: - - res.num_cores_per_mpiproc * res.num_mpiprocs_per_machine = res.num_cores_per_machine - -If you want to specifiy single value in num_mpiprocs_per_machine and num_cores_per_machine, please make sure that res.num_cores_per_machine is multiple of res.num_cores_per_mpiproc and/or res.num_mpiprocs_per_machine. - -.. note:: - In PBSPro, the num_mpiprocs_per_machine and num_cores_per_machine fields are used for mpiprocs and ppn respectively. - -.. note:: - In Torque, the num_mpiprocs_per_machine field is used for ppn unless the num_mpiprocs_per_machine is specified. - -.. _ParEnvJobResource: - -ParEnvJobResource (SGE-like) ----------------------------- -In SGE and similar schedulers, one has to specify a *parallel environment* and the *total number of CPUs* requested. The class is :py:class:`~aiida.schedulers.datastructures.ParEnvJobResource`. - -Once an instance of the class is obtained, you have the following fields that you can set: - -* ``res.parallel_env``: specify the parallel environment in which you want to run your job (a string) -* ``res.tot_num_mpiprocs``: the total number of MPI processes that this job is requesting - -Remember to always specify both fields. No checks are done on the consistency between the specified parallel environment and the total number of MPI processes requested (for instance, some parallel environments may have been configured by your cluster administrator to run on a single machine). It is your responsibility to make sure that the information is valid, otherwise the submission will fail. - -Some examples: - -* setting the fields directly in the class constructor:: - - res = ParEnvJobResource(parallel_env='mpi', tot_num_mpiprocs=64) - -* even better, you will pass the fields to set directly in the ``metadata.options`` input dictionary of the :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob`.:: - - inputs = { - 'metadata': { - 'options': { - resources', {'parallel_env': 'mpi', 'tot_num_mpiprocs': 64} - } - } - } diff --git a/docs/source/topics/calculations/concepts.rst b/docs/source/topics/calculations/concepts.rst index f9abe1fe6c..a6abbd83d4 100644 --- a/docs/source/topics/calculations/concepts.rst +++ b/docs/source/topics/calculations/concepts.rst @@ -110,7 +110,7 @@ When you want to run this 'code' through AiiDA, you need to tell *how* AiiDA sho The :py:class:`~aiida.calculations.arithmetic.add.ArithmeticAddCalculation` is a calculation job implementation that forms an interface to accomplish exactly that for the example bash script. A ``CalcJob`` implementation for a specific code, often referred to as a calculation plugin, essentially instructs the engine how it should be run. This includes how the necessary input files should be created based on the inputs that it receives, how the code executable should be called and what files should be retrieved when the calculation is complete. -Note the files should be 'retrieved' because calculation jobs can be run not just on the localhost, but on any :ref:`computer that is configured in AiiDA`, including remote machines accessible over for example SSH. +Note the files should be 'retrieved' because calculation jobs can be run not just on the localhost, but on any :ref:`computer that is configured in AiiDA`, including remote machines accessible over for example SSH. Since a ``CalcJob`` is a process just like the :ref:`calculation functions` described before, they can be run in an identical way. diff --git a/docs/source/topics/calculations/include/snippets/calcjobs/arithmetic_add_parser.py b/docs/source/topics/calculations/include/snippets/calcjobs/arithmetic_add_parser.py index 3e3773060e..fb0a7dee32 100644 --- a/docs/source/topics/calculations/include/snippets/calcjobs/arithmetic_add_parser.py +++ b/docs/source/topics/calculations/include/snippets/calcjobs/arithmetic_add_parser.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -from aiida.common import exceptions from aiida.orm import Int from aiida.parsers.parser import Parser @@ -8,10 +7,7 @@ class ArithmeticAddParser(Parser): def parse(self, **kwargs): """Parse the contents of the output files retrieved in the `FolderData`.""" - try: - output_folder = self.retrieved - except exceptions.NotExistent: - return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER + output_folder = self.retrieved try: with output_folder.open(self.node.get_option('output_filename'), 'r') as handle: diff --git a/docs/source/topics/calculations/index.rst b/docs/source/topics/calculations/index.rst index 615af4c223..ce08ca20b3 100644 --- a/docs/source/topics/calculations/index.rst +++ b/docs/source/topics/calculations/index.rst @@ -5,7 +5,7 @@ Calculations ************ This topic section provides detailed information on the concept of calculations in AiiDA and an extensive guide on how to work with them. -An introductory guide to working with calculations can be found in :ref:`"How to run external codes"`. +An introductory guide to working with calculations can be found in :ref:`"How to run external codes"`. .. toctree:: :maxdepth: 2 diff --git a/docs/source/topics/calculations/usage.rst b/docs/source/topics/calculations/usage.rst index 57e03f2f1c..5226424c3d 100644 --- a/docs/source/topics/calculations/usage.rst +++ b/docs/source/topics/calculations/usage.rst @@ -466,45 +466,32 @@ The advantage of adding the raw output data in different form as output nodes, i This allows one to query for calculations that produced specific outputs with a certain value, which becomes a very powerful approach for post-processing and analyses of big databases. The ``retrieved`` attribute of the parser will return the ``FolderData`` node that should have been attached by the engine containing all the retrieved files, as specified using the :ref:`retrieve list` in the :ref:`preparation step of the calculation job`. -If this node has not been attached for whatever reason, this call will throw an :py:class:`~aiida.common.exceptions.NotExistent` exception. -This is why we wrap the ``self.retrieved`` call in a try-catch block: +This retrieved folder can be used to open and read the contents of the files it contains. +In this example, there should be a single output file that was written by redirecting the standard output of the bash script that added the two integers. +The parser opens this file, reads its content and tries to parse the sum from it: .. literalinclude:: include/snippets/calcjobs/arithmetic_add_parser.py :language: python - :lines: 10-13 + :lines: 12-16 :linenos: - :lineno-start: 10 + :lineno-start: 12 -If the exception is thrown, it means the retrieved files are not available and something must have has gone terribly awry with the calculation. -In this case, there is nothing to do for the parser and so we return an exit code. -Specific exit codes can be referenced by their label, such as ``ERROR_NO_RETRIEVED_FOLDER`` in this example, through the ``self.exit_codes`` property. +Note that this parsing action is wrapped in a try-except block to catch the exceptions that would be thrown if the output file could not be read. +If the exception would not be caught, the engine will catch the exception instead and set the process state of the corresponding calculation to ``Excepted``. +Note that this will happen for any uncaught exception that is thrown during parsing. +Instead, we catch these exceptions and return an exit code that is retrieved by referencing it by its label, such as ``ERROR_READING_OUTPUT_FILE`` in this example, through the ``self.exit_codes`` property. This call will retrieve the corresponding exit code defined on the ``CalcJob`` that we are currently parsing. Returning this exit code from the parser will stop the parsing immediately and will instruct the engine to set its exit status and exit message on the node of this calculation job. -This should scenario should however never occur, but it is just here as a safety. -If the exception would not be caught, the engine will catch the exception instead and set the process state of the corresponding calculation to ``Excepted``. -Note that this will happen for any exception that occurs during parsing. - -Assuming that everything went according to plan during the retrieval, we now have access to those retrieved files and can start to parse them. -In this example, there should be a single output file that was written by redirecting the standard output of the bash script that added the two integers. -The parser opens this file, reads its content and tries to parse the sum from it: - -.. literalinclude:: include/snippets/calcjobs/arithmetic_add_parser.py - :language: python - :lines: 15-19 - :linenos: - :lineno-start: 15 -Note that again we wrap this parsing action in a try-except block. -If the file cannot be found or cannot be read, we return the appropriate exit code. The ``parse_stdout`` method is just a small utility function to separate the actual parsing of the data from the main parser code. In this case, the parsing is so simple that we might have as well kept it in the main method, but this is just to illustrate that you are completely free to organize the code within the ``parse`` method for clarity. If we manage to parse the sum, produced by the calculation, we wrap it in the appropriate :py:class:`~aiida.orm.nodes.data.int.Int` data node class, and register it as an output through the ``out`` method: .. literalinclude:: include/snippets/calcjobs/arithmetic_add_parser.py :language: python - :lines: 24-24 + :lines: 21-21 :linenos: - :lineno-start: 24 + :lineno-start: 21 Note that if we encountered no problems, we do not have to return anything. The engine will interpret this as the calculation having finished successfully. @@ -518,3 +505,50 @@ However, we can give you some guidelines: If you were to store all this data in the database, it would become unnecessarily bloated, because the chances you would have to query for this data are unlikely. Instead these array type data nodes store the bulk of their content in the repository. This way you still keep the data and therewith the provenance of your calculations, while keeping your database lean and fast! + + +.. _topics:calculations:usage:calcjobs:scheduler-errors: + +Scheduler errors +---------------- + +Besides the output parsers, the scheduler plugins can also provide parsing of the output generated by the job scheduler, by implementing the :meth:`~aiida.schedulers.scheduler.Scheduler.parse_output` method. +If the scheduler plugin has implemented this method, the output generated by the scheduler, written to the stdout and stderr file descriptors as well as the output of the detailed job info command, is parsed. +If the parser detects a known problem, such as an out-of-memory (OOM) error, the corresponding exit code will already be set on the calculation job node. +The output parser, if defined in the inputs, can inspect the exit status on the node and decide to keep it or override it with a different, potentially more useful, exit code. + +.. code:: python + + class SomeParser(Parser): + + def parse(self, **kwargs): + """Parse the contents of the output files retrieved in the `FolderData`.""" + if self.node.exit_status is not None: + # If an exit status is already set on the node, that means the + # scheduler plugin detected a problem. + return + +Note that in the example given above, the parser returns immediately if it detects that the scheduler detected a problem. +Since it returns `None`, the exit code of the scheduler will be kept and will be the final exit code of the calculation job. +However, the parser does not have to immediately return. +It can still try to parse some of the retrieved output, if there is any. +If it finds a more specific problem than the generic scheduler error, it can always return an exit code of itself to override it. +The parser can even return ``ExitCode(0)`` to have the calculation marked as successfully finished, despite the scheduler having determined that there was a problem. +The following table summarizes the possible scenarios of the scheduler parser and output parser returning an exit code and what the final resulting exit code will be that is set on the node: + ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ +| **Scenario** | **Scheduler result** | **Retrieved result** | **Final result** | ++====================================================================================+=======================+=======================+=======================+ +| Neither parser found any problem. | ``None`` | ``None`` | ``ExitCode(0)`` | ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ +| Scheduler parser found an issue, | ``ExitCode(100)`` | ``None`` | ``ExitCode(100)`` | +| but output parser does not override. | | | | ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ +| Only output parser found a problem. | ``None`` | ``ExitCode(400)`` | ``ExitCode(400)`` | ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ +| Scheduler parser found an issue, but the output parser overrides with a more | ``ExitCode(100)`` | ``ExitCode(400)`` | ``ExitCode(400)`` | +| specific error code. | | | | ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ +| Scheduler found issue but output parser overrides saying that despite that the | ``ExitCode(100)`` | ``ExitCode(0)`` | ``ExitCode(0)`` | +| calculation should be considered finished successfully. | | | | ++------------------------------------------------------------------------------------+-----------------------+-----------------------+-----------------------+ diff --git a/docs/source/topics/index.rst b/docs/source/topics/index.rst index b5f68c1985..a71b1c8eea 100644 --- a/docs/source/topics/index.rst +++ b/docs/source/topics/index.rst @@ -12,6 +12,7 @@ Topics provenance/index database plugins + schedulers .. todo:: diff --git a/docs/source/topics/plugins.rst b/docs/source/topics/plugins.rst index 3d4e5a5599..8ff4ff8548 100644 --- a/docs/source/topics/plugins.rst +++ b/docs/source/topics/plugins.rst @@ -34,6 +34,33 @@ If you find yourself in a situation where you feel like you need to do any of th .. _core: https://github.com/aiidateam/aiida-core .. _registry: https://github.com/aiidateam/aiida-registry +.. _topics:plugins:guidelines: + +Guidelines for plugin design +============================ + +CalcJob & Parser plugins +------------------------ + +The following guidelines are useful to keep in mind when wrapping external codes: + + * | **Start simple.** + | Make use of existing classes like :py:class:`~aiida.orm.nodes.data.dict.Dict`, :py:class:`~aiida.orm.nodes.data.singlefile.SinglefileData`, ... + Write only what is necessary to pass information from and to AiiDA. + * | **Don't break data provenance.** + | Store *at least* what is needed for full reproducibility. + * | **Expose the full functionality.** + | Standardization is good but don't artificially limit the power of a code you are wrapping - or your users will get frustrated. + If the code can do it, there should be *some* way to do it with your plugin. + * | **Don't rely on AiiDA internals.** + Functionality at deeper nesting levels is not considered part of the public API and may change between minor AiiDA releases, breaking your plugin. + * | **Parse what you want to query for.** + | Make a list of which information to: + + #. parse into the database for querying (:py:class:`~aiida.orm.nodes.data.dict.Dict`, ...) + #. store in the file repository for safe-keeping (:py:class:`~aiida.orm.nodes.data.singlefile.SinglefileData`, ...) + #. leave on the computer where the calculation ran (:py:class:`~aiida.orm.nodes.data.remote.RemoteData`, ...) + .. _topics:plugins:entrypoints: diff --git a/docs/source/topics/provenance/caching.rst b/docs/source/topics/provenance/caching.rst new file mode 100644 index 0000000000..00d977eded --- /dev/null +++ b/docs/source/topics/provenance/caching.rst @@ -0,0 +1,75 @@ +.. _topics:provenance:caching: + +=================== +Caching and hashing +=================== + +.. _topics:provenance:caching:hashing: + +How are nodes hashed +-------------------- + +*Hashing* is turned on by default, i.e., all nodes in AiiDA are hashed. +This means that even when you enable caching once you have already completed a number of calculations, those calculations can still be used retro-actively by the caching mechanism since their hashes have been computed. + +The hash of a ``Data`` node is computed from: + +* all attributes of the node, except the ``_updatable_attributes`` and ``_hash_ignored_attributes`` +* the ``__version__`` of the package which defined the node class +* the content of the repository folder of the node +* the UUID of the computer, if the node is associated with one + +The hash of a :class:`~aiida.orm.ProcessNode` includes, on top of this, the hashes of all of its input ``Data`` nodes. + +Once a node is stored in the database, its hash is stored in the ``_aiida_hash`` extra, and this extra is used to find matching nodes. +If a node of the same class with the same hash already exists in the database, this is considered a cache match. +Use the :meth:`~aiida.orm.nodes.Node.get_hash` method to check the hash of any node. +In order to figure out why a calculation is *not* being reused, the :meth:`~aiida.orm.nodes.Node._get_objects_to_hash` method may be useful: + +.. code-block:: ipython + + In [5]: node = load_node(1234) + + In [6]: node.get_hash() + Out[6]: '62eca804967c9428bdbc11c692b7b27a59bde258d9971668e19ccf13a5685eb8' + + In [7]: node._get_objects_to_hash() + Out[7]: + [ + '1.0.0', + { + 'resources': {'num_machines': 2, 'default_mpiprocs_per_machine': 28}, + 'parser_name': 'cp2k', + 'linkname_retrieved': 'retrieved' + }, + , + '6850dc88-0949-482e-bba6-8b11205aec11', + { + 'code': 'f6bd65b9ca3a5f0cf7d299d9cfc3f403d32e361aa9bb8aaa5822472790eae432', + 'parameters': '2c20fdc49672c3505cebabacfb9b1258e71e7baae5940a80d25837bee0032b59', + 'structure': 'c0f1c1d1bbcfc7746dcf7d0d675904c62a5b1759d37db77b564948fa5a788769', + 'parent_calc_folder': 'e375178ceeffcde086546d3ddbce513e0527b5fa99993091b2837201ad96569c' + } + ] + + +.. _topics:provenance:caching:limitations: + +Limitations +----------- + +#. Workflow nodes are not cached. + In the current design this follows from the requirement that the provenance graph be independent of whether caching is enabled or not: + + * **Calculation nodes:** Calculation nodes can have data inputs and create new data nodes as outputs. + In order to make it look as if a cloned calculation produced its own outputs, the output nodes are copied and linked as well. + * **Workflow nodes:** Workflows differ from calculations in that they can *return* an input node or an output node created by a calculation. + Since caching does not care about the *identity* of input nodes but only their *content*, it is not straightforward to figure out which node to return in a cached workflow. + + This limitation has typically no significant impact since the runtime of AiiDA work chains is commonly dominated by expensive calculations. + +#. The caching mechanism for calculations *should* trigger only when the inputs and the calculation to be performed are exactly the same. + While AiiDA's hashes include the version of the Python package containing the calculation/data classes, it cannot detect cases where the underlying Python code was changed without increasing the version number. + Another scenario that can lead to an erroneous cache hit is if the parser and calculation are not implemented as part of the same Python package, because the calculation nodes store only the name, but not the version of the used parser. + +#. Finally, while caching saves unnecessary computations, it does not save disk space: the output nodes of the cached calculation are full copies of the original outputs. diff --git a/docs/source/topics/provenance/index.rst b/docs/source/topics/provenance/index.rst index c07892d78f..9b0965f713 100644 --- a/docs/source/topics/provenance/index.rst +++ b/docs/source/topics/provenance/index.rst @@ -13,3 +13,4 @@ The :ref:`consistency` section details the rules concepts implementation consistency + caching diff --git a/docs/source/topics/schedulers.rst b/docs/source/topics/schedulers.rst new file mode 100644 index 0000000000..d99b338f0c --- /dev/null +++ b/docs/source/topics/schedulers.rst @@ -0,0 +1,220 @@ +.. _topics:schedulers: + +==================== +Batch Job Schedulers +==================== + +Batch job schedulers manage the job queues and execution on a compute resource. +AiiDA ships with plugins for a range of schedulers, and this section describes the interface of these plugins. + +See :ref:`this how-to ` for adding support for custom schedulers. + +PBSPro +------ + +The `PBSPro`_ scheduler is supported (tested: version 12.1). + +All the main features are supported with this scheduler. + +Use the :ref:`topics:schedulers:job_resources:node` when setting job resources. + +.. _PBSPro: http://www.pbsworks.com/Product.aspx?id=1 + +SLURM +----- + +The `SLURM`_ scheduler is supported (tested: version 2.5.4). + +All the main features are supported with this scheduler. + +Use the :ref:`topics:schedulers:job_resources:node` when setting job resources. + +.. _SLURM: https://slurm.schedmd.com/ + +SGE +--- + +The `SGE`_ scheduler (Sun Grid Engine, now called Oracle Grid Engine) and some of its main variants/forks are supported (tested: version GE 6.2u3). + +All the main features are supported with this scheduler. + +Use the :ref:`topics:schedulers:job_resources:par` when setting job resources. + +.. _SGE: https://en.wikipedia.org/wiki/Oracle_Grid_Engine + +LSF +--- + +The IBM `LSF`_ scheduler is supported (tested: version 9.1.3 on the CERN `lxplus` cluster). + +.. _LSF: https://www-01.ibm.com/support/knowledgecenter/SSETD4_9.1.3/lsf_welcome.html + +Torque +------ + +`Torque`_ (based on OpenPBS) is supported (tested: version 2.4.16 from Ubuntu). + +All the main features are supported with this scheduler. + +Use the :ref:`topics:schedulers:job_resources:node` when setting job resources. + +.. _Torque: http://www.adaptivecomputing.com/products/open-source/torque/ + + + +Direct execution (bypassing schedulers) +--------------------------------------- + +The ``direct`` scheduler plugin simply executes the command in a new bash shell, puts it in the background and checks for its process ID (PID) to determine when the execution is completed. + +Its main purpose is debugging on the local machine. +Use a proper batch scheduler for any production calculations. + +.. warning:: + + Compared to a proper batch scheduler, direct execution mode is fragile. + In particular: + + * There is no queueing, i.e. all calculations run in parallel. + * PID numeration is reset during reboots. + +.. warning:: + + Do *not* use the direct scheduler for running on a supercomputer. + The job will end up running on the login node (which is typically forbidden), and if your centre has multiple login nodes, AiiDA may get confused if subsequent SSH connections end up at a different login node (causing AiiDA to infer that the job has completed). + +All the main features are supported with this scheduler. + +Use the :ref:`topics:schedulers:job_resources:node` when setting job resources. + + +.. _topics:schedulers:job_resources: + +Job resources +------------- + +Unsurprisingly, different schedulers have different ways of specifying the resources for a job (such as the number of required nodes or the numbers of MPI processes per node). + +In AiiDA, these differences are accounted for by subclasses of the |JobResource| class. +The previous section lists which subclass to use with a given scheduler. + +All subclasses define at least the :py:meth:`~aiida.schedulers.datastructures.JobResource.get_tot_num_mpiprocs` method that returns the total number of MPI processes requested but otherwise have slightly different interfaces described in the following. + +.. note:: + + You can manually load a `specific` |JobResource| subclass by directly importing it, e.g. + + .. code-block:: python + + from aiida.schedulers.datastructures import NodeNumberJobResource + + In practice, however, the appropriate class will be inferred from scheduler configured for the relevant AiiDA computer, and you can simply set the relevant fields in the ``metadata.options`` input dictionary of the |CalcJob|. + + For a scheduler with job resources of type |NodeNumberJobResource|, this could be: + + .. code-block:: python + + from aiida.orm import load_code + + inputs = { + 'code': load_code('somecode@localhost'), # The configured code to be used, which also defines the computer + 'metadata': { + 'options': { + 'resources', {'num_machines': 4, 'num_mpiprocs_per_machine': 16} + } + } + } + + +.. _topics:schedulers:job_resources:node: + +NodeNumberJobResource (PBS-like) +................................ + +The |NodeNumberJobResource| class is used for specifying job resources in PBS and SLURM. + +The class has the following attributes: + +* ``res.num_machines``: the number of machines (also called nodes) on which the code should run +* ``res.num_mpiprocs_per_machine``: number of MPI processes to use on each machine +* ``res.tot_num_mpiprocs``: the total number of MPI processes that this job requests +* ``res.num_cores_per_machine``: the number of cores to use on each machine +* ``res.num_cores_per_mpiproc``: the number of cores to run each MPI process on + +You need to specify only two among the first three fields above, but they have to be defined upon construction. +We suggest using the first two, for instance: + +.. code-block:: python + + res = NodeNumberJobResource(num_machines=4, num_mpiprocs_per_machine=16) + +asks the scheduler to allocate 4 machines, with 16 MPI processes on each machine. +This will automatically ask for a total of ``4*16=64`` total number of MPI processes. + +.. note:: + + When creating a new computer, you will be asked for a ``default_mpiprocs_per_machine``. + If specified, it will automatically be used as the default value for ``num_mpiprocs_per_machine`` whenever creating the resources for that computer. + +.. note:: + + If you prefer using ``res.tot_num_mpiprocs`` instead, make sure it is a multiple of ``res.num_machines`` and/or ``res.num_mpiprocs_per_machine``. + + The first three fields are related by the equation: + + .. code-block:: python + + res.num_machines * res.num_mpiprocs_per_machine = res.tot_num_mpiprocs + + +The ``num_cores_per_machine`` and ``num_cores_per_mpiproc`` fields are optional and must satisfy the equation: + +.. code-block:: python + + res.num_cores_per_mpiproc * res.num_mpiprocs_per_machine = res.num_cores_per_machine + + +.. note:: + + In PBSPro, the ``num_mpiprocs_per_machine`` and ``num_cores_per_machine`` fields are used for mpiprocs and ppn respectively. + + In Torque, the ``num_mpiprocs_per_machine`` field is used for ppn unless the ``num_mpiprocs_per_machine`` is specified. + +.. _topics:schedulers:job_resources:par: + +ParEnvJobResource (SGE-like) +............................ + +The :py:class:`~aiida.schedulers.datastructures.ParEnvJobResource` class is used for specifying the resources of SGE and similar schedulers, which require specifying a *parallel environment* and the *total number of CPUs* requested. + +The class has the following attributes: + +* ``res.parallel_env``: the parallel environment in which you want to run your job (a string) +* ``res.tot_num_mpiprocs``: the total number of MPI processes that this job requests + +Both attributes are required. +No checks are done on the consistency between the specified parallel environment and the total number of MPI processes requested (for instance, some parallel environments may have been configured by your cluster administrator to run on a single machine). +It is your responsibility to make sure that the information is valid, otherwise the submission will fail. + +Setting the fields directly in the class constructor: + +.. code-block:: python + + res = ParEnvJobResource(parallel_env='mpi', tot_num_mpiprocs=64) + +And setting the fields using the ``metadata.options`` input dictionary of the |CalcJob|: + +.. code-block:: python + + inputs = { + 'metadata': { + 'options': { + resources', {'parallel_env': 'mpi', 'tot_num_mpiprocs': 64} + } + } + } + + +.. |NodeNumberJobResource| replace:: :py:class:`~aiida.schedulers.datastructures.NodeNumberJobResource` +.. |JobResource| replace:: :py:class:`~aiida.schedulers.datastructures.JobResource` +.. |CalcJob| replace:: :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` diff --git a/docs/source/topics/workflows/usage.rst b/docs/source/topics/workflows/usage.rst index 8b52065e63..c83f863117 100644 --- a/docs/source/topics/workflows/usage.rst +++ b/docs/source/topics/workflows/usage.rst @@ -447,7 +447,7 @@ If a non-zero integer value is detected, the engine will interpret this as an ex In addition, the integer return value will be set as the ``exit_status`` of the work chain, which combined with the ``Finished`` process state will denote that the worchain is considered to be ``Failed``, as explained in the section on the :ref:`process state `. This is useful because it allows a workflow designer to easily exit from a work chain and use the return value to communicate programmatically the reason for the work chain stopping. -We assume that you have read the `section on how to define exit code `_ through the process specification of the work chain. +We assume that you have read the :ref:`section on how to define exit codes ` through the process specification of the work chain. Consider the following example work chain that defines such an exit code: .. code:: python @@ -485,7 +485,7 @@ Returning this exit code, which will be an instance of the :py:class:`~aiida.eng The ``message`` attribute of an ``ExitCode`` can also be a string that contains placeholders. This is useful when the exit code's message is generic enough to a host of situations, but one would just like to parameterize the exit message. -To concretize the template message of an exit code, simply call the :meth:`~aiida.engine.processes.exit_code.ExitCode.format` method and pass the parameters as keyword arguments:: +To concretize the template message of an exit code, simply call the :meth:`~aiida.engine.processes.exit_code.ExitCode.format` method and pass the parameters as keyword arguments: .. code:: python @@ -493,7 +493,7 @@ To concretize the template message of an exit code, simply call the :meth:`~aiid exit_code_concrete = exit_code_template.format(parameter='some_specific_key') This concept can also be applied within the scope of a process. -In the process spec, we can declare a generic exit code whose exact message should depend on one or multiple parameters:: +In the process spec, we can declare a generic exit code whose exact message should depend on one or multiple parameters: .. code:: python diff --git a/docs/source/working_with_aiida/caching.rst b/docs/source/working_with_aiida/caching.rst deleted file mode 100644 index 4e2f864352..0000000000 --- a/docs/source/working_with_aiida/caching.rst +++ /dev/null @@ -1,181 +0,0 @@ -.. _caching: - -******* -Caching -******* - -Enabling caching ----------------- - -There are numerous reasons why you may need to re-run calculations you’ve already done before. -Since AiiDA stores the full provenance of each calculation, it can detect whether a calculation has been run before and reuse its outputs without wasting computational resources. -This is what we mean by **caching** in AiiDA. - -Caching is **not enabled by default**. -In order to enable caching for your AiiDA profile (here called ``aiida2``), place the following ``cache_config.yml`` file in your ``.aiida`` configuration folder: - -.. code:: yaml - - aiida2: - default: True - -From this point onwards, when you launch a new calculation, AiiDA will compare its hash (depending both on the type of calculation and its inputs, see :ref:`caching_matches`) against other calculations already present in your database. -If another calculation with the same hash is found, AiiDA will reuse its results without repeating the actual calculation. - -In order to ensure that the provenance graph with and without caching is the same, AiiDA creates both a new calculation node and a copy of the output data nodes as shown in :numref:`fig_caching`. - -.. _fig_caching: -.. figure:: include/images/caching.png - :align: center - :height: 350px - - When reusing the results of a calculation **C** for a new calculation **C'**, AiiDA simply makes a copy of the result nodes and links them up as usual. - -.. note:: - - AiiDA uses the *hashes* of the input nodes **D1** and **D2** when searching the calculation cache. - I.e. if the input of **C'** were new nodes **D1'** and **D2'** with the same content (hash) as **D1**, **D2**, the cache would trigger as well. - - -.. note:: Caching is **not** implemented at the WorkChain/workfunction level (see :ref:`caching_limitations` for details). - - -.. _caching_matches: - -How are nodes hashed? ---------------------- - -*Hashing* is turned on by default, i.e. all nodes in AiiDA are hashed (see also :ref:`devel_controlling_hashing`). -The hash of a ``Data`` node is computed from: - -* all attributes of the node, except the ``_updatable_attributes`` and ``_hash_ignored_attributes`` -* the ``__version__`` of the package which defined the node class -* the content of the repository folder of the node -* the UUID of the computer, if the node is associated with one - -The hash of a :class:`~aiida.orm.ProcessNode` includes, on top of this, the hashes of all of its input ``Data`` nodes. - -Once a node is stored in the database, its hash is stored in the ``_aiida_hash`` extra, and this extra is used to find matching nodes. -If a node of the same class with the same hash already exists in the database, this is considered a cache match. - -Use the :meth:`~aiida.orm.nodes.Node.get_hash` method to check the hash of any node. - -In order to figure out why a calculation is *not* being reused, the :meth:`~aiida.orm.nodes.Node._get_objects_to_hash` method may be useful: - -.. ipython:: - :verbatim: - - In [5]: calc=load_node(1234) - - In [6]: calc.get_hash() - Out[6]: '62eca804967c9428bdbc11c692b7b27a59bde258d9971668e19ccf13a5685eb8' - - In [7]: calc._get_objects_to_hash() - Out[7]: - ['1.0.0b4', - {'resources': {'num_machines': 2, 'default_mpiprocs_per_machine': 28}, - 'parser_name': 'cp2k', - 'linkname_retrieved': 'retrieved'}, - , - '6850dc88-0949-482e-bba6-8b11205aec11', - {'code': 'f6bd65b9ca3a5f0cf7d299d9cfc3f403d32e361aa9bb8aaa5822472790eae432', - 'parameters': '2c20fdc49672c3505cebabacfb9b1258e71e7baae5940a80d25837bee0032b59', - 'structure': 'c0f1c1d1bbcfc7746dcf7d0d675904c62a5b1759d37db77b564948fa5a788769', - 'parent_calc_folder': 'e375178ceeffcde086546d3ddbce513e0527b5fa99993091b2837201ad96569c'}] - - -Configuration -------------- - -Class level -........... - -Besides an on/off switch per profile, the ``.aiida/cache_config.yml`` provides control over caching at the level of specific calculations using their corresponding entry point strings (see the output of ``verdi plugin list aiida.calculations``): - -.. code:: yaml - - profile-name: - default: False - enabled: - - aiida.calculations:quantumespresso.pw - disabled: - - aiida.calculations:templatereplacer - -In this example, caching is disabled by default, but explicitly enabled for calculaions of the ``PwCalculation`` class, identified by the ``aiida.calculations:quantumespresso.pw`` entry point string. -It also shows how to disable caching for particular calculations (which has no effect here due to the profile-wide default). - -For calculations which do not have an entry point, you need to specify the fully qualified Python name instead. For example, the ``seekpath_structure_analysis`` calcfunction defined in ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis`` is labelled as ``aiida_quantumespresso.workflows.functions.seekpath_structure_analysis.seekpath_structure_analysis``. From an existing :class:`~aiida.orm.nodes.process.calculation.CalculationNode`, you can get the identifier string through the ``process_type`` attribute. - -The caching configuration also accepts ``*`` wildcards. For example, the following configuration enables caching for all calculation entry points defined by ``aiida-quantumespresso``, and the ``seekpath_structure_analysis`` calcfunction. Note that the ``*.seekpath_structure_analysis`` entry needs to be quoted, because it starts with ``*`` which is a special character in YAML. - -.. code:: yaml - - profile-name: - default: False - enabled: - - aiida.calculations:quantumespresso.* - - '*.seekpath_structure_analysis' - -You can even override a wildcard with a more specific entry. The following configuration enables caching for all ``aiida.calculation`` entry points, except those of ``aiida-quantumespresso``: - -.. code:: yaml - - profile-name: - default: False - enabled: - - aiida.calculations:* - disabled: - - aiida.calculations:quantumespresso.* - - -Instance level -.............. - -Even when caching is turned off for a given calculation type, you can enable it on a case-by-case basis by using the :class:`~aiida.manage.caching.enable_caching` context manager for testing purposes: - -.. code:: python - - from aiida.engine import run - from aiida.manage.caching import enable_caching - with enable_caching(identifier='aiida.calculations:templatereplacer'): - run(...) - -.. warning:: - - This affects only the current python interpreter and won't change the behavior of the daemon workers. - This means that this technique is only useful when using :py:class:`~aiida.engine.run`, and **not** with :py:class:`~aiida.engine.submit`. - -If you suspect a node is being reused in error (e.g. during development), you can also manually *prevent* a specific node from being reused: - -1. Load one of the nodes you suspect to be a clone. - Check that :meth:`~aiida.orm.nodes.Node.get_cache_source` returns a UUID. - If it returns `None`, the node was not cloned. -2. Clear the hashes of all nodes that are considered identical to this node: - - .. code:: python - - for n in node.get_all_same_nodes(): - n.clear_hash() -3. Run your calculation again. The node in question should no longer be reused. - - -.. _caching_limitations: - -Limitations ------------ - -#. Workflow nodes are not cached. In the current design this follows from the requirement that the provenance graph be independent of whether caching is enabled or not: - - * **Calculation nodes:** Calculation nodes can have data inputs and create new data nodes as outputs. - In order to make it look as if a cloned calculation produced its own outputs, the output nodes are copied and linked as well. - * **Workflow nodes:** Workflows differ from calculations in that they can *return* an input node or an output node created by a calculation. - Since caching does not care about the *identity* of input nodes but only their *content*, it is not straightforward to figure out which node to return in a cached workflow. - - For the moment, this limitation is acceptable since the runtime of AiiDA WorkChains is usually dominated by expensive calculations, which are covered by the current caching mechanism. - -#. The caching mechanism for calculations *should* trigger only when the inputs and the calculation to be performed are exactly the same. - While AiiDA's hashes include the version of the python package containing the calculation/data classes, it cannot detect cases where the underlying python code was changed without increasing the version number. - Another edge case would be if the parser lives in a different python package than the calculation (calculation nodes store the name of the parser used but not the version of the package containing the parser). - -Finally, while caching saves unnecessary computations, it does not save disk space: The output nodes of the cached calculation are full copies of the original outputs. -The plan is to add data deduplication as a global feature at the repository and database level (independent of caching). diff --git a/docs/source/working_with_aiida/cookbook.rst b/docs/source/working_with_aiida/cookbook.rst index 844cfe187b..1e8ca01b5d 100644 --- a/docs/source/working_with_aiida/cookbook.rst +++ b/docs/source/working_with_aiida/cookbook.rst @@ -28,7 +28,7 @@ you can use a modification of the following script:: """ from aiida import orm - computer = Computer.get(name='deneb') + computer = Computer.get(label='deneb') transport = computer.get_transport() scheduler = computer.get_scheduler() scheduler.set_transport(transport) @@ -114,7 +114,7 @@ Here is, as an example, an useful utility function:: manager = get_manager() profile = manager.get_profile() return AuthInfo.objects.get( - dbcomputer_id=Computer.get(name=computername).id, + dbcomputer_id=Computer.get(label=computername).id, aiidauser_id=User.get(email=profile.default_user).id ) diff --git a/docs/source/working_with_aiida/index.rst b/docs/source/working_with_aiida/index.rst index 3febd0037c..d00e6f2723 100644 --- a/docs/source/working_with_aiida/index.rst +++ b/docs/source/working_with_aiida/index.rst @@ -94,15 +94,6 @@ Import and Export ../import_export/main ../import_export/external_dbs -======= -Caching -======= - -.. toctree:: - :maxdepth: 4 - - caching.rst - ========== Schedulers diff --git a/docs/source/working_with_aiida/tips/ssh_proxycommand.rst b/docs/source/working_with_aiida/tips/ssh_proxycommand.rst deleted file mode 100644 index 39b2b16d9a..0000000000 --- a/docs/source/working_with_aiida/tips/ssh_proxycommand.rst +++ /dev/null @@ -1,130 +0,0 @@ -.. _ssh_proxycommand: - -####################################### -Using the proxy_command option with ssh -####################################### - -This page explains how to use the ``proxy_command`` feature of ``ssh``. This feature -is needed when you want to connect to a computer ``B``, but you are not allowed to -connect directly to it; instead, you have to connect to computer ``A`` first, and then -perform a further connection from ``A`` to ``B``. - - -Requirements -++++++++++++ -The idea is that you ask ``ssh`` to connect to computer ``B`` by using -a proxy to create a sort of tunnel. One way to perform such an -operation is to use ``netcat``, a tool that simply takes the standard input and -redirects it to a given TCP port. - -Therefore, a requirement is to install ``netcat`` on computer A. -You can already check if the ``netcat`` or ``nc`` command is available -on you computer, since some distributions include it (if it is already -installed, the output of the command:: - - which netcat - -or:: - - which nc - -will return the absolute path to the executable). - -If this is not the case, you will need to install it on your own. -Typically, it will be sufficient to look for a netcat distribution on -the web, unzip the downloaded package, ``cd`` into the folder and -execute something like:: - - ./configure --prefix=. - make - make install - -This usually creates a subfolder ``bin``, containing the ``netcat`` -and ``nc`` executables. -Write down the full path to ``nc`` that we will need later. - - -ssh/config -++++++++++ -You can now test the proxy command with ``ssh``. Edit the -``~/.ssh/config`` file on the computer on which you installed AiiDA -(or create it if missing) and add the following lines:: - - Host FULLHOSTNAME_B - Hostname FULLHOSTNAME_B - User USER_B - ProxyCommand ssh USER_A@FULLHOSTNAME_A ABSPATH_NETCAT %h %p - -where you have to replace: - -* ``FULLHOSTNAMEA`` and ``FULLHOSTNAMEB`` with - the fully-qualified hostnames of computer ``A`` and ``B`` (remembering that ``B`` - is the computer you want to actually connect to, and ``A`` is the - intermediate computer to which you have direct access) -* ``USER_A`` and ``USER_B`` are the usernames on the two machines (that - can possibly be the same). -* ``ABSPATH_NETCAT`` is the absolute path to the ``nc`` executable - that you obtained in the previous step. - -Remember also to configure passwordless ssh connections using ssh keys -both from your computer to ``A``, and from ``A`` to ``B``. - -Once you add this lines and save the file, try to execute:: - - ssh FULLHOSTNAME_B - -which should allow you to directly connect to ``B``. - - -WARNING -+++++++ - -There are several versions of netcat available on the web. -We found at least one case in which the executable wasn't working -properly. -At the end of the connection, the ``netcat`` executable might still be -running: as a result, you may rapidly -leave the cluster with hundreds of opened ``ssh`` connections, one for -every time you connect to the cluster ``B``. -Therefore, check on both computers ``A`` and ``B`` that the number of -processes ``netcat`` and ``ssh`` are disappearing if you close the -connection. -To check if such processes are running, you can execute:: - - ps -aux | grep - -Remember that a cluster might have more than one login node, and the ``ssh`` -connection will randomly connect to any of them. - - -AiiDA config -++++++++++++ -If the above steps work, setup and configure now the computer as -explained :ref:`here `. - -If you properly set up the ``~/.ssh/config`` file in the previous -step, AiiDA should properly parse the information in the file and -provide the correct default value for the ``proxy_command`` during the -``verdi computer configure`` step. - -.. _ssh_proxycommand_notes: - -Some notes on the ``proxy_command`` option ------------------------------------------- - -* In the ``~/.ssh/config`` file, you can leave the ``%h`` and ``%p`` - placeholders, that are then automatically replaced by ssh with the hostname - and the port of the machine ``B`` when creating the proxy. - However, in the AiiDA ``proxy_command`` option, you need to put the - actual hostname and port. If you start from a properly configured - ``~/.ssh/config`` file, AiiDA will already replace these - placeholders with the correct values. However, if you input the ``proxy_command`` - value manually, remember to write the - hostname and the port and not ``%h`` and ``%p``. -* In the ``~/.ssh/config`` file, you can also insert stdout and stderr - redirection, e.g. ``2> /dev/null`` to hide any error that may occur - during the proxying/tunneling. However, you should only give AiiDA - the actual command to be executed, without any redirection. Again, - AiiDA will remove the redirection when it automatically reads the - ``~/.ssh/config`` file, but be careful if entering manually the - content in this field. diff --git a/environment.yml b/environment.yml index 7ac0513ff7..7915dfd4be 100644 --- a/environment.yml +++ b/environment.yml @@ -19,8 +19,8 @@ dependencies: - ipython~=7.0 - jinja2~=2.10 - kiwipy[rmq]~=0.5.5 -- numpy<1.18,~=1.17 -- paramiko~=2.6 +- numpy~=1.17 +- paramiko~=2.7 - pika~=1.1 - plumpy~=0.15.0 - pgsu~=0.1.0 diff --git a/pyproject.toml b/pyproject.toml index db666b6ad6..e535691248 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,45 @@ [build-system] -requires = [ "setuptools>=40.8.0", "wheel", "reentry~=1.3", "fastentrypoints~=0.12",] +requires = ["setuptools>=40.8.0,<50", "wheel", "reentry~=1.3", "fastentrypoints~=0.12"] build-backend = "setuptools.build_meta:__legacy__" + +[tool.tox] +# To use tox, see https://tox.readthedocs.io +# Simply pip or conda install tox +# If you use conda, you may also want to install tox-conda +# then run `tox` or `tox -e py37 -- {pytest args}` + +# To ensure rebuild of the tox environment, +# either simple delete the .tox folder or use `tox -r` + +legacy_tox_ini = """ +[tox] +envlist = py37-django + +[testenv:py{35,36,37,38}-{django,sqla}] +deps = + py35: -rrequirements/requirements-py-3.5.txt + py36: -rrequirements/requirements-py-3.6.txt + py37: -rrequirements/requirements-py-3.7.txt + py38: -rrequirements/requirements-py-3.8.txt +setenv = + django: AIIDA_TEST_BACKEND = django + sqla: AIIDA_TEST_BACKEND = sqlalchemy +commands = pytest {posargs} + +[testenv:py{36,37,38}-docs-{clean,update}] +deps = + py36: -rrequirements/requirements-py-3.6.txt + py37: -rrequirements/requirements-py-3.7.txt + py38: -rrequirements/requirements-py-3.8.txt +setenv = + update: RUN_APIDOC = False +changedir = docs +whitelist_externals = make +commands = + clean: make clean + make debug + +[testenv:py{36,37,38}-pre-commit] +extras = all +commands = pre-commit run {posargs} +""" diff --git a/pytest.ini b/pytest.ini index c8f247074f..757395ca17 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +addopts = --benchmark-skip testpaths = tests filterwarnings = ignore::DeprecationWarning:babel: diff --git a/requirements/requirements-py-3.5.txt b/requirements/requirements-py-3.5.txt index ad5465c54b..9b58b33841 100644 --- a/requirements/requirements-py-3.5.txt +++ b/requirements/requirements-py-3.5.txt @@ -97,7 +97,9 @@ PyNaCl==1.3.0 pyparsing==2.4.6 pyrsistent==0.15.7 pytest==5.4.2 +pytest-benchmark==3.2.3 pytest-cov==2.8.1 +pytest-rerunfailures==9.0 pytest-timeout==1.3.4 python-dateutil==2.8.1 python-editor==1.0.4 diff --git a/requirements/requirements-py-3.6.txt b/requirements/requirements-py-3.6.txt index 14397ece42..b4962959a7 100644 --- a/requirements/requirements-py-3.6.txt +++ b/requirements/requirements-py-3.6.txt @@ -96,7 +96,9 @@ PyNaCl==1.3.0 pyparsing==2.4.6 pyrsistent==0.15.7 pytest==5.4.2 +pytest-benchmark==3.2.3 pytest-cov==2.8.1 +pytest-rerunfailures==9.0 pytest-timeout==1.3.4 python-dateutil==2.8.1 python-editor==1.0.4 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 56b5b43aec..e027ff75ca 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -95,7 +95,9 @@ PyNaCl==1.3.0 pyparsing==2.4.6 pyrsistent==0.15.7 pytest==5.4.2 +pytest-benchmark==3.2.3 pytest-cov==2.8.1 +pytest-rerunfailures==9.0 pytest-timeout==1.3.4 python-dateutil==2.8.1 python-editor==1.0.4 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index c6d5e28ce6..cff1f7a7a9 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -94,7 +94,9 @@ PyNaCl==1.3.0 pyparsing==2.4.6 pyrsistent==0.15.7 pytest==5.4.2 +pytest-benchmark==3.2.3 pytest-cov==2.8.1 +pytest-rerunfailures==9.0 pytest-timeout==1.3.4 python-dateutil==2.8.1 python-editor==1.0.4 diff --git a/setup.json b/setup.json index 69f1885d4d..6b7eed43d0 100644 --- a/setup.json +++ b/setup.json @@ -1,6 +1,6 @@ { "name": "aiida-core", - "version": "1.3.1", + "version": "1.4.0", "url": "http://www.aiida.net/", "license": "MIT License", "author": "The AiiDA team", @@ -34,8 +34,8 @@ "ipython~=7.0", "jinja2~=2.10", "kiwipy[rmq]~=0.5.5", - "numpy~=1.17,<1.18", - "paramiko~=2.6", + "numpy~=1.17", + "paramiko~=2.7", "pika~=1.1", "plumpy~=0.15.0", "pgsu~=0.1.0", @@ -95,7 +95,7 @@ "pre-commit~=2.2", "pylint~=2.5.0", "pylint-django~=2.0", - "toml~=0.10.0" + "tomlkit~=0.7.0" ], "tests": [ "aiida-export-migration-tests==0.9.0", @@ -104,6 +104,8 @@ "pytest~=5.4", "pytest-timeout~=1.3", "pytest-cov~=2.7", + "pytest-rerunfailures~=9.0", + "pytest-benchmark~=3.2", "coverage<5.0", "sqlalchemy-diff~=0.1.3" ], diff --git a/tests/backends/aiida_django/migrations/test_migrations_0045_dbgroup_extras.py b/tests/backends/aiida_django/migrations/test_migrations_0045_dbgroup_extras.py new file mode 100644 index 0000000000..f9c1686ff1 --- /dev/null +++ b/tests/backends/aiida_django/migrations/test_migrations_0045_dbgroup_extras.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module,invalid-name +"""Test migration to add the `extras` JSONB column to the `DbGroup` model.""" + +from .test_migrations_common import TestMigrations + + +class TestGroupExtrasMigration(TestMigrations): + """Test migration to add the `extras` JSONB column to the `DbGroup` model.""" + + migrate_from = '0044_dbgroup_type_string' + migrate_to = '0045_dbgroup_extras' + + def setUpBeforeMigration(self): + DbGroup = self.apps.get_model('db', 'DbGroup') + + group = DbGroup(label='01', user_id=self.default_user.id, type_string='user') + group.save() + self.group_pk = group.pk + + def test_extras(self): + """Test that the model now has an extras column with empty dictionary as default.""" + DbGroup = self.apps.get_model('db', 'DbGroup') + + group = DbGroup.objects.get(pk=self.group_pk) + self.assertEqual(group.extras, {}) diff --git a/tests/backends/aiida_django/test_generic.py b/tests/backends/aiida_django/test_generic.py deleted file mode 100644 index 0ddc506df6..0000000000 --- a/tests/backends/aiida_django/test_generic.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Generic tests that need the use of the DB -""" - -from aiida import orm -from aiida.backends.testbase import AiidaTestCase -from aiida.orm import Data - - -class TestComputer(AiidaTestCase): - """Test the Computer class.""" - - def test_deletion(self): - """Test computer deletion.""" - from aiida.orm import CalcJobNode, Computer - from aiida.common.exceptions import InvalidOperation - - newcomputer = Computer( - name='testdeletioncomputer', - hostname='localhost', - transport_type='local', - scheduler_type='pbspro', - workdir='/tmp/aiida' - ).store() - - # This should be possible, because nothing is using this computer - orm.Computer.objects.delete(newcomputer.id) - - calc = CalcJobNode(computer=self.computer) - calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - calc.store() - - # This should fail, because there is at least a calculation - # using this computer (the one created just above) - with self.assertRaises(InvalidOperation): - orm.Computer.objects.delete(self.computer.id) # pylint: disable=no-member - - -class TestGroupsDjango(AiidaTestCase): - """Test groups.""" - - # Tests that are specific to the Django backend - def test_query(self): - """ - Test if queries are working - """ - from aiida.common.exceptions import NotExistent, MultipleObjectsError - - backend = self.backend - - default_user = backend.users.create('{}@aiida.net'.format(self.id())).store() - - g_1 = backend.groups.create(label='testquery1', user=default_user).store() - self.addCleanup(lambda: backend.groups.delete(g_1.id)) - g_2 = backend.groups.create(label='testquery2', user=default_user).store() - self.addCleanup(lambda: backend.groups.delete(g_2.id)) - - n_1 = Data().store().backend_entity - n_2 = Data().store().backend_entity - n_3 = Data().store().backend_entity - n_4 = Data().store().backend_entity - - g_1.add_nodes([n_1, n_2]) - g_2.add_nodes([n_1, n_3]) - - newuser = backend.users.create(email='test@email.xx') - g_3 = backend.groups.create(label='testquery3', user=newuser).store() - self.addCleanup(lambda: backend.groups.delete(g_3.id)) - - # I should find it - g_1copy = backend.groups.get(uuid=g_1.uuid) - self.assertEqual(g_1.pk, g_1copy.pk) - - # NOTE: Here we pass type_string='' to all query and get calls in the groups collection because - # otherwise run the risk that we will pick up autogroups as well when really we're just interested - # the the ones that we created in this test - # Try queries - res = backend.groups.query(nodes=n_4, type_string='') - self.assertListEqual([_.pk for _ in res], []) - - res = backend.groups.query(nodes=n_1, type_string='') - self.assertEqual([_.pk for _ in res], [_.pk for _ in [g_1, g_2]]) - - res = backend.groups.query(nodes=n_2, type_string='') - self.assertEqual([_.pk for _ in res], [_.pk for _ in [g_1]]) - - # I try to use 'get' with zero or multiple results - with self.assertRaises(NotExistent): - backend.groups.get(nodes=n_4, type_string='') - with self.assertRaises(MultipleObjectsError): - backend.groups.get(nodes=n_1, type_string='') - - self.assertEqual(backend.groups.get(nodes=n_2, type_string='').pk, g_1.pk) - - # Query by user - res = backend.groups.query(user=newuser, type_string='') - self.assertEqual(set(_.pk for _ in res), set(_.pk for _ in [g_3])) - - # Same query, but using a string (the username=email) instead of - # a DbUser object - res = backend.groups.query(user=newuser.email, type_string='') - self.assertEqual(set(_.pk for _ in res), set(_.pk for _ in [g_3])) - - res = backend.groups.query(user=default_user, type_string='') - self.assertEqual(set(_.pk for _ in res), set(_.pk for _ in [g_1, g_2])) - - def test_creation_from_dbgroup(self): - """Test creation of a group from another group.""" - backend = self.backend - - node = Data().store() - - default_user = backend.users.create('{}@aiida.net'.format(self.id())).store() - - grp = backend.groups.create(label='testgroup_from_dbgroup', user=default_user).store() - self.addCleanup(lambda: backend.groups.delete(grp.id)) - - grp.store() - grp.add_nodes([node.backend_entity]) - - dbgroup = grp.dbmodel - gcopy = backend.groups.from_dbmodel(dbgroup) - - self.assertEqual(grp.pk, gcopy.pk) - self.assertEqual(grp.uuid, gcopy.uuid) diff --git a/tests/backends/aiida_django/test_manager.py b/tests/backends/aiida_django/test_manager.py new file mode 100644 index 0000000000..16f0c8c838 --- /dev/null +++ b/tests/backends/aiida_django/test_manager.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the django backend manager.""" + +from aiida.backends.djsite.manager import DjangoSettingsManager +from aiida.backends.testbase import AiidaTestCase +from aiida.common import exceptions + + +class TestDjangoSettingsManager(AiidaTestCase): + """Test the DjangoSettingsManager class and its methods.""" + + def setUp(self): + super().setUp() + self.settings_manager = DjangoSettingsManager() + + def test_set_get(self): + """Test the get and set methods.""" + temp_key = 'temp_setting' + temp_value = 'Valuable value' + temp_description = 'Temporary value for testing' + + self.settings_manager.set(temp_key, temp_value, temp_description) + self.assertEqual(self.settings_manager.get(temp_key).value, temp_value) + self.assertEqual(self.settings_manager.get(temp_key).description, temp_description) + + non_existent_key = 'I_dont_exist' + + with self.assertRaises(exceptions.NotExistent): + self.settings_manager.get(non_existent_key) + + def test_delete(self): + """Test the delete method.""" + temp_key = 'temp_setting' + temp_value = 'Valuable value' + + self.settings_manager.set(temp_key, temp_value) + self.settings_manager.delete(temp_key) + + non_existent_key = 'I_dont_exist' + + with self.assertRaises(exceptions.NotExistent): + self.settings_manager.delete(non_existent_key) diff --git a/tests/backends/aiida_sqlalchemy/test_generic.py b/tests/backends/aiida_sqlalchemy/test_generic.py deleted file mode 100644 index cba65a7abe..0000000000 --- a/tests/backends/aiida_sqlalchemy/test_generic.py +++ /dev/null @@ -1,166 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic tests that need the be specific to sqlalchemy.""" - -from aiida.backends.testbase import AiidaTestCase -from aiida.orm import Data - - -class TestComputer(AiidaTestCase): - """Test the Computer class.""" - - def test_deletion(self): - """Test computer deletion.""" - from aiida.orm import CalcJobNode - from aiida.common.exceptions import InvalidOperation - import aiida.backends.sqlalchemy - - newcomputer = self.backend.computers.create( - name='testdeletioncomputer', hostname='localhost', transport_type='local', scheduler_type='pbspro' - ) - newcomputer.store() - - # This should be possible, because nothing is using this computer - self.backend.computers.delete(newcomputer.id) - - node = CalcJobNode() - node.computer = self.computer - node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - node.store() - - session = aiida.backends.sqlalchemy.get_scoped_session() - - # This should fail, because there is at least a calculation - # using this computer (the one created just above) - try: - session.begin_nested() - with self.assertRaises(InvalidOperation): - self.backend.computers.delete(self.computer.id) # pylint: disable=no-member - finally: - session.rollback() - - -class TestGroupsSqla(AiidaTestCase): - """Test group queries for sqlalchemy backend.""" - - def setUp(self): - from aiida.orm.implementation import sqlalchemy as sqla - super().setUp() - self.assertIsInstance(self.backend, sqla.backend.SqlaBackend) - - def test_query(self): - """Test if queries are working.""" - from aiida.common.exceptions import NotExistent, MultipleObjectsError - - backend = self.backend - - simple_user = backend.users.create('simple@ton.com') - - g_1 = backend.groups.create(label='testquery1', user=simple_user).store() - self.addCleanup(lambda: backend.groups.delete(g_1.id)) - g_2 = backend.groups.create(label='testquery2', user=simple_user).store() - self.addCleanup(lambda: backend.groups.delete(g_2.id)) - - n_1 = Data().store().backend_entity - n_2 = Data().store().backend_entity - n_3 = Data().store().backend_entity - n_4 = Data().store().backend_entity - - g_1.add_nodes([n_1, n_2]) - g_2.add_nodes([n_1, n_3]) - - # NOTE: Here we pass type_string to query and get calls so that these calls don't - # find the autogroups (otherwise the assertions will fail) - newuser = backend.users.create(email='test@email.xx') - g_3 = backend.groups.create(label='testquery3', user=newuser).store() - - # I should find it - g_1copy = backend.groups.get(uuid=g_1.uuid) - self.assertEqual(g_1.pk, g_1copy.pk) - - # Try queries - res = backend.groups.query(nodes=n_4, type_string='') - self.assertEqual([_.pk for _ in res], []) - - res = backend.groups.query(nodes=n_1, type_string='') - self.assertEqual([_.pk for _ in res], [_.pk for _ in [g_1, g_2]]) - - res = backend.groups.query(nodes=n_2, type_string='') - self.assertEqual([_.pk for _ in res], [_.pk for _ in [g_1]]) - - # I try to use 'get' with zero or multiple results - with self.assertRaises(NotExistent): - backend.groups.get(nodes=n_4, type_string='') - with self.assertRaises(MultipleObjectsError): - backend.groups.get(nodes=n_1, type_string='') - - self.assertEqual(backend.groups.get(nodes=n_2, type_string='').pk, g_1.pk) - - # Query by user - res = backend.groups.query(user=newuser, type_string='') - self.assertSetEqual(set(_.pk for _ in res), set(_.pk for _ in [g_3])) - - # Same query, but using a string (the username=email) instead of - # a DbUser object - res = backend.groups.query(user=newuser, type_string='') - self.assertSetEqual(set(_.pk for _ in res), set(_.pk for _ in [g_3])) - - res = backend.groups.query(user=simple_user, type_string='') - - self.assertSetEqual(set(_.pk for _ in res), set(_.pk for _ in [g_1, g_2])) - - -class TestGroupNoOrmSQLA(AiidaTestCase): - """These tests check that the group node addition works ok when the skip_orm=True flag is used.""" - - def test_group_general(self): - """General tests to verify that the group addition with the skip_orm=True flag - work properly.""" - backend = self.backend - - node_01 = Data().store().backend_entity - node_02 = Data().store().backend_entity - node_03 = Data().store().backend_entity - node_04 = Data().store().backend_entity - node_05 = Data().store().backend_entity - nodes = [node_01, node_02, node_03, node_04, node_05] - - simple_user = backend.users.create('simple1@ton.com') - group = backend.groups.create(label='test_adding_nodes', user=simple_user).store() - # Single node in a list - group.add_nodes([node_01], skip_orm=True) - # List of nodes - group.add_nodes([node_02, node_03], skip_orm=True) - # Tuple of nodes - group.add_nodes((node_04, node_05), skip_orm=True) - - # Check - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) - - # Try to add a node that is already present: there should be no problem - group.add_nodes([node_01], skip_orm=True) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) - - def test_group_batch_size(self): - """Test that the group addition in batches works as expected.""" - from aiida.orm.groups import Group - - # Create 100 nodes - nodes = [] - for _ in range(100): - nodes.append(Data().store().backend_entity) - - # Add nodes to groups using different batch size. Check in the end the - # correct addition. - batch_sizes = (1, 3, 10, 1000) - for batch_size in batch_sizes: - group = Group(label='test_batches_' + str(batch_size)).store() - group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size) - self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 01ce327f31..1cbe698bc8 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -1708,3 +1708,42 @@ def test_group_string_update(self): self.assertEqual(group_autorun.type_string, 'core.auto') finally: session.close() + + +class TestGroupExtrasMigration(TestMigrationsSQLA): + """Test migration to add the `extras` JSONB column to the `DbGroup` model.""" + + migrate_from = 'bf591f31dd12' # bf591f31dd12_dbgroup_type_string.py + migrate_to = '0edcdd5a30f0' # 0edcdd5a30f0_dbgroup_extras.py + + def setUpBeforeMigration(self): + """Create a DbGroup.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + DbUser = self.get_current_table('db_dbuser') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + default_user = DbUser(email='{}@aiida.net'.format(self.id())) + session.add(default_user) + session.commit() + + group = DbGroup(label='01', user_id=default_user.id, type_string='user') + session.add(group) + session.commit() + + # Store values for later tests + self.group_pk = group.id + + finally: + session.close() + + def test_group_string_update(self): + """Test that the model now has an extras column with empty dictionary as default.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + group = session.query(DbGroup).filter(DbGroup.id == self.group_pk).one() + self.assertEqual(group.extras, {}) + finally: + session.close() diff --git a/tests/backends/aiida_sqlalchemy/test_nodes.py b/tests/backends/aiida_sqlalchemy/test_nodes.py index 4b68b61fd6..2ace41401d 100644 --- a/tests/backends/aiida_sqlalchemy/test_nodes.py +++ b/tests/backends/aiida_sqlalchemy/test_nodes.py @@ -120,14 +120,14 @@ def test_multiple_node_creation(self): # Query the session before commit res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 0, 'There should not be any nodes with this' 'UUID in the session/DB.') + self.assertEqual(len(res), 0, 'There should not be any nodes with this UUID in the session/DB.') # Commit the transaction session.commit() # Check again that the node is not in the DB res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 0, 'There should not be any nodes with this' 'UUID in the session/DB.') + self.assertEqual(len(res), 0, 'There should not be any nodes with this UUID in the session/DB.') # Get the automatic user dbuser = orm.User.objects.get_default().backend_entity.dbmodel @@ -138,11 +138,11 @@ def test_multiple_node_creation(self): # Query the session before commit res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 1, 'There should be a node in the session/DB with the ' 'UUID {}'.format(node_uuid)) + self.assertEqual(len(res), 1, 'There should be a node in the session/DB with the UUID {}'.format(node_uuid)) # Commit the transaction session.commit() # Check again that the node is in the db res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() - self.assertEqual(len(res), 1, 'There should be a node in the session/DB with the ' 'UUID {}'.format(node_uuid)) + self.assertEqual(len(res), 1, 'There should be a node in the session/DB with the UUID {}'.format(node_uuid)) diff --git a/tests/backends/aiida_sqlalchemy/test_utils.py b/tests/backends/aiida_sqlalchemy/test_utils.py index b91f222a2d..63484bb567 100644 --- a/tests/backends/aiida_sqlalchemy/test_utils.py +++ b/tests/backends/aiida_sqlalchemy/test_utils.py @@ -103,6 +103,6 @@ def create_database(url, encoding='utf8'): engine.execute(text) else: - raise Exception('Only PostgreSQL with the psycopg2 driver is ' 'supported.') + raise Exception('Only PostgreSQL with the psycopg2 driver is supported.') finally: engine.dispose() diff --git a/tests/benchmark/test_engine.py b/tests/benchmark/test_engine.py new file mode 100644 index 0000000000..1f5d6a038b --- /dev/null +++ b/tests/benchmark/test_engine.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=unused-argument,redefined-outer-name +"""Performance benchmark tests for local processes. + +The purpose of these tests is to benchmark and compare processes, +which are executed *via* both a local runner and the daemon. +""" +import datetime + +from tornado import gen +import pytest + +from aiida.engine import run_get_node, submit, ToContext, while_, WorkChain +from aiida.manage.manager import get_manager +from aiida.orm import Code, Int +from aiida.plugins.factories import CalculationFactory + +ArithmeticAddCalculation = CalculationFactory('arithmetic.add') + + +class WorkchainLoop(WorkChain): + """A basic Workchain to run a looped step n times.""" + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('iterations', required=True) + spec.input('code', required=False) + spec.outline(cls.init_loop, while_(cls.terminate_loop)(cls.run_task)) + + def init_loop(self): + self.ctx.iter = self.inputs.iterations.value + self.ctx.counter = 0 + + def terminate_loop(self): + if self.ctx.counter >= self.ctx.iter: + return False + self.ctx.counter += 1 + return True + + def run_task(self): + pass + + +class WorkchainLoopWcSerial(WorkchainLoop): + """A WorkChain that submits another WorkChain n times in different steps.""" + + def run_task(self): + future = self.submit(WorkchainLoop, iterations=Int(1)) + return ToContext(**{'wkchain' + str(self.ctx.counter): future}) + + +class WorkchainLoopWcThreaded(WorkchainLoop): + """A WorkChain that submits another WorkChain n times in the same step.""" + + def init_loop(self): + super().init_loop() + self.ctx.iter = 1 + + def run_task(self): + + context = { + 'wkchain' + str(i): self.submit(WorkchainLoop, iterations=Int(1)) + for i in range(self.inputs.iterations.value) + } + return ToContext(**context) + + +class WorkchainLoopCalcSerial(WorkchainLoop): + """A WorkChain that submits a CalcJob n times in different steps.""" + + def run_task(self): + inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': self.inputs.code, + } + future = self.submit(ArithmeticAddCalculation, **inputs) + return ToContext(addition=future) + + +class WorkchainLoopCalcThreaded(WorkchainLoop): + """A WorkChain that submits a CalcJob n times in the same step.""" + + def init_loop(self): + super().init_loop() + self.ctx.iter = 1 + + def run_task(self): + futures = {} + for i in range(self.inputs.iterations.value): + inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': self.inputs.code, + } + futures['addition' + str(i)] = self.submit(ArithmeticAddCalculation, **inputs) + return ToContext(**futures) + + +WORKCHAINS = { + 'basic-loop': (WorkchainLoop, 4, 0), + 'serial-wc-loop': (WorkchainLoopWcSerial, 4, 4), + 'threaded-wc-loop': (WorkchainLoopWcThreaded, 4, 4), + 'serial-calcjob-loop': (WorkchainLoopCalcSerial, 4, 4), + 'threaded-calcjob-loop': (WorkchainLoopCalcThreaded, 4, 4), +} + + +@pytest.mark.parametrize('workchain,iterations,outgoing', WORKCHAINS.values(), ids=WORKCHAINS.keys()) +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group='engine') +def test_workchain_local(benchmark, aiida_localhost, workchain, iterations, outgoing): + """Benchmark Workchains, executed in the local runner.""" + code = Code(input_plugin_name='arithmetic.add', remote_computer_exec=[aiida_localhost, '/bin/true']) + + def _run(): + return run_get_node(workchain, iterations=Int(iterations), code=code) + + result = benchmark.pedantic(_run, iterations=1, rounds=10, warmup_rounds=1) + + assert result.node.is_finished_ok, (result.node.exit_status, result.node.exit_message) + assert len(result.node.get_outgoing().all()) == outgoing + + +@gen.coroutine +def with_timeout(what, timeout=60): + """Coroutine return with timeout.""" + raise gen.Return((yield gen.with_timeout(datetime.timedelta(seconds=timeout), what))) + + +@gen.coroutine +def wait_for_process(runner, calc_node, timeout=60): + """Coroutine block with timeout.""" + future = runner.get_process_future(calc_node.pk) + raise gen.Return((yield with_timeout(future, timeout))) + + +@pytest.fixture() +def submit_get_node(): + """A test fixture for running a process *via* submission to the daemon, + and blocking until it is complete. + + Adapted from tests/engine/test_rmq.py + """ + manager = get_manager() + runner = manager.get_runner() + # The daemon runner needs to share a common event loop, + # otherwise the local runner will never send the message while the daemon is running listening to intercept. + daemon_runner = manager.create_daemon_runner(loop=runner.loop) + + def _submit(_process, timeout=60, **kwargs): + + @gen.coroutine + def _do_submit(): + node = submit(_process, **kwargs) + yield wait_for_process(runner, node) + return node + + result = runner.loop.run_sync(_do_submit, timeout=timeout) + + return result + + yield _submit + + daemon_runner.close() + + +@pytest.mark.parametrize('workchain,iterations,outgoing', WORKCHAINS.values(), ids=WORKCHAINS.keys()) +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group='engine') +def test_workchain_daemon(benchmark, submit_get_node, aiida_localhost, workchain, iterations, outgoing): + """Benchmark Workchains, executed in the via a daemon runner.""" + code = Code(input_plugin_name='arithmetic.add', remote_computer_exec=[aiida_localhost, '/bin/true']) + + def _run(): + return submit_get_node(workchain, iterations=Int(iterations), code=code) + + result = benchmark.pedantic(_run, iterations=1, rounds=10, warmup_rounds=1) + + assert result.is_finished_ok, (result.exit_status, result.exit_message) + assert len(result.get_outgoing().all()) == outgoing diff --git a/tests/benchmark/test_importexport.py b/tests/benchmark/test_importexport.py new file mode 100644 index 0000000000..d81f8a6e9c --- /dev/null +++ b/tests/benchmark/test_importexport.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=unused-argument,protected-access +"""Performance benchmark tests for import/export utilities. + +The purpose of these tests is to benchmark and compare importing and exporting +parts of the database. +""" +from io import StringIO + +import pytest + +from aiida.common.links import LinkType +from aiida.engine import ProcessState +from aiida.orm import CalcFunctionNode, Dict, load_node +from aiida.tools.importexport import import_data, export + + +def recursive_provenance(in_node, depth, breadth, num_objects=0): + """Recursively build a provenance tree.""" + if not in_node.is_stored: + in_node.store() + if depth < 1: + return + depth -= 1 + for _ in range(breadth): + calcfunc = CalcFunctionNode() + calcfunc.set_process_state(ProcessState.FINISHED) + calcfunc.set_exit_status(0) + calcfunc.add_incoming(in_node, link_type=LinkType.INPUT_CALC, link_label='input') + calcfunc.store() + + out_node = Dict(dict={str(i): i for i in range(10)}) + for idx in range(num_objects): + out_node.put_object_from_filelike(StringIO('a' * 10000), 'key' + str(idx)) + out_node.add_incoming(calcfunc, link_type=LinkType.CREATE, link_label='output') + out_node.store() + + calcfunc.seal() + + recursive_provenance(out_node, depth, breadth, num_objects) + + +def get_export_kwargs(**kwargs): + """Return default export keyword arguments.""" + obj = { + 'silent': True, + 'input_calc_forward': True, + 'input_work_forward': True, + 'create_backward': True, + 'return_backward': True, + 'call_calc_backward': True, + 'call_work_backward': True, + 'include_comments': True, + 'include_logs': True, + 'overwrite': True, + 'use_compression': True + } + obj.update(kwargs) + return obj + + +TREE = {'no-objects': (4, 3, 0), 'with-objects': (4, 3, 2)} + + +@pytest.mark.parametrize('depth,breadth,num_objects', TREE.values(), ids=TREE.keys()) +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group='import-export') +def test_export(benchmark, tmp_path, depth, breadth, num_objects): + """Benchmark exporting a provenance graph.""" + root_node = Dict() + recursive_provenance(root_node, depth=depth, breadth=breadth, num_objects=num_objects) + out_path = tmp_path / 'test.aiida' + kwargs = get_export_kwargs(filename=str(out_path)) + + def _setup(): + if out_path.exists(): + out_path.unlink() + + def _run(): + export([root_node], **kwargs) + + benchmark.pedantic(_run, setup=_setup, iterations=1, rounds=12, warmup_rounds=1) + assert out_path.exists() + + +@pytest.mark.parametrize('depth,breadth,num_objects', TREE.values(), ids=TREE.keys()) +@pytest.mark.benchmark(group='import-export') +def test_import(aiida_profile, benchmark, tmp_path, depth, breadth, num_objects): + """Benchmark importing a provenance graph.""" + aiida_profile.reset_db() + root_node = Dict() + recursive_provenance(root_node, depth=depth, breadth=breadth, num_objects=num_objects) + root_uuid = root_node.uuid + out_path = tmp_path / 'test.aiida' + kwargs = get_export_kwargs(filename=str(out_path)) + export([root_node], **kwargs) + + def _setup(): + aiida_profile.reset_db() + + def _run(): + import_data(str(out_path), silent=True) + + benchmark.pedantic(_run, setup=_setup, iterations=1, rounds=12, warmup_rounds=1) + load_node(root_uuid) diff --git a/tests/benchmark/test_nodes.py b/tests/benchmark/test_nodes.py new file mode 100644 index 0000000000..ad0b1e6a85 --- /dev/null +++ b/tests/benchmark/test_nodes.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=unused-argument,protected-access +"""Performance benchmark tests for single nodes. + +The purpose of these tests is to benchmark and compare basic node interactions, +such as storage and deletion from the database and repository. +""" +from io import StringIO + +import pytest + +from aiida.common import NotExistent +from aiida.orm import Data, load_node + +GROUP_NAME = 'node' + + +def get_data_node(store=True): + """A function to create a simple data node.""" + data = Data() + data.set_attribute_many({str(i): i for i in range(10)}) + if store: + data.store() + return (), {'node': data} + + +def get_data_node_and_object(store=True): + """A function to create a simple data node, with an object.""" + data = Data() + data.set_attribute_many({str(i): i for i in range(10)}) + data.put_object_from_filelike(StringIO('a' * 10000), 'key') + if store: + data.store() + return (), {'node': data} + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME, min_rounds=100) +def test_store_backend(benchmark): + """Benchmark for creating and storing a node directly, + via the backend storage mechanism. + """ + + def _run(): + data = Data() + data.set_attribute_many({str(i): i for i in range(10)}) + data._backend_entity.store(clean=False) + return data + + node = benchmark(_run) + assert node.is_stored, node + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME, min_rounds=100) +def test_store(benchmark): + """Benchmark for creating and storing a node, + via the full ORM mechanism. + """ + _, node_dict = benchmark(get_data_node) + assert node_dict['node'].is_stored, node_dict + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME, min_rounds=100) +def test_store_with_object(benchmark): + """Benchmark for creating and storing a node, + including an object to be stored in the repository. + """ + _, node_dict = benchmark(get_data_node_and_object) + assert node_dict['node'].is_stored, node_dict + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME) +def test_delete_backend(benchmark): + """Benchmark for deleting a stored node directly, + via the backend deletion mechanism. + """ + + def _run(node): + pk = node.pk + Data.objects._backend.nodes.delete(pk) # pylint: disable=no-member + return pk + + pk = benchmark.pedantic(_run, setup=get_data_node, iterations=1, rounds=100, warmup_rounds=1) + with pytest.raises(NotExistent): + load_node(pk) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME) +def test_delete(benchmark): + """Benchmark for deleting a node, + via the full ORM mechanism. + """ + + def _run(node): + pk = node.pk + Data.objects.delete(pk) # pylint: disable=no-member + return pk + + pk = benchmark.pedantic(_run, setup=get_data_node, iterations=1, rounds=100, warmup_rounds=1) + with pytest.raises(NotExistent): + load_node(pk) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.benchmark(group=GROUP_NAME) +def test_delete_with_object(benchmark): + """Benchmark for deleting a node, + including an object stored in the repository + """ + + def _run(node): + pk = node.pk + Data.objects.delete(pk) # pylint: disable=no-member + return pk + + pk = benchmark.pedantic(_run, setup=get_data_node_and_object, iterations=1, rounds=100, warmup_rounds=1) + with pytest.raises(NotExistent): + load_node(pk) diff --git a/tests/calculations/test_templatereplacer.py b/tests/calculations/test_templatereplacer.py new file mode 100644 index 0000000000..ff700ff169 --- /dev/null +++ b/tests/calculations/test_templatereplacer.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the `TemplatereplacerCalculation` plugin.""" +import io +import pytest + +from aiida import orm +from aiida.common import datastructures + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_base_template(fixture_sandbox, aiida_localhost, generate_calc_job): + """Test a base template that emulates the arithmetic add.""" + + entry_point_name = 'templatereplacer' + inputs = { + 'code': + orm.Code(remote_computer_exec=(aiida_localhost, '/bin/bash')), + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'tot_num_mpiprocs': 1 + } + } + }, + 'template': + orm.Dict( + dict={ + 'input_file_template': 'echo $(({x} + {y}))', + 'input_file_name': 'input.txt', + 'cmdline_params': ['input.txt'], + 'output_file_name': 'output.txt', + } + ), + 'parameters': + orm.Dict(dict={ + 'x': 1, + 'y': 2 + }), + } + + # Check the attributes of the resulting `CalcInfo` + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert isinstance(calc_info, datastructures.CalcInfo) + assert sorted(calc_info.retrieve_list) == sorted([inputs['template']['output_file_name']]) + + # Check the integrity of the `codes_info` + codes_info = calc_info.codes_info + assert isinstance(codes_info, list) + assert len(codes_info) == 1 + + # Check the attributes of the resulting `CodeInfo` + code_info = codes_info[0] + assert isinstance(code_info, datastructures.CodeInfo) + assert code_info.code_uuid == inputs['code'].uuid + assert code_info.stdout_name == inputs['template']['output_file_name'] + assert sorted(code_info.cmdline_params) == sorted(inputs['template']['cmdline_params']) + + # Check the content of the generated script + with fixture_sandbox.open(inputs['template']['input_file_name']) as handle: + input_written = handle.read() + assert input_written == 'echo $(({} + {}))'.format(inputs['parameters']['x'], inputs['parameters']['y']) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_file_usage(fixture_sandbox, aiida_localhost, generate_calc_job): + """Test a base template that uses two files.""" + + file1_node = orm.SinglefileData(io.BytesIO(b'Content of file 1')) + file2_node = orm.SinglefileData(io.BytesIO(b'Content of file 2')) + + # Check that the files are correctly copied to the copy list + entry_point_name = 'templatereplacer' + inputs = { + 'code': orm.Code(remote_computer_exec=(aiida_localhost, '/bin/bash')), + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'tot_num_mpiprocs': 1 + } + } + }, + 'template': orm.Dict(dict={ + 'files_to_copy': [('filenode1', 'file1.txt'), ('filenode2', 'file2.txt')], + }), + 'files': { + 'filenode1': file1_node, + 'filenode2': file2_node + } + } + + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + reference_copy_list = [] + for node_idname, target_path in inputs['template']['files_to_copy']: + file_node = inputs['files'][node_idname] + reference_copy_list.append((file_node.uuid, file_node.filename, target_path)) + + assert sorted(calc_info.local_copy_list) == sorted(reference_copy_list) diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index b8072735c2..3ee3a833f2 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -37,7 +37,7 @@ def setUpClass(cls, *args, **kwargs): from aiida.engine import ProcessState cls.computer = orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() cls.code = orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store() diff --git a/tests/cmdline/commands/test_code.py b/tests/cmdline/commands/test_code.py index c09899fd2b..bda2992d40 100644 --- a/tests/cmdline/commands/test_code.py +++ b/tests/cmdline/commands/test_code.py @@ -30,7 +30,7 @@ class TestVerdiCodeSetup(AiidaTestCase): def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls.computer = orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() def setUp(self): @@ -50,12 +50,12 @@ def test_noninteractive_remote(self): label = 'noninteractive_remote' options = [ '--non-interactive', '--label={}'.format(label), '--description=description', - '--input-plugin=arithmetic.add', '--on-computer', '--computer={}'.format(self.computer.name), + '--input-plugin=arithmetic.add', '--on-computer', '--computer={}'.format(self.computer.label), '--remote-abs-path=/remote/abs/path' ] result = self.cli_runner.invoke(setup_code, options) self.assertClickResultNoException(result) - self.assertIsInstance(orm.Code.get_from_string('{}@{}'.format(label, self.computer.name)), orm.Code) + self.assertIsInstance(orm.Code.get_from_string('{}@{}'.format(label, self.computer.label)), orm.Code) def test_noninteractive_upload(self): """Test non-interactive code setup.""" @@ -89,7 +89,7 @@ def test_from_config(self): # local file label = 'noninteractive_config' with tempfile.NamedTemporaryFile('w') as handle: - handle.write(config_file_template.format(label=label, computer=self.computer.name)) + handle.write(config_file_template.format(label=label, computer=self.computer.label)) handle.flush() result = self.cli_runner.invoke( setup_code, @@ -103,7 +103,7 @@ def test_from_config(self): fake_url = 'https://my.url.com' with mock.patch( 'urllib.request.urlopen', - return_value=config_file_template.format(label=label, computer=self.computer.name) + return_value=config_file_template.format(label=label, computer=self.computer.label) ): result = self.cli_runner.invoke(setup_code, ['--non-interactive', '--config', fake_url]) @@ -120,7 +120,7 @@ class TestVerdiCodeCommands(AiidaTestCase): def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls.computer = orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() def setUp(self): @@ -189,7 +189,7 @@ def test_code_list(self): code.label = 'code2' code.store() - options = ['-A', '-a', '-o', '--input-plugin=arithmetic.add', '--computer={}'.format(self.computer.name)] + options = ['-A', '-a', '-o', '--input-plugin=arithmetic.add', '--computer={}'.format(self.computer.label)] result = self.cli_runner.invoke(code_list, options) self.assertIsNone(result.exception, result.output) self.assertTrue(str(self.code.pk) in result.output, 'PK of first code should be included') @@ -243,10 +243,10 @@ def test_code_list_no_codes_error_message(self): def test_interactive_remote(clear_database_before_test, aiida_localhost, non_interactive_editor): """Test interactive remote code setup.""" label = 'interactive_remote' - user_input = '\n'.join([label, 'description', 'arithmetic.add', 'yes', aiida_localhost.name, '/remote/abs/path']) + user_input = '\n'.join([label, 'description', 'arithmetic.add', 'yes', aiida_localhost.label, '/remote/abs/path']) result = CliRunner().invoke(setup_code, input=user_input) assert result.exception is None - assert isinstance(orm.Code.get_from_string('{}@{}'.format(label, aiida_localhost.name)), orm.Code) + assert isinstance(orm.Code.get_from_string('{}@{}'.format(label, aiida_localhost.label)), orm.Code) @pytest.mark.parametrize('non_interactive_editor', ('sleep 1; vim -cwq',), indirect=True) @@ -267,10 +267,10 @@ def test_mixed(clear_database_before_test, aiida_localhost, non_interactive_edit from aiida.orm import Code label = 'mixed_remote' options = ['--description=description', '--on-computer', '--remote-abs-path=/remote/abs/path'] - user_input = '\n'.join([label, 'arithmetic.add', aiida_localhost.name]) + user_input = '\n'.join([label, 'arithmetic.add', aiida_localhost.label]) result = CliRunner().invoke(setup_code, options, input=user_input) assert result.exception is None - assert isinstance(Code.get_from_string('{}@{}'.format(label, aiida_localhost.name)), Code) + assert isinstance(Code.get_from_string('{}@{}'.format(label, aiida_localhost.label)), Code) @pytest.mark.parametrize('non_interactive_editor', ('sleep 1; vim -cwq',), indirect=True) diff --git a/tests/cmdline/commands/test_computer.py b/tests/cmdline/commands/test_computer.py index 668728c21f..975b9bec40 100644 --- a/tests/cmdline/commands/test_computer.py +++ b/tests/cmdline/commands/test_computer.py @@ -19,7 +19,7 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands.cmd_computer import computer_setup -from aiida.cmdline.commands.cmd_computer import computer_show, computer_list, computer_rename, computer_delete +from aiida.cmdline.commands.cmd_computer import computer_show, computer_list, computer_relabel, computer_delete from aiida.cmdline.commands.cmd_computer import computer_test, computer_configure, computer_duplicate @@ -131,7 +131,7 @@ def test_mixed(self): options_dict = generate_setup_options_dict(replace_args={'label': label}) options_dict_full = options_dict.copy() - options_dict.pop('non-interactive', 'None') + options_dict.pop('non-interactive', None) non_interactive_options_dict = {} non_interactive_options_dict['prepend-text'] = options_dict.pop('prepend-text') @@ -146,13 +146,13 @@ def test_mixed(self): result = self.cli_runner.invoke(computer_setup, options, input=user_input) self.assertIsNone(result.exception, msg='There was an unexpected exception. Output: {}'.format(result.output)) - new_computer = orm.Computer.objects.get(name=label) + new_computer = orm.Computer.objects.get(label=label) self.assertIsInstance(new_computer, orm.Computer) self.assertEqual(new_computer.description, options_dict_full['description']) self.assertEqual(new_computer.hostname, options_dict_full['hostname']) - self.assertEqual(new_computer.get_transport_type(), options_dict_full['transport']) - self.assertEqual(new_computer.get_scheduler_type(), options_dict_full['scheduler']) + self.assertEqual(new_computer.transport_type, options_dict_full['transport']) + self.assertEqual(new_computer.scheduler_type, options_dict_full['scheduler']) self.assertEqual(new_computer.get_mpirun_command(), options_dict_full['mpirun-command'].split()) self.assertEqual(new_computer.get_shebang(), options_dict_full['shebang']) self.assertEqual(new_computer.get_workdir(), options_dict_full['work-dir']) @@ -173,13 +173,13 @@ def test_noninteractive(self): result = self.cli_runner.invoke(computer_setup, options) self.assertIsNone(result.exception, result.output[-1000:]) - new_computer = orm.Computer.objects.get(name=options_dict['label']) + new_computer = orm.Computer.objects.get(label=options_dict['label']) self.assertIsInstance(new_computer, orm.Computer) self.assertEqual(new_computer.description, options_dict['description']) self.assertEqual(new_computer.hostname, options_dict['hostname']) - self.assertEqual(new_computer.get_transport_type(), options_dict['transport']) - self.assertEqual(new_computer.get_scheduler_type(), options_dict['scheduler']) + self.assertEqual(new_computer.transport_type, options_dict['transport']) + self.assertEqual(new_computer.scheduler_type, options_dict['scheduler']) self.assertEqual(new_computer.get_mpirun_command(), options_dict['mpirun-command'].split()) self.assertEqual(new_computer.get_shebang(), options_dict['shebang']) self.assertEqual(new_computer.get_workdir(), options_dict['work-dir']) @@ -203,7 +203,7 @@ def test_noninteractive_optional_default_mpiprocs(self): # pylint: disable=inva self.assertIsNone(result.exception, result.output[-1000:]) - new_computer = orm.Computer.objects.get(name=options_dict['label']) + new_computer = orm.Computer.objects.get(label=options_dict['label']) self.assertIsInstance(new_computer, orm.Computer) self.assertIsNone(new_computer.get_default_mpiprocs_per_machine()) @@ -218,7 +218,7 @@ def test_noninteractive_optional_default_mpiprocs_2(self): # pylint: disable=in self.assertIsNone(result.exception, result.output[-1000:]) - new_computer = orm.Computer.objects.get(name=options_dict['label']) + new_computer = orm.Computer.objects.get(label=options_dict['label']) self.assertIsInstance(new_computer, orm.Computer) self.assertIsNone(new_computer.get_default_mpiprocs_per_machine()) @@ -298,7 +298,7 @@ def test_noninteractive_from_config(self): result = self.cli_runner.invoke(computer_setup, options) self.assertClickResultNoException(result) - self.assertIsInstance(orm.Computer.objects.get(name=label), orm.Computer) + self.assertIsInstance(orm.Computer.objects.get(label=label), orm.Computer) class TestVerdiComputerConfigure(AiidaTestCase): @@ -370,9 +370,16 @@ def test_local_interactive(self): comp = self.comp_builder.new() comp.store() - result = self.cli_runner.invoke(computer_configure, ['local', comp.label], input='\n', catch_exceptions=False) + command_input = ('{use_login_shell}\n{safe_interval}\n').format(use_login_shell='False', safe_interval='1.0') + result = self.cli_runner.invoke( + computer_configure, ['local', comp.label], input=command_input, catch_exceptions=False + ) self.assertTrue(comp.is_user_configured(self.user), msg=result.output) + new_auth_params = comp.get_authinfo(self.user).get_auth_params() + self.assertEqual(new_auth_params['use_login_shell'], False) + self.assertEqual(new_auth_params['safe_interval'], 1.0) + def test_ssh_interactive(self): """ Check that the interactive prompt is accepting the correct values. @@ -411,6 +418,7 @@ def test_ssh_interactive(self): self.assertEqual(new_auth_params['port'], port) self.assertEqual(new_auth_params['look_for_keys'], look_for_keys) self.assertEqual(new_auth_params['key_filename'], key_filename) + self.assertEqual(new_auth_params['use_login_shell'], True) def test_local_from_config(self): """Test configuring a computer from a config file""" @@ -521,7 +529,7 @@ def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls.computer_name = 'comp_cli_test_computer' cls.comp = orm.Computer( - name=cls.computer_name, + label=cls.computer_name, hostname='localhost', transport_type='local', scheduler_type='direct', @@ -597,55 +605,50 @@ def test_computer_show(self): # Exceptions should arise self.assertIsNotNone(result.exception) - def test_computer_rename(self): + def test_computer_relabel(self): """ - Test if 'verdi computer rename' command works + Test if 'verdi computer relabel' command works """ from aiida.common.exceptions import NotExistent # See if the command complains about not getting an invalid computer - options = ['not_existent_computer_name'] - result = self.cli_runner.invoke(computer_rename, options) - # Exception should be raised + options = ['not_existent_computer_label'] + result = self.cli_runner.invoke(computer_relabel, options) self.assertIsNotNone(result.exception) - # See if the command complains about not getting both names + # See if the command complains about not getting both labels options = ['comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_rename, options) - # Exception should be raised + result = self.cli_runner.invoke(computer_relabel, options) self.assertIsNotNone(result.exception) - # The new name must be different to the old one + # The new label must be different to the old one options = ['comp_cli_test_computer', 'comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_rename, options) - # Exception should be raised + result = self.cli_runner.invoke(computer_relabel, options) self.assertIsNotNone(result.exception) - # Change a computer name successully. - options = ['comp_cli_test_computer', 'renamed_test_computer'] - result = self.cli_runner.invoke(computer_rename, options) - # Exception should be not be raised + # Change a computer label successully. + options = ['comp_cli_test_computer', 'relabeled_test_computer'] + result = self.cli_runner.invoke(computer_relabel, options) self.assertIsNone(result.exception, result.output) - # Check that the name really was changed - # The old name should not be available + # Check that the label really was changed + # The old label should not be available with self.assertRaises(NotExistent): - orm.Computer.objects.get(name='comp_cli_test_computer') - # The new name should be avilable - orm.Computer.objects.get(name='renamed_test_computer') + orm.Computer.objects.get(label='comp_cli_test_computer') + # The new label should be available + orm.Computer.objects.get(label='relabeled_test_computer') - # Now change the name back - options = ['renamed_test_computer', 'comp_cli_test_computer'] - result = self.cli_runner.invoke(computer_rename, options) - # Exception should be not be raised + # Now change the label back + options = ['relabeled_test_computer', 'comp_cli_test_computer'] + result = self.cli_runner.invoke(computer_relabel, options) self.assertIsNone(result.exception, result.output) - # Check that the name really was changed - # The old name should not be available + # Check that the label really was changed + # The old label should not be available with self.assertRaises(NotExistent): - orm.Computer.objects.get(name='renamed_test_computer') - # The new name should be avilable - orm.Computer.objects.get(name='comp_cli_test_computer') + orm.Computer.objects.get(label='relabeled_test_computer') + # The new label should be available + orm.Computer.objects.get(label='comp_cli_test_computer') def test_computer_delete(self): """ @@ -655,7 +658,7 @@ def test_computer_delete(self): # Setup a computer to delete during the test orm.Computer( - name='computer_for_test_delete', + label='computer_for_test_delete', hostname='localhost', transport_type='local', scheduler_type='direct', @@ -675,7 +678,7 @@ def test_computer_delete(self): self.assertClickResultNoException(result) # Check that the computer really was deleted with self.assertRaises(NotExistent): - orm.Computer.objects.get(name='computer_for_test_delete') + orm.Computer.objects.get(label='computer_for_test_delete') def test_computer_duplicate_interactive(self): """Test 'verdi computer duplicate' in interactive mode.""" @@ -688,11 +691,11 @@ def test_computer_duplicate_interactive(self): ) self.assertIsNone(result.exception, result.output) - new_computer = orm.Computer.objects.get(name=label) + new_computer = orm.Computer.objects.get(label=label) self.assertEqual(self.comp.description, new_computer.description) - self.assertEqual(self.comp.get_hostname(), new_computer.get_hostname()) - self.assertEqual(self.comp.get_transport_type(), new_computer.get_transport_type()) - self.assertEqual(self.comp.get_scheduler_type(), new_computer.get_scheduler_type()) + self.assertEqual(self.comp.hostname, new_computer.hostname) + self.assertEqual(self.comp.transport_type, new_computer.transport_type) + self.assertEqual(self.comp.scheduler_type, new_computer.scheduler_type) self.assertEqual(self.comp.get_shebang(), new_computer.get_shebang()) self.assertEqual(self.comp.get_workdir(), new_computer.get_workdir()) self.assertEqual(self.comp.get_mpirun_command(), new_computer.get_mpirun_command()) @@ -709,11 +712,11 @@ def test_computer_duplicate_non_interactive(self): ) self.assertIsNone(result.exception, result.output) - new_computer = orm.Computer.objects.get(name=label) + new_computer = orm.Computer.objects.get(label=label) self.assertEqual(self.comp.description, new_computer.description) - self.assertEqual(self.comp.get_hostname(), new_computer.get_hostname()) - self.assertEqual(self.comp.get_transport_type(), new_computer.get_transport_type()) - self.assertEqual(self.comp.get_scheduler_type(), new_computer.get_scheduler_type()) + self.assertEqual(self.comp.hostname, new_computer.hostname) + self.assertEqual(self.comp.transport_type, new_computer.transport_type) + self.assertEqual(self.comp.scheduler_type, new_computer.scheduler_type) self.assertEqual(self.comp.get_shebang(), new_computer.get_shebang()) self.assertEqual(self.comp.get_workdir(), new_computer.get_workdir()) self.assertEqual(self.comp.get_mpirun_command(), new_computer.get_mpirun_command()) @@ -736,13 +739,13 @@ def test_interactive(clear_database_before_test, aiida_localhost, non_interactiv result = CliRunner().invoke(computer_setup, input=user_input) assert result.exception is None, 'There was an unexpected exception. Output: {}'.format(result.output) - new_computer = orm.Computer.objects.get(name=label) + new_computer = orm.Computer.objects.get(label=label) assert isinstance(new_computer, orm.Computer) assert new_computer.description == options_dict['description'] assert new_computer.hostname == options_dict['hostname'] - assert new_computer.get_transport_type() == options_dict['transport'] - assert new_computer.get_scheduler_type() == options_dict['scheduler'] + assert new_computer.transport_type == options_dict['transport'] + assert new_computer.scheduler_type == options_dict['scheduler'] assert new_computer.get_mpirun_command() == options_dict['mpirun-command'].split() assert new_computer.get_shebang() == options_dict['shebang'] assert new_computer.get_workdir() == options_dict['work-dir'] @@ -750,3 +753,39 @@ def test_interactive(clear_database_before_test, aiida_localhost, non_interactiv # For now I'm not writing anything in them assert new_computer.get_prepend_text() == '' assert new_computer.get_append_text() == '' + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_computer_test_stderr(run_cli_command, aiida_localhost, monkeypatch): + """Test `verdi computer test` where tested command returns non-empty stderr.""" + from aiida.transports.plugins.local import LocalTransport + + aiida_localhost.configure() + stderr = 'spurious output in standard error' + + def exec_command_wait(self, command, **kwargs): + return 0, '', stderr + + monkeypatch.setattr(LocalTransport, 'exec_command_wait', exec_command_wait) + + result = run_cli_command(computer_test, [aiida_localhost.label]) + assert 'Warning: 1 out of 5 tests failed' in result.output + assert stderr in result.output + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_computer_test_stdout(run_cli_command, aiida_localhost, monkeypatch): + """Test `verdi computer test` where tested command returns non-empty stdout.""" + from aiida.transports.plugins.local import LocalTransport + + aiida_localhost.configure() + stdout = 'spurious output in standard output' + + def exec_command_wait(self, command, **kwargs): + return 0, stdout, '' + + monkeypatch.setattr(LocalTransport, 'exec_command_wait', exec_command_wait) + + result = run_cli_command(computer_test, [aiida_localhost.label]) + assert 'Warning: 1 out of 5 tests failed' in result.output + assert stdout in result.output diff --git a/tests/cmdline/commands/test_daemon.py b/tests/cmdline/commands/test_daemon.py index 8d7d85cf67..40cb3c7117 100644 --- a/tests/cmdline/commands/test_daemon.py +++ b/tests/cmdline/commands/test_daemon.py @@ -86,7 +86,8 @@ def test_daemon_restart(self): finally: self.daemon_client.stop_daemon(wait=True) - @pytest.mark.skip(reason='Test fails non-deterministically; see issue #3051.') + # Tracked in issue #3051 + @pytest.mark.flaky(reruns=2) def test_daemon_start_number(self): """Test `verdi daemon start` with a specific number of workers.""" @@ -111,7 +112,8 @@ def test_daemon_start_number(self): finally: self.daemon_client.stop_daemon(wait=True) - @pytest.mark.skip(reason='Test fails non-deterministically; see issue #3051.') + # Tracked in issue #3051 + @pytest.mark.flaky(reruns=2) def test_daemon_start_number_config(self): """Test `verdi daemon start` with `daemon.default_workers` config option being set.""" number = 3 diff --git a/tests/cmdline/commands/test_data.py b/tests/cmdline/commands/test_data.py index e4d507235a..7d9609e327 100644 --- a/tests/cmdline/commands/test_data.py +++ b/tests/cmdline/commands/test_data.py @@ -27,6 +27,7 @@ from aiida.engine import calcfunction from aiida.orm.nodes.data.cif import has_pycifrw from aiida.orm import Group, ArrayData, BandsData, KpointsData, CifData, Dict, RemoteData, StructureData, TrajectoryData +from tests.static import STATIC_DIR class DummyVerdiDataExportable: @@ -341,7 +342,7 @@ def test_bandsexport(self): options = [str(self.ids[DummyVerdiDataListable.NODE_ID_STR])] res = self.cli_runner.invoke(cmd_bands.bands_export, options, catch_exceptions=False) self.assertEqual(res.exit_code, 0, 'The command did not finish correctly') - self.assertIn(b'[1.0, 3.0]', res.stdout_bytes, 'The string [1.0, 3.0] was not found in the bands' 'export') + self.assertIn(b'[1.0, 3.0]', res.stdout_bytes, 'The string [1.0, 3.0] was not found in the bands export') def test_bandsexport_single_kp(self): """ @@ -532,7 +533,7 @@ def create_trajectory_data(): def setUpClass(cls): # pylint: disable=arguments-differ super().setUpClass() orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() cls.ids = cls.create_trajectory_data() @@ -613,7 +614,7 @@ def create_structure_data(): def setUpClass(cls): # pylint: disable=arguments-differ super().setUpClass() orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() cls.ids = cls.create_structure_data() @@ -795,7 +796,7 @@ def setUpClass(cls): # pylint: disable=arguments-differ """Setup class to test CifData.""" super().setUpClass() orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() cls.ids = cls.create_cif_data() @@ -897,7 +898,7 @@ def setUpClass(cls): # pylint: disable=arguments-differ super().setUpClass() def setUp(self): - self.filepath_pseudos = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, 'fixtures', 'pseudos') + self.filepath_pseudos = os.path.join(STATIC_DIR, 'pseudos') self.cli_runner = CliRunner() def upload_family(self): diff --git a/tests/cmdline/commands/test_export.py b/tests/cmdline/commands/test_export.py index 2683f745df..9a404bc3f3 100644 --- a/tests/cmdline/commands/test_export.py +++ b/tests/cmdline/commands/test_export.py @@ -49,7 +49,7 @@ def setUpClass(cls, *args, **kwargs): from aiida import orm cls.computer = orm.Computer( - name='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='comp', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() cls.code = orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store() @@ -181,18 +181,6 @@ def test_migrate_version_specific(self): finally: delete_temporary_file(filename_output) - def test_migrate_versions_recent(self): - """Migrating an archive with the current version should exit with non-zero status.""" - filename_input = get_archive_file(self.newest_archive, filepath=self.fixture_archive) - filename_output = next(tempfile._get_candidate_names()) # pylint: disable=protected-access - - try: - options = [filename_input, filename_output] - result = self.cli_runner.invoke(cmd_export.migrate, options) - self.assertIsNotNone(result.exception) - finally: - delete_temporary_file(filename_output) - def test_migrate_force(self): """Test that passing the -f/--force option will overwrite the output file even if it exists.""" filename_input = get_archive_file(self.penultimate_archive, filepath=self.fixture_archive) @@ -213,6 +201,39 @@ def test_migrate_force(self): self.assertTrue(os.path.isfile(filename_output)) self.assertEqual(zipfile.ZipFile(filename_output).testzip(), None) + def test_migrate_in_place(self): + """Test that passing the -i/--in-place option will overwrite the passed file.""" + archive = 'export_v0.1_simple.aiida' + target_version = '0.2' + filename_input = get_archive_file(archive, filepath=self.fixture_archive) + filename_tmp = next(tempfile._get_candidate_names()) # pylint: disable=protected-access + + try: + # copy file (don't want to overwrite test data) + shutil.copy(filename_input, filename_tmp) + + # specifying both output and in-place should except + options = [filename_tmp, '--in-place', '--output-file', 'test.aiida'] + result = self.cli_runner.invoke(cmd_export.migrate, options) + self.assertIsNotNone(result.exception, result.output) + + # specifying neither output nor in-place should except + options = [filename_tmp] + result = self.cli_runner.invoke(cmd_export.migrate, options) + self.assertIsNotNone(result.exception, result.output) + + # check that in-place migration produces a valid archive in place of the old file + options = [filename_tmp, '--in-place', '--version', target_version] + result = self.cli_runner.invoke(cmd_export.migrate, options) + self.assertIsNone(result.exception, result.output) + self.assertTrue(os.path.isfile(filename_tmp)) + # check that files in zip file are ok + self.assertEqual(zipfile.ZipFile(filename_tmp).testzip(), None) + with Archive(filename_tmp) as archive_object: + self.assertEqual(archive_object.version_format, target_version) + finally: + os.remove(filename_tmp) + def test_migrate_silent(self): """Test that the captured output is an empty string when the -s/--silent option is passed.""" filename_input = get_archive_file(self.penultimate_archive, filepath=self.fixture_archive) diff --git a/tests/cmdline/commands/test_node.py b/tests/cmdline/commands/test_node.py index 733529eb85..c91fc6dc00 100644 --- a/tests/cmdline/commands/test_node.py +++ b/tests/cmdline/commands/test_node.py @@ -61,8 +61,8 @@ def setUpClass(cls, *args, **kwargs): cls.content_file2 = 'the minister of silly walks' cls.key_file1 = 'some/nested/folder/filename.txt' cls.key_file2 = 'some_other_file.txt' - folder_node.put_object_from_filelike(io.StringIO(cls.content_file1), key=cls.key_file1) - folder_node.put_object_from_filelike(io.StringIO(cls.content_file2), key=cls.key_file2) + folder_node.put_object_from_filelike(io.StringIO(cls.content_file1), cls.key_file1) + folder_node.put_object_from_filelike(io.StringIO(cls.content_file2), cls.key_file2) folder_node.store() cls.folder_node = folder_node diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index 642aac19c3..287c82e8e3 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -162,6 +162,10 @@ def setUpClass(cls, *args, **kwargs): if state == ProcessState.FINISHED: calc.set_exit_status(1) + # Set the waiting work chain as paused as well + if state == ProcessState.WAITING: + calc.pause() + calc.store() cls.calcs.append(calc) @@ -194,13 +198,13 @@ def test_list(self): flag_value = 'asc' result = self.cli_runner.invoke(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) self.assertIsNone(result.exception, result.output) - result_num_asc = [l.split()[0] for l in get_result_lines(result)] + result_num_asc = [line.split()[0] for line in get_result_lines(result)] self.assertEqual(len(result_num_asc), 6) flag_value = 'desc' result = self.cli_runner.invoke(cmd_process.process_list, ['-r', '-O', 'id', flag, flag_value]) self.assertIsNone(result.exception, result.output) - result_num_desc = [l.split()[0] for l in get_result_lines(result)] + result_num_desc = [line.split()[0] for line in get_result_lines(result)] self.assertEqual(len(result_num_desc), 6) self.assertEqual(result_num_asc, list(reversed(result_num_desc))) @@ -262,6 +266,12 @@ def test_list(self): for line in get_result_lines(result): self.assertIn(self.process_label, line.strip()) + # There should be exactly one paused + for flag in ['--paused']: + result = self.cli_runner.invoke(cmd_process.process_list, ['-r', flag]) + self.assertClickResultNoException(result) + self.assertEqual(len(get_result_lines(result)), 1) + def test_process_show(self): """Test verdi process show""" # We must choose a Node we can store diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 4ed690bb20..595ed2e131 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -194,7 +194,7 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals ArithmeticAdd = CalculationFactory('arithmetic.add') computer = Computer( - name='localhost-example-{}'.format(sys.argv[1]), + label='localhost-example-{}'.format(sys.argv[1]), hostname='localhost', description='my computer', transport_type='local', diff --git a/tests/cmdline/commands/test_status.py b/tests/cmdline/commands/test_status.py index 83868196f6..4818be3d39 100644 --- a/tests/cmdline/commands/test_status.py +++ b/tests/cmdline/commands/test_status.py @@ -8,12 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi status`.""" +import pytest + from aiida.cmdline.commands import cmd_status from aiida.cmdline.utils.echo import ExitCode def test_status(run_cli_command): - """Test running verdi status.""" + """Test `verdi status`.""" options = [] result = run_cli_command(cmd_status.verdi_status, options) @@ -25,8 +27,16 @@ def test_status(run_cli_command): assert string in result.output +@pytest.mark.usefixtures('create_empty_config_instance') +def test_status_no_profile(run_cli_command): + """Test `verdi status` when there is no profile.""" + options = [] + result = run_cli_command(cmd_status.verdi_status, options) + assert 'no profile configured yet' in result.output + + def test_status_no_rmq(run_cli_command): - """Test running verdi status, with no rmq check""" + """Test `verdi status` without a check for RabbitMQ.""" options = ['--no-rmq'] result = run_cli_command(cmd_status.verdi_status, options) @@ -35,3 +45,18 @@ def test_status_no_rmq(run_cli_command): for string in ['config', 'profile', 'postgres', 'daemon']: assert string in result.output + + +def test_database_incompatible(run_cli_command, monkeypatch): + """Test `verdi status` when database schema version is incompatible with that of the code.""" + from aiida.manage.manager import get_manager + + def get_backend(): + from aiida.common.exceptions import IncompatibleDatabaseSchema + raise IncompatibleDatabaseSchema() + + monkeypatch.setattr(get_manager(), 'get_backend', get_backend) + + result = run_cli_command(cmd_status.verdi_status, raises=True) + assert 'Database schema version is incompatible with the code: run `verdi database migrate`.' in result.output + assert result.exit_code is ExitCode.CRITICAL diff --git a/tests/cmdline/params/options/test_interactive.py b/tests/cmdline/params/options/test_interactive.py index 04899ab19a..bca18a7c7a 100644 --- a/tests/cmdline/params/options/test_interactive.py +++ b/tests/cmdline/params/options/test_interactive.py @@ -281,7 +281,7 @@ def test_non_interactive(self): def test_non_interactive_default(self): """ scenario: InteractiveOption, invoked with only --non-interactive - behaviour: fail + behaviour: success """ cmd = self.simple_command(default='default') runner = CliRunner() diff --git a/tests/cmdline/params/types/test_code.py b/tests/cmdline/params/types/test_code.py index a2464f64d7..a3fd64e954 100644 --- a/tests/cmdline/params/types/test_code.py +++ b/tests/cmdline/params/types/test_code.py @@ -69,7 +69,7 @@ def test_get_by_label(setup_codes, parameter_type): def test_get_by_fullname(setup_codes, parameter_type): """Verify that using the LABEL@machinename will retrieve the correct entity.""" entity_01, entity_02, entity_03 = setup_codes - identifier = '{}@{}'.format(entity_01.label, entity_01.computer.name) + identifier = '{}@{}'.format(entity_01.label, entity_01.computer.label) result = parameter_type.convert(identifier, None, None) assert result.uuid == entity_01.uuid diff --git a/tests/cmdline/params/types/test_computer.py b/tests/cmdline/params/types/test_computer.py index b6523a52c8..f365e9684f 100644 --- a/tests/cmdline/params/types/test_computer.py +++ b/tests/cmdline/params/types/test_computer.py @@ -36,9 +36,9 @@ def setUpClass(cls, *args, **kwargs): } cls.param = ComputerParamType() - cls.entity_01 = orm.Computer(name='computer_01', **kwargs).store() - cls.entity_02 = orm.Computer(name=str(cls.entity_01.pk), **kwargs).store() - cls.entity_03 = orm.Computer(name=str(cls.entity_01.uuid), **kwargs).store() + cls.entity_01 = orm.Computer(label='computer_01', **kwargs).store() + cls.entity_02 = orm.Computer(label=str(cls.entity_01.pk), **kwargs).store() + cls.entity_03 = orm.Computer(label=str(cls.entity_01.uuid), **kwargs).store() def test_get_by_id(self): """ @@ -60,7 +60,7 @@ def test_get_by_label(self): """ Verify that using the LABEL will retrieve the correct entity """ - identifier = '{}'.format(self.entity_01.name) + identifier = '{}'.format(self.entity_01.label) result = self.param.convert(identifier, None, None) self.assertEqual(result.uuid, self.entity_01.uuid) @@ -71,11 +71,11 @@ def test_ambiguous_label_pk(self): Verify that using an ambiguous identifier gives precedence to the ID interpretation Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL """ - identifier = '{}'.format(self.entity_02.name) + identifier = '{}'.format(self.entity_02.label) result = self.param.convert(identifier, None, None) self.assertEqual(result.uuid, self.entity_01.uuid) - identifier = '{}{}'.format(self.entity_02.name, OrmEntityLoader.label_ambiguity_breaker) + identifier = '{}{}'.format(self.entity_02.label, OrmEntityLoader.label_ambiguity_breaker) result = self.param.convert(identifier, None, None) self.assertEqual(result.uuid, self.entity_02.uuid) @@ -86,10 +86,10 @@ def test_ambiguous_label_uuid(self): Verify that using an ambiguous identifier gives precedence to the UUID interpretation Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL """ - identifier = '{}'.format(self.entity_03.name) + identifier = '{}'.format(self.entity_03.label) result = self.param.convert(identifier, None, None) self.assertEqual(result.uuid, self.entity_01.uuid) - identifier = '{}{}'.format(self.entity_03.name, OrmEntityLoader.label_ambiguity_breaker) + identifier = '{}{}'.format(self.entity_03.label, OrmEntityLoader.label_ambiguity_breaker) result = self.param.convert(identifier, None, None) self.assertEqual(result.uuid, self.entity_03.uuid) diff --git a/tests/cmdline/utils/test_common.py b/tests/cmdline/utils/test_common.py index c755f34345..0bd6480cd9 100644 --- a/tests/cmdline/utils/test_common.py +++ b/tests/cmdline/utils/test_common.py @@ -21,7 +21,7 @@ def test_get_node_summary(self): """Test the `get_node_summary` utility.""" from aiida.cmdline.utils.common import get_node_summary - computer_label = self.computer.name # pylint: disable=no-member + computer_label = self.computer.label # pylint: disable=no-member code = orm.Code( input_plugin_name='arithmetic.add', diff --git a/tests/cmdline/utils/test_multiline.py b/tests/cmdline/utils/test_multiline.py index 8731972f30..fb7cc168f7 100644 --- a/tests/cmdline/utils/test_multiline.py +++ b/tests/cmdline/utils/test_multiline.py @@ -11,41 +11,11 @@ """Unit tests for editing pre and post bash scripts, comments, etc.""" import pytest -from aiida.cmdline.utils.multi_line_input import edit_pre_post, edit_comment +from aiida.cmdline.utils.multi_line_input import edit_comment COMMAND = 'sleep 1 ; vim -c "g!/^#=/s/$/Test" -cwq' # Appends `Test` to every line NOT starting with `#=` -@pytest.mark.parametrize('non_interactive_editor', (COMMAND,), indirect=True) -def test_pre_post(non_interactive_editor): - result = edit_pre_post(summary={'Param 1': 'Value 1', 'Param 2': 'Value 1'}) - assert result[0] == 'Test\nTest\nTest' - assert result[1] == 'Test\nTest\nTest' - - -@pytest.mark.parametrize('non_interactive_editor', (COMMAND,), indirect=True) -def test_edit_pre_post(non_interactive_editor): - result = edit_pre_post(pre='OldPre', post='OldPost') - assert result[0] == 'Test\nOldPreTest\nTest' - assert result[1] == 'Test\nOldPostTest\nTest' - - -@pytest.mark.parametrize('non_interactive_editor', (COMMAND,), indirect=True) -def test_edit_pre_post_comment(non_interactive_editor): - """Test that lines starting with '#=' are ignored and are not ignored if they start with any other character.""" - result = edit_pre_post(pre='OldPre\n#=Delete me', post='OldPost #=Dont delete me') - assert result[0] == 'Test\nOldPreTest\nTest' - assert result[1] == 'Test\nOldPost #=Dont delete meTest\nTest' - - -@pytest.mark.parametrize('non_interactive_editor', (COMMAND,), indirect=True) -def test_edit_pre_bash_comment(non_interactive_editor): - """Test that bash comments starting with '#' are NOT deleted.""" - result = edit_pre_post(pre='OldPre\n# Dont delete me', post='OldPost # Dont delete me') - assert result[0] == 'Test\nOldPreTest\n# Dont delete meTest\nTest' - assert result[1] == 'Test\nOldPost # Dont delete meTest\nTest' - - @pytest.mark.parametrize('non_interactive_editor', (COMMAND,), indirect=True) def test_new_comment(non_interactive_editor): new_comment = edit_comment() diff --git a/tests/cmdline/utils/test_repository.py b/tests/cmdline/utils/test_repository.py index e65823ab80..0fae8bf272 100644 --- a/tests/cmdline/utils/test_repository.py +++ b/tests/cmdline/utils/test_repository.py @@ -28,8 +28,8 @@ def runner(): def folder_data(): """Create a `FolderData` instance with basic file and directory structure.""" node = FolderData() - node.put_object_from_filelike(io.StringIO(''), key='nested/file.txt') - node.put_object_from_filelike(io.StringIO(''), key='file.txt') + node.put_object_from_filelike(io.StringIO(''), 'nested/file.txt') + node.put_object_from_filelike(io.StringIO(''), 'file.txt') return node diff --git a/tests/common/test_serialize.py b/tests/common/test_serialize.py index 1a83dea592..720456678f 100644 --- a/tests/common/test_serialize.py +++ b/tests/common/test_serialize.py @@ -75,7 +75,7 @@ def test_serialize_computer_round_trip(self): # pylint: disable=no-member self.assertEqual(computer.uuid, deserialized.uuid) - self.assertEqual(computer.name, deserialized.name) + self.assertEqual(computer.label, deserialized.label) def test_serialize_unstored_node(self): """Test that you can't serialize an unstored node""" diff --git a/tests/conftest.py b/tests/conftest.py index f68803aa24..fa9876bdff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-outer-name """Configuration file for pytest tests.""" -import pytest # pylint: disable=unused-import +import os + +import pytest + +from aiida.manage.configuration import Config, Profile, get_config pytest_plugins = ['aiida.manage.tests.pytest_fixtures'] # pylint: disable=invalid-name @@ -24,7 +29,6 @@ def non_interactive_editor(request): :param request: the command to set for the editor that is to be called """ - import os from unittest.mock import patch from click._termui_impl import Editor @@ -32,7 +36,6 @@ def non_interactive_editor(request): os.environ['VISUAL'] = request.param def edit_file(self, filename): - import os import subprocess import click @@ -74,21 +77,23 @@ def generate_calc_job(): to it, into which the raw input files will have been written. """ - def _generate_calc_job(folder, entry_point_name, inputs=None): + def _generate_calc_job(folder, entry_point_name, inputs=None, return_process=False): """Fixture to generate a mock `CalcInfo` for testing calculation jobs.""" from aiida.engine.utils import instantiate_process from aiida.manage.manager import get_manager from aiida.plugins import CalculationFactory + inputs = inputs or {} manager = get_manager() runner = manager.get_runner() process_class = CalculationFactory(entry_point_name) process = instantiate_process(runner, process_class, **inputs) - calc_info = process.prepare_for_submission(folder) + if return_process: + return process - return calc_info + return process.prepare_for_submission(folder) return _generate_calc_job @@ -144,3 +149,113 @@ def _generate_calculation_node(process_state=ProcessState.FINISHED, exit_status= return node return _generate_calculation_node + + +@pytest.fixture +def create_empty_config_instance(tmp_path) -> Config: + """Create a temporary configuration instance. + + This creates a temporary directory with a clean `.aiida` folder and basic configuration file. The currently loaded + configuration and profile are stored in memory and are automatically restored at the end of this context manager. + + :return: a new empty config instance. + """ + from aiida.common.utils import Capturing + from aiida.manage import configuration + from aiida.manage.configuration import settings, load_profile, reset_profile + + # Store the current configuration instance and config directory path + current_config = configuration.CONFIG + current_config_path = current_config.dirpath + current_profile_name = configuration.PROFILE.name + + reset_profile() + configuration.CONFIG = None + + # Create a temporary folder, set it as the current config directory path and reset the loaded configuration + settings.AIIDA_CONFIG_FOLDER = str(tmp_path) + + # Create the instance base directory structure, the config file and a dummy profile + settings.create_instance_directories() + + # The constructor of `Config` called by `load_config` will print warning messages about migrating it + with Capturing(): + configuration.CONFIG = configuration.load_config(create=True) + + yield get_config() + + # Reset the config folder path and the config instance. Note this will always be executed after the yield no + # matter what happened in the test that used this fixture. + reset_profile() + settings.AIIDA_CONFIG_FOLDER = current_config_path + configuration.CONFIG = current_config + load_profile(current_profile_name) + + +@pytest.fixture +def create_profile() -> Profile: + """Create a new profile instance. + + :return: the profile instance. + """ + + def _create_profile(name, **kwargs): + + repository_dirpath = kwargs.pop('repository_dirpath', get_config().dirpath) + + profile_dictionary = { + 'default_user': kwargs.pop('default_user', 'dummy@localhost'), + 'database_engine': kwargs.pop('database_engine', 'postgresql_psycopg2'), + 'database_backend': kwargs.pop('database_backend', 'django'), + 'database_hostname': kwargs.pop('database_hostname', 'localhost'), + 'database_port': kwargs.pop('database_port', 5432), + 'database_name': kwargs.pop('database_name', name), + 'database_username': kwargs.pop('database_username', 'user'), + 'database_password': kwargs.pop('database_password', 'pass'), + 'repository_uri': 'file:///' + os.path.join(repository_dirpath, 'repository_' + name), + } + + return Profile(name, profile_dictionary) + + return _create_profile + + +@pytest.fixture +def backend(): + """Get the ``Backend`` instance of the currently loaded profile.""" + from aiida.manage.manager import get_manager + return get_manager().get_backend() + + +@pytest.fixture +def skip_if_not_django(backend): + """Fixture that will skip any test that uses it when a profile is loaded with any other backend then Django.""" + from aiida.orm.implementation.django.backend import DjangoBackend + if not isinstance(backend, DjangoBackend): + pytest.skip('this test should only be run for the Django backend.') + + +@pytest.fixture +def skip_if_not_sqlalchemy(backend): + """Fixture that will skip any test that uses it when a profile is loaded with any other backend then SqlAlchemy.""" + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + if not isinstance(backend, SqlaBackend): + pytest.skip('this test should only be run for the SqlAlchemy backend.') + + +@pytest.fixture(scope='function') +def override_logging(): + """Return a `SandboxFolder`.""" + from aiida.common.log import configure_logging + + config = get_config() + + try: + config.set_option('logging.aiida_loglevel', 'DEBUG') + config.set_option('logging.db_loglevel', 'DEBUG') + configure_logging(with_orm=True) + yield + finally: + config.unset_option('logging.aiida_loglevel') + config.unset_option('logging.db_loglevel') + configure_logging(with_orm=True) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py new file mode 100644 index 0000000000..dbc9aec95c --- /dev/null +++ b/tests/engine/daemon/test_execmanager.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the :mod:`aiida.engine.daemon.execmanager` module.""" +import os +import pytest + +from aiida.engine.daemon import execmanager +from aiida.transports.plugins.local import LocalTransport + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_retrieve_files_from_list(tmp_path_factory, generate_calculation_node): + """Test the `retrieve_files_from_list` function.""" + node = generate_calculation_node() + + retrieve_list = [ + 'file_a.txt', + ('sub/folder', 'sub/folder', 0), + ] + + source = tmp_path_factory.mktemp('source') + target = tmp_path_factory.mktemp('target') + + content_a = b'content_a' + content_b = b'content_b' + + with open(str(source / 'file_a.txt'), 'wb') as handle: + handle.write(content_a) + handle.flush() + + os.makedirs(str(source / 'sub' / 'folder')) + + with open(str(source / 'sub' / 'folder' / 'file_b.txt'), 'wb') as handle: + handle.write(content_b) + handle.flush() + + with LocalTransport() as transport: + transport.chdir(str(source)) + execmanager.retrieve_files_from_list(node, transport, str(target), retrieve_list) + + assert sorted(os.listdir(str(target))) == sorted(['file_a.txt', 'sub']) + assert os.listdir(str(target / 'sub')) == ['folder'] + assert os.listdir(str(target / 'sub' / 'folder')) == ['file_b.txt'] + + with open(str(target / 'sub' / 'folder' / 'file_b.txt'), 'rb') as handle: + assert handle.read() == content_b + + with open(str(target / 'file_a.txt'), 'rb') as handle: + assert handle.read() == content_a diff --git a/tests/engine/processes/text_exit_code.py b/tests/engine/processes/test_exit_code.py similarity index 100% rename from tests/engine/processes/text_exit_code.py rename to tests/engine/processes/test_exit_code.py diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index 5c28a06585..dbea7970b4 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `aiida.engine.processes.workchains.restart` module.""" -# pylint: disable=invalid-name,inconsistent-return-statements,no-self-use,no-member +# pylint: disable=invalid-name,no-self-use,no-member import pytest from aiida import engine diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 51a787235e..efcc28537b 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -217,3 +217,22 @@ def disabled_handler(self, node): pass assert not SomeWorkChain.disabled_handler.enabled # pylint: disable=no-member + + def test_empty_exit_codes_list(self): + """A `process_handler` with an empty `exit_codes` list should not run.""" + + class SomeWorkChain(BaseRestartWorkChain): + _process_class = ArithmeticAddCalculation + + @process_handler(exit_codes=[]) + def should_not_run(self, node): + raise ValueError('This should not run.') + + child = ProcessNode() + child.set_process_state(ProcessState.FINISHED) + + process = SomeWorkChain() + process.setup() + process.ctx.iteration = 1 + process.ctx.children = [child] + process.inspect_process() diff --git a/tests/engine/test_calc_job.py b/tests/engine/test_calc_job.py index 38ef789484..94d18b6efe 100644 --- a/tests/engine/test_calc_job.py +++ b/tests/engine/test_calc_job.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-public-methods,redefined-outer-name """Test for the `CalcJob` process sub class.""" from copy import deepcopy from functools import partial @@ -17,8 +18,8 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase -from aiida.common import exceptions -from aiida.engine import launch, CalcJob, Process +from aiida.common import exceptions, LinkType, CalcJobState +from aiida.engine import launch, CalcJob, Process, ExitCode from aiida.engine.processes.ports import PortNamespace from aiida.plugins import CalculationFactory @@ -81,6 +82,23 @@ def setUpClass(cls, *args, **kwargs): cls.local_code = orm.Code(local_executable='bash', files=['/bin/bash']).store() cls.inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'metadata': {'options': {}}} + def instantiate_process(self, state=CalcJobState.PARSING): + """Instantiate a process with default inputs and return the `Process` instance.""" + from aiida.engine.utils import instantiate_process + from aiida.manage.manager import get_manager + + inputs = deepcopy(self.inputs) + inputs['code'] = self.remote_code + + manager = get_manager() + runner = manager.get_runner() + + process_class = CalculationFactory('arithmetic.add') + process = instantiate_process(runner, process_class, **inputs) + process.node.set_state(state) + + return process + def setUp(self): super().setUp() self.assertIsNone(Process.current()) @@ -318,3 +336,225 @@ def test_provenance_exclude_list(self): self.assertIn('base', node.list_object_names()) self.assertEqual(sorted(['b']), sorted(node.list_object_names(os.path.join('base')))) self.assertEqual(['two'], node.list_object_names(os.path.join('base', 'b'))) + + def test_parse_no_retrieved_folder(self): + """Test the `CalcJob.parse` method when there is no retrieved folder.""" + process = self.instantiate_process() + exit_code = process.parse() + assert exit_code == process.exit_codes.ERROR_NO_RETRIEVED_FOLDER + + def test_parse_retrieved_folder(self): + """Test the `CalcJob.parse` method when there is a retrieved folder.""" + process = self.instantiate_process() + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + exit_code = process.parse() + + # The following exit code is specific to the `ArithmeticAddCalculation` we are testing here and is returned + # because the retrieved folder does not contain the output file it expects + assert exit_code == process.exit_codes.ERROR_READING_OUTPUT_FILE + + +@pytest.fixture +def process(aiida_local_code_factory): + """Instantiate a process with default inputs and return the `Process` instance.""" + from aiida.engine.utils import instantiate_process + from aiida.manage.manager import get_manager + + inputs = { + 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash'), + 'x': orm.Int(1), + 'y': orm.Int(2), + 'metadata': { + 'options': {} + } + } + + manager = get_manager() + runner = manager.get_runner() + + process_class = CalculationFactory('arithmetic.add') + process = instantiate_process(runner, process_class, **inputs) + process.node.set_state(CalcJobState.PARSING) + + return process + + +@pytest.mark.usefixtures('clear_database_before_test', 'override_logging') +def test_parse_insufficient_data(process): + """Test the scheduler output parsing logic in `CalcJob.parse`. + + Here we check explicitly that the parsing does not except even if the required information is not available. + """ + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + process.parse() + + filename_stderr = process.node.get_option('scheduler_stderr') + filename_stdout = process.node.get_option('scheduler_stdout') + + # The scheduler parsing requires three resources of information, the `detailed_job_info` dictionary which is + # stored as an attribute on the calculation job node and the output of the stdout and stderr which are both + # stored in the repository. In this test, we haven't created these on purpose. This should not except the + # process but should log a warning, so here we check that those expected warnings are attached to the node + logs = [log.message for log in orm.Log.objects.get_logs_for(process.node)] + expected_logs = [ + 'could not parse scheduler output: the `detailed_job_info` attribute is missing', + 'could not parse scheduler output: the `{}` file is missing'.format(filename_stderr), + 'could not parse scheduler output: the `{}` file is missing'.format(filename_stdout) + ] + + for log in expected_logs: + assert log in logs + + +@pytest.mark.usefixtures('clear_database_before_test', 'override_logging') +def test_parse_non_zero_retval(process): + """Test the scheduler output parsing logic in `CalcJob.parse`. + + This is testing the case where the `detailed_job_info` is incomplete because the call failed. This is checked + through the return value that is stored within the attribute dictionary. + """ + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + + process.node.set_attribute('detailed_job_info', {'retval': 1, 'stderr': 'accounting disabled', 'stdout': ''}) + process.parse() + + logs = [log.message for log in orm.Log.objects.get_logs_for(process.node)] + assert 'could not parse scheduler output: return value of `detailed_job_info` is non-zero' in logs + + +@pytest.mark.usefixtures('clear_database_before_test', 'override_logging') +def test_parse_not_implemented(process): + """Test the scheduler output parsing logic in `CalcJob.parse`. + + Here we check explicitly that the parsing does not except even if the scheduler does not implement the method. + """ + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + + process.node.set_attribute('detailed_job_info', {}) + + filename_stderr = process.node.get_option('scheduler_stderr') + filename_stdout = process.node.get_option('scheduler_stdout') + + with retrieved.open(filename_stderr, 'w') as handle: + handle.write('\n') + + with retrieved.open(filename_stdout, 'w') as handle: + handle.write('\n') + + process.parse() + + # The `DirectScheduler` at this point in time does not implement the `parse_output` method. Instead of raising + # a warning message should be logged. We verify here that said message is present. + logs = [log.message for log in orm.Log.objects.get_logs_for(process.node)] + expected_logs = ['`DirectScheduler` does not implement scheduler output parsing'] + + for log in expected_logs: + assert log in logs + + +@pytest.mark.usefixtures('clear_database_before_test', 'override_logging') +def test_parse_scheduler_excepted(process, monkeypatch): + """Test the scheduler output parsing logic in `CalcJob.parse`. + + Here we check explicitly the case where the `Scheduler.parse_output` method excepts + """ + from aiida.schedulers.plugins.direct import DirectScheduler + + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + + process.node.set_attribute('detailed_job_info', {}) + + filename_stderr = process.node.get_option('scheduler_stderr') + filename_stdout = process.node.get_option('scheduler_stdout') + + with retrieved.open(filename_stderr, 'w') as handle: + handle.write('\n') + + with retrieved.open(filename_stdout, 'w') as handle: + handle.write('\n') + + msg = 'crash' + + def raise_exception(*args, **kwargs): + raise RuntimeError(msg) + + # Monkeypatch the `DirectScheduler.parse_output` to raise an exception + monkeypatch.setattr(DirectScheduler, 'parse_output', raise_exception) + process.parse() + logs = [log.message for log in orm.Log.objects.get_logs_for(process.node)] + expected_logs = ['the `parse_output` method of the scheduler excepted: {}'.format(msg)] + + for log in expected_logs: + assert log in logs + + +@pytest.mark.parametrize(('exit_status_scheduler', 'exit_status_retrieved', 'final'), ( + (None, None, 0), + (100, None, 100), + (None, 400, 400), + (100, 400, 400), + (100, 0, 0), +)) +@pytest.mark.usefixtures('clear_database_before_test') +def test_parse_exit_code_priority( + exit_status_scheduler, + exit_status_retrieved, + final, + generate_calc_job, + fixture_sandbox, + aiida_local_code_factory, + monkeypatch, +): # pylint: disable=too-many-arguments + """Test the logic around exit codes in the `CalcJob.parse` method. + + The `parse` method will first call the `Scheduler.parse_output` method, which if implemented by the relevant + scheduler plugin, will parse the scheduler output and potentially return an exit code. Next, the output parser + plugin is called if defined in the inputs that can also optionally return an exit code. This test is designed + to make sure the right logic is implemented in terms of which exit code should be dominant. + + Scheduler result | Retrieved result | Final result | Scenario + -----------------|------------------|-----------------|----------------------------------------- + `None` | `None` | `ExitCode(0)` | Neither parser found any problem + `ExitCode(100)` | `None` | `ExitCode(100)` | Scheduler found issue, output parser does not override + `None` | `ExitCode(400)` | `ExitCode(400)` | Only output parser found a problem + `ExitCode(100)` | `ExitCode(400)` | `ExitCode(400)` | Scheduler found issue, but output parser overrides + | | | with a more specific error code + `ExitCode(100)` | `ExitCode(0)` | `ExitCode(0)` | Scheduler found issue but output parser overrides saying + | | | that despite that the calculation should be considered + | | | finished successfully. + + To test this, we just need to test the `CalcJob.parse` method and the easiest way is to simply mock the scheduler + parser and output parser calls called `parse_scheduler_output` and `parse_retrieved_output`, respectively. We will + just mock them by a simple method that returns `None` or an `ExitCode`. We then check that the final exit code + returned by `CalcJob.parse` is the one we expect according to the table above. + """ + from aiida.orm import Int + + def parse_scheduler_output(_, __): + if exit_status_scheduler is not None: + return ExitCode(exit_status_scheduler) + + def parse_retrieved_output(_, __): + if exit_status_retrieved is not None: + return ExitCode(exit_status_retrieved) + + monkeypatch.setattr(CalcJob, 'parse_scheduler_output', parse_scheduler_output) + monkeypatch.setattr(CalcJob, 'parse_retrieved_output', parse_retrieved_output) + + inputs = { + 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash'), + 'x': Int(1), + 'y': Int(2), + } + process = generate_calc_job(fixture_sandbox, 'arithmetic.add', inputs, return_process=True) + retrieved = orm.FolderData().store() + retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) + + result = process.parse() + assert isinstance(result, ExitCode) + assert result.status == final diff --git a/tests/engine/test_launch.py b/tests/engine/test_launch.py index d7ec21bf93..d259ee5121 100644 --- a/tests/engine/test_launch.py +++ b/tests/engine/test_launch.py @@ -31,7 +31,9 @@ def define(cls, spec): def prepare_for_submission(self, folder): from aiida.common.datastructures import CalcInfo, CodeInfo - local_copy_list = [(self.inputs.single_file.uuid, self.inputs.single_file.filename, 'single_file')] + # Use nested path for the target filename, where the directory does not exist, to check that the engine will + # create intermediate directories as needed. Regression test for #4350 + local_copy_list = [(self.inputs.single_file.uuid, self.inputs.single_file.filename, 'path/single_file')] for name, node in self.inputs.files.items(): local_copy_list.append((node.uuid, node.filename, name)) @@ -286,5 +288,5 @@ def test_calcjob_dry_run_no_provenance(self): _, node = launch.run_get_node(FileCalcJob, **inputs) self.assertIn('folder', node.dry_run_info) - for filename in ['single_file', 'file_one', 'file_two']: + for filename in ['path', 'file_one', 'file_two']: self.assertIn(filename, os.listdir(node.dry_run_info['folder'])) diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 0dded70fe7..f66ce82656 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -7,12 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines,missing-function-docstring,invalid-name,missing-class-docstring,no-self-use +"""Tests for the `WorkChain` class.""" import inspect import unittest import plumpy -import pytest from tornado import gen +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase @@ -55,7 +57,7 @@ def run_until_waiting(proc): in_waiting.set_result(True) else: - def on_waiting(waiting_proc): + def on_waiting(_): in_waiting.set_result(True) proc.remove_process_listener(listener) @@ -80,6 +82,8 @@ def run_and_check_success(process_class, **kwargs): class Wf(WorkChain): + """"Dummy work chain implementation with various steps and logical constructs in the outline.""" + # Keep track of which steps were completed by the workflow finished_steps = {} @@ -90,10 +94,10 @@ def define(cls, spec): spec.input('n', default=lambda: Int(3)) spec.outputs.dynamic = True spec.outline( - cls.s1, - if_(cls.isA)(cls.s2).elif_(cls.isB)(cls.s3).else_(cls.s4), - cls.s5, - while_(cls.ltN)(cls.s6), + cls.step1, + if_(cls.is_a)(cls.step2).elif_(cls.is_b)(cls.step3).else_(cls.step4), + cls.step5, + while_(cls.larger_then_n)(cls.step6,), ) def on_create(self): @@ -101,40 +105,40 @@ def on_create(self): # Reset the finished step self.finished_steps = { k: False for k in [ - self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, - self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__ + self.step1.__name__, self.step2.__name__, self.step3.__name__, self.step4.__name__, self.step5.__name__, + self.step6.__name__, self.is_a.__name__, self.is_b.__name__, self.larger_then_n.__name__ ] } - def s1(self): + def step1(self): self._set_finished(inspect.stack()[0][3]) - def s2(self): + def step2(self): self._set_finished(inspect.stack()[0][3]) - def s3(self): + def step3(self): self._set_finished(inspect.stack()[0][3]) - def s4(self): + def step4(self): self._set_finished(inspect.stack()[0][3]) - def s5(self): + def step5(self): self.ctx.counter = 0 self._set_finished(inspect.stack()[0][3]) - def s6(self): + def step6(self): self.ctx.counter = self.ctx.counter + 1 self._set_finished(inspect.stack()[0][3]) - def isA(self): + def is_a(self): self._set_finished(inspect.stack()[0][3]) return self.inputs.value.value == 'A' - def isB(self): + def is_b(self): self._set_finished(inspect.stack()[0][3]) return self.inputs.value.value == 'B' - def ltN(self): + def larger_then_n(self): keep_looping = self.ctx.counter < self.inputs.n.value if not keep_looping: self._set_finished(inspect.stack()[0][3]) @@ -145,6 +149,8 @@ def _set_finished(self, function_name): class PotentialFailureWorkChain(WorkChain): + """Work chain that can finish with a non-zero exit code.""" + EXIT_STATUS = 1 EXIT_MESSAGE = 'Well you did ask for it' OUTPUT_LABEL = 'optional_output' @@ -168,18 +174,17 @@ def failure(self): # Returning either 0 or ExitCode with non-zero status should terminate the workchain if self.inputs.through_exit_code.value is False: return self.EXIT_STATUS - else: - return self.exit_codes.EXIT_STATUS - else: - # Returning 0 or ExitCode with zero status should *not* terminate the workchain - if self.inputs.through_exit_code.value is False: - return 0 - else: - return ExitCode() + + return self.exit_codes.EXIT_STATUS # pylint: disable=no-member + + # Returning 0 or ExitCode with zero status should *not* terminate the workchain + if self.inputs.through_exit_code.value is False: + return 0 + + return ExitCode() def success(self): self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE).store()) - return class TestExitStatus(AiidaTestCase): @@ -190,7 +195,7 @@ class TestExitStatus(AiidaTestCase): """ def test_failing_workchain_through_integer(self): - result, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False)) + _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False)) self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) self.assertEqual(node.exit_message, None) self.assertEqual(node.is_finished, True) @@ -199,7 +204,7 @@ def test_failing_workchain_through_integer(self): self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) def test_failing_workchain_through_exit_code(self): - result, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True)) + _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True)) self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) self.assertEqual(node.exit_message, PotentialFailureWorkChain.EXIT_MESSAGE) self.assertEqual(node.is_finished, True) @@ -208,27 +213,31 @@ def test_failing_workchain_through_exit_code(self): self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) def test_successful_workchain_through_integer(self): - result, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True)) + _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True)) self.assertEqual(node.exit_status, 0) self.assertEqual(node.is_finished, True) self.assertEqual(node.is_finished_ok, True) self.assertEqual(node.is_failed, False) self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual(node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), - PotentialFailureWorkChain.OUTPUT_VALUE) + self.assertEqual( + node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + PotentialFailureWorkChain.OUTPUT_VALUE + ) def test_successful_workchain_through_exit_code(self): - result, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True)) + _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True)) self.assertEqual(node.exit_status, 0) self.assertEqual(node.is_finished, True) self.assertEqual(node.is_finished_ok, True) self.assertEqual(node.is_failed, False) self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual(node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), - PotentialFailureWorkChain.OUTPUT_VALUE) + self.assertEqual( + node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + PotentialFailureWorkChain.OUTPUT_VALUE + ) def test_return_out_of_outline(self): - result, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True)) + _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True)) self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) self.assertEqual(node.is_finished, True) self.assertEqual(node.is_finished_ok, False) @@ -268,7 +277,7 @@ def test_attributes(self): del wc.ctx.new_attr with self.assertRaises(AttributeError): - wc.ctx.new_attr + wc.ctx.new_attr # pylint: disable=pointless-statement def test_dict(self): wc = IfTest() @@ -277,11 +286,13 @@ def test_dict(self): del wc.ctx['new_attr'] with self.assertRaises(KeyError): - wc.ctx['new_attr'] + wc.ctx['new_attr'] # pylint: disable=pointless-statement class TestWorkchain(AiidaTestCase): + # pylint: disable=too-many-public-methods + def setUp(self): super().setUp() self.assertIsNone(Process.current()) @@ -317,26 +328,26 @@ def test_run(self): launch.run(Wf, value=A, n=three) # Check the steps that should have been run for step, finished in Wf.finished_steps.items(): - if step not in ['s3', 's4', 'isB']: + if step not in ['step3', 'step4', 'is_b']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) # Try the elif(..) part finished_steps = launch.run(Wf, value=B, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): - if step not in ['isA', 's2', 's4']: + if step not in ['is_a', 'step2', 'step4']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) # Try the else... part finished_steps = launch.run(Wf, value=C, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): - if step not in ['isA', 's2', 'isB', 's3']: + if step not in ['is_a', 'step2', 'is_b', 'step3']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) def test_incorrect_outline(self): - class Wf(WorkChain): + class IncorrectOutline(WorkChain): @classmethod def define(cls, spec): @@ -345,7 +356,7 @@ def define(cls, spec): spec.outline(5) with self.assertRaises(TypeError): - Wf.spec() + IncorrectOutline.spec() def test_define_not_calling_super(self): """A `WorkChain` that does not call super in `define` classmethod should raise.""" @@ -377,11 +388,11 @@ def illegal(self): self.out('not_allowed', orm.Int(2)) with self.assertRaises(ValueError): - result = launch.run(IllegalWorkChain) + launch.run(IllegalWorkChain) def test_same_input_node(self): - class Wf(WorkChain): + class SimpleWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -396,7 +407,7 @@ def check_a_b(self): assert 'b' in self.inputs x = Int(1) - run_and_check_success(Wf, a=x, b=x) + run_and_check_success(SimpleWorkChain, a=x, b=x) def test_context(self): A = Str('a').store() @@ -426,7 +437,7 @@ def define(cls, spec): def result(self): self.out('res', B) - class Wf(WorkChain): + class OverrideContextWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -447,7 +458,7 @@ def s3(self): test_case.assertEqual(self.ctx.r1.outputs.res, B) test_case.assertEqual(self.ctx.r2.outputs.res, B) - run_and_check_success(Wf) + run_and_check_success(OverrideContextWorkChain) def test_unstored_nodes_in_context(self): @@ -495,21 +506,21 @@ def test_checkpointing(self): finished_steps = self._run_with_checkpoints(Wf, inputs={'value': A, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): - if step not in ['s3', 's4', 'isB']: + if step not in ['step3', 'step4', 'is_b']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) # Try the elif(..) part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': B, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): - if step not in ['isA', 's2', 's4']: + if step not in ['is_a', 'step2', 'step4']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) # Try the else... part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': C, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): - if step not in ['isA', 's2', 'isB', 's3']: + if step not in ['is_a', 'step2', 'is_b', 'step3']: self.assertTrue(finished, 'Step {} was not called by workflow'.format(step)) def test_return(self): @@ -519,12 +530,12 @@ class WcWithReturn(WorkChain): @classmethod def define(cls, spec): super().define(spec) - spec.outline(cls.s1, if_(cls.isA)(return_), cls.after) + spec.outline(cls.s1, if_(cls.is_a)(return_), cls.after) def s1(self): pass - def isA(self): + def is_a(self): return True def after(self): @@ -547,12 +558,12 @@ class MainWorkChain(WorkChain): @classmethod def define(cls, spec): super().define(spec) - spec.outline(cls.launch) + spec.outline(cls.do_launch) - def launch(self): + def do_launch(self): # "Call" a calculation function simply by running it inputs = {'metadata': {'call_link_label': label_calcfunction}} - calculation_function(**inputs) + calculation_function(**inputs) # pylint: disable=unexpected-keyword-arg # Call a sub work chain inputs = {'metadata': {'call_link_label': label_workchain}} @@ -564,7 +575,8 @@ class SubWorkChain(WorkChain): process = run_and_check_success(MainWorkChain) # Verify that the `CALL` link of the calculation function is there with the correct label - link_triple = process.node.get_outgoing(link_type=LinkType.CALL_CALC, link_label_filter=label_calcfunction).one() + link_triple = process.node.get_outgoing(link_type=LinkType.CALL_CALC, + link_label_filter=label_calcfunction).one() self.assertIsInstance(link_triple.node, orm.CalcFunctionNode) # Verify that the `CALL` link of the work chain is there with the correct label @@ -643,7 +655,7 @@ def define(cls, spec): def do_run(self): pks = [] - for i in range(2): + for _ in range(2): node = self.submit(SubWorkChain) pks.append(node.pk) self.to_context(subwc=node) @@ -696,7 +708,7 @@ def run_async(workchain): self.assertTrue(workchain.ctx.s1) self.assertTrue(workchain.ctx.s2) - runner.loop.run_sync(lambda: run_async(wc)) + runner.loop.run_sync(lambda: run_async(wc)) # pylint: disable=unnecessary-lambda def test_report_dbloghandler(self): """ @@ -717,7 +729,6 @@ def define(cls, spec): def run(self): orm.Log.objects.delete_all() self.report('Testing the report function') - return def check(self): logs = self._backend.logs.find() @@ -821,18 +832,19 @@ def run(self): self.assertEqual(wc.exit_codes.SOME_EXIT_CODE.status, status) with self.assertRaises(AttributeError): - wc.exit_codes.NON_EXISTENT_ERROR + wc.exit_codes.NON_EXISTENT_ERROR # pylint: disable=pointless-statement - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status, status) - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message, message) + self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status, status) # pylint: disable=no-member + self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message, message) # pylint: disable=no-member - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status, status) - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message, message) + self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status, status) # pylint: disable=unsubscriptable-object + self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message, message) # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes[label].status, status) - self.assertEqual(ExitCodeWorkChain.exit_codes[label].message, message) + self.assertEqual(ExitCodeWorkChain.exit_codes[label].status, status) # pylint: disable=unsubscriptable-object + self.assertEqual(ExitCodeWorkChain.exit_codes[label].message, message) # pylint: disable=unsubscriptable-object - def _run_with_checkpoints(self, wf_class, inputs=None): + @staticmethod + def _run_with_checkpoints(wf_class, inputs=None): if inputs is None: inputs = {} proc = run_and_check_success(wf_class, **inputs) @@ -884,7 +896,7 @@ def run_async(): yield process.future() runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) + runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda self.assertEqual(process.node.is_finished_ok, False) self.assertEqual(process.node.is_excepted, True) @@ -910,7 +922,7 @@ def run_async(): launch.run(process) runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) + runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda self.assertEqual(process.node.is_finished_ok, False) self.assertEqual(process.node.is_excepted, False) @@ -996,7 +1008,7 @@ def run_async(): yield process.future() runner.schedule(process) - runner.loop.run_sync(lambda: run_async()) + runner.loop.run_sync(lambda: run_async()) # pylint: disable=unnecessary-lambda child = process.node.get_outgoing(link_type=LinkType.CALL_WORK).first().node self.assertEqual(child.is_finished_ok, False) @@ -1027,7 +1039,7 @@ def test_immutable_input(self): """ test_class = self - class Wf(WorkChain): + class FrozenDictWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -1055,7 +1067,7 @@ def step_two(self): test_class.assertNotIn('c', self.inputs) test_class.assertEqual(self.inputs['a'].value, 1) - run_and_check_success(Wf, a=Int(1), b=Int(2)) + run_and_check_success(FrozenDictWorkChain, a=Int(1), b=Int(2)) def test_immutable_input_groups(self): """ @@ -1063,7 +1075,7 @@ def test_immutable_input_groups(self): """ test_class = self - class Wf(WorkChain): + class ImmutableGroups(WorkChain): @classmethod def define(cls, spec): @@ -1090,7 +1102,7 @@ def step_two(self): test_class.assertNotIn('four', self.inputs.subspace) test_class.assertEqual(self.inputs.subspace['one'].value, 1) - run_and_check_success(Wf, subspace={'one': Int(1), 'two': Int(2)}) + run_and_check_success(ImmutableGroups, subspace={'one': Int(1), 'two': Int(2)}) class SerializeWorkChain(WorkChain): @@ -1126,13 +1138,15 @@ def tearDown(self): super().tearDown() self.assertIsNone(Process.current()) - def test_serialize(self): + @staticmethod + def test_serialize(): """ Test a simple serialization of a class to its identifier. """ run_and_check_success(SerializeWorkChain, test=Int, reference=Str(ObjectLoader().identify_object(Int))) - def test_serialize_builder(self): + @staticmethod + def test_serialize_builder(): """ Test serailization when using a builder. """ @@ -1155,7 +1169,8 @@ def define(cls, spec): def do_run(self): return ToContext( - child=self.submit(ParentExposeWorkChain, **self.exposed_inputs(ParentExposeWorkChain, namespace='sub.sub'))) + child=self.submit(ParentExposeWorkChain, **self.exposed_inputs(ParentExposeWorkChain, namespace='sub.sub')) + ) def finalize(self): self.out_many(self.exposed_outputs(self.ctx.child, ParentExposeWorkChain, namespace='sub.sub')) @@ -1195,11 +1210,14 @@ def start_children(self): child_1 = self.submit( ChildExposeWorkChain, a=self.exposed_inputs(ChildExposeWorkChain)['a'], - **self.exposed_inputs(ChildExposeWorkChain, namespace='sub_1', agglomerate=False)) - child_2 = self.submit(ChildExposeWorkChain, **self.exposed_inputs( - ChildExposeWorkChain, - namespace='sub_2.sub_3', - )) + **self.exposed_inputs(ChildExposeWorkChain, namespace='sub_1', agglomerate=False) + ) + child_2 = self.submit( + ChildExposeWorkChain, **self.exposed_inputs( + ChildExposeWorkChain, + namespace='sub_2.sub_3', + ) + ) return ToContext(child_1=child_1, child_2=child_2) def finalize(self): @@ -1266,9 +1284,10 @@ def test_expose(self): 'c': Bool(False) } } - }) + } + ) - @unittest.skip('Functionality of `WorkChain.exposed_outputs` is broken.') + @unittest.skip('Functionality of `Process.exposed_outputs` is broken for nested namespaces, see issue #3533.') def test_nested_expose(self): res = launch.run( GrandParentExposeWorkChain, @@ -1285,7 +1304,9 @@ def test_nested_expose(self): 'c': Bool(False) } }, - ))) + ) + ) + ) self.assertEqual( res, { 'sub': { @@ -1303,7 +1324,8 @@ def test_nested_expose(self): } } } - }) + } + ) @pytest.mark.filterwarnings('ignore::UserWarning') def test_issue_1741_expose_inputs(self): @@ -1365,7 +1387,8 @@ def illegal_submit(self): from aiida.engine import submit submit(TestWorkChainMisc.PointlessWorkChain) - def test_run_pointless_workchain(self): + @staticmethod + def test_run_pointless_workchain(): """Running the pointless workchain should not incur any exceptions""" launch.run(TestWorkChainMisc.PointlessWorkChain) @@ -1415,13 +1438,14 @@ def test_unique_default_inputs(self): nodes. """ inputs = {'child_one': {}, 'child_two': {}} - result, node = launch.run.get_node(TestDefaultUniqueness.Parent, **inputs) + _, node = launch.run.get_node(TestDefaultUniqueness.Parent, **inputs) nodes = node.get_incoming().all_nodes() - uuids = set([n.uuid for n in nodes]) + uuids = {n.uuid for n in nodes} # Trying to load one of the inputs through the UUID should fail, # as both `child_one.a` and `child_two.a` should have the same UUID. node = load_node(uuid=node.get_incoming().get_node_by_label('child_one__a').uuid) self.assertEqual( - len(uuids), len(nodes), 'Only {} unique UUIDS for {} input nodes'.format(len(uuids), len(nodes))) + len(uuids), len(nodes), 'Only {} unique UUIDS for {} input nodes'.format(len(uuids), len(nodes)) + ) diff --git a/tests/manage/backup/test_backup_script.py b/tests/manage/backup/test_backup_script.py index 29fea2394d..6d780c31ed 100644 --- a/tests/manage/backup/test_backup_script.py +++ b/tests/manage/backup/test_backup_script.py @@ -16,6 +16,7 @@ import tempfile from dateutil.parser import parse +import pytest from aiida.backends.testbase import AiidaTestCase from aiida.common import utils, json @@ -281,6 +282,8 @@ def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls._bs_instance = backup_setup.BackupSetup() + # Tracked in issue #2134 + @pytest.mark.flaky(reruns=2) def test_integration(self): """Test integration""" from aiida.common.utils import Capturing diff --git a/tests/manage/configuration/migrations/test_migrations.py b/tests/manage/configuration/migrations/test_migrations.py index 28b8b1e2d1..4251b3116a 100644 --- a/tests/manage/configuration/migrations/test_migrations.py +++ b/tests/manage/configuration/migrations/test_migrations.py @@ -63,3 +63,10 @@ def test_2_3_migration(self): config_reference = self.load_config_sample('reference/3.json') config_migrated = _MIGRATION_LOOKUP[2].apply(config_initial) self.assertEqual(config_migrated, config_reference) + + def test_3_4_migration(self): + """Test the step between config versions 3 and 4.""" + config_initial = self.load_config_sample('input/3.json') + config_reference = self.load_config_sample('reference/4.json') + config_migrated = _MIGRATION_LOOKUP[3].apply(config_initial) + self.assertEqual(config_migrated, config_reference) diff --git a/tests/manage/configuration/migrations/test_samples/input/3.json b/tests/manage/configuration/migrations/test_samples/input/3.json new file mode 100644 index 0000000000..7a5a731750 --- /dev/null +++ b/tests/manage/configuration/migrations/test_samples/input/3.json @@ -0,0 +1 @@ +{"CONFIG_VERSION": {"CURRENT": 3, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd"}}} diff --git a/tests/manage/configuration/migrations/test_samples/reference/4.json b/tests/manage/configuration/migrations/test_samples/reference/4.json new file mode 100644 index 0000000000..0f3df0ec40 --- /dev/null +++ b/tests/manage/configuration/migrations/test_samples/reference/4.json @@ -0,0 +1 @@ +{"CONFIG_VERSION": {"CURRENT": 4, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} diff --git a/tests/manage/configuration/migrations/test_samples/reference/final.json b/tests/manage/configuration/migrations/test_samples/reference/final.json index 7a5a731750..0f3df0ec40 100644 --- a/tests/manage/configuration/migrations/test_samples/reference/final.json +++ b/tests/manage/configuration/migrations/test_samples/reference/final.json @@ -1 +1 @@ -{"CONFIG_VERSION": {"CURRENT": 3, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd"}}} +{"CONFIG_VERSION": {"CURRENT": 4, "OLDEST_COMPATIBLE": 3}, "default_profile": "default", "profiles": {"default": {"PROFILE_UUID": "00000000000000000000000000000000", "AIIDADB_ENGINE": "postgresql_psycopg2", "AIIDADB_PASS": "some_random_password", "AIIDADB_NAME": "aiidadb_qs_some_user", "AIIDADB_HOST": "localhost", "AIIDADB_BACKEND": "django", "AIIDADB_PORT": "5432", "default_user_email": "email@aiida.net", "AIIDADB_REPOSITORY_URI": "file:////home/some_user/.aiida/repository-quicksetup/", "AIIDADB_USER": "aiida_qs_greschd", "broker_protocol": "amqp", "broker_username": "guest", "broker_password": "guest", "broker_host": "127.0.0.1", "broker_port": 5672, "broker_virtual_host": ""}}} diff --git a/tests/manage/external/test_rmq.py b/tests/manage/external/test_rmq.py new file mode 100644 index 0000000000..5c27e6962f --- /dev/null +++ b/tests/manage/external/test_rmq.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the `aiida.manage.external.rmq` module.""" +import pytest + +from aiida.manage.external import rmq + + +@pytest.mark.parametrize(('args', 'kwargs', 'expected'), ( + ((), {}, 'amqp://guest:guest@127.0.0.1:5672?'), + ((), {'heartbeat': 1}, 'amqp://guest:guest@127.0.0.1:5672?'), + ((), {'invalid_parameters': 1}, ValueError), + ((), {'cafile': 'file', 'cadata': 'ab'}, 'amqp://guest:guest@127.0.0.1:5672?'), + (('amqps', 'jojo', 'rabbit', '192.168.0.1', 6783), {}, 'amqps://jojo:rabbit@192.168.0.1:6783?'), +)) # yapf: disable +def test_get_rmq_url(args, kwargs, expected): + """Test the `get_rmq_url` method. + + It is not possible to use a complete hardcoded URL to compare to the return value of `get_rmq_url` because the order + of the query parameters are arbitrary. Therefore, we just compare the rest of the URL and make sure that all query + parameters are present in the expected form separately. + """ + if isinstance(expected, str): + url = rmq.get_rmq_url(*args, **kwargs) + assert url.startswith(expected) + for key, value in kwargs.items(): + assert '{}={}'.format(key, value) in url + else: + with pytest.raises(expected): + rmq.get_rmq_url(*args, **kwargs) diff --git a/tests/orm/data/test_code.py b/tests/orm/data/test_code.py index 8a334962e2..9426608f8c 100644 --- a/tests/orm/data/test_code.py +++ b/tests/orm/data/test_code.py @@ -47,7 +47,7 @@ def test_get_full_text_info(create_codes): assert ['List of files/folders:', ''] in full_text_info else: assert ['Type', 'remote'] in full_text_info - assert ['Remote machine', code.computer.name] in full_text_info + assert ['Remote machine', code.computer.label] in full_text_info assert ['Remote absolute path', code.get_remote_exec_path()] in full_text_info for code in create_codes: diff --git a/tests/orm/data/test_data.py b/tests/orm/data/test_data.py index 8fe588943a..bfd8f391b7 100644 --- a/tests/orm/data/test_data.py +++ b/tests/orm/data/test_data.py @@ -11,9 +11,11 @@ import os import numpy +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase +from tests.static import STATIC_DIR class TestData(AiidaTestCase): @@ -22,18 +24,17 @@ class TestData(AiidaTestCase): @staticmethod def generate_class_instance(data_class): """Generate a dummy `Data` instance for the given sub class.""" - dirpath_fixtures = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, 'fixtures')) if data_class is orm.CifData: - instance = data_class(file=os.path.join(dirpath_fixtures, 'data', 'Si.cif')) + instance = data_class(file=os.path.join(STATIC_DIR, 'data', 'Si.cif')) return instance if data_class is orm.UpfData: - filename = os.path.join(dirpath_fixtures, 'pseudos', 'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF') + filename = os.path.join(STATIC_DIR, 'pseudos', 'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF') instance = data_class(file=filename) return instance if data_class is orm.StructureData: - instance = orm.CifData(file=os.path.join(dirpath_fixtures, 'data', 'Si.cif')).get_structure() + instance = orm.CifData(file=os.path.join(STATIC_DIR, 'data', 'Si.cif')).get_structure() return instance if data_class is orm.BandsData: @@ -55,9 +56,7 @@ def generate_class_instance(data_class): return instance if data_class is orm.UpfData: - filepath_base = os.path.abspath( - os.path.join(__file__, os.pardir, os.pardir, os.pardir, 'fixtures', 'pseudos') - ) + filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) filepath_carbon = os.path.join(filepath_base, 'C_pbe_v1.2.uspp.F.UPF') instance = data_class(file=filepath_carbon) return instance @@ -67,6 +66,8 @@ def generate_class_instance(data_class): 'for this data class, add a generator of a dummy instance here'.format(data_class) ) + # Tracked in issue #4281 + @pytest.mark.flaky(reruns=2) def test_data_exporters(self): """Verify that the return value of the export methods of all `Data` sub classes have the correct type. diff --git a/tests/orm/data/test_dict.py b/tests/orm/data/test_dict.py index 4172f04741..87416dd3fb 100644 --- a/tests/orm/data/test_dict.py +++ b/tests/orm/data/test_dict.py @@ -39,3 +39,14 @@ def test_get_item(self): """Test the `__getitem__` method.""" self.assertEqual(self.node['value'], self.dictionary['value']) self.assertEqual(self.node['nested'], self.dictionary['nested']) + + def test_set_item(self): + """Test the methods for setting the item. + + * `__setitem__` directly on the node + * `__setattr__` through the `AttributeManager` returned by the `dict` property + """ + self.node['value'] = 2 + self.assertEqual(self.node['value'], 2) + self.node.dict.value = 3 + self.assertEqual(self.node['value'], 3) diff --git a/tests/orm/data/test_upf.py b/tests/orm/data/test_upf.py index 02922bc60f..70094f46ed 100644 --- a/tests/orm/data/test_upf.py +++ b/tests/orm/data/test_upf.py @@ -22,6 +22,7 @@ from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import ParsingError from aiida.orm.nodes.data.upf import parse_upf +from tests.static import STATIC_DIR def isnumeric(vector): @@ -80,7 +81,7 @@ class TestUpfParser(AiidaTestCase): @classmethod def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) - filepath_base = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, os.pardir, 'fixtures', 'pseudos')) + filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) cls.filepath_barium = os.path.join(filepath_base, 'Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF') cls.filepath_oxygen = os.path.join(filepath_base, 'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF') cls.filepath_carbon = os.path.join(filepath_base, 'C_pbe_v1.2.uspp.F.UPF') @@ -325,7 +326,7 @@ def test_upf1_to_json_carbon(self): """Test UPF check Oxygen UPF1 pp conversion""" # pylint: disable=protected-access json_string, _ = self.pseudo_carbon._prepare_json() - filepath_base = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, os.pardir, 'fixtures', 'pseudos')) + filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) reference_dict = json.load(open(os.path.join(filepath_base, 'C.json'), 'r')) pp_dict = json.loads(json_string.decode('utf-8')) # remove path information @@ -337,7 +338,7 @@ def test_upf2_to_json_barium(self): """Test UPF check Bariium UPF1 pp conversion""" # pylint: disable=protected-access json_string, _ = self.pseudo_barium._prepare_json() - filepath_base = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, os.pardir, 'fixtures', 'pseudos')) + filepath_base = os.path.abspath(os.path.join(STATIC_DIR, 'pseudos')) reference_dict = json.load(open(os.path.join(filepath_base, 'Ba.json'), 'r')) pp_dict = json.loads(json_string.decode('utf-8')) # remove path information diff --git a/tests/orm/implementation/test_groups.py b/tests/orm/implementation/test_groups.py new file mode 100644 index 0000000000..3e19e88cc0 --- /dev/null +++ b/tests/orm/implementation/test_groups.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Unit tests for the BackendGroup and BackendGroupCollection classes.""" +import pytest + +from aiida import orm + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_query(backend): + """Test if queries are working.""" + from aiida.common.exceptions import NotExistent, MultipleObjectsError + + default_user = backend.users.create('simple@ton.com') + + g_1 = backend.groups.create(label='testquery1', user=default_user).store() + g_2 = backend.groups.create(label='testquery2', user=default_user).store() + + n_1 = orm.Data().store().backend_entity + n_2 = orm.Data().store().backend_entity + n_3 = orm.Data().store().backend_entity + n_4 = orm.Data().store().backend_entity + + g_1.add_nodes([n_1, n_2]) + g_2.add_nodes([n_1, n_3]) + + newuser = backend.users.create(email='test@email.xx') + g_3 = backend.groups.create(label='testquery3', user=newuser).store() + + # I should find it + g_1copy = backend.groups.get(uuid=g_1.uuid) + assert g_1.pk == g_1copy.pk + + # Try queries + res = backend.groups.query(nodes=n_4) + assert [_.pk for _ in res] == [] + + res = backend.groups.query(nodes=n_1) + assert [_.pk for _ in res] == [_.pk for _ in [g_1, g_2]] + + res = backend.groups.query(nodes=n_2) + assert [_.pk for _ in res] == [_.pk for _ in [g_1]] + + # I try to use 'get' with zero or multiple results + with pytest.raises(NotExistent): + backend.groups.get(nodes=n_4) + with pytest.raises(MultipleObjectsError): + backend.groups.get(nodes=n_1) + + assert backend.groups.get(nodes=n_2).pk == g_1.pk + + # Query by user + res = backend.groups.query(user=newuser) + assert set(_.pk for _ in res) == set(_.pk for _ in [g_3]) + + # Same query, but using a string (the username=email) instead of a DbUser object + res = backend.groups.query(user=newuser) + assert set(_.pk for _ in res) == set(_.pk for _ in [g_3]) + + res = backend.groups.query(user=default_user) + + assert set(_.pk for _ in res) == set(_.pk for _ in [g_1, g_2]) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_creation_from_dbgroup(backend): + """Test creation of a group from another group.""" + node = orm.Data().store() + + default_user = backend.users.create('test@aiida.net').store() + group = backend.groups.create(label='testgroup_from_dbgroup', user=default_user).store() + + group.store() + group.add_nodes([node.backend_entity]) + + dbgroup = group.dbmodel + gcopy = backend.groups.from_dbmodel(dbgroup) + + assert group.pk == gcopy.pk + assert group.uuid == gcopy.uuid + + +@pytest.mark.usefixtures('clear_database_before_test', 'skip_if_not_sqlalchemy') +def test_add_nodes_skip_orm(): + """Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag.""" + group = orm.Group(label='test_adding_nodes').store().backend_entity + + node_01 = orm.Data().store().backend_entity + node_02 = orm.Data().store().backend_entity + node_03 = orm.Data().store().backend_entity + node_04 = orm.Data().store().backend_entity + node_05 = orm.Data().store().backend_entity + nodes = [node_01, node_02, node_03, node_04, node_05] + + group.add_nodes([node_01], skip_orm=True) + group.add_nodes([node_02, node_03], skip_orm=True) + group.add_nodes((node_04, node_05), skip_orm=True) + + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + # Try to add a node that is already present: there should be no problem + group.add_nodes([node_01], skip_orm=True) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + +@pytest.mark.usefixtures('clear_database_before_test', 'skip_if_not_sqlalchemy') +def test_add_nodes_skip_orm_batch(): + """Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches.""" + nodes = [orm.Data().store().backend_entity for _ in range(100)] + + # Add nodes to groups using different batch size. Check in the end the correct addition. + batch_sizes = (1, 3, 10, 1000) + for batch_size in batch_sizes: + group = orm.Group(label='test_batches_' + str(batch_size)).store() + group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + +@pytest.mark.usefixtures('clear_database_before_test', 'skip_if_not_sqlalchemy') +def test_remove_nodes_bulk(): + """Test node removal with `skip_orm=True`.""" + group = orm.Group(label='test_removing_nodes').store().backend_entity + + node_01 = orm.Data().store().backend_entity + node_02 = orm.Data().store().backend_entity + node_03 = orm.Data().store().backend_entity + node_04 = orm.Data().store().backend_entity + nodes = [node_01, node_02, node_03] + + group.add_nodes(nodes) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + # Remove a node that is not in the group: nothing should happen + group.remove_nodes([node_04], skip_orm=True) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + # Remove one Node + nodes.remove(node_03) + group.remove_nodes([node_03], skip_orm=True) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + # Remove a list of Nodes and check + nodes.remove(node_01) + nodes.remove(node_02) + group.remove_nodes([node_01, node_02], skip_orm=True) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) diff --git a/tests/orm/implementation/test_utils.py b/tests/orm/implementation/test_utils.py new file mode 100644 index 0000000000..d04ca4e3b6 --- /dev/null +++ b/tests/orm/implementation/test_utils.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Unit tests for the backend non-specific utility methods.""" +import math + +from aiida.backends.testbase import AiidaTestCase +from aiida.common import exceptions +from aiida.orm.implementation.utils import validate_attribute_extra_key, clean_value, FIELD_SEPARATOR + + +class TestOrmImplementationUtils(AiidaTestCase): + """Test the utility methods in aiida.orm.implementation.utils""" + + def test_invalid_attribute_extra_key(self): + """Test supplying an invalid key to the `validate_attribute_extra_key` method.""" + non_string_key = 5 + field_separator_key = 'invalid' + FIELD_SEPARATOR + 'key' + + with self.assertRaises(exceptions.ValidationError): + validate_attribute_extra_key(non_string_key) + + with self.assertRaises(exceptions.ValidationError): + validate_attribute_extra_key(field_separator_key) + + def test_invalid_value(self): + """Test supplying nan and inf values to the `clean_value` method.""" + nan_value = math.nan + inf_value = math.inf + + with self.assertRaises(exceptions.ValidationError): + clean_value(nan_value) + + with self.assertRaises(exceptions.ValidationError): + clean_value(inf_value) diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index dd53f83c89..6b66d9df3f 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -335,6 +335,20 @@ def test_delete_extra(self): def test_delete_extra_many(self): """Test the `Node.delete_extra_many` method.""" + self.node.set_extra('valid_key', 'value') + self.assertEqual(self.node.get_extra('valid_key'), 'value') + self.node.delete_extra('valid_key') + + with self.assertRaises(AttributeError): + self.node.delete_extra('valid_key') + + # Repeat with stored group + self.node.set_extra('valid_key', 'value') + self.node.store() + + self.node.delete_extra('valid_key') + with self.assertRaises(AttributeError): + load_node(self.node.pk).get_extra('valid_key') def test_clear_extras(self): """Test the `Node.clear_extras` method.""" diff --git a/tests/orm/test_computers.py b/tests/orm/test_computers.py index 2a14c3abe1..e8af574fb8 100644 --- a/tests/orm/test_computers.py +++ b/tests/orm/test_computers.py @@ -24,7 +24,7 @@ def test_get_transport(self): import tempfile new_comp = orm.Computer( - name='bbb', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' + label='bbb', hostname='localhost', transport_type='local', scheduler_type='direct', workdir='/tmp/aiida' ).store() # Configure the computer - no parameters for local transport @@ -43,7 +43,7 @@ def test_get_transport(self): def test_delete(self): """Test the deletion of a `Computer` instance.""" new_comp = orm.Computer( - name='aaa', hostname='aaa', transport_type='local', scheduler_type='pbspro', workdir='/tmp/aiida' + label='aaa', hostname='aaa', transport_type='local', scheduler_type='pbspro', workdir='/tmp/aiida' ).store() comp_pk = new_comp.pk diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index e598983697..e2833967c9 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -420,3 +420,182 @@ def test_query_with_group(): loaded = builder.one()[0] assert loaded.pk == group.pk + + +class TestGroupExtras(AiidaTestCase): + """Test the property and methods of group extras.""" + + def setUp(self): + super().setUp() + for group in orm.Group.objects.all(): + orm.Group.objects.delete(group.id) + self.group = orm.Group('test_extras') + + def test_extras(self): + """Test the `Group.extras` property.""" + original_extra = {'nested': {'a': 1}} + + self.group.set_extra('key', original_extra) + group_extras = self.group.extras + self.assertEqual(group_extras['key'], original_extra) + group_extras['key']['nested']['a'] = 2 + + self.assertEqual(original_extra['nested']['a'], 2) + + # Now store the group and verify that `extras` then returns a deep copy + self.group.store() + group_extras = self.group.extras + + # We change the returned group extras but the original extra should remain unchanged + group_extras['key']['nested']['a'] = 3 + self.assertEqual(original_extra['nested']['a'], 2) + + def test_get_extra(self): + """Test the `Group.get_extra` method.""" + original_extra = {'nested': {'a': 1}} + + self.group.set_extra('key', original_extra) + group_extra = self.group.get_extra('key') + self.assertEqual(group_extra, original_extra) + group_extra['nested']['a'] = 2 + + self.assertEqual(original_extra['nested']['a'], 2) + + default = 'default' + self.assertEqual(self.group.get_extra('not_existing', default=default), default) + with self.assertRaises(AttributeError): + self.group.get_extra('not_existing') + + # Now store the group and verify that `get_extra` then returns a deep copy + self.group.store() + group_extra = self.group.get_extra('key') + + # We change the returned group extras but the original extra should remain unchanged + group_extra['nested']['a'] = 3 + self.assertEqual(original_extra['nested']['a'], 2) + + default = 'default' + self.assertEqual(self.group.get_extra('not_existing', default=default), default) + with self.assertRaises(AttributeError): + self.group.get_extra('not_existing') + + def test_get_extra_many(self): + """Test the `Group.get_extra_many` method.""" + original_extra = {'nested': {'a': 1}} + + self.group.set_extra('key', original_extra) + group_extra = self.group.get_extra_many(['key'])[0] + self.assertEqual(group_extra, original_extra) + group_extra['nested']['a'] = 2 + + self.assertEqual(original_extra['nested']['a'], 2) + + # Now store the group and verify that `get_extra` then returns a deep copy + self.group.store() + group_extra = self.group.get_extra_many(['key'])[0] + + # We change the returned group extras but the original extra should remain unchanged + group_extra['nested']['a'] = 3 + self.assertEqual(original_extra['nested']['a'], 2) + + def test_set_extra(self): + """Test the `Group.set_extra` method.""" + with self.assertRaises(exceptions.ValidationError): + self.group.set_extra('illegal.key', 'value') + + self.group.set_extra('valid_key', 'value') + self.group.store() + + self.group.set_extra('valid_key', 'changed') + self.assertEqual(orm.load_group(self.group.pk).get_extra('valid_key'), 'changed') + + def test_set_extra_many(self): + """Test the `Group.set_extra` method.""" + with self.assertRaises(exceptions.ValidationError): + self.group.set_extra_many({'illegal.key': 'value', 'valid_key': 'value'}) + + self.group.set_extra_many({'valid_key': 'value'}) + self.group.store() + + self.group.set_extra_many({'valid_key': 'changed'}) + self.assertEqual(orm.load_group(self.group.pk).get_extra('valid_key'), 'changed') + + def test_reset_extra(self): + """Test the `Group.reset_extra` method.""" + extras_before = {'extra_one': 'value', 'extra_two': 'value'} + extras_after = {'extra_three': 'value', 'extra_four': 'value'} + extras_illegal = {'extra.illegal': 'value', 'extra_four': 'value'} + + self.group.set_extra_many(extras_before) + self.assertEqual(self.group.extras, extras_before) + self.group.reset_extras(extras_after) + self.assertEqual(self.group.extras, extras_after) + + with self.assertRaises(exceptions.ValidationError): + self.group.reset_extras(extras_illegal) + + self.group.store() + + self.group.reset_extras(extras_after) + self.assertEqual(orm.load_group(self.group.pk).extras, extras_after) + + def test_delete_extra(self): + """Test the `Group.delete_extra` method.""" + self.group.set_extra('valid_key', 'value') + self.assertEqual(self.group.get_extra('valid_key'), 'value') + self.group.delete_extra('valid_key') + + with self.assertRaises(AttributeError): + self.group.delete_extra('valid_key') + + # Repeat with stored group + self.group.set_extra('valid_key', 'value') + self.group.store() + + self.group.delete_extra('valid_key') + with self.assertRaises(AttributeError): + orm.load_group(self.group.pk).get_extra('valid_key') + + def test_delete_extra_many(self): + """Test the `Group.delete_extra_many` method.""" + extras_valid = {'extra_one': 'value', 'extra_two': 'value'} + valid_keys = ['extra_one', 'extra_two'] + invalid_keys = ['extra_one', 'invalid_key'] + + self.group.set_extra_many(extras_valid) + self.assertEqual(self.group.extras, extras_valid) + + with self.assertRaises(AttributeError): + self.group.delete_extra_many(invalid_keys) + + self.group.store() + + self.group.delete_extra_many(valid_keys) + self.assertEqual(orm.load_group(self.group.pk).extras, {}) + + def test_clear_extras(self): + """Test the `Group.clear_extras` method.""" + extras = {'extra_one': 'value', 'extra_two': 'value'} + self.group.set_extra_many(extras) + self.assertEqual(self.group.extras, extras) + + self.group.clear_extras() + self.assertEqual(self.group.extras, {}) + + # Repeat for stored group + self.group.store() + + self.group.clear_extras() + self.assertEqual(orm.load_group(self.group.pk).extras, {}) + + def test_extras_items(self): + """Test the `Group.extras_items` generator.""" + extras = {'extra_one': 'value', 'extra_two': 'value'} + self.group.set_extra_many(extras) + self.assertEqual(dict(self.group.extras_items()), extras) + + def test_extras_keys(self): + """Test the `Group.extras_keys` generator.""" + extras = {'extra_one': 'value', 'extra_two': 'value'} + self.group.set_extra_many(extras) + self.assertEqual(set(self.group.extras_keys()), set(extras)) diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index b3264223d7..11fbcbdf24 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -10,6 +10,7 @@ # pylint: disable=invalid-name,missing-docstring,too-many-lines """Tests for the QueryBuilder.""" import warnings +import pytest from aiida import orm from aiida.backends.testbase import AiidaTestCase @@ -130,6 +131,8 @@ def test_get_group_type_filter(self): self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'pseudo.family'}) self.assertEqual(get_group_type_filter(classifiers, True), {'like': 'pseudo.family%'}) + # Tracked in issue #4281 + @pytest.mark.flaky(reruns=2) def test_process_query(self): """ Test querying for a process class. @@ -737,10 +740,16 @@ def test_queryhelp(self): self.assertEqual(qb.count(), 1) def test_recreate_from_queryhelp(self): - """Test recreating a QueryBuilder from the Query Help""" + """Test recreating a QueryBuilder from the Query Help + + We test appending a Data node and a Process node for variety, as well + as a generic Node specifically because it translates to `entity_type` + as an empty string (which can potentially cause problems). + """ import copy qb1 = orm.QueryBuilder() + qb1.append(orm.Node) qb1.append(orm.Data) qb1.append(orm.CalcJobNode) @@ -1034,14 +1043,14 @@ def test_joins3_user_group(self): qb = orm.QueryBuilder() qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) qb.append(orm.Group, with_user='user', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 1, 'The expected group that belongs to ' 'the selected user was not found.') + self.assertEqual(qb.count(), 1, 'The expected group that belongs to the selected user was not found.') # Search for the user that owns a group qb = orm.QueryBuilder() qb.append(orm.Group, tag='group', filters={'id': {'==': group.id}}) qb.append(orm.User, with_group='group', filters={'id': {'==': user.id}}) - self.assertEqual(qb.count(), 1, 'The expected user that owns the ' 'selected group was not found.') + self.assertEqual(qb.count(), 1, 'The expected user that owns the selected group was not found.') def test_joins_group_node(self): """ @@ -1423,13 +1432,13 @@ def test_statistics_default_class(self): # pylint: disable=no-member expected_dict = { 'description': self.computer.description, - 'scheduler_type': self.computer.get_scheduler_type(), + 'scheduler_type': self.computer.scheduler_type, 'hostname': self.computer.hostname, 'uuid': self.computer.uuid, - 'name': self.computer.name, - 'transport_type': self.computer.get_transport_type(), + 'name': self.computer.label, + 'transport_type': self.computer.transport_type, 'id': self.computer.id, - 'metadata': self.computer.get_metadata(), + 'metadata': self.computer.metadata, } qb = orm.QueryBuilder() diff --git a/tests/orm/utils/test_repository.py b/tests/orm/utils/test_repository.py index a490459edb..f422fdca81 100644 --- a/tests/orm/utils/test_repository.py +++ b/tests/orm/utils/test_repository.py @@ -8,15 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `Repository` utility class.""" - import os import shutil import tempfile from aiida.backends.testbase import AiidaTestCase -from aiida.orm import Node, Data -from aiida.orm.utils.repository import File, FileType from aiida.common.exceptions import ModificationNotAllowed +from aiida.orm import Node, Data +from aiida.repository import File, FileType class TestRepository(AiidaTestCase): diff --git a/tests/parsers/test_parser.py b/tests/parsers/test_parser.py index 54388d51f8..8e1ac45ccb 100644 --- a/tests/parsers/test_parser.py +++ b/tests/parsers/test_parser.py @@ -17,6 +17,7 @@ from aiida.engine import CalcJob from aiida.parsers import Parser from aiida.plugins import CalculationFactory, ParserFactory +from aiida.parsers.plugins.arithmetic.add import SimpleArithmeticAddParser # for demonstration purposes only ArithmeticAddCalculation = CalculationFactory('arithmetic.add') # pylint: disable=invalid-name ArithmeticAddParser = ParserFactory('arithmetic.add') # pylint: disable=invalid-name @@ -106,12 +107,13 @@ def test_parse_from_node(self): retrieved.store() retrieved.add_incoming(node, link_type=LinkType.CREATE, link_label='retrieved') - result, calcfunction = ArithmeticAddParser.parse_from_node(node) + for cls in [ArithmeticAddParser, SimpleArithmeticAddParser]: + result, calcfunction = cls.parse_from_node(node) - self.assertIsInstance(result['sum'], orm.Int) - self.assertEqual(result['sum'].value, summed) - self.assertIsInstance(calcfunction, orm.CalcFunctionNode) - self.assertEqual(calcfunction.exit_status, 0) + self.assertIsInstance(result['sum'], orm.Int) + self.assertEqual(result['sum'].value, summed) + self.assertIsInstance(calcfunction, orm.CalcFunctionNode) + self.assertEqual(calcfunction.exit_status, 0) # Verify that the `retrieved_temporary_folder` keyword can be passed, there is no validation though result, calcfunction = ArithmeticAddParser.parse_from_node(node, retrieved_temporary_folder='/some/path') diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index db30d11ce9..30dc4af061 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -90,7 +90,7 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma handle.write(aiida_in) handle.flush() handle.seek(0) - calc.put_object_from_filelike(handle, key='calcjob_inputs/aiida.in', force=True) + calc.put_object_from_filelike(handle, 'calcjob_inputs/aiida.in', force=True) calc.store() # create log message for calcjob @@ -118,7 +118,7 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma handle.write(aiida_out) handle.flush() handle.seek(0) - retrieved_outputs.put_object_from_filelike(handle, key='calcjob_outputs/aiida.out', force=True) + retrieved_outputs.put_object_from_filelike(handle, 'calcjob_outputs/aiida.out', force=True) retrieved_outputs.store() retrieved_outputs.add_incoming(calc, link_type=LinkType.CREATE, link_label='retrieved') @@ -129,22 +129,22 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma calc1.store() dummy_computers = [{ - 'name': 'test1', + 'label': 'test1', 'hostname': 'test1.epfl.ch', 'transport_type': 'ssh', 'scheduler_type': 'pbspro', }, { - 'name': 'test2', + 'label': 'test2', 'hostname': 'test2.epfl.ch', 'transport_type': 'ssh', 'scheduler_type': 'torque', }, { - 'name': 'test3', + 'label': 'test3', 'hostname': 'test3.epfl.ch', 'transport_type': 'local', 'scheduler_type': 'slurm', }, { - 'name': 'test4', + 'label': 'test4', 'hostname': 'test4.epfl.ch', 'transport_type': 'ssh', 'scheduler_type': 'slurm', diff --git a/tests/restapi/test_threaded_restapi.py b/tests/restapi/test_threaded_restapi.py index 7a530061da..41109a6336 100644 --- a/tests/restapi/test_threaded_restapi.py +++ b/tests/restapi/test_threaded_restapi.py @@ -62,6 +62,8 @@ def test_run_threaded_server(restapi_server, server_url, aiida_localhost): pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') +# Tracked in issue #4281 +@pytest.mark.flaky(reruns=2) @pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd): """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" diff --git a/tests/schedulers/test_direct.py b/tests/schedulers/test_direct.py index 613c723ad2..8f8c2c295d 100644 --- a/tests/schedulers/test_direct.py +++ b/tests/schedulers/test_direct.py @@ -7,8 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name,protected-access +"""Tests for the `DirectScheduler` plugin.""" import unittest + from aiida.schedulers.plugins.direct import DirectScheduler from aiida.schedulers import SchedulerError @@ -84,7 +86,3 @@ def test_parse_linux_joblist_output(self): job_ids = [job.job_id for job in result] self.assertIn('11383', job_ids) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/schedulers/test_lsf.py b/tests/schedulers/test_lsf.py index 7c5c574177..9eb6564f2a 100644 --- a/tests/schedulers/test_lsf.py +++ b/tests/schedulers/test_lsf.py @@ -7,21 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - -import unittest +# pylint: disable=invalid-name,protected-access +"""Tests for the `LsfScheduler` plugin.""" import logging +import unittest import uuid -from aiida.schedulers.plugins.lsf import * +from aiida.schedulers.datastructures import JobState +from aiida.schedulers.scheduler import SchedulerError +from aiida.schedulers.plugins.lsf import LsfScheduler BJOBS_STDOUT_TO_TEST = '764213236|EXIT|TERM_RUNLIMIT: job killed after reaching LSF run time limit' \ - '|b681e480bd|inewton|1|-|b681e480bd|test|Feb 2 00:46|Feb 2 00:45|-|Feb 2 00:44|aiida-1033269\n' \ - '764220165|PEND|-|-|inewton|-|-|-|8nm|-|-|-|Feb 2 01:46|aiida-1033444\n' \ + '|b681e480bd|inewton|1|-|b681e480bd|test|Feb 2 00:46|Feb 2 00:45|-|Feb 2 00:44' \ + '|aiida-1033269\n' '764220165|PEND|-|-|inewton|-|-|-|8nm|-|-|-|Feb 2 01:46|aiida-1033444\n' \ '764220167|PEND|-|-|fchopin|-|-|-|test|-|-|-|Feb 2 01:53 L|aiida-1033449\n' \ - '764254593|RUN|-|lxbsu2710|inewton|1|-|lxbsu2710|test|Feb 2 07:40|Feb 2 07:39|-|Feb 2 07:39|test\n' \ - '764255172|RUN|-|b68ac74822|inewton|1|-|b68ac74822|test|Feb 2 07:48 L|Feb 2 07:47|15.00% L|Feb 2 07:47|test\n' \ - '764245175|RUN|-|b68ac74822|dbowie|1|-|b68ac74822|test|Jan 1 05:07|Dec 31 23:48 L|25.00%|Dec 31 23:40|test\n' \ - '764399747|DONE|-|p05496706j68144|inewton|1|-|p05496706j68144|test|Feb 2 14:56 L|Feb 2 14:54|38.33% L|Feb 2 14:54|test' + '764254593|RUN|-|lxbsu2710|inewton|1|-|lxbsu2710|test|Feb 2 07:40|Feb 2 07:39|-|Feb 2 07:39|'\ + 'test\n764255172|RUN|-|b68ac74822|inewton|1|-|b68ac74822|test|Feb 2 07:48 L|Feb 2 07:47| ' \ + '15.00% L|Feb 2 07:47|test\n764245175|RUN|-|b68ac74822|dbowie|1|-|b68ac74822|test|' \ + 'Jan 1 05:07|Dec 31 23:48 L|25.00%|Dec 31 23:40|test\n 764399747|DONE|-|p05496706j68144|' \ + 'inewton|1|-|p05496706j68144|test|Feb 2 14:56 L|Feb 2 14:54|38.33% L|Feb 2 14:54|test' BJOBS_STDERR_TO_TEST = 'Job <864220165> is not found' SUBMIT_STDOUT_TO_TEST = 'Job <764254593> is submitted to queue .' @@ -38,6 +42,7 @@ def test_parse_common_joblist_output(self): """ Test whether _parse_joblist can parse the bjobs output """ + # pylint: disable=too-many-locals,too-many-statements import datetime scheduler = LsfScheduler() @@ -79,8 +84,7 @@ def test_parse_common_joblist_output(self): self.assertEqual(job_done_annotation, job_done_annotation_parsed) job_running = 3 - job_running_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING]) + job_running_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.RUNNING]) self.assertEqual(job_running, job_running_parsed) running_users = ['inewton', 'inewton', 'dbowie'] @@ -104,13 +108,14 @@ def test_parse_common_joblist_output(self): self.assertEqual([j.wallclock_time_seconds for j in job_list if j.job_id == '764245175'][0], 4785) current_year = datetime.datetime.now().year self.assertEqual([j.submission_time for j in job_list if j.job_id == '764245175'][0], - datetime.datetime(current_year, 12, 31, 23, 40)) + datetime.datetime(current_year, 12, 31, 23, 40)) # Important to enable again logs! logging.disable(logging.NOTSET) class TestSubmitScript(unittest.TestCase): + """Tests for the submit script.""" def test_submit_script(self): """ @@ -140,8 +145,7 @@ def test_submit_script(self): self.assertTrue('#BSUB -W 24:00' in submit_script_text) self.assertTrue('#BSUB -n 2' in submit_script_text) - self.assertTrue("'mpirun' '-np' '2' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '2' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_machines(self): """ @@ -160,6 +164,7 @@ def test_submit_script_with_num_machines(self): class TestParserSubmit(unittest.TestCase): + """Test the parsing of the submit response.""" def test_submit_output(self): """ @@ -174,6 +179,7 @@ def test_submit_output(self): class TestParserBkill(unittest.TestCase): + """Test the parsing of the kill response.""" def test_kill_output(self): """ @@ -185,7 +191,3 @@ def test_kill_output(self): stderr = '' self.assertTrue(scheduler._parse_kill_output(retval, stdout, stderr)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/schedulers/test_pbspro.py b/tests/schedulers/test_pbspro.py index 60271735d4..a75718ae18 100644 --- a/tests/schedulers/test_pbspro.py +++ b/tests/schedulers/test_pbspro.py @@ -7,10 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name,protected-access,too-many-lines +"""Tests for the `PbsProScheduler` plugin.""" import unittest import uuid -from aiida.schedulers.plugins.pbspro import * + +from aiida.schedulers.plugins.pbspro import PbsproScheduler from aiida.schedulers.datastructures import JobState text_qstat_f_to_test = """Job Id: 68350.mycluster @@ -337,8 +339,8 @@ """ -## This contains in the 10-th job unexpected newlines -## in the sched_hint field. Still, it should parse correctly. +# This contains in the 10-th job unexpected newlines +# in the sched_hint field. Still, it should parse correctly. text_qstat_f_to_test_with_unexpected_newlines = """Job Id: 549159 Job_Name = somejob Job_Owner = user_549159 @@ -762,6 +764,7 @@ def test_parse_common_joblist_output(self): """ Test whether _parse_joblist can parse the qstat -f output """ + # pylint: disable=too-many-locals scheduler = PbsproScheduler() retval = 0 @@ -776,28 +779,23 @@ def test_parse_common_joblist_output(self): self.assertEqual(job_parsed, job_on_cluster) job_running = 2 - job_running_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING]) + job_running_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.RUNNING]) self.assertEqual(job_running, job_running_parsed) job_held = 2 - job_held_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED_HELD]) + job_held_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED_HELD]) self.assertEqual(job_held, job_held_parsed) job_queued = 2 - job_queued_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED]) + job_queued_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED]) self.assertEqual(job_queued, job_queued_parsed) running_users = ['user02', 'user3'] - parsed_running_users = [j.job_owner for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_users = [j.job_owner for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_users), set(parsed_running_users)) running_jobs = ['69301.mycluster', '74164.mycluster'] - parsed_running_jobs = [j.job_id for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_jobs = [j.job_id for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_jobs), set(parsed_running_jobs)) for j in job_list: @@ -810,13 +808,13 @@ def test_parse_common_joblist_output(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars def test_parse_with_unexpected_newlines(self): """ Test whether _parse_joblist can parse the qstat -f output also when there are unexpected newlines """ + # pylint: disable=too-many-locals scheduler = PbsproScheduler() retval = 0 @@ -831,28 +829,23 @@ def test_parse_with_unexpected_newlines(self): self.assertEqual(job_parsed, job_on_cluster) job_running = 2 - job_running_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING]) + job_running_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.RUNNING]) self.assertEqual(job_running, job_running_parsed) job_held = 1 - job_held_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED_HELD]) + job_held_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED_HELD]) self.assertEqual(job_held, job_held_parsed) job_queued = 5 - job_queued_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED]) + job_queued_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED]) self.assertEqual(job_queued, job_queued_parsed) running_users = ['somebody', 'user_556491'] - parsed_running_users = [j.job_owner for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_users = [j.job_owner for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_users), set(parsed_running_users)) running_jobs = ['555716', '556491'] - parsed_running_jobs = [j.job_id for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_jobs = [j.job_id for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_jobs), set(parsed_running_jobs)) for j in job_list: @@ -865,10 +858,9 @@ def test_parse_with_unexpected_newlines(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars -# TODO: WHEN WE USE THE CORRECT ERROR MANAGEMENT, REIMPLEMENT THIS TEST +# TODO: WHEN WE USE THE CORRECT ERROR MANAGEMENT, REIMPLEMENT THIS TEST # pylint: disable=fixme # def test_parse_with_error_retval(self): # """ # The qstat -f command has received a retval != 0 @@ -898,6 +890,7 @@ def test_parse_with_unexpected_newlines(self): class TestSubmitScript(unittest.TestCase): + """Test the submit script.""" def test_submit_script(self): """ @@ -968,7 +961,8 @@ def test_submit_script_with_num_cores_per_machine(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=2, num_cores_per_machine=24) + num_machines=1, num_mpiprocs_per_machine=2, num_cores_per_machine=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -1000,7 +994,8 @@ def test_submit_script_with_num_cores_per_mpiproc(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -1034,7 +1029,8 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -1066,4 +1062,5 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): job_tmpl = JobTemplate() with self.assertRaises(ValueError): job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23 + ) diff --git a/tests/schedulers/test_sge.py b/tests/schedulers/test_sge.py index a78ceed045..c0957d57b7 100644 --- a/tests/schedulers/test_sge.py +++ b/tests/schedulers/test_sge.py @@ -7,10 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name,protected-access +"""Tests for the `SgeScheduler` plugin.""" import unittest import logging -from aiida.schedulers.plugins.sge import * + +from aiida.schedulers.datastructures import JobState +from aiida.schedulers.plugins.sge import SgeScheduler +from aiida.schedulers.scheduler import SchedulerError, SchedulerParsingError text_qstat_ext_urg_xml_test = """ @@ -169,8 +173,10 @@ class TestCommand(unittest.TestCase): + """Test various commands.""" def test_get_joblist_command(self): + """Test the `get_joblist_command`.""" sge = SgeScheduler() # TEST 1: @@ -194,6 +200,7 @@ def test_get_joblist_command(self): self.assertTrue('*' in sge_get_joblist_command) def test_detailed_jobinfo_command(self): + """Test the `get_detailed_jobinfo_command`.""" sge = SgeScheduler() sge_get_djobinfo_command = sge._get_detailed_job_info_command('123456') @@ -203,6 +210,7 @@ def test_detailed_jobinfo_command(self): self.assertTrue('-j' in sge_get_djobinfo_command) def test_get_submit_command(self): + """Test the `get_submit_command`.""" sge = SgeScheduler() sge_get_submit_command = sge._get_submit_command('script.sh') @@ -212,6 +220,7 @@ def test_get_submit_command(self): self.assertTrue('script.sh' in sge_get_submit_command) def test_parse_submit_output(self): + """Test the `parse_submit_command`.""" sge = SgeScheduler() # TEST 1: @@ -225,6 +234,7 @@ def test_parse_submit_output(self): logging.disable(logging.NOTSET) def test_parse_joblist_output(self): + """Test the `parse_joblist_command`.""" sge = SgeScheduler() retval = 0 @@ -294,6 +304,7 @@ def test_parse_joblist_output(self): logging.disable(logging.NOTSET) def test_submit_script(self): + """Test the submit script.""" from aiida.schedulers.datastructures import JobTemplate sge = SgeScheduler() diff --git a/tests/schedulers/test_slurm.py b/tests/schedulers/test_slurm.py index 057e26d3f8..0970344a39 100644 --- a/tests/schedulers/test_slurm.py +++ b/tests/schedulers/test_slurm.py @@ -13,7 +13,10 @@ import uuid import datetime +import pytest + from aiida.schedulers.plugins.slurm import SlurmScheduler, JobState +from aiida.schedulers import SchedulerError # pylint: disable=line-too-long # job_id, state_raw, annotation, executing_host, username, number_nodes, number_cpus, allocated_machines, partition, time_limit, time_used, dispatch_time, job_name, submission_time @@ -42,7 +45,7 @@ class TestParserSqueue(unittest.TestCase): def test_parse_common_joblist_output(self): """ - Test whether _parse_joblist can parse the qstat -f output + Test whether _parse_joblist_output can parse the squeue output """ scheduler = SlurmScheduler() @@ -98,6 +101,19 @@ def test_parse_common_joblist_output(self): # # self.assertTrue( j.num_machines==num_machines ) # self.assertTrue( j.num_mpiprocs==num_mpiprocs ) + def test_parse_failed_squeue_output(self): + """ + Test that _parse_joblist_output reacts as expected to failures. + """ + scheduler = SlurmScheduler() + + # non-zero return value should raise + with self.assertRaises(SchedulerError): + _ = scheduler._parse_joblist_output(1, TEXT_SQUEUE_TO_TEST, '') # pylint: disable=protected-access + + # non-empty stderr should be logged + with self.assertLogs(scheduler.logger, 'WARNING'): + _ = scheduler._parse_joblist_output(0, TEXT_SQUEUE_TO_TEST, 'error message') # pylint: disable=protected-access class TestTimes(unittest.TestCase): @@ -184,8 +200,7 @@ def test_submit_script(self): self.assertTrue('#SBATCH --time=1-00:00:00' in submit_script_text) self.assertTrue('#SBATCH --nodes=1' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_bad_shebang(self): """Test that first line of submit script is as expected.""" @@ -243,8 +258,7 @@ def test_submit_script_with_num_cores_per_machine(self): # pylint: disable=inva self.assertTrue('#SBATCH --ntasks-per-node=2' in submit_script_text) self.assertTrue('#SBATCH --cpus-per-task=12' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_cores_per_mpiproc(self): # pylint: disable=invalid-name """ @@ -276,8 +290,7 @@ def test_submit_script_with_num_cores_per_mpiproc(self): # pylint: disable=inva self.assertTrue('#SBATCH --ntasks-per-node=1' in submit_script_text) self.assertTrue('#SBATCH --cpus-per-task=24' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): # pylint: disable=invalid-name """ @@ -312,8 +325,7 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): # pylint: self.assertTrue('#SBATCH --ntasks-per-node=1' in submit_script_text) self.assertTrue('#SBATCH --cpus-per-task=24' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): # pylint: disable=invalid-name """ @@ -334,5 +346,64 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): # pylint: ) -if __name__ == '__main__': - unittest.main() +class TestJoblistCommand(unittest.TestCase): + """ + Tests of the issued squeue command. + """ + + def test_joblist_single(self): + """Test that asking for a single job results in duplication of the list.""" + scheduler = SlurmScheduler() + + command = scheduler._get_joblist_command(jobs=['123']) # pylint: disable=protected-access + self.assertIn('123,123', command) + + def test_joblist_multi(self): + """Test that asking for multiple jobs does not result in duplications.""" + scheduler = SlurmScheduler() + + command = scheduler._get_joblist_command(jobs=['123', '456']) # pylint: disable=protected-access + self.assertIn('123,456', command) + self.assertNotIn('456,456', command) + + +def test_parse_out_of_memory(): + """Test that for job that failed due to OOM `parse_output` return the `ERROR_SCHEDULER_OUT_OF_MEMORY` code.""" + from aiida.engine import CalcJob + + scheduler = SlurmScheduler() + stdout = '' + stderr = '' + detailed_job_info = { + 'retval': 0, + 'stderr': '', + 'stdout': """|||||||||||||||||||||||||||||||||||||||||||||||||| + |||||||||||||||||||||||||||||||||||||||||OUT_OF_MEMORY|||||||||""" + } # yapf: disable + + exit_code = scheduler.parse_output(detailed_job_info, stdout, stderr) + assert exit_code == CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_MEMORY # pylint: disable=no-member + + +@pytest.mark.parametrize('detailed_job_info, expected', [ + ('string', TypeError), # Not a dictionary + ({'stderr': ''}, ValueError), # Key `stdout` missing + ({'stdout': None}, TypeError), # `stdout` is not a string + ({'stdout': ''}, ValueError), # `stdout` does not contain at least two lines + ({'stdout': 'Header\nValue'}, ValueError), # `stdout` second line contains too few elements separated by pipe +]) # yapf: disable +def test_parse_output_invalid(detailed_job_info, expected): + """Test `SlurmScheduler.parse_output` for various invalid arguments.""" + scheduler = SlurmScheduler() + + with pytest.raises(expected): + scheduler.parse_output(detailed_job_info, '', '') + + +def test_parse_output_valid(): + """Test `SlurmScheduler.parse_output` for valid arguments.""" + number_of_fields = len(SlurmScheduler._detailed_job_info_fields) # pylint: disable=protected-access + detailed_job_info = {'stdout': 'Header\n{}'.format('|' * number_of_fields)} + scheduler = SlurmScheduler() + + assert scheduler.parse_output(detailed_job_info, '', '') is None diff --git a/tests/schedulers/test_torque.py b/tests/schedulers/test_torque.py index 7998036bd6..775e358e3e 100644 --- a/tests/schedulers/test_torque.py +++ b/tests/schedulers/test_torque.py @@ -7,11 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name,protected-access,too-many-lines +"""Tests for the `TorqueScheduler` plugin.""" import unittest import uuid + from aiida.schedulers.datastructures import JobState -from aiida.schedulers.plugins.torque import * +from aiida.schedulers.plugins.torque import TorqueScheduler text_qstat_f_to_test = """Job Id: 68350.mycluster Job_Name = cell-Qnormal @@ -762,6 +764,7 @@ def test_parse_common_joblist_output(self): """ Test whether _parse_joblist can parse the qstat -f output """ + # pylint: disable=too-many-locals s = TorqueScheduler() retval = 0 @@ -810,13 +813,13 @@ def test_parse_common_joblist_output(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars def test_parse_with_unexpected_newlines(self): """ Test whether _parse_joblist can parse the qstat -f output also when there are unexpected newlines """ + # pylint: disable=too-many-locals s = TorqueScheduler() retval = 0 @@ -831,28 +834,23 @@ def test_parse_with_unexpected_newlines(self): self.assertEqual(job_parsed, job_on_cluster) job_running = 2 - job_running_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING]) + job_running_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.RUNNING]) self.assertEqual(job_running, job_running_parsed) job_held = 1 - job_held_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED_HELD]) + job_held_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED_HELD]) self.assertEqual(job_held, job_held_parsed) job_queued = 5 - job_queued_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED]) + job_queued_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED]) self.assertEqual(job_queued, job_queued_parsed) running_users = ['somebody', 'user_556491'] - parsed_running_users = [j.job_owner for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_users = [j.job_owner for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_users), set(parsed_running_users)) running_jobs = ['555716', '556491'] - parsed_running_jobs = [j.job_id for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_jobs = [j.job_id for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_jobs), set(parsed_running_jobs)) for j in job_list: @@ -865,10 +863,10 @@ def test_parse_with_unexpected_newlines(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars class TestSubmitScript(unittest.TestCase): + """Test the submit script.""" def test_submit_script(self): """ @@ -895,8 +893,7 @@ def test_submit_script(self): self.assertTrue('#PBS -r n' in submit_script_text) self.assertTrue(submit_script_text.startswith('#!/bin/bash')) self.assertTrue('#PBS -l nodes=1:ppn=1,walltime=24:00:00' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_cores_per_machine(self): """ @@ -906,12 +903,13 @@ def test_submit_script_with_num_cores_per_machine(self): from aiida.schedulers.datastructures import JobTemplate from aiida.common.datastructures import CodeInfo, CodeRunMode - s = TorqueScheduler() + scheduler = TorqueScheduler() job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' - job_tmpl.job_resource = s.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24) + job_tmpl.job_resource = scheduler.create_job_resource( + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -920,7 +918,7 @@ def test_submit_script_with_num_cores_per_machine(self): job_tmpl.codes_info = [code_info] job_tmpl.codes_run_mode = CodeRunMode.SERIAL - submit_script_text = s.get_submit_script(job_tmpl) + submit_script_text = scheduler.get_submit_script(job_tmpl) self.assertTrue('#PBS -r n' in submit_script_text) self.assertTrue(submit_script_text.startswith('#!/bin/bash')) @@ -940,7 +938,8 @@ def test_submit_script_with_num_cores_per_mpiproc(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -971,7 +970,8 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -1001,4 +1001,5 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): job_tmpl = JobTemplate() with self.assertRaises(ValueError): job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23 + ) diff --git a/tests/sphinxext/workchain_source/conf.py b/tests/sphinxext/workchain_source/conf.py index 1c68124fff..bb3ced18f9 100644 --- a/tests/sphinxext/workchain_source/conf.py +++ b/tests/sphinxext/workchain_source/conf.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-builtin,invalid-name,missing-module-docstring # # sphinx-aiida-demo documentation build configuration file, created by # sphinx-quickstart on Mon Oct 2 13:04:07 2017. diff --git a/tests/sphinxext/workchain_source_broken/conf.py b/tests/sphinxext/workchain_source_broken/conf.py index 1c68124fff..bb3ced18f9 100644 --- a/tests/sphinxext/workchain_source_broken/conf.py +++ b/tests/sphinxext/workchain_source_broken/conf.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-builtin,invalid-name,missing-module-docstring # # sphinx-aiida-demo documentation build configuration file, created by # sphinx-quickstart on Mon Oct 2 13:04:07 2017. diff --git a/tests/static/__init__.py b/tests/static/__init__.py new file mode 100644 index 0000000000..d1b1daa282 --- /dev/null +++ b/tests/static/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Collection of static test data.""" + +import os + +STATIC_DIR = os.path.dirname(__file__) diff --git a/tests/fixtures/calcjob/arithmetic.add.aiida b/tests/static/calcjob/arithmetic.add.aiida similarity index 100% rename from tests/fixtures/calcjob/arithmetic.add.aiida rename to tests/static/calcjob/arithmetic.add.aiida diff --git a/tests/fixtures/calcjob/arithmetic.add_old.aiida b/tests/static/calcjob/arithmetic.add_old.aiida similarity index 100% rename from tests/fixtures/calcjob/arithmetic.add_old.aiida rename to tests/static/calcjob/arithmetic.add_old.aiida diff --git a/tests/fixtures/data/Si.cif b/tests/static/data/Si.cif similarity index 100% rename from tests/fixtures/data/Si.cif rename to tests/static/data/Si.cif diff --git a/tests/fixtures/export/compare/django.aiida b/tests/static/export/compare/django.aiida similarity index 100% rename from tests/fixtures/export/compare/django.aiida rename to tests/static/export/compare/django.aiida diff --git a/tests/fixtures/export/compare/sqlalchemy.aiida b/tests/static/export/compare/sqlalchemy.aiida similarity index 100% rename from tests/fixtures/export/compare/sqlalchemy.aiida rename to tests/static/export/compare/sqlalchemy.aiida diff --git a/tests/fixtures/export/migrate/empty.aiida b/tests/static/export/migrate/empty.aiida similarity index 100% rename from tests/fixtures/export/migrate/empty.aiida rename to tests/static/export/migrate/empty.aiida diff --git a/tests/fixtures/export/migrate/export_v0.1_simple.aiida b/tests/static/export/migrate/export_v0.1_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.1_simple.aiida rename to tests/static/export/migrate/export_v0.1_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.2_simple.aiida b/tests/static/export/migrate/export_v0.2_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.2_simple.aiida rename to tests/static/export/migrate/export_v0.2_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.3_simple.aiida b/tests/static/export/migrate/export_v0.3_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.3_simple.aiida rename to tests/static/export/migrate/export_v0.3_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.4_simple.aiida b/tests/static/export/migrate/export_v0.4_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.4_simple.aiida rename to tests/static/export/migrate/export_v0.4_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.5_simple.aiida b/tests/static/export/migrate/export_v0.5_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.5_simple.aiida rename to tests/static/export/migrate/export_v0.5_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.6_simple.aiida b/tests/static/export/migrate/export_v0.6_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.6_simple.aiida rename to tests/static/export/migrate/export_v0.6_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.7_simple.aiida b/tests/static/export/migrate/export_v0.7_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.7_simple.aiida rename to tests/static/export/migrate/export_v0.7_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.8_simple.aiida b/tests/static/export/migrate/export_v0.8_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.8_simple.aiida rename to tests/static/export/migrate/export_v0.8_simple.aiida diff --git a/tests/fixtures/export/migrate/export_v0.9_simple.aiida b/tests/static/export/migrate/export_v0.9_simple.aiida similarity index 100% rename from tests/fixtures/export/migrate/export_v0.9_simple.aiida rename to tests/static/export/migrate/export_v0.9_simple.aiida diff --git a/tests/fixtures/graphs/graph1.aiida b/tests/static/graphs/graph1.aiida similarity index 100% rename from tests/fixtures/graphs/graph1.aiida rename to tests/static/graphs/graph1.aiida diff --git a/tests/fixtures/pseudos/Ba.json b/tests/static/pseudos/Ba.json similarity index 100% rename from tests/fixtures/pseudos/Ba.json rename to tests/static/pseudos/Ba.json diff --git a/tests/fixtures/pseudos/Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF b/tests/static/pseudos/Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF similarity index 100% rename from tests/fixtures/pseudos/Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF rename to tests/static/pseudos/Ba.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF diff --git a/tests/fixtures/pseudos/C.json b/tests/static/pseudos/C.json similarity index 100% rename from tests/fixtures/pseudos/C.json rename to tests/static/pseudos/C.json diff --git a/tests/fixtures/pseudos/C_pbe_v1.2.uspp.F.UPF b/tests/static/pseudos/C_pbe_v1.2.uspp.F.UPF similarity index 100% rename from tests/fixtures/pseudos/C_pbe_v1.2.uspp.F.UPF rename to tests/static/pseudos/C_pbe_v1.2.uspp.F.UPF diff --git a/tests/fixtures/pseudos/O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF b/tests/static/pseudos/O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF similarity index 100% rename from tests/fixtures/pseudos/O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF rename to tests/static/pseudos/O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF diff --git a/tests/fixtures/pseudos/Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF b/tests/static/pseudos/Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF similarity index 100% rename from tests/fixtures/pseudos/Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF rename to tests/static/pseudos/Ti.pbesol-spn-rrkjus_psl.0.2.3-tot-pslib030.UPF diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index ec908c05ea..75f73c0d21 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -7,8 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines,invalid-name """Tests for specific subclasses of Data.""" - import os import tempfile import unittest @@ -39,7 +39,7 @@ def to_list_of_lists(lofl): :param lofl: an iterable of iterables :return: a list of lists""" - return [[el for el in l] for l in lofl] + return [[el for el in l] for l in lofl] # pylint: disable=unnecessary-comprehension def simplify(string): @@ -96,6 +96,7 @@ class TestCifData(AiidaTestCase): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_reload_cifdata(self): + """Test `CifData` cycle.""" file_content = 'data_test _cell_length_a 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: filename = tmpf.name @@ -157,6 +158,7 @@ def test_reload_cifdata(self): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_parse_cifdata(self): + """Test parsing a CIF file.""" file_content = 'data_test _cell_length_a 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(file_content) @@ -167,6 +169,7 @@ def test_parse_cifdata(self): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_change_cifdata_file(self): + """Test changing file for `CifData` before storing.""" file_content_1 = 'data_test _cell_length_a 10(1)' file_content_2 = 'data_test _cell_length_a 11(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -186,6 +189,7 @@ def test_change_cifdata_file(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_get_structure(self): + """Test `CifData.get_structure`.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -422,6 +426,7 @@ def test_cif_with_long_line(): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_cif_roundtrip(self): + """Test the `CifData` roundtrip.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -455,6 +460,7 @@ def test_cif_roundtrip(self): self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access def test_symop_string_from_symop_matrix_tr(self): + """Test symmetry operations.""" from aiida.tools.data.cif import symop_string_from_symop_matrix_tr self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 'x,y,z') @@ -468,6 +474,7 @@ def test_symop_string_from_symop_matrix_tr(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_attached_hydrogens(self): + """Test parsing of file with attached hydrogens.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -1024,6 +1031,7 @@ class TestStructureData(AiidaTestCase): """ Tests the creation of StructureData objects (cell and pbc). """ + # pylint: disable=too-many-public-methods from aiida.orm.nodes.data.structure import has_ase, has_spglib from aiida.orm.nodes.data.cif import has_pycifrw @@ -1362,9 +1370,9 @@ def test_kind_5_bis_ase(self): [4, 0, 0], ]) - asecell[0].mass = 12. - asecell[1].mass = 12. - asecell[2].mass = 12. + asecell[0].mass = 12. # pylint: disable=assigning-non-slot + asecell[1].mass = 12. # pylint: disable=assigning-non-slot + asecell[2].mass = 12. # pylint: disable=assigning-non-slot s = StructureData(ase=asecell) @@ -1392,9 +1400,9 @@ def test_kind_5_bis_ase_unknown(self): [4, 0, 0], ]) - asecell[0].mass = 12. - asecell[1].mass = 12. - asecell[2].mass = 12. + asecell[0].mass = 12. # pylint: disable=assigning-non-slot + asecell[1].mass = 12. # pylint: disable=assigning-non-slot + asecell[2].mass = 12. # pylint: disable=assigning-non-slot s = StructureData(ase=asecell) @@ -1497,6 +1505,7 @@ def test_kind_8(self): """ Test the ase_refine_cell() function """ + # pylint: disable=too-many-statements from aiida.orm.nodes.data.structure import ase_refine_cell import ase import math @@ -1912,7 +1921,7 @@ def test_ase(self): (0., 0., 0.), (0.5, 0.7, 0.9), )) - a[1].mass = 110.2 + a[1].mass = 110.2 # pylint: disable=assigning-non-slot b = StructureData(ase=a) c = b.get_ase() @@ -1973,8 +1982,8 @@ def test_conversion_of_types_2(self): )) a.set_tags((0, 1, 0, 1)) - a[2].mass = 100. - a[3].mass = 300. + a[2].mass = 100. # pylint: disable=assigning-non-slot + a[3].mass = 300. # pylint: disable=assigning-non-slot b = StructureData(ase=a) # This will give funny names to the kinds, because I am using @@ -2026,9 +2035,9 @@ def test_conversion_of_types_4(self): import ase atoms = ase.Atoms('Fe5') - atoms[2].tag = 1 - atoms[3].tag = 1 - atoms[4].tag = 4 + atoms[2].tag = 1 # pylint: disable=assigning-non-slot + atoms[3].tag = 1 # pylint: disable=assigning-non-slot + atoms[4].tag = 4 # pylint: disable=assigning-non-slot atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} @@ -2048,9 +2057,9 @@ def test_conversion_of_types_5(self): import ase atoms = ase.Atoms('Fe5') - atoms[0].tag = 1 - atoms[2].tag = 1 - atoms[3].tag = 4 + atoms[0].tag = 1 # pylint: disable=assigning-non-slot + atoms[2].tag = 1 # pylint: disable=assigning-non-slot + atoms[3].tag = 4 # pylint: disable=assigning-non-slot atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} @@ -2517,6 +2526,7 @@ def test_creation(self): Check the methods to add, remove, modify, and get arrays and array shapes. """ + # pylint: disable=too-many-statements import numpy # Create a node with two arrays @@ -2635,6 +2645,7 @@ class TestTrajectoryData(AiidaTestCase): def test_creation(self): """Check the methods to set and retrieve a trajectory.""" + # pylint: disable=too-many-statements import numpy # Create a node with two arrays @@ -3238,6 +3249,7 @@ def test_tetra_z_wrapper_legacy(self): class TestSpglibTupleConversion(AiidaTestCase): + """Tests for conversion of Spglib tuples.""" def test_simple_to_aiida(self): """ @@ -3385,9 +3397,11 @@ def test_aiida_roundtrip(self): class TestSeekpathExplicitPath(AiidaTestCase): + """Tests for the `get_explicit_kpoints_path` from SeeK-path.""" @unittest.skipIf(not has_seekpath(), 'No seekpath available') def test_simple(self): + """Test a simple case.""" import numpy as np from aiida.plugins import DataFactory @@ -3575,7 +3589,7 @@ def test_band(self): b.set_bands(input_bands, units='ev') b.set_bands(input_bands, occupations=input_occupations) with self.assertRaises(TypeError): - b.set_bands(occupations=input_occupations, units='ev') + b.set_bands(occupations=input_occupations, units='ev') # pylint: disable=no-value-for-parameter b.set_bands(input_bands, occupations=input_occupations, units='ev') bands, occupations = b.get_bands(also_occupations=True) diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index 2154dca36b..086ba53b31 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -11,6 +11,7 @@ import unittest from aiida.backends.testbase import AiidaTestCase +from tests.static import STATIC_DIR class TestCodDbImporter(AiidaTestCase): @@ -72,14 +73,12 @@ def test_datatype_checks(self): from aiida.tools.dbimporters.plugins.cod import CodDbImporter codi = CodDbImporter() - messages = ['', - "incorrect value for keyword 'test' -- " + \ - 'only integers and strings are accepted', - "incorrect value for keyword 'test' -- " + \ - 'only strings are accepted', - "incorrect value for keyword 'test' -- " + \ - 'only integers and floats are accepted', - "invalid literal for int() with base 10: 'text'"] + messages = [ + '', "incorrect value for keyword 'test' only integers and strings are accepted", + "incorrect value for keyword 'test' only strings are accepted", + "incorrect value for keyword 'test' only integers and floats are accepted", + "invalid literal for int() with base 10: 'text'" + ] values = [10, 'text', 'text', '10', 1.0 / 3, [1, 2, 3]] methods = [ # pylint: disable=protected-access @@ -258,8 +257,7 @@ def test_upfentry_creation(self): results = NnincSearchResults([{'id': upf}]) entry = results.at(0) - path_root = os.path.split(__file__)[0] - path_pseudos = os.path.join(path_root, 'fixtures', 'pseudos') + path_pseudos = os.path.join(STATIC_DIR, 'pseudos') with open(os.path.join(path_pseudos, '{}.UPF'.format(upf)), 'r', encoding='utf8') as fpntr: entry._contents = fpntr.read() # pylint: disable=protected-access diff --git a/tests/test_generic.py b/tests/test_generic.py index 0c35d5e0ec..0a05834d67 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -81,7 +81,7 @@ def test_remote(self): self.assertTrue(code.can_run_on(self.computer)) othercomputer = orm.Computer( - name='another_localhost', + label='another_localhost', hostname='localhost', transport_type='local', scheduler_type='pbspro', diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 4e85fb7d62..0f7dc23835 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1393,7 +1393,7 @@ def test_valid_links(self): d2 = SinglefileData(file=handle).store() unsavedcomputer = orm.Computer( - name='localhost2', hostname='localhost', scheduler_type='direct', transport_type='local' + label='localhost2', hostname='localhost', scheduler_type='direct', transport_type='local' ) with self.assertRaises(ValueError): diff --git a/tests/tools/importexport/migration/test_migration.py b/tests/tools/importexport/migration/test_migration.py index 82c08b8e1d..4c470dd1ab 100644 --- a/tests/tools/importexport/migration/test_migration.py +++ b/tests/tools/importexport/migration/test_migration.py @@ -114,9 +114,7 @@ def test_migrate_recursively_specific_version(self): with self.assertRaises(ArchiveMigrationError): migrate_recursively(archive.meta_data, archive.data, None, version='0.2') - # Same version will also raise - with self.assertRaises(ArchiveMigrationError): - migrate_recursively(archive.meta_data, archive.data, None, version='0.3') + migrate_recursively(archive.meta_data, archive.data, None, version='0.3') migrated_version = '0.5' version = migrate_recursively(archive.meta_data, archive.data, None, version=migrated_version) @@ -186,17 +184,11 @@ def test_wrong_versions(self): ) def test_migrate_newest_version(self): - """Test that an exception is raised when an export file with the newest export version is migrated.""" + """Test that migrating the latest version runs without complaints.""" metadata = {'export_version': newest_version} - with self.assertRaises(ArchiveMigrationError): - new_version = migrate_recursively(metadata, {}, None) - - self.assertIsNone( - new_version, - msg='migrate_recursively should not return anything, ' - "hence the 'return' should be None, but instead it is {}".format(new_version) - ) + new_version = migrate_recursively(metadata, {}, None) + self.assertEqual(new_version, newest_version) @with_temp_dir def test_v02_to_newest(self, temp_dir): diff --git a/tests/tools/importexport/migration/test_migrations.py b/tests/tools/importexport/migration/test_migration_array.py similarity index 97% rename from tests/tools/importexport/migration/test_migrations.py rename to tests/tools/importexport/migration/test_migration_array.py index fc2546d259..76106666b5 100644 --- a/tests/tools/importexport/migration/test_migrations.py +++ b/tests/tools/importexport/migration/test_migration_array.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=redefined-outer-name -"""Test the export archive migrations on the archives included in `tests/fixtures/export/migrate`.""" +"""Test migrating all export archives included in `tests/static/export/migrate`.""" import copy import pytest diff --git a/tests/tools/importexport/orm/test_computers.py b/tests/tools/importexport/orm/test_computers.py index e32ed98bb6..059cbe9184 100644 --- a/tests/tools/importexport/orm/test_computers.py +++ b/tests/tools/importexport/orm/test_computers.py @@ -60,7 +60,7 @@ def test_same_computer_import(self, temp_dir): calc2.seal() # Store locally the computer name - comp_name = str(comp.name) + comp_name = str(comp.label) comp_uuid = str(comp.uuid) # Export the first job calculation @@ -149,14 +149,14 @@ def test_same_computer_different_name_import(self, temp_dir): calc1.seal() # Store locally the computer name - comp1_name = str(comp1.name) + comp1_name = str(comp1.label) # Export the first job calculation filename1 = os.path.join(temp_dir, 'export1.aiida') export([calc1], filename=filename1, silent=True) # Rename the computer - comp1.set_name(comp1_name + '_updated') + comp1.label = comp1_name + '_updated' # Store a second calculation calc2_label = 'calc2' @@ -223,7 +223,7 @@ def test_different_computer_same_name_import(self, temp_dir): # Set the computer name comp1_name = 'localhost_1' - self.computer.set_name(comp1_name) + self.computer.label = comp1_name # Store a calculation calc1_label = 'calc1' @@ -243,7 +243,7 @@ def test_different_computer_same_name_import(self, temp_dir): self.insert_data() # Set the computer name to the same name as before - self.computer.set_name(comp1_name) + self.computer.label = comp1_name # Store a second calculation calc2_label = 'calc2' @@ -263,7 +263,7 @@ def test_different_computer_same_name_import(self, temp_dir): self.insert_data() # Set the computer name to the same name as before - self.computer.set_name(comp1_name) + self.computer.label = comp1_name # Store a third calculation calc3_label = 'calc3' @@ -319,8 +319,8 @@ def test_import_of_computer_json_params(self, temp_dir): # Set the computer name comp1_name = 'localhost_1' comp1_metadata = {'workdir': '/tmp/aiida'} - self.computer.set_name(comp1_name) - self.computer.set_metadata(comp1_metadata) + self.computer.label = comp1_name + self.computer.metadata = comp1_metadata # Store a calculation calc1_label = 'calc1' @@ -368,7 +368,7 @@ def test_import_of_django_sqla_export_file(self): builder = orm.QueryBuilder() builder.append( orm.Computer, project=['metadata'], tag='comp', filters={'name': { - '!==': self.computer.name + '!==': self.computer.label }} ) self.assertEqual(builder.count(), 1, 'Expected only one computer') diff --git a/tests/tools/importexport/test_specific_import.py b/tests/tools/importexport/test_specific_import.py index 9ee506250f..df709fd9fb 100644 --- a/tests/tools/importexport/test_specific_import.py +++ b/tests/tools/importexport/test_specific_import.py @@ -13,14 +13,12 @@ import shutil import tempfile -import unittest - import numpy as np from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common.folders import RepositoryFolder -from aiida.orm.utils.repository import Repository +from aiida.orm.utils._repository import Repository from aiida.tools.importexport import import_data, export from aiida.tools.importexport.common import exceptions @@ -268,7 +266,6 @@ def test_missing_node_repo_folder_import(self, temp_dir): 'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception) ) - @unittest.skip('Reenable when issue #3199 is solve (PR #3242): Fix `extract_tree`') @with_temp_dir def test_empty_repo_folder_export(self, temp_dir): """Check a Node's empty repository folder is exported properly""" diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 59a39dfcc1..9a90aac453 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines,fixme """ This module contains a set of unittest test classes that can be loaded from the plugin. @@ -41,7 +42,9 @@ def get_all_custom_transports(): thisdir, thisfname = os.path.split(this_full_fname) test_modules = [ - os.path.split(fname)[1][:-3] for fname in os.listdir(thisdir) if fname.endswith('.py') and fname.startswith('test_') + os.path.split(fname)[1][:-3] + for fname in os.listdir(thisdir) + if fname.endswith('.py') and fname.startswith('test_') ] # Remove this module: note that I should be careful because __file__, from @@ -53,13 +56,13 @@ def get_all_custom_transports(): print('Warning, this module ({}) was not found!'.format(thisbasename)) all_custom_transports = {} - for m in test_modules: - module = importlib.import_module('.'.join([modulename, m])) + for module in test_modules: + module = importlib.import_module('.'.join([modulename, module])) custom_transport = module.__dict__.get('plugin_transport', None) if custom_transport is None: - print('Define the plugin_transport variable inside the {} module!'.format(m)) + print('Define the plugin_transport variable inside the {} module!'.format(module)) else: - all_custom_transports[m] = custom_transport + all_custom_transports[module] = custom_transport return all_custom_transports @@ -84,10 +87,9 @@ def test_all_plugins(self): for tr_name, custom_transport in all_custom_transports.items(): try: actual_test_method(self, custom_transport) - except Exception as e: + except Exception as exception: # pylint: disable=broad-except import traceback - - exceptions.append((e, traceback.format_exc(), tr_name)) + exceptions.append((exception, traceback.format_exc(), tr_name)) if exceptions: if all(isinstance(exc[0], AssertionError) for exc in exceptions): @@ -97,8 +99,7 @@ def test_all_plugins(self): messages = ['*** At least one test for a subplugin failed. See below ***', ''] for exc in exceptions: - messages.append("*** [For plugin {}]: Exception '{}': {}" - .format(exc[2], type(exc[0]).__name__, exc[0])) + messages.append("*** [For plugin {}]: Exception '{}': {}".format(exc[2], type(exc[0]).__name__, exc[0])) messages.append(exc[1]) raise exception_to_raise('\n'.join(messages)) @@ -137,38 +138,38 @@ def test_makedirs(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) # define folder structure dir_tree = os.path.join('1', '2') # I create the tree - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # verify the existence - self.assertTrue(t.isdir('1')) + self.assertTrue(transport.isdir('1')) self.assertTrue(dir_tree) # try to recreate the same folder with self.assertRaises(OSError): - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # recreate but with ignore flag - t.makedirs(dir_tree, True) + transport.makedirs(dir_tree, True) - t.rmdir(dir_tree) - t.rmdir('1') + transport.rmdir(dir_tree) + transport.rmdir('1') - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_rmtree(self, custom_transport): @@ -180,40 +181,40 @@ def test_rmtree(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) # define folder structure dir_tree = os.path.join('1', '2') # I create the tree - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # remove it - t.rmtree('1') + transport.rmtree('1') # verify the removal - self.assertFalse(t.isdir('1')) + self.assertFalse(transport.isdir('1')) # also tests that it works with a single file # create file local_file_name = 'file.txt' text = 'Viva Verdi\n' - with open(os.path.join(t.getcwd(), local_file_name), 'w', encoding='utf8') as fhandle: + with open(os.path.join(transport.getcwd(), local_file_name), 'w', encoding='utf8') as fhandle: fhandle.write(text) # remove it - t.rmtree(local_file_name) + transport.rmtree(local_file_name) # verify the removal - self.assertFalse(t.isfile(local_file_name)) + self.assertFalse(transport.isfile(local_file_name)) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_listdir(self, custom_transport): @@ -317,9 +318,15 @@ def simplify_attributes(data): 'a2': True, 'a4f': True, 'a': False - }) + } + ) self.assertTrue(simplify_attributes(trans.listdir_withattributes('.', 'a?')), {'as': True, 'a2': True}) - self.assertTrue(simplify_attributes(trans.listdir_withattributes('.', 'a[2-4]*')), {'a2': True, 'a4f': True}) + self.assertTrue( + simplify_attributes(trans.listdir_withattributes('.', 'a[2-4]*')), { + 'a2': True, + 'a4f': True + } + ) for this_dir in list_of_dir: trans.rmdir(this_dir) @@ -332,29 +339,30 @@ def simplify_attributes(data): @run_for_all_plugins def test_dir_creation_deletion(self, custom_transport): + """Test creating and deleting directories.""" # Imports required later import random import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) + transport.mkdir(directory) with self.assertRaises(OSError): # I create twice the same directory - t.mkdir(directory) + transport.mkdir(directory) - t.isdir(directory) - self.assertFalse(t.isfile(directory)) - t.rmdir(directory) + transport.isdir(directory) + self.assertFalse(transport.isfile(directory)) + transport.rmdir(directory) @run_for_all_plugins def test_dir_copy(self, custom_transport): @@ -367,30 +375,30 @@ def test_dir_copy(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) + transport.mkdir(directory) dest_directory = directory + '_copy' - t.copy(directory, dest_directory) + transport.copy(directory, dest_directory) with self.assertRaises(ValueError): - t.copy(directory, '') + transport.copy(directory, '') with self.assertRaises(ValueError): - t.copy('', directory) + transport.copy('', directory) - t.rmdir(directory) - t.rmdir(dest_directory) + transport.rmdir(directory) + transport.rmdir(dest_directory) @run_for_all_plugins - def test_dir_permissions_creation_modification(self, custom_transport): + def test_dir_permissions_creation_modification(self, custom_transport): # pylint: disable=invalid-name """ verify if chmod raises IOError when trying to change bits on a non-existing folder @@ -400,51 +408,51 @@ def test_dir_permissions_creation_modification(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - t.mkdir(directory) + transport.mkdir(directory) # change permissions - t.chmod(directory, 0o777) + transport.chmod(directory, 0o777) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0o777) + self.assertEqual(transport.get_mode(directory), 0o777) # change permissions - t.chmod(directory, 0o511) + transport.chmod(directory, 0o511) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0o511) + self.assertEqual(transport.get_mode(directory), 0o511) # TODO : bug in paramiko. When changing the directory to very low \ # I cannot set it back to higher permissions - ## TODO: probably here we should then check for - ## the new directory modes. To see if we want a higher - ## level function to ask for the mode, or we just - ## use get_attribute - t.chdir(directory) + # TODO: probably here we should then check for + # the new directory modes. To see if we want a higher + # level function to ask for the mode, or we just + # use get_attribute + transport.chdir(directory) # change permissions of an empty string, non existing folder. fake_dir = '' with self.assertRaises(IOError): - t.chmod(fake_dir, 0o777) + transport.chmod(fake_dir, 0o777) fake_dir = 'pippo' with self.assertRaises(IOError): # chmod to a non existing folder - t.chmod(fake_dir, 0o777) + transport.chmod(fake_dir, 0o777) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_dir_reading_permissions(self, custom_transport): @@ -457,37 +465,37 @@ def test_dir_reading_permissions(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - t.mkdir(directory) + transport.mkdir(directory) # change permissions to low ones - t.chmod(directory, 0) + transport.chmod(directory, 0) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0) + self.assertEqual(transport.get_mode(directory), 0) - old_cwd = t.getcwd() + old_cwd = transport.getcwd() with self.assertRaises(IOError): - t.chdir(directory) + transport.chdir(directory) - new_cwd = t.getcwd() + new_cwd = transport.getcwd() self.assertEqual(old_cwd, new_cwd) # TODO : the test leaves a directory even if it is successful # The bug is in paramiko. After lowering the permissions, # I cannot restore them to higher values - #t.rmdir(directory) + # transport.rmdir(directory) @run_for_all_plugins def test_isfile_isdir_to_empty_string(self, custom_transport): @@ -497,11 +505,11 @@ def test_isfile_isdir_to_empty_string(self, custom_transport): """ import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) - t.chdir(location) - self.assertFalse(t.isdir('')) - self.assertFalse(t.isfile('')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(location) + self.assertFalse(transport.isdir('')) + self.assertFalse(transport.isfile('')) @run_for_all_plugins def test_isfile_isdir_to_non_existing_string(self, custom_transport): @@ -511,14 +519,14 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport): """ import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) - t.chdir(location) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(location) fake_folder = 'pippo' - self.assertFalse(t.isfile(fake_folder)) - self.assertFalse(t.isdir(fake_folder)) + self.assertFalse(transport.isfile(fake_folder)) + self.assertFalse(transport.isdir(fake_folder)) with self.assertRaises(IOError): - t.chdir(fake_folder) + transport.chdir(fake_folder) @run_for_all_plugins def test_chdir_to_empty_string(self, custom_transport): @@ -529,11 +537,11 @@ def test_chdir_to_empty_string(self, custom_transport): """ import os - with custom_transport as t: - new_dir = t.normalize(os.path.join('/', 'tmp')) - t.chdir(new_dir) - t.chdir('') - self.assertEqual(new_dir, t.getcwd()) + with custom_transport as transport: + new_dir = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(new_dir) + transport.chdir('') + self.assertEqual(new_dir, transport.getcwd()) class TestPutGetFile(unittest.TestCase): @@ -546,6 +554,7 @@ class TestPutGetFile(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): + """Test putting and getting files.""" import os import random import string @@ -554,14 +563,14 @@ def test_put_and_get(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_dir, directory, 'file.txt') remote_file_name = 'file_remote.txt' @@ -572,12 +581,12 @@ def test_put_and_get(self, custom_transport): fhandle.write(text) # here use full path in src and dst - t.put(local_file_name, remote_file_name) - t.get(remote_file_name, retrieved_file_name) - t.putfile(local_file_name, remote_file_name) - t.getfile(remote_file_name, retrieved_file_name) + transport.put(local_file_name, remote_file_name) + transport.get(remote_file_name, retrieved_file_name) + transport.putfile(local_file_name, remote_file_name) + transport.getfile(remote_file_name, retrieved_file_name) - list_of_files = t.listdir('.') + list_of_files = transport.listdir('.') # it is False because local_file_name has the full path, # while list_of_files has not self.assertFalse(local_file_name in list_of_files) @@ -585,11 +594,11 @@ def test_put_and_get(self, custom_transport): self.assertFalse(retrieved_file_name in list_of_files) os.remove(local_file_name) - t.remove(remote_file_name) + transport.remove(remote_file_name) os.remove(retrieved_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_abs_path(self, custom_transport): @@ -604,14 +613,14 @@ def test_put_get_abs_path(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) partial_file_name = 'file.txt' local_file_name = os.path.join(local_dir, directory, 'file.txt') @@ -623,36 +632,36 @@ def test_put_get_abs_path(self, custom_transport): # partial_file_name is not an abs path with self.assertRaises(ValueError): - t.put(partial_file_name, remote_file_name) + transport.put(partial_file_name, remote_file_name) with self.assertRaises(ValueError): - t.putfile(partial_file_name, remote_file_name) + transport.putfile(partial_file_name, remote_file_name) # retrieved_file_name does not exist with self.assertRaises(OSError): - t.put(retrieved_file_name, remote_file_name) + transport.put(retrieved_file_name, remote_file_name) with self.assertRaises(OSError): - t.putfile(retrieved_file_name, remote_file_name) + transport.putfile(retrieved_file_name, remote_file_name) # remote_file_name does not exist with self.assertRaises(IOError): - t.get(remote_file_name, retrieved_file_name) + transport.get(remote_file_name, retrieved_file_name) with self.assertRaises(IOError): - t.getfile(remote_file_name, retrieved_file_name) + transport.getfile(remote_file_name, retrieved_file_name) - t.put(local_file_name, remote_file_name) - t.putfile(local_file_name, remote_file_name) + transport.put(local_file_name, remote_file_name) + transport.putfile(local_file_name, remote_file_name) # local filename is not an abs path with self.assertRaises(ValueError): - t.get(remote_file_name, 'delete_me.txt') + transport.get(remote_file_name, 'delete_me.txt') with self.assertRaises(ValueError): - t.getfile(remote_file_name, 'delete_me.txt') + transport.getfile(remote_file_name, 'delete_me.txt') - t.remove(remote_file_name) + transport.remove(remote_file_name) os.remove(local_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_empty_string(self, custom_transport): @@ -668,14 +677,14 @@ def test_put_get_empty_string(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_dir, directory, 'file_local.txt') remote_file_name = 'file_remote.txt' @@ -688,48 +697,48 @@ def test_put_get_empty_string(self, custom_transport): # localpath is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.put('', remote_file_name) + transport.put('', remote_file_name) with self.assertRaises(ValueError): - t.putfile('', remote_file_name) + transport.putfile('', remote_file_name) # remote path is an empty string with self.assertRaises(IOError): - t.put(local_file_name, '') + transport.put(local_file_name, '') with self.assertRaises(IOError): - t.putfile(local_file_name, '') + transport.putfile(local_file_name, '') - t.put(local_file_name, remote_file_name) + transport.put(local_file_name, remote_file_name) # overwrite the remote_file_name - t.putfile(local_file_name, remote_file_name) + transport.putfile(local_file_name, remote_file_name) # remote path is an empty string with self.assertRaises(IOError): - t.get('', retrieved_file_name) + transport.get('', retrieved_file_name) with self.assertRaises(IOError): - t.getfile('', retrieved_file_name) + transport.getfile('', retrieved_file_name) # local path is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.get(remote_file_name, '') + transport.get(remote_file_name, '') with self.assertRaises(ValueError): - t.getfile(remote_file_name, '') + transport.getfile(remote_file_name, '') # TODO : get doesn't retrieve empty files. # Is it what we want? - t.get(remote_file_name, retrieved_file_name) + transport.get(remote_file_name, retrieved_file_name) # overwrite retrieved_file_name - t.getfile(remote_file_name, retrieved_file_name) + transport.getfile(remote_file_name, retrieved_file_name) os.remove(local_file_name) - t.remove(remote_file_name) + transport.remove(remote_file_name) # If it couldn't end the copy, it leaves what he did on # local file - self.assertTrue('file_retrieved.txt' in t.listdir('.')) + self.assertTrue('file_retrieved.txt' in transport.listdir('.')) os.remove(retrieved_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) class TestPutGetTree(unittest.TestCase): @@ -742,6 +751,7 @@ class TestPutGetTree(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): + """Test putting and getting files.""" import os import random import string @@ -750,9 +760,9 @@ def test_put_and_get(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: + with custom_transport as transport: - t.chdir(remote_dir) + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -765,7 +775,7 @@ def test_put_and_get(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') @@ -776,14 +786,14 @@ def test_put_and_get(self, custom_transport): # here use full path in src and dst for i in range(2): if i == 0: - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) else: - t.puttree(local_subfolder, remote_subfolder) - t.gettree(remote_subfolder, retrieved_subfolder) + transport.puttree(local_subfolder, remote_subfolder) + transport.gettree(remote_subfolder, retrieved_subfolder) # Here I am mixing the local with the remote fold - list_of_dirs = t.listdir('.') + list_of_dirs = transport.listdir('.') # # it is False because local_file_name has the full path, # # while list_of_files has not self.assertFalse(local_subfolder in list_of_dirs) @@ -792,8 +802,8 @@ def test_put_and_get(self, custom_transport): self.assertTrue('tmp1' in list_of_dirs) self.assertTrue('tmp3' in list_of_dirs) - list_pushed_file = t.listdir('tmp2') - list_retrieved_file = t.listdir('tmp3') + list_pushed_file = transport.listdir('tmp2') + list_retrieved_file = transport.listdir('tmp3') self.assertTrue('file.txt' in list_pushed_file) self.assertTrue('file.txt' in list_retrieved_file) @@ -801,23 +811,25 @@ def test_put_and_get(self, custom_transport): shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) - t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_and_get_overwrite(self, custom_transport): - import os, shutil + """Test putting and getting files with overwrites.""" + import os import random + import shutil import string local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -830,7 +842,7 @@ def test_put_and_get_overwrite(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') @@ -838,32 +850,33 @@ def test_put_and_get_overwrite(self, custom_transport): with open(local_file_name, 'w', encoding='utf8') as fhandle: fhandle.write(text) - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) # by defaults rewrite everything - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) with self.assertRaises(OSError): - t.put(local_subfolder, remote_subfolder, overwrite=False) + transport.put(local_subfolder, remote_subfolder, overwrite=False) with self.assertRaises(OSError): - t.get(remote_subfolder, retrieved_subfolder, overwrite=False) + transport.get(remote_subfolder, retrieved_subfolder, overwrite=False) with self.assertRaises(OSError): - t.puttree(local_subfolder, remote_subfolder, overwrite=False) + transport.puttree(local_subfolder, remote_subfolder, overwrite=False) with self.assertRaises(OSError): - t.gettree(remote_subfolder, retrieved_subfolder, overwrite=False) + transport.gettree(remote_subfolder, retrieved_subfolder, overwrite=False) shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) - t.rmtree(remote_subfolder) - # t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) + # transport.rmtree(remote_subfolder) # here I am mixing inevitably the local and the remote folder - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_copy(self, custom_transport): + """Test copying.""" import os import random import string @@ -872,15 +885,15 @@ def test_copy(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -895,46 +908,48 @@ def test_copy(self, custom_transport): fhandle.write(text) # first test the copy. Copy of two files matching patterns, into a folder - t.copy(os.path.join('local', '*.txt'), '.') - self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(t.listdir('.'))) - t.remove('a.txt') - t.remove('c.txt') + transport.copy(os.path.join('local', '*.txt'), '.') + self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(transport.listdir('.'))) + transport.remove('a.txt') + transport.remove('c.txt') # second test copy. Copy of two folders - t.copy('local', 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir('prova'))) - t.rmtree('prova') + transport.copy('local', 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir('prova'))) + transport.rmtree('prova') # third test copy. Can copy one file into a new file - t.copy(os.path.join('local', '*.tmp'), 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - t.remove('prova') + transport.copy(os.path.join('local', '*.tmp'), 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + transport.remove('prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.copy(os.path.join('local', '*.txt'), 'prova') + transport.copy(os.path.join('local', '*.txt'), 'prova') # fifth test, copying one file into a folder - t.mkdir('prova') - t.copy(os.path.join('local', 'a.txt'), 'prova') - self.assertEqual(set(t.listdir('prova')), set(['a.txt'])) - t.rmtree('prova') + transport.mkdir('prova') + transport.copy(os.path.join('local', 'a.txt'), 'prova') + self.assertEqual(set(transport.listdir('prova')), set(['a.txt'])) + transport.rmtree('prova') # sixth test, copying one file into a file - t.copy(os.path.join('local', 'a.txt'), 'prova') - self.assertTrue(t.isfile('prova')) - t.remove('prova') + transport.copy(os.path.join('local', 'a.txt'), 'prova') + self.assertTrue(transport.isfile('prova')) + transport.remove('prova') # copy of folder into an existing folder - #NOTE: the command cp has a different behavior on Mac vs Ubuntu - #tests performed locally on a Mac may result in a failure. - t.mkdir('prova') - t.copy('local', 'prova') - self.assertEqual(set(['local']), set(t.listdir('prova'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir(os.path.join('prova', 'local')))) - t.rmtree('prova') + # NOTE: the command cp has a different behavior on Mac vs Ubuntu + # tests performed locally on a Mac may result in a failure. + transport.mkdir('prova') + transport.copy('local', 'prova') + self.assertEqual(set(['local']), set(transport.listdir('prova'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir(os.path.join('prova', 'local')))) + transport.rmtree('prova') # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_put(self, custom_transport): + """Test putting files.""" + # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute import os @@ -945,15 +960,15 @@ def test_put(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -967,72 +982,75 @@ def test_put(self, custom_transport): with open(filename, 'w', encoding='utf8') as fhandle: fhandle.write(text) - # first test put. Copy of two files matching patterns, into a folder - t.put(os.path.join(local_base_dir, '*.txt'), '.') - self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(t.listdir('.'))) - t.remove('a.txt') - t.remove('c.txt') + # first test putransport. Copy of two files matching patterns, into a folder + transport.put(os.path.join(local_base_dir, '*.txt'), '.') + self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(transport.listdir('.'))) + transport.remove('a.txt') + transport.remove('c.txt') # second. Copy of folder into a non existing folder - t.put(local_base_dir, 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir('prova'))) - t.rmtree('prova') + transport.put(local_base_dir, 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir('prova'))) + transport.rmtree('prova') # third. copy of folder into an existing folder - t.mkdir('prova') - t.put(local_base_dir, 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['local']), set(t.listdir('prova'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir(os.path.join('prova', 'local')))) - t.rmtree('prova') + transport.mkdir('prova') + transport.put(local_base_dir, 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['local']), set(transport.listdir('prova'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir(os.path.join('prova', 'local')))) + transport.rmtree('prova') # third test copy. Can copy one file into a new file - t.put(os.path.join(local_base_dir, '*.tmp'), 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - t.remove('prova') + transport.put(os.path.join(local_base_dir, '*.tmp'), 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + transport.remove('prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.put(os.path.join(local_base_dir, '*.txt'), 'prova') + transport.put(os.path.join(local_base_dir, '*.txt'), 'prova') # copy of folder into file with open(os.path.join(local_dir, directory, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with self.assertRaises(OSError): - t.put(os.path.join(local_base_dir), 'existing.txt') - t.remove('existing.txt') + transport.put(os.path.join(local_base_dir), 'existing.txt') + transport.remove('existing.txt') # fifth test, copying one file into a folder - t.mkdir('prova') - t.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - self.assertEqual(set(t.listdir('prova')), set(['a.txt'])) - t.rmtree('prova') + transport.mkdir('prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') + self.assertEqual(set(transport.listdir('prova')), set(['a.txt'])) + transport.rmtree('prova') # sixth test, copying one file into a file - t.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - self.assertTrue(t.isfile('prova')) - t.remove('prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') + self.assertTrue(transport.isfile('prova')) + transport.remove('prova') # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_get(self, custom_transport): + """Test getting files.""" + # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute import os import random - import string, shutil + import shutil + import string local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') local_destination = os.path.join(local_dir, directory) @@ -1048,51 +1066,53 @@ def test_get(self, custom_transport): fhandle.write(text) # first test put. Copy of two files matching patterns, into a folder - t.get(os.path.join('local', '*.txt'), local_destination) + transport.get(os.path.join('local', '*.txt'), local_destination) self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(os.listdir(local_destination))) os.remove(os.path.join(local_destination, 'a.txt')) os.remove(os.path.join(local_destination, 'c.txt')) # second. Copy of folder into a non existing folder - t.get('local', os.path.join(local_destination, 'prova')) + transport.get('local', os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) self.assertEqual( - set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova')))) + set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova'))) + ) shutil.rmtree(os.path.join(local_destination, 'prova')) # third. copy of folder into an existing folder os.mkdir(os.path.join(local_destination, 'prova')) - t.get('local', os.path.join(local_destination, 'prova')) + transport.get('local', os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) self.assertEqual(set(['local']), set(os.listdir(os.path.join(local_destination, 'prova')))) self.assertEqual( - set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova', 'local')))) + set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova', 'local'))) + ) shutil.rmtree(os.path.join(local_destination, 'prova')) # third test copy. Can copy one file into a new file - t.get(os.path.join('local', '*.tmp'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', '*.tmp'), os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) os.remove(os.path.join(local_destination, 'prova')) # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.get(os.path.join('local', '*.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', '*.txt'), os.path.join(local_destination, 'prova')) # copy of folder into file with open(os.path.join(local_destination, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with self.assertRaises(OSError): - t.get('local', os.path.join(local_destination, 'existing.txt')) + transport.get('local', os.path.join(local_destination, 'existing.txt')) os.remove(os.path.join(local_destination, 'existing.txt')) # fifth test, copying one file into a folder os.mkdir(os.path.join(local_destination, 'prova')) - t.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) self.assertEqual(set(os.listdir(os.path.join(local_destination, 'prova'))), set(['a.txt'])) shutil.rmtree(os.path.join(local_destination, 'prova')) # sixth test, copying one file into a file - t.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) self.assertTrue(os.path.isfile(os.path.join(local_destination, 'prova'))) os.remove(os.path.join(local_destination, 'prova')) # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_put_get_abs_path(self, custom_transport): @@ -1107,8 +1127,8 @@ def test_put_get_abs_path(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -1121,7 +1141,7 @@ def test_put_get_abs_path(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') fhandle = open(local_file_name, 'w', encoding='utf8') @@ -1129,44 +1149,44 @@ def test_put_get_abs_path(self, custom_transport): # 'tmp1' is not an abs path with self.assertRaises(ValueError): - t.put('tmp1', remote_subfolder) + transport.put('tmp1', remote_subfolder) with self.assertRaises(ValueError): - t.putfile('tmp1', remote_subfolder) + transport.putfile('tmp1', remote_subfolder) with self.assertRaises(ValueError): - t.puttree('tmp1', remote_subfolder) + transport.puttree('tmp1', remote_subfolder) # 'tmp3' does not exist with self.assertRaises(OSError): - t.put(retrieved_subfolder, remote_subfolder) + transport.put(retrieved_subfolder, remote_subfolder) with self.assertRaises(OSError): - t.putfile(retrieved_subfolder, remote_subfolder) + transport.putfile(retrieved_subfolder, remote_subfolder) with self.assertRaises(OSError): - t.puttree(retrieved_subfolder, remote_subfolder) + transport.puttree(retrieved_subfolder, remote_subfolder) # remote_file_name does not exist with self.assertRaises(IOError): - t.get('non_existing', retrieved_subfolder) + transport.get('non_existing', retrieved_subfolder) with self.assertRaises(IOError): - t.getfile('non_existing', retrieved_subfolder) + transport.getfile('non_existing', retrieved_subfolder) with self.assertRaises(IOError): - t.gettree('non_existing', retrieved_subfolder) + transport.gettree('non_existing', retrieved_subfolder) - t.put(local_subfolder, remote_subfolder) + transport.put(local_subfolder, remote_subfolder) # local filename is not an abs path with self.assertRaises(ValueError): - t.get(remote_subfolder, 'delete_me_tree') + transport.get(remote_subfolder, 'delete_me_tree') with self.assertRaises(ValueError): - t.getfile(remote_subfolder, 'delete_me_tree') + transport.getfile(remote_subfolder, 'delete_me_tree') with self.assertRaises(ValueError): - t.gettree(remote_subfolder, 'delete_me_tree') + transport.gettree(remote_subfolder, 'delete_me_tree') os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_empty_string(self, custom_transport): @@ -1182,8 +1202,8 @@ def test_put_get_empty_string(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -1196,7 +1216,7 @@ def test_put_get_empty_string(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') text = 'Viva Verdi\n' @@ -1206,42 +1226,43 @@ def test_put_get_empty_string(self, custom_transport): # localpath is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.puttree('', remote_subfolder) + transport.puttree('', remote_subfolder) # remote path is an empty string with self.assertRaises(IOError): - t.puttree(local_subfolder, '') + transport.puttree(local_subfolder, '') - t.puttree(local_subfolder, remote_subfolder) + transport.puttree(local_subfolder, remote_subfolder) # remote path is an empty string with self.assertRaises(IOError): - t.gettree('', retrieved_subfolder) + transport.gettree('', retrieved_subfolder) # local path is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.gettree(remote_subfolder, '') + transport.gettree(remote_subfolder, '') # TODO : get doesn't retrieve empty files. # Is it what we want? - t.gettree(remote_subfolder, retrieved_subfolder) + transport.gettree(remote_subfolder, retrieved_subfolder) os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - t.remove(os.path.join(remote_subfolder, 'file.txt')) - t.rmdir(remote_subfolder) + transport.remove(os.path.join(remote_subfolder, 'file.txt')) + transport.rmdir(remote_subfolder) # If it couldn't end the copy, it leaves what he did on local file # here I am mixing local with remote - self.assertTrue('file.txt' in t.listdir('tmp3')) + self.assertTrue('file.txt' in transport.listdir('tmp3')) os.remove(os.path.join(retrieved_subfolder, 'file.txt')) os.rmdir(retrieved_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins - def test_gettree_nested_directory(self, custom_transport): + def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use + """Test `gettree` for a nested directory.""" import os import tempfile @@ -1278,66 +1299,69 @@ def test_exec_pwd(self, custom_transport): # Start value delete_at_end = False - with custom_transport as t: + with custom_transport as transport: # To compare with: getcwd uses the normalized ('realpath') path - location = t.normalize('/tmp') + location = transport.normalize('/tmp') subfolder = """_'s f"#""" # A folder with characters to escape subfolder_fullpath = os.path.join(location, subfolder) - t.chdir(location) - if not t.isdir(subfolder): + transport.chdir(location) + if not transport.isdir(subfolder): # Since I created the folder, I will remember to # delete it at the end of this test delete_at_end = True - t.mkdir(subfolder) + transport.mkdir(subfolder) - self.assertTrue(t.isdir(subfolder)) - t.chdir(subfolder) + self.assertTrue(transport.isdir(subfolder)) + transport.chdir(subfolder) - self.assertEqual(subfolder_fullpath, t.getcwd()) - retcode, stdout, stderr = t.exec_command_wait('pwd') + self.assertEqual(subfolder_fullpath, transport.getcwd()) + retcode, stdout, stderr = transport.exec_command_wait('pwd') self.assertEqual(retcode, 0) # I have to strip it because 'pwd' returns a trailing \n self.assertEqual(stdout.strip(), subfolder_fullpath) self.assertEqual(stderr, '') if delete_at_end: - t.chdir(location) - t.rmdir(subfolder) + transport.chdir(location) + transport.rmdir(subfolder) @run_for_all_plugins def test_exec_with_stdin_string(self, custom_transport): + """Test command execution with a stdin string.""" test_string = str('some_test String') - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=test_string) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=test_string) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_stdin_unicode(self, custom_transport): + """Test command execution with a unicode stdin string.""" test_string = 'some_test String' - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=test_string) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=test_string) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_stdin_filelike(self, custom_transport): - + """Test command execution with a stdin from filelike.""" test_string = 'some_test String' stdin = io.StringIO(test_string) - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=stdin) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=stdin) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_wrong_stdin(self, custom_transport): + """Test command execution with incorrect stdin string.""" # I pass a number - with custom_transport as t: + with custom_transport as transport: with self.assertRaises(ValueError): - t.exec_command_wait('cat', stdin=1) + transport.exec_command_wait('cat', stdin=1) diff --git a/tests/transports/test_local.py b/tests/transports/test_local.py index 712398183c..fa88e00468 100644 --- a/tests/transports/test_local.py +++ b/tests/transports/test_local.py @@ -7,9 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tests for the `LocalTransport`.""" import unittest -from aiida.transports.plugins.local import * +from aiida.transports.plugins.local import LocalTransport +from aiida.transports.transport import TransportInternalError # This will be used by test_all_plugins @@ -22,10 +24,11 @@ class TestGeneric(unittest.TestCase): """ def test_whoami(self): + """Test the `whoami` command.""" import getpass - with LocalTransport() as t: - self.assertEqual(t.whoami(), getpass.getuser()) + with LocalTransport() as transport: + self.assertEqual(transport.whoami(), getpass.getuser()) class TestBasicConnection(unittest.TestCase): @@ -34,16 +37,27 @@ class TestBasicConnection(unittest.TestCase): """ def test_closed_connection(self): - from aiida.transports.transport import TransportInternalError + """Test running a command on a closed connection.""" with self.assertRaises(TransportInternalError): - t = LocalTransport() - t.listdir() + transport = LocalTransport() + transport.listdir() - def test_basic(self): + @staticmethod + def test_basic(): + """Test constructor.""" with LocalTransport(): pass -if __name__ == '__main__': - unittest.main() +def test_gotocomputer(): + """Test gotocomputer""" + with LocalTransport() as transport: + cmd_str = transport.gotocomputer_command('/remote_dir/') + + expected_str = ( + """bash -c "if [ -d '/remote_dir/' ] ;""" + """ then cd '/remote_dir/' ; bash -l ; else echo ' ** The directory' ; """ + """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ + ) + assert cmd_str == expected_str diff --git a/tests/transports/test_ssh.py b/tests/transports/test_ssh.py index d7a0a7d2a4..6c8c713622 100644 --- a/tests/transports/test_ssh.py +++ b/tests/transports/test_ssh.py @@ -7,16 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Test ssh plugin on localhost -""" -import unittest +"""Test the `SshTransport` plugin on localhost.""" import logging +import unittest -import aiida.transports -import aiida.transports.transport import paramiko + from aiida.transports.plugins.ssh import SshTransport +from aiida.transports.transport import TransportInternalError # This will be used by test_all_plugins @@ -29,20 +27,25 @@ class TestBasicConnection(unittest.TestCase): """ def test_closed_connection_ssh(self): - with self.assertRaises(aiida.transports.transport.TransportInternalError): - t = SshTransport(machine='localhost') - t._exec_command_internal('ls') + """Test calling command on a closed connection.""" + with self.assertRaises(TransportInternalError): + transport = SshTransport(machine='localhost') + transport._exec_command_internal('ls') # pylint: disable=protected-access def test_closed_connection_sftp(self): - with self.assertRaises(aiida.transports.transport.TransportInternalError): - t = SshTransport(machine='localhost') - t.listdir() + """Test calling sftp command on a closed connection.""" + with self.assertRaises(TransportInternalError): + transport = SshTransport(machine='localhost') + transport.listdir() - def test_auto_add_policy(self): + @staticmethod + def test_auto_add_policy(): + """Test the auto add policy.""" with SshTransport(machine='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy'): pass def test_no_host_key(self): + """Test if there is no host key.""" # Disable logging to avoid output during test logging.disable(logging.ERROR) @@ -54,5 +57,14 @@ def test_no_host_key(self): logging.disable(logging.NOTSET) -if __name__ == '__main__': - unittest.main() +def test_gotocomputer(): + """Test gotocomputer""" + with SshTransport(machine='localhost', timeout=30, use_login_shell=False, key_policy='AutoAddPolicy') as transport: + cmd_str = transport.gotocomputer_command('/remote_dir/') + + expected_str = ( + """ssh -t localhost "if [ -d '/remote_dir/' ] ;""" + """ then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """ + """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ + ) + assert cmd_str == expected_str diff --git a/tests/utils/archives.py b/tests/utils/archives.py index a9c0deb125..f9964263ef 100644 --- a/tests/utils/archives.py +++ b/tests/utils/archives.py @@ -17,6 +17,7 @@ from aiida.common.exceptions import NotExistent from aiida.tools.importexport.common.archive import extract_tar, extract_zip from aiida.common.folders import SandboxFolder +from tests.static import STATIC_DIR def get_archive_file(archive, filepath=None, external_module=None): @@ -24,7 +25,7 @@ def get_archive_file(archive, filepath=None, external_module=None): The expected path for these files: - tests.fixtures.filepath + tests.static.filepath :param archive: the relative filename of the archive :param filepath: str of directories of where to find archive (starting "/"s are irrelevant) @@ -55,11 +56,8 @@ def get_archive_file(archive, filepath=None, external_module=None): dirpath_archive = os.path.join(external_path, dirpath_archive) else: - # Add absolute path to local repo's fixtures - dirpath_current = os.path.dirname(os.path.realpath(__file__)) - dirpath_migrate = os.path.join(dirpath_current, os.pardir, 'fixtures') - - dirpath_archive = os.path.join(dirpath_migrate, dirpath_archive) + # Add absolute path to local repo's static + dirpath_archive = os.path.join(STATIC_DIR, dirpath_archive) if not os.path.isfile(dirpath_archive): dirpath_parent = os.path.dirname(dirpath_archive) diff --git a/tests/utils/processes.py b/tests/utils/processes.py index 7115b4fdc5..2d8d7fd6f9 100644 --- a/tests/utils/processes.py +++ b/tests/utils/processes.py @@ -95,7 +95,7 @@ def define(cls, spec): 123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True ) - def run(self): # pylint: disable=inconsistent-return-statements + def run(self): if self.inputs.return_exit_code: return self.exit_codes.GENERIC_EXIT_CODE # pylint: disable=no-member diff --git a/utils/dependency_management.py b/utils/dependency_management.py index dbd79fb499..5a04142ca7 100755 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -21,7 +21,7 @@ import click import yaml -import toml +import tomlkit as toml ROOT = Path(__file__).resolve().parent.parent # repository root @@ -161,13 +161,23 @@ def generate_environment_yml(): @cli.command() -def generate_pyproject_toml(): - """Generate 'pyproject.toml' file.""" +def update_pyproject_toml(): + """Generate a 'pyproject.toml' file, or update an existing one. + + This function generates/updates the ``build-system`` section, + to be consistent with the 'setup.json' file. + """ + + # read the current file + toml_path = ROOT / 'pyproject.toml' + if toml_path.exists(): + pyproject = toml.loads(toml_path.read_text(encoding='utf8')) + else: + pyproject = {} # Read the requirements from 'setup.json' setup_cfg = _load_setup_cfg() install_requirements = [Requirement.parse(r) for r in setup_cfg['install_requires']] - for requirement in install_requirements: if requirement.name == 'reentry': reentry_requirement = requirement @@ -175,15 +185,17 @@ def generate_pyproject_toml(): else: raise DependencySpecificationError("Failed to find reentry requirement in 'setup.json'.") - pyproject = { - 'build-system': { - 'requires': ['setuptools>=40.8.0', 'wheel', - str(reentry_requirement), 'fastentrypoints~=0.12'], - 'build-backend': 'setuptools.build_meta:__legacy__', - } - } - with open(ROOT / 'pyproject.toml', 'w') as file: - toml.dump(pyproject, file) + # update the build-system key + pyproject.setdefault('build-system', {}) + pyproject['build-system'].update({ + 'requires': ['setuptools>=40.8.0,<50', 'wheel', + str(reentry_requirement), 'fastentrypoints~=0.12'], + 'build-backend': + 'setuptools.build_meta:__legacy__', + }) + + # write the new file + toml_path.write_text(toml.dumps(pyproject), encoding='utf8') @cli.command() @@ -191,7 +203,7 @@ def generate_pyproject_toml(): def generate_all(ctx): """Generate all dependent requirement files.""" ctx.invoke(generate_environment_yml) - ctx.invoke(generate_pyproject_toml) + ctx.invoke(update_pyproject_toml) @cli.command('validate-environment-yml', help="Validate 'environment.yml'.") @@ -281,18 +293,15 @@ def validate_pyproject_toml(): else: raise DependencySpecificationError("Failed to find reentry requirement in 'setup.json'.") - try: - with open(ROOT / 'pyproject.toml') as file: - pyproject = toml.load(file) - pyproject_requires = [Requirement.parse(r) for r in pyproject['build-system']['requires']] + pyproject_file = ROOT / 'pyproject.toml' + if not pyproject_file.exists(): + raise DependencySpecificationError("The 'pyproject.toml' file is missing!") - if reentry_requirement not in pyproject_requires: - raise DependencySpecificationError( - "Missing requirement '{}' in 'pyproject.toml'.".format(reentry_requirement) - ) + pyproject = toml.loads(pyproject_file.read_text(encoding='utf8')) + pyproject_requires = [Requirement.parse(r) for r in pyproject['build-system']['requires']] - except FileNotFoundError: - raise DependencySpecificationError("The 'pyproject.toml' file is missing!") + if reentry_requirement not in pyproject_requires: + raise DependencySpecificationError("Missing requirement '{}' in 'pyproject.toml'.".format(reentry_requirement)) click.secho('Pyproject.toml dependency specification is consistent.', fg='green') diff --git a/utils/validate_consistency.py b/utils/validate_consistency.py index f592aecb76..696f87366f 100644 --- a/utils/validate_consistency.py +++ b/utils/validate_consistency.py @@ -17,11 +17,11 @@ * reentry dependency in pyproject.toml """ - +import collections +import json import os import sys -import json -from collections import OrderedDict + import click FILENAME_TOML = 'pyproject.toml' @@ -35,7 +35,7 @@ def get_setup_json(): """Return the `setup.json` as a python dictionary """ with open(FILEPATH_SETUP_JSON, 'r') as fil: - return json.load(fil, object_pairs_hook=OrderedDict) + return json.load(fil, object_pairs_hook=collections.OrderedDict) def write_setup_json(data): @@ -143,8 +143,10 @@ def validate_verdi_documentation(): from click import Context from aiida.cmdline.commands.cmd_verdi import verdi + width = 90 # The maximum width of the formatted help strings in characters + # Set the `verdi data` command to isolated mode such that external plugin commands are not discovered - ctx = Context(verdi) + ctx = Context(verdi, terminal_width=width) command = verdi.get_command(ctx, 'data') command.set_exclude_external_plugins(True) @@ -159,7 +161,7 @@ def validate_verdi_documentation(): block = ['{}\n{}\n{}\n\n'.format(header, '=' * len(header), message)] for name, command in sorted(verdi.commands.items()): - ctx = click.Context(command) + ctx = click.Context(command, terminal_width=width) header_label = '.. _reference:command-line:verdi-{name:}:'.format(name=name) header_string = '``verdi {name:}``'.format(name=name) @@ -168,7 +170,7 @@ def validate_verdi_documentation(): block.append(header_label + '\n\n') block.append(header_string + '\n') block.append(header_underline + '\n\n') - block.append('::\n\n') # Mark the beginning of a literal block + block.append('.. code:: console\n\n') # Mark the beginning of a literal block for line in ctx.get_help().split('\n'): if line: block.append(' {}\n'.format(line))