diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 656c8d7e..ad524fb2 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -88,7 +88,7 @@ jobs: run: pip install -e .[tests] - name: Run pytest - run: pytest -sv --cov=plumpy test + run: pytest -s --cov=plumpy tests - name: Create xml coverage run: coverage xml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f324ec7f..8d813780 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: continuous-integration +name: ci on: [push, pull_request] @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.12' - name: Install Python dependencies run: pip install -e .[pre-commit] @@ -26,6 +26,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + fail-fast: false services: rabbitmq: @@ -42,10 +43,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install python dependencies - run: pip install -e .[tests] + run: pip install .[tests] - name: Run pytest - run: pytest -sv --cov=plumpy test + run: pytest -s --cov=plumpy tests/ - name: Create xml coverage run: coverage xml diff --git a/.github/workflows/validate_release_tag.py b/.github/workflows/validate_release_tag.py index bdd35537..4caf68b8 100644 --- a/.github/workflows/validate_release_tag.py +++ b/.github/workflows/validate_release_tag.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Validate that the version in the tag label matches the version of the package.""" + import argparse import ast from pathlib import Path @@ -17,8 +18,11 @@ def get_version_from_module(content: str) -> str: try: return next( - ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) - for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ast.literal_eval(statement.value) + for statement in module.body + if isinstance(statement, ast.Assign) + for target in statement.targets + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration as exception: raise IOError('Unable to find the `__version__` attribute in the module.') from exception diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cae9888f..970f9d71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,46 +1,35 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: double-quote-string-fixer - - id: end-of-file-fixer - - id: fix-encoding-pragma - - id: mixed-line-ending - - id: trailing-whitespace + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: mixed-line-ending + - id: trailing-whitespace -- repo: https://github.com/ikamensh/flynt/ - rev: '0.77' + - repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 hooks: - - id: flynt + - id: flynt + args: [--line-length=120, --fail-on-change] -- repo: https://github.com/pycqa/isort - rev: '5.12.0' + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.0 hooks: - - id: isort + - id: ruff-format + exclude: &exclude_ruff > + (?x)^( + tests/.*| + )$ -- repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 - hooks: - - id: yapf - name: yapf - types: [python] - args: ['-i'] - additional_dependencies: ['toml'] - -- repo: https://github.com/PyCQA/pylint - rev: v2.15.8 - hooks: - - id: pylint - language: system - exclude: > - (?x)^( - docs/source/conf.py| - test/.*| - )$ + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] -- repo: local + - repo: local hooks: - - id: mypy + - id: mypy name: mypy entry: mypy args: [--config-file=pyproject.toml] @@ -49,6 +38,6 @@ repos: require_serial: true pass_filenames: true files: >- - (?x)^( - src/.*py| - )$ + (?x)^( + src/.*py| + )$ diff --git a/docs/source/conf.py b/docs/source/conf.py index a1c6f26e..b1a2a019 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,11 +8,9 @@ import filecmp import os -from pathlib import Path import shutil -import subprocess -import sys import tempfile +from pathlib import Path import plumpy @@ -32,8 +30,12 @@ master_doc = 'index' language = None extensions = [ - 'myst_nb', 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', - 'IPython.sphinxext.ipython_console_highlighting' + 'myst_nb', + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.viewcode', + 'sphinx.ext.intersphinx', + 'IPython.sphinxext.ipython_console_highlighting', ] # List of patterns, relative to source directory, that match files and @@ -46,14 +48,14 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3.8', None), - 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None) + 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None), } myst_enable_extensions = ['colon_fence', 'deflist', 'html_image', 'smartquotes', 'substitution'] myst_url_schemes = ('http', 'https', 'mailto') myst_substitutions = { 'rabbitmq': '[RabbitMQ](https://www.rabbitmq.com/)', - 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)' + 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)', } jupyter_execute_notebooks = 'cache' execution_show_tb = 'READTHEDOCS' in os.environ @@ -84,7 +86,7 @@ 'use_issues_button': True, 'path_to_docs': 'docs', 'use_edit_page_button': True, - 'extra_navbar': '' + 'extra_navbar': '', } # API Documentation @@ -112,9 +114,17 @@ def run_apidoc(app): # this ensures that document rebuilds are not triggered every time (due to change in file mtime) with tempfile.TemporaryDirectory() as tmpdirname: options = [ - '-o', tmpdirname, - str(package_dir), '--private', '--force', '--module-first', '--separate', '--no-toc', '--maxdepth', '4', - '-q' + '-o', + tmpdirname, + str(package_dir), + '--private', + '--force', + '--module-first', + '--separate', + '--no-toc', + '--maxdepth', + '4', + '-q', ] os.environ['SPHINX_APIDOC_OPTIONS'] = 'members,special-members,private-members,undoc-members,show-inheritance' diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 90194728..c1fdb3b2 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -34,10 +34,11 @@ "outputs": [], "source": [ "import asyncio\n", - "from pprint import pprint\n", "import time\n", + "from pprint import pprint\n", "\n", "import kiwipy\n", + "\n", "import plumpy\n", "\n", "# this is required because jupyter is already running an event loop\n", @@ -116,16 +117,16 @@ ], "source": [ "class SimpleProcess(plumpy.Process):\n", - "\n", " def run(self):\n", " print(self.state.name)\n", - " \n", + "\n", + "\n", "process = SimpleProcess()\n", "print(process.state.name)\n", "process.execute()\n", "print(process.state.name)\n", - "print(\"Success\", process.is_successful)\n", - "print(\"Result\", process.result())" + "print('Success', process.is_successful)\n", + "print('Result', process.result())" ] }, { @@ -204,17 +205,16 @@ ], "source": [ "class SpecProcess(plumpy.Process):\n", - " \n", " @classmethod\n", " def define(cls, spec: plumpy.ProcessSpec):\n", " super().define(spec)\n", " spec.input('input1', valid_type=str, help='A help string')\n", " spec.output('output1')\n", - " \n", + "\n", " spec.input_namespace('input2')\n", " spec.input('input2.input2a')\n", " spec.input('input2.input2b', default='default')\n", - " \n", + "\n", " spec.output_namespace('output2')\n", " spec.output('output2.output2a')\n", " spec.output('output2.output2b')\n", @@ -223,12 +223,10 @@ " self.out('output1', self.inputs.input1)\n", " self.out('output2.output2a', self.inputs.input2.input2a)\n", " self.out('output2.output2b', self.inputs.input2.input2b)\n", - " \n", + "\n", + "\n", "pprint(SpecProcess.spec().get_description())\n", - "process = SpecProcess(inputs={\n", - " 'input1': 'my input',\n", - " 'input2': {'input2a': 'other input'}\n", - "})\n", + "process = SpecProcess(inputs={'input1': 'my input', 'input2': {'input2a': 'other input'}})\n", "process.execute()\n", "process.outputs" ] @@ -276,20 +274,20 @@ ], "source": [ "class ContinueProcess(plumpy.Process):\n", - "\n", " def run(self):\n", - " print(\"running\")\n", + " print('running')\n", " return plumpy.Continue(self.continue_fn)\n", - " \n", + "\n", " def continue_fn(self):\n", - " print(\"continuing\")\n", + " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill(\"I was killed\")\n", - " \n", + " return plumpy.Kill('I was killed')\n", + "\n", + "\n", "process = ContinueProcess()\n", "try:\n", " process.execute()\n", - "except plumpy.KilledError as error:\n", + "except plumpy.KilledError:\n", " pass\n", "\n", "print(process.state)\n", @@ -330,7 +328,6 @@ ], "source": [ "class WaitListener(plumpy.ProcessListener):\n", - "\n", " def on_process_running(self, process):\n", " print(process.state.name)\n", "\n", @@ -338,14 +335,15 @@ " print(process.state.name)\n", " process.resume()\n", "\n", - "class WaitProcess(plumpy.Process):\n", "\n", + "class WaitProcess(plumpy.Process):\n", " def run(self):\n", " return plumpy.Wait(self.resume_fn)\n", - " \n", + "\n", " def resume_fn(self):\n", " return plumpy.Stop(None, True)\n", "\n", + "\n", "process = WaitProcess()\n", "print(process.state.name)\n", "\n", @@ -394,33 +392,32 @@ ], "source": [ "async def async_fn():\n", - " print(\"async_fn start\")\n", - " await asyncio.sleep(.01)\n", - " print(\"async_fn end\")\n", + " print('async_fn start')\n", + " await asyncio.sleep(0.01)\n", + " print('async_fn end')\n", + "\n", "\n", "class NamedProcess(plumpy.Process):\n", - " \n", " @classmethod\n", " def define(cls, spec: plumpy.ProcessSpec):\n", " super().define(spec)\n", " spec.input('name')\n", "\n", " def run(self):\n", - " print(self.inputs.name, \"run\")\n", + " print(self.inputs.name, 'run')\n", " return plumpy.Continue(self.continue_fn)\n", "\n", " def continue_fn(self):\n", - " print(self.inputs.name, \"continued\")\n", + " print(self.inputs.name, 'continued')\n", + "\n", + "\n", + "process1 = NamedProcess({'name': 'process1'})\n", + "process2 = NamedProcess({'name': 'process2'})\n", "\n", - "process1 = NamedProcess({\"name\": \"process1\"})\n", - "process2 = NamedProcess({\"name\": \"process2\"})\n", "\n", "async def execute():\n", - " await asyncio.gather(\n", - " async_fn(),\n", - " process1.step_until_terminated(),\n", - " process2.step_until_terminated()\n", - " )\n", + " await asyncio.gather(async_fn(), process1.step_until_terminated(), process2.step_until_terminated())\n", + "\n", "\n", "plumpy.get_event_loop().run_until_complete(execute())" ] @@ -468,31 +465,33 @@ ], "source": [ "class SimpleProcess(plumpy.Process):\n", - " \n", " def run(self):\n", " print(self.get_name())\n", - " \n", - "class PauseProcess(plumpy.Process):\n", "\n", + "\n", + "class PauseProcess(plumpy.Process):\n", " def run(self):\n", - " print(f\"{self.get_name()}: pausing\")\n", + " print(f'{self.get_name()}: pausing')\n", " self.pause()\n", - " print(f\"{self.get_name()}: continue step\")\n", + " print(f'{self.get_name()}: continue step')\n", " return plumpy.Continue(self.next_step)\n", - " \n", + "\n", " def next_step(self):\n", - " print(f\"{self.get_name()}: next step\")\n", + " print(f'{self.get_name()}: next step')\n", + "\n", "\n", "pause_proc = PauseProcess()\n", "simple_proc = SimpleProcess()\n", "\n", + "\n", "async def play(proc):\n", " while True:\n", " if proc.paused:\n", - " print(f\"{proc.get_name()}: playing (state={proc.state.name})\")\n", + " print(f'{proc.get_name()}: playing (state={proc.state.name})')\n", " proc.play()\n", " break\n", "\n", + "\n", "async def execute():\n", " return await asyncio.gather(\n", " pause_proc.step_until_terminated(),\n", @@ -500,6 +499,7 @@ " play(pause_proc),\n", " )\n", "\n", + "\n", "outputs = plumpy.get_event_loop().run_until_complete(execute())" ] }, @@ -555,7 +555,8 @@ "\n", " def step2(self):\n", " print('step2')\n", - " \n", + "\n", + "\n", "workchain = SimpleWorkChain()\n", "output = workchain.execute()" ] @@ -601,11 +602,7 @@ " super().define(spec)\n", " spec.input('run', valid_type=bool)\n", "\n", - " spec.outline(\n", - " plumpy.if_(cls.if_step)(\n", - " cls.conditional_step\n", - " )\n", - " )\n", + " spec.outline(plumpy.if_(cls.if_step)(cls.conditional_step))\n", "\n", " def if_step(self):\n", " print(' if')\n", @@ -613,12 +610,13 @@ "\n", " def conditional_step(self):\n", " print(' conditional')\n", - " \n", - "workchain = IfWorkChain({\"run\": False})\n", + "\n", + "\n", + "workchain = IfWorkChain({'run': False})\n", "print('execute False')\n", "output = workchain.execute()\n", "\n", - "workchain = IfWorkChain({\"run\": True})\n", + "workchain = IfWorkChain({'run': True})\n", "print('execute True')\n", "output = workchain.execute()" ] @@ -666,23 +664,19 @@ " super().define(spec)\n", " spec.input('steps', valid_type=int, default=3)\n", "\n", - " spec.outline(\n", - " cls.init_step,\n", - " plumpy.while_(cls.while_step)(\n", - " cls.conditional_step\n", - " )\n", - " )\n", - " \n", + " spec.outline(cls.init_step, plumpy.while_(cls.while_step)(cls.conditional_step))\n", + "\n", " def init_step(self):\n", " self.ctx.iterator = 0\n", "\n", " def while_step(self):\n", " self.ctx.iterator += 1\n", - " return (self.ctx.iterator <= self.inputs.steps)\n", + " return self.ctx.iterator <= self.inputs.steps\n", "\n", " def conditional_step(self):\n", " print('step', self.ctx.iterator)\n", - " \n", + "\n", + "\n", "workchain = WhileWorkChain()\n", "output = workchain.execute()" ] @@ -714,13 +708,12 @@ "outputs": [], "source": [ "async def awaitable_func(msg):\n", - " await asyncio.sleep(.01)\n", + " await asyncio.sleep(0.01)\n", " print(msg)\n", " return True\n", - " \n", "\n", - "class InternalProcess(plumpy.Process):\n", "\n", + "class InternalProcess(plumpy.Process):\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -733,7 +726,6 @@ "\n", "\n", "class InterstepWorkChain(plumpy.WorkChain):\n", - "\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -745,31 +737,24 @@ " cls.step2,\n", " cls.step3,\n", " )\n", - " \n", + "\n", " def step1(self):\n", " print(self.inputs.name, 'step1')\n", "\n", " def step2(self):\n", " print(self.inputs.name, 'step2')\n", - " time.sleep(.01)\n", - " \n", + " time.sleep(0.01)\n", + "\n", " if self.inputs.awaitable:\n", " self.to_context(\n", - " awaitable=asyncio.ensure_future(\n", - " awaitable_func(f'{self.inputs.name} step2 awaitable'),\n", - " loop=self.loop\n", - " )\n", + " awaitable=asyncio.ensure_future(awaitable_func(f'{self.inputs.name} step2 awaitable'), loop=self.loop)\n", " )\n", " if self.inputs.process:\n", - " self.to_context(\n", - " process=self.launch(\n", - " InternalProcess, \n", - " inputs={'name': f'{self.inputs.name} step2 process'})\n", - " )\n", + " self.to_context(process=self.launch(InternalProcess, inputs={'name': f'{self.inputs.name} step2 process'}))\n", "\n", " def step3(self):\n", " print(self.inputs.name, 'step3')\n", - " print(f\" ctx={self.ctx}\")" + " print(f' ctx={self.ctx}')" ] }, { @@ -803,11 +788,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1'})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2'})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -847,11 +831,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'process': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'process': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -882,11 +865,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'awaitable': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'awaitable': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -926,11 +908,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'process': True, 'awaitable': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'process': True, 'awaitable': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -972,8 +953,8 @@ "source": [ "persister = plumpy.InMemoryPersister()\n", "\n", - "class PersistWorkChain(plumpy.WorkChain):\n", "\n", + "class PersistWorkChain(plumpy.WorkChain):\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -982,10 +963,10 @@ " cls.step2,\n", " cls.step3,\n", " )\n", - " \n", + "\n", " def __repr__(self):\n", - " return f\"PersistWorkChain(ctx={self.ctx})\"\n", - " \n", + " return f'PersistWorkChain(ctx={self.ctx})'\n", + "\n", " def init_step(self):\n", " self.ctx.step = 1\n", " persister.save_checkpoint(self, 'init')\n", @@ -997,7 +978,8 @@ " def step3(self):\n", " self.ctx.step += 1\n", " persister.save_checkpoint(self, 'step3')\n", - " \n", + "\n", + "\n", "workchain = PersistWorkChain()\n", "workchain.execute()\n", "workchain" @@ -1129,9 +1111,11 @@ "source": [ "communicator = kiwipy.LocalCommunicator()\n", "\n", + "\n", "class SimpleProcess(plumpy.Process):\n", " pass\n", "\n", + "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" @@ -1161,43 +1145,42 @@ ], "source": [ "class ControlledWorkChain(plumpy.WorkChain):\n", - "\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", " spec.input('steps', valid_type=int, default=10)\n", " spec.output('result', valid_type=int)\n", "\n", - " spec.outline(\n", - " cls.init_step,\n", - " plumpy.while_(cls.while_step)(cls.loop_step),\n", - " cls.final_step\n", - " )\n", - " \n", + " spec.outline(cls.init_step, plumpy.while_(cls.while_step)(cls.loop_step), cls.final_step)\n", + "\n", " def init_step(self):\n", " self.ctx.iterator = 0\n", "\n", " def while_step(self):\n", - " return (self.ctx.iterator <= self.inputs.steps)\n", - " \n", + " return self.ctx.iterator <= self.inputs.steps\n", + "\n", " def loop_step(self):\n", " self.ctx.iterator += 1\n", "\n", " def final_step(self):\n", " self.out('result', self.ctx.iterator)\n", "\n", + "\n", "loop_communicator = plumpy.wrap_communicator(kiwipy.LocalCommunicator())\n", "loop_communicator.add_task_subscriber(plumpy.ProcessLauncher())\n", "controller = plumpy.RemoteProcessController(loop_communicator)\n", "\n", "wkchain = ControlledWorkChain(communicator=loop_communicator)\n", - " \n", + "\n", + "\n", "async def run_wait():\n", " return await controller.launch_process(ControlledWorkChain)\n", "\n", + "\n", "async def run_nowait():\n", " return await controller.launch_process(ControlledWorkChain, nowait=True)\n", "\n", + "\n", "print(plumpy.get_event_loop().run_until_complete(run_wait()))\n", "print(plumpy.get_event_loop().run_until_complete(run_nowait()))" ] diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index cf043eba..db2eff0f 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -3,7 +3,6 @@ class HelloWorld(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) diff --git a/examples/process_wait_and_resume.py b/examples/process_wait_and_resume.py index 03e8b57a..d4aa20b4 100644 --- a/examples/process_wait_and_resume.py +++ b/examples/process_wait_and_resume.py @@ -5,7 +5,6 @@ class WaitForResumeProc(plumpy.Process): - def run(self): print(f'Now I am running: {self.state}') return plumpy.Wait(self.after_resume_and_exec) @@ -15,12 +14,10 @@ def after_resume_and_exec(self): kwargs = { - 'connection_params': { - 'url': 'amqp://guest:guest@127.0.0.1:5672/' - }, + 'connection_params': {'url': 'amqp://guest:guest@127.0.0.1:5672/'}, 'message_exchange': 'WaitForResume.uuid-0', 'task_exchange': 'WaitForResume.uuid-0', - 'task_queue': 'WaitForResume.uuid-0' + 'task_queue': 'WaitForResume.uuid-0', } if __name__ == '__main__': diff --git a/examples/workchain_simple.py b/examples/workchain_simple.py index 078de3ca..aa189d3b 100644 --- a/examples/workchain_simple.py +++ b/examples/workchain_simple.py @@ -3,7 +3,6 @@ class AddAndMulWF(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/pyproject.toml b/pyproject.toml index 78c34807..9d996d5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,16 +53,15 @@ docs = [ 'importlib-metadata~=4.12.0', ] pre-commit = [ - 'mypy==1.3.0', + 'mypy==1.13.0', 'pre-commit~=2.2', - 'pylint==2.15.8', 'types-pyyaml' ] tests = [ 'ipykernel==6.12.1', - 'pytest==6.2.5', - 'pytest-asyncio==0.16.0', - 'pytest-cov==3.0.0', + 'pytest~=7.0', + 'pytest-asyncio~=0.12,<0.17', + 'pytest-cov~=4.1', 'pytest-notebook>=0.8.0', 'shortuuid==1.0.8', 'importlib-resources~=5.2', @@ -75,18 +74,43 @@ name = 'plumpy' exclude = [ 'docs/', 'examples/', - 'test/', + 'tests/', ] [tool.flynt] line-length = 120 fail-on-change = true -[tool.isort] -force_sort_within_sections = true -include_trailing_comma = true -line_length = 120 -multi_line_output = 3 +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'single' + +[tool.ruff.lint] +ignore = [ + 'F403', # Star imports unable to detect undefined names + 'F405', # Import may be undefined or defined from star imports + 'PLR0911', # Too many return statements + 'PLR0912', # Too many branches + 'PLR0913', # Too many arguments in function definition + 'PLR0915', # Too many statements + 'PLR2004', # Magic value used in comparison + 'RUF005', # Consider iterable unpacking instead of concatenation + 'RUF012' # Mutable class attributes should be annotated with `typing.ClassVar` +] +select = [ + 'E', # pydocstyle + 'W', # pydocstyle + 'F', # pyflakes + 'I', # isort + 'N', # pep8-naming + 'PLC', # pylint-convention + 'PLE', # pylint-error + 'PLR', # pylint-refactor + 'PLW', # pylint-warning + 'RUF' # ruff +] [tool.mypy] show_error_codes = true diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index ea88f872..6f94b5bf 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined -# pylint: disable=undefined-variable __version__ = '0.22.3' import logging @@ -21,9 +20,20 @@ from .workchains import * __all__ = ( - events.__all__ + exceptions.__all__ + processes.__all__ + utils.__all__ + futures.__all__ + mixins.__all__ + - persistence.__all__ + communications.__all__ + process_comms.__all__ + process_listener.__all__ + - workchains.__all__ + loaders.__all__ + ports.__all__ + process_states.__all__ + events.__all__ + + exceptions.__all__ + + processes.__all__ + + utils.__all__ + + futures.__all__ + + mixins.__all__ + + persistence.__all__ + + communications.__all__ + + process_comms.__all__ + + process_listener.__all__ + + workchains.__all__ + + loaders.__all__ + + ports.__all__ + + process_states.__all__ ) @@ -32,7 +42,6 @@ # https://docs.python.org/3.1/library/logging.html#library-config # for more details class NullHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: pass diff --git a/src/plumpy/base/__init__.py b/src/plumpy/base/__init__.py index 79450590..a4e3132e 100644 --- a/src/plumpy/base/__init__.py +++ b/src/plumpy/base/__init__.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -# pylint: disable=undefined-variable -# type: ignore from .state_machine import * from .utils import * -__all__ = (state_machine.__all__ + utils.__all__) +__all__ = state_machine.__all__ + utils.__all__ # type: ignore[name-defined] diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index b62825e1..d99d0705 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The state machine for processes""" + import enum import functools import inspect @@ -13,24 +14,24 @@ from .utils import call_with_super_check, super_check -__all__ = ['StateMachine', 'StateMachineMeta', 'event', 'TransitionFailed'] +__all__ = ['StateMachine', 'StateMachineMeta', 'TransitionFailed', 'event'] _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] # pylint: disable=invalid-name -EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] # pylint: disable=invalid-name +LABEL_TYPE = Union[None, enum.Enum, str] +EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] class StateMachineError(Exception): """Base class for state machine errors""" -class StateEntryFailed(Exception): +class StateEntryFailed(Exception): # noqa: N818 """ Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg + def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -42,20 +43,16 @@ class InvalidStateError(Exception): class EventError(StateMachineError): - def __init__(self, evt: str, msg: str): super().__init__(msg) self.event = evt -class TransitionFailed(Exception): +class TransitionFailed(Exception): # noqa: N818 """A state transition failed""" def __init__( - self, - initial_state: 'State', - final_state: Optional['State'] = None, - traceback_str: Optional[str] = None + self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None ) -> None: self.initial_state = initial_state self.final_state = final_state @@ -71,7 +68,7 @@ def _format_msg(self) -> str: def event( from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', - to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*' + to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" if from_states != '*': @@ -102,8 +99,8 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: raise EventError(evt_label, 'Machine did not transition') raise EventError( - evt_label, 'Event produced invalid state transition from ' - f'{initial.LABEL} to {self._state.LABEL}' + evt_label, + 'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}', ) return result @@ -126,7 +123,7 @@ class State: def is_terminal(cls) -> bool: return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): # pylint: disable=unused-argument + def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to """ @@ -138,12 +135,12 @@ def __str__(self) -> str: @property def label(self) -> LABEL_TYPE: - """ Convenience property to get the state label """ + """Convenience property to get the state label""" return self.LABEL @super_check def enter(self) -> None: - """ Entering the state """ + """Entering the state""" def execute(self) -> Optional['State']: """ @@ -153,7 +150,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: - """ Exiting the state """ + """Exiting the state""" if self.is_terminal(): raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -175,13 +172,13 @@ class StateEventHook(enum.Enum): procedure. The callback will be passed a state instance whose meaning will differ depending on the hook as commented below. """ + ENTERING_STATE: int = 0 # State passed will be the state that is being entered ENTERED_STATE: int = 1 # State passed will be the last state that we entered from EXITING_STATE: int = 2 # State passed will be the next state that will be entered (or None for terminal) class StateMachineMeta(type): - def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': """ Create the state machine and enter the initial state. @@ -220,13 +217,13 @@ def get_states(cls) -> Sequence[Type[State]]: def initial_state_label(cls) -> LABEL_TYPE: cls.__ensure_built() assert cls.STATES is not None - return cls.STATES[0].LABEL # pylint: disable=unsubscriptable-object + return cls.STATES[0].LABEL @classmethod def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None - return cls._STATES_MAP[label] # pylint: disable=unsubscriptable-object + return cls._STATES_MAP[label] @classmethod def __ensure_built(cls) -> None: @@ -238,15 +235,15 @@ def __ensure_built(cls) -> None: pass cls.STATES = cls.get_states() - assert isinstance(cls.STATES, Iterable) # pylint: disable=isinstance-second-argument-not-valid-type + assert isinstance(cls.STATES, Iterable) # Build the states map cls._STATES_MAP = {} - for state_cls in cls.STATES: # pylint: disable=not-an-iterable + for state_cls in cls.STATES: assert issubclass(state_cls, State) label = state_cls.LABEL - assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" # pylint: disable=unsupported-membership-test - cls._STATES_MAP[label] = state_cls # pylint: disable=unsupported-assignment-operation + assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" + cls._STATES_MAP[label] = state_cls # should class initialise sealed = False? cls.sealed = True # type: ignore @@ -301,11 +298,10 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: @super_check def on_terminated(self) -> None: - """ Called when a terminal state is entered """ + """Called when a terminal state is entered""" def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, \ - 'Cannot call transition_to when already transitioning state' + assert not self._transitioning, 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -331,7 +327,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A if self._state is not None and self._state.is_terminal(): call_with_super_check(self.on_terminated) - except Exception: # pylint: disable=broad-except + except Exception: self._transitioning = False if self._transition_failing: raise @@ -360,12 +356,12 @@ def set_debug(self, enabled: bool) -> None: def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: try: - return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object + return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') def _exit_current_state(self, next_state: State) -> None: - """ Exit the given state """ + """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state @@ -401,6 +397,6 @@ def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State return state try: - return self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object + return self.get_states_map()[cast(Hashable, state)] except KeyError: raise ValueError(f'{state} is not a valid state') diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index 232c5d26..8c35b903 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import Any, Callable -__all__ = ['super_check', 'call_with_super_check'] +__all__ = ['call_with_super_check', 'super_check'] def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index 51dff60d..1d7e775b 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for general kiwipy communication methods""" + import asyncio import functools from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional @@ -10,7 +11,12 @@ from .utils import ensure_coroutine __all__ = [ - 'Communicator', 'RemoteException', 'DeliveryFailed', 'TaskRejected', 'plum_to_kiwi_future', 'wrap_communicator' + 'Communicator', + 'DeliveryFailed', + 'RemoteException', + 'TaskRejected', + 'plum_to_kiwi_future', + 'wrap_communicator', ] RemoteException = kiwipy.RemoteException @@ -20,7 +26,7 @@ if TYPE_CHECKING: # identifiers for subscribers - ID_TYPE = Hashable # pylint: disable=invalid-name + ID_TYPE = Hashable Subscriber = Callable[..., Any] # RPC subscriber params: communicator, msg RpcSubscriber = Callable[[kiwipy.Communicator, Any], Any] @@ -55,8 +61,9 @@ def on_done(_plum_future: futures.Future) -> None: return kiwi_future -def convert_to_comm(callback: 'Subscriber', - loop: Optional[asyncio.AbstractEventLoop] = None) -> Callable[..., kiwipy.Future]: +def convert_to_comm( + callback: 'Subscriber', loop: Optional[asyncio.AbstractEventLoop] = None +) -> Callable[..., kiwipy.Future]: """ Take a callback function and converted it to one that will schedule a callback on the given even loop and return a kiwi future representing the future outcome @@ -67,7 +74,6 @@ def convert_to_comm(callback: 'Subscriber', :return: a new callback function that returns a future """ if isinstance(callback, kiwipy.BroadcastFilter): - # if the broadcast is filtered for this callback, # we don't want to go through the (costly) process # of setting up async tasks and callbacks @@ -75,16 +81,15 @@ def convert_to_comm(callback: 'Subscriber', def _passthrough(*args: Any, **kwargs: Any) -> bool: sender = kwargs.get('sender', args[1]) subject = kwargs.get('subject', args[2]) - return callback.is_filtered(sender, subject) # type: ignore[attr-defined] + return callback.is_filtered(sender, subject) else: - def _passthrough(*args: Any, **kwargs: Any) -> bool: # pylint: disable=unused-argument + def _passthrough(*args: Any, **kwargs: Any) -> bool: return False coro = ensure_coroutine(callback) def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> kiwipy.Future: - if _passthrough(*args, **kwargs): kiwi_future = kiwipy.Future() kiwi_future.set_result(None) @@ -170,7 +175,7 @@ def broadcast_send( body: Optional[Any], sender: Optional[str] = None, subject: Optional[str] = None, - correlation_id: Optional['ID_TYPE'] = None + correlation_id: Optional['ID_TYPE'] = None, ) -> futures.Future: return self._communicator.broadcast_send(body, sender, subject, correlation_id) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 3a342321..47ad4956 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -7,14 +7,13 @@ if TYPE_CHECKING: from typing import Set, Type - from .process_listener import ProcessListener # pylint: disable=cyclic-import + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @persistence.auto_persist('_listeners', '_listener_type') class EventHelper(persistence.Savable): - def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -50,5 +49,5 @@ def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: A for listener in list(self.listeners): try: getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 60a5306e..3de81987 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -1,18 +1,24 @@ # -*- coding: utf-8 -*- """Event and loop related classes and functions""" + import asyncio import sys from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence __all__ = [ - 'new_event_loop', 'set_event_loop', 'get_event_loop', 'run_until_complete', 'set_event_loop_policy', - 'reset_event_loop_policy', 'PlumpyEventLoopPolicy' + 'PlumpyEventLoopPolicy', + 'get_event_loop', + 'new_event_loop', + 'reset_event_loop_policy', + 'run_until_complete', + 'set_event_loop', + 'set_event_loop_policy', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -get_event_loop = asyncio.get_event_loop # pylint: disable=invalid-name +get_event_loop = asyncio.get_event_loop def set_event_loop(*args: Any, **kwargs: Any) -> None: @@ -51,12 +57,10 @@ def reset_event_loop_policy() -> None: """Reset the event loop policy to the default.""" loop = get_event_loop() - # pylint: disable=protected-access cls = loop.__class__ del cls._check_running # type: ignore del cls._nest_patched # type: ignore - # pylint: enable=protected-access asyncio.set_event_loop_policy(None) @@ -69,7 +73,7 @@ def run_until_complete(future: asyncio.Future, loop: Optional[asyncio.AbstractEv class ProcessCallback: """Object returned by callback registration methods.""" - __slots__ = ('_callback', '_args', '_kwargs', '_process', '_cancelled', '__weakref__') + __slots__ = ('__weakref__', '_args', '_callback', '_cancelled', '_kwargs', '_process') def __init__( self, process: 'Process', callback: Callable[..., Any], args: Sequence[Any], kwargs: Dict[str, Any] @@ -93,7 +97,7 @@ async def run(self) -> None: if not self._cancelled: try: await self._callback(*self._args, **self._kwargs) - except Exception: # pylint: disable=broad-except + except Exception: exc_info = sys.exc_info() self._process.callback_excepted(self._callback, exc_info[1], exc_info[2]) finally: diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 40d3e12d..70b5aa2d 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import Optional -__all__ = ['KilledError', 'UnsuccessfulResult', 'InvalidStateError', 'PersistenceError', 'ClosedError'] +__all__ = ['ClosedError', 'InvalidStateError', 'KilledError', 'PersistenceError', 'UnsuccessfulResult'] class KilledError(Exception): diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 365b8008..161244cd 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -2,12 +2,13 @@ """ Module containing future related methods and classes """ + import asyncio from typing import Any, Callable, Coroutine, Optional import kiwipy -__all__ = ['Future', 'gather', 'chain', 'copy_future', 'CancelledError', 'create_task'] +__all__ = ['CancelledError', 'Future', 'chain', 'copy_future', 'create_task', 'gather'] CancelledError = kiwipy.CancelledError @@ -16,11 +17,11 @@ class InvalidStateError(Exception): """Exception for when a future or action is in an invalid state""" -copy_future = kiwipy.copy_future # pylint: disable=invalid-name -chain = kiwipy.chain # pylint: disable=invalid-name -gather = asyncio.gather # pylint: disable=invalid-name +copy_future = kiwipy.copy_future +chain = kiwipy.chain +gather = asyncio.gather -Future = asyncio.Future # pylint: disable=invalid-name +Future = asyncio.Future class CancellableAction(Future): @@ -35,7 +36,7 @@ def __init__(self, action: Callable[..., Any], cookie: Any = None): @property def cookie(self) -> Any: - """ A cookie that can be used to correlate the actions with something """ + """A cookie that can be used to correlate the actions with something""" return self._cookie def run(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/plumpy/lang.py b/src/plumpy/lang.py index 6d9290af..450927d6 100644 --- a/src/plumpy/lang.py +++ b/src/plumpy/lang.py @@ -2,13 +2,13 @@ """ Python language utilities and tools. """ + import functools import inspect from typing import Any, Callable def protected(check: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def wrap(func: Callable[..., Any]) -> Callable[..., Any]: if isinstance(func, property): raise RuntimeError('Protected must go after @property decorator') @@ -31,7 +31,7 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn @@ -60,15 +60,14 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn return wrap -class __NULL: # pylint: disable=invalid-name - +class __NULL: # noqa: N801 def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) diff --git a/src/plumpy/loaders.py b/src/plumpy/loaders.py index 59f33f64..a01f9b60 100644 --- a/src/plumpy/loaders.py +++ b/src/plumpy/loaders.py @@ -3,7 +3,7 @@ import importlib from typing import Any, Optional -__all__ = ['ObjectLoader', 'DefaultObjectLoader', 'set_object_loader', 'get_object_loader'] +__all__ = ['DefaultObjectLoader', 'ObjectLoader', 'get_object_loader', 'set_object_loader'] class ObjectLoader(metaclass=abc.ABCMeta): @@ -74,7 +74,7 @@ def get_object_loader() -> ObjectLoader: :return: A class loader :rtype: :class:`ObjectLoader` """ - global OBJECT_LOADER + global OBJECT_LOADER # noqa: PLW0603 if OBJECT_LOADER is None: OBJECT_LOADER = DefaultObjectLoader() return OBJECT_LOADER @@ -88,5 +88,5 @@ def set_object_loader(loader: Optional[ObjectLoader]) -> None: :type loader: :class:`ObjectLoader` :return: """ - global OBJECT_LOADER + global OBJECT_LOADER # noqa: PLW0603 OBJECT_LOADER = loader diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index a8dcca1e..10142eb7 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -12,6 +12,7 @@ class ContextMixin(persistence.Savable): Add a context to a Process. The contents of the context will be saved in the instance state unlike standard instance variables. """ + CONTEXT: str = '_context' def __init__(self, *args: Any, **kwargs: Any): diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 7a15b1cc..ba755bc5 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -18,18 +18,24 @@ from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ - 'Bundle', 'Persister', 'PicklePersister', 'auto_persist', 'Savable', 'SavableFuture', 'LoadSaveContext', - 'PersistedCheckpoint', 'InMemoryPersister' + 'Bundle', + 'InMemoryPersister', + 'LoadSaveContext', + 'PersistedCheckpoint', + 'Persister', + 'PicklePersister', + 'Savable', + 'SavableFuture', + 'auto_persist', ] PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the @@ -77,7 +83,6 @@ def _bundle_constructor(loader: yaml.Loader, data: Any) -> Generator[Bundle, Non class Persister(metaclass=abc.ABCMeta): - @abc.abstractmethod def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: """ @@ -301,7 +306,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: class InMemoryPersister(Persister): - """ Mainly to be used in testing/debugging """ + """Mainly to be used in testing/debugging""" def __init__(self, loader: Optional[loaders.ObjectLoader] = None) -> None: super().__init__() @@ -340,13 +345,11 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') # type: ignore[name-defined] # pylint: disable=invalid-name +SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - # pylint: disable=protected-access if savable._auto_persist is None: savable._auto_persist = set() else: @@ -390,7 +393,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: self._values = dict(**kwargs) self.loader = loader @@ -408,7 +410,7 @@ def __contains__(self, item: Any) -> bool: return self._values.__contains__(item) def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """ Add additional information to the context by making a copy with the new values """ + """Add additional information to the context by making a copy with the new values""" extended = self._values.copy() extended.update(kwargs) loader = extended.pop('loader', self.loader) @@ -485,7 +487,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio self.load_members(self._auto_persist, saved_state, load_context) @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: # pylint: disable=unused-argument + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: self.save_members(self._auto_persist, out_state) @@ -527,10 +529,7 @@ def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> N out_state[member] = value def load_members( - self, - members: Iterable[str], - saved_state: SAVED_STATE_TYPE, - load_context: Optional[LoadSaveContext] = None + self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None ) -> None: for member in members: setattr(self, member, self._get_value(saved_state, member, load_context)) @@ -580,8 +579,9 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: # endregion - def _get_value(self, saved_state: SAVED_STATE_TYPE, name: str, - load_context: Optional[LoadSaveContext]) -> Union[MethodType, 'Savable']: + def _get_value( + self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] + ) -> Union[MethodType, 'Savable']: value = saved_state[name] typ = Savable._get_meta_type(saved_state, name) @@ -626,10 +626,10 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa state = saved_state['_state'] - if state == asyncio.futures._PENDING: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._PENDING: # type: ignore obj = cls(loop=loop) - if state == asyncio.futures._FINISHED: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._FINISHED: # type: ignore obj = cls(loop=loop) result = saved_state['_result'] @@ -639,14 +639,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa except KeyError: obj.set_result(result) - if state == asyncio.futures._CANCELLED: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._CANCELLED: # type: ignore obj = cls(loop=loop) obj.cancel() return obj def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - # pylint: disable=attribute-defined-outside-init super().load_instance_state(saved_state, load_context) if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index fc5f138f..cfbd92d5 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -1,16 +1,17 @@ # -*- coding: utf-8 -*- """Module for process ports""" + import collections import copy import inspect import json import logging -from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast import warnings +from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check -__all__ = ['UNSPECIFIED', 'PortValidationError', 'PortNamespace', 'Port', 'InputPort', 'OutputPort'] +__all__ = ['UNSPECIFIED', 'InputPort', 'OutputPort', 'Port', 'PortNamespace', 'PortValidationError'] _LOGGER = logging.getLogger(__name__) UNSPECIFIED = () @@ -19,7 +20,7 @@ This has been deprecated and the new signature is `validator(value, port)` where the `port` argument will be the port instance to which the validator has been assigned.""" -VALIDATOR_TYPE = Callable[[Any, 'Port'], Optional[str]] # pylint: disable=invalid-name +VALIDATOR_TYPE = Callable[[Any, 'Port'], Optional[str]] class PortValidationError(Exception): @@ -66,9 +67,9 @@ def __init__( self, name: str, valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None + validator: Optional[VALIDATOR_TYPE] = None, ) -> None: self._name = name self._valid_type = valid_type @@ -134,7 +135,7 @@ def help(self) -> Optional[str]: return self._help @help.setter - def help(self, help: Optional[str]) -> None: # pylint: disable=redefined-builtin + def help(self, help: Optional[str]) -> None: """Set the help string for this port :param help: the help string @@ -198,9 +199,9 @@ def validate(self, value: Any, breadcrumbs: Sequence[str] = ()) -> Optional[Port spec = inspect.getfullargspec(self.validator) if len(spec[0]) == 1: warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__)) - result = self.validator(value) # type: ignore # pylint: disable=not-callable + result = self.validator(value) # type: ignore else: - result = self.validator(value, self) # pylint: disable=not-callable + result = self.validator(value, self) if result is not None: assert isinstance(result, str), 'Validator returned non string type' validation_error = result @@ -233,17 +234,17 @@ def __init__( self, name: str, valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, default: Any = UNSPECIFIED, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None - ) -> None: # pylint: disable=too-many-arguments + validator: Optional[VALIDATOR_TYPE] = None, + ) -> None: super().__init__( name, valid_type=valid_type, help=help, required=InputPort.required_override(required, default), - validator=validator + validator=validator, ) if required is not InputPort.required_override(required, default): @@ -252,7 +253,6 @@ def __init__( ) if default is not UNSPECIFIED: - # Only validate the default value if it is not a callable. If it is a callable its return value will always # be validated when the port is validated upon process construction, if the default is was actually used. if not callable(default): @@ -304,14 +304,14 @@ class PortNamespace(collections.abc.MutableMapping, Port): def __init__( self, name: str = '', # Note this was set to None, but that would fail if you tried to compute breadcrumbs - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, required: bool = True, validator: Optional[VALIDATOR_TYPE] = None, valid_type: Optional[Type[Any]] = None, default: Any = UNSPECIFIED, dynamic: bool = False, - populate_defaults: bool = True - ) -> None: # pylint: disable=too-many-arguments + populate_defaults: bool = True, + ) -> None: """Construct a port namespace. :param name: the name of the namespace @@ -396,7 +396,7 @@ def valid_type(self, valid_type: Optional[Type[Any]]) -> None: if valid_type is not None: self.dynamic = True - super(PortNamespace, self.__class__).valid_type.fset(self, valid_type) # type: ignore # pylint: disable=no-member + super(PortNamespace, self.__class__).valid_type.fset(self, valid_type) # type: ignore @property def populate_defaults(self) -> bool: @@ -459,7 +459,7 @@ def get_port(self, name: str, create_dynamically: bool = False) -> Union[Port, ' valid_type=self.valid_type, default=self.default, dynamic=self.dynamic, - populate_defaults=self.populate_defaults + populate_defaults=self.populate_defaults, ) if namespace: @@ -495,7 +495,6 @@ def create_port_namespace(self, name: str, **kwargs: Any) -> 'PortNamespace': # If this is True, the (sub) port namespace does not yet exist, so we create it if port_name not in self: - # If there still is a `namespace`, we create a sub namespace, *without* the constructor arguments if namespace: self[port_name] = self.__class__(port_name) @@ -515,7 +514,7 @@ def absorb( port_namespace: 'PortNamespace', exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[Dict[str, Any]] = None + namespace_options: Optional[Dict[str, Any]] = None, ) -> List[str]: """Absorb another PortNamespace instance into oneself, including all its mutable properties and ports. @@ -531,7 +530,7 @@ def absorb( :param namespace_options: a dictionary with mutable PortNamespace property values to override :return: list of the names of the ports that were absorbed """ - # pylint: disable=too-many-branches + if not isinstance(port_namespace, PortNamespace): raise ValueError('port_namespace has to be an instance of PortNamespace') @@ -559,14 +558,12 @@ def absorb( absorbed_ports = [] for port_name, port in port_namespace.items(): - # If the current port name occurs in the exclude list, simply skip it entirely, there is no need to consider # any of the nested ports it might have, even if it is a port namespace if exclude and port_name in exclude: continue if isinstance(port, PortNamespace): - # If the name does not appear at the start of any of the include rules we continue: if include and not any(rule.startswith(port_name) for rule in include): continue @@ -580,7 +577,7 @@ def absorb( # absorb call that will properly consider the include and exclude rules self[port_name] = copy.copy(port) portnamespace = cast(PortNamespace, self[port_name]) - portnamespace._ports = {} # pylint: disable=protected-access + portnamespace._ports = {} portnamespace.absorb(port, sub_exclude, sub_include) else: # If include rules are specified but the port name does not appear, simply skip it @@ -615,10 +612,8 @@ def project(self, port_values: MutableMapping[str, Any]) -> MutableMapping[str, return result - def validate( # pylint: disable=arguments-differ - self, - port_values: Optional[Mapping[str, Any]] = None, - breadcrumbs: Sequence[str] = () + def validate( + self, port_values: Optional[Mapping[str, Any]] = None, breadcrumbs: Sequence[str] = () ) -> Optional[PortValidationError]: """ Validate the namespace port itself and subsequently all the port_values it contains @@ -627,7 +622,7 @@ def validate( # pylint: disable=arguments-differ :param breadcrumbs: a tuple of the path to having reached this point in validation :return: None or tuple containing 0: error string 1: tuple of breadcrumb strings to where the validation failed """ - # pylint: disable=arguments-renamed + breadcrumbs_local = (*breadcrumbs, self.name) message: Optional[str] @@ -665,12 +660,13 @@ def validate( # pylint: disable=arguments-differ spec = inspect.getfullargspec(self.validator) if len(spec[0]) == 1: warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__)) - message = self.validator(port_values_clone) # type: ignore # pylint: disable=not-callable + message = self.validator(port_values_clone) # type: ignore else: - message = self.validator(port_values_clone, self) # pylint: disable=not-callable + message = self.validator(port_values_clone, self) if message is not None: - assert isinstance(message, str), \ - f"Validator returned something other than None or str: '{type(message)}'" + assert isinstance( + message, str + ), f"Validator returned something other than None or str: '{type(message)}'" return PortValidationError(message, breadcrumbs_to_port(breadcrumbs_local)) return None @@ -682,14 +678,12 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen :return: an AttributesFrozenDict with pre-processed port value mapping, complemented with port default values """ for name, port in self.items(): - # If the port was not specified in the inputs values and the port is a namespace with the property # `populate_defaults=False`, we skip the pre-processing and do not populate defaults. if name not in port_values and isinstance(port, PortNamespace) and not port.populate_defaults: continue if name not in port_values: - if port.has_default(): default = port.default if callable(default): @@ -712,8 +706,9 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen return AttributesFrozendict(port_values) - def validate_ports(self, port_values: MutableMapping[str, Any], - breadcrumbs: Sequence[str]) -> Optional[PortValidationError]: + def validate_ports( + self, port_values: MutableMapping[str, Any], breadcrumbs: Sequence[str] + ) -> Optional[PortValidationError]: """ Validate port values with respect to the explicitly defined ports of the port namespace. Ports values that are matched to an actual Port will be popped from the dictionary @@ -791,7 +786,7 @@ def strip_namespace(namespace: str, separator: str, rules: Optional[Sequence[str for rule in rules: if rule.startswith(prefix): - stripped.append(rule[len(prefix):]) + stripped.append(rule[len(prefix) :]) return stripped diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index c66e8431..293c680b 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" + import asyncio import copy import logging @@ -11,19 +12,19 @@ from .utils import PID_TYPE __all__ = [ + 'KILL_MSG', 'PAUSE_MSG', 'PLAY_MSG', - 'KILL_MSG', 'STATUS_MSG', 'ProcessLauncher', + 'RemoteProcessController', + 'RemoteProcessThreadController', 'create_continue_body', 'create_launch_body', - 'RemoteProcessThreadController', - 'RemoteProcessController', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process ProcessResult = Any ProcessStatus = Any @@ -34,7 +35,7 @@ class Intent: """Intent constants for a process message""" - # pylint: disable=too-few-public-methods + PLAY: str = 'play' PAUSE: str = 'pause' KILL: str = 'kill' @@ -71,7 +72,7 @@ def create_launch_body( init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, - nowait: bool = True + nowait: bool = True, ) -> Dict[str, Any]: """ Create a message body for the launch action @@ -95,8 +96,8 @@ def create_launch_body( PERSIST_KEY: persist, NOWAIT_KEY: nowait, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -119,7 +120,7 @@ def create_create_body( init_args: Optional[Sequence[Any]] = None, init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> Dict[str, Any]: """ Create a message body to create a new process @@ -140,8 +141,8 @@ def create_create_body( PROCESS_CLASS_KEY: loader.identify_object(process_class), PERSIST_KEY: persist, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -216,11 +217,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro return result async def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Optional['ProcessResult']: """ Continue the process @@ -249,7 +246,7 @@ async def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Launch a process given the class and constructor arguments @@ -263,7 +260,7 @@ async def launch_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of launching the process """ - # pylint: disable=too-many-arguments + message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait) launch_future = self._communicator.task_send(message, no_reply=no_reply) future = await asyncio.wrap_future(launch_future) @@ -281,7 +278,7 @@ async def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Execute a process. This call will first send a create task and then a continue task over @@ -296,7 +293,7 @@ async def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ - # pylint: disable=too-many-arguments + message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) create_future = self._communicator.task_send(message) @@ -399,11 +396,7 @@ def kill_all(self, msg: Optional[Any]) -> None: self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: message = create_continue_body(pid=pid, tag=tag, nowait=nowait) return self.task_send(message, no_reply=no_reply) @@ -416,9 +409,8 @@ def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: - # pylint: disable=too-many-arguments """ Launch the process @@ -441,7 +433,7 @@ def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: """ Execute a process. This call will first send a create task and then a continue task over @@ -456,7 +448,7 @@ def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ - # pylint: disable=too-many-arguments + message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) execute_future = kiwipy.Future() @@ -512,7 +504,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, persister: Optional[persistence.Persister] = None, load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> None: self._loop = loop self._persister = persister @@ -573,7 +565,8 @@ async def _launch( self._persister.save_checkpoint(proc) if nowait: - asyncio.ensure_future(proc.step_until_terminated()) + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 return proc.pid await proc.step_until_terminated() @@ -581,11 +574,7 @@ async def _launch( return proc.future().result() async def _continue( - self, - _communicator: kiwipy.Communicator, - pid: 'PID_TYPE', - nowait: bool, - tag: Optional[str] = None + self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None ) -> Union[PID_TYPE, ProcessResult]: """ Continue the process @@ -604,7 +593,8 @@ async def _continue( proc = cast('Process', saved_state.unbundle(self._load_context)) if nowait: - asyncio.ensure_future(proc.step_until_terminated()) + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 return proc.pid await proc.step_until_terminated() diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 110394a2..8e1acf94 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -8,12 +8,11 @@ __all__ = ['ProcessListener'] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process @persistence.auto_persist('_params') class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): - # region Persistence methods def __init__(self) -> None: diff --git a/src/plumpy/process_spec.py b/src/plumpy/process_spec.py index c82d59ee..00f2f3cc 100644 --- a/src/plumpy/process_spec.py +++ b/src/plumpy/process_spec.py @@ -7,9 +7,9 @@ from .ports import InputPort, OutputPort, Port, PortNamespace if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -EXPOSED_TYPE = Dict[Optional[str], Dict[Type['Process'], Sequence[str]]] # pylint: disable=invalid-name +EXPOSED_TYPE = Dict[Optional[str], Dict[Type['Process'], Sequence[str]]] class ProcessSpec: @@ -22,6 +22,7 @@ class ProcessSpec: Every Process class has one of these. """ + NAME_INPUTS_PORT_NAMESPACE: str = 'inputs' NAME_OUTPUTS_PORT_NAMESPACE: str = 'outputs' PORT_NAMESPACE_TYPE = PortNamespace @@ -184,7 +185,7 @@ def expose_inputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the inputs from another Process to this ProcessSpec. @@ -215,7 +216,7 @@ def expose_outputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the ouputs from another Process to this ProcessSpec. @@ -249,8 +250,8 @@ def _expose_ports( namespace: Optional[str], exclude: Optional[Sequence[str]], include: Optional[Sequence[str]], - namespace_options: Optional[dict] = None - ) -> None: # pylint: disable=too-many-arguments + namespace_options: Optional[dict] = None, + ) -> None: """ Expose ports from a source PortNamespace of the ProcessSpec of a Process class into the destination PortNamespace of this ProcessSpec. If the namespace is specified, the ports will be exposed in that sub diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3407412d..7ae6e9bd 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from enum import Enum import sys import traceback +from enum import Enum from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast @@ -22,28 +22,28 @@ from .utils import SAVED_STATE_TYPE __all__ = [ - 'ProcessState', + 'Continue', 'Created', - 'Running', - 'Waiting', - 'Finished', 'Excepted', - 'Killed', + 'Finished', + 'Interruption', # Commands 'Kill', - 'Stop', - 'Wait', - 'Continue', - 'Interruption', 'KillInterruption', + 'Killed', 'PauseInterruption', + 'ProcessState', + 'Running', + 'Stop', + 'Wait', + 'Waiting', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -class Interruption(Exception): +class Interruption(Exception): # noqa: N818 pass @@ -64,7 +64,6 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): super().__init__() self.msg = msg @@ -76,7 +75,6 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): - def __init__( self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None ): @@ -88,7 +86,6 @@ def __init__( @auto_persist('result') class Stop(Command): - def __init__(self, result: Any, successful: bool) -> None: super().__init__() self.result = result @@ -127,6 +124,7 @@ class ProcessState(Enum): """ The possible states that a :class:`~plumpy.processes.Process` can be in. """ + CREATED: str = 'created' RUNNING: str = 'running' WAITING: str = 'waiting' @@ -137,7 +135,6 @@ class ProcessState(Enum): @auto_persist('in_state') class State(state_machine.State, persistence.Savable): - @property def process(self) -> state_machine.StateMachine: """ @@ -149,7 +146,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process - def interrupt(self, reason: Any) -> None: # pylint: disable=unused-argument + def interrupt(self, reason: Any) -> None: pass @@ -183,7 +180,11 @@ def execute(self) -> state_machine.State: class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.FINISHED, + ProcessState.KILLED, + ProcessState.EXCEPTED, } RUN_FN = 'run_fn' # The key used to store the function to run @@ -217,7 +218,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore # pylint: disable=invalid-overridden-method + async def execute(self) -> State: # type: ignore if self._command is not None: command = self._command else: @@ -230,7 +231,7 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over except Interruption: # Let this bubble up to the caller raise - except Exception: # pylint: disable=broad-except + except Exception: excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(State, excepted) else: @@ -267,7 +268,11 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, ProcessState.EXCEPTED, ProcessState.FINISHED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.KILLED, + ProcessState.EXCEPTED, + ProcessState.FINISHED, } DONE_CALLBACK = 'DONE_CALLBACK' @@ -285,7 +290,7 @@ def __init__( process: 'Process', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - data: Optional[Any] = None + data: Optional[Any] = None, ) -> None: super().__init__(process) self.done_callback = done_callback @@ -311,7 +316,7 @@ def interrupt(self, reason: Any) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore # pylint: disable=invalid-overridden-method + async def execute(self) -> State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -370,9 +375,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = \ - tblib.Traceback.from_string(saved_state[self.TRACEBACK], - strict=False) + self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) except KeyError: self.traceback = None else: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index b6e14ad9..ba7967d3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The main Process module""" + import abc import asyncio import contextlib @@ -10,6 +11,8 @@ import re import sys import time +import uuid +import warnings from types import TracebackType from typing import ( Any, @@ -26,17 +29,15 @@ Union, cast, ) -import uuid -import warnings try: from aiocontextvars import ContextVar except ModuleNotFoundError: from contextvars import ContextVar -from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed import kiwipy import yaml +from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils from .base import state_machine @@ -47,9 +48,7 @@ from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -# pylint: disable=too-many-lines - -__all__ = ['Process', 'ProcessSpec', 'BundleKeys', 'TransitionFailed'] +__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) PROCESS_STACK = ContextVar('process stack', default=[]) @@ -62,7 +61,7 @@ class BundleKeys: See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. """ - # pylint: disable=too-few-public-methods + INPUTS_RAW = 'INPUTS_RAW' INPUTS_PARSED = 'INPUTS_PARSED' OUTPUTS = 'OUTPUTS' @@ -75,7 +74,7 @@ class ProcessStateMachineMeta(abc.ABCMeta, state_machine.StateMachineMeta): # Make ProcessStateMachineMeta instances (classes) YAML - able yaml.representer.Representer.add_representer( ProcessStateMachineMeta, - yaml.representer.Representer.represent_name # type: ignore[arg-type] + yaml.representer.Representer.represent_name, # type: ignore[arg-type] ) @@ -84,7 +83,6 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # pylint: disable=protected-access if self._closed: raise exceptions.ClosedError('Process is closed') return func(self, *args, **kwargs) @@ -133,8 +131,6 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe executed. """ - # pylint: disable=too-many-instance-attributes,too-many-public-methods - # Static class stuff ###################### _spec_class = ProcessSpec # Default placeholders, will be populated in init() @@ -167,7 +163,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: state_classes = cls.get_state_classes() return ( state_classes[process_states.ProcessState.CREATED], - *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED] + *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED], ) @classmethod @@ -179,7 +175,7 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: process_states.ProcessState.WAITING: process_states.Waiting, process_states.ProcessState.FINISHED: process_states.Finished, process_states.ProcessState.EXCEPTED: process_states.Excepted, - process_states.ProcessState.KILLED: process_states.Killed + process_states.ProcessState.KILLED: process_states.Killed, } @classmethod @@ -256,7 +252,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -278,8 +274,9 @@ def __init__( self._setup_event_hooks() self._status: Optional[str] = None # May hold a current status message - self._pre_paused_status: Optional[ - str] = None # Save status when a pause message replaces it, such that it can be restored + self._pre_paused_status: Optional[str] = ( + None # Save status when a pause message replaces it, such that it can be restored + ) self._paused = None # Input/output @@ -331,12 +328,13 @@ def try_killing(future: futures.Future) -> None: def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { - state_machine.StateEventHook.ENTERING_STATE: - lambda _s, _h, state: self.on_entering(cast(process_states.State, state)), - state_machine.StateEventHook.ENTERED_STATE: - lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)), - state_machine.StateEventHook.EXITING_STATE: - lambda _s, _h, _state: self.on_exiting() + state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( + cast(process_states.State, state) + ), + state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( + cast(Optional[process_states.State], from_state) + ), + state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } for hook, callback in event_hooks.items(): self.add_state_event_callback(hook, callback) @@ -356,7 +354,7 @@ def pid(self) -> Optional[PID_TYPE]: @property def uuid(self) -> Optional[uuid.UUID]: - """Return the UUID of the process """ + """Return the UUID of the process""" return self._uuid @property @@ -421,7 +419,7 @@ def launch( process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, - logger: Optional[logging.Logger] = None + logger: Optional[logging.Logger] = None, ) -> 'Process': """Start running the nested process. @@ -507,7 +505,7 @@ def done(self) -> bool: .. deprecated:: 0.18.6 Use the `has_terminated` method instead """ - warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) # pylint: disable=no-member + warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) return self._state.is_terminal() # endregion @@ -663,7 +661,7 @@ def add_process_listener(self, listener: ProcessListener) -> None: the specific state condition. """ - assert (listener != self), 'Cannot listen to yourself!' # type: ignore + assert listener != self, 'Cannot listen to yourself!' # type: ignore self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: @@ -886,7 +884,7 @@ def on_close(self) -> None: for cleanup in self._cleanups or []: try: cleanup() - except Exception: # pylint: disable=broad-except + except Exception: self.logger.exception('Process<%s>: Exception calling cleanup method %s', self.pid, cleanup) self._cleanups = None finally: @@ -926,15 +924,16 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An # Didn't match any known intents raise RuntimeError('Unknown intent') - def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, - correlation_id: Any) -> Optional[kiwipy.Future]: + def broadcast_receive( + self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator :param _comm: the communicator that sent the message :param msg: the message """ - # pylint: disable=unused-argument + self.logger.debug( "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body ) @@ -1044,7 +1043,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: - """ Carry out the pause procedure, optionally transitioning to the next state first""" + """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) @@ -1091,7 +1090,7 @@ def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) self._interrupt_action = new_action def _set_interrupt_action_from_exception(self, interrupt_exception: process_states.Interruption) -> None: - """ Set an interrupt action from the corresponding interrupt exception """ + """Set an interrupt action from the corresponding interrupt exception""" action = self._create_interrupt_action(interrupt_exception) self._set_interrupt_action(action) @@ -1233,9 +1232,9 @@ async def step(self) -> None: else: self._set_interrupt_action_from_exception(exception) - except KeyboardInterrupt: # pylint: disable=try-except-raise + except KeyboardInterrupt: raise - except Exception: # pylint: disable=broad-except + except Exception: # Overwrite the next state to go to excepted directly next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) self._set_interrupt_action(None) @@ -1285,7 +1284,7 @@ def out(self, output_port: str, value: Any) -> None: if namespace: port_namespace = cast( ports.PortNamespace, - self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True) + self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True), ) else: port_namespace = self.spec().outputs @@ -1341,9 +1340,11 @@ def get_status_info(self, out_status_info: dict) -> None: :param out_status_info: the old status """ - out_status_info.update({ - 'ctime': self.creation_time, - 'paused': self.paused, - 'process_string': str(self), - 'state': str(self.state), - }) + out_status_info.update( + { + 'ctime': self.creation_time, + 'paused': self.paused, + 'process_string': str(self), + 'state': str(self.state), + } + ) diff --git a/src/plumpy/settings.py b/src/plumpy/settings.py index 8a136dea..e863311c 100644 --- a/src/plumpy/settings.py +++ b/src/plumpy/settings.py @@ -1,4 +1,3 @@ # -*- coding: utf-8 -*- -# pylint: disable=invalid-name check_protected: bool = False check_override: bool = False diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 4eab8efe..36d76bbd 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -1,30 +1,36 @@ # -*- coding: utf-8 -*- import asyncio -from collections import deque -from collections.abc import Mapping import functools import importlib import inspect import logging import types -from typing import Set # pylint: disable=unused-import -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type +from collections import deque +from collections.abc import Mapping +from typing import ( + Any, + Callable, + Hashable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Type, +) from . import lang from .settings import check_override, check_protected -if TYPE_CHECKING: - from .process_listener import ProcessListener # pylint: disable=cyclic-import - __all__ = ['AttributesDict'] -protected = lang.protected(check=check_protected) # pylint: disable=invalid-name -override = lang.override(check=check_override) # pylint: disable=invalid-name +protected = lang.protected(check=check_protected) +override = lang.override(check=check_override) _LOGGER = logging.getLogger(__name__) -SAVED_STATE_TYPE = MutableMapping[str, Any] # pylint: disable=invalid-name -PID_TYPE = Hashable # pylint: disable=invalid-name +SAVED_STATE_TYPE = MutableMapping[str, Any] +PID_TYPE = Hashable class Frozendict(Mapping): @@ -67,7 +73,6 @@ def __hash__(self) -> int: class AttributesFrozendict(Frozendict): - def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._initialised: bool = True @@ -130,7 +135,7 @@ def load_function(name: str, instance: Optional[Any] = None) -> Callable[..., An obj = load_object(name) if inspect.ismethod(obj): if instance is not None: - return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] # pylint: disable=unnecessary-dunder-call + return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] return obj diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 90e35482..748a44d7 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,28 +1,43 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc import asyncio import collections import inspect import logging import re -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + List, + Mapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) import kiwipy from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE -__all__ = ['WorkChain', 'if_', 'while_', 'return_', 'ToContext', 'WorkChainSpec'] +__all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] ToContext = dict -PREDICATE_TYPE = Callable[['WorkChain'], bool] # pylint: disable=invalid-name -WC_COMMAND_TYPE = Callable[['WorkChain'], Any] # pylint: disable=invalid-name -EXIT_CODE_TYPE = int # pylint: disable=invalid-name +PREDICATE_TYPE = Callable[['WorkChain'], bool] +WC_COMMAND_TYPE = Callable[['WorkChain'], Any] +EXIT_CODE_TYPE = int class WorkChainSpec(processes.ProcessSpec): - def __init__(self) -> None: super().__init__() self._outline: Optional[Union['_Instruction', '_FunctionCall']] = None @@ -55,21 +70,20 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): - """ Overwrite the waiting state""" + """Overwrite the waiting state""" def __init__( self, process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None + awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: super().__init__(process, done_callback, msg, awaiting) self._awaiting: Dict[asyncio.Future, str] = {} for awaitable, key in (awaiting or {}).items(): - if isinstance(awaitable, processes.Process): - awaitable = awaitable.future() - self._awaiting[awaitable] = key + resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable + self._awaiting[resolved_awaitable] = key def enter(self) -> None: super().enter() @@ -85,7 +99,7 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: self.process.ctx[key] = awaitable.result() # type: ignore - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: self._waiting_future.set_exception(exception) else: if not self._awaiting: @@ -97,6 +111,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): A WorkChain is a series of instructions carried out with the ability to save state in between. """ + _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' _CONTEXT = 'CONTEXT' @@ -113,7 +128,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) self._stepper: Optional[Stepper] = None @@ -152,9 +167,9 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None to the corresponding key in the context of the workchain """ for key, awaitable in kwargs.items(): - if isinstance(awaitable, processes.Process): - awaitable = awaitable.future() - self._awaitables[awaitable] = key + resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable + + self._awaitables[resolved_awaitable] = key def run(self) -> Any: return self._do_step() @@ -169,7 +184,6 @@ def _do_step(self) -> Any: finished, return_value = True, exception.exit_code if not finished and (return_value is None or isinstance(return_value, ToContext)): - if isinstance(return_value, ToContext): self.to_context(**return_value) @@ -182,7 +196,6 @@ def _do_step(self) -> Any: class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: self._workchain = workchain @@ -210,11 +223,11 @@ class _Instruction(metaclass=abc.ABCMeta): @abc.abstractmethod def create_stepper(self, workchain: 'WorkChain') -> Stepper: - """ Create a new stepper for this instruction """ + """Create a new stepper for this instruction""" @abc.abstractmethod def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> Stepper: - """ Recreate a stepper from a previously saved state """ + """Recreate a stepper from a previously saved state""" def __str__(self) -> str: return str(self.get_description()) @@ -229,7 +242,6 @@ def get_description(self) -> Any: class _FunctionStepper(Stepper): - def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn @@ -250,7 +262,6 @@ def __str__(self) -> str: class _FunctionCall(_Instruction): - def __init__(self, func: WC_COMMAND_TYPE) -> None: try: args = inspect.getfullargspec(func)[0] @@ -282,7 +293,6 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') class _BlockStepper(Stepper): - def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: super().__init__(workchain) self._block = block @@ -333,14 +343,15 @@ class _Block(_Instruction, collections.abc.Sequence): def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) -> None: # Build up the list of commands - comms = [] + comms: MutableSequence[_Instruction | _FunctionCall] = [] for instruction in instructions: if not isinstance(instruction, _Instruction): # Assume it's a function call - instruction = _FunctionCall(instruction) + comms.append(_FunctionCall(instruction)) + else: + comms.append(instruction) - comms.append(instruction) - self._instruction: List[Union[_Instruction, _FunctionCall]] = comms + self._instruction: MutableSequence[_Instruction | _FunctionCall] = comms def __getitem__(self, index: int) -> Union[_Instruction, _FunctionCall]: # type: ignore return self._instruction[index] @@ -392,10 +403,12 @@ def is_true(self, workflow: 'WorkChain') -> bool: if not hasattr(result, '__bool__'): import warnings + warnings.warn( f'The conditional predicate `{self._predicate.__name__}` returned `{result}` which is not boolean-like.' ' The return value should be `True` or `False` or implement the `__bool__` method. This behavior is ' - 'deprecated and will soon start raising an exception.', UserWarning + 'deprecated and will soon start raising an exception.', + UserWarning, ) return result @@ -411,7 +424,6 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') class _IfStepper(Stepper): - def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: super().__init__(workchain) self._if_instruction = if_instruction @@ -467,7 +479,6 @@ def __str__(self) -> str: class _If(_Instruction, collections.abc.Sequence): - def __init__(self, condition: PREDICATE_TYPE) -> None: super().__init__() self._ifs: List[_Conditional] = [_Conditional(self, condition, label=if_.__name__)] @@ -520,7 +531,6 @@ def get_description(self) -> Mapping[str, Any]: class _WhileStepper(Stepper): - def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: super().__init__(workchain) self._while_instruction = while_instruction @@ -563,7 +573,6 @@ def __str__(self) -> str: class _While(_Conditional, _Instruction, collections.abc.Sequence): - def __init__(self, predicate: PREDICATE_TYPE) -> None: super().__init__(self, predicate, label=while_.__name__) @@ -586,14 +595,12 @@ def get_description(self) -> Dict[str, Any]: class _PropagateReturn(BaseException): - def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: super().__init__() self.exit_code = exit_code class _ReturnStepper(Stepper): - def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: super().__init__(workchain) self._return_instruction = return_instruction @@ -603,7 +610,7 @@ def step(self) -> Tuple[bool, Any]: Raise a _PropagateReturn exception where the value is the exit code set in the _Return instruction upon instantiation """ - raise _PropagateReturn(self._return_instruction._exit_code) # pylint: disable=protected-access + raise _PropagateReturn(self._return_instruction._exit_code) class _Return(_Instruction): @@ -670,7 +677,7 @@ def while_(condition: PREDICATE_TYPE) -> _While: return _While(condition) -return_ = _Return() # pylint: disable=invalid-name +return_ = _Return() """ A global singleton that contains a Return instruction that allows to exit out of the workchain outline directly with None as exit code diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/base/__init__.py b/tests/base/__init__.py similarity index 100% rename from test/base/__init__.py rename to tests/base/__init__.py diff --git a/test/base/test_statemachine.py b/tests/base/test_statemachine.py similarity index 90% rename from test/base/test_statemachine.py rename to tests/base/test_statemachine.py index 72fed261..5b4b73d8 100644 --- a/test/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -25,7 +25,7 @@ def __init__(self, player, track): super().__init__(player) self.track = track self._last_time = None - self._played = 0. + self._played = 0.0 def __str__(self): if self.in_state: @@ -55,8 +55,7 @@ class Paused(state_machine.State): TRANSITIONS = {STOP: STOPPED} def __init__(self, player, playing_state): - assert isinstance(playing_state, Playing), \ - 'Must provide the playing state to pause' + assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) self.playing_state = playing_state @@ -65,7 +64,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) else: self.state_machine.transition_to(self.playing_state) @@ -81,7 +80,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) class CdPlayer(state_machine.StateMachine): @@ -108,7 +107,7 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, self._state) + self.transition_to(Paused, playing_state=self._state) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) @@ -117,14 +116,13 @@ def stop(self): class TestStateMachine(unittest.TestCase): - def test_basic(self): cd_player = CdPlayer() self.assertEqual(cd_player.state, STOPPED) cd_player.play('Eminem - The Real Slim Shady') self.assertEqual(cd_player.state, PLAYING) - time.sleep(1.) + time.sleep(1.0) cd_player.pause() self.assertEqual(cd_player.state, PAUSED) diff --git a/test/base/test_utils.py b/tests/base/test_utils.py similarity index 99% rename from test/base/test_utils.py rename to tests/base/test_utils.py index 9aa0237b..d62b1422 100644 --- a/test/base/test_utils.py +++ b/tests/base/test_utils.py @@ -5,7 +5,6 @@ class Root: - @utils.super_check def method(self): pass @@ -15,19 +14,16 @@ def do(self): class DoCall(Root): - def method(self): super().method() class DontCall(Root): - def method(self): pass class TestSuperCheckMixin(unittest.TestCase): - def test_do_call(self): DoCall().do() @@ -36,9 +32,7 @@ def test_dont_call(self): DontCall().do() def dont_call_middle(self): - class ThirdChild(DontCall): - def method(self): super().method() diff --git a/test/conftest.py b/tests/conftest.py similarity index 99% rename from test/conftest.py rename to tests/conftest.py index 43555586..c70088fa 100644 --- a/test/conftest.py +++ b/tests/conftest.py @@ -5,4 +5,5 @@ @pytest.fixture(scope='session') def set_event_loop_policy(): from plumpy import set_event_loop_policy + set_event_loop_policy() diff --git a/test/notebooks/get_event_loop.ipynb b/tests/notebooks/get_event_loop.ipynb similarity index 90% rename from test/notebooks/get_event_loop.ipynb rename to tests/notebooks/get_event_loop.ipynb index 6aa4fbd5..860ca3d2 100644 --- a/test/notebooks/get_event_loop.ipynb +++ b/tests/notebooks/get_event_loop.ipynb @@ -7,7 +7,9 @@ "outputs": [], "source": [ "import asyncio\n", - "from plumpy import set_event_loop_policy, PlumpyEventLoopPolicy\n", + "\n", + "from plumpy import PlumpyEventLoopPolicy, set_event_loop_policy\n", + "\n", "set_event_loop_policy()\n", "assert isinstance(asyncio.get_event_loop_policy(), PlumpyEventLoopPolicy)\n", "assert hasattr(asyncio.get_event_loop(), '_nest_patched')" diff --git a/test/persistence/__init__.py b/tests/persistence/__init__.py similarity index 100% rename from test/persistence/__init__.py rename to tests/persistence/__init__.py diff --git a/test/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py similarity index 88% rename from test/persistence/test_inmemory.py rename to tests/persistence/test_inmemory.py index bc03f88b..b0db46e7 100644 --- a/test/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,31 +1,26 @@ # -*- coding: utf-8 -*- -import asyncio -from test.utils import ProcessWithCheckpoint import unittest +from ..utils import ProcessWithCheckpoint + import plumpy +import plumpy -class TestInMemoryPersister(unittest.TestCase): +class TestInMemoryPersister(unittest.TestCase): def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint and recreating it from the same checkpoint """ - loop = asyncio.get_event_loop() process = ProcessWithCheckpoint() persister = plumpy.InMemoryPersister() persister.save_checkpoint(process) - bundle = persister.load_checkpoint(process.pid) - load_context = plumpy.LoadSaveContext(loop=loop) - recreated = bundle.unbundle(load_context) - def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -43,8 +38,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -64,15 +58,12 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') checkpoints = [checkpoint_a1, checkpoint_a2] @@ -87,15 +78,12 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') persister = plumpy.InMemoryPersister() persister.save_checkpoint(process_a, tag='1') @@ -116,8 +104,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/persistence/test_pickle.py b/tests/persistence/test_pickle.py similarity index 89% rename from test/persistence/test_pickle.py rename to tests/persistence/test_pickle.py index 19e4f52a..dd68b4fd 100644 --- a/test/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -1,37 +1,29 @@ # -*- coding: utf-8 -*- -import asyncio import tempfile import unittest if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from test.utils import ProcessWithCheckpoint +from ..utils import ProcessWithCheckpoint import plumpy class TestPicklePersister(unittest.TestCase): - def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint and recreating it from the same checkpoint """ - loop = asyncio.get_event_loop() process = ProcessWithCheckpoint() with tempfile.TemporaryDirectory() as directory: persister = plumpy.PicklePersister(directory) persister.save_checkpoint(process) - bundle = persister.load_checkpoint(process.pid) - load_context = plumpy.LoadSaveContext(loop=loop) - recreated = bundle.unbundle(load_context) - def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -50,8 +42,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -72,15 +63,12 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') checkpoints = [checkpoint_a1, checkpoint_a2] @@ -96,15 +84,12 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') with tempfile.TemporaryDirectory() as directory: persister = plumpy.PicklePersister(directory) @@ -126,8 +111,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/rmq/__init__.py b/tests/rmq/__init__.py similarity index 100% rename from test/rmq/__init__.py rename to tests/rmq/__init__.py diff --git a/test/rmq/docker-compose.yml b/tests/rmq/docker-compose.yml similarity index 100% rename from test/rmq/docker-compose.yml rename to tests/rmq/docker-compose.yml diff --git a/test/rmq/test_communicator.py b/tests/rmq/test_communicator.py similarity index 90% rename from test/rmq/test_communicator.py rename to tests/rmq/test_communicator.py index 5cedd38d..3f2570d8 100644 --- a/test/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.rmq.communicator` module.""" + import asyncio import functools import shutil import tempfile import uuid -from kiwipy import BroadcastFilter, rmq import pytest import shortuuid import yaml +from kiwipy import BroadcastFilter, rmq import plumpy from plumpy import communications, process_comms @@ -38,7 +39,7 @@ def loop_communicator(): message_exchange=message_exchange, task_exchange=task_exchange, task_queue=task_queue, - decoder=functools.partial(yaml.load, Loader=yaml.Loader) + decoder=functools.partial(yaml.load, Loader=yaml.Loader), ) loop = asyncio.get_event_loop() @@ -61,7 +62,7 @@ class TestLoopCommunicator: @pytest.mark.asyncio async def test_broadcast(self, loop_communicator): - BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} + BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806 broadcast_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -69,12 +70,9 @@ async def test_broadcast(self, loop_communicator): def get_broadcast(_comm, body, sender, subject, correlation_id): assert loop is asyncio.get_event_loop() - broadcast_future.set_result({ - 'body': body, - 'sender': sender, - 'subject': subject, - 'correlation_id': correlation_id - }) + broadcast_future.set_result( + {'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id} + ) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send(**BROADCAST) @@ -84,11 +82,8 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_broadcast_filter(self, loop_communicator): - broadcast_future = plumpy.Future() - loop = asyncio.get_event_loop() - def ignore_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_exception(AssertionError('broadcast received')) @@ -98,12 +93,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( - **{ - 'body': 'present', - 'sender': 'Martin', - 'subject': 'sup', - 'correlation_id': 420 - } + **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} ) result = await broadcast_future @@ -111,7 +101,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_rpc(self, loop_communicator): - MSG = 'rpc this' + MSG = 'rpc this' # noqa: N806 rpc_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -128,7 +118,7 @@ def get_rpc(_comm, msg): @pytest.mark.asyncio async def test_task(self, loop_communicator): - TASK = 'task this' + TASK = 'task this' # noqa: N806 task_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -145,7 +135,6 @@ def get_task(_comm, msg): class TestTaskActions: - @pytest.mark.asyncio async def test_launch(self, loop_communicator, async_controller, persister): # Let the process run to the end @@ -157,7 +146,7 @@ async def test_launch(self, loop_communicator, async_controller, persister): @pytest.mark.asyncio async def test_launch_nowait(self, loop_communicator, async_controller, persister): - """ Testing launching but don't wait, just get the pid """ + """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) @@ -165,7 +154,7 @@ async def test_launch_nowait(self, loop_communicator, async_controller, persiste @pytest.mark.asyncio async def test_execute_action(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) result = await async_controller.execute_process(utils.DummyProcessWithOutput) @@ -173,7 +162,7 @@ async def test_execute_action(self, loop_communicator, async_controller, persist @pytest.mark.asyncio async def test_execute_action_nowait(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) @@ -197,7 +186,7 @@ async def test_launch_many(self, loop_communicator, async_controller, persister) @pytest.mark.asyncio async def test_continue(self, loop_communicator, async_controller, persister): - """ Test continuing a saved process """ + """Test continuing a saved process""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) process = utils.DummyProcessWithOutput() diff --git a/test/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py similarity index 97% rename from test/rmq/test_process_comms.py rename to tests/rmq/test_process_comms.py index 6afccf46..7223b888 100644 --- a/test/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- import asyncio +import copy import kiwipy -from kiwipy import rmq import pytest import shortuuid +from kiwipy import rmq import plumpy -from plumpy import process_comms import plumpy.communications +from plumpy import process_comms from .. import utils @@ -43,7 +44,6 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -122,7 +122,6 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -197,7 +196,10 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - sync_controller.kill_all('bang bang, I shot you down') + msg = copy.copy(process_comms.KILL_MSG) + msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + + sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) diff --git a/test/test_communications.py b/tests/test_communications.py similarity index 100% rename from test/test_communications.py rename to tests/test_communications.py index f82036bd..f7e04255 100644 --- a/test/test_communications.py +++ b/tests/test_communications.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.communications` module.""" -from kiwipy import CommunicatorHelper + import pytest +from kiwipy import CommunicatorHelper from plumpy.communications import LoopCommunicator @@ -14,7 +15,6 @@ def __call__(self): class Communicator(CommunicatorHelper): - def task_send(self, task, no_reply=False): pass diff --git a/test/test_events.py b/tests/test_events.py similarity index 99% rename from test/test_events.py rename to tests/test_events.py index e6260f1d..964bd6f7 100644 --- a/test/test_events.py +++ b/tests/test_events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.events` module.""" + import asyncio import pathlib diff --git a/test/test_expose.py b/tests/test_expose.py similarity index 84% rename from test/test_expose.py rename to tests/test_expose.py index 1a495727..0f6f8087 100644 --- a/test/test_expose.py +++ b/tests/test_expose.py @@ -1,56 +1,50 @@ # -*- coding: utf-8 -*- -from test.utils import NewLoopProcess import unittest +from .utils import NewLoopProcess + from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process -class TestExposeProcess(unittest.TestCase): +def validator_function(input, port): + pass - def setUp(self): - super().setUp() - def validator_function(input, port): - pass +class BaseNamespaceProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('top') + spec.input('namespace.sub_one') + spec.input('namespace.sub_two') + spec.inputs['namespace'].valid_type = (int, float) + spec.inputs['namespace'].validator = validator_function - class BaseNamespaceProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('top') - spec.input('namespace.sub_one') - spec.input('namespace.sub_two') - spec.inputs['namespace'].valid_type = (int, float) - spec.inputs['namespace'].validator = validator_function +class BaseProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a', valid_type=str, default='a') + spec.input('b', valid_type=str, default='b') + spec.inputs.dynamic = True + spec.inputs.valid_type = str - class BaseProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('a', valid_type=str, default='a') - spec.input('b', valid_type=str, default='b') - spec.inputs.dynamic = True - spec.inputs.valid_type = str +class ExposeProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_inputs(BaseProcess, namespace='base.name.space') + spec.input('c', valid_type=int, default=1) + spec.input('d', valid_type=int, default=2) + spec.inputs.dynamic = True + spec.inputs.valid_type = int - class ExposeProcess(NewLoopProcess): - - @classmethod - def define(cls, spec): - super().define(spec) - spec.expose_inputs(BaseProcess, namespace='base.name.space') - spec.input('c', valid_type=int, default=1) - spec.input('d', valid_type=int, default=2) - spec.inputs.dynamic = True - spec.inputs.valid_type = int - - self.BaseNamespaceProcess = BaseNamespaceProcess - self.BaseProcess = BaseProcess - self.ExposeProcess = ExposeProcess +class TestExposeProcess(unittest.TestCase): def check_ports(self, process, namespace, expected_port_names): """Check the port namespace of a given process inputs spec for existence of set of expected port names.""" port_namespace = process.spec().inputs @@ -68,24 +62,21 @@ def check_namespace_properties(self, process_left, namespace_left, process_right port_namespace_left = process_left.spec().inputs.get_port(namespace_left) port_namespace_right = process_right.spec().inputs.get_port(namespace_right) - # Pop the ports in stored in the `_ports` attribute - port_namespace_left.__dict__.pop('_ports', None) - port_namespace_right.__dict__.pop('_ports', None) + left_dict = {k: v for k, v in port_namespace_left.__dict__.items() if k != '_ports'} + right_dict = {k: v for k, v in port_namespace_right.__dict__.items() if k != '_ports'} - self.assertEqual(port_namespace_left.__dict__, port_namespace_right.__dict__) + self.assertEqual(left_dict, right_dict) def test_expose_dynamic(self): """Test that exposing a dynamic namespace remains dynamic.""" class Lower(Process): - @classmethod def define(cls, spec): super(Lower, cls).define(spec) spec.input_namespace('foo', dynamic=True) class Upper(Process): - @classmethod def define(cls, spec): super(Upper, cls).define(spec) @@ -96,7 +87,7 @@ def define(cls, spec): def test_expose_nested_namespace(self): """Test that expose_inputs can create nested namespaces while maintaining own ports.""" - inputs = self.ExposeProcess.spec().inputs + inputs = ExposeProcess.spec().inputs # Verify that the nested namespaces are present self.assertTrue('base' in inputs) @@ -116,7 +107,7 @@ def test_expose_nested_namespace(self): def test_expose_ports(self): """Test that the exposed ports are present and properly deepcopied.""" - exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') + exposed_inputs = ExposeProcess.spec().inputs.get_port('base.name.space') self.assertEqual(len(exposed_inputs), 2) self.assertTrue('a' in exposed_inputs) @@ -125,32 +116,30 @@ def test_expose_ports(self): self.assertEqual(exposed_inputs['b'].default, 'b') # Change the default of base process port and verify they don't change the exposed port - self.BaseProcess.spec().inputs['a'].default = 'c' - self.assertEqual(self.BaseProcess.spec().inputs['a'].default, 'c') + BaseProcess.spec().inputs['a'].default = 'c' + self.assertEqual(BaseProcess.spec().inputs['a'].default, 'c') self.assertEqual(exposed_inputs['a'].default, 'a') def test_expose_attributes(self): """Test that the attributes of the exposed PortNamespace are maintained and properly deepcopied.""" - inputs = self.ExposeProcess.spec().inputs - exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') + inputs = ExposeProcess.spec().inputs + exposed_inputs = ExposeProcess.spec().inputs.get_port('base.name.space') - self.assertEqual(str, self.BaseProcess.spec().inputs.valid_type) + self.assertEqual(str, BaseProcess.spec().inputs.valid_type) self.assertEqual(str, exposed_inputs.valid_type) self.assertEqual(int, inputs.valid_type) # Now change the valid type of the BaseProcess inputs and verify it does not affect ExposeProcess - self.BaseProcess.spec().inputs.valid_type = float + BaseProcess.spec().inputs.valid_type = float - self.assertEqual(self.BaseProcess.spec().inputs.valid_type, float) + self.assertEqual(BaseProcess.spec().inputs.valid_type, float) self.assertEqual(exposed_inputs.valid_type, str) self.assertEqual(inputs.valid_type, int) def test_expose_exclude(self): """Test that the exclude argument of exposed_inputs works correctly and excludes ports from being absorbed.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -165,10 +154,8 @@ def define(cls, spec): def test_expose_include(self): """Test that the include argument of exposed_inputs works correctly and includes only specified ports.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -183,10 +170,8 @@ def define(cls, spec): def test_expose_exclude_include_mutually_exclusive(self): """Test that passing both exclude and include raises.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -208,7 +193,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -218,7 +203,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -228,7 +213,7 @@ def validator_function(input, port): namespace=None, exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -256,7 +241,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -266,7 +251,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -283,7 +268,7 @@ def validator_function(input, port): 'dynamic': False, 'default': None, 'help': None, - } + }, ) # Verify that all the ports are there @@ -310,7 +295,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -320,7 +305,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -330,7 +315,7 @@ def validator_function(input, port): namespace='namespace', exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -351,8 +336,8 @@ def test_expose_ports_namespace_options_non_existent(self): Verify that passing non-supported PortNamespace mutable properties in namespace_options will raise a ValueError """ - ChildProcessSpec = ProcessSpec() - ParentProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 + ParentProcessSpec = ProcessSpec() # noqa: N806 with self.assertRaises(ValueError): ParentProcessSpec._expose_ports( @@ -365,15 +350,13 @@ def test_expose_ports_namespace_options_non_existent(self): include=None, namespace_options={ 'non_existent': None, - } + }, ) def test_expose_nested_include_top_level(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -384,10 +367,8 @@ def define(cls, spec): def test_expose_nested_include_namespace(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -400,10 +381,8 @@ def define(cls, spec): def test_expose_nested_include_namespace_sub(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -416,10 +395,8 @@ def define(cls, spec): def test_expose_nested_include_combination(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -432,10 +409,8 @@ def define(cls, spec): def test_expose_nested_exclude_top_level(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -448,10 +423,8 @@ def define(cls, spec): def test_expose_nested_exclude_namespace(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -462,10 +435,8 @@ def define(cls, spec): def test_expose_nested_exclude_namespace_sub(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -478,10 +449,8 @@ def define(cls, spec): def test_expose_nested_exclude_combination(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -504,7 +473,6 @@ def test_expose_exclude_port_with_validator(self): """ class BaseProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -520,10 +488,10 @@ def validator(cls, value, ctx): return None if not isinstance(value['a'], str): - return f'value for input `a` should be a str, but got: {type(value["a"])}' + a_type = type(value['a']) + return f'value for input `a` should be a str, but got: {a_type}' class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_lang.py b/tests/test_lang.py similarity index 97% rename from test/test_lang.py rename to tests/test_lang.py index 13136530..a55af31a 100644 --- a/test/test_lang.py +++ b/tests/test_lang.py @@ -5,7 +5,6 @@ class A: - def __init__(self): self._a = None @@ -22,27 +21,24 @@ def protected_property(self): def protected_fn_nocheck(self): return self._a - def testA(self): + def testA(self): # noqa: N802 self.protected_fn() self.protected_property class B(A): - - def testB(self): + def testB(self): # noqa: N802 self.protected_fn() self.protected_property class C(B): - - def testC(self): + def testC(self): # noqa: N802 self.protected_fn() self.protected_property class TestProtected(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -79,7 +75,6 @@ def test_incorrect_usage(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder: - @protected(check=True) @property def a(self): @@ -87,13 +82,11 @@ def a(self): class Superclass: - def test(self): pass class TestOverride(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -102,9 +95,7 @@ def some_func(): pass def test_correct_usage(self): - class Derived(Superclass): - @override(check=True) def test(self): return True @@ -115,7 +106,6 @@ class Middle(Superclass): pass class Next(Middle): - @override(check=True) def test(self): return True @@ -123,9 +113,7 @@ def test(self): self.assertTrue(Next().test()) def test_incorrect_usage(self): - class Derived: - @override(check=True) def test(self): pass @@ -136,7 +124,6 @@ def test(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder(Superclass): - @override(check=True) @property def test(self): diff --git a/test/test_loaders.py b/tests/test_loaders.py similarity index 97% rename from test/test_loaders.py rename to tests/test_loaders.py index 3058b77c..a1813f09 100644 --- a/test/test_loaders.py +++ b/tests/test_loaders.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.loaders` module.""" + import pytest import plumpy @@ -7,6 +8,7 @@ class DummyClass: """Dummy class for testing.""" + pass @@ -38,11 +40,12 @@ def test_default_object_roundtrip(): @pytest.mark.parametrize( - 'identifier, match', ( + 'identifier, match', + ( ('plumpy.non_existing_module.SomeClass', r'identifier `.*` has an invalid format.'), ('plumpy.non_existing_module:SomeClass', r'module `.*` from identifier `.*` could not be loaded.'), ('plumpy.loaders:NonExistingClass', r'object `.*` form identifier `.*` could not be loaded.'), - ) + ), ) def test_default_object_loader_load_object_except(identifier, match): """Test the :meth:`plumpy.DefaultObjectLoader.load_object` when it is expected to raise.""" diff --git a/test/test_persistence.py b/tests/test_persistence.py similarity index 97% rename from test/test_persistence.py rename to tests/test_persistence.py index 2c9cf4f9..78724aa0 100644 --- a/test/test_persistence.py +++ b/tests/test_persistence.py @@ -15,7 +15,6 @@ class SaveEmpty(plumpy.Savable): @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): - def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -26,13 +25,11 @@ def m(): @plumpy.auto_persist('test') class Save(plumpy.Savable): - def __init__(self): self.test = Save1() class TestSavable(unittest.TestCase): - def test_empty_savable(self): self._save_round_trip(SaveEmpty()) @@ -79,9 +76,8 @@ def _save_round_trip_with_loader(self, savable): class TestBundle(unittest.TestCase): - def test_bundle_load_context(self): - """ Check that the loop from the load context is used """ + """Check that the loop from the load context is used""" loop1 = asyncio.get_event_loop() proc = utils.DummyProcess(loop=loop1) bundle = plumpy.Bundle(proc) diff --git a/test/test_port.py b/tests/test_port.py similarity index 99% rename from test/test_port.py rename to tests/test_port.py index ab9b51a6..da483e81 100644 --- a/test/test_port.py +++ b/tests/test_port.py @@ -7,7 +7,6 @@ class TestPort(TestCase): - def test_required(self): spec = Port('required_value', required=True) @@ -21,7 +20,6 @@ def test_validate(self): self.assertIsNotNone(spec.validate('a')) def test_validator(self): - def validate(value, port): assert isinstance(port, Port) if not isinstance(value, int): @@ -45,7 +43,6 @@ def validate(value, port): class TestInputPort(TestCase): - def test_default(self): """Test the default value property for the InputPort.""" port = InputPort('test', default=5) @@ -81,12 +78,14 @@ def test_lambda_default(self): # Testing that passing an actual lambda as a value is alos possible port = InputPort('test', valid_type=(types.FunctionType, int), default=lambda: 5) - some_lambda = lambda: 'string' + + def some_lambda(): + return 'string' + self.assertIsNone(port.validate(some_lambda)) class TestOutputPort(TestCase): - def test_default(self): """ Test the default value property for the InputPort @@ -108,7 +107,6 @@ def validator(value, port): class TestPortNamespace(TestCase): - BASE_PORT_NAME = 'port' BASE_PORT_NAMESPACE_NAME = 'port' @@ -299,7 +297,7 @@ def test_port_namespace_validate(self): # Check the breadcrumbs are correct self.assertEqual( validation_error.port, - self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')) + self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')), ) def test_port_namespace_required(self): @@ -371,7 +369,9 @@ def test_port_namespace_lambda_defaults(self): self.assertIsNone(port_namespace.validate(inputs)) # When passing a lambda directly as the value, it should NOT be evaluated during pre_processing - some_lambda = lambda: 5 + def some_lambda(): + return 5 + inputs = port_namespace.pre_process({'lambda_default': some_lambda}) self.assertEqual(inputs['lambda_default'], some_lambda) self.assertIsNone(port_namespace.validate(inputs)) diff --git a/test/test_process_comms.py b/tests/test_process_comms.py similarity index 87% rename from test/test_process_comms.py rename to tests/test_process_comms.py index 6d3d335c..c59737ac 100644 --- a/test/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,23 +1,17 @@ # -*- coding: utf-8 -*- -import asyncio -from test import utils -import unittest - -from kiwipy import rmq import pytest +from tests import utils import plumpy -from plumpy import communications, process_comms +from plumpy import process_comms class Process(plumpy.Process): - def run(self): pass class CustomObjectLoader(plumpy.DefaultObjectLoader): - def load_object(self, identifier): if identifier == 'jimmy': return Process @@ -49,7 +43,7 @@ async def test_continue(): @pytest.mark.asyncio async def test_loader_is_used(): - """ Make sure that the provided class loader is used by the process launcher """ + """Make sure that the provided class loader is used by the process launcher""" loader = CustomObjectLoader() proc = Process() persister = plumpy.InMemoryPersister(loader=loader) diff --git a/test/test_process_spec.py b/tests/test_process_spec.py similarity index 99% rename from test/test_process_spec.py rename to tests/test_process_spec.py index 443f7a64..3be8c1d2 100644 --- a/test/test_process_spec.py +++ b/tests/test_process_spec.py @@ -10,7 +10,6 @@ class StrSubtype(str): class TestProcessSpec(TestCase): - def setUp(self): self.spec = ProcessSpec() @@ -18,7 +17,6 @@ def test_get_port_namespace_base(self): """ Get the root, inputs and outputs port namespaces of the ProcessSpec """ - ports = self.spec.ports input_ports = self.spec.inputs output_ports = self.spec.outputs diff --git a/test/test_processes.py b/tests/test_processes.py similarity index 97% rename from test/test_processes.py rename to tests/test_processes.py index 0cb4161b..faea9eae 100644 --- a/test/test_processes.py +++ b/tests/test_processes.py @@ -1,20 +1,22 @@ # -*- coding: utf-8 -*- """Process tests""" + import asyncio +import copy import enum -from test import utils import unittest import kiwipy import pytest +from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY from plumpy.utils import AttributesFrozendict class ForgetToCallParent(plumpy.Process): - def __init__(self, forget_on): super().__init__() self.forget_on = forget_on @@ -42,9 +44,7 @@ def on_kill(self, msg): @pytest.mark.asyncio async def test_process_scope(): - class ProcessTaskInterleave(plumpy.Process): - async def task(self, steps: list): steps.append(f'[{self.pid}] started') assert plumpy.Process.current() is self @@ -64,7 +64,6 @@ async def task(self, steps: list): class TestProcess(unittest.TestCase): - def test_spec(self): """ Check that the references to specs are doing the right thing... @@ -82,12 +81,10 @@ class Proc(utils.DummyProcess): self.assertIs(p.spec(), Proc.spec()) def test_dynamic_inputs(self): - class NoDynamic(Process): pass class WithDynamic(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -100,9 +97,7 @@ def define(cls, spec): proc.execute() def test_inputs(self): - class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -122,7 +117,6 @@ def test_raw_inputs(self): """ class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -138,9 +132,7 @@ def define(cls, spec): self.assertDictEqual(dict(process.raw_inputs), {'a': 5, 'nested': {'a': 'value'}}) def test_inputs_default(self): - class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -199,7 +191,6 @@ def test_inputs_default_that_evaluate_to_false(self): for def_val in (True, False, 0, 1): class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -214,7 +205,6 @@ def test_nested_namespace_defaults(self): """Process with a default in a nested namespace should be created, even if top level namespace not supplied.""" class SomeProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -229,7 +219,6 @@ def test_raise_in_define(self): """Process which raises in its 'define' method. Check that the spec is not set.""" class BrokenProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -293,12 +282,11 @@ def test_run_kill(self): proc.execute() def test_get_description(self): - class ProcWithoutSpec(Process): pass class ProcWithSpec(Process): - """ Process with a spec and a docstring """ + """Process with a spec and a docstring""" @classmethod def define(cls, spec): @@ -324,9 +312,7 @@ def define(cls, spec): self.assertIsInstance(desc_with_spec['description'], str) def test_logging(self): - class LoggerTester(Process): - def run(self, **kwargs): self.logger.info('Test') @@ -335,11 +321,13 @@ def run(self, **kwargs): proc.execute() def test_kill(self): - proc = utils.DummyProcess() + proc: Process = utils.DummyProcess() - proc.kill('Farewell!') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'Farewell!' + proc.kill(msg) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), 'Farewell!') + self.assertEqual(proc.killed_msg(), msg) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self): @@ -438,12 +426,13 @@ async def async_test(): self.assertEqual(proc.state, ProcessState.FINISHED) def test_kill_in_run(self): - class KillProcess(Process): after_kill = False def run(self, **kwargs): - self.kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state self.after_kill = True @@ -456,9 +445,7 @@ def run(self, **kwargs): self.assertEqual(proc.state, ProcessState.KILLED) def test_kill_when_paused_in_run(self): - class PauseProcess(Process): - def run(self, **kwargs): self.pause() self.kill() @@ -510,9 +497,7 @@ def test_run_multiple(self): self.assertDictEqual(proc_class.EXPECTED_OUTPUTS, result) def test_invalid_output(self): - class InvalidOutput(plumpy.Process): - def run(self): self.out('invalid', 5) @@ -536,7 +521,6 @@ def test_unsuccessful_result(self): ERROR_CODE = 256 class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -550,11 +534,10 @@ def run(self): self.assertEqual(proc.result(), ERROR_CODE) def test_pause_in_process(self): - """ Test that we can pause and cancel that by playing within the process """ + """Test that we can pause and cancel that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -574,12 +557,11 @@ def run(self): self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) def test_pause_play_in_process(self): - """ Test that we can pause and play that by playing within the process """ + """Test that we can pause and play that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -596,7 +578,6 @@ def test_process_stack(self): test_case = self class StackTest(plumpy.Process): - def run(self): test_case.assertIs(self, Process.current()) @@ -613,7 +594,6 @@ def test_nested(process): expect_true.append(process == Process.current()) class StackTest(plumpy.Process): - def run(self): # TODO: unexpected behaviour here # if assert error happend here not raise @@ -623,7 +603,6 @@ def run(self): test_nested(self) class ParentProcess(plumpy.Process): - def run(self): expect_true.append(self == Process.current()) StackTest().execute() @@ -646,21 +625,17 @@ def test_process_nested(self): """ class StackTest(plumpy.Process): - def run(self): pass class ParentProcess(plumpy.Process): - def run(self): StackTest().execute() ParentProcess().execute() def test_call_soon(self): - class CallSoon(plumpy.Process): - def run(self): self.call_soon(self.do_except) @@ -680,7 +655,6 @@ def test_exception_during_on_entered(self): """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" class RaisingProcess(Process): - def on_entered(self, from_state): if from_state is not None and from_state.label == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') @@ -696,9 +670,7 @@ def on_entered(self, from_state): assert str(process.exception()) == 'exception during on_entered' def test_exception_during_run(self): - class RaisingProcess(Process): - def run(self): raise RuntimeError('exception during run') @@ -862,7 +834,7 @@ async def async_test(): loop.run_until_complete(async_test()) def test_wait_save_continue(self): - """ Test that process saved while in WAITING state restarts correctly when loaded """ + """Test that process saved while in WAITING state restarts correctly when loaded""" loop = asyncio.get_event_loop() proc = utils.WaitForSignalProcess() @@ -905,7 +877,6 @@ def _check_round_trip(self, proc1): class TestProcessNamespace(unittest.TestCase): - def test_namespaced_process(self): """ Test that inputs in nested namespaces are properly validated and the returned @@ -913,7 +884,6 @@ def test_namespaced_process(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -938,7 +908,6 @@ def test_namespaced_process_inputs(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -964,7 +933,6 @@ def test_namespaced_process_dynamic(self): namespace = 'name.space' class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -991,14 +959,12 @@ def test_namespaced_process_outputs(self): namespace_nested = f'{namespace}.nested' class OutputMode(enum.Enum): - NONE = 0 DYNAMIC_PORT_NAMESPACE = 1 SINGLE_REQUIRED_PORT = 2 BOTH_SINGLE_AND_NAMESPACE = 3 class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -1057,7 +1023,6 @@ def run(self): class TestProcessEvents(unittest.TestCase): - def test_basic_events(self): proc = utils.DummyProcessWithOutput() events_tester = utils.ProcessListenerTester( @@ -1077,11 +1042,14 @@ def test_killed(self): def test_excepted(self): proc = utils.ExceptionProcess() - events_tester = utils.ProcessListenerTester(proc, ( - 'excepted', - 'running', - 'output_emitted', - )) + events_tester = utils.ProcessListenerTester( + proc, + ( + 'excepted', + 'running', + 'output_emitted', + ), + ) with self.assertRaises(RuntimeError): proc.execute() proc.result() @@ -1120,7 +1088,6 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): class _RestartProcess(utils.WaitForSignalProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_utils.py b/tests/test_utils.py similarity index 99% rename from test/test_utils.py rename to tests/test_utils.py index 546261f2..9567db7a 100644 --- a/test/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,6 @@ class TestAttributesFrozendict: - def test_getitem(self): d = AttributesFrozendict({'a': 5}) assert d['a'] == 5 @@ -40,7 +39,6 @@ async def async_fct(): class TestEnsureCoroutine: - def test_sync_func(self): coro = ensure_coroutine(fct) assert inspect.iscoroutinefunction(coro) @@ -50,9 +48,7 @@ def test_async_func(self): assert coro is async_fct def test_callable_class(self): - class AsyncDummy: - async def __call__(self): pass @@ -60,9 +56,7 @@ async def __call__(self): assert coro is AsyncDummy def test_callable_object(self): - class AsyncDummy: - async def __call__(self): pass diff --git a/test/test_waiting_process.py b/tests/test_waiting_process.py similarity index 99% rename from test/test_waiting_process.py rename to tests/test_waiting_process.py index 87d39192..90427554 100644 --- a/test/test_waiting_process.py +++ b/tests/test_waiting_process.py @@ -9,7 +9,6 @@ class TestWaitingProcess(unittest.TestCase): - def test_instance_state(self): proc = utils.ThreeSteps() wl = utils.ProcessSaver(proc) diff --git a/test/test_workchains.py b/tests/test_workchains.py similarity index 93% rename from test/test_workchains.py rename to tests/test_workchains.py index c698aff9..08c7317a 100644 --- a/test/test_workchains.py +++ b/tests/test_workchains.py @@ -33,9 +33,17 @@ def on_create(self): super().on_create() # Reset the finished step self.finished_steps = { - k: False for k in [ - self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, - self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__ + k: False + for k in [ + self.s1.__name__, + self.s2.__name__, + self.s3.__name__, + self.s4.__name__, + self.s5.__name__, + self.s6.__name__, + self.isA.__name__, + self.isB.__name__, + self.ltN.__name__, ] } @@ -59,15 +67,15 @@ def s6(self): self.ctx.counter = self.ctx.counter + 1 self._set_finished(inspect.stack()[0][3]) - def isA(self): + def isA(self): # noqa: N802 self._set_finished(inspect.stack()[0][3]) return self.inputs.value == 'A' - def isB(self): + def isB(self): # noqa: N802 self._set_finished(inspect.stack()[0][3]) return self.inputs.value == 'B' - def ltN(self): + def ltN(self): # noqa: N802 keep_looping = self.ctx.counter < self.inputs.n if not keep_looping: self._set_finished(inspect.stack()[0][3]) @@ -78,7 +86,6 @@ def _set_finished(self, function_name): class IfTest(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -101,7 +108,6 @@ def step2(self): class DummyWc(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -112,7 +118,6 @@ def do_nothing(self): class TestContext(unittest.TestCase): - def test_attributes(self): wc = DummyWc() wc.ctx.new_attr = 5 @@ -136,9 +141,9 @@ class TestWorkchain(unittest.TestCase): maxDiff = None def test_run(self): - A = 'A' - B = 'B' - C = 'C' + A = 'A' # noqa: N806 + B = 'B' # noqa: N806 + C = 'C' # noqa: N806 three = 3 # Try the if(..) part @@ -163,9 +168,7 @@ def test_run(self): self.assertTrue(finished, f'Step {step} was not called by workflow') def test_incorrect_outline(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -176,9 +179,7 @@ def define(cls, spec): Wf.spec() def test_same_input_node(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -195,11 +196,10 @@ def check_a_b(self): Wf(inputs=dict(a=x, b=x)).execute() def test_context(self): - A = 'a' - B = 'b' + A = 'a' # noqa: N806 + B = 'b' # noqa: N806 class ReturnA(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -209,7 +209,6 @@ def run(self): self.out('res', A) class ReturnB(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -219,7 +218,6 @@ def run(self): self.out('res', B) class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -257,9 +255,9 @@ def test_malformed_outline(self): spec.outline(lambda x, y: 5) def test_checkpointing(self): - A = 'A' - B = 'B' - C = 'C' + A = 'A' # noqa: N806 + B = 'B' # noqa: N806 + C = 'C' # noqa: N806 three = 3 # Try the if(..) part @@ -288,13 +286,11 @@ def test_listener_persistence(self): process_finished_count = 0 class TestListener(plumpy.ProcessListener): - def on_process_finished(self, process, output): nonlocal process_finished_count process_finished_count += 1 class SimpleWorkChain(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -315,7 +311,8 @@ def step2(self): workchain = SimpleWorkChain() workchain.add_process_listener(TestListener()) - output = workchain.execute() + + workchain.execute() self.assertEqual(process_finished_count, 1) @@ -324,7 +321,6 @@ def step2(self): self.assertEqual(process_finished_count, 2) def test_return_in_outline(self): - class WcWithReturn(WorkChain): FAILED_CODE = 1 @@ -360,9 +356,7 @@ def default(self): workchain.execute() def test_return_in_step(self): - class WcWithReturn(WorkChain): - FAILED_CODE = 1 @classmethod @@ -393,9 +387,7 @@ def after(self): workchain.execute() def test_tocontext_schedule_workchain(self): - class MainWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -409,7 +401,6 @@ def check(self): assert self.ctx.subwc.out.value == 5 class SubWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -446,14 +437,13 @@ async def async_test(): self.assertTrue(workchain.ctx.s2) loop = asyncio.get_event_loop() - loop.create_task(workchain.step_until_terminated()) + loop.create_task(workchain.step_until_terminated()) # noqa: RUF006 loop.run_until_complete(async_test()) def test_to_context(self): val = 5 class SimpleWc(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -463,7 +453,6 @@ def run(self): self.out('_return', val) class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -484,7 +473,6 @@ def test_output_namespace(self): """Test running a workchain with nested outputs.""" class TestWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -501,7 +489,6 @@ def test_exception_tocontext(self): my_exception = RuntimeError('Should not be reached') class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -528,7 +515,6 @@ def test_stepper_info(self): """Check status information provided by steppers""" class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -539,7 +525,13 @@ def define(cls, spec): cls.chill, cls.chill, ), - if_(cls.do_step)(cls.chill,).elif_(cls.do_step)(cls.chill,).else_(cls.chill), + if_(cls.do_step)( + cls.chill, + ) + .elif_(cls.do_step)( + cls.chill, + ) + .else_(cls.chill), ) def check_n(self): @@ -560,7 +552,6 @@ def do_step(self): return False class StatusCollector(ProcessListener): - def __init__(self): self.stepper_strings = [] @@ -574,9 +565,15 @@ def on_process_running(self, process): wf.execute() stepper_strings = [ - '0:check_n', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '1:while_(do_step)(1:chill)', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '2:if_(do_step)' + '0:check_n', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '2:if_(do_step)', ] self.assertListEqual(collector.stepper_strings, stepper_strings) @@ -593,7 +590,6 @@ def test_immutable_input(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -630,7 +626,6 @@ def test_immutable_input_namespace(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/utils.py b/tests/utils.py similarity index 93% rename from test/utils.py rename to tests/utils.py index 1f7408f6..f2a58dfc 100644 --- a/test/utils.py +++ b/tests/utils.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- """Utilities for tests""" + import asyncio import collections -from collections.abc import Mapping +import copy import unittest - -import kiwipy.rmq -import shortuuid +from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -24,7 +24,9 @@ class DummyProcess(processes.Process): """ EXPECTED_STATE_SEQUENCE = [ - process_states.ProcessState.CREATED, process_states.ProcessState.RUNNING, process_states.ProcessState.FINISHED + process_states.ProcessState.CREATED, + process_states.ProcessState.RUNNING, + process_states.ProcessState.FINISHED, ] EXPECTED_OUTPUTS = {} @@ -58,14 +60,12 @@ def run(self, **kwargs): class KeyboardInterruptProc(processes.Process): - @utils.override def run(self): raise KeyboardInterrupt() class ProcessWithCheckpoint(processes.Process): - @utils.override def run(self): return process_states.Continue(self.last_step) @@ -75,7 +75,6 @@ def last_step(self): class WaitForSignalProcess(processes.Process): - @utils.override def run(self): return process_states.Wait(self.last_step) @@ -85,14 +84,15 @@ def last_step(self): class KillProcess(processes.Process): - @utils.override def run(self): - return process_states.Kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + return process_states.Kill(msg=msg) class MissingOutputProcess(processes.Process): - """ A process that does not generate a required output """ + """A process that does not generate a required output""" @classmethod def define(cls, spec): @@ -101,7 +101,6 @@ def define(cls, spec): class NewLoopProcess(processes.Process): - def __init__(self, *args, **kwargs): kwargs['loop'] = plumpy.new_event_loop() super().__init__(*args, **kwargs) @@ -118,8 +117,7 @@ def called(cls, event): cls.called_events.append(event) def __init__(self, *args, **kwargs): - assert isinstance(self, processes.Process), \ - 'Mixin has to be used with a type derived from a Process' + assert isinstance(self, processes.Process), 'Mixin has to be used with a type derived from a Process' super().__init__(*args, **kwargs) self.__class__.called_events = [] @@ -165,7 +163,6 @@ def on_terminate(self): class ProcessEventsTester(EventsTesterMixin, processes.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -193,7 +190,6 @@ def last_step(self): class TwoCheckpointNoFinish(ProcessEventsTester): - def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -203,21 +199,18 @@ def middle_step(self): class ExceptionProcess(ProcessEventsTester): - def run(self): self.out('test', 5) raise RuntimeError('Great scott!') class ThreeStepsThenException(ThreeSteps): - @utils.override def last_step(self): raise RuntimeError('Great scott!') class ProcessListenerTester(plumpy.ProcessListener): - def __init__(self, process, expected_events): process.add_process_listener(self) self.expected_events = set(expected_events) @@ -249,7 +242,6 @@ def on_process_killed(self, process, msg): class Saver: - def __init__(self): self.snapshots = [] self.outputs = [] @@ -357,7 +349,11 @@ def on_process_killed(self, process, msg): TEST_PROCESSES = [DummyProcess, DummyProcessWithOutput, DummyProcessWithDynamicOutput, ThreeSteps] TEST_WAITING_PROCESSES = [ - ProcessWithCheckpoint, TwoCheckpointNoFinish, ExceptionProcess, ProcessEventsTester, ThreeStepsThenException + ProcessWithCheckpoint, + TwoCheckpointNoFinish, + ExceptionProcess, + ProcessEventsTester, + ThreeStepsThenException, ] TEST_EXCEPTION_PROCESSES = [ExceptionProcess, ThreeStepsThenException, MissingOutputProcess] @@ -402,7 +398,7 @@ def check_process_against_snapshots(loop, proc_class, snapshots): saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], - exclude={'exception', '_listeners'} + exclude={'exception', '_listeners'}, ) j += 1 @@ -438,9 +434,8 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): compare_value(bundle1, bundle2, list(v1), list(v2), exclude) elif isinstance(v1, set) and isinstance(v2, set): raise NotImplementedError('Comparison between sets not implemented') - else: - if v1 != v2: - raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') + elif v1 != v2: + raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') class TestPersister(persistence.Persister): @@ -449,7 +444,7 @@ class TestPersister(persistence.Persister): """ def save_checkpoint(self, process, tag=None): - """ Create the checkpoint bundle """ + """Create the checkpoint bundle""" persistence.Bundle(process) def load_checkpoint(self, pid, tag=None): @@ -469,7 +464,7 @@ def delete_process_checkpoints(self, pid): def run_until_waiting(proc): - """ Set up a future that will be resolved on entering the WAITING state """ + """Set up a future that will be resolved on entering the WAITING state""" from plumpy import ProcessState listener = plumpy.ProcessListener() @@ -490,7 +485,7 @@ def on_waiting(_waiting_proc): def run_until_paused(proc): - """ Set up a future that will be resolved when the process is paused """ + """Set up a future that will be resolved when the process is paused""" listener = plumpy.ProcessListener() paused = plumpy.Future()