diff --git a/.ci/workchains.py b/.ci/workchains.py index 5504813a10..e94f44669c 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name +"""Work chain implementations for testing purposes.""" from aiida.common import AttributeDict from aiida.engine import calcfunction, workfunction, WorkChain, ToContext, append_, while_, ExitCode from aiida.engine import BaseRestartWorkChain, process_handler, ProcessHandlerReport @@ -15,7 +16,6 @@ from aiida.orm import Int, List, Str from aiida.plugins import CalculationFactory - ArithmeticAddCalculation = CalculationFactory('arithmetic.add') @@ -54,15 +54,15 @@ def setup(self): def sanity_check_not_too_big(self, node): """My puny brain cannot deal with numbers that I cannot count on my hand.""" if node.is_finished_ok and node.outputs.sum > 10: - return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) + return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) # pylint: disable=no-member @process_handler(priority=460, enabled=False) - def disabled_handler(self, node): + def disabled_handler(self, node): # pylint: disable=unused-argument """By default this is not enabled and so should never be called, irrespective of exit codes of sub process.""" - return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) + return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) # pylint: disable=no-member @process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered')) - def a_magic_unicorn_appeared(self, node): + def a_magic_unicorn_appeared(self, node): # pylint: disable=no-self-argument,no-self-use """As we all know unicorns do not exist so we should never have to deal with it.""" raise RuntimeError('this handler should never even have been called') @@ -78,30 +78,24 @@ class NestedWorkChain(WorkChain): """ Nested workchain which creates a workflow where the nesting level is equal to its input. """ + @classmethod def define(cls, spec): super().define(spec) spec.input('inp', valid_type=Int) - spec.outline( - cls.do_submit, - cls.finalize - ) + spec.outline(cls.do_submit, cls.finalize) spec.output('output', valid_type=Int, required=True) def do_submit(self): if self.should_submit(): self.report('Submitting nested workchain.') - return ToContext( - workchain=append_(self.submit( - NestedWorkChain, - inp=self.inputs.inp - 1 - )) - ) + return ToContext(workchain=append_(self.submit(NestedWorkChain, inp=self.inputs.inp - 1))) def should_submit(self): return int(self.inputs.inp) > 0 def finalize(self): + """Attach the outputs.""" if self.should_submit(): self.report('Getting sub-workchain output.') sub_workchain = self.ctx.workchain[0] @@ -112,15 +106,13 @@ def finalize(self): class SerializeWorkChain(WorkChain): + """Work chain that serializes inputs.""" + @classmethod def define(cls, spec): super().define(spec) - spec.input( - 'test', - valid_type=Str, - serializer=lambda x: Str(ObjectLoader().identify_object(x)) - ) + spec.input('test', valid_type=Str, serializer=lambda x: Str(ObjectLoader().identify_object(x))) spec.outline(cls.echo) spec.outputs.dynamic = True @@ -130,6 +122,8 @@ def echo(self): class NestedInputNamespace(WorkChain): + """Work chain with nested namespace.""" + @classmethod def define(cls, spec): super().define(spec) @@ -143,6 +137,8 @@ def do_echo(self): class ListEcho(WorkChain): + """Work chain that simply echos a `List` input.""" + @classmethod def define(cls, spec): super().define(spec) @@ -157,6 +153,8 @@ def do_echo(self): class DynamicNonDbInput(WorkChain): + """Work chain with dynamic non_db inputs.""" + @classmethod def define(cls, spec): super().define(spec) @@ -172,6 +170,8 @@ def do_test(self): class DynamicDbInput(WorkChain): + """Work chain with dynamic input namespace.""" + @classmethod def define(cls, spec): super().define(spec) @@ -186,6 +186,8 @@ def do_test(self): class DynamicMixedInput(WorkChain): + """Work chain with dynamic mixed input.""" + @classmethod def define(cls, spec): super().define(spec) @@ -194,6 +196,7 @@ def define(cls, spec): spec.outline(cls.do_test) def do_test(self): + """Run the test.""" input_non_db = self.inputs.namespace.inputs['input_non_db'] input_db = self.inputs.namespace.inputs['input_db'] assert isinstance(input_non_db, int) @@ -206,6 +209,7 @@ class CalcFunctionRunnerWorkChain(WorkChain): """ WorkChain which calls an InlineCalculation in its step. """ + @classmethod def define(cls, spec): super().define(spec) @@ -223,6 +227,7 @@ class WorkFunctionRunnerWorkChain(WorkChain): """ WorkChain which calls a workfunction in its step """ + @classmethod def define(cls, spec): super().define(spec) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3e83f5dd8..e58bc453ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,104 +8,16 @@ repos: - id: mixed-line-ending - id: trailing-whitespace -- repo: https://github.com/PyCQA/pylint - rev: pylint-2.5.2 - hooks: - - id: pylint - language: system - exclude: &exclude_files > - (?x)^( - .ci/workchains.py| - aiida/backends/djsite/queries.py| - aiida/backends/djsite/db/models.py| - aiida/backends/djsite/db/migrations/0001_initial.py| - aiida/backends/djsite/db/migrations/0002_db_state_change.py| - aiida/backends/djsite/db/migrations/0003_add_link_type.py| - aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py| - aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py| - aiida/backends/djsite/db/migrations/0006_delete_dbpath.py| - aiida/backends/djsite/db/migrations/0007_update_linktypes.py| - aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py| - aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py| - aiida/backends/djsite/db/migrations/0010_process_type.py| - aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py| - aiida/backends/djsite/db/migrations/0012_drop_dblock.py| - aiida/backends/djsite/db/migrations/0013_django_1_8.py| - aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py| - aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py| - aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py| - aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py| - aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py| - aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py| - aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py| - aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py| - aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py| - aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py| - aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py| - aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py| - aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py| - aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py| - aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py| - aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py| - aiida/backends/sqlalchemy/models/computer.py| - aiida/backends/sqlalchemy/models/settings.py| - aiida/backends/sqlalchemy/models/node.py| - aiida/backends/utils.py| - aiida/common/datastructures.py| - aiida/engine/daemon/execmanager.py| - aiida/engine/processes/calcjobs/tasks.py| - aiida/orm/querybuilder.py| - aiida/orm/nodes/data/array/bands.py| - aiida/orm/nodes/data/array/projection.py| - aiida/orm/nodes/data/array/xy.py| - aiida/orm/nodes/data/code.py| - aiida/orm/nodes/data/orbital.py| - aiida/orm/nodes/data/remote.py| - aiida/orm/nodes/data/structure.py| - aiida/orm/utils/remote.py| - aiida/parsers/plugins/arithmetic/add.py| - aiida/parsers/plugins/templatereplacer/doubler.py| - aiida/parsers/plugins/templatereplacer/__init__.py| - aiida/plugins/entry.py| - aiida/plugins/info.py| - aiida/plugins/registry.py| - aiida/tools/data/array/kpoints/legacy.py| - aiida/tools/data/array/kpoints/seekpath.py| - aiida/tools/data/__init__.py| - aiida/tools/dbexporters/__init__.py| - aiida/tools/dbimporters/baseclasses.py| - aiida/tools/dbimporters/__init__.py| - aiida/tools/dbimporters/plugins/cod.py| - aiida/tools/dbimporters/plugins/icsd.py| - aiida/tools/dbimporters/plugins/__init__.py| - aiida/tools/dbimporters/plugins/mpds.py| - aiida/tools/dbimporters/plugins/mpod.py| - aiida/tools/dbimporters/plugins/nninc.py| - aiida/tools/dbimporters/plugins/oqmd.py| - aiida/tools/dbimporters/plugins/pcod.py| - docs/.*| - examples/.*| - tests/engine/test_work_chain.py| - tests/schedulers/test_direct.py| - tests/schedulers/test_lsf.py| - tests/schedulers/test_pbspro.py| - tests/schedulers/test_sge.py| - tests/schedulers/test_torque.py| - tests/sphinxext/workchain_source/conf.py| - tests/sphinxext/workchain_source_broken/conf.py| - tests/transports/test_all_plugins.py| - tests/transports/test_local.py| - tests/transports/test_ssh.py| - tests/test_dataclasses.py| - )$ - - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.30.0 hooks: - id: yapf name: yapf types: [python] - exclude: *exclude_files + exclude: &exclude_files > + (?x)^( + docs/.*| + )$ args: ['-i'] - repo: https://github.com/pre-commit/mirrors-mypy @@ -120,6 +32,13 @@ repos: )$ - repo: local + + hooks: + - id: pylint + name: pylint + language: system + exclude: *exclude_files + hooks: - id: dm-generate-all name: Update all requirements files diff --git a/.pylintrc b/.pylintrc index 5816e54bd0..eca9d52a8b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -60,7 +60,8 @@ disable=bad-continuation, import-outside-toplevel, cyclic-import, duplicate-code, - too-few-public-methods + too-few-public-methods, + inconsistent-return-statements # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/aiida/backends/djsite/db/migrations/0001_initial.py b/aiida/backends/djsite/db/migrations/0001_initial.py index 15d19b72ac..0ea8397da0 100644 --- a/aiida/backends/djsite/db/migrations/0001_initial.py +++ b/aiida/backends/djsite/db/migrations/0001_initial.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import django.db.models.deletion import django.utils.timezone @@ -19,6 +20,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('auth', '0001_initial'), @@ -53,8 +55,8 @@ class Migration(migrations.Migration): 'is_active', models.BooleanField( default=True, - help_text= - 'Designates whether this user should be treated as active. Unselect this instead of deleting accounts.' + help_text='Designates whether this user should be treated as active. Unselect this instead of ' + 'deleting accounts.' ) ), ('date_joined', models.DateTimeField(default=django.utils.timezone.now)), @@ -65,8 +67,8 @@ class Migration(migrations.Migration): related_name='user_set', to='auth.Group', blank=True, - help_text= - 'The groups this user belongs to. A user will get all permissions granted to each of his/her group.', + help_text='The groups this user belongs to. A user will get all permissions granted to each of ' + 'his/her group.', verbose_name='groups' ) ), diff --git a/aiida/backends/djsite/db/migrations/0002_db_state_change.py b/aiida/backends/djsite/db/migrations/0002_db_state_change.py index cfc0fdcd2c..2ac6d980c4 100644 --- a/aiida/backends/djsite/db/migrations/0002_db_state_change.py +++ b/aiida/backends/djsite/db/migrations/0002_db_state_change.py @@ -7,24 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.2' DOWN_REVISION = '1.0.1' -def fix_calc_states(apps, schema_editor): +def fix_calc_states(apps, _): + """Fix calculation states.""" from aiida.orm.utils import load_node # These states should never exist in the database but we'll play it safe # and deal with them if they do DbCalcState = apps.get_model('db', 'DbCalcState') - for calc_state in DbCalcState.objects.filter( - state__in=['UNDETERMINED', 'NOTFOUND']): + for calc_state in DbCalcState.objects.filter(state__in=['UNDETERMINED', 'NOTFOUND']): old_state = calc_state.state calc_state.state = 'FAILED' calc_state.save() @@ -32,11 +32,13 @@ def fix_calc_states(apps, schema_editor): calc = load_node(pk=calc_state.dbnode.pk) calc.logger.warning( 'Job state {} found for calculation {} which should never be in ' - 'the database. Changed state to FAILED.'.format( - old_state, calc_state.dbnode.pk)) + 'the database. Changed state to FAILED.'.format(old_state, calc_state.dbnode.pk) + ) class Migration(migrations.Migration): + """Database migration.""" + dependencies = [ ('db', '0001_initial'), ] @@ -47,14 +49,15 @@ class Migration(migrations.Migration): name='state', # The UNDETERMINED and NOTFOUND 'states' were removed as these # don't make sense - field=models.CharField(db_index=True, max_length=25, - choices=[('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), - ('RETRIEVING', 'RETRIEVING'), ('WITHSCHEDULER', 'WITHSCHEDULER'), - ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), ('PARSING', 'PARSING'), - ('FAILED', 'FAILED'), ('FINISHED', 'FINISHED'), - ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), - ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), - ('PARSINGFAILED', 'PARSINGFAILED')]), + field=models.CharField( + db_index=True, + max_length=25, + choices=[('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), ('RETRIEVING', 'RETRIEVING'), + ('WITHSCHEDULER', 'WITHSCHEDULER'), ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), + ('PARSING', 'PARSING'), ('FAILED', 'FAILED'), + ('FINISHED', 'FINISHED'), ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), + ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), ('PARSINGFAILED', 'PARSINGFAILED')] + ), preserve_default=True, ), # Fix up any calculation states that had one of the removed states diff --git a/aiida/backends/djsite/db/migrations/0003_add_link_type.py b/aiida/backends/djsite/db/migrations/0003_add_link_type.py index 40117da428..24e32381b7 100644 --- a/aiida/backends/djsite/db/migrations/0003_add_link_type.py +++ b/aiida/backends/djsite/db/migrations/0003_add_link_type.py @@ -7,17 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import aiida.common.timezone from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.3' DOWN_REVISION = '1.0.2' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0002_db_state_change'), diff --git a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py b/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py index 5bf78b5bdf..cb53ff3d6e 100644 --- a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py +++ b/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py @@ -7,18 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.4' DOWN_REVISION = '1.0.3' class Migration(migrations.Migration): + """Database migration.""" + dependencies = [ ('db', '0003_add_link_type'), ] @@ -27,20 +29,20 @@ class Migration(migrations.Migration): # Create the index that speeds up the daemon queries # We use the RunSQL command because Django interface # doesn't seem to support partial indexes - migrations.RunSQL(""" + migrations.RunSQL( + """ CREATE INDEX tval_idx_for_daemon ON db_dbattribute (tval) WHERE ("db_dbattribute"."tval" - IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))"""), + IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))""" + ), # Create an index on UUIDs to speed up loading of nodes # using this field migrations.AlterField( model_name='dbnode', name='uuid', - field=models.CharField(max_length=36,db_index=True, - editable=False, - blank=True), + field=models.CharField(max_length=36, db_index=True, editable=False, blank=True), preserve_default=True, ), upgrade_schema_version(REVISION, DOWN_REVISION) diff --git a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py b/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py index 71a901f1a0..11c7e99953 100644 --- a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py +++ b/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py @@ -7,17 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations import aiida.common.timezone from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.5' DOWN_REVISION = '1.0.4' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0004_add_daemon_and_uuid_indices'), diff --git a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py b/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py index 905c459960..134b52d8c7 100644 --- a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py +++ b/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.6' DOWN_REVISION = '1.0.5' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0005_add_cmtime_indices'), @@ -35,13 +36,12 @@ class Migration(migrations.Migration): model_name='dbnode', name='children', ), - migrations.DeleteModel( - name='DbPath', - ), - migrations.RunSQL(""" + migrations.DeleteModel(name='DbPath',), + migrations.RunSQL( + """ DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink; DROP FUNCTION IF EXISTS update_tc(); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py b/aiida/backends/djsite/db/migrations/0007_update_linktypes.py index c694cb793b..a966516b29 100644 --- a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py +++ b/aiida/backends/djsite/db/migrations/0007_update_linktypes.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.8' DOWN_REVISION = '1.0.7' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0006_delete_dbpath'), @@ -36,7 +37,8 @@ class Migration(migrations.Migration): # - joins a Data (or subclass) as output # - is marked as a returnlink. # 2) set for these links the type to 'createlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -46,7 +48,8 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'data.%' AND db_dblink_1.type = 'returnlink' ); - """), + """ + ), # Now I am updating the link-types that are null because of either an export and subsequent import # https://github.com/aiidateam/aiida-core/issues/685 # or because the link types don't exist because the links were added before the introduction of link types. @@ -55,10 +58,11 @@ class Migration(migrations.Migration): # The following sql statement: # 1) selects all links that # - joins Data (or subclass) or Code as input - # - joins Calculation (or subclass) as output. This includes WorkCalculation, InlineCalcuation, JobCalculations... + # - joins Calculation (or subclass) as output: includes WorkCalculation, InlineCalcuation, JobCalculations... # - has no type (null) # 2) set for these links the type to 'inputlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='inputlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -68,7 +72,8 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # # The following sql statement: # 1) selects all links that @@ -76,7 +81,8 @@ class Migration(migrations.Migration): # - joins Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'createlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='createlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -90,14 +96,16 @@ class Migration(migrations.Migration): ) AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # The following sql statement: # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked for. + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked # - join Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'returnlink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='returnlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -107,15 +115,17 @@ class Migration(migrations.Migration): AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), # Now I update links that are CALLS: # The following sql statement: # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked for. + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. # - has no type (null) # 2) set for these links the type to 'calllink' - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dblink set type='calllink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -125,7 +135,7 @@ class Migration(migrations.Migration): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py b/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py index 52b3292951..be65bd0bc7 100644 --- a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py +++ b/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.8' DOWN_REVISION = '1.0.7' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0007_update_linktypes'), @@ -28,16 +29,20 @@ class Migration(migrations.Migration): # we move that value to the extra table # # First we copy the 'hidden' attributes from code.Code. nodes to the db_extra table - migrations.RunSQL(""" + migrations.RunSQL( + """ INSERT INTO db_dbextra (key, datatype, tval, fval, ival, bval, dval, dbnode_id) ( - SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id + SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, + db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id FROM db_dbattribute JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' ); - """), + """ + ), # Secondly, we delete the original entries from the DbAttribute table - migrations.RunSQL(""" + migrations.RunSQL( + """ DELETE FROM db_dbattribute WHERE id in ( SELECT db_dbattribute.id @@ -45,6 +50,7 @@ class Migration(migrations.Migration): JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' ); - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py b/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py index 2ca434e3e4..1a9317d0b1 100644 --- a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py +++ b/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.9' DOWN_REVISION = '1.0.8' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0008_code_hidden_to_extra'), @@ -26,12 +27,14 @@ class Migration(migrations.Migration): # The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their # module path changes, which determines the plugin type string which is stored in the databse. # The type string now will have a type string prefix that is unique to each sub type. - migrations.RunSQL(""" + migrations.RunSQL( + """ UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0010_process_type.py b/aiida/backends/djsite/db/migrations/0010_process_type.py index 11fb34bd32..d1c36dc526 100644 --- a/aiida/backends/djsite/db/migrations/0010_process_type.py +++ b/aiida/backends/djsite/db/migrations/0010_process_type.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.10' DOWN_REVISION = '1.0.9' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0009_base_data_plugin_type_string'), @@ -24,9 +25,7 @@ class Migration(migrations.Migration): operations = [ migrations.AddField( - model_name='dbnode', - name='process_type', - field=models.CharField(max_length=255, db_index=True, null=True) + model_name='dbnode', name='process_type', field=models.CharField(max_length=255, db_index=True, null=True) ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py b/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py index 35cdeb715c..d3fcb91e1b 100644 --- a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py +++ b/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py @@ -7,23 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.11' DOWN_REVISION = '1.0.10' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0010_process_type'), ] operations = [ - migrations.RunSQL(""" + migrations.RunSQL( + """ DROP TABLE IF EXISTS kombu_message; DROP TABLE IF EXISTS kombu_queue; DELETE FROM db_dbsetting WHERE key = 'daemon|user'; @@ -33,6 +35,7 @@ class Migration(migrations.Migration): DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; - """), + """ + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py b/aiida/backends/djsite/db/migrations/0012_drop_dblock.py index b89787f3f1..0c37ec8fd7 100644 --- a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py +++ b/aiida/backends/djsite/db/migrations/0012_drop_dblock.py @@ -7,24 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.12' DOWN_REVISION = '1.0.11' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0011_delete_kombu_tables'), ] - operations = [ - migrations.DeleteModel( - name='DbLock', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] + operations = [migrations.DeleteModel(name='DbLock',), upgrade_schema_version(REVISION, DOWN_REVISION)] diff --git a/aiida/backends/djsite/db/migrations/0013_django_1_8.py b/aiida/backends/djsite/db/migrations/0013_django_1_8.py index 6448e3b924..17d5b3a196 100644 --- a/aiida/backends/djsite/db/migrations/0013_django_1_8.py +++ b/aiida/backends/djsite/db/migrations/0013_django_1_8.py @@ -7,16 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import models, migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version - REVISION = '1.0.13' DOWN_REVISION = '1.0.12' class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0012_drop_dblock'), diff --git a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py b/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py index 48794a87c5..8d125f2196 100644 --- a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py +++ b/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py @@ -18,7 +18,7 @@ DOWN_REVISION = '1.0.13' -def verify_node_uuid_uniqueness(apps, schema_editor): +def verify_node_uuid_uniqueness(_, __): """Check whether the database contains nodes with duplicate UUIDS. Note that we have to redefine this method from aiida.manage.database.integrity.verify_node_uuid_uniqueness @@ -31,7 +31,7 @@ def verify_node_uuid_uniqueness(apps, schema_editor): verify_uuid_uniqueness(table='db_dbnode') -def reverse_code(apps, schema_editor): +def reverse_code(_, __): pass diff --git a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py b/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py index 62e8e446b9..f3ac3ca9c3 100644 --- a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py +++ b/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py @@ -8,9 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,too-few-public-methods -""" -Invalidating node hash - User should rehash nodes for caching -""" +"""Invalidating node hash - User should rehash nodes for caching.""" # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-name-in-module,import-error diff --git a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py b/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py index 611a324dc1..d1fe5fe1b2 100644 --- a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py +++ b/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version @@ -16,6 +17,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0015_invalidating_node_hash'), @@ -26,6 +28,7 @@ class Migration(migrations.Migration): # To make everything fully consistent, its type string should therefore also start with `data.` migrations.RunSQL( sql="""UPDATE db_dbnode SET type = 'data.code.Code.' WHERE type = 'code.Code.';""", - reverse_sql="""UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';"""), + reverse_sql="""UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';""" + ), upgrade_schema_version(REVISION, DOWN_REVISION) ] diff --git a/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py b/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py index 4f7bcd904e..eda7694481 100644 --- a/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py +++ b/aiida/backends/djsite/db/migrations/0017_drop_dbcalcstate.py @@ -7,7 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name +"""Database migration.""" from django.db import migrations from aiida.backends.djsite.db.migrations import upgrade_schema_version @@ -16,6 +17,7 @@ class Migration(migrations.Migration): + """Database migration.""" dependencies = [ ('db', '0016_code_sub_class_of_data'), diff --git a/aiida/backends/djsite/db/models.py b/aiida/backends/djsite/db/models.py index c5c81924ad..01701fd2f4 100644 --- a/aiida/backends/djsite/db/models.py +++ b/aiida/backends/djsite/db/models.py @@ -193,9 +193,7 @@ def __str__(self): return "'{}'={}".format(self.key, self.getvalue()) @classmethod - def set_value( - cls, key, value, with_transaction=True, subspecifier_value=None, other_attribs=None, stop_if_existing=False - ): + def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): """Delete a setting value.""" other_attribs = other_attribs if other_attribs is not None else {} setting = DbSetting.objects.filter(key=key).first() @@ -221,7 +219,7 @@ def get_description(self): return self.description @classmethod - def del_value(cls, key, only_children=False, subspecifier_value=None): + def del_value(cls, key): """Set a setting value.""" setting = DbSetting.objects.filter(key=key).first() @@ -300,7 +298,6 @@ class DbComputer(m.Model): name = m.CharField(max_length=255, unique=True, blank=False) hostname = m.CharField(max_length=255) description = m.TextField(blank=True) - # TODO: next three fields should not be blank... scheduler_type = m.CharField(max_length=255) transport_type = m.CharField(max_length=255) metadata = JSONField(default=dict) diff --git a/aiida/backends/djsite/queries.py b/aiida/backends/djsite/queries.py index ff8121c7ab..53ed500305 100644 --- a/aiida/backends/djsite/queries.py +++ b/aiida/backends/djsite/queries.py @@ -110,6 +110,7 @@ def query_group(q_object, args): def get_bands_and_parents_structure(self, args): """Returns bands and closest parent structure.""" + # pylint: disable=too-many-locals from django.db.models import Q from aiida.backends.djsite.db import models from aiida.common.utils import grouper diff --git a/aiida/backends/sqlalchemy/migrations/env.py b/aiida/backends/sqlalchemy/migrations/env.py index c18b73c2f6..e2d5f246c1 100644 --- a/aiida/backends/sqlalchemy/migrations/env.py +++ b/aiida/backends/sqlalchemy/migrations/env.py @@ -55,7 +55,7 @@ def run_migrations_online(): if connectable is None: from aiida.common.exceptions import ConfigurationError - raise ConfigurationError('An initialized connection is expected ' 'for the AiiDA online migrations.') + raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') with connectable.connect() as connection: context.configure( diff --git a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py b/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py index d38f1fe672..66d8f7e0a8 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py +++ b/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the `transport_params` from the `Computer` database model. Revision ID: 07fac78e6209 diff --git a/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py index 0f3392ce6f..d73fd01407 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py +++ b/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Correct the type string for the base data types Revision ID: 0aebbeab274d @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = '0aebbeab274d' down_revision = '7a6587e16f4c' @@ -26,29 +26,35 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their # module path changes, which determines the plugin type string which is stored in the databse. # The type string now will have a type string prefix that is unique to each sub type. - statement = text(""" + statement = text( + """ UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; - """) + """ + ) conn.execute(statement) def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() - statement = text(""" + statement = text( + """ UPDATE db_dbnode SET type = 'data.base.Bool.' WHERE type = 'data.bool.Bool.'; UPDATE db_dbnode SET type = 'data.base.Float.' WHERE type = 'data.float.Float.'; UPDATE db_dbnode SET type = 'data.base.Int.' WHERE type = 'data.int.Int.'; UPDATE db_dbnode SET type = 'data.base.Str.' WHERE type = 'data.str.Str.'; UPDATE db_dbnode SET type = 'data.base.List.' WHERE type = 'data.list.List.'; - """) + """ + ) conn.execute(statement) diff --git a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index 835cdff566..70c331faa1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Move trajectory symbols from repository array to attribute Revision ID: 12536798d4d3 diff --git a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py b/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py index 191d0e6935..888bf556be 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py +++ b/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the DbCalcState table Revision ID: 162b99bca4a2 @@ -35,10 +36,10 @@ def downgrade(): sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint( - ['dbnode_id'], ['db_dbnode.id'], - name='db_dbcalcstate_dbnode_id_fkey', - ondelete='CASCADE', - initially='DEFERRED', - deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), - sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key')) + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcalcstate_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), + sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') + ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py b/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py index 7f42c2d91a..0e9587e5b3 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the columns `nodeversion` and `public` from the `DbNode` model. Revision ID: 1830c8430131 diff --git a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py b/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py index 16ca636185..af91d0e34c 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py +++ b/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for legacy `JobCalculations`. These old nodes have already been migrated to the correct `CalcJobNode` type in a previous migration, but they can diff --git a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py b/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py index 7cacdbb518..8d417a4ffc 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py +++ b/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Migrating 'hidden' properties from DbAttribute to DbExtra for code.Code. nodes Revision ID: 35d4ee9a1b0e @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = '35d4ee9a1b0e' down_revision = '89176227b25' @@ -26,14 +26,25 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # Set hidden=True in extras if the attributes contain hidden=True - statement = text("""UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(True)) WHERE type = 'code.Code.' AND attributes @> '{"hidden": true}'""") + statement = text( + """ + UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(True)) + WHERE type = 'code.Code.' AND attributes @> '{"hidden": true}' + """ + ) conn.execute(statement) # Set hidden=False in extras if the attributes contain hidden=False - statement = text("""UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(False)) WHERE type = 'code.Code.' AND attributes @> '{"hidden": false}'""") + statement = text( + """ + UPDATE db_dbnode SET extras = jsonb_set(extras, '{"hidden"}', to_jsonb(False)) + WHERE type = 'code.Code.' AND attributes @> '{"hidden": false}' + """ + ) conn.execute(statement) # Delete the hidden key from the attributes @@ -42,14 +53,25 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() # Set hidden=True in attributes if the extras contain hidden=True - statement = text("""UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(True)) WHERE type = 'code.Code.' AND extras @> '{"hidden": true}'""") + statement = text( + """ + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(True)) + WHERE type = 'code.Code.' AND extras @> '{"hidden": true}' + """ + ) conn.execute(statement) # Set hidden=False in attributes if the extras contain hidden=False - statement = text("""UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(False)) WHERE type = 'code.Code.' AND extras @> '{"hidden": false}'""") + statement = text( + """ + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(False)) + WHERE type = 'code.Code.' AND extras @> '{"hidden": false}' + """ + ) conn.execute(statement) # Delete the hidden key from the extras diff --git a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py b/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py index 76d6e1569d..bc43767eb1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py +++ b/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Make all uuid columns unique Revision ID: 37f3d4882837 diff --git a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py b/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py index f42d245587..c710703708 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py +++ b/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Adding indexes and constraints to the dbnode-dbgroup relationship table Revision ID: 59edaf8a8b79 @@ -25,30 +26,28 @@ def upgrade(): + """Migrations for the upgrade.""" # Check if constraint uix_dbnode_id_dbgroup_id of migration 7a6587e16f4c # is there and if yes, drop it insp = Inspector.from_engine(op.get_bind()) for constr in insp.get_unique_constraints('db_dbgroup_dbnodes'): if constr['name'] == 'uix_dbnode_id_dbgroup_id': - op.drop_constraint('uix_dbnode_id_dbgroup_id', - 'db_dbgroup_dbnodes') + op.drop_constraint('uix_dbnode_id_dbgroup_id', 'db_dbgroup_dbnodes') - op.create_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes', - ['dbnode_id']) - op.create_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes', - ['dbgroup_id']) + op.create_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes', ['dbnode_id']) + op.create_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes', ['dbgroup_id']) op.create_unique_constraint( - 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', - ['dbgroup_id', 'dbnode_id']) + 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', ['dbgroup_id', 'dbnode_id'] + ) def downgrade(): + """Migrations for the downgrade.""" op.drop_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes') op.drop_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes') - op.drop_constraint('db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', - 'db_dbgroup_dbnodes') + op.drop_constraint('db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes') # Creating the constraint uix_dbnode_id_dbgroup_id that migration # 7a6587e16f4c would add op.create_unique_constraint( - 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', - ['dbgroup_id', 'dbnode_id']) + 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', ['dbgroup_id', 'dbnode_id'] + ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py b/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py index 17661d95e2..354f52b6d3 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py +++ b/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Adding indices on the `input_id`, `output_id` and `type` column of the `DbLink` table Revision ID: 5a49629f0d45 diff --git a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py b/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py index e988b836f0..4420d84cd6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py +++ b/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Final data migration for `Nodes` after `aiida.orm.nodes` reorganization was finalized to remove the `node.` prefix Revision ID: 61fc0913fae9 diff --git a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py b/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py index fb67adf396..81336fb16f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py +++ b/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add a unique constraint on the UUID column of the Node model Revision ID: 62fe0d36de90 @@ -31,26 +32,30 @@ def verify_node_uuid_uniqueness(): :raises: IntegrityError if database contains nodes with duplicate UUIDS. """ - from alembic import op from sqlalchemy.sql import text from aiida.common.exceptions import IntegrityError query = text( - 'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM db_dbnode) AS s WHERE c > 1') + 'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM db_dbnode) AS s WHERE c > 1' + ) conn = op.get_bind() duplicates = conn.execute(query).fetchall() if duplicates: table = 'db_dbnode' - command = '`verdi database integrity detect-duplicate-uuid {table}`'.format(table) - raise IntegrityError('Your table "{}" contains entries with duplicate UUIDS.\nRun {} ' - 'to return to a consistent state'.format(table, command)) + command = '`verdi database integrity detect-duplicate-uuid {table}`'.format(table=table) + raise IntegrityError( + 'Your table "{}" contains entries with duplicate UUIDS.\nRun {} ' + 'to return to a consistent state'.format(table, command) + ) def upgrade(): + """Migrations for the upgrade.""" verify_node_uuid_uniqueness() op.create_unique_constraint('db_dbnode_uuid_key', 'db_dbnode', ['uuid']) def downgrade(): + """Migrations for the downgrade.""" op.drop_constraint('db_dbnode_uuid_key', 'db_dbnode') diff --git a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py b/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py index f6e32c800e..86160b0e46 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py +++ b/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for `Data` nodes after it was moved in the `aiida.orm.node` module changing the type string. Revision ID: 6a5c2ea1439d diff --git a/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py b/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py index 79fd26ca75..b29a4b7514 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py +++ b/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add the process_type column to DbNode Revision ID: 6c629c886f84 @@ -17,7 +18,6 @@ from alembic import op import sqlalchemy as sa - # revision identifiers, used by Alembic. revision = '6c629c886f84' down_revision = '0aebbeab274d' @@ -26,11 +26,14 @@ def upgrade(): - op.add_column('db_dbnode', + """Migrations for the upgrade.""" + op.add_column( + 'db_dbnode', sa.Column('process_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), ) op.create_index('ix_db_dbnode_process_type', 'db_dbnode', ['process_type']) def downgrade(): + """Migrations for the downgrade.""" op.drop_column('db_dbnode', 'process_type') diff --git a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py index 98df687a2d..bd0ad4409f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py +++ b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Deleting dbpath table and triggers Revision ID: 70c7d732f1b2 @@ -16,7 +17,6 @@ """ from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql from sqlalchemy.orm.session import Session from aiida.backends.sqlalchemy.utils import install_tc @@ -28,6 +28,7 @@ def upgrade(): + """Migrations for the upgrade.""" op.drop_table('db_dbpath') conn = op.get_bind() conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink') @@ -35,17 +36,23 @@ def upgrade(): def downgrade(): - op.create_table('db_dbpath', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], name='db_dbpath_child_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], name='db_dbpath_parent_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') + """Migrations for the downgrade.""" + op.create_table( + 'db_dbpath', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], + name='db_dbpath_child_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], + name='db_dbpath_parent_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') ) # I get the session using the alembic connection # (Keep in mind that alembic uses the AiiDA SQLA diff --git a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py b/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py index cd7ff83347..f3f8087837 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py +++ b/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Add indexes to dbworkflowdata table Revision ID: 89176227b25 @@ -24,10 +25,8 @@ def upgrade(): - op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', - ['aiida_obj_id']) - op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', - ['parent_id']) + op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', ['aiida_obj_id']) + op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', ['parent_id']) def downgrade(): diff --git a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py b/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py index 1093abf660..24cd6c8be9 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Drop the DbLock model Revision ID: a514d673c163 @@ -18,7 +19,6 @@ from sqlalchemy.dialects import postgresql import sqlalchemy as sa - # revision identifiers, used by Alembic. revision = 'a514d673c163' down_revision = 'f9a69de76a9a' @@ -31,10 +31,10 @@ def upgrade(): def downgrade(): - op.create_table('db_dblock', - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') + op.create_table( + 'db_dblock', sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py b/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py index 7342df1e61..71203f6f28 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Correct the type string for the code class Revision ID: a603da2cc809 @@ -25,6 +26,7 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # The Code class used to be just a sub class of Node but was changed to act like a Data node. @@ -36,6 +38,7 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() statement = text(""" diff --git a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py b/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py index 7636241e41..440d41cf20 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py +++ b/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Updating link types - This is a copy of the Django migration script Revision ID: a6048f0ffca8 @@ -25,6 +26,7 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # I am first migrating the wrongly declared returnlinks out of @@ -40,7 +42,8 @@ def upgrade(): # - joins a Data (or subclass) as output # - is marked as a returnlink. # 2) set for these links the type to 'createlink' - stmt1 = text(""" + stmt1 = text( + """ UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -50,7 +53,8 @@ def upgrade(): AND db_dbnode_2.type LIKE 'data.%' AND db_dblink_1.type = 'returnlink' ) - """) + """ + ) conn.execute(stmt1) # Now I am updating the link-types that are null because of either an export and subsequent import # https://github.com/aiidateam/aiida-core/issues/685 @@ -63,7 +67,8 @@ def upgrade(): # - joins Calculation (or subclass) as output. This includes WorkCalculation, InlineCalcuation, JobCalculations... # - has no type (null) # 2) set for these links the type to 'inputlink' - stmt2 = text(""" + stmt2 = text( + """ UPDATE db_dblink set type='inputlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -73,7 +78,8 @@ def upgrade(): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ); - """) + """ + ) conn.execute(stmt2) # # The following sql statement: @@ -82,7 +88,8 @@ def upgrade(): # - joins Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'createlink' - stmt3 = text(""" + stmt3 = text( + """ UPDATE db_dblink set type='createlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -96,7 +103,8 @@ def upgrade(): ) AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt3) # The following sql statement: # 1) selects all links that @@ -104,7 +112,8 @@ def upgrade(): # - join Data (or subclass) as output. # - has no type (null) # 2) set for these links the type to 'returnlink' - stmt4 = text(""" + stmt4 = text( + """ UPDATE db_dblink set type='returnlink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -114,7 +123,8 @@ def upgrade(): AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt4) # Now I update links that are CALLS: # The following sql statement: @@ -123,7 +133,8 @@ def upgrade(): # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. # - has no type (null) # 2) set for these links the type to 'calllink' - stmt5 = text(""" + stmt5 = text( + """ UPDATE db_dblink set type='calllink' where id in ( SELECT db_dblink_1.id FROM db_dbnode AS db_dbnode_1 @@ -133,9 +144,11 @@ def upgrade(): AND db_dbnode_2.type LIKE 'calculation.%' AND ( db_dblink_1.type = null OR db_dblink_1.type = '') ) - """) + """ + ) conn.execute(stmt5) def downgrade(): + """Migrations for the downgrade.""" print('There is no downgrade for the link types') diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py index 626b561c12..6d71cd55f6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py +++ b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Migration after the `Group` class became pluginnable and so the group `type_string` changed. Revision ID: bf591f31dd12 diff --git a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py index cf5932e79e..682e4371ad 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Delete trajectory symbols array from the repository and the reference in the attributes Revision ID: ce56d84bcc35 @@ -14,8 +15,6 @@ Create Date: 2019-01-21 15:35:07.280805 """ -# pylint: disable=invalid-name - # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-member,no-name-in-module,import-error diff --git a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py b/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py index 071e548ce3..87a1aa8fc0 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py +++ b/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Data migration for after `ParameterData` was renamed to `Dict`. Revision ID: d254fdfed416 @@ -14,7 +15,6 @@ Create Date: 2019-02-25 19:29:11.753089 """ -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module from alembic import op from sqlalchemy.sql import text diff --git a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py b/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py index 143054fe7b..a154d0f019 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py +++ b/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module """Drop various columns from the `DbUser` model. These columns were part of the default Django user model @@ -16,7 +17,6 @@ Create Date: 2019-05-28 11:15:33.242602 """ -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module from alembic import op import sqlalchemy as sa diff --git a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py b/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py index c48ec01f46..ab4b00f560 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py +++ b/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Initial schema Revision ID: e15ef2630a1b @@ -28,233 +29,290 @@ def upgrade(): - op.create_table('db_dbuser', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('email', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('password', sa.VARCHAR(length=128), autoincrement=False, nullable=True), - sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False), - sa.Column('first_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('last_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('institution', sa.VARCHAR(length=254), autoincrement=False, nullable=True), - sa.Column('is_staff', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbuser_pkey'), - postgresql_ignore_search_path=False + """Migrations for the upgrade.""" + op.create_table( + 'db_dbuser', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('email', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('password', sa.VARCHAR(length=128), autoincrement=False, nullable=True), + sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False), + sa.Column('first_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('last_name', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('institution', sa.VARCHAR(length=254), autoincrement=False, nullable=True), + sa.Column('is_staff', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbuser_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbuser_email', 'db_dbuser', ['email'], unique=True) - op.create_table('db_dbworkflow', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('lastsyncedversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('report', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module_class', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_path', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_md5', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflow_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), - postgresql_ignore_search_path=False + op.create_table( + 'db_dbworkflow', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('lastsyncedversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('report', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('module', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('module_class', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('script_path', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('script_md5', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflow_user_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbworkflow_label', 'db_dbworkflow', ['label']) - op.create_table('db_dbworkflowstep', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('nextcall', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_parent_id_fkey'), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflowstep_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), - sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_key'), - postgresql_ignore_search_path=False + op.create_table( + 'db_dbworkflowstep', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('nextcall', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_parent_id_fkey'), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflowstep_user_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), + sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_key'), + postgresql_ignore_search_path=False ) - op.create_table('db_dbcomputer', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('hostname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('transport_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('scheduler_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('transport_params', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcomputer_pkey'), - sa.UniqueConstraint('name', name='db_dbcomputer_name_key') + op.create_table( + 'db_dbcomputer', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('hostname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('transport_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('scheduler_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('transport_params', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbcomputer_pkey'), + sa.UniqueConstraint('name', name='db_dbcomputer_name_key') ) - op.create_table('db_dbauthinfo', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('aiidauser_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('auth_params', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['aiidauser_id'], ['db_dbuser.id'], name='db_dbauthinfo_aiidauser_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], name='db_dbauthinfo_dbcomputer_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbauthinfo_pkey'), - sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id', name='db_dbauthinfo_aiidauser_id_dbcomputer_id_key') + op.create_table( + 'db_dbauthinfo', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('aiidauser_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('auth_params', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('enabled', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['aiidauser_id'], ['db_dbuser.id'], + name='db_dbauthinfo_aiidauser_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], + name='db_dbauthinfo_dbcomputer_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbauthinfo_pkey'), + sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id', name='db_dbauthinfo_aiidauser_id_dbcomputer_id_key') ) - op.create_table('db_dbgroup', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbgroup_user_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbgroup_pkey'), - sa.UniqueConstraint('name', 'type', name='db_dbgroup_name_type_key') + op.create_table( + 'db_dbgroup', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbgroup_user_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbgroup_pkey'), + sa.UniqueConstraint('name', 'type', name='db_dbgroup_name_type_key') ) op.create_index('ix_db_dbgroup_name', 'db_dbgroup', ['name']) op.create_index('ix_db_dbgroup_type', 'db_dbgroup', ['type']) - op.create_table('db_dbnode', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('public', sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column('attributes', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('extras', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], name='db_dbnode_dbcomputer_id_fkey', ondelete='RESTRICT', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbnode_user_id_fkey', ondelete='RESTRICT', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbnode_pkey'),postgresql_ignore_search_path=False + op.create_table( + 'db_dbnode', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('public', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('attributes', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('extras', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('dbcomputer_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['dbcomputer_id'], ['db_dbcomputer.id'], + name='db_dbnode_dbcomputer_id_fkey', + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbnode_user_id_fkey', + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True), + sa.PrimaryKeyConstraint('id', name='db_dbnode_pkey'), + postgresql_ignore_search_path=False ) op.create_index('ix_db_dbnode_label', 'db_dbnode', ['label']) op.create_index('ix_db_dbnode_type', 'db_dbnode', ['type']) - op.create_table('db_dbgroup_dbnodes', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbgroup_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], name='db_dbgroup_dbnodes_dbgroup_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbgroup_dbnodes_dbnode_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbgroup_dbnodes_pkey') + op.create_table( + 'db_dbgroup_dbnodes', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbgroup_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], + name='db_dbgroup_dbnodes_dbgroup_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbgroup_dbnodes_dbnode_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbgroup_dbnodes_pkey') ) - op.create_table('db_dblock', - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') + op.create_table( + 'db_dblock', sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') ) - op.create_table('db_dbworkflowdata', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('data_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('value_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('json_value', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('aiida_obj_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['aiida_obj_id'], ['db_dbnode.id'], name='db_dbworkflowdata_aiida_obj_id_fkey'), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowdata_parent_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), - sa.UniqueConstraint('parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_key') + op.create_table( + 'db_dbworkflowdata', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('data_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('value_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('json_value', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('aiida_obj_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['aiida_obj_id'], ['db_dbnode.id'], name='db_dbworkflowdata_aiida_obj_id_fkey'), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowdata_parent_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), + sa.UniqueConstraint('parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_key') ) - op.create_table('db_dblink', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('input_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('output_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], name='db_dblink_input_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['output_id'], ['db_dbnode.id'], name='db_dblink_output_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dblink_pkey'), + op.create_table( + 'db_dblink', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('input_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('output_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], + name='db_dblink_input_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['output_id'], ['db_dbnode.id'], + name='db_dblink_output_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.PrimaryKeyConstraint('id', name='db_dblink_pkey'), ) op.create_index('ix_db_dblink_label', 'db_dblink', ['label']) - op.create_table('db_dbworkflowstep_calculations', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbworkflowstep_calculations_dbnode_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], name='db_dbworkflowstep_calculations_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), - sa.UniqueConstraint('dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calculations_id_dbnode_id_key') + op.create_table( + 'db_dbworkflowstep_calculations', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbworkflowstep_calculations_dbnode_id_fkey'), + sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_calculations_dbworkflowstep_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), + sa.UniqueConstraint('dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calculations_id_dbnode_id_key') ) - op.create_table('db_dbpath', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], name='db_dbpath_child_id_fkey', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], name='db_dbpath_parent_id_fkey', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') + op.create_table( + 'db_dbpath', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], + name='db_dbpath_child_id_fkey', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], + name='db_dbpath_parent_id_fkey', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') ) - op.create_table('db_dbcalcstate', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbcalcstate_dbnode_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), - sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') + op.create_table( + 'db_dbcalcstate', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcalcstate_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), + sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') ) op.create_index('ix_db_dbcalcstate_state', 'db_dbcalcstate', ['state']) - op.create_table('db_dbsetting', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('val', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.Column('description', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dbsetting_pkey'), - sa.UniqueConstraint('key', name='db_dbsetting_key_key') + op.create_table( + 'db_dbsetting', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('val', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.Column('description', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dbsetting_pkey'), + sa.UniqueConstraint('key', name='db_dbsetting_key_key') ) op.create_index('ix_db_dbsetting_key', 'db_dbsetting', ['key']) - op.create_table('db_dbcomment', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('content', sa.TEXT(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbcomment_dbnode_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbcomment_user_id_fkey', ondelete='CASCADE', initially='DEFERRED', deferrable=True), - sa.PrimaryKeyConstraint('id', name='db_dbcomment_pkey') + op.create_table( + 'db_dbcomment', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), + sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('content', sa.TEXT(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], + name='db_dbcomment_dbnode_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), + sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], + name='db_dbcomment_user_id_fkey', + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcomment_pkey') ) - op.create_table('db_dblog', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('loggername', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('levelname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('objname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('objpk', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('message', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='db_dblog_pkey') + op.create_table( + 'db_dblog', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('loggername', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('levelname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('objname', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('objpk', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('message', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('metadata', postgresql.JSONB(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='db_dblog_pkey') ) op.create_index('ix_db_dblog_levelname', 'db_dblog', ['levelname']) op.create_index('ix_db_dblog_loggername', 'db_dblog', ['loggername']) op.create_index('ix_db_dblog_objname', 'db_dblog', ['objname']) op.create_index('ix_db_dblog_objpk', 'db_dblog', ['objpk']) - op.create_table('db_dbworkflowstep_sub_workflows', - sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbworkflow_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbworkflow_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_sub_workflows_dbworkflow_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], name='db_dbworkflowstep_sub_workflows_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), - sa.UniqueConstraint('dbworkflowstep_id', 'dbworkflow_id', name='db_dbworkflowstep_sub_workflows_id_dbworkflow__key') + op.create_table( + 'db_dbworkflowstep_sub_workflows', sa.Column('id', sa.INTEGER(), nullable=False), + sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('dbworkflow_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['dbworkflow_id'], ['db_dbworkflow.id'], + name='db_dbworkflowstep_sub_workflows_dbworkflow_id_fkey'), + sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_sub_workflows_dbworkflowstep_id_fkey'), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), + sa.UniqueConstraint( + 'dbworkflowstep_id', 'dbworkflow_id', name='db_dbworkflowstep_sub_workflows_id_dbworkflow__key' + ) ) # I get the session using the alembic connection # (Keep in mind that alembic uses the AiiDA SQLA @@ -264,6 +322,7 @@ def upgrade(): def downgrade(): + """Migrations for the downgrade.""" op.drop_table('db_dbworkflowstep_calculations') op.drop_table('db_dbworkflowstep_sub_workflows') op.drop_table('db_dbworkflowdata') diff --git a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py index 3b0b4c32f8..3d34308429 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py @@ -7,8 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,no-member -# pylint: disable=no-name-in-module,import-error +# pylint: disable=invalid-name,no-member,no-name-in-module,import-error """This migration creates UUID column and populates it with distinct UUIDs This migration corresponds to the 0024_dblog_update Django migration. diff --git a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py b/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py index ff38db86f6..a6543778a4 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py +++ b/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member """Delete the kombu tables that were used by the old Celery based daemon and the obsolete related timestamps Revision ID: f9a69de76a9a @@ -17,7 +18,6 @@ from alembic import op from sqlalchemy.sql import text - # revision identifiers, used by Alembic. revision = 'f9a69de76a9a' down_revision = '6c629c886f84' @@ -26,10 +26,12 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() # Drop the kombu tables and delete the old timestamps and user related to the daemon in the DbSetting table - statement = text(""" + statement = text( + """ DROP TABLE IF EXISTS kombu_message; DROP TABLE IF EXISTS kombu_queue; DROP SEQUENCE IF EXISTS message_id_sequence; @@ -41,9 +43,11 @@ def upgrade(): DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; - """) + """ + ) conn.execute(statement) def downgrade(): + """Migrations for the downgrade.""" print('There is no downgrade for the deletion of the kombu tables and the daemon timestamps') diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index 8721455bfd..73a7cba6cf 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -44,7 +44,7 @@ class _SessionProperty: def __get__(self, obj, _type): if not aiida.backends.sqlalchemy.get_scoped_session(): - raise InvalidOperation('You need to call load_dbenv before ' 'accessing the session of SQLALchemy.') + raise InvalidOperation('You need to call load_dbenv before accessing the session of SQLALchemy.') return aiida.backends.sqlalchemy.get_scoped_session() diff --git a/aiida/backends/sqlalchemy/models/computer.py b/aiida/backends/sqlalchemy/models/computer.py index 5b4befde47..e53f052ebf 100644 --- a/aiida/backends/sqlalchemy/models/computer.py +++ b/aiida/backends/sqlalchemy/models/computer.py @@ -9,7 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage computers for the SQLA backend.""" - from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.schema import Column from sqlalchemy.types import Integer, String, Text @@ -34,7 +33,6 @@ class DbComputer(Base): def __init__(self, *args, **kwargs): """Provide _metadata and description attributes to the class.""" self._metadata = {} - # TODO SP: it's supposed to be nullable, but there is a NOT constraint inside the DB. self.description = '' # If someone passes metadata in **kwargs we change it to _metadata diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/backends/sqlalchemy/models/node.py index b1ccce8eba..48d0cae220 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/backends/sqlalchemy/models/node.py @@ -53,6 +53,7 @@ class DbNode(Base): Integer, ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='restrict'), nullable=False ) + # pylint: disable=fixme # TODO SP: The 'passive_deletes=all' argument here means that SQLAlchemy # won't take care of automatic deleting in the DbLink table. This still # isn't exactly the same behaviour than with Django. The solution to @@ -108,7 +109,7 @@ def outputs(self): @property def inputs(self): - return self.inputs_q.all() + return self.inputs_q.all() # pylint: disable=no-member def get_simple_name(self, invalid_result=None): """ diff --git a/aiida/backends/sqlalchemy/models/settings.py b/aiida/backends/sqlalchemy/models/settings.py index d5b328999e..eac5433a28 100644 --- a/aiida/backends/sqlalchemy/models/settings.py +++ b/aiida/backends/sqlalchemy/models/settings.py @@ -9,7 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage node settings for the SQLA backend.""" - from pytz import UTC from sqlalchemy import Column @@ -40,9 +39,7 @@ def __str__(self): return "'{}'={}".format(self.key, self.getvalue()) @classmethod - def set_value( - cls, key, value, with_transaction=True, subspecifier_value=None, other_attribs=None, stop_if_existing=False - ): + def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): """Set a setting value.""" other_attribs = other_attribs if other_attribs is not None else {} setting = sa.get_scoped_session().query(DbSetting).filter_by(key=key).first() diff --git a/aiida/cmdline/commands/cmd_data/cmd_show.py b/aiida/cmdline/commands/cmd_data/cmd_show.py index 8a443e2945..51cdbca9b1 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_show.py +++ b/aiida/cmdline/commands/cmd_data/cmd_show.py @@ -10,7 +10,6 @@ """ This allows to manage showfunctionality to all data types. """ - import click from aiida.cmdline.params.options.multivalue import MultipleValueOption @@ -209,7 +208,7 @@ def _show_xmgrace(exec_name, list_bands): import sys import subprocess import tempfile - from aiida.orm.nodes.data.array.bands import max_num_agr_colors + from aiida.orm.nodes.data.array.bands import MAX_NUM_AGR_COLORS list_files = [] current_band_number = 0 @@ -218,7 +217,7 @@ def _show_xmgrace(exec_name, list_bands): nbnds = bnds.get_bands().shape[1] # pylint: disable=protected-access text, _ = bnds._exportcontent( - 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % max_num_agr_colors) + 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) ) # write a tempfile tempf = tempfile.NamedTemporaryFile('w+b', suffix='.agr') diff --git a/aiida/cmdline/commands/cmd_data/cmd_structure.py b/aiida/cmdline/commands/cmd_data/cmd_structure.py index c0d205197e..a2e91949bd 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_structure.py +++ b/aiida/cmdline/commands/cmd_data/cmd_structure.py @@ -231,7 +231,7 @@ def import_ase(filename, dry_run): try: import ase.io except ImportError: - echo.echo_critical('You have not installed the package ase. \n' 'You can install it with: pip install ase') + echo.echo_critical('You have not installed the package ase. \nYou can install it with: pip install ase') try: asecell = ase.io.read(filename) diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 531ed78eec..5c2357b4c5 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -455,7 +455,7 @@ def build_entries(ports): click.secho('{:>{width_name}d}: {}'.format(exit_code.status, message, width_name=max_width_name)) -def get_num_workers(): #pylint: disable=inconsistent-return-statements +def get_num_workers(): """ Get the number of active daemon workers from the circus client """ diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 61bd147e52..e10a7cca22 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to define commonly used data structures.""" - from enum import Enum, IntEnum from .extendeddicts import DefaultFieldsAttributeDict @@ -78,7 +77,7 @@ class CalcInfo(DefaultFieldsAttributeDict): """ _default_fields = ( - 'job_environment', # TODO UNDERSTAND THIS! + 'job_environment', 'email', 'email_on_started', 'email_on_terminated', diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 7b66b24ffc..b7d59eef30 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -58,7 +58,7 @@ using_sysrandom = False # pylint: disable=invalid-name -def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz' 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): +def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): """ Returns a securely generated random string. diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 1fc1613a92..1ae7ae9510 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -37,6 +37,7 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= :param calc_info: the calculation info datastructure returned by `CalcJob.presubmit` :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements from logging import LoggerAdapter from tempfile import NamedTemporaryFile from aiida.orm import load_node, Code, RemoteData @@ -59,21 +60,23 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= logger = LoggerAdapter(logger=execlogger, extra=logger_extra) if not dry_run and node.has_cached_links(): - raise ValueError('Cannot submit calculation {} because it has cached input links! If you just want to test the ' - 'submission, set `metadata.dry_run` to True in the inputs.'.format(node.pk)) + raise ValueError( + 'Cannot submit calculation {} because it has cached input links! If you just want to test the ' + 'submission, set `metadata.dry_run` to True in the inputs.'.format(node.pk) + ) # If we are performing a dry-run, the working directory should actually be a local folder that should already exist if dry_run: workdir = transport.getcwd() else: remote_user = transport.whoami() - # TODO Doc: {username} field - # TODO: if something is changed here, fix also 'verdi computer test' remote_working_directory = computer.get_workdir().format(username=remote_user) if not remote_working_directory.strip(): raise exceptions.ConfigurationError( "[submission of calculation {}] No remote_working_directory configured for computer '{}'".format( - node.pk, computer.name)) + node.pk, computer.name + ) + ) # If it already exists, no exception is raised try: @@ -81,7 +84,9 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= except IOError: logger.debug( '[submission of calculation {}] Unable to chdir in {}, trying to create it'.format( - node.pk, remote_working_directory)) + node.pk, remote_working_directory + ) + ) try: transport.makedirs(remote_working_directory) transport.chdir(remote_working_directory) @@ -89,8 +94,8 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= raise exceptions.ConfigurationError( '[submission of calculation {}] ' 'Unable to create the remote directory {} on ' - "computer '{}': {}".format( - node.pk, remote_working_directory, computer.name, exc)) + "computer '{}': {}".format(node.pk, remote_working_directory, computer.name, exc) + ) # Store remotely with sharding (here is where we choose # the folder structure of remote jobs; then I store this # in the calculation properties using _set_remote_dir @@ -112,8 +117,11 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= path_existing = os.path.join(transport.getcwd(), calc_info.uuid[4:]) path_lost_found = os.path.join(remote_working_directory, REMOTE_WORK_DIRECTORY_LOST_FOUND) path_target = os.path.join(path_lost_found, calc_info.uuid) - logger.warning('tried to create path {} but it already exists, moving the entire folder to {}'.format( - path_existing, path_target)) + logger.warning( + 'tried to create path {} but it already exists, moving the entire folder to {}'.format( + path_existing, path_target + ) + ) # Make sure the lost+found directory exists, then copy the existing folder there and delete the original transport.mkdir(path_lost_found, ignore_existing=True) @@ -136,14 +144,14 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= for code in input_codes: if code.is_local(): # Note: this will possibly overwrite files - for f in code.list_object_names(): + for filename in code.list_object_names(): # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` # Since the content of the node could potentially be binary, we read the raw bytes and pass them on with NamedTemporaryFile(mode='wb+') as handle: - handle.write(code.get_object_content(f, mode='rb')) + handle.write(code.get_object_content(filename, mode='rb')) handle.flush() - transport.put(handle.name, f) + transport.put(handle.name, filename) transport.chmod(code.get_local_executable(), 0o755) # rwxr-xr-x # In a dry_run, the working directory is the raw input folder, which will already contain these resources @@ -168,10 +176,10 @@ def find_data_node(inputs, uuid): :param uuid: UUID of the node to find :return: instance of `Node` or `None` if not found """ - from collections import Mapping + from collections.abc import Mapping data_node = None - for link_label, input_node in inputs.items(): + for input_node in inputs.values(): if isinstance(input_node, Mapping): data_node = find_data_node(input_node, uuid) elif isinstance(input_node, Node) and input_node.uuid == uuid: @@ -201,45 +209,64 @@ def find_data_node(inputs, uuid): if remote_copy_list: with open(os.path.join(workdir, '_aiida_remote_copy_list.txt'), 'w') as handle: for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: - handle.write('would have copied {} to {} in working directory on remote {}'.format( - remote_abs_path, dest_rel_path, computer.name)) + handle.write( + 'would have copied {} to {} in working directory on remote {}'.format( + remote_abs_path, dest_rel_path, computer.name + ) + ) if remote_symlink_list: with open(os.path.join(workdir, '_aiida_remote_symlink_list.txt'), 'w') as handle: for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: - handle.write('would have created symlinks from {} to {} in working directory on remote {}'.format( - remote_abs_path, dest_rel_path, computer.name)) + handle.write( + 'would have created symlinks from {} to {} in working directory on remote {}'.format( + remote_abs_path, dest_rel_path, computer.name + ) + ) else: for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_copy_list: if remote_computer_uuid == computer.uuid: - logger.debug('[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( - node.pk, dest_rel_path, computer.name)) + logger.debug( + '[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( + node.pk, dest_rel_path, computer.name + ) + ) try: transport.copy(remote_abs_path, dest_rel_path) except (IOError, OSError): - logger.warning('[submission of calculation {}] Unable to copy remote resource from {} to {}! ' - 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path)) + logger.warning( + '[submission of calculation {}] Unable to copy remote resource from {} to {}! ' + 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path) + ) raise else: raise NotImplementedError( '[submission of calculation {}] Remote copy between two different machines is ' - 'not implemented yet'.format(node.pk)) + 'not implemented yet'.format(node.pk) + ) for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_symlink_list: if remote_computer_uuid == computer.uuid: - logger.debug('[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( - node.pk, dest_rel_path, computer.name)) + logger.debug( + '[submission of calculation {}] copying {} remotely, directly on the machine {}'.format( + node.pk, dest_rel_path, computer.name + ) + ) try: transport.symlink(remote_abs_path, dest_rel_path) except (IOError, OSError): - logger.warning('[submission of calculation {}] Unable to create remote symlink from {} to {}! ' - 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path)) + logger.warning( + '[submission of calculation {}] Unable to create remote symlink from {} to {}! ' + 'Stopping.'.format(node.pk, remote_abs_path, dest_rel_path) + ) raise else: - raise IOError('It is not possible to create a symlink between two different machines for ' - 'calculation {}'.format(node.pk)) + raise IOError( + 'It is not possible to create a symlink between two different machines for ' + 'calculation {}'.format(node.pk) + ) provenance_exclude_list = calc_info.provenance_exclude_list or [] @@ -250,7 +277,7 @@ def find_data_node(inputs, uuid): # advantage of this explicit copying instead of deleting the files from `provenance_exclude_list` from the sandbox # first before moving the entire remaining content to the node's repository, is that in this way we are guaranteed # not to accidentally move files to the repository that should not go there at all cost. - for root, dirnames, filenames in os.walk(folder.abspath): + for root, _, filenames in os.walk(folder.abspath): for filename in filenames: filepath = os.path.join(root, filename) relpath = os.path.relpath(filepath, folder.abspath) @@ -319,8 +346,9 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the retrieval. link_label = calculation.link_label_retrieved if calculation.get_outgoing(FolderData, link_label_filter=link_label).first(): - execlogger.warning('CalcJobNode<{}> already has a `{}` output folder: skipping retrieval'.format( - calculation.pk, link_label)) + execlogger.warning( + 'CalcJobNode<{}> already has a `{}` output folder: skipping retrieval'.format(calculation.pk, link_label) + ) return # Create the FolderData node into which to store the files that are to be retrieved @@ -351,14 +379,17 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): # Log the files that were retrieved in the temporary folder for filename in os.listdir(retrieved_temporary_folder): - execlogger.debug("[retrieval of calc {}] Retrieved temporary file or folder '{}'".format( - calculation.pk, filename), extra=logger_extra) + execlogger.debug( + "[retrieval of calc {}] Retrieved temporary file or folder '{}'".format(calculation.pk, filename), + extra=logger_extra + ) # Store everything execlogger.debug( '[retrieval of calc {}] ' 'Storing retrieved_files={}'.format(calculation.pk, retrieved_files.pk), - extra=logger_extra) + extra=logger_extra + ) retrieved_files.store() # Make sure that attaching the `retrieved` folder with a link is the last thing we do. This gives the biggest chance @@ -421,12 +452,17 @@ def parse_results(process, retrieved_temporary_folder=None): for filename in filenames: files.append('- [F] {}'.format(os.path.join(root, filename))) - execlogger.debug('[parsing of calc {}] ' - 'Content of the retrieved_temporary_folder: \n' - '{}'.format(process.node.pk, '\n'.join(files)), extra=logger_extra) + execlogger.debug( + '[parsing of calc {}] ' + 'Content of the retrieved_temporary_folder: \n' + '{}'.format(process.node.pk, '\n'.join(files)), + extra=logger_extra + ) else: - execlogger.debug('[parsing of calc {}] ' - 'No retrieved_temporary_folder.'.format(process.node.pk), extra=logger_extra) + execlogger.debug( + '[parsing of calc {}] ' + 'No retrieved_temporary_folder.'.format(process.node.pk), extra=logger_extra + ) if parser_class is not None: @@ -459,11 +495,14 @@ def parse_results(process, retrieved_temporary_folder=None): def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_extra=None): + """Retrieve files specified through the singlefile list mechanism.""" singlefile_list = [] for (linkname, subclassname, filename) in retrieve_file_list: - execlogger.debug('[retrieval of calc {}] Trying ' - "to retrieve remote singlefile '{}'".format( - job.pk, filename), extra=logger_extra) + execlogger.debug( + '[retrieval of calc {}] Trying ' + "to retrieve remote singlefile '{}'".format(job.pk, filename), + extra=logger_extra + ) localfilename = os.path.join(folder.abspath, os.path.split(filename)[1]) transport.get(filename, localfilename, ignore_nonexisting=True) singlefile_list.append((linkname, subclassname, localfilename)) @@ -474,16 +513,16 @@ def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_ext # after retrieving from the cluster, I create the objects singlefiles = [] for (linkname, subclassname, filename) in singlefile_list: - SinglefileSubclass = DataFactory(subclassname) - singlefile = SinglefileSubclass(file=filename) + cls = DataFactory(subclassname) + singlefile = cls(file=filename) singlefile.add_incoming(job, link_type=LinkType.CREATE, link_label=linkname) singlefiles.append(singlefile) for fil in singlefiles: execlogger.debug( '[retrieval of calc {}] ' - 'Storing retrieved_singlefile={}'.format(job.pk, fil.pk), - extra=logger_extra) + 'Storing retrieved_singlefile={}'.format(job.pk, fil.pk), extra=logger_extra + ) fil.store() @@ -522,14 +561,11 @@ def retrieve_files_from_list(calculation, transport, folder, retrieve_list): to_append = rem.split(os.path.sep)[-depth:] if depth > 0 else [] local_names.append(os.path.sep.join([tmp_lname] + to_append)) else: - remote_names = [tmp_rname] - to_append = remote_names.split(os.path.sep)[-depth:] if depth > 0 else [] + to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else [] local_names = [os.path.sep.join([tmp_lname] + to_append)] if depth > 1: # create directories in the folder, if needed for this_local_file in local_names: - new_folder = os.path.join( - folder, - os.path.split(this_local_file)[0]) + new_folder = os.path.join(folder, os.path.split(this_local_file)[0]) if not os.path.exists(new_folder): os.makedirs(new_folder) else: # it is a string @@ -542,5 +578,6 @@ def retrieve_files_from_list(calculation, transport, folder, retrieve_list): for rem, loc in zip(remote_names, local_names): transport.logger.debug( - "[retrieval of calc {}] Trying to retrieve remote item '{}'".format(calculation.pk, rem)) + "[retrieval of calc {}] Trying to retrieve remote item '{}'".format(calculation.pk, rem) + ) transport.get(rem, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 6f686355fd..bcdc87b844 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -26,7 +26,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs, ctx): # pylint: disable=inconsistent-return-statements,too-many-return-statements +def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statements """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -88,7 +88,7 @@ def validate_calc_job(inputs, ctx): # pylint: disable=inconsistent-return-state return 'input `metadata.options.resources` is not valid for the {} scheduler: {}'.format(scheduler, exception) -def validate_parser(parser_name, _): # pylint: disable=inconsistent-return-statements +def validate_parser(parser_name, _): """Validate the parser. :return: string with error message in case the inputs are invalid diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 112e3a5058..c0958e02d5 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Transport tasks for calculation jobs.""" import functools import logging import tempfile @@ -87,7 +88,8 @@ def do_upload(): logger.info('scheduled request to upload CalcJob<{}>'.format(node.pk)) ignore_exceptions = (plumpy.CancelledError, PreSubmitException) result = yield exponential_backoff_retry( - do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions) + do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) except PreSubmitException: raise except plumpy.CancelledError: @@ -136,7 +138,8 @@ def do_submit(): try: logger.info('scheduled request to submit CalcJob<{}>'.format(node.pk)) result = yield exponential_backoff_retry( - do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: pass except Exception: @@ -195,7 +198,8 @@ def do_update(): try: logger.info('scheduled request to update CalcJob<{}>'.format(node.pk)) job_done = yield exponential_backoff_retry( - do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: raise except Exception: @@ -257,8 +261,9 @@ def do_retrieve(): try: logger.info('scheduled request to retrieve CalcJob<{}>'.format(node.pk)) - result = yield exponential_backoff_retry( - do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) + yield exponential_backoff_retry( + do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + ) except plumpy.Interruption: raise except Exception: @@ -333,7 +338,8 @@ def load_instance_state(self, saved_state, load_context): @coroutine def execute(self): - + """Override the execute coroutine of the base `Waiting` state.""" + # pylint: disable=too-many-branches node = self.process.node transport_queue = self.process.runner.transport command = self.data @@ -393,6 +399,7 @@ def execute(self): @coroutine def _launch_task(self, coro, *args, **kwargs): + """Launch a coroutine as a task, making sure to make it interruptable.""" task_fn = functools.partial(coro, *args, **kwargs) try: self._task = interruptable_task(task_fn) diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 07e535af3d..35b65022dd 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -20,7 +20,7 @@ __all__ = ('BaseRestartWorkChain',) -def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=inconsistent-return-statements,unused-argument +def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=unused-argument """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a @@ -170,7 +170,7 @@ def run_process(self): return ToContext(children=append_(node)) - def inspect_process(self): # pylint: disable=inconsistent-return-statements,too-many-branches + def inspect_process(self): # pylint: disable=too-many-branches """Analyse the results of the previous process and call the handlers when necessary. If the process is excepted or killed, the work chain will abort. Otherwise any attached handlers will be called @@ -260,7 +260,7 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too # Otherwise the process was successful and no handler returned anything so we consider the work done self.ctx.is_finished = True - def results(self): # pylint: disable=inconsistent-return-statements + def results(self): """Attach the outputs specified in the output specification from the last completed process.""" node = self.ctx.children[self.ctx.iteration - 1] diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 4c78eac4ea..1c3b64777d 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -7,11 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ This module defines the classes related to band structures or dispersions in a Brillouin zone, and how to operate on them. """ - from string import Template import numpy @@ -22,6 +22,7 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): + """Prepare the header.""" from aiida import get_file_header filetext = [] @@ -32,13 +33,10 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): filetext.append('\t{}\t{}'.format(*plot_info['y'].shape)) filetext.append('') filetext.append('\tlabel\tpoint') - for l in plot_info['raw_labels']: - filetext.append('\t{}\t{:.8f}'.format(l[1], l[0])) - - return '\n'.join('{} {}'.format(comment_char, l) for l in filetext) - + for label in plot_info['raw_labels']: + filetext.append('\t{}\t{:.8f}'.format(label[1], label[0])) -# TODO: set and get bands could have more functionalities: how do I know the number of bands for example? + return '\n'.join('{} {}'.format(comment_char, line) for line in filetext) def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): @@ -71,14 +69,15 @@ def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): equal to the lumo (e.g. in semi-metals). """ + # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements,no-else-return + def nint(num): """ Stable rounding function """ - if (num > 0): + if num > 0: return int(num + .5) - else: - return int(num - .5) + return int(num - .5) if fermi_energy and number_electrons: raise ValueError('Specify either the number of electrons or the Fermi energy, but not both') @@ -89,11 +88,9 @@ def nint(num): raise KeyError('Cannot do much of a band analysis without bands') if len(stored_bands.shape) == 3: - # I write the algorithm for the generic case of having both the - # spin up and spin down array - + # I write the algorithm for the generic case of having both the spin up and spin down array # put all spins on one band per kpoint - bands = numpy.concatenate([_ for _ in stored_bands], axis=1) + bands = numpy.concatenate(stored_bands, axis=1) else: bands = stored_bands @@ -114,7 +111,7 @@ def nint(num): # spin up and spin down array # put all spins on one band per kpoint - occupations = numpy.concatenate([_ for _ in stored_occupations], axis=1) + occupations = numpy.concatenate(stored_occupations, axis=1) else: occupations = stored_occupations @@ -124,23 +121,29 @@ def nint(num): # sort the bands by energy, and reorder the occupations accordingly # since after joining the two spins, I might have unsorted stuff bands, occupations = [ - numpy.array(y) for y in zip(*[ - list(zip(*j)) for j in - [sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) for i in zip(bands, occupations)] - ]) + numpy.array(y) for y in zip( + *[ + list(zip(*j)) for j in [ + sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) + for i in zip(bands, occupations) + ] + ] + ) ] number_electrons = int(round(sum([sum(i) for i in occupations]) / num_kpoints)) homo_indexes = [numpy.where(numpy.array([nint(_) for _ in x]) > 0)[0][-1] for x in occupations] if len(set(homo_indexes)) > 1: # there must be intersections of valence and conduction bands return False, None - else: - homo = [_[0][_[1]] for _ in zip(bands, homo_indexes)] - try: - lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] - except IndexError: - raise ValueError('To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons') + + homo = [_[0][_[1]] for _ in zip(bands, homo_indexes)] + try: + lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] + except IndexError: + raise ValueError( + 'To understand if it is a metal or insulator, ' + 'need more bands than n_band=number_electrons' + ) else: bands = numpy.sort(bands) @@ -155,8 +158,10 @@ def nint(num): # gather the energies of the lumo band, for every kpoint lumo = [i[number_electrons // number_electrons_per_band] for i in bands] # take the n+1th level except IndexError: - raise ValueError('To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons') + raise ValueError( + 'To understand if it is a metal or insulator, ' + 'need more bands than n_band=number_electrons' + ) if number_electrons % 2 == 1 and len(stored_bands.shape) == 2: # if #electrons is odd and we have a non spin polarized calculation @@ -167,10 +172,11 @@ def nint(num): gap = min(lumo) - max(homo) if gap == 0.: return False, 0. - elif gap < 0.: + + if gap < 0.: return False, None - else: - return True, gap + + return True, gap # analysis on the fermi energy else: @@ -188,23 +194,22 @@ def nint(num): raise ValueError("The Fermi energy is below all band energies, don't know what to do.") # one band is crossed by the fermi energy - if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): + if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): # pylint: disable=chained-comparison return False, None # case of semimetals, fermi energy at the crossing of two bands # this will only work if the dirac point is computed! - elif (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): + if (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): return False, 0. - # insulating case - else: - # take the max of the band maxima below the fermi energy - homo = max([i[0] for i in max_mins if i[0] < fermi_energy]) - # take the min of the band minima above the fermi energy - lumo = min([i[1] for i in max_mins if i[1] > fermi_energy]) - gap = lumo - homo - if gap <= 0.: - raise Exception('Something wrong has been implemented. Revise the code!') - return True, gap + + # insulating case, take the max of the band maxima below the fermi energy + homo = max([i[0] for i in max_mins if i[0] < fermi_energy]) + # take the min of the band minima above the fermi energy + lumo = min([i[1] for i in max_mins if i[1] > fermi_energy]) + gap = lumo - homo + if gap <= 0.: + raise Exception('Something wrong has been implemented. Revise the code!') + return True, gap class BandsData(KpointsData): @@ -212,15 +217,6 @@ class BandsData(KpointsData): Class to handle bands data """ - # Associate file extensions to default plotting formats - _custom_export_format_replacements = { - 'dat': 'dat_multicolumn', - 'png': 'mpl_png', - 'pdf': 'mpl_pdf', - 'py': 'mpl_singlefile', - 'gnu': 'gnuplot' - } - def set_kpointsdata(self, kpointsdata): """ Load the kpoints from a kpoint object. @@ -258,6 +254,7 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): Nkpoints x Nbands floats or Nspins x Nkpoints x Nbands; Nkpoints must correspond to the number of kpoints. """ + # pylint: disable=too-many-branches try: kpoints = self.get_kpoints() except AttributeError: @@ -266,9 +263,11 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): the_bands = numpy.array(bands) if len(the_bands.shape) not in [2, 3]: - raise ValueError('Bands must be an array of dimension 2' - '([N_kpoints, N_bands]) or of dimension 3 ' - ' ([N_arrays, N_kpoints, N_bands]), found instead {}'.format(len(the_bands.shape))) + raise ValueError( + 'Bands must be an array of dimension 2' + '([N_kpoints, N_bands]) or of dimension 3 ' + ' ([N_arrays, N_kpoints, N_bands]), found instead {}'.format(len(the_bands.shape)) + ) list_of_arrays_to_be_checked = [] @@ -280,8 +279,10 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): if occupations is not None: the_occupations = numpy.array(occupations) if the_occupations.shape != the_bands.shape: - raise ValueError('Shape of occupations {} different from shape' - 'shape of bands {}'.format(the_occupations.shape, the_bands.shape)) + raise ValueError( + 'Shape of occupations {} different from shape' + 'shape of bands {}'.format(the_occupations.shape, the_bands.shape) + ) if not the_bands.dtype.type == numpy.float64: list_of_arrays_to_be_checked.append([the_occupations, 'occupations']) @@ -306,8 +307,10 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): elif isinstance(labels, (tuple, list)) and all([isinstance(_, str) for _ in labels]): the_labels = [str(_) for _ in labels] else: - raise ValidationError('Band labels have an unrecognized type ({})' - 'but should be a string or a list of strings'.format(labels.__class__)) + raise ValidationError( + 'Band labels have an unrecognized type ({})' + 'but should be a string or a list of strings'.format(labels.__class__) + ) if len(the_bands.shape) == 2 and len(the_labels) != 1: raise ValidationError('More array labels than the number of arrays') @@ -405,8 +408,8 @@ def get_bands(self, also_occupations=False, also_labels=False): if len(to_return) == 1: return bands - else: - return to_return + + return to_return def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, get_segments=False, y_origin=0.): """ @@ -432,6 +435,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, depending on the type of spin; the length is always equalt to the total number of bands per kpoint). """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements # load the x and y's of the graph stored_bands = self.get_bands() if len(stored_bands.shape) == 2: @@ -439,7 +443,7 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, band_type_idx = numpy.array([0] * stored_bands.shape[1]) two_band_types = False elif len(stored_bands.shape) == 3: - bands = numpy.concatenate([_ for _ in stored_bands], axis=1) + bands = numpy.concatenate(stored_bands, axis=1) band_type_idx = numpy.array([0] * stored_bands.shape[2] + [1] * stored_bands.shape[2]) two_band_types = True else: @@ -468,8 +472,9 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, # as a result, where there are discontinuities in the path, # I have two consecutive points with the same x coordinate distances = [ - numpy.linalg.norm(kpoints[i] - kpoints[i - 1]) - if not (i in labels_indices and i - 1 in labels_indices) else 0. for i in range(1, len(kpoints)) + numpy.linalg.norm(kpoints[i] - + kpoints[i - 1]) if not (i in labels_indices and i - 1 in labels_indices) else 0. + for i in range(1, len(kpoints)) ] x = [float(sum(distances[:i])) for i in range(len(distances) + 1)] @@ -499,8 +504,8 @@ def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, if labels[0][0] != 0: labels.insert(0, (0, '')) # I add an empty label that points to the last band if the last label does not do it - if labels[-1][0] != len(bands)-1 : - labels.append((len(bands)-1, '')) + if labels[-1][0] != len(bands) - 1: + labels.append((len(bands) - 1, '')) for (position_from, label_from), (position_to, label_to) in zip(labels[:-1], labels[1:]): if position_to - position_from > 1: # Create a new path line only if there are at least two points, @@ -547,6 +552,7 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-locals import os dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat' @@ -561,7 +567,6 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N x = plot_info['x'] labels = plot_info['labels'] - num_labels = len(labels) num_bands = bands.shape[1] # axis limits @@ -573,14 +578,6 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N # first prepare the xy coordinates of the sets raw_data, _ = self._prepare_dat_blocks(plot_info) - ## Manually add the xy coordinates of the vertical lines - not needed! Use gridlines - #new_block = [] - #for l in labels: - # new_block.append("{}\t{}".format(l[0], y_min_lim)) - # new_block.append("{}\t{}".format(l[0], y_max_lim)) - # new_block.append("") - #raw_data += "\n".join(new_block) - batch = [] if comments: batch.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) @@ -598,9 +595,9 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N batch.append('xaxis tick spec type both') batch.append('xaxis tick spec {}'.format(len(labels))) # set the name of the special points - for i, l in enumerate(labels): - batch.append('xaxis tick major {}, {}'.format(i, l[0])) - batch.append('xaxis ticklabel {}, "{}"'.format(i, l[1])) + for index, label in enumerate(labels): + batch.append('xaxis tick major {}, {}'.format(index, label[0])) + batch.append('xaxis ticklabel {}, "{}"'.format(index, label[1])) batch.append('xaxis tick major color 7') batch.append('xaxis tick major grid on') @@ -614,20 +611,16 @@ def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=N batch.append('xaxis label font 4') # set color and linewidths of bands - for i in range(num_bands): - batch.append('s{} line color 1'.format(i)) - batch.append('s{} linewidth 1'.format(i)) - - ## set color and linewidths of label lines - not needed! use gridlines - #for i in range(num_bands, num_bands + num_labels): - # batch.append("s{} hidden true".format(i)) + for index in range(num_bands): + batch.append('s{} line color 1'.format(index)) + batch.append('s{} linewidth 1'.format(index)) batch_data = '\n'.join(batch) + '\n' extra_files = {dat_filename: raw_data} return batch_data.encode('utf-8'), extra_files - def _prepare_dat_multicolumn(self, main_file_name='', comments=True): + def _prepare_dat_multicolumn(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Write an N x M matrix. First column is the distance between kpoints, The other columns are the bands. Header contains number of kpoints and @@ -651,7 +644,7 @@ def _prepare_dat_multicolumn(self, main_file_name='', comments=True): return ('\n'.join(return_text) + '\n').encode('utf-8'), {} - def _prepare_dat_blocks(self, main_file_name='', comments=True): + def _prepare_dat_blocks(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Format suitable for gnuplot using blocks. Columns with x and y (path and band energy). Several blocks, separated @@ -669,10 +662,8 @@ def _prepare_dat_blocks(self, main_file_name='', comments=True): if comments: return_text.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) - the_bands = numpy.transpose(bands) - - for b in the_bands: - for i in zip(x, b): + for band in numpy.transpose(bands): + for i in zip(x, band): line = ['{:.8f}'.format(i[0]), '{:.8f}'.format(i[1])] return_text.append('\t'.join(line)) return_text.append('') @@ -680,17 +671,19 @@ def _prepare_dat_blocks(self, main_file_name='', comments=True): return '\n'.join(return_text).encode('utf-8'), {} - def _matplotlib_get_dict(self, - main_file_name='', - comments=True, - title='', - legend=None, - legend2=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None, - **kwargs): + def _matplotlib_get_dict( + self, + main_file_name='', + comments=True, + title='', + legend=None, + legend2=None, + y_max_lim=None, + y_min_lim=None, + y_origin=0., + prettify_format=None, + **kwargs + ): # pylint: disable=unused-argument """ Prepare the data to send to the python-matplotlib plotting script. @@ -700,7 +693,7 @@ def _matplotlib_get_dict(self, :param setnumber_offset: an offset to be applied to all set numbers (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter max_num_agr_colors + and filling (should be less than the parameter MAX_NUM_AGR_COLORS defined below) :param title: the title :param legend: the legend (applied only to the first of the set) @@ -717,7 +710,7 @@ def _matplotlib_get_dict(self, :param kwargs: additional customization variables; only a subset is accepted, see internal variable 'valid_additional_keywords """ - #import math + # pylint: disable=too-many-arguments,too-many-locals # Only these keywords are accepted in kwargs, and then set into the json valid_additional_keywords = [ @@ -763,7 +756,8 @@ def _matplotlib_get_dict(self, prettify_format=prettify_format, join_symbol=join_symbol, get_segments=True, - y_origin=y_origin) + y_origin=y_origin + ) all_data = {} @@ -777,10 +771,8 @@ def _matplotlib_get_dict(self, tick_pos = [] tick_labels = [] - #all_data['bands'] = the_bands.tolist() all_data['paths'] = plot_info['paths'] all_data['band_type_idx'] = plot_info['band_type_idx'].tolist() - #all_data['x'] = x all_data['tick_pos'] = tick_pos all_data['tick_labels'] = tick_labels @@ -798,17 +790,15 @@ def _matplotlib_get_dict(self, y_min_lim = numpy.array(bands).min() x_min_lim = min(x) # this isn't a numpy array, but a list x_max_lim = max(x) - #ytick_spacing = 10 ** int(math.log10((y_max_lim - y_min_lim))) all_data['x_min_lim'] = x_min_lim all_data['x_max_lim'] = x_max_lim all_data['y_min_lim'] = y_min_lim all_data['y_max_lim'] = y_max_lim - #all_data['ytick_spacing'] = ytick_spacing - for k, v in kwargs.items(): - if k not in valid_additional_keywords: - raise TypeError("_matplotlib_get_dict() got an unexpected keyword argument '{}'".format(k)) - all_data[k] = v + for key, value in kwargs.items(): + if key not in valid_additional_keywords: + raise TypeError("_matplotlib_get_dict() got an unexpected keyword argument '{}'".format(key)) + all_data[key] = value return all_data @@ -823,16 +813,16 @@ def _prepare_mpl_singlefile(self, *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) - s_header = matplotlib_header_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = matplotlib_footer_template_show.substitute() + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer - return s.encode('utf-8'), {} + return string.encode('utf-8'), {} - def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -852,16 +842,16 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): ext_files = {json_fname: json.dumps(all_data, indent=2).encode('utf-8')} - s_header = matplotlib_header_template.substitute() - s_import = matplotlib_import_data_fromfile_template.substitute(json_fname=json_fname) + s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE.substitute(json_fname=json_fname) s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = matplotlib_footer_template_show.substitute() + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer - return s.encode('utf-8'), ext_files + return string.encode('utf-8'), ext_files - def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -879,8 +869,8 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) # Use the Agg backend - s_header = matplotlib_header_agg_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) # I get a temporary file name @@ -890,30 +880,28 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): escaped_fname = filename.replace('"', '\"') - s_footer = matplotlib_footer_template_exportfile.substitute(fname=escaped_fname, format='pdf') + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE.substitute(fname=escaped_fname, format='pdf') - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer # I don't exec it because I might mess up with the matplotlib backend etc. # I run instead in a different process, with the same executable # (so it should work properly with virtualenvs) - #exec s - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(s) - f.flush() - - subprocess.check_output([sys.executable, f.name]) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(string) + handle.flush() + subprocess.check_output([sys.executable, handle.name]) if not os.path.exists(filename): raise RuntimeError('Unable to generate the PDF...') - with open(filename, 'rb', encoding=None) as f: - imgdata = f.read() + with open(filename, 'rb', encoding=None) as handle: + imgdata = handle.read() os.remove(filename) return imgdata, {} - def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): + def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument """ Prepare a python script using matplotlib to plot the bands, with the JSON returned as an independent file. @@ -930,8 +918,8 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): all_data = self._matplotlib_get_dict(*args, **kwargs) # Use the Agg backend - s_header = matplotlib_header_agg_template.substitute() - s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2)) + s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() + s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) s_body = self._get_mpl_body_template(all_data['paths']) # I get a temporary file name @@ -941,24 +929,23 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): escaped_fname = filename.replace('"', '\"') - s_footer = matplotlib_footer_template_exportfile_with_dpi.substitute(fname=escaped_fname, format='png', dpi=300) + s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI.substitute(fname=escaped_fname, format='png', dpi=300) - s = s_header + s_import + s_body + s_footer + string = s_header + s_import + s_body + s_footer # I don't exec it because I might mess up with the matplotlib backend etc. # I run instead in a different process, with the same executable # (so it should work properly with virtualenvs) - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(s) - f.flush() - - subprocess.check_output([sys.executable, f.name]) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(string) + handle.flush() + subprocess.check_output([sys.executable, handle.name]) if not os.path.exists(filename): raise RuntimeError('Unable to generate the PNG...') - with open(filename, 'rb', encoding=None) as f: - imgdata = f.read() + with open(filename, 'rb', encoding=None) as handle: + imgdata = handle.read() os.remove(filename) return imgdata, {} @@ -969,9 +956,9 @@ def _get_mpl_body_template(paths): :param paths: paths of k-points """ if len(paths) == 1: - s_body = matplotlib_body_template.substitute(plot_code=single_kp) + s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=SINGLE_KP) else: - s_body = matplotlib_body_template.substitute(plot_code=multi_kp) + s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=MULTI_KP) return s_body def show_mpl(self, **kwargs): @@ -984,14 +971,16 @@ def show_mpl(self, **kwargs): """ exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) # pylint: disable=exec-used - def _prepare_gnuplot(self, - main_file_name=None, - title='', - comments=True, - prettify_format=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0.): + def _prepare_gnuplot( + self, + main_file_name=None, + title='', + comments=True, + prettify_format=None, + y_max_lim=None, + y_min_lim=None, + y_origin=0. + ): """ Prepare an gnuplot script to plot the bands, with the .dat file returned as an independent file. @@ -1006,6 +995,7 @@ def _prepare_gnuplot(self, :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-arguments,too-many-locals import os main_file_name = main_file_name or 'band.dat' @@ -1016,14 +1006,11 @@ def _prepare_gnuplot(self, prettify_format = 'gnuplot_seekpath' plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin) + cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin + ) bands = plot_info['y'] x = plot_info['x'] - labels = plot_info['labels'] - - num_labels = len(labels) - num_bands = bands.shape[1] # axis limits if y_max_lim is None: @@ -1045,7 +1032,8 @@ def _prepare_gnuplot(self, script.append(prepare_header_comment(self.uuid, plot_info=plot_info, comment_char='# ')) script.append('') - script.append(u"""## Uncomment the next two lines to write directly to PDF + script.append( + """## Uncomment the next two lines to write directly to PDF ## Note: You need to have gnuplot installed with pdfcairo support! #set term pdfcairo #set output 'out.pdf' @@ -1060,18 +1048,15 @@ def _prepare_gnuplot(self, #set termopt font "CMU Sans Serif, 12" ## Classical Times New Roman #set termopt font "Times New Roman, 12" -""") +""" + ) # Actual logic script.append('set termopt enhanced') # Properly deals with e.g. subscripts script.append('set encoding utf8') # To deal with Greek letters script.append('set xtics ({})'.format(xtics_string)) - script.append('unset key') - - script.append('set yrange [{}:{}]'.format(y_min_lim, y_max_lim)) - script.append('set ylabel "{}"'.format('Dispersion ({})'.format(self.units))) if title: @@ -1084,25 +1069,31 @@ def _prepare_gnuplot(self, script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) else: script.append('set xrange [-1.0:1.0]') - script.append('plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) + script.append( + 'plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format( + os.path.basename(dat_filename).replace('"', '\"') + ) + ) script_data = '\n'.join(script) + '\n' extra_files = {dat_filename: raw_data} return script_data.encode('utf-8'), extra_files - def _prepare_agr(self, - main_file_name='', - comments=True, - setnumber_offset=0, - color_number=1, - color_number2=2, - legend='', - title='', - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None): + def _prepare_agr( + self, + main_file_name='', + comments=True, + setnumber_offset=0, + color_number=1, + color_number2=2, + legend='', + title='', + y_max_lim=None, + y_min_lim=None, + y_origin=0., + prettify_format=None + ): """ Prepare an xmgrace agr file. @@ -1112,11 +1103,11 @@ def _prepare_agr(self, :param setnumber_offset: an offset to be applied to all set numbers (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter max_num_agr_colors + and filling (should be less than the parameter MAX_NUM_AGR_COLORS defined below) :param color_number2: the color number for lines, symbols, error bars and filling for the second-type spins (should be less than the - parameter max_num_agr_colors defined below) + parameter MAX_NUM_AGR_COLORS defined below) :param legend: the legend (applied only to the first set) :param title: the title :param y_max_lim: the maximum on the y axis (if None, put the @@ -1130,19 +1121,21 @@ def _prepare_agr(self, :param prettify_format: if None, use the default prettify format. Otherwise specify a string with the prettifier to use. """ + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unused-argument if prettify_format is None: # Default. Specified like this to allow caller functions to pass 'None' prettify_format = 'agr_seekpath' plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin) + cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin + ) import math # load the x and y of every set - if color_number > max_num_agr_colors: - raise ValueError('Color number is too high (should be less than {})'.format(max_num_agr_colors)) - if color_number2 > max_num_agr_colors: - raise ValueError('Color number 2 is too high (should be less than {})'.format(max_num_agr_colors)) + if color_number > MAX_NUM_AGR_COLORS: + raise ValueError('Color number is too high (should be less than {})'.format(MAX_NUM_AGR_COLORS)) + if color_number2 > MAX_NUM_AGR_COLORS: + raise ValueError('Color number 2 is too high (should be less than {})'.format(MAX_NUM_AGR_COLORS)) bands = plot_info['y'] x = plot_info['x'] @@ -1161,22 +1154,22 @@ def _prepare_agr(self, # prepare xticks labels sx1 = '' - for i, l in enumerate(labels): - sx1 += agr_single_xtick_template.substitute( + for i, label in enumerate(labels): + sx1 += AGR_SINGLE_XTICK_TEMPLATE.substitute( index=i, - coord=l[0], - name=l[1], + coord=label[0], + name=label[1], ) - xticks = agr_xticks_template.substitute( + xticks = AGR_XTICKS_TEMPLATE.substitute( num_labels=num_labels, single_xtick_templates=sx1, ) # build the arrays with the xy coordinates all_sets = [] - for b in the_bands: + for band in the_bands: this_set = '' - for i in zip(x, b): + for i in zip(x, band): line = '{:.8f}'.format(i[0]) + '\t' + '{:.8f}'.format(i[1]) + '\n' this_set += line all_sets.append(this_set) @@ -1188,15 +1181,16 @@ def _prepare_agr(self, else: linecolor = color_number2 width = str(2.0) - set_descriptions += agr_set_description_template.substitute( + set_descriptions += AGR_SET_DESCRIPTION_TEMPLATE.substitute( set_number=i + setnumber_offset, linewidth=width, color_number=linecolor, - legend=legend if i == 0 else '') + legend=legend if i == 0 else '' + ) units = self.units - graphs = agr_graph_template.substitute( + graphs = AGR_GRAPH_TEMPLATE.substitute( x_min_lim=x_min_lim, y_min_lim=y_min_lim, x_max_lim=x_max_lim, @@ -1209,19 +1203,21 @@ def _prepare_agr(self, ) sets = [] for i, this_set in enumerate(all_sets): - sets.append(agr_singleset_template.substitute(set_number=i + setnumber_offset, xydata=this_set)) + sets.append(AGR_SINGLESET_TEMPLATE.substitute(set_number=i + setnumber_offset, xydata=this_set)) the_sets = '&\n'.join(sets) - s = agr_template.substitute(graphs=graphs, sets=the_sets) + string = AGR_TEMPLATE.substitute(graphs=graphs, sets=the_sets) if comments: - s = prepare_header_comment(self.uuid, plot_info, comment_char='#') + '\n' + s + string = prepare_header_comment(self.uuid, plot_info, comment_char='#') + '\n' + string - return s.encode('utf-8'), {} + return string.encode('utf-8'), {} def _get_band_segments(self, cartesian): + """Return the band segments.""" plot_info = self._get_bandplot_data( - cartesian=cartesian, prettify_format=None, join_symbol=None, get_segments=True) + cartesian=cartesian, prettify_format=None, join_symbol=None, get_segments=True + ) out_dict = {'label': self.label} @@ -1230,7 +1226,7 @@ def _get_band_segments(self, cartesian): return out_dict - def _prepare_json(self, main_file_name='', comments=True): + def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument """ Prepare a json file in a format compatible with the AiiDA band visualizer @@ -1249,9 +1245,10 @@ def _prepare_json(self, main_file_name='', comments=True): return json.dumps(json_dict).encode('utf-8'), {} -max_num_agr_colors = 15 +MAX_NUM_AGR_COLORS = 15 -agr_template = Template(""" +AGR_TEMPLATE = Template( + """ # Grace project file # @version 50122 @@ -1376,19 +1373,23 @@ def _prepare_json(self, main_file_name='', comments=True): @r4 line 0, 0, 0, 0 $graphs $sets - """) + """ +) -agr_xticks_template = Template(""" +AGR_XTICKS_TEMPLATE = Template(""" @ xaxis tick spec $num_labels $single_xtick_templates """) -agr_single_xtick_template = Template(""" +AGR_SINGLE_XTICK_TEMPLATE = Template( + """ @ xaxis tick major $index, $coord @ xaxis ticklabel $index, "$name" - """) + """ +) -agr_graph_template = Template(""" +AGR_GRAPH_TEMPLATE = Template( + """ @g0 on @g0 hidden false @g0 type XY @@ -1545,9 +1546,11 @@ def _prepare_json(self, main_file_name='', comments=True): @ frame background color 0 @ frame background pattern 0 $set_descriptions - """) + """ +) -agr_set_description_template = Template(""" +AGR_SET_DESCRIPTION_TEMPLATE = Template( + """ @ s$set_number hidden false @ s$set_number type xy @ s$set_number symbol 0 @@ -1597,9 +1600,10 @@ def _prepare_json(self, main_file_name='', comments=True): @ s$set_number errorbar riser clip length 0.100000 @ s$set_number comment "Cols 1:2" @ s$set_number legend "$legend" - """) + """ +) -agr_singleset_template = Template(""" +AGR_SINGLESET_TEMPLATE = Template(""" @target G0.S$set_number @type xy $xydata @@ -1608,7 +1612,8 @@ def _prepare_json(self, main_file_name='', comments=True): # text.latex.preview=True is needed to have a proper alignment of # tick marks with and without subscripts # see e.g. http://matplotlib.org/1.3.0/examples/pylab_examples/usetex_baseline_test.html -matplotlib_header_agg_template = Template('''# -*- coding: utf-8 -*- +MATPLOTLIB_HEADER_AGG_TEMPLATE = Template( + """# -*- coding: utf-8 -*- import matplotlib matplotlib.use('Agg') @@ -1630,12 +1635,14 @@ def _prepare_json(self, main_file_name='', comments=True): import json print_comment = False -''') +""" +) # text.latex.preview=True is needed to have a proper alignment of # tick marks with and without subscripts # see e.g. http://matplotlib.org/1.3.0/examples/pylab_examples/usetex_baseline_test.html -matplotlib_header_template = Template('''# -*- coding: utf-8 -*- +MATPLOTLIB_HEADER_TEMPLATE = Template( + """# -*- coding: utf-8 -*- from matplotlib import rc # Uncomment to change default font @@ -1654,16 +1661,19 @@ def _prepare_json(self, main_file_name='', comments=True): import json print_comment = False -''') +""" +) -matplotlib_import_data_inline_template = Template('''all_data_str = r"""$all_data_json""" +MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE = Template('''all_data_str = r"""$all_data_json""" ''') -matplotlib_import_data_fromfile_template = Template('''with open("$json_fname", encoding='utf8') as f: +MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE = Template( + """with open("$json_fname", encoding='utf8') as f: all_data_str = f.read() -''') +""" +) -multi_kp = ''' +MULTI_KP = """ for path in paths: if path['length'] <= 1: # Avoid printing empty lines @@ -1690,16 +1700,17 @@ def _prepare_json(self, main_file_name='', comments=True): p.plot(x, band, label=label, **further_plot_options ) -''' +""" -single_kp = ''' +SINGLE_KP = """ path = paths[0] values = path['values'] x = [path['x'] for _ in values] p.scatter(x, values, marker="_") -''' +""" -matplotlib_body_template = Template('''all_data = json.loads(all_data_str) +MATPLOTLIB_BODY_TEMPLATE = Template( + """all_data = json.loads(all_data_str) if not all_data.get('use_latex', False): rc('text', usetex=False) @@ -1779,13 +1790,11 @@ def _prepare_json(self, main_file_name='', comments=True): print(all_data['comment']) except KeyError: pass -''') +""" +) -matplotlib_footer_template_show = Template('''pl.show() -''') +MATPLOTLIB_FOOTER_TEMPLATE_SHOW = Template("""pl.show()""") -matplotlib_footer_template_exportfile = Template('''pl.savefig("$fname", format="$format") -''') +MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE = Template("""pl.savefig("$fname", format="$format")""") -matplotlib_footer_template_exportfile_with_dpi = Template('''pl.savefig("$fname", format="$format", dpi=$dpi) -''') +MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index e02eae0929..68e91c7b40 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -12,7 +12,6 @@ lists and meshes of k-points (i.e., points in the reciprocal space of a periodic crystal structure). """ - import numpy from .array import ArrayData diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 30c3a2aed9..9e81fc8c5d 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Data plugin to represet arrays of projected wavefunction components.""" import copy import numpy as np @@ -27,6 +27,7 @@ class ProjectionData(OrbitalData, ArrayData): s, n, and k. E.g. the elements are the projections described as < orbital | Bloch wavefunction (s,n,k) > """ + def _check_projections_bands(self, projection_array): """ Checks to make sure that a reference bandsdata is already set, and that @@ -46,9 +47,7 @@ def _check_projections_bands(self, projection_array): # The [0:2] is so that each array, and not collection of arrays # is used to make the comparison if np.shape(projection_array) != shape_bands: - raise AttributeError('These arrays are not the same shape as' - ' the bands') - return None + raise AttributeError('These arrays are not the same shape as' ' the bands') def set_reference_bandsdata(self, value): """ @@ -65,17 +64,19 @@ def set_reference_bandsdata(self, value): else: try: pk = int(value) - bands = load_node(pk=pk, type=BandsData) + bands = load_node(pk=pk) uuid = bands.uuid except ValueError: uuid = str(value) try: - bands = load_node(uuid=uuid, type=BandsData) + bands = load_node(uuid=uuid) uuid = bands.uuid - except : - raise exceptions.NotExistent('The value passed to ' - 'set_reference_bandsdata was not ' - 'associated to any bandsdata') + except Exception: # pylint: disable=bare-except + raise exceptions.NotExistent( + 'The value passed to ' + 'set_reference_bandsdata was not ' + 'associated to any bandsdata' + ) self.set_attribute('reference_bandsdata_uuid', uuid) @@ -94,12 +95,9 @@ def get_reference_bandsdata(self): except AttributeError: raise AttributeError('BandsData has not been set for this instance') try: - #bands = load_node(uuid=uuid, type=BandsData) - bands = load_node(uuid=uuid) #TODO switch to above once type - # has been implemented for load_node + bands = load_node(uuid=uuid) except exceptions.NotExistent: - raise exceptions.NotExistent('The bands referenced to this class have not been ' - 'found in this database.') + raise exceptions.NotExistent('The bands referenced to this class have not been found in this database.') return bands def _find_orbitals_and_indices(self, **kwargs): @@ -112,16 +110,12 @@ def _find_orbitals_and_indices(self, **kwargs): to the kwargs :return: all_orbitals, list of orbitals to which the indexes correspond """ - # index_and_orbitals = self._get_orbitals_and_index() - index_and_orbitals = [] selected_orbitals = self.get_orbitals(**kwargs) - selected_orb_dicts = [orb.get_orbital_dict() for orb - in selected_orbitals] + selected_orb_dicts = [orb.get_orbital_dict() for orb in selected_orbitals] all_orbitals = self.get_orbitals() all_orb_dicts = [orb.get_orbital_dict() for orb in all_orbitals] - retrieve_indices = [i for i in range(len(all_orb_dicts)) - if all_orb_dicts[i] in selected_orb_dicts] - return retrieve_indices, all_orbitals + retrieve_indices = [i for i in range(len(all_orb_dicts)) if all_orb_dicts[i] in selected_orb_dicts] + return retrieve_indices, all_orbitals def get_pdos(self, **kwargs): """ @@ -134,12 +128,10 @@ def get_pdos(self, **kwargs): """ retrieve_indices, all_orbitals = self._find_orbitals_and_indices(**kwargs) - out_list = [(all_orbitals[i], - self.get_array('pdos_{}'.format( - self._from_index_to_arrayname(i))), - self.get_array('energy_{}'.format( - self._from_index_to_arrayname(i))) ) - for i in retrieve_indices] + out_list = [( + all_orbitals[i], self.get_array('pdos_{}'.format(self._from_index_to_arrayname(i))), + self.get_array('energy_{}'.format(self._from_index_to_arrayname(i))) + ) for i in retrieve_indices] return out_list def get_projections(self, **kwargs): @@ -153,21 +145,26 @@ def get_projections(self, **kwargs): """ retrieve_indices, all_orbitals = self._find_orbitals_and_indices(**kwargs) - out_list = [(all_orbitals[i], - self.get_array('proj_{}'.format( - self._from_index_to_arrayname(i)))) + out_list = [(all_orbitals[i], self.get_array('proj_{}'.format(self._from_index_to_arrayname(i)))) for i in retrieve_indices] return out_list - def _from_index_to_arrayname(self, index): + @staticmethod + def _from_index_to_arrayname(index): """ Used internally to determine the array names. """ return 'array_{}'.format(index) - def set_projectiondata(self,list_of_orbitals, list_of_projections=None, - list_of_energy=None, list_of_pdos=None, - tags = None, bands_check=True): + def set_projectiondata( + self, + list_of_orbitals, + list_of_projections=None, + list_of_energy=None, + list_of_pdos=None, + tags=None, + bands_check=True + ): """ Stores the projwfc_array using the projwfc_label, after validating both. @@ -196,6 +193,9 @@ def set_projectiondata(self,list_of_orbitals, list_of_projections=None, been stored and therefore get_reference_bandsdata cannot be called """ + + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + def single_to_list(item): """ Checks if the item is a list or tuple, and converts it to a list @@ -207,8 +207,8 @@ def single_to_list(item): """ if isinstance(item, (list, tuple)): return item - else: - return [item] + + return [item] def array_list_checker(array_list, array_name, orb_length): """ @@ -217,12 +217,13 @@ def array_list_checker(array_list, array_name, orb_length): required_length, raises exception using array_name if there is a failure """ - if not all([isinstance(_,np.ndarray) for _ in array_list]): - raise exceptions.ValidationError('{} was not composed ' - 'entirely of ndarrays'.format(array_name)) + if not all([isinstance(_, np.ndarray) for _ in array_list]): + raise exceptions.ValidationError('{} was not composed entirely of ndarrays'.format(array_name)) if len(array_list) != orb_length: - raise exceptions.ValidationError('{} did not have the same length as the ' - 'list of orbitals'.format(array_name)) + raise exceptions.ValidationError( + '{} did not have the same length as the ' + 'list of orbitals'.format(array_name) + ) ############## list_of_orbitals = single_to_list(list_of_orbitals) @@ -232,22 +233,21 @@ def array_list_checker(array_list, array_name, orb_length): if not list_of_pdos and not list_of_projections: raise exceptions.ValidationError('Must set either pdos or projections') if bool(list_of_energy) != bool(list_of_pdos): - raise exceptions.ValidationError('list_of_pdos and list_of_energy must always ' - 'be set together') + raise exceptions.ValidationError('list_of_pdos and list_of_energy must always be set together') orb_length = len(list_of_orbitals) # verifies and sets the orbital dicts list_of_orbital_dicts = [] - for i in range(len(list_of_orbitals)): + for i, _ in enumerate(list_of_orbitals): this_orbital = list_of_orbitals[i] orbital_dict = this_orbital.get_orbital_dict() try: orbital_type = orbital_dict.pop('_orbital_type') except KeyError: - raise ValidationError('No _orbital_type key found in dictionary: {}'.format(orbital_dict)) - OrbitalClass = OrbitalFactory(orbital_type) - test_orbital = OrbitalClass(**orbital_dict) + raise exceptions.ValidationError('No _orbital_type key found in dictionary: {}'.format(orbital_dict)) + cls = OrbitalFactory(orbital_type) + test_orbital = cls(**orbital_dict) list_of_orbital_dicts.append(test_orbital.get_orbital_dict()) self.set_attribute('orbital_dicts', list_of_orbital_dicts) @@ -255,7 +255,7 @@ def array_list_checker(array_list, array_name, orb_length): if list_of_projections: list_of_projections = single_to_list(list_of_projections) array_list_checker(list_of_projections, 'projections', orb_length) - for i in range(len(list_of_projections)): + for i, _ in enumerate(list_of_projections): this_projection = list_of_projections[i] array_name = self._from_index_to_arrayname(i) if bands_check: @@ -268,7 +268,7 @@ def array_list_checker(array_list, array_name, orb_length): list_of_energy = single_to_list(list_of_energy) array_list_checker(list_of_pdos, 'pdos', orb_length) array_list_checker(list_of_energy, 'energy', orb_length) - for i in range(len(list_of_pdos)): + for i, _ in enumerate(list_of_pdos): this_pdos = list_of_pdos[i] this_energy = list_of_energy[i] array_name = self._from_index_to_arrayname(i) @@ -285,15 +285,17 @@ def array_list_checker(array_list, array_name, orb_length): except IndexError: return exceptions.ValidationError('tags must be a list') - if not all([isinstance(_,str) for _ in tags]): + if not all([isinstance(_, str) for _ in tags]): raise exceptions.ValidationError('Tags must set a list of strings') self.set_attribute('tags', tags) - def set_orbitals(self, **kwargs): + def set_orbitals(self, **kwargs): # pylint: disable=arguments-differ """ This method is inherited from OrbitalData, but is blocked here. If used will raise a NotImplementedError """ - raise NotImplementedError('You cannot set orbitals using this class!' - ' This class is for setting orbitals and ' - ' projections only!') + raise NotImplementedError( + 'You cannot set orbitals using this class!' + ' This class is for setting orbitals and ' + ' projections only!' + ) diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index a3d2674320..ecc0b3ee8f 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -12,8 +12,6 @@ collections of y-arrays bound to a single x-array, and the methods to operate on them. """ - - import numpy as np from aiida.common.exceptions import InputValidationError, NotExistent from .array import ArrayData @@ -30,8 +28,8 @@ def check_convert_single_to_tuple(item): """ if isinstance(item, (list, tuple)): return item - else: - return [item] + + return [item] class XyData(ArrayData): @@ -40,7 +38,9 @@ class XyData(ArrayData): each other. That is there is one array, the X array, and there are several Y arrays, which can be considered functions of X. """ - def _arrayandname_validator(self, array, name, units): + + @staticmethod + def _arrayandname_validator(array, name, units): """ Validates that the array is an numpy.ndarray and that the name is of type str. Raises InputValidationError if this not the case. @@ -86,8 +86,7 @@ def set_y(self, y_arrays, y_names, y_units): # checks that the input lengths match if len(y_arrays) != len(y_names): - raise InputValidationError('Length of arrays and names do not ' - 'match!') + raise InputValidationError('Length of arrays and names do not match!') if len(y_units) != len(y_names): raise InputValidationError('Length of units does not match!') @@ -100,9 +99,11 @@ def set_y(self, y_arrays, y_names, y_units): for num, (y_array, y_name, y_unit) in enumerate(zip(y_arrays, y_names, y_units)): self._arrayandname_validator(y_array, y_name, y_unit) if np.shape(y_array) != np.shape(x_array): - raise InputValidationError('y_array {} did not have the ' - 'same shape has the x_array!' - ''.format(y_name)) + raise InputValidationError( + 'y_array {} did not have the ' + 'same shape has the x_array!' + ''.format(y_name) + ) self.set_array('y_array_{}'.format(num), y_array) # if the y_arrays pass the initial validation, sets each @@ -147,6 +148,5 @@ def get_y(self): for i in range(len(y_names)): y_arrays += [self.get_array('y_array_{}'.format(i))] except (KeyError, AttributeError): - raise NotExistent('Could not retrieve array associated with y array' - ' {}'.format(y_names[i])) + raise NotExistent('Could not retrieve array associated with y array {}'.format(y_names[i])) return list(zip(y_names, y_arrays, y_units)) diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index 4766e15c91..b47105fb0d 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -7,9 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Data plugin represeting an executable code to be wrapped and called through a `CalcJob` plugin.""" import os -from aiida.common.exceptions import ValidationError, EntryPointError, InputValidationError +from aiida.common import exceptions from .data import Data __all__ = ('Code',) @@ -32,6 +33,8 @@ class Code(Data): for the code to be run). """ + # pylint: disable=too-many-public-methods + def __init__(self, remote_computer_exec=None, local_executable=None, input_plugin_name=None, files=None, **kwargs): super().__init__(**kwargs) @@ -132,9 +135,9 @@ def label(self, value): """ if '@' in str(value): msg = "Code labels must not contain the '@' symbol" - raise InputValidationError(msg) + raise exceptions.InputValidationError(msg) - super(Code, self.__class__).label.fset(self, value) + super(Code, self.__class__).label.fset(self, value) # pylint: disable=no-member def relabel(self, new_label, raise_error=True): """Relabel this code. @@ -146,6 +149,7 @@ def relabel(self, new_label, raise_error=True): .. deprecated:: 1.2.0 Will remove raise_error in `v2.0.0`. Use `try/except` instead. """ + # pylint: disable=unused-argument suffix = '@{}'.format(self.get_computer_name()) if new_label.endswith(suffix): new_label = new_label[:-len(suffix)] @@ -173,21 +177,21 @@ def get_code_helper(cls, label, machinename=None): from aiida.orm.querybuilder import QueryBuilder from aiida.orm.computers import Computer - qb = QueryBuilder() - qb.append(cls, filters={'label': {'==': label}}, project=['*'], tag='code') + query = QueryBuilder() + query.append(cls, filters={'label': label}, project='*', tag='code') if machinename: - qb.append(Computer, filters={'name': {'==': machinename}}, with_node='code') + query.append(Computer, filters={'name': machinename}, with_node='code') - if qb.count() == 0: + if query.count() == 0: raise NotExistent("'{}' is not a valid code name.".format(label)) - elif qb.count() > 1: - codes = qb.all(flat=True) + elif query.count() > 1: + codes = query.all(flat=True) retstr = ("There are multiple codes with label '{}', having IDs: ".format(label)) retstr += ', '.join(sorted([str(c.pk) for c in codes])) + '.\n' retstr += ('Relabel them (using their ID), or refer to them with their ID.') raise MultipleObjectsError(retstr) else: - return qb.first()[0] + return query.first()[0] @classmethod def get(cls, pk=None, label=None, machinename=None): @@ -200,11 +204,10 @@ def get(cls, pk=None, label=None, machinename=None): :param machinename: the machine name where code is setup :raise aiida.common.NotExistent: if no code identified by the given string is found - :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely - a code + :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code :raise aiida.common.InputValidationError: if neither a pk nor a label was passed in """ - from aiida.common.exceptions import (NotExistent, MultipleObjectsError, InputValidationError) + # pylint: disable=arguments-differ from aiida.orm.utils import load_code # first check if code pk is provided @@ -212,17 +215,17 @@ def get(cls, pk=None, label=None, machinename=None): code_int = int(pk) try: return load_code(pk=code_int) - except NotExistent: + except exceptions.NotExistent: raise ValueError('{} is not valid code pk'.format(pk)) - except MultipleObjectsError: - raise MultipleObjectsError("More than one code in the DB with pk='{}'!".format(pk)) + except exceptions.MultipleObjectsError: + raise exceptions.MultipleObjectsError("More than one code in the DB with pk='{}'!".format(pk)) # check if label (and machinename) is provided elif label is not None: return cls.get_code_helper(label, machinename) else: - raise InputValidationError('Pass either pk or code label (and machinename)') + raise exceptions.InputValidationError('Pass either pk or code label (and machinename)') @classmethod def get_from_string(cls, code_string): @@ -247,8 +250,8 @@ def get_from_string(cls, code_string): from aiida.common.exceptions import NotExistent, MultipleObjectsError, InputValidationError try: - label, sep, machinename = code_string.partition('@') - except AttributeError as exception: + label, _, machinename = code_string.partition('@') + except AttributeError: raise InputValidationError('the provided code_string is not of valid string type') try: @@ -270,35 +273,39 @@ def list_for_plugin(cls, plugin, labels=True): otherwise a list of integers with the code PKs. """ from aiida.orm.querybuilder import QueryBuilder - qb = QueryBuilder() - qb.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) - valid_codes = qb.all(flat=True) + query = QueryBuilder() + query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) + valid_codes = query.all(flat=True) if labels: return [c.label for c in valid_codes] - else: - return [c.pk for c in valid_codes] + + return [c.pk for c in valid_codes] def _validate(self): super()._validate() if self.is_local() is None: - raise ValidationError('You did not set whether the code is local or remote') + raise exceptions.ValidationError('You did not set whether the code is local or remote') if self.is_local(): if not self.get_local_executable(): - raise ValidationError('You have to set which file is the local executable ' - 'using the set_exec_filename() method') + raise exceptions.ValidationError( + 'You have to set which file is the local executable ' + 'using the set_exec_filename() method' + ) if self.get_local_executable() not in self.list_object_names(): - raise ValidationError("The local executable '{}' is not in the list of " - 'files of this code'.format(self.get_local_executable())) + raise exceptions.ValidationError( + "The local executable '{}' is not in the list of " + 'files of this code'.format(self.get_local_executable()) + ) else: if self.list_object_names(): - raise ValidationError('The code is remote but it has files inside') + raise exceptions.ValidationError('The code is remote but it has files inside') if not self.get_remote_computer(): - raise ValidationError('You did not specify a remote computer') + raise exceptions.ValidationError('You did not specify a remote computer') if not self.get_remote_exec_path(): - raise ValidationError('You did not specify a remote executable') + raise exceptions.ValidationError('You did not specify a remote executable') def set_prepend_text(self, code): """ @@ -367,9 +374,11 @@ def set_remote_computer_exec(self, remote_computer_exec): from aiida.common.lang import type_check if (not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2): - raise ValueError('remote_computer_exec must be a list or tuple ' - 'of length 2, with machine and executable ' - 'name') + raise ValueError( + 'remote_computer_exec must be a list or tuple ' + 'of length 2, with machine and executable ' + 'name' + ) computer, remote_exec_path = tuple(remote_computer_exec) @@ -455,8 +464,8 @@ def get_execname(self): """ if self.is_local(): return './{}'.format(self.get_local_executable()) - else: - return self.get_remote_exec_path() + + return self.get_remote_exec_path() def get_builder(self): """Create and return a new `ProcessBuilder` for the `CalcJob` class of the plugin configured for this code. @@ -478,8 +487,8 @@ def get_builder(self): try: process_class = CalculationFactory(plugin_name) - except EntryPointError: - raise EntryPointError('the calculation entry point `{}` could not be loaded'.format(plugin_name)) + except exceptions.EntryPointError: + raise exceptions.EntryPointError('the calculation entry point `{}` could not be loaded'.format(plugin_name)) builder = process_class.get_builder() builder.code = self @@ -532,14 +541,3 @@ def get_full_text_info(self, verbose=False): result.append(['Append text', 'No append text']) return result - - @classmethod - def setup(cls, **kwargs): - from aiida.cmdline.commands.code import CodeInputValidationClass - code = CodeInputValidationClass().set_and_validate_from_code(kwargs) - - try: - code.store() - except ValidationError as exc: - raise ValidationError('Unable to store the computer: {}.'.format(exc)) - return code diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py index 195d0b331a..cc97cc2ae2 100644 --- a/aiida/orm/nodes/data/orbital.py +++ b/aiida/orm/nodes/data/orbital.py @@ -7,13 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Data plugin to model an atomic orbital.""" import copy -from .data import Data -from .structure import Site as site_class from aiida.common.exceptions import ValidationError, InputValidationError from aiida.plugins import OrbitalFactory +from .data import Data __all__ = ('OrbitalData',) @@ -51,8 +50,7 @@ def get_orbitals(self, **kwargs): filter_dict.update(kwargs) # prevents KeyError from occuring orbital_dicts = [x for x in orbital_dicts if all([y in x for y in filter_dict])] - orbital_dicts = [x for x in orbital_dicts if - all([x[y] == filter_dict[y] for y in filter_dict])] + orbital_dicts = [x for x in orbital_dicts if all([x[y] == filter_dict[y] for y in filter_dict])] list_of_outputs = [] for orbital_dict in orbital_dicts: @@ -61,8 +59,8 @@ def get_orbitals(self, **kwargs): except KeyError: raise ValidationError('No _orbital_type found in: {}'.format(orbital_dict)) - OrbitalClass = OrbitalFactory(orbital_type) - orbital = OrbitalClass(**orbital_dict) + cls = OrbitalFactory(orbital_type) + orbital = cls(**orbital_dict) list_of_outputs.append(orbital) return list_of_outputs @@ -86,6 +84,7 @@ def set_orbitals(self, orbitals): orbital_dicts.append(orbital_dict) self.set_attribute('orbital_dicts', orbital_dicts) + ########################################################################## # Here are some ideas for potential future convenience methods ######################################################################### diff --git a/aiida/orm/nodes/data/remote.py b/aiida/orm/nodes/data/remote.py index 03e6bf6453..4c729e009a 100644 --- a/aiida/orm/nodes/data/remote.py +++ b/aiida/orm/nodes/data/remote.py @@ -7,10 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Data plugin that models a folder on a remote computer.""" import os -from .data import Data from aiida.orm import AuthInfo +from .data import Data __all__ = ('RemoteData',) @@ -61,19 +62,19 @@ def getfile(self, relpath, destpath): :param destpath: The absolute path of where to store the file on the local machine. """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), relpath) - t.getfile(full_path, destpath) - except IOError as e: - if e.errno == 2: # file not existing - raise IOError('The required remote file {} on {} does not exist or has been deleted.'.format( - full_path, self.computer.name - )) - else: - raise + transport.getfile(full_path, destpath) + except IOError as exception: + if exception.errno == 2: # file does not exist + raise IOError( + 'The required remote file {} on {} does not exist or has been deleted.'.format( + full_path, self.computer.name + ) + ) + raise def listdir(self, relpath='.'): """ @@ -83,32 +84,31 @@ def listdir(self, relpath='.'): :return: a flat list of file/directory names (as strings). """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), relpath) - t.chdir(full_path) - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + transport.chdir(full_path) + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.name) + ) + exc.errno = exception.errno raise exc else: raise try: - return t.listdir() - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + return transport.listdir() + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.name) + ) + exc.errno = exception.errno raise exc else: raise @@ -121,32 +121,31 @@ def listdir_withattributes(self, path='.'): :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. """ authinfo = self.get_authinfo() - t = authinfo.get_transport() - with t: + with authinfo.get_transport() as transport: try: full_path = os.path.join(self.get_remote_path(), path) - t.chdir(full_path) - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + transport.chdir(full_path) + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.name) + ) + exc.errno = exception.errno raise exc else: raise try: - return t.listdir_withattributes() - except IOError as e: - if e.errno == 2 or e.errno == 20: # directory not existing or not a directory + return transport.listdir_withattributes() + except IOError as exception: + if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'.format( - full_path, self.computer.name - )) - exc.errno = e.errno + 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. + format(full_path, self.computer.name) + ) + exc.errno = exception.errno raise exc else: raise diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index b4a3682c70..4f4e179571 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -7,29 +7,28 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ This module defines the classes for structures and all related functions to operate on them. """ - -import itertools import copy -from functools import reduce - +import functools +import itertools -from .data import Data from aiida.common.constants import elements from aiida.common.exceptions import UnsupportedSpeciesError +from .data import Data __all__ = ('StructureData', 'Kind', 'Site') # Threshold used to check if the mass of two different Site objects is the same. -_mass_threshold = 1.e-3 +_MASS_THRESHOLD = 1.e-3 # Threshold to check if the sum is one or not -_sum_threshold = 1.e-6 +_SUM_THRESHOLD = 1.e-6 # Threshold used to check if the cell volume is not zero. -_volume_threshold = 1.e-6 +_VOLUME_THRESHOLD = 1.e-6 _valid_symbols = tuple(i['symbol'] for i in elements.values()) _atomic_masses = {el['symbol']: el['mass'] for el in elements.values()} @@ -51,7 +50,7 @@ def _get_valid_cell(inputcell): except (IndexError, ValueError, TypeError): raise ValueError('Cell must be a list of three vectors, each defined as a list of three coordinates.') - if abs(calc_cell_volume(the_cell)) < _volume_threshold: + if abs(calc_cell_volume(the_cell)) < _VOLUME_THRESHOLD: raise ValueError('The cell volume is zero. Invalid cell.') return the_cell @@ -66,7 +65,7 @@ def get_valid_pbc(inputpbc): """ if isinstance(inputpbc, bool): the_pbc = (inputpbc, inputpbc, inputpbc) - elif (hasattr(inputpbc, '__iter__')): + elif hasattr(inputpbc, '__iter__'): # To manage numpy lists of bools, whose elements are of type numpy.bool_ # and for which isinstance(i,bool) return False... if hasattr(inputpbc, 'tolist'): @@ -93,7 +92,7 @@ def has_ase(): :return: True if the ase module can be imported, False otherwise. """ try: - import ase + import ase # pylint: disable=unused-import except ImportError: return False return True @@ -104,7 +103,7 @@ def has_pymatgen(): :return: True if the pymatgen module can be imported, False otherwise. """ try: - import pymatgen + import pymatgen # pylint: disable=unused-import except ImportError: return False return True @@ -125,7 +124,7 @@ def has_spglib(): :return: True if the spglib module can be imported, False otherwise. """ try: - import spglib + import spglib # pylint: disable=unused-import except ImportError: return False return True @@ -143,6 +142,7 @@ def calc_cell_volume(cell): :returns: the cell volume. """ + # pylint: disable=invalid-name # returns the volume of the primitive cell: |a1.(a2xa3)| a1 = cell[0] a2 = cell[1] @@ -243,8 +243,10 @@ def validate_symbols_tuple(symbols_tuple): else: valid = all(is_valid_symbol(sym) for sym in symbols_tuple) if not valid: - raise UnsupportedSpeciesError('At least one element of the symbol list {} has ' - 'not been recognized.'.format(symbols_tuple)) + raise UnsupportedSpeciesError( + 'At least one element of the symbol list {} has ' + 'not been recognized.'.format(symbols_tuple) + ) def is_ase_atoms(ase_atoms): @@ -256,10 +258,7 @@ def is_ase_atoms(ase_atoms): Requires the ability to import ase, by doing 'import ase'. """ - # TODO: Check if we want to try to import ase and do something - # reasonable depending on whether ase is there or not. import ase - return isinstance(ase_atoms, ase.Atoms) @@ -318,8 +317,9 @@ def get_formula_from_symbol_list(_list, separator=''): if isinstance(elem[1], str): list_str.append('{}{}'.format(elem[1], multiplicity_str)) elif elem[0] > 1: - list_str.append('({}){}'.format( - get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str)) + list_str.append( + '({}){}'.format(get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str) + ) else: list_str.append('{}{}'.format(get_formula_from_symbol_list(elem[1], separator=separator), multiplicity_str)) @@ -361,15 +361,15 @@ def group_together(_list, group_size, offset): the_list = copy.deepcopy(_list) the_list.reverse() grouped_list = [] - for i in range(offset): + for _ in range(offset): grouped_list.append([the_list.pop()]) while the_list: - l = [] - for i in range(group_size): + sub_list = [] + for _ in range(group_size): if the_list: - l.append(the_list.pop()) - grouped_list.append(l) + sub_list.append(the_list.pop()) + grouped_list.append(sub_list) return grouped_list @@ -402,10 +402,10 @@ def group_together_symbols(_list, group_size): the_symbol_list = copy.deepcopy(_list) has_grouped = False offset = 0 - while (not has_grouped) and (offset < group_size): + while not has_grouped and offset < group_size: grouped_list = group_together(the_symbol_list, group_size, offset) new_symbol_list = group_symbols(grouped_list) - if (len(new_symbol_list) < len(grouped_list)): + if len(new_symbol_list) < len(grouped_list): the_symbol_list = copy.deepcopy(new_symbol_list) the_symbol_list = cleanout_symbol_list(the_symbol_list) has_grouped = True @@ -423,10 +423,9 @@ def group_all_together_symbols(_list): """ has_finished = False group_size = 2 - n = len(_list) the_symbol_list = copy.deepcopy(_list) - while (not has_finished) and (group_size <= n // 2): + while not has_finished and group_size <= len(_list) // 2: # try to group as much as possible by groups of size group_size the_symbol_list, has_grouped = group_together_symbols(the_symbol_list, group_size) has_finished = has_grouped @@ -504,7 +503,7 @@ def get_formula(symbol_list, mode='hill', separator=''): # for hill and count cases, simply count the occurences of each # chemical symbol (with some re-ordering in hill) - elif mode in ['hill', 'hill_compact']: + if mode in ['hill', 'hill_compact']: symbol_set = set(symbol_list) first_symbols = [] if 'C' in symbol_set: @@ -527,11 +526,11 @@ def get_formula(symbol_list, mode='hill', separator=''): the_symbol_list = group_symbols(symbol_list) else: - raise ValueError('Mode should be hill, hill_compact, group, ' 'reduce, count or count_compact') + raise ValueError('Mode should be hill, hill_compact, group, reduce, count or count_compact') if mode in ['hill_compact', 'count_compact']: from math import gcd - the_gcd = reduce(gcd, [e[0] for e in the_symbol_list]) + the_gcd = functools.reduce(gcd, [e[0] for e in the_symbol_list]) the_symbol_list = [[e[0] // the_gcd, e[1]] for e in the_symbol_list] return get_formula_from_symbol_list(the_symbol_list, separator=separator) @@ -555,24 +554,24 @@ def get_symbols_string(symbols, weights): """ if len(symbols) == 1 and weights[0] == 1.: return symbols[0] - else: - pieces = [] - for s, w in zip(symbols, weights): - pieces.append('{}{:4.2f}'.format(s, w)) - if has_vacancies(weights): - pieces.append('X{:4.2f}'.format(1. - sum(weights))) - return '{{{}}}'.format(''.join(sorted(pieces))) + + pieces = [] + for symbol, weight in zip(symbols, weights): + pieces.append('{}{:4.2f}'.format(symbol, weight)) + if has_vacancies(weights): + pieces.append('X{:4.2f}'.format(1. - sum(weights))) + return '{{{}}}'.format(''.join(sorted(pieces))) def has_vacancies(weights): """ Returns True if the sum of the weights is less than one. - It uses the internal variable _sum_threshold as a threshold. + It uses the internal variable _SUM_THRESHOLD as a threshold. :param weights: the weights :return: a boolean """ w_sum = sum(weights) - return not (1. - w_sum < _sum_threshold) + return not 1. - w_sum < _SUM_THRESHOLD def symop_ortho_from_fract(cell): @@ -586,6 +585,7 @@ def symop_ortho_from_fract(cell): :param cell: array of cell parameters (three lengths and three angles) """ + # pylint: disable=invalid-name import math import numpy @@ -609,6 +609,7 @@ def symop_fract_from_ortho(cell): :param cell: array of cell parameters (three lengths and three angles) """ + # pylint: disable=invalid-name import math import numpy @@ -681,12 +682,12 @@ def atom_kinds_to_html(atom_kind): # Parse the formula (TODO can be made more robust though never fails if # it takes strings generated with kind.get_symbols_string()) import re - elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) + matched_elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) # Compose the html string html_formula_pieces = [] - for element in elements: + for element in matched_elements: # replace element X by 'vacancy' species = element[0] if element[0] != 'X' else 'vacancy' @@ -709,6 +710,9 @@ class StructureData(Data): boundary conditions (whether they are periodic or not) and other related useful information. """ + + # pylint: disable=too-many-public-methods + _set_incompatibilities = [('ase', 'cell'), ('ase', 'pbc'), ('ase', 'pymatgen'), ('ase', 'pymatgen_molecule'), ('ase', 'pymatgen_structure'), ('cell', 'pymatgen'), ('cell', 'pymatgen_molecule'), ('cell', 'pymatgen_structure'), ('pbc', 'pymatgen'), ('pbc', 'pymatgen_molecule'), @@ -716,9 +720,18 @@ class StructureData(Data): ('pymatgen', 'pymatgen_structure'), ('pymatgen_molecule', 'pymatgen_structure')] _dimensionality_label = {0: '', 1: 'length', 2: 'surface', 3: 'volume'} - - def __init__(self, cell=None, pbc=None, ase=None, pymatgen=None, pymatgen_structure=None, pymatgen_molecule=None, **kwargs): - + _internal_kind_tags = None + + def __init__( + self, + cell=None, + pbc=None, + ase=None, + pymatgen=None, + pymatgen_structure=None, + pymatgen_molecule=None, + **kwargs + ): # pylint: disable=too-many-arguments args = { 'cell': cell, 'pbc': pbc, @@ -779,8 +792,7 @@ def get_dimensionality(self): if dim == 0: pass elif dim == 1: - v = cell[pbc] - retdict['value'] = np.linalg.norm(v) + retdict['value'] = np.linalg.norm(cell[pbc]) elif dim == 2: vectors = cell[pbc] retdict['value'] = np.linalg.norm(np.cross(vectors[0], vectors[1])) @@ -830,8 +842,8 @@ def set_pymatgen_molecule(self, mol, margin=5): of earlier versions may cause errors). """ box = [ - max([x.coords.tolist()[0] for x in mol.sites]) - min([x.coords.tolist()[0] for x in mol.sites - ]) + 2 * margin, + max([x.coords.tolist()[0] for x in mol.sites]) - min([x.coords.tolist()[0] for x in mol.sites]) + + 2 * margin, max([x.coords.tolist()[1] for x in mol.sites]) - min([x.coords.tolist()[1] for x in mol.sites]) + 2 * margin, max([x.coords.tolist()[2] for x in mol.sites]) - min([x.coords.tolist()[2] for x in mol.sites]) + 2 * margin @@ -886,8 +898,8 @@ def build_kind_name(species_and_occu): kind_name += '2' return kind_name - else: - return None + + return None self.cell = struct.lattice.matrix.tolist() self.pbc = [True, True, True] @@ -910,7 +922,7 @@ def build_kind_name(species_and_occu): inputs = { 'symbols': [x.symbol for x in species_and_occu.keys()], - 'weights': [x for x in species_and_occu.values()], + 'weights': list(species_and_occu.values()), 'position': site.coords.tolist() } @@ -947,10 +959,12 @@ def _validate(self): from collections import Counter counts = Counter([k.name for k in kinds]) - for c in counts: - if counts[c] != 1: - raise ValidationError("Kind with name '{}' appears {} times " - 'instead of only one'.format(c, counts[c])) + for count in counts: + if counts[count] != 1: + raise ValidationError( + "Kind with name '{}' appears {} times " + 'instead of only one'.format(count, counts[count]) + ) try: # This will try to create the sites objects @@ -960,15 +974,19 @@ def _validate(self): for site in sites: if site.kind_name not in [k.name for k in kinds]: - raise ValidationError('A site has kind {}, but no specie with that name exists' - ''.format(site.kind_name)) + raise ValidationError( + 'A site has kind {}, but no specie with that name exists' + ''.format(site.kind_name) + ) kinds_without_sites = (set(k.name for k in kinds) - set(s.kind_name for s in sites)) if kinds_without_sites: - raise ValidationError('The following kinds are defined, but there ' - 'are no sites with that kind: {}'.format(list(kinds_without_sites))) + raise ValidationError( + 'The following kinds are defined, but there ' + 'are no sites with that kind: {}'.format(list(kinds_without_sites)) + ) - def _prepare_xsf(self, main_file_name=''): + def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format XSF (for XCrySDen). """ @@ -990,19 +1008,20 @@ def _prepare_xsf(self, main_file_name=''): return_string += '%18.10f %18.10f %18.10f\n' % tuple(site.position) return return_string.encode('utf-8'), {} - def _prepare_cif(self, main_file_name=''): + def _prepare_cif(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format CIF. """ from aiida.orm import CifData cif = CifData(ase=self.get_ase()) - return cif._prepare_cif() + return cif._prepare_cif() # pylint: disable=protected-access - def _prepare_chemdoodle(self, main_file_name=''): + def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format required by ChemDoodle. """ + # pylint: disable=too-many-locals,invalid-name import numpy as np from itertools import product @@ -1044,7 +1063,6 @@ def _prepare_chemdoodle(self, main_file_name=''): 'x': base_site['position'][0] + shift[0], 'y': base_site['position'][1] + shift[1], 'z': base_site['position'][2] + shift[2], - # 'atomic_elements_html': kind_string 'atomic_elements_html': atom_kinds_to_html(kind_string) }) @@ -1065,7 +1083,7 @@ def _prepare_chemdoodle(self, main_file_name=''): return json.dumps(return_dict).encode('utf-8'), {} - def _prepare_xyz(self, main_file_name=''): + def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument """ Write the given structure to a string of format XYZ. """ @@ -1076,14 +1094,20 @@ def _prepare_xyz(self, main_file_name=''): cell = self.cell return_list = ['{}'.format(len(sites))] - return_list.append('Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( - cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], cell[2][2], - self.pbc[0], self.pbc[1], self.pbc[2])) + return_list.append( + 'Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( + cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], + cell[2][2], self.pbc[0], self.pbc[1], self.pbc[2] + ) + ) for site in sites: # I checked above that it is not an alloy, therefore I take the # first symbol - return_list.append('{:6s} {:18.10f} {:18.10f} {:18.10f}'.format( - self.get_kind(site.kind_name).symbols[0], site.position[0], site.position[1], site.position[2])) + return_list.append( + '{:6s} {:18.10f} {:18.10f} {:18.10f}'.format( + self.get_kind(site.kind_name).symbols[0], site.position[0], site.position[1], site.position[2] + ) + ) return_string = '\n'.join(return_list) return return_string.encode('utf-8'), {} @@ -1115,8 +1139,8 @@ def _adjust_default_cell(self, vacuum_factor=1.0, vacuum_addition=10.0, pbc=(Fal leading to an unphysical definition of the structure. This method will adjust the cell """ + # pylint: disable=invalid-name import numpy as np - from ase.visualize import view def get_extremas_from_positions(positions): """ @@ -1130,7 +1154,7 @@ def get_extremas_from_positions(positions): # Calculating the minimal cell: positions = np.array([site.position for site in self.sites]) - position_min, position_max = get_extremas_from_positions(positions) + position_min, _ = get_extremas_from_positions(positions) # Translate the structure to the origin, such that the minimal values in each dimension # amount to (0,0,0) @@ -1336,11 +1360,11 @@ def append_kind(self, kind): # If here, no exceptions have been raised, so I add the site. self.attributes.setdefault('kinds', []).append(new_kind.get_raw()) - # Note, this is a dict (with integer keys) so it allows for empty - # spots! - if not hasattr(self, '_internal_kind_tags'): + # Note, this is a dict (with integer keys) so it allows for empty spots! + if self._internal_kind_tags is None: self._internal_kind_tags = {} - self._internal_kind_tags[len(self.get_attribute('kinds')) - 1] = kind._internal_tag + + self._internal_kind_tags[len(self.get_attribute('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access def append_site(self, site): """ @@ -1357,9 +1381,11 @@ def append_site(self, site): new_site = Site(site=site) # So we make a copy - if site.kind_name not in [k.name for k in self.kinds]: - raise ValueError("No kind with name '{}', available kinds are: " - '{}'.format(site.kind_name, [k.name for k in self.kinds])) + if site.kind_name not in [kind.name for kind in self.kinds]: + raise ValueError( + "No kind with name '{}', available kinds are: " + '{}'.format(site.kind_name, [kind.name for kind in self.kinds]) + ) # If here, no exceptions have been raised, so I add the site. self.attributes.setdefault('sites', []).append(new_site.get_raw()) @@ -1396,12 +1422,15 @@ def append_atom(self, **kwargs): .. note :: checks of equality of species are done using the :py:meth:`~aiida.orm.nodes.data.structure.Kind.compare_with` method. """ + # pylint: disable=too-many-branches aseatom = kwargs.pop('ase', None) if aseatom is not None: if kwargs: - raise ValueError("If you pass 'ase' as a parameter to " - 'append_atom, you cannot pass any further' - 'parameter') + raise ValueError( + "If you pass 'ase' as a parameter to " + 'append_atom, you cannot pass any further' + 'parameter' + ) position = aseatom.position kind = Kind(ase=aseatom) else: @@ -1420,13 +1449,13 @@ def append_atom(self, **kwargs): exists_already = False for idx, existing_kind in enumerate(_kinds): try: - existing_kind._internal_tag = self._internal_kind_tags[idx] + existing_kind._internal_tag = self._internal_kind_tags[idx] # pylint: disable=protected-access except KeyError: # self._internal_kind_tags does not contain any info for # the kind in position idx: I don't have to add anything # then, and I continue pass - if (kind.compare_with(existing_kind)[0]): + if kind.compare_with(existing_kind)[0]: kind = existing_kind exists_already = True break @@ -1456,10 +1485,12 @@ def append_atom(self, **kwargs): if is_the_same: kind = old_kind else: - raise ValueError('You are explicitly setting the name ' - "of the kind to '{}', that already " - 'exists, but the two kinds are different!' - ' (first difference: {})'.format(kind.name, firstdiff)) + raise ValueError( + 'You are explicitly setting the name ' + "of the kind to '{}', that already " + 'exists, but the two kinds are different!' + ' (first difference: {})'.format(kind.name, firstdiff) + ) site = Site(kind_name=kind.name, position=position) self.append_site(site) @@ -1528,7 +1559,7 @@ def get_kind(self, kind_name): try: kinds_dict = self._kinds_cache except AttributeError: - self._kinds_cache = {_.name: _ for _ in self.kinds} + self._kinds_cache = {_.name: _ for _ in self.kinds} # pylint: disable=attribute-defined-outside-init kinds_dict = self._kinds_cache else: kinds_dict = {_.name: _ for _ in self.kinds} @@ -1562,9 +1593,11 @@ def cell(self): @cell.setter def cell(self, value): + """Set the cell.""" self.set_cell(value) def set_cell(self, value): + """Set the cell.""" from aiida.common.exceptions import ModificationNotAllowed if self.is_stored: @@ -1611,7 +1644,6 @@ def reset_sites_positions(self, new_positions, conserve_particle=True): raise ModificationNotAllowed() if not conserve_particle: - # TODO: raise NotImplementedError else: @@ -1653,9 +1685,11 @@ def pbc(self): @pbc.setter def pbc(self, value): + """Set the periodic boundary conditions.""" self.set_pbc(value) def set_pbc(self, value): + """Set the periodic boundary conditions.""" from aiida.common.exceptions import ModificationNotAllowed if self.is_stored: @@ -1765,7 +1799,7 @@ def _get_object_phonopyatoms(self): :return: a PhonopyAtoms object """ - from phonopy.structure.atoms import Atoms as PhonopyAtoms + from phonopy.structure.atoms import PhonopyAtoms # pylint: disable=import-error atoms = PhonopyAtoms(symbols=[_.kind_name for _ in self.sites]) # Phonopy internally uses scaled positions, so you must store cell first! @@ -1805,8 +1839,8 @@ def _get_object_pymatgen(self, **kwargs): """ if self.pbc == (True, True, True): return self._get_object_pymatgen_structure(**kwargs) - else: - return self._get_object_pymatgen_molecule(**kwargs) + + return self._get_object_pymatgen_molecule(**kwargs) def _get_object_pymatgen_structure(self, **kwargs): """ @@ -1844,24 +1878,25 @@ def _get_object_pymatgen_structure(self, **kwargs): # case when spins are defined -> no partial occupancy allowed from pymatgen import Specie oxidation_state = 0 # now I always set the oxidation_state to zero - for s in self.sites: - k = self.get_kind(s.kind_name) - if len(k.symbols) != 1 or (len(k.weights) != 1 or sum(k.weights) < 1.): + for site in self.sites: + kind = self.get_kind(site.kind_name) + if len(kind.symbols) != 1 or (len(kind.weights) != 1 or sum(kind.weights) < 1.): raise ValueError('Cannot set partial occupancies and spins at the same time') species.append( Specie( - k.symbols[0], + kind.symbols[0], oxidation_state, - properties={'spin': -1 if k.name.endswith('1') else 1 if k.name.endswith('2') else 0})) + properties={'spin': -1 if kind.name.endswith('1') else 1 if kind.name.endswith('2') else 0} + ) + ) else: # case when no spin are defined - for s in self.sites: - k = self.get_kind(s.kind_name) - species.append({s: w for s, w in zip(k.symbols, k.weights)}) + for site in self.sites: + kind = self.get_kind(site.kind_name) + species.append(dict(zip(kind.symbols, kind.weights))) if any([ - create_automatic_kind_name(self.get_kind(name).symbols, - self.get_kind(name).weights) != name - for name in self.get_site_kindnames() + create_automatic_kind_name(self.get_kind(name).symbols, + self.get_kind(name).weights) != name for name in self.get_site_kindnames() ]): # add "kind_name" as a properties to each site, whenever # the kind_name cannot be automatically obtained from the symbols @@ -1892,11 +1927,11 @@ def _get_object_pymatgen_molecule(self, **kwargs): raise ValueError('Unrecognized parameters passed to pymatgen converter: {}'.format(kwargs.keys())) species = [] - for s in self.sites: - k = self.get_kind(s.kind_name) - species.append({s: w for s, w in zip(k.symbols, k.weights)}) + for site in self.sites: + kind = self.get_kind(site.kind_name) + species.append(dict(zip(kind.symbols, kind.weights))) - positions = [list(x.position) for x in self.sites] + positions = [list(site.position) for site in self.sites] return Molecule(species, positions) @@ -1931,6 +1966,7 @@ def __init__(self, **kwargs): :param name: a string that uniquely identifies the kind, and that is used to identify the sites. """ + # pylint: disable=too-many-branches,too-many-statements # Internal variables self._mass = None self._symbols = None @@ -1976,9 +2012,11 @@ def __init__(self, **kwargs): self.name = oldkind.name self._internal_tag = oldkind._internal_tag except AttributeError: - raise ValueError('Error using the Kind object. Are you sure ' - 'it is a Kind object? [Introspection says it is ' - '{}]'.format(str(type(oldkind)))) + raise ValueError( + 'Error using the Kind object. Are you sure ' + 'it is a Kind object? [Introspection says it is ' + '{}]'.format(str(type(oldkind))) + ) elif 'ase' in kwargs: aseatom = kwargs['ase'] @@ -1994,9 +2032,11 @@ def __init__(self, **kwargs): else: self.reset_mass() except AttributeError: - raise ValueError('Error using the aseatom object. Are you sure ' - 'it is a ase.atom.Atom object? [Introspection says it is ' - '{}]'.format(str(type(aseatom)))) + raise ValueError( + 'Error using the aseatom object. Are you sure ' + 'it is a ase.atom.Atom object? [Introspection says it is ' + '{}]'.format(str(type(aseatom))) + ) if aseatom.tag != 0: self.set_automatic_kind_name(tag=aseatom.tag) self._internal_tag = aseatom.tag @@ -2004,9 +2044,11 @@ def __init__(self, **kwargs): self.set_automatic_kind_name() else: if 'symbols' not in kwargs: - raise ValueError("'symbols' need to be " - 'specified (at least) to create a Site object. Otherwise, ' - "pass a raw site using the 'raw' parameter.") + raise ValueError( + "'symbols' need to be " + 'specified (at least) to create a Site object. Otherwise, ' + "pass a raw site using the 'raw' parameter." + ) weights = kwargs.pop('weights', None) self.set_symbols_and_weights(kwargs.pop('symbols'), weights) try: @@ -2051,7 +2093,7 @@ def reset_mass(self): """ w_sum = sum(self._weights) - if abs(w_sum) < _sum_threshold: + if abs(w_sum) < _SUM_THRESHOLD: self._mass = None return @@ -2115,21 +2157,28 @@ def compare_with(self, other_kind): # Check list of symbols for i in range(len(self.symbols)): if self.symbols[i] != other_kind.symbols[i]: - return (False, 'Symbol at position {:d} are different ' - '({} vs. {})'.format(i + 1, self.symbols[i], other_kind.symbols[i])) + return ( + False, 'Symbol at position {:d} are different ' + '({} vs. {})'.format(i + 1, self.symbols[i], other_kind.symbols[i]) + ) # Check weights (assuming length of weights and of symbols have same # length, which should be always true for i in range(len(self.weights)): if self.weights[i] != other_kind.weights[i]: - return (False, 'Weight at position {:d} are different ' - '({} vs. {})'.format(i + 1, self.weights[i], other_kind.weights[i])) + return ( + False, 'Weight at position {:d} are different ' + '({} vs. {})'.format(i + 1, self.weights[i], other_kind.weights[i]) + ) # Check masses - if abs(self.mass - other_kind.mass) > _mass_threshold: + if abs(self.mass - other_kind.mass) > _MASS_THRESHOLD: return (False, 'Masses are different ({} vs. {})'.format(self.mass, other_kind.mass)) - if self._internal_tag != other_kind._internal_tag: - return (False, 'Internal tags are different ({} vs. {})' - ''.format(self._internal_tag, other_kind._internal_tag)) + if self._internal_tag != other_kind._internal_tag: # pylint: disable=protected-access + return ( + False, + 'Internal tags are different ({} vs. {})' + ''.format(self._internal_tag, other_kind._internal_tag) # pylint: disable=protected-access + ) # If we got here, the two Site objects are similar enough # to be considered of the same kind @@ -2169,9 +2218,11 @@ def weights(self, value): weights_tuple = _create_weights_tuple(value) if len(weights_tuple) != len(self._symbols): - raise ValueError('Cannot change the number of weights. Use the ' - 'set_symbols_and_weights function instead.') - validate_weights_tuple(weights_tuple, _sum_threshold) + raise ValueError( + 'Cannot change the number of weights. Use the ' + 'set_symbols_and_weights function instead.' + ) + validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) self._weights = weights_tuple @@ -2200,8 +2251,8 @@ def symbol(self): """ if len(self._symbols) == 1: return self._symbols[0] - else: - raise ValueError('This kind has more than one symbol (it is an alloy): {}'.format(self._symbols)) + + raise ValueError('This kind has more than one symbol (it is an alloy): {}'.format(self._symbols)) @property def symbols(self): @@ -2227,8 +2278,10 @@ def symbols(self, value): symbols_tuple = _create_symbols_tuple(value) if len(symbols_tuple) != len(self._weights): - raise ValueError('Cannot change the number of symbols. Use the ' - 'set_symbols_and_weights function instead.') + raise ValueError( + 'Cannot change the number of symbols. Use the ' + 'set_symbols_and_weights function instead.' + ) validate_symbols_tuple(symbols_tuple) self._symbols = symbols_tuple @@ -2244,7 +2297,7 @@ def set_symbols_and_weights(self, symbols, weights): if len(symbols_tuple) != len(weights_tuple): raise ValueError('The number of symbols and weights must coincide.') validate_symbols_tuple(symbols_tuple) - validate_weights_tuple(weights_tuple, _sum_threshold) + validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) self._symbols = symbols_tuple self._weights = weights_tuple @@ -2260,7 +2313,7 @@ def is_alloy(self): def has_vacancies(self): """Return whether the Kind contains vacancies, i.e. when the sum of the weights is less than one. - .. note:: the property uses the internal variable `_sum_threshold` as a threshold. + .. note:: the property uses the internal variable `_SUM_THRESHOLD` as a threshold. :return: boolean, True if the sum of the weights is less than one, False otherwise """ @@ -2345,6 +2398,7 @@ def get_ase(self, kinds): .. note:: If any site is an alloy or has vacancies, a ValueError is raised (from the site.get_ase() routine). """ + # pylint: disable=too-many-branches from collections import defaultdict import ase @@ -2371,7 +2425,7 @@ def get_ase(self, kinds): pass tag_list.append(k.symbols[0]) # I use a string as a placeholder - for i in range(len(tag_list)): + for i, _ in enumerate(tag_list): # If it is a string, it is the name of the element, # and I have to generate a new integer for this element # and replace tag_list[i] with this new integer @@ -2388,10 +2442,10 @@ def get_ase(self, kinds): tag_list[i] = new_tag found = False - for k, t in zip(kinds, tag_list): - if k.name == self.kind_name: - kind = k - tag = t + for kind_candidate, tag_candidate in zip(kinds, tag_list): + if kind_candidate.name == self.kind_name: + kind = kind_candidate + tag = tag_candidate found = True break if not found: @@ -2401,7 +2455,7 @@ def get_ase(self, kinds): raise ValueError('Cannot convert to ASE if the kind represents an alloy or it has vacancies.') aseatom = ase.Atom(position=self.position, symbol=str(kind.symbols[0]), mass=kind.mass) if tag is not None: - aseatom.tag = tag + aseatom.tag = tag # pylint: disable=assigning-non-slot return aseatom @property diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index bd2be3dd50..f2dfe39a8d 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines """ The QueryBuilder: A class that allows you to query the AiiDA database, independent from backend. Note that the backend implementation is enforced and handled with a composition model! @@ -18,19 +19,18 @@ An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance when instantiated by the user. """ -# Checking for correct input with the inspect module from inspect import isclass as inspect_isclass import copy import logging import warnings + from sqlalchemy import and_, or_, not_, func as sa_func, select, join from sqlalchemy.types import Integer from sqlalchemy.orm import aliased -from sqlalchemy.sql.expression import cast +from sqlalchemy.sql.expression import cast as type_cast from sqlalchemy.dialects.postgresql import array from aiida.common.exceptions import InputValidationError -# The way I get column as a an attribute to the orm class from aiida.common.links import LinkType from aiida.manage.manager import get_manager from aiida.common.exceptions import ConfigurationError @@ -58,18 +58,19 @@ GROUP_ENTITY_TYPE_PREFIX = 'group.' -def get_querybuilder_classifiers_from_cls(cls, qb): +def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invalid-name """ Return the correct classifiers for the QueryBuilder from an ORM class. :param cls: an AiiDA ORM class or backend ORM class. - :param qb: an instance of the appropriate QueryBuilder backend. + :param query: an instance of the appropriate QueryBuilder backend. :returns: the ORM class as well as a dictionary with additional classifier strings :rtype: cls, dict Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) """ + # pylint: disable=protected-access,too-many-branches,too-many-statements # Note: Unable to move this import to the top of the module for some reason from aiida.engine import Process from aiida.orm.utils.node import is_valid_node_type_string @@ -79,63 +80,63 @@ def get_querybuilder_classifiers_from_cls(cls, qb): classifiers['process_type_string'] = None # Nodes - if issubclass(cls, qb.Node): + if issubclass(cls, query.Node): # If a backend ORM node (i.e. DbNode) is passed. # Users shouldn't do that, by why not... - classifiers['ormclass_type_string'] = qb.AiidaNode._plugin_type_string + classifiers['ormclass_type_string'] = query.AiidaNode._plugin_type_string ormclass = cls - elif issubclass(cls, qb.AiidaNode): + elif issubclass(cls, query.AiidaNode): classifiers['ormclass_type_string'] = cls._plugin_type_string - ormclass = qb.Node + ormclass = query.Node # Groups: - elif issubclass(cls, qb.Group): + elif issubclass(cls, query.Group): classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string ormclass = cls elif issubclass(cls, groups.Group): classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string - ormclass = qb.Group + ormclass = query.Group # Computers: - elif issubclass(cls, qb.Computer): + elif issubclass(cls, query.Computer): classifiers['ormclass_type_string'] = 'computer' ormclass = cls elif issubclass(cls, computers.Computer): classifiers['ormclass_type_string'] = 'computer' - ormclass = qb.Computer + ormclass = query.Computer # Users - elif issubclass(cls, qb.User): + elif issubclass(cls, query.User): classifiers['ormclass_type_string'] = 'user' ormclass = cls elif issubclass(cls, users.User): classifiers['ormclass_type_string'] = 'user' - ormclass = qb.User + ormclass = query.User # AuthInfo - elif issubclass(cls, qb.AuthInfo): + elif issubclass(cls, query.AuthInfo): classifiers['ormclass_type_string'] = 'authinfo' ormclass = cls elif issubclass(cls, authinfos.AuthInfo): classifiers['ormclass_type_string'] = 'authinfo' - ormclass = qb.AuthInfo + ormclass = query.AuthInfo # Comment - elif issubclass(cls, qb.Comment): + elif issubclass(cls, query.Comment): classifiers['ormclass_type_string'] = 'comment' ormclass = cls elif issubclass(cls, comments.Comment): classifiers['ormclass_type_string'] = 'comment' - ormclass = qb.Comment + ormclass = query.Comment # Log - elif issubclass(cls, qb.Log): + elif issubclass(cls, query.Log): classifiers['ormclass_type_string'] = 'log' ormclass = cls elif issubclass(cls, logs.Log): classifiers['ormclass_type_string'] = 'log' - ormclass = qb.Log + ormclass = query.Log # Process # This is a special case, since Process is not an ORM class. @@ -143,23 +144,23 @@ def get_querybuilder_classifiers_from_cls(cls, qb): elif issubclass(cls, Process): classifiers['ormclass_type_string'] = cls._node_class._plugin_type_string classifiers['process_type_string'] = cls.build_process_type() - ormclass = qb.Node + ormclass = query.Node else: raise InputValidationError('I do not know what to do with {}'.format(cls)) - if ormclass == qb.Node: + if ormclass == query.Node: is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) return ormclass, classifiers -def get_querybuilder_classifiers_from_type(ormclass_type_string, qb): +def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pylint: disable=invalid-name """ Return the correct classifiers for the QueryBuilder from an ORM type string. :param ormclass_type_string: type string for ORM class - :param qb: an instance of the appropriate QueryBuilder backend. + :param query: an instance of the appropriate QueryBuilder backend. :returns: the ORM class as well as a dictionary with additional classifier strings :rtype: cls, dict @@ -174,18 +175,18 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, qb): if classifiers['ormclass_type_string'].startswith(GROUP_ENTITY_TYPE_PREFIX): classifiers['ormclass_type_string'] = 'group.core' - ormclass = qb.Group + ormclass = query.Group elif classifiers['ormclass_type_string'] == 'computer': - ormclass = qb.Computer + ormclass = query.Computer elif classifiers['ormclass_type_string'] == 'user': - ormclass = qb.User + ormclass = query.User else: # At this point, we assume it is a node. The only valid type string then is a string # that matches exactly the _plugin_type_string of a node class classifiers['ormclass_type_string'] = ormclass_type_string # no lowercase - ormclass = qb.Node + ormclass = query.Node - if ormclass == qb.Node: + if ormclass == query.Node: is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) return ormclass, classifiers @@ -233,7 +234,6 @@ def get_process_type_filter(classifiers, subclassing): from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaEntryPointWarning from aiida.engine.processes.process import get_query_string_from_process_type_string - import warnings value = classifiers['process_type_string'] @@ -246,10 +246,16 @@ def get_process_type_filter(classifiers, subclassing): # Note: the process_type_string stored in the database does *not* end in a dot. # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', # we need to search separately for equality and 'like'. - filters = {'or': [ - {'==': value}, - {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, - ]} + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } elif value.startswith('aiida.engine'): # For core process types, a filter is not is needed since each process type has a corresponding # ormclass type that already specifies everything. @@ -259,14 +265,21 @@ def get_process_type_filter(classifiers, subclassing): # Note: Improve this when issue #2475 is addressed filters = {'like': '%'} else: - warnings.warn("Process type '{}' does not correspond to a registered entry. " - 'This risks queries to fail once the location of the process class changes. ' - "Add an entry point for '{}' to remove this warning.".format(value, value), - AiidaEntryPointWarning) - filters = {'or': [ - {'==': value}, - {'like': escape_for_sql_like(get_query_string_from_process_type_string(value))}, - ]} + warnings.warn( + "Process type '{value}' does not correspond to a registered entry. " + 'This risks queries to fail once the location of the process class changes. ' + "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning + ) + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } return filters @@ -314,6 +327,8 @@ class QueryBuilder: """ + # pylint: disable=too-many-instance-attributes,too-many-public-methods + # This tag defines how edges are tagged (labeled) by the QueryBuilder default # namely tag of first entity + _EDGE_TAG_DELIM + tag of second entity _EDGE_TAG_DELIM = '--' @@ -362,12 +377,18 @@ def __init__(self, backend=None, **kwargs): # A dictionary tag:alias of ormclass # redundant but makes life easier self.tag_to_alias_map = {} + self.tag_to_projected_property_dict = {} # A dictionary tag: filter specification for this alias self._filters = {} # A dictionary tag: projections for this alias self._projections = {} + self.nr_of_projections = 0 + + self._attrkeys_as_in_sql_result = None + + self._query = None # A dictionary for classes passed to the tag given to them # Everything is specified with unique tags, which are strings. @@ -385,20 +406,10 @@ def __init__(self, backend=None, **kwargs): # is used twice. In that case, the user has to provide a tag! self._cls_to_tag_map = {} - # Hashing the the internal queryhelp allows me to avoid to build a query again, if i have used - # it already. - # Example: - - ## User is building a query: - # qb = QueryBuilder().append(.....) - ## User asks for the first results: - # qb.first() - ## User asks for all results, of the same query: - # qb.all() - # In above example, I can reuse the query, and to track whether somethis was changed - # I record a hash: + # Hashing the the internal queryhelp allows me to avoid to build a query again self._hash = None - ## The hash being None implies that the query will be build (Check the code in .get_query + + # The hash being None implies that the query will be build (Check the code in .get_query # The user can inject a query, this keyword stores whether this was done. # Check QueryBuilder.inject_query self._injected = False @@ -414,7 +425,6 @@ def __init__(self, backend=None, **kwargs): for path_spec in path: if isinstance(path_spec, dict): self.append(**path_spec) - # ~ except TypeError as e: elif isinstance(path_spec, str): # Maybe it is just a string, # I assume user means the type @@ -454,10 +464,12 @@ def __init__(self, backend=None, **kwargs): # If kwargs is not empty, there is a problem: if kwargs: valid_keys = ('path', 'filters', 'project', 'limit', 'offset', 'order_by') - raise InputValidationError('Received additional keywords: {}' - '\nwhich I cannot process' - '\nValid keywords are: {}' - ''.format(list(kwargs.keys()), valid_keys)) + raise InputValidationError( + 'Received additional keywords: {}' + '\nwhich I cannot process' + '\nValid keywords are: {}' + ''.format(list(kwargs.keys()), valid_keys) + ) def __str__(self): """ @@ -504,9 +516,9 @@ def _get_ormclass(self, cls, ormclass_type_string): ormclass = None classifiers = [] - for i, c in enumerate(input_info): - new_ormclass, new_classifiers = func(c, self._impl) - if i: + for index, classifier in enumerate(input_info): + new_ormclass, new_classifiers = func(classifier, self._impl) + if index: # This is not my first iteration! # I check consistency with what was specified before if new_ormclass != ormclass: @@ -556,8 +568,8 @@ def get_tag_from_type(classifiers): """ if isinstance(classifiers, list): return '-'.join([t['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' for t in classifiers]) - else: - return classifiers['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' + + return classifiers['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' basetag = get_tag_from_type(classifiers) tags_used = self.tag_to_alias_map.keys() @@ -568,18 +580,20 @@ def get_tag_from_type(classifiers): raise RuntimeError('Cannot find a tag after 100 tries') - def append(self, - cls=None, - entity_type=None, - tag=None, - filters=None, - project=None, - subclassing=True, - edge_tag=None, - edge_filters=None, - edge_project=None, - outerjoin=False, - **kwargs): + def append( + self, + cls=None, + entity_type=None, + tag=None, + filters=None, + project=None, + subclassing=True, + edge_tag=None, + edge_filters=None, + edge_project=None, + outerjoin=False, + **kwargs + ): """ Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. @@ -638,13 +652,16 @@ def append(self, :return: self :rtype: :class:`aiida.orm.QueryBuilder` """ + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## # This function can be called by users, so I am checking the # input now. # First of all, let's make sure the specified # the class or the type (not both) if cls and entity_type: - raise InputValidationError('You cannot specify both a class ({}) and a entity_type ({})'.format(cls, entity_type)) + raise InputValidationError( + 'You cannot specify both a class ({}) and a entity_type ({})'.format(cls, entity_type) + ) if not (cls or entity_type): raise InputValidationError('You need to specify at least a class or a entity_type') @@ -652,17 +669,17 @@ def append(self, # Let's check if it is a valid class or type if cls: if isinstance(cls, (tuple, list, set)): - for c in cls: - if not inspect_isclass(c): - raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(c)) + for sub_cls in cls: + if not inspect_isclass(sub_cls): + raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(sub_cls)) else: if not inspect_isclass(cls): raise InputValidationError("{} was passed with kw 'cls', but is not a class".format(cls)) elif entity_type: if isinstance(entity_type, (tuple, list, set)): - for t in entity_type: - if not isinstance(t, str): - raise InputValidationError('{} was passed as entity_type, but is not a string'.format(t)) + for sub_type in entity_type: + if not isinstance(sub_type, str): + raise InputValidationError('{} was passed as entity_type, but is not a string'.format(sub_type)) else: if not isinstance(entity_type, str): raise InputValidationError('{} was passed as entity_type, but is not a string'.format(entity_type)) @@ -673,10 +690,11 @@ def append(self, # Let's get a tag if tag: if self._EDGE_TAG_DELIM in tag: - raise InputValidationError('tag cannot contain {}\n' - 'since this is used as a delimiter for links' - ''.format(self._EDGE_TAG_DELIM)) - tag = tag + raise InputValidationError( + 'tag cannot contain {}\n' + 'since this is used as a delimiter for links' + ''.format(self._EDGE_TAG_DELIM) + ) if tag in self.tag_to_alias_map.keys(): raise InputValidationError('This tag ({}) is already in use'.format(tag)) else: @@ -687,7 +705,6 @@ def append(self, # Now, several things can go wrong along the way, so I need to split into # atomic blocks that I can reverse if something goes wrong. # TAG MAPPING ################################# - # TODO check with duplicate classes # Let's fill the cls_to_tag_map so that one can specify # this vertice in a joining specification later @@ -716,10 +733,10 @@ def append(self, # ALIASING ############################## try: self.tag_to_alias_map[tag] = aliased(ormclass) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append, cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -742,10 +759,10 @@ def append(self, # if the user specified a filter, add it: if filters is not None: self.add_filter(tag, filters) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part filters), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag) @@ -757,18 +774,19 @@ def append(self, self._projections[tag] = [] if project is not None: self.add_projection(tag, project) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part projections), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) self._filters.pop(tag) self._projections.pop(tag) - raise e + raise exception # JOINING ##################################### + # pylint: disable=too-many-nested-blocks try: # Get the functions that are implemented: spec_to_function_map = [] @@ -784,12 +802,16 @@ def append(self, '{} is not a valid keyword ' 'for joining specification\n' 'Valid keywords are: ' - '{}'.format(key, - spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'])) + '{}'.format( + key, spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'] + ) + ) elif joining_keyword: - raise InputValidationError('You already specified joining specification {}\n' - 'But you now also want to specify {}' - ''.format(joining_keyword, key)) + raise InputValidationError( + 'You already specified joining specification {}\n' + 'But you now also want to specify {}' + ''.format(joining_keyword, key) + ) else: joining_keyword = key if joining_keyword == 'direction': @@ -804,9 +826,11 @@ def append(self, raise InputValidationError('direction=0 is not valid') joining_value = self._path[-abs(val)]['tag'] except IndexError as exc: - raise InputValidationError('You have specified a non-existent entity with\n' - 'direction={}\n' - '{}\n'.format(joining_value, exc)) + raise InputValidationError( + 'You have specified a non-existent entity with\n' + 'direction={}\n' + '{}\n'.format(joining_value, exc) + ) else: joining_value = self._get_tag_from_specification(val) # the default is that this vertice is 'with_incoming' as the previous one @@ -814,18 +838,17 @@ def append(self, joining_keyword = 'with_incoming' joining_value = self._path[-1]['tag'] - - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part joining), cleaning up') - print(' ', e) + print(' ', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) self._filters.pop(tag) self._projections.pop(tag) # There's not more to clean up here! - raise e + raise exception # EDGES ################################# if len(self._path) > 0: @@ -856,7 +879,7 @@ def append(self, self._projections[edge_tag] = [] if edge_project is not None: self.add_projection(edge_tag, edge_project) - except Exception as e: + except Exception as exception: if self._debug: print('DEBUG: Exception caught in append (part joining), cleaning up') @@ -872,7 +895,7 @@ def append(self, self._filters.pop(edge_tag, None) self._projections.pop(edge_tag, None) # There's not more to clean up here! - raise e + raise exception # EXTENDING THE PATH ################################# # Note: 'type' being a list is a relict of an earlier implementation @@ -889,7 +912,9 @@ def append(self, joining_keyword=joining_keyword, joining_value=joining_value, outerjoin=outerjoin, - edge_tag=edge_tag)) + edge_tag=edge_tag + ) + ) return self @@ -927,7 +952,7 @@ def order_by(self, order_by): qb.append(Node, tag='node') qb.order_by({'node':[{'id':'desc'}]}) """ - + # pylint: disable=too-many-nested-blocks,too-many-branches self._order_by = [] allowed_keys = ('cast', 'order') possible_orders = ('asc', 'desc') @@ -937,10 +962,12 @@ def order_by(self, order_by): for order_spec in order_by: if not isinstance(order_spec, dict): - raise InputValidationError('Invalid input for order_by statement: {}\n' - 'I am expecting a dictionary ORMClass,' - '[columns to sort]' - ''.format(order_spec)) + raise InputValidationError( + 'Invalid input for order_by statement: {}\n' + 'I am expecting a dictionary ORMClass,' + '[columns to sort]' + ''.format(order_spec) + ) _order_spec = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): @@ -953,9 +980,11 @@ def order_by(self, order_by): elif isinstance(item_to_order_by, dict): pass else: - raise InputValidationError('Cannot deal with input to order_by {}\n' - 'of type{}' - '\n'.format(item_to_order_by, type(item_to_order_by))) + raise InputValidationError( + 'Cannot deal with input to order_by {}\n' + 'of type{}' + '\n'.format(item_to_order_by, type(item_to_order_by)) + ) for entityname, orderspec in item_to_order_by.items(): # if somebody specifies eg {'node':{'id':'asc'}} # tranform to {'node':{'id':{'order':'asc'}}} @@ -965,21 +994,27 @@ def order_by(self, order_by): elif isinstance(orderspec, dict): this_order_spec = orderspec else: - raise InputValidationError('I was expecting a string or a dictionary\n' - 'You provided {} {}\n' - ''.format(type(orderspec), orderspec)) - for key in this_order_spec.keys(): + raise InputValidationError( + 'I was expecting a string or a dictionary\n' + 'You provided {} {}\n' + ''.format(type(orderspec), orderspec) + ) + for key in this_order_spec: if key not in allowed_keys: - raise InputValidationError('The allowed keys for an order specification\n' - 'are {}\n' - '{} is not valid\n' - ''.format(', '.join(allowed_keys), key)) + raise InputValidationError( + 'The allowed keys for an order specification\n' + 'are {}\n' + '{} is not valid\n' + ''.format(', '.join(allowed_keys), key) + ) this_order_spec['order'] = this_order_spec.get('order', 'asc') if this_order_spec['order'] not in possible_orders: - raise InputValidationError('You gave {} as an order parameters,\n' - 'but it is not a valid order parameter\n' - 'Valid orders are: {}\n' - ''.format(this_order_spec['order'], possible_orders)) + raise InputValidationError( + 'You gave {} as an order parameters,\n' + 'but it is not a valid order parameter\n' + 'Valid orders are: {}\n' + ''.format(this_order_spec['order'], possible_orders) + ) item_to_order_by[entityname] = this_order_spec _order_spec[tag].append(item_to_order_by) @@ -1009,7 +1044,9 @@ def add_filter(self, tagspec, filter_spec): tag = self._get_tag_from_specification(tagspec) self._filters[tag].update(filters) - def _process_filters(self, filters): + @staticmethod + def _process_filters(filters): + """Process filters.""" if not isinstance(filters, dict): raise InputValidationError('Filters have to be passed as dictionaries') @@ -1036,8 +1073,8 @@ def _add_node_type_filter(self, tagspec, classifiers, subclassing): if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers entity_type_filter = {'or': []} - for c in classifiers: - entity_type_filter['or'].append(get_node_type_filter(c, subclassing)) + for classifier in classifiers: + entity_type_filter['or'].append(get_node_type_filter(classifier, subclassing)) else: entity_type_filter = get_node_type_filter(classifiers, subclassing) @@ -1056,9 +1093,9 @@ def _add_process_type_filter(self, tagspec, classifiers, subclassing): if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers process_type_filter = {'or': []} - for c in classifiers: - if c['process_type_string'] is not None: - process_type_filter['or'].append(get_process_type_filter(c, subclassing)) + for classifier in classifiers: + if classifier['process_type_string'] is not None: + process_type_filter['or'].append(get_process_type_filter(classifier, subclassing)) if len(process_type_filter['or']) > 0: self.add_filter(tagspec, {'process_type': process_type_filter}) @@ -1147,12 +1184,14 @@ def add_projection(self, tag_spec, projection_spec): _thisprojection = {projection: {}} else: raise InputValidationError('Cannot deal with projection specification {}\n'.format(projection)) - for p, spec in _thisprojection.items(): + for spec in _thisprojection.values(): if not isinstance(spec, dict): - raise InputValidationError('\nThe value of a key-value pair in a projection\n' - 'has to be a dictionary\n' - 'You gave: {}\n' - ''.format(spec)) + raise InputValidationError( + '\nThe value of a key-value pair in a projection\n' + 'has to be a dictionary\n' + 'You gave: {}\n' + ''.format(spec) + ) for key, val in spec.items(): if key not in self._VALID_PROJECTION_KEYS: @@ -1165,6 +1204,7 @@ def add_projection(self, tag_spec, projection_spec): self._projections[tag] = _projections def _get_projectable_entity(self, alias, column_name, attrpath, **entityspec): + """Return projectable entity for a given alias and column name.""" if attrpath or column_name in ('attributes', 'extras'): entity = self._impl.get_projectable_attribute(alias, column_name, attrpath, **entityspec) else: @@ -1186,10 +1226,12 @@ def _add_to_projections(self, alias, projectable_entity_name, cast=None, func=No if column_name == '*': if func is not None: - raise InputValidationError('Very sorry, but functions on the aliased class\n' - "(You specified '*')\n" - 'will not work!\n' - "I suggest you apply functions on a column, e.g. ('id')\n") + raise InputValidationError( + 'Very sorry, but functions on the aliased class\n' + "(You specified '*')\n" + 'will not work!\n' + "I suggest you apply functions on a column, e.g. ('id')\n" + ) self._query = self._query.add_entity(alias) else: entity_to_project = self._get_projectable_entity(alias, column_name, attr_key, cast=cast) @@ -1205,11 +1247,8 @@ def _add_to_projections(self, alias, projectable_entity_name, cast=None, func=No raise InputValidationError('\nInvalid function specification {}'.format(func)) self._query = self._query.add_columns(entity_to_project) - def get_table_columns(self, table_alias): - raise NotImplementedError - def _build_projections(self, tag, items_to_project=None): - + """Build the projections for a given tag.""" if items_to_project is None: items_to_project = self._projections.get(tag, []) @@ -1248,16 +1287,19 @@ def _get_tag_from_specification(self, specification): if specification in self.tag_to_alias_map.keys(): tag = specification else: - raise InputValidationError('tag {} is not among my known tags\n' - 'My tags are: {}'.format(specification, self.tag_to_alias_map.keys())) + raise InputValidationError( + 'tag {} is not among my known tags\n' + 'My tags are: {}'.format(specification, self.tag_to_alias_map.keys()) + ) else: if specification in self._cls_to_tag_map.keys(): tag = self._cls_to_tag_map[specification] else: - raise InputValidationError('You specified as a class for which I have to find a tag\n' - 'The classes that I can do this for are:{}\n' - 'The tags I have are: {}'.format(specification, self._cls_to_tag_map.keys(), - self.tag_to_alias_map.keys())) + raise InputValidationError( + 'You specified as a class for which I have to find a tag\n' + 'The classes that I can do this for are:{}\n' + 'The tags I have are: {}'.format(self._cls_to_tag_map.keys(), self.tag_to_alias_map.keys()) + ) return tag def set_debug(self, debug): @@ -1337,7 +1379,7 @@ def _build_filters(self, alias, filter_spec): raise if not isinstance(filter_operation_dict, dict): filter_operation_dict = {'==': filter_operation_dict} - [ + for operator, value in filter_operation_dict.items(): expressions.append( self._impl.get_filter_expr( operator, @@ -1346,8 +1388,9 @@ def _build_filters(self, alias, filter_spec): is_attribute=is_attribute, column=column, column_name=column_name, - alias=alias)) for operator, value in filter_operation_dict.items() - ] + alias=alias + ) + ) return and_(*expressions) @staticmethod @@ -1365,22 +1408,25 @@ def _check_dbentities(entities_cls_joined, entities_cls_to_join, relationship): The relationship between the two entities to make the Exception comprehensible """ + # pylint: disable=protected-access for entity, cls in (entities_cls_joined, entities_cls_to_join): if not issubclass(entity._sa_class_manager.class_, cls): - raise InputValidationError("You are attempting to join {} as '{}' of {}\n" - 'This failed because you passed:\n' - ' - {} as entity joined (expected {})\n' - ' - {} as entity to join (expected {})\n' - '\n'.format( - entities_cls_joined[0].__name__, - relationship, - entities_cls_to_join[0].__name__, - entities_cls_joined[0]._sa_class_manager.class_.__name__, - entities_cls_joined[1].__name__, - entities_cls_to_join[0]._sa_class_manager.class_.__name__, - entities_cls_to_join[1].__name__, - )) + raise InputValidationError( + "You are attempting to join {} as '{}' of {}\n" + 'This failed because you passed:\n' + ' - {} as entity joined (expected {})\n' + ' - {} as entity to join (expected {})\n' + '\n'.format( + entities_cls_joined[0].__name__, + relationship, + entities_cls_to_join[0].__name__, + entities_cls_joined[0]._sa_class_manager.class_.__name__, + entities_cls_joined[1].__name__, + entities_cls_to_join[0]._sa_class_manager.class_.__name__, + entities_cls_to_join[1].__name__, + ) + ) def _join_outputs(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1394,9 +1440,12 @@ def _join_outputs(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_incoming') aliased_edge = aliased(self._impl.Link) - self._query = self._query.join( - aliased_edge, aliased_edge.input_id == joined_entity.id, isouter=isouterjoin).join( - entity_to_join, aliased_edge.output_id == entity_to_join.id, isouter=isouterjoin) + self._query = self._query.join(aliased_edge, aliased_edge.input_id == joined_entity.id, + isouter=isouterjoin).join( + entity_to_join, + aliased_edge.output_id == entity_to_join.id, + isouter=isouterjoin + ) return aliased_edge def _join_inputs(self, joined_entity, entity_to_join, isouterjoin): @@ -1415,8 +1464,7 @@ def _join_inputs(self, joined_entity, entity_to_join, isouterjoin): self._query = self._query.join( aliased_edge, aliased_edge.output_id == joined_entity.id, - ).join( - entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) + ).join(entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) return aliased_edge def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): @@ -1426,8 +1474,7 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin :TODO: Pass an option to also show the path, if this is wanted. """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), - 'with_ancestors') + self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') link1 = aliased(self._impl.Link) link2 = aliased(self._impl.Link) @@ -1437,7 +1484,7 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin selection_walk_list = [ link1.input_id.label('ancestor_id'), link1.output_id.label('descendant_id'), - cast(0, Integer).label('depth'), + type_cast(0, Integer).label('depth'), ] if expand_path: selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) @@ -1446,13 +1493,15 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin and_( in_recursive_filters, # I apply filters for speed here link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links - )).cte(recursive=True) + ) + ).cte(recursive=True) aliased_walk = aliased(walk) selection_union_list = [ aliased_walk.c.ancestor_id.label('ancestor_id'), - link2.output_id.label('descendant_id'), (aliased_walk.c.depth + cast(1, Integer)).label('current_depth') + link2.output_id.label('descendant_id'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth') ] if expand_path: selection_union_list.append((aliased_walk.c.path + array((link2.output_id,))).label('path')) @@ -1464,13 +1513,17 @@ def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin aliased_walk, link2, link2.input_id == aliased_walk.c.descendant_id, - )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))))) # .alias() + ) + ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ) + ) # .alias() self._query = self._query.join(descendants_recursive, descendants_recursive.c.ancestor_id == joined_entity.id).join( entity_to_join, descendants_recursive.c.descendant_id == entity_to_join.id, - isouter=isouterjoin) + isouter=isouterjoin + ) return descendants_recursive.c def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): @@ -1480,8 +1533,7 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, :TODO: Pass an option to also show the path, if this is wanted. """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), - 'with_ancestors') + self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') link1 = aliased(self._impl.Link) link2 = aliased(self._impl.Link) @@ -1491,21 +1543,21 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, selection_walk_list = [ link1.input_id.label('ancestor_id'), link1.output_id.label('descendant_id'), - cast(0, Integer).label('depth'), + type_cast(0, Integer).label('depth'), ] if expand_path: selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( - and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, - LinkType.INPUT_CALC.value)))).cte(recursive=True) + and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ).cte(recursive=True) aliased_walk = aliased(walk) selection_union_list = [ link2.input_id.label('ancestor_id'), aliased_walk.c.descendant_id.label('descendant_id'), - (aliased_walk.c.depth + cast(1, Integer)).label('current_depth'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth'), ] if expand_path: selection_union_list.append((aliased_walk.c.path + array((link2.input_id,))).label('path')) @@ -1517,15 +1569,18 @@ def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, aliased_walk, link2, link2.output_id == aliased_walk.c.ancestor_id, - )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ) + ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) # I can't follow RETURN or CALL links - )) + ) + ) self._query = self._query.join(ancestors_recursive, ancestors_recursive.c.descendant_id == joined_entity.id).join( entity_to_join, ancestors_recursive.c.ancestor_id == entity_to_join.id, - isouter=isouterjoin) + isouter=isouterjoin + ) return ancestors_recursive.c def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): @@ -1544,7 +1599,8 @@ def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Group), (entity_to_join, self._impl.Node), 'with_group') aliased_group_nodes = aliased(self._impl.table_groups_nodes) self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin) + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin + ) return aliased_group_nodes def _join_groups(self, joined_entity, entity_to_join, isouterjoin): @@ -1560,7 +1616,8 @@ def _join_groups(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Group), 'with_node') aliased_group_nodes = aliased(self._impl.table_groups_nodes) self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin) + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin + ) return aliased_group_nodes def _join_creator_of(self, joined_entity, entity_to_join, isouterjoin): @@ -1587,7 +1644,8 @@ def _join_to_computer_used(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Computer), (entity_to_join, self._impl.Node), 'with_computer') self._query = self._query.join( - entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin + ) def _join_computer(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1596,7 +1654,8 @@ def _join_computer(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Computer), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin + ) def _join_group_user(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1621,7 +1680,8 @@ def _join_node_comment(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Comment), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin + ) def _join_comment_node(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1630,7 +1690,8 @@ def _join_comment_node(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.Node), 'with_comment') self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin + ) def _join_node_log(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1639,7 +1700,8 @@ def _join_node_log(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Log), 'with_node') self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin + ) def _join_log_node(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1648,7 +1710,8 @@ def _join_log_node(self, joined_entity, entity_to_join, isouterjoin): """ self._check_dbentities((joined_entity, self._impl.Log), (entity_to_join, self._impl.Node), 'with_log') self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin + ) def _join_user_comment(self, joined_entity, entity_to_join, isouterjoin): """ @@ -1669,8 +1732,8 @@ def _join_comment_user(self, joined_entity, entity_to_join, isouterjoin): def _get_function_map(self): """ Map relationship type keywords to functions - The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been passed to - the qb.append functon, and the second defines the relationship with respect to a given tag. + The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been + passed to the qb.append functon, and the second defines the relationship with respect to a given tag. """ mapping = { 'node': { @@ -1722,6 +1785,7 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, :param joining_keyword: the relation on which to join :param joining_value: the tag of the nodes to be joined """ + # pylint: disable=unused-argument # Set the calling entity - to allow for the correct join relation to be set entity_type = self._path[index]['entity_type'] @@ -1743,8 +1807,11 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, try: func = self._get_function_map()[calling_entity][joining_keyword] except KeyError: - raise InputValidationError("'{}' is not a valid joining keyword for a '{}' type entity".format( - joining_keyword, calling_entity)) + raise InputValidationError( + "'{}' is not a valid joining keyword for a '{}' type entity".format( + joining_keyword, calling_entity + ) + ) if isinstance(joining_value, int): returnval = (self._aliased_path[joining_value], func) @@ -1752,10 +1819,10 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, try: returnval = self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func except KeyError: - raise InputValidationError('Key {} is unknown to the types I know about:\n' - '{}'.format( - self._get_tag_from_specification(joining_value), - self.tag_to_alias_map.keys())) + raise InputValidationError( + 'Key {} is unknown to the types I know about:\n' + '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) + ) return returnval def get_json_compatible_queryhelp(self): @@ -1820,9 +1887,11 @@ def _build_order(self, alias, entitytag, entityspec): column_name = entitytag.split('.')[0] attrpath = entitytag.split('.')[1:] if attrpath and 'cast' not in entityspec.keys(): - raise InputValidationError('In order to project ({}), I have to cast the the values,\n' - 'but you have not specified the datatype to cast to\n' - "You can do this with keyword 'cast'".format(entitytag)) + raise InputValidationError( + 'In order to project ({}), I have to cast the the values,\n' + 'but you have not specified the datatype to cast to\n' + "You can do this with keyword 'cast'".format(entitytag) + ) entity = self._get_projectable_entity(alias, column_name, attrpath, **entityspec) order = entityspec.get('order', 'asc') @@ -1834,17 +1903,10 @@ def _build(self): """ build the query and return a sqlalchemy.Query instance """ - - # self.tags_location_dict is a dictionary that - # maps the tag to its index in the list - # this is basically the mapping between the count - # of nodes traversed - # and the tag used for that node - self.tags_location_dict = {path['tag']: index for index, path in enumerate(self._path)} + # pylint: disable=too-many-branches # Starting the query by receiving a session - # Every subclass needs to have _get_session and give me the - # right session + # Every subclass needs to have _get_session and give me the right session firstalias = self.tag_to_alias_map[self._path[0]['tag']] self._query = self._impl.get_session().query(firstalias) @@ -1869,7 +1931,8 @@ def _build(self): expand_path = ((self._filters[edge_tag].get('path', None) is not None) or any(['path' in d.keys() for d in self._projections[edge_tag]])) aliased_edge = connection_func( - toconnectwith, alias, isouterjoin=isouterjoin, filter_dict=filter_dict, expand_path=expand_path) + toconnectwith, alias, isouterjoin=isouterjoin, filter_dict=filter_dict, expand_path=expand_path + ) else: aliased_edge = connection_func(toconnectwith, alias, isouterjoin=isouterjoin) if aliased_edge is not None: @@ -1881,9 +1944,10 @@ def _build(self): try: alias = self.tag_to_alias_map[tag] except KeyError: - # TODO Check KeyError before? - raise InputValidationError('You looked for tag {} among the alias list\n' - 'The tags I know are:\n{}'.format(tag, self.tag_to_alias_map.keys())) + raise InputValidationError( + 'You looked for tag {} among the alias list\n' + 'The tags I know are:\n{}'.format(tag, self.tag_to_alias_map.keys()) + ) self._query = self._query.filter(self._build_filters(alias, filter_specs)) ######################### PROJECTIONS ########################## @@ -1891,8 +1955,6 @@ def _build(self): # path was not meant to be projected # attribute of Query instance storing entities to project: - # Will be later set to this list: - entities = [] # Mapping between entities and the tag used/ given by user: self.tag_to_projected_property_dict = {} @@ -1919,17 +1981,19 @@ def _build(self): edge_tag = vertex.get('edge_tag', None) if self._debug: print('DEBUG: Checking projections for edges:') - print(' This is edge {} from {}, {} of {}'.format(edge_tag, vertex.get('tag'), - vertex.get('joining_keyword'), - vertex.get('joining_value'))) + print( + ' This is edge {} from {}, {} of {}'.format( + edge_tag, vertex.get('tag'), vertex.get('joining_keyword'), vertex.get('joining_value') + ) + ) if edge_tag is not None: self._build_projections(edge_tag) # ORDER ################################ for order_spec in self._order_by: - for tag, entities in order_spec.items(): + for tag, entity_list in order_spec.items(): alias = self.tag_to_alias_map[tag] - for entitydict in entities: + for entitydict in entity_list: for entitytag, entityspec in entitydict.items(): self._build_order(alias, entitytag, entityspec) @@ -1943,7 +2007,7 @@ def _build(self): ################ LAST BUT NOT LEAST ############################ # pop the entity that I added to start the query - self._query._entities.pop(0) + self._query._entities.pop(0) # pylint: disable=protected-access # Dirty solution coming up: # Sqlalchemy is by default de-duplicating results if possible. @@ -1953,7 +2017,7 @@ def _build(self): # We also addressed this with sqlachemy: # https://github.com/sqlalchemy/sqlalchemy/issues/4395#event-2002418814 # where the following solution was sanctioned: - self._query._has_mapper_entities = False + self._query._has_mapper_entities = False # pylint: disable=protected-access # We should monitor SQLAlchemy, for when a solution is officially supported by the API! # Make a list that helps the projection postprocessing @@ -2115,10 +2179,7 @@ def first(self): if len(result) != len(self._attrkeys_as_in_sql_result): raise Exception('length of query result does not match the number of specified projections') - return [ - self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) - for colindex, rowitem in enumerate(result) - ] + return [self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result)] def one(self): """ diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py index e2e0249c58..49e5c3b44f 100644 --- a/aiida/orm/utils/remote.py +++ b/aiida/orm/utils/remote.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Utilities for operations on files on remote computers.""" import os @@ -73,17 +74,17 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer if pks: filters_calc['id'] = {'in': pks} - qb = orm.QueryBuilder() - qb.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) - qb.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) - qb.append(orm.User, with_node='calc', filters={'email': user.email}) + query = orm.QueryBuilder() + query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) + query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) + query.append(orm.User, with_node='calc', filters={'email': user.email}) - if qb.count() == 0: + if query.count() == 0: return None path_mapping = {} - for path, computer in qb.all(): + for path, computer in query.all(): if path is not None: path_mapping.setdefault(computer.uuid, []).append(path) diff --git a/aiida/parsers/plugins/arithmetic/add.py b/aiida/parsers/plugins/arithmetic/add.py index 0bb8a7abf8..0856639448 100644 --- a/aiida/parsers/plugins/arithmetic/add.py +++ b/aiida/parsers/plugins/arithmetic/add.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements # Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part # of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency. """Parser for an `ArithmeticAddCalculation` job.""" diff --git a/aiida/parsers/plugins/templatereplacer/doubler.py b/aiida/parsers/plugins/templatereplacer/doubler.py index 93700a3c2c..e3c7d90f8d 100644 --- a/aiida/parsers/plugins/templatereplacer/doubler.py +++ b/aiida/parsers/plugins/templatereplacer/doubler.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Parser for the `TemplatereplacerCalculation` calculation job doubling a number.""" import os from aiida.common import exceptions @@ -19,6 +19,7 @@ class TemplatereplacerDoublerParser(Parser): + """Parser for the `TemplatereplacerCalculation` calculation job doubling a number.""" def parse(self, **kwargs): """Parse the contents of the output files retrieved in the `FolderData`.""" @@ -57,8 +58,11 @@ def parse(self, **kwargs): file_path = os.path.join(retrieved_temporary_folder, retrieved_file) if not os.path.isfile(file_path): - self.logger.error('the file {} was not found in the temporary retrieved folder {}'.format( - retrieved_file, retrieved_temporary_folder)) + self.logger.error( + 'the file {} was not found in the temporary retrieved folder {}'.format( + retrieved_file, retrieved_temporary_folder + ) + ) return self.exit_codes.ERROR_READING_TEMPORARY_RETRIEVED_FILE with open(file_path, 'r', encoding='utf8') as handle: diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 1675ac6cb6..39633995d4 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,inconsistent-return-statements,cyclic-import +# pylint: disable=invalid-name,cyclic-import """Definition of factories to load classes from the various plugin groups.""" from inspect import isclass diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index 8cc4df95d8..ba6f91f157 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -8,7 +8,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements """ It defines the method with all required parameters to run restapi locally. """ diff --git a/aiida/schedulers/plugins/sge.py b/aiida/schedulers/plugins/sge.py index 6bf679b735..ade7595c12 100644 --- a/aiida/schedulers/plugins/sge.py +++ b/aiida/schedulers/plugins/sge.py @@ -322,14 +322,14 @@ def _parse_joblist_output(self, retval, stdout, stderr): try: xmldata = xml.dom.minidom.parseString(stdout) except xml.parsers.expat.ExpatError: - self.logger.error('in sge._parse_joblist_output: xml parsing of stdout failed:' '{}'.format(stdout)) - raise SchedulerParsingError('Error during joblist retrieval,' 'xml parsing of stdout failed') + self.logger.error('in sge._parse_joblist_output: xml parsing of stdout failed: {}'.format(stdout)) + raise SchedulerParsingError('Error during joblist retrieval, xml parsing of stdout failed') else: self.logger.error( 'Error in sge._parse_joblist_output: retval={}; ' 'stdout={}; stderr={}'.format(retval, stdout, stderr) ) - raise SchedulerError('Error during joblist retrieval,' 'no stdout produced') + raise SchedulerError('Error during joblist retrieval, no stdout produced') try: first_child = xmldata.firstChild @@ -381,12 +381,12 @@ def _parse_joblist_output(self, retval, stdout, stderr): self.logger.error('Error in sge._parse_joblist_output:' 'no job id is given, stdout={}' \ .format(stdout)) - raise SchedulerError('Error in sge._parse_joblist_output:' 'no job id is given') + raise SchedulerError('Error in sge._parse_joblist_output: no job id is given') except IndexError: self.logger.error("No 'job_number' given for job index {} in " 'job list, stdout={}'.format(jobs.index(job) \ , stdout)) - raise IndexError('Error in sge._parse_joblist_output:' 'no job id is given') + raise IndexError('Error in sge._parse_joblist_output: no job id is given') try: job_element = job.getElementsByTagName('state').pop(0) diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index eb9e2bef44..a3655b15e2 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -11,7 +11,6 @@ Various utilities to deal with KpointsData instances or create new ones (e.g. band paths, kpoints from a parsed input text file, ...) """ - from aiida.orm import KpointsData, Dict from aiida.tools.data.array.kpoints import legacy from aiida.tools.data.array.kpoints import seekpath @@ -49,15 +48,6 @@ def get_kpoints_path(structure, method='seekpath', **kwargs): if method not in _GET_KPOINTS_PATH_METHODS.keys(): raise ValueError("the method '{}' is not implemented".format(method)) - if method == 'seekpath': - try: - seekpath.check_seekpath_is_installed() - except ImportError: - raise ValueError( - "selected method is 'seekpath' but the package is not installed\n" - "Either install it or pass method='legacy' as input to the function call" - ) - method = _GET_KPOINTS_PATH_METHODS[method] return method(structure, **kwargs) @@ -94,15 +84,6 @@ def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): raise ValueError("the method '{}' is not implemented".format(method)) - if method == 'seekpath': - try: - seekpath.check_seekpath_is_installed() - except ImportError: - raise ValueError( - "selected method is 'seekpath' but the package is not installed\n" - "Either install it or pass method='legacy' as input to the function call" - ) - method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] return method(structure, **kwargs) diff --git a/aiida/tools/data/array/kpoints/legacy.py b/aiida/tools/data/array/kpoints/legacy.py index bc8260ff9e..350db26957 100644 --- a/aiida/tools/data/array/kpoints/legacy.py +++ b/aiida/tools/data/array/kpoints/legacy.py @@ -7,10 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Tool to automatically determine k-points for a given structure using legacy custom implementation.""" +# pylint: disable=too-many-lines,fixme,invalid-name,too-many-arguments,too-many-locals,eval-used import numpy - _default_epsilon_length = 1e-5 _default_epsilon_angle = 1e-5 @@ -36,6 +36,7 @@ def change_reference(reciprocal_cell, kpoints, to_cartesian=True): # hence, first transpose kpoints, then multiply, finally transpose it back return numpy.transpose(numpy.dot(matrix, numpy.transpose(kpoints))) + def analyze_cell(cell=None, pbc=None): """ A function executed by the __init__ or by set_cell. @@ -50,11 +51,7 @@ def analyze_cell(cell=None, pbc=None): dimension = sum(pbc) if cell is None: - return { - 'reciprocal_cell': None, - 'dimension': dimension, - 'pbc': pbc - } + return {'reciprocal_cell': None, 'dimension': dimension, 'pbc': pbc} the_cell = numpy.array(cell) reciprocal_cell = 2. * numpy.pi * numpy.linalg.inv(the_cell).transpose() @@ -71,7 +68,6 @@ def analyze_cell(cell=None, pbc=None): cosbeta = numpy.dot(a3, a1) / c / a cosgamma = numpy.dot(a1, a2) / a / b - result = { 'a1': a1, 'a2': a2, @@ -93,9 +89,15 @@ def analyze_cell(cell=None, pbc=None): return result -def get_explicit_kpoints_path(value=None, cell=None, pbc=None, kpoint_distance=None, cartesian=False, - epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def get_explicit_kpoints_path( + value=None, + cell=None, + pbc=None, + kpoint_distance=None, + cartesian=False, + epsilon_length=_default_epsilon_length, + epsilon_angle=_default_epsilon_angle +): """ Set a path of kpoints in the Brillouin zone. @@ -131,11 +133,8 @@ def get_explicit_kpoints_path(value=None, cell=None, pbc=None, kpoint_distance=N :returns: point_coordinates, path, bravais_info, explicit_kpoints, labels """ - bravais_info = find_bravais_info( - cell=cell, pbc=pbc, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle - ) + # pylint: disable=too-many-branches,too-many-statements + bravais_info = find_bravais_info(cell=cell, pbc=pbc, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle) analysis = analyze_cell(cell, pbc) dimension = analysis['dimension'] @@ -166,18 +165,14 @@ def _is_path_2(path): if not are_three: return False - are_good = all([all([isinstance(b[0], str), - isinstance(b[1], str), - isinstance(b[2], int)]) - for b in path]) + are_good = all([all([isinstance(b[0], str), isinstance(b[1], str), isinstance(b[2], int)]) for b in path]) if not are_good: return False # check that at least two points per segment (beginning and end) points_num = [int(i[2]) for i in path] if any([i < 2 for i in points_num]): - raise ValueError('Must set at least two points per path ' - 'segment') + raise ValueError('Must set at least two points per path segment') except IndexError: return False @@ -218,8 +213,7 @@ def _is_path_4(path): # check that at least two points per segment (beginning and end) points_num = [int(i[4]) for i in path] if any([i < 2 for i in points_num]): - raise ValueError('Must set at least two points per path ' - 'segment') + raise ValueError('Must set at least two points per path segment') for i in path: coord1 = [float(j) for j in i[1]] coord2 = [float(j) for j in i[3]] @@ -232,9 +226,9 @@ def _is_path_4(path): def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): # NOTE: this way of creating intervals ensures equispaced objects # in crystal coordinates of b1,b2,b3 - distances = [numpy.linalg.norm(numpy.array(point_coordinates[i[0]]) - - numpy.array(point_coordinates[i[1]]) - ) for i in path] + distances = [ + numpy.linalg.norm(numpy.array(point_coordinates[i[0]]) - numpy.array(point_coordinates[i[1]])) for i in path + ] if kpoint_distance is None: # Use max_points_per_interval as the default guess for automatically @@ -244,8 +238,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): try: points_per_piece = [max(2, max_point_per_interval * i // max_interval) for i in distances] except ValueError: - raise ValueError('The beginning and end of each segment in the ' - 'path should be different.') + raise ValueError('The beginning and end of each segment in the path should be different.') else: points_per_piece = [max(2, int(distance // kpoint_distance)) for distance in distances] @@ -253,8 +246,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): if cartesian: if cell is None: - raise ValueError('To use cartesian coordinates, a cell must ' - 'be provided') + raise ValueError('To use cartesian coordinates, a cell must be provided') if kpoint_distance is not None: if kpoint_distance <= 0.: @@ -262,38 +254,32 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): if value is None: if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') point_coordinates, path, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = _num_points_from_coordinates(path, point_coordinates, kpoint_distance) elif _is_path_1(value): # in the form [('X','M'),(...),...] if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') path = value point_coordinates, _, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = _num_points_from_coordinates(path, point_coordinates, kpoint_distance) elif _is_path_2(value): # [('G','M',30), (...), ...] if cell is None: - raise ValueError('Cannot set a path not even knowing the ' - 'kpoints or at least the cell') + raise ValueError('Cannot set a path not even knowing the kpoints or at least the cell') path = [(i[0], i[1]) for i in value] point_coordinates, _, bravais_info = get_kpoints_path( - cell=cell, pbc=pbc, cartesian=cartesian, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle) + cell=cell, pbc=pbc, cartesian=cartesian, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle + ) num_points = [i[2] for i in value] elif _is_path_3(value): @@ -307,10 +293,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[0]] = change_reference( - reciprocal_cell, - numpy.array([piece[1]]), - to_cartesian=False)[0] + point_coordinates[ + piece[0]] = change_reference(reciprocal_cell, numpy.array([piece[1]]), to_cartesian=False)[0] else: point_coordinates[piece[0]] = piece[1] if piece[2] in point_coordinates: @@ -318,10 +302,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[2]] = change_reference( - reciprocal_cell, - numpy.array([piece[3]]), - to_cartesian=False)[0] + point_coordinates[ + piece[2]] = change_reference(reciprocal_cell, numpy.array([piece[3]]), to_cartesian=False)[0] else: point_coordinates[piece[2]] = piece[3] @@ -338,10 +320,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[0]] = change_reference( - reciprocal_cell, - numpy.array([piece[1]]), - to_cartesian=False)[0] + point_coordinates[ + piece[0]] = change_reference(reciprocal_cell, numpy.array([piece[1]]), to_cartesian=False)[0] else: point_coordinates[piece[0]] = piece[1] if piece[2] in point_coordinates: @@ -349,10 +329,8 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): raise ValueError('Different points cannot have the same label') else: if cartesian: - point_coordinates[piece[2]] = change_reference( - reciprocal_cell, - numpy.array([piece[3]]), - to_cartesian=False)[0] + point_coordinates[ + piece[2]] = change_reference(reciprocal_cell, numpy.array([piece[3]]), to_cartesian=False)[0] else: point_coordinates[piece[2]] = piece[3] @@ -370,19 +348,19 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): ini_coord = point_coordinates[ini_label] end_coord = point_coordinates[end_label] - path_piece = list(zip(numpy.linspace(ini_coord[0], end_coord[0], - num_points[count_piece]), - numpy.linspace(ini_coord[1], end_coord[1], - num_points[count_piece]), - numpy.linspace(ini_coord[2], end_coord[2], - num_points[count_piece]), - )) + path_piece = list( + zip( + numpy.linspace(ini_coord[0], end_coord[0], num_points[count_piece]), + numpy.linspace(ini_coord[1], end_coord[1], num_points[count_piece]), + numpy.linspace(ini_coord[2], end_coord[2], num_points[count_piece]), + ) + ) for count, j in enumerate(path_piece): if all(numpy.array(explicit_kpoints[-1]) == j): continue # avoid duplcates - else: - explicit_kpoints.append(j) + + explicit_kpoints.append(j) # add labels for the first and last point if count == 0: @@ -396,8 +374,7 @@ def _num_points_from_coordinates(path, point_coordinates, kpoint_distance=None): return point_coordinates, path, bravais_info, explicit_kpoints, labels -def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, epsilon_angle=_default_epsilon_angle): """ Finds the Bravais lattice of the cell passed in input to the Kpoint class :note: We assume that the cell given by the cell property is the @@ -425,6 +402,7 @@ def find_bravais_info(cell, pbc, epsilon_length=_default_epsilon_length, the variation of the Bravais lattice) and extra (a dictionary with extra parameters used by the get_kpoints_path method) """ + # pylint: disable=too-many-branches,too-many-statements if cell is None: return None @@ -467,10 +445,8 @@ def a_are_equals(a, b): # 3D case -> 14 possible Bravais lattices # # =========================================# - comparison_length = [l_are_equals(a, b), l_are_equals(b, c), - l_are_equals(c, a)] - comparison_angles = [a_are_equals(cosa, cosb), a_are_equals(cosb, cosc), - a_are_equals(cosc, cosa)] + comparison_length = [l_are_equals(a, b), l_are_equals(b, c), l_are_equals(c, a)] + comparison_angles = [a_are_equals(cosa, cosb), a_are_equals(cosb, cosc), a_are_equals(cosc, cosa)] if comparison_length.count(True) == 3: @@ -479,68 +455,70 @@ def a_are_equals(a, b): orci_b = numpy.linalg.norm(a1 + a3) orci_c = numpy.linalg.norm(a1 + a2) orci_the_a, orci_the_b, orci_the_c = sorted([orci_a, orci_b, orci_c]) - bco1 = - (-orci_the_a ** 2 + orci_the_b ** 2 + orci_the_c ** 2) / (4. * a ** 2) - bco2 = - (orci_the_a ** 2 - orci_the_b ** 2 + orci_the_c ** 2) / (4. * a ** 2) - bco3 = - (orci_the_a ** 2 + orci_the_b ** 2 - orci_the_c ** 2) / (4. * a ** 2) + bco1 = -(-orci_the_a**2 + orci_the_b**2 + orci_the_c**2) / (4. * a**2) + bco2 = -(orci_the_a**2 - orci_the_b**2 + orci_the_c**2) / (4. * a**2) + bco3 = -(orci_the_a**2 + orci_the_b**2 - orci_the_c**2) / (4. * a**2) # ======================# # simple cubic lattice # # ======================# if comparison_angles.count(True) == 3 and a_are_equals(cosa, _90): - bravais_info = {'short_name': 'cub', - 'extended_name': 'cubic', - 'index': 1, - 'permutation': [0, 1, 2] - } + bravais_info = {'short_name': 'cub', 'extended_name': 'cubic', 'index': 1, 'permutation': [0, 1, 2]} # =====================# # face centered cubic # # =====================# elif comparison_angles.count(True) == 3 and a_are_equals(cosa, _60): - bravais_info = {'short_name': 'fcc', - 'extended_name': 'face centered cubic', - 'index': 2, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'fcc', + 'extended_name': 'face centered cubic', + 'index': 2, + 'permutation': [0, 1, 2] + } # =====================# # body centered cubic # # =====================# elif comparison_angles.count(True) == 3 and a_are_equals(cosa, -1. / 3.): - bravais_info = {'short_name': 'bcc', - 'extended_name': 'body centered cubic', - 'index': 3, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'bcc', + 'extended_name': 'body centered cubic', + 'index': 3, + 'permutation': [0, 1, 2] + } # ==============# # rhombohedral # # ==============# elif comparison_angles.count(True) == 3: # logical order is important, this check must come after the cubic cases - bravais_info = {'short_name': 'rhl', - 'extended_name': 'rhombohedral', - 'index': 11, - 'permutation': [0, 1, 2] - } + bravais_info = { + 'short_name': 'rhl', + 'extended_name': 'rhombohedral', + 'index': 11, + 'permutation': [0, 1, 2] + } if cosa > 0.: bravais_info['variation'] = 'rhl1' eta = (1. + 4. * cosa) / (2. + 4. * cosa) - bravais_info['extra'] = {'eta': eta, - 'nu': 0.75 - eta / 2., - } + bravais_info['extra'] = { + 'eta': eta, + 'nu': 0.75 - eta / 2., + } else: bravais_info['variation'] = 'rhl2' eta = 1. / (2. * (1. - cosa) / (1. + cosa)) - bravais_info['extra'] = {'eta': eta, - 'nu': 0.75 - eta / 2., - } + bravais_info['extra'] = { + 'eta': eta, + 'nu': 0.75 - eta / 2., + } # ==========================# # body centered tetragonal # # ==========================# elif comparison_angles.count(True) == 1: # two angles are the same - bravais_info = {'short_name': 'bct', - 'extended_name': 'body centered tetragonal', - 'index': 5, - } + bravais_info = { + 'short_name': 'bct', + 'extended_name': 'body centered tetragonal', + 'index': 5, + } if comparison_angles.index(True) == 0: # alfa=beta ref_ang = cosa bravais_info['permutation'] = [0, 1, 2] @@ -552,31 +530,39 @@ def a_are_equals(a, b): bravais_info['permutation'] = [1, 2, 0] if ref_ang >= 0.: - raise ValueError('Problems on the definition of ' - 'body centered tetragonal lattices') - the_c = numpy.sqrt(-4. * ref_ang * (a ** 2)) - the_a = numpy.sqrt(2. * a ** 2 - (the_c ** 2) / 2.) + raise ValueError('Problems on the definition of body centered tetragonal lattices') + the_c = numpy.sqrt(-4. * ref_ang * (a**2)) + the_a = numpy.sqrt(2. * a**2 - (the_c**2) / 2.) if the_c < the_a: bravais_info['variation'] = 'bct1' - bravais_info['extra'] = {'eta': (1. + (the_c / the_a) ** 2) / 4.} + bravais_info['extra'] = {'eta': (1. + (the_c / the_a)**2) / 4.} else: bravais_info['variation'] = 'bct2' - bravais_info['extra'] = {'eta': (1. + (the_a / the_c) ** 2) / 4., - 'csi': ((the_a / the_c) ** 2) / 2., - } + bravais_info['extra'] = { + 'eta': (1. + (the_a / the_c)**2) / 4., + 'csi': ((the_a / the_c)**2) / 2., + } # ============================# # body centered orthorhombic # # ============================# - elif (any([a_are_equals(cosa, bco1), a_are_equals(cosb, bco1), a_are_equals(cosc, bco1)]) and - any([a_are_equals(cosa, bco2), a_are_equals(cosb, bco2), a_are_equals(cosc, bco2)]) and - any([a_are_equals(cosa, bco3), a_are_equals(cosb, bco3), a_are_equals(cosc, bco3)]) - ): - bravais_info = {'short_name': 'orci', - 'extended_name': 'body centered orthorhombic', - 'index': 8, - } + elif ( + any([a_are_equals(cosa, bco1), + a_are_equals(cosb, bco1), + a_are_equals(cosc, bco1)]) and + any([a_are_equals(cosa, bco2), + a_are_equals(cosb, bco2), + a_are_equals(cosc, bco2)]) and + any([a_are_equals(cosa, bco3), + a_are_equals(cosb, bco3), + a_are_equals(cosc, bco3)]) + ): + bravais_info = { + 'short_name': 'orci', + 'extended_name': 'body centered orthorhombic', + 'index': 8, + } if a_are_equals(cosa, bco1) and a_are_equals(cosc, bco3): bravais_info['permutation'] = [0, 1, 2] if a_are_equals(cosa, bco1) and a_are_equals(cosc, bco2): @@ -590,58 +576,64 @@ def a_are_equals(a, b): if a_are_equals(cosa, bco3) and a_are_equals(cosc, bco1): bravais_info['permutation'] = [2, 1, 0] - bravais_info['extra'] = {'csi': (1. + (orci_the_a / orci_the_c) ** 2) / 4., - 'eta': (1. + (orci_the_b / orci_the_c) ** 2) / 4., - 'dlt': (orci_the_b ** 2 - orci_the_a ** 2) / (4. * orci_the_c ** 2), - 'mu': (orci_the_a ** 2 + orci_the_b ** 2) / (4. * orci_the_c ** 2), - } + bravais_info['extra'] = { + 'csi': (1. + (orci_the_a / orci_the_c)**2) / 4., + 'eta': (1. + (orci_the_b / orci_the_c)**2) / 4., + 'dlt': (orci_the_b**2 - orci_the_a**2) / (4. * orci_the_c**2), + 'mu': (orci_the_a**2 + orci_the_b**2) / (4. * orci_the_c**2), + } # if it doesn't fall in the above, is triclinic else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # the check for triclinic variations is at the end of the method - - elif comparison_length.count(True) == 1: + # ============# # tetragonal # # ============# if comparison_angles.count(True) == 3 and a_are_equals(cosa, _90): - bravais_info = {'short_name': 'tet', - 'extended_name': 'tetragonal', - 'index': 4, - } - if comparison_length[0] == True: + bravais_info = { + 'short_name': 'tet', + 'extended_name': 'tetragonal', + 'index': 4, + } + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] # ====================================# # c-centered orthorombic + hexagonal # # ====================================# # alpha/=beta=gamma=pi/2 - elif (comparison_angles.count(True) == 1 and - any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), a_are_equals(cosc, _90)]) - ): + elif ( + comparison_angles.count(True) == 1 and + any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), + a_are_equals(cosc, _90)]) + ): if any([a_are_equals(cosa, _120), a_are_equals(cosb, _120), a_are_equals(cosc, _120)]): - bravais_info = {'short_name': 'hex', - 'extended_name': 'hexagonal', - 'index': 10, - } + bravais_info = { + 'short_name': 'hex', + 'extended_name': 'hexagonal', + 'index': 10, + } else: - bravais_info = {'short_name': 'orcc', - 'extended_name': 'c-centered orthorhombic', - 'index': 9, - } - if comparison_length[0] == True: + bravais_info = { + 'short_name': 'orcc', + 'extended_name': 'c-centered orthorhombic', + 'index': 9, + } + if comparison_length[0]: the_a1 = a1 the_a2 = a2 - elif comparison_length[1] == True: + elif comparison_length[1]: the_a1 = a2 the_a2 = a3 else: # comparison_length[2]==True: @@ -649,38 +641,40 @@ def a_are_equals(a, b): the_a2 = a1 the_a = numpy.linalg.norm(the_a1 + the_a2) the_b = numpy.linalg.norm(the_a1 - the_a2) - bravais_info['extra'] = {'csi': (1. + (the_a / the_b) ** 2) / 4., - } + bravais_info['extra'] = { + 'csi': (1. + (the_a / the_b)**2) / 4., + } # TODO : re-check this case, permutations look weird - if comparison_length[0] == True: + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] # =======================# # c-centered monoclinic # # =======================# elif comparison_angles.count(True) == 1: - bravais_info = {'short_name': 'mclc', - 'extended_name': 'c-centered monoclinic', - 'index': 13, - } + bravais_info = { + 'short_name': 'mclc', + 'extended_name': 'c-centered monoclinic', + 'index': 13, + } # TODO : re-check this case, permutations look weird - if comparison_length[0] == True: + if comparison_length[0]: bravais_info['permutation'] = [0, 1, 2] the_ka = cosa the_a1 = a1 the_a2 = a2 the_c = c - if comparison_length[1] == True: + if comparison_length[1]: bravais_info['permutation'] = [2, 0, 1] the_ka = cosb the_a1 = a2 the_a2 = a3 the_c = a - if comparison_length[2] == True: + if comparison_length[2]: bravais_info['permutation'] = [1, 2, 0] the_ka = cosc the_a1 = a3 @@ -693,94 +687,99 @@ def a_are_equals(a, b): if a_are_equals(the_ka, _90): # order matters: has to be before the check on mclc1 bravais_info['variation'] = 'mclc2' - csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) - psi = 0.75 - the_a ** 2 / (4. * the_b * (1. - the_cosa ** 2)) - bravais_info['extra'] = {'csi': csi, - 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, - 'psi': psi, - 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, - } + csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) + psi = 0.75 - the_a**2 / (4. * the_b * (1. - the_cosa**2)) + bravais_info['extra'] = { + 'csi': csi, + 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, + 'psi': psi, + 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, + } elif the_ka < 0.: bravais_info['variation'] = 'mclc1' - csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) - psi = 0.75 - the_a ** 2 / (4. * the_b * (1. - the_cosa ** 2)) - bravais_info['extra'] = {'csi': csi, - 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, - 'psi': psi, - 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, - } + csi = (2. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) + psi = 0.75 - the_a**2 / (4. * the_b * (1. - the_cosa**2)) + bravais_info['extra'] = { + 'csi': csi, + 'eta': 0.5 + 2. * csi * the_c * the_cosa / the_b, + 'psi': psi, + 'phi': psi + (0.75 - psi) * the_b * the_cosa / the_c, + } else: # if the_ka>0.: - x = the_b * the_cosa / the_c + the_b ** 2 * (1. - the_cosa ** 2) / the_a ** 2 + x = the_b * the_cosa / the_c + the_b**2 * (1. - the_cosa**2) / the_a**2 if a_are_equals(x, 1.): bravais_info['variation'] = 'mclc4' # order matters here too - mu = (1. + (the_b / the_a) ** 2) / 4. - dlt = the_b * the_c * the_cosa / (2. * the_a ** 2) - csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) + mu = (1. + (the_b / the_a)**2) / 4. + dlt = the_b * the_c * the_cosa / (2. * the_a**2) + csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) eta = 0.5 + 2. * csi * the_c * the_cosa / the_b phi = 1. + eta - 2. * mu psi = eta - 2. * dlt - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'phi': phi, - 'psi': psi, - } + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'phi': phi, + 'psi': psi, + } elif x < 1.: bravais_info['variation'] = 'mclc3' - mu = (1. + (the_b / the_a) ** 2) / 4. - dlt = the_b * the_c * the_cosa / (2. * the_a ** 2) - csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa ** 2)) + mu = (1. + (the_b / the_a)**2) / 4. + dlt = the_b * the_c * the_cosa / (2. * the_a**2) + csi = mu - 0.25 + (1. - the_b * the_cosa / the_c) / (4. * (1. - the_cosa**2)) eta = 0.5 + 2. * csi * the_c * the_cosa / the_b phi = 1. + eta - 2. * mu psi = eta - 2. * dlt - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'phi': phi, - 'psi': psi, - } + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'phi': phi, + 'psi': psi, + } elif x > 1.: bravais_info['variation'] = 'mclc5' - csi = ((the_b / the_a) ** 2 + (1. - the_b * the_cosa / the_c) / (1. - the_cosa ** 2)) / 4. + csi = ((the_b / the_a)**2 + (1. - the_b * the_cosa / the_c) / (1. - the_cosa**2)) / 4. eta = 0.5 + 2. * csi * the_c * the_cosa / the_b - mu = eta / 2. + the_b ** 2 / 4. / the_a ** 2 - the_b * the_c * the_cosa / 2. / the_a ** 2 + mu = eta / 2. + the_b**2 / 4. / the_a**2 - the_b * the_c * the_cosa / 2. / the_a**2 nu = 2. * mu - csi - omg = (4. * nu - 1. - the_b ** 2 * (1. - the_cosa ** 2) / the_a ** 2) * the_c / ( - 2. * the_b * the_cosa) + omg = (4. * nu - 1. - the_b**2 * + (1. - the_cosa**2) / the_a**2) * the_c / (2. * the_b * the_cosa) dlt = csi * the_c * the_cosa / the_b + omg / 2. - 0.25 - rho = 1. - csi * the_a ** 2 / the_b ** 2 - bravais_info['extra'] = {'mu': mu, - 'dlt': dlt, - 'csi': csi, - 'eta': eta, - 'rho': rho, - } + rho = 1. - csi * the_a**2 / the_b**2 + bravais_info['extra'] = { + 'mu': mu, + 'dlt': dlt, + 'csi': csi, + 'eta': eta, + 'rho': rho, + } # if it doesn't fall in the above, is triclinic else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # the check for triclinic variations is at the end of the method - - else: # if comparison_length.count(True)==0: - fco1 = c ** 2 / numpy.sqrt((a ** 2 + c ** 2) * (b ** 2 + c ** 2)) - fco2 = a ** 2 / numpy.sqrt((a ** 2 + b ** 2) * (a ** 2 + c ** 2)) - fco3 = b ** 2 / numpy.sqrt((a ** 2 + b ** 2) * (b ** 2 + c ** 2)) + fco1 = c**2 / numpy.sqrt((a**2 + c**2) * (b**2 + c**2)) + fco2 = a**2 / numpy.sqrt((a**2 + b**2) * (a**2 + c**2)) + fco3 = b**2 / numpy.sqrt((a**2 + b**2) * (b**2 + c**2)) # ==============# # orthorhombic # # ==============# if comparison_angles.count(True) == 3: - bravais_info = {'short_name': 'orc', - 'extended_name': 'orthorhombic', - 'index': 6, - } + bravais_info = { + 'short_name': 'orc', + 'extended_name': 'orthorhombic', + 'index': 6, + } lens = [a, b, c] ind_a = lens.index(min(lens)) ind_c = lens.index(max(lens)) @@ -799,12 +798,16 @@ def a_are_equals(a, b): # ============# # monoclinic # # ============# - elif (comparison_angles.count(True) == 1 and - any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), a_are_equals(cosc, _90)])): - bravais_info = {'short_name': 'mcl', - 'extended_name': 'monoclinic', - 'index': 12, - } + elif ( + comparison_angles.count(True) == 1 and + any([a_are_equals(cosa, _90), a_are_equals(cosb, _90), + a_are_equals(cosc, _90)]) + ): + bravais_info = { + 'short_name': 'mcl', + 'extended_name': 'monoclinic', + 'index': 12, + } lens = [a, b, c] # find the angle different from 90 # then order (if possible) a 0.: bravais_info['variation'] = 'orcf1' - bravais_info['extra'] = {'csi': (1. + (the_a / the_b) ** 2 - (the_a / the_c) ** 2) / 4., - 'eta': (1. + (the_a / the_b) ** 2 + (the_a / the_c) ** 2) / 4., - } + bravais_info['extra'] = { + 'csi': (1. + (the_a / the_b)**2 - (the_a / the_c)**2) / 4., + 'eta': (1. + (the_a / the_b)**2 + (the_a / the_c)**2) / 4., + } # orcf2 else: bravais_info['variation'] = 'orcf2' - bravais_info['extra'] = {'eta': (1. + (the_a / the_b) ** 2 - (the_a / the_c) ** 2) / 4., - 'dlt': (1. + (the_b / the_a) ** 2 + (the_b / the_c) ** 2) / 4., - 'phi': (1. + (the_c / the_b) ** 2 - (the_c / the_a) ** 2) / 4., - } + bravais_info['extra'] = { + 'eta': (1. + (the_a / the_b)**2 - (the_a / the_c)**2) / 4., + 'dlt': (1. + (the_b / the_a)**2 + (the_b / the_c)**2) / 4., + 'phi': (1. + (the_c / the_b)**2 - (the_c / the_a)**2) / 4., + } else: - bravais_info = {'short_name': 'tri', - 'extended_name': 'triclinic', - 'index': 14, - } + bravais_info = { + 'short_name': 'tri', + 'extended_name': 'triclinic', + 'index': 14, + } # ===========# # triclinic # # ===========# @@ -1015,19 +1031,21 @@ def a_are_equals(a, b): # square lattice # # ================# if comparison_angle_90 and comparison_length: - bravais_info = {'short_name': 'sq', - 'extended_name': 'square', - 'index': 1, - } + bravais_info = { + 'short_name': 'sq', + 'extended_name': 'square', + 'index': 1, + } # =========================# # (primitive) rectangular # # =========================# elif comparison_angle_90: - bravais_info = {'short_name': 'rec', - 'extended_name': 'rectangular', - 'index': 2, - } + bravais_info = { + 'short_name': 'rec', + 'extended_name': 'rectangular', + 'index': 2, + } # set the order such that first_vector < second_vector in norm if lens[0] > lens[1]: in_plane_indexes.reverse() @@ -1037,30 +1055,31 @@ def a_are_equals(a, b): # ===========# # this has to be put before the centered-rectangular case elif (l_are_equals(lens[0], lens[1]) and a_are_equals(cosphi, _120)): - bravais_info = {'short_name': 'hex', - 'extended_name': 'hexagonal', - 'index': 4, - } + bravais_info = { + 'short_name': 'hex', + 'extended_name': 'hexagonal', + 'index': 4, + } # ======================# # centered rectangular # # ======================# - elif (comparison_length and - l_are_equals(numpy.dot(vectors[0] + vectors[1], - vectors[0] - vectors[1]), 0.)): - bravais_info = {'short_name': 'recc', - 'extended_name': 'centered rectangular', - 'index': 3, - } + elif (comparison_length and l_are_equals(numpy.dot(vectors[0] + vectors[1], vectors[0] - vectors[1]), 0.)): + bravais_info = { + 'short_name': 'recc', + 'extended_name': 'centered rectangular', + 'index': 3, + } # =========# # oblique # # =========# else: - bravais_info = {'short_name': 'obl', - 'extended_name': 'oblique', - 'index': 5, - } + bravais_info = { + 'short_name': 'obl', + 'extended_name': 'oblique', + 'index': 5, + } # set the order such that first_vector < second_vector in norm if lens[0] > lens[1]: in_plane_indexes.reverse() @@ -1098,10 +1117,9 @@ def a_are_equals(a, b): return bravais_info - -def get_kpoints_path(cell, pbc=None, cartesian=False, - epsilon_length=_default_epsilon_length, - epsilon_angle=_default_epsilon_angle): +def get_kpoints_path( + cell, pbc=None, cartesian=False, epsilon_length=_default_epsilon_length, epsilon_angle=_default_epsilon_angle +): """ Get the special point and path of a given structure. @@ -1136,12 +1154,9 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, :note: We assume that the cell given by the cell property is the primitive unit cell """ + # pylint: disable=too-many-branches,too-many-statements # recognize which bravais lattice we are dealing with - bravais_info = find_bravais_info( - cell=cell, pbc=pbc, - epsilon_length=epsilon_length, - epsilon_angle=epsilon_angle - ) + bravais_info = find_bravais_info(cell=cell, pbc=pbc, epsilon_length=epsilon_length, epsilon_angle=epsilon_angle) analysis = analyze_cell(cell, pbc) dimension = analysis['dimension'] @@ -1153,98 +1168,108 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, # 3D case: 14 Bravais lattices # simple cubic if bravais_info['index'] == 1: - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0.5, 0.], - 'R': [0.5, 0.5, 0.5], - 'X': [0., 0.5, 0.], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'R'), - ('R', 'X'), - ('M', 'R'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0.5, 0.], + 'R': [0.5, 0.5, 0.5], + 'X': [0., 0.5, 0.], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'R'), + ('R', 'X'), + ('M', 'R'), + ] # face centered cubic elif bravais_info['index'] == 2: - special_points = {'G': [0., 0., 0.], - 'K': [3. / 8., 3. / 8., 0.75], - 'L': [0.5, 0.5, 0.5], - 'U': [5. / 8., 0.25, 5. / 8.], - 'W': [0.5, 0.25, 0.75], - 'X': [0.5, 0., 0.5], - } - path = [('G', 'X'), - ('X', 'W'), - ('W', 'K'), - ('K', 'G'), - ('G', 'L'), - ('L', 'U'), - ('U', 'W'), - ('W', 'L'), - ('L', 'K'), - ('U', 'X'), - ] + special_points = { + 'G': [0., 0., 0.], + 'K': [3. / 8., 3. / 8., 0.75], + 'L': [0.5, 0.5, 0.5], + 'U': [5. / 8., 0.25, 5. / 8.], + 'W': [0.5, 0.25, 0.75], + 'X': [0.5, 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'W'), + ('W', 'K'), + ('K', 'G'), + ('G', 'L'), + ('L', 'U'), + ('U', 'W'), + ('W', 'L'), + ('L', 'K'), + ('U', 'X'), + ] # body centered cubic elif bravais_info['index'] == 3: - special_points = {'G': [0., 0., 0.], - 'H': [0.5, -0.5, 0.5], - 'P': [0.25, 0.25, 0.25], - 'N': [0., 0., 0.5], - } - path = [('G', 'H'), - ('H', 'N'), - ('N', 'G'), - ('G', 'P'), - ('P', 'H'), - ('P', 'N'), - ] + special_points = { + 'G': [0., 0., 0.], + 'H': [0.5, -0.5, 0.5], + 'P': [0.25, 0.25, 0.25], + 'N': [0., 0., 0.5], + } + path = [ + ('G', 'H'), + ('H', 'N'), + ('N', 'G'), + ('G', 'P'), + ('P', 'H'), + ('P', 'N'), + ] # Tetragonal elif bravais_info['index'] == 4: - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5, 0.5], - 'M': [0.5, 0.5, 0.], - 'R': [0., 0.5, 0.5], - 'X': [0., 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'Z'), - ('Z', 'R'), - ('R', 'A'), - ('A', 'Z'), - ('X', 'R'), - ('M', 'A'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5, 0.5], + 'M': [0.5, 0.5, 0.], + 'R': [0., 0.5, 0.5], + 'X': [0., 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'Z'), + ('Z', 'R'), + ('R', 'A'), + ('A', 'Z'), + ('X', 'R'), + ('M', 'A'), + ] # body centered tetragonal elif bravais_info['index'] == 5: if bravais_info['variation'] == 'bct1': # Body centered tetragonal bct1 eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'M': [-0.5, 0.5, 0.5], - 'N': [0., 0.5, 0.], - 'P': [0.25, 0.25, 0.25], - 'X': [0., 0., 0.5], - 'Z': [eta, eta, -eta], - 'Z1': [-eta, 1. - eta, eta], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ('G', 'Z'), - ('Z', 'P'), - ('P', 'N'), - ('N', 'Z1'), - ('Z1', 'M'), - ('X', 'P'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [-0.5, 0.5, 0.5], + 'N': [0., 0.5, 0.], + 'P': [0.25, 0.25, 0.25], + 'X': [0., 0., 0.5], + 'Z': [eta, eta, -eta], + 'Z1': [-eta, 1. - eta, eta], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ('G', 'Z'), + ('Z', 'P'), + ('P', 'N'), + ('N', 'Z1'), + ('Z1', 'M'), + ('X', 'P'), + ] else: # bct2 # Body centered tetragonal bct2 @@ -1261,127 +1286,136 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y1': [0.5, 0.5, -csi], 'Z': [0.5, 0.5, -0.5], } - path = [('G', 'X'), - ('X', 'Y'), - ('Y', 'S'), - ('S', 'G'), - ('G', 'Z'), - ('Z', 'S1'), - ('S1', 'N'), - ('N', 'P'), - ('P', 'Y1'), - ('Y1', 'Z'), - ('X', 'P'), - ] + path = [ + ('G', 'X'), + ('X', 'Y'), + ('Y', 'S'), + ('S', 'G'), + ('G', 'Z'), + ('Z', 'S1'), + ('S1', 'N'), + ('N', 'P'), + ('P', 'Y1'), + ('Y1', 'Z'), + ('X', 'P'), + ] # orthorhombic elif bravais_info['index'] == 6: - special_points = {'G': [0., 0., 0.], - 'R': [0.5, 0.5, 0.5], - 'S': [0.5, 0.5, 0.], - 'T': [0., 0.5, 0.5], - 'U': [0.5, 0., 0.5], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'Y'), - ('Y', 'G'), - ('G', 'Z'), - ('Z', 'U'), - ('U', 'R'), - ('R', 'T'), - ('T', 'Z'), - ('Y', 'T'), - ('U', 'X'), - ('S', 'R'), - ] + special_points = { + 'G': [0., 0., 0.], + 'R': [0.5, 0.5, 0.5], + 'S': [0.5, 0.5, 0.], + 'T': [0., 0.5, 0.5], + 'U': [0.5, 0., 0.5], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'Y'), + ('Y', 'G'), + ('G', 'Z'), + ('Z', 'U'), + ('U', 'R'), + ('R', 'T'), + ('T', 'Z'), + ('Y', 'T'), + ('U', 'X'), + ('S', 'R'), + ] # face centered orthorhombic elif bravais_info['index'] == 7: if bravais_info['variation'] == 'orcf1': csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5 + csi, csi], - 'A1': [0.5, 0.5 - csi, 1. - csi], - 'L': [0.5, 0.5, 0.5], - 'T': [1., 0.5, 0.5], - 'X': [0., eta, eta], - 'X1': [1., 1. - eta, 1. - eta], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'T'), - ('T', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('X', 'A1'), - ('A1', 'Y'), - ('T', 'X1'), - ('X', 'A'), - ('A', 'Z'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5 + csi, csi], + 'A1': [0.5, 0.5 - csi, 1. - csi], + 'L': [0.5, 0.5, 0.5], + 'T': [1., 0.5, 0.5], + 'X': [0., eta, eta], + 'X1': [1., 1. - eta, 1. - eta], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'T'), + ('T', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('X', 'A1'), + ('A1', 'Y'), + ('T', 'X1'), + ('X', 'A'), + ('A', 'Z'), + ('L', 'G'), + ] elif bravais_info['variation'] == 'orcf2': eta = bravais_info['extra']['eta'] dlt = bravais_info['extra']['dlt'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'C': [0.5, 0.5 - eta, 1. - eta], - 'C1': [0.5, 0.5 + eta, eta], - 'D': [0.5 - dlt, 0.5, 1. - dlt], - 'D1': [0.5 + dlt, 0.5, dlt], - 'L': [0.5, 0.5, 0.5], - 'H': [1. - phi, 0.5 - phi, 0.5], - 'H1': [phi, 0.5 + phi, 0.5], - 'X': [0., 0.5, 0.5], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'C'), - ('C', 'D'), - ('D', 'X'), - ('X', 'G'), - ('G', 'Z'), - ('Z', 'D1'), - ('D1', 'H'), - ('H', 'C'), - ('C1', 'Z'), - ('X', 'H1'), - ('H', 'Y'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'C': [0.5, 0.5 - eta, 1. - eta], + 'C1': [0.5, 0.5 + eta, eta], + 'D': [0.5 - dlt, 0.5, 1. - dlt], + 'D1': [0.5 + dlt, 0.5, dlt], + 'L': [0.5, 0.5, 0.5], + 'H': [1. - phi, 0.5 - phi, 0.5], + 'H1': [phi, 0.5 + phi, 0.5], + 'X': [0., 0.5, 0.5], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'C'), + ('C', 'D'), + ('D', 'X'), + ('X', 'G'), + ('G', 'Z'), + ('Z', 'D1'), + ('D1', 'H'), + ('H', 'C'), + ('C1', 'Z'), + ('X', 'H1'), + ('H', 'Y'), + ('L', 'G'), + ] else: csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5 + csi, csi], - 'A1': [0.5, 0.5 - csi, 1. - csi], - 'L': [0.5, 0.5, 0.5], - 'T': [1., 0.5, 0.5], - 'X': [0., eta, eta], - 'X1': [1., 1. - eta, 1. - eta], - 'Y': [0.5, 0., 0.5], - 'Z': [0.5, 0.5, 0.], - } - path = [('G', 'Y'), - ('Y', 'T'), - ('T', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('X', 'A1'), - ('A1', 'Y'), - ('X', 'A'), - ('A', 'Z'), - ('L', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5 + csi, csi], + 'A1': [0.5, 0.5 - csi, 1. - csi], + 'L': [0.5, 0.5, 0.5], + 'T': [1., 0.5, 0.5], + 'X': [0., eta, eta], + 'X1': [1., 1. - eta, 1. - eta], + 'Y': [0.5, 0., 0.5], + 'Z': [0.5, 0.5, 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'T'), + ('T', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('X', 'A1'), + ('A1', 'Y'), + ('X', 'A'), + ('A', 'Z'), + ('L', 'G'), + ] # Body centered orthorhombic elif bravais_info['index'] == 8: @@ -1389,168 +1423,180 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, dlt = bravais_info['extra']['dlt'] eta = bravais_info['extra']['eta'] mu = bravais_info['extra']['mu'] - special_points = {'G': [0., 0., 0.], - 'L': [-mu, mu, 0.5 - dlt], - 'L1': [mu, -mu, 0.5 + dlt], - 'L2': [0.5 - dlt, 0.5 + dlt, -mu], - 'R': [0., 0.5, 0.], - 'S': [0.5, 0., 0.], - 'T': [0., 0., 0.5], - 'W': [0.25, 0.25, 0.25], - 'X': [-csi, csi, csi], - 'X1': [csi, 1. - csi, -csi], - 'Y': [eta, -eta, eta], - 'Y1': [1. - eta, eta, -eta], - 'Z': [0.5, 0.5, -0.5], - } - path = [('G', 'X'), - ('X', 'L'), - ('L', 'T'), - ('T', 'W'), - ('W', 'R'), - ('R', 'X1'), - ('X1', 'Z'), - ('Z', 'G'), - ('G', 'Y'), - ('Y', 'S'), - ('S', 'W'), - ('L1', 'Y'), - ('Y1', 'Z'), - ] + special_points = { + 'G': [0., 0., 0.], + 'L': [-mu, mu, 0.5 - dlt], + 'L1': [mu, -mu, 0.5 + dlt], + 'L2': [0.5 - dlt, 0.5 + dlt, -mu], + 'R': [0., 0.5, 0.], + 'S': [0.5, 0., 0.], + 'T': [0., 0., 0.5], + 'W': [0.25, 0.25, 0.25], + 'X': [-csi, csi, csi], + 'X1': [csi, 1. - csi, -csi], + 'Y': [eta, -eta, eta], + 'Y1': [1. - eta, eta, -eta], + 'Z': [0.5, 0.5, -0.5], + } + path = [ + ('G', 'X'), + ('X', 'L'), + ('L', 'T'), + ('T', 'W'), + ('W', 'R'), + ('R', 'X1'), + ('X1', 'Z'), + ('Z', 'G'), + ('G', 'Y'), + ('Y', 'S'), + ('S', 'W'), + ('L1', 'Y'), + ('Y1', 'Z'), + ] # C-centered orthorhombic elif bravais_info['index'] == 9: csi = bravais_info['extra']['csi'] - special_points = {'G': [0., 0., 0.], - 'A': [csi, csi, 0.5], - 'A1': [-csi, 1. - csi, 0.5], - 'R': [0., 0.5, 0.5], - 'S': [0., 0.5, 0.], - 'T': [-0.5, 0.5, 0.5], - 'X': [csi, csi, 0.], - 'X1': [-csi, 1. - csi, 0.], - 'Y': [-0.5, 0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'R'), - ('R', 'A'), - ('A', 'Z'), - ('Z', 'G'), - ('G', 'Y'), - ('Y', 'X1'), - ('X1', 'A1'), - ('A1', 'T'), - ('T', 'Y'), - ('Z', 'T'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [csi, csi, 0.5], + 'A1': [-csi, 1. - csi, 0.5], + 'R': [0., 0.5, 0.5], + 'S': [0., 0.5, 0.], + 'T': [-0.5, 0.5, 0.5], + 'X': [csi, csi, 0.], + 'X1': [-csi, 1. - csi, 0.], + 'Y': [-0.5, 0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'R'), + ('R', 'A'), + ('A', 'Z'), + ('Z', 'G'), + ('G', 'Y'), + ('Y', 'X1'), + ('X1', 'A1'), + ('A1', 'T'), + ('T', 'Y'), + ('Z', 'T'), + ] # Hexagonal elif bravais_info['index'] == 10: - special_points = {'G': [0., 0., 0.], - 'A': [0., 0., 0.5], - 'H': [1. / 3., 1. / 3., 0.5], - 'K': [1. / 3., 1. / 3., 0.], - 'L': [0.5, 0., 0.5], - 'M': [0.5, 0., 0.], - } - path = [('G', 'M'), - ('M', 'K'), - ('K', 'G'), - ('G', 'A'), - ('A', 'L'), - ('L', 'H'), - ('H', 'A'), - ('L', 'M'), - ('K', 'H'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0., 0., 0.5], + 'H': [1. / 3., 1. / 3., 0.5], + 'K': [1. / 3., 1. / 3., 0.], + 'L': [0.5, 0., 0.5], + 'M': [0.5, 0., 0.], + } + path = [ + ('G', 'M'), + ('M', 'K'), + ('K', 'G'), + ('G', 'A'), + ('A', 'L'), + ('L', 'H'), + ('H', 'A'), + ('L', 'M'), + ('K', 'H'), + ] # rhombohedral elif bravais_info['index'] == 11: if bravais_info['variation'] == 'rhl1': eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'B': [eta, 0.5, 1. - eta], - 'B1': [0.5, 1. - eta, eta - 1.], - 'F': [0.5, 0.5, 0.], - 'L': [0.5, 0., 0.], - 'L1': [0., 0., -0.5], - 'P': [eta, nu, nu], - 'P1': [1. - nu, 1. - nu, 1. - eta], - 'P2': [nu, nu, eta - 1.], - 'Q': [1. - nu, nu, 0.], - 'X': [nu, 0., -nu], - 'Z': [0.5, 0.5, 0.5], - } - path = [('G', 'L'), - ('L', 'B1'), - ('B', 'Z'), - ('Z', 'G'), - ('G', 'X'), - ('Q', 'F'), - ('F', 'P1'), - ('P1', 'Z'), - ('L', 'P'), - ] + special_points = { + 'G': [0., 0., 0.], + 'B': [eta, 0.5, 1. - eta], + 'B1': [0.5, 1. - eta, eta - 1.], + 'F': [0.5, 0.5, 0.], + 'L': [0.5, 0., 0.], + 'L1': [0., 0., -0.5], + 'P': [eta, nu, nu], + 'P1': [1. - nu, 1. - nu, 1. - eta], + 'P2': [nu, nu, eta - 1.], + 'Q': [1. - nu, nu, 0.], + 'X': [nu, 0., -nu], + 'Z': [0.5, 0.5, 0.5], + } + path = [ + ('G', 'L'), + ('L', 'B1'), + ('B', 'Z'), + ('Z', 'G'), + ('G', 'X'), + ('Q', 'F'), + ('F', 'P1'), + ('P1', 'Z'), + ('L', 'P'), + ] else: # Rhombohedral rhl2 eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'F': [0.5, -0.5, 0.], - 'L': [0.5, 0., 0.], - 'P': [1. - nu, -nu, 1. - nu], - 'P1': [nu, nu - 1., nu - 1.], - 'Q': [eta, eta, eta], - 'Q1': [1. - eta, -eta, -eta], - 'Z': [0.5, -0.5, 0.5], - } - path = [('G', 'P'), - ('P', 'Z'), - ('Z', 'Q'), - ('Q', 'G'), - ('G', 'F'), - ('F', 'P1'), - ('P1', 'Q1'), - ('Q1', 'L'), - ('L', 'Z'), - ] + special_points = { + 'G': [0., 0., 0.], + 'F': [0.5, -0.5, 0.], + 'L': [0.5, 0., 0.], + 'P': [1. - nu, -nu, 1. - nu], + 'P1': [nu, nu - 1., nu - 1.], + 'Q': [eta, eta, eta], + 'Q1': [1. - eta, -eta, -eta], + 'Z': [0.5, -0.5, 0.5], + } + path = [ + ('G', 'P'), + ('P', 'Z'), + ('Z', 'Q'), + ('Q', 'G'), + ('G', 'F'), + ('F', 'P1'), + ('P1', 'Q1'), + ('Q1', 'L'), + ('L', 'Z'), + ] # monoclinic elif bravais_info['index'] == 12: eta = bravais_info['extra']['eta'] nu = bravais_info['extra']['nu'] - special_points = {'G': [0., 0., 0.], - 'A': [0.5, 0.5, 0.], - 'C': [0., 0.5, 0.5], - 'D': [0.5, 0., 0.5], - 'D1': [0.5, 0., -0.5], - 'E': [0.5, 0.5, 0.5], - 'H': [0., eta, 1. - nu], - 'H1': [0., 1. - eta, nu], - 'H2': [0., eta, -nu], - 'M': [0.5, eta, 1. - nu], - 'M1': [0.5, 1. - eta, nu], - 'M2': [0.5, eta, -nu], - 'X': [0., 0.5, 0.], - 'Y': [0., 0., 0.5], - 'Y1': [0., 0., -0.5], - 'Z': [0.5, 0., 0.], - } - path = [('G', 'Y'), - ('Y', 'H'), - ('H', 'C'), - ('C', 'E'), - ('E', 'M1'), - ('M1', 'A'), - ('A', 'X'), - ('X', 'H1'), - ('M', 'D'), - ('D', 'Z'), - ('Y', 'D'), - ] + special_points = { + 'G': [0., 0., 0.], + 'A': [0.5, 0.5, 0.], + 'C': [0., 0.5, 0.5], + 'D': [0.5, 0., 0.5], + 'D1': [0.5, 0., -0.5], + 'E': [0.5, 0.5, 0.5], + 'H': [0., eta, 1. - nu], + 'H1': [0., 1. - eta, nu], + 'H2': [0., eta, -nu], + 'M': [0.5, eta, 1. - nu], + 'M1': [0.5, 1. - eta, nu], + 'M2': [0.5, eta, -nu], + 'X': [0., 0.5, 0.], + 'Y': [0., 0., 0.5], + 'Y1': [0., 0., -0.5], + 'Z': [0.5, 0., 0.], + } + path = [ + ('G', 'Y'), + ('Y', 'H'), + ('H', 'C'), + ('C', 'E'), + ('E', 'M1'), + ('M1', 'A'), + ('A', 'X'), + ('X', 'H1'), + ('M', 'D'), + ('D', 'Z'), + ('Y', 'D'), + ] elif bravais_info['index'] == 13: if bravais_info['variation'] == 'mclc1': @@ -1558,68 +1604,72 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, eta = bravais_info['extra']['eta'] psi = bravais_info['extra']['psi'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'F': [1. - csi, 1. - csi, 1. - eta], - 'F1': [csi, csi, eta], - 'F2': [csi, -csi, 1. - eta], - 'F3': [1. - csi, -csi, 1. - eta], - 'I': [phi, 1. - phi, 0.5], - 'I1': [1. - phi, phi - 1., 0.5], - 'L': [0.5, 0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'X': [1. - psi, psi - 1., 0.], - 'X1': [psi, 1. - psi, 0.], - 'X2': [psi - 1., -psi, 0.], - 'Y': [0.5, 0.5, 0.], - 'Y1': [-0.5, -0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'F1'), - ('Y', 'X1'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'F': [1. - csi, 1. - csi, 1. - eta], + 'F1': [csi, csi, eta], + 'F2': [csi, -csi, 1. - eta], + 'F3': [1. - csi, -csi, 1. - eta], + 'I': [phi, 1. - phi, 0.5], + 'I1': [1. - phi, phi - 1., 0.5], + 'L': [0.5, 0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'X': [1. - psi, psi - 1., 0.], + 'X1': [psi, 1. - psi, 0.], + 'X2': [psi - 1., -psi, 0.], + 'Y': [0.5, 0.5, 0.], + 'Y1': [-0.5, -0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'F1'), + ('Y', 'X1'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] elif bravais_info['variation'] == 'mclc2': csi = bravais_info['extra']['csi'] eta = bravais_info['extra']['eta'] psi = bravais_info['extra']['psi'] phi = bravais_info['extra']['phi'] - special_points = {'G': [0., 0., 0.], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'F': [1. - csi, 1. - csi, 1. - eta], - 'F1': [csi, csi, eta], - 'F2': [csi, -csi, 1. - eta], - 'F3': [1. - csi, -csi, 1. - eta], - 'I': [phi, 1. - phi, 0.5], - 'I1': [1. - phi, phi - 1., 0.5], - 'L': [0.5, 0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'X': [1. - psi, psi - 1., 0.], - 'X1': [psi, 1. - psi, 0.], - 'X2': [psi - 1., -psi, 0.], - 'Y': [0.5, 0.5, 0.], - 'Y1': [-0.5, -0.5, 0.], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'F1'), - ('N', 'G'), - ('G', 'M'), - ] + special_points = { + 'G': [0., 0., 0.], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'F': [1. - csi, 1. - csi, 1. - eta], + 'F1': [csi, csi, eta], + 'F2': [csi, -csi, 1. - eta], + 'F3': [1. - csi, -csi, 1. - eta], + 'I': [phi, 1. - phi, 0.5], + 'I1': [1. - phi, phi - 1., 0.5], + 'L': [0.5, 0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'X': [1. - psi, psi - 1., 0.], + 'X1': [psi, 1. - psi, 0.], + 'X2': [psi - 1., -psi, 0.], + 'Y': [0.5, 0.5, 0.], + 'Y1': [-0.5, -0.5, 0.], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'F1'), + ('N', 'G'), + ('G', 'M'), + ] elif bravais_info['variation'] == 'mclc3': mu = bravais_info['extra']['mu'] @@ -1647,18 +1697,19 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y3': [mu, mu - 1., dlt], 'Z': [0., 0., 0.5], } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'H'), - ('H', 'Z'), - ('Z', 'I'), - ('I', 'F1'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'F'), - ('G', 'N'), - ('M', 'G'), - ] + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'H'), + ('H', 'Z'), + ('Z', 'I'), + ('I', 'F1'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'F'), + ('G', 'N'), + ('M', 'G'), + ] elif bravais_info['variation'] == 'mclc4': mu = bravais_info['extra']['mu'] @@ -1667,35 +1718,37 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, eta = bravais_info['extra']['eta'] phi = bravais_info['extra']['phi'] psi = bravais_info['extra']['psi'] - special_points = {'G': [0., 0., 0.], - 'F': [1. - phi, 1 - phi, 1. - psi], - 'F1': [phi, phi - 1., psi], - 'F2': [1. - phi, -phi, 1. - psi], - 'H': [csi, csi, eta], - 'H1': [1. - csi, -csi, 1. - eta], - 'H2': [-csi, -csi, 1. - eta], - 'I': [0.5, -0.5, 0.5], - 'M': [0.5, 0., 0.5], - 'N': [0.5, 0., 0.], - 'N1': [0., -0.5, 0.], - 'X': [0.5, -0.5, 0.], - 'Y': [mu, mu, dlt], - 'Y1': [1. - mu, -mu, -dlt], - 'Y2': [-mu, -mu, -dlt], - 'Y3': [mu, mu - 1., dlt], - 'Z': [0., 0., 0.5], - } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'H'), - ('H', 'Z'), - ('Z', 'I'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'F': [1. - phi, 1 - phi, 1. - psi], + 'F1': [phi, phi - 1., psi], + 'F2': [1. - phi, -phi, 1. - psi], + 'H': [csi, csi, eta], + 'H1': [1. - csi, -csi, 1. - eta], + 'H2': [-csi, -csi, 1. - eta], + 'I': [0.5, -0.5, 0.5], + 'M': [0.5, 0., 0.5], + 'N': [0.5, 0., 0.], + 'N1': [0., -0.5, 0.], + 'X': [0.5, -0.5, 0.], + 'Y': [mu, mu, dlt], + 'Y1': [1. - mu, -mu, -dlt], + 'Y2': [-mu, -mu, -dlt], + 'Y3': [mu, mu - 1., dlt], + 'Z': [0., 0., 0.5], + } + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'H'), + ('H', 'Z'), + ('Z', 'I'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] else: csi = bravais_info['extra']['csi'] @@ -1726,85 +1779,94 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, 'Y3': [mu, mu - 1., dlt], 'Z': [0., 0., 0.5], } - path = [('G', 'Y'), - ('Y', 'F'), - ('F', 'L'), - ('L', 'I'), - ('I1', 'Z'), - ('Z', 'H'), - ('H', 'F1'), - ('H1', 'Y1'), - ('Y1', 'X'), - ('X', 'G'), - ('G', 'N'), - ('M', 'G'), - ] + path = [ + ('G', 'Y'), + ('Y', 'F'), + ('F', 'L'), + ('L', 'I'), + ('I1', 'Z'), + ('Z', 'H'), + ('H', 'F1'), + ('H1', 'Y1'), + ('Y1', 'X'), + ('X', 'G'), + ('G', 'N'), + ('M', 'G'), + ] # triclinic elif bravais_info['index'] == 14: if bravais_info['variation'] == 'tri1a' or bravais_info['variation'] == 'tri2a': - special_points = {'G': [0.0, 0.0, 0.0], - 'L': [0.5, 0.5, 0.0], - 'M': [0.0, 0.5, 0.5], - 'N': [0.5, 0.0, 0.5], - 'R': [0.5, 0.5, 0.5], - 'X': [0.5, 0.0, 0.0], - 'Y': [0.0, 0.5, 0.0], - 'Z': [0.0, 0.0, 0.5], - } - path = [('X', 'G'), - ('G', 'Y'), - ('L', 'G'), - ('G', 'Z'), - ('N', 'G'), - ('G', 'M'), - ('R', 'G'), - ] + special_points = { + 'G': [0.0, 0.0, 0.0], + 'L': [0.5, 0.5, 0.0], + 'M': [0.0, 0.5, 0.5], + 'N': [0.5, 0.0, 0.5], + 'R': [0.5, 0.5, 0.5], + 'X': [0.5, 0.0, 0.0], + 'Y': [0.0, 0.5, 0.0], + 'Z': [0.0, 0.0, 0.5], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('L', 'G'), + ('G', 'Z'), + ('N', 'G'), + ('G', 'M'), + ('R', 'G'), + ] else: - special_points = {'G': [0.0, 0.0, 0.0], - 'L': [0.5, -0.5, 0.0], - 'M': [0.0, 0.0, 0.5], - 'N': [-0.5, -0.5, 0.5], - 'R': [0.0, -0.5, 0.5], - 'X': [0.0, -0.5, 0.0], - 'Y': [0.5, 0.0, 0.0], - 'Z': [-0.5, 0.0, 0.5], - } - path = [('X', 'G'), - ('G', 'Y'), - ('L', 'G'), - ('G', 'Z'), - ('N', 'G'), - ('G', 'M'), - ('R', 'G'), - ] + special_points = { + 'G': [0.0, 0.0, 0.0], + 'L': [0.5, -0.5, 0.0], + 'M': [0.0, 0.0, 0.5], + 'N': [-0.5, -0.5, 0.5], + 'R': [0.0, -0.5, 0.5], + 'X': [0.0, -0.5, 0.0], + 'Y': [0.5, 0.0, 0.0], + 'Z': [-0.5, 0.0, 0.5], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('L', 'G'), + ('G', 'Z'), + ('N', 'G'), + ('G', 'M'), + ('R', 'G'), + ] elif dimension == 2: # 2D case: 5 Bravais lattices if bravais_info['index'] == 1: # square - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0.5, 0.], - 'X': [0.5, 0., 0.], - } - path = [('G', 'X'), - ('X', 'M'), - ('M', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0.5, 0.], + 'X': [0.5, 0., 0.], + } + path = [ + ('G', 'X'), + ('X', 'M'), + ('M', 'G'), + ] elif bravais_info['index'] == 2: # (primitive) rectangular - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'S': [0.5, 0.5, 0.], - } - path = [('G', 'X'), - ('X', 'S'), - ('S', 'Y'), - ('Y', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'S': [0.5, 0.5, 0.], + } + path = [ + ('G', 'X'), + ('X', 'S'), + ('S', 'Y'), + ('Y', 'G'), + ] elif bravais_info['index'] == 3: # centered rectangular (rhombic) @@ -1815,57 +1877,67 @@ def get_kpoints_path(cell, pbc=None, cartesian=False, # coordinates (primitive reciprocal cell) as for the rest. # Ramirez & Bohn gave them initially in (s1=b1+b2, s2=-b1+b2) # coordinates, i.e. using the conventional reciprocal cell. - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0.5, 0.], - 'Y1': [0.25, 0.75, 0.], - 'Y': [-0.25, 0.25, 0.], # typo in p. 404 of Ramirez & Bohm (should be Y=(0,1/4)) - 'C': [0., 0.5, 0.], - } - path = [('Y1', 'X'), - ('X', 'G'), - ('G', 'Y'), - ('Y', 'C'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0.5, 0.], + 'Y1': [0.25, 0.75, 0.], + 'Y': [-0.25, 0.25, 0.], # typo in p. 404 of Ramirez & Bohm (should be Y=(0,1/4)) + 'C': [0., 0.5, 0.], + } + path = [ + ('Y1', 'X'), + ('X', 'G'), + ('G', 'Y'), + ('Y', 'C'), + ] elif bravais_info['index'] == 4: # hexagonal - special_points = {'G': [0., 0., 0.], - 'M': [0.5, 0., 0.], - 'K': [1. / 3., 1. / 3., 0.], - } - path = [('G', 'M'), - ('M', 'K'), - ('K', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'M': [0.5, 0., 0.], + 'K': [1. / 3., 1. / 3., 0.], + } + path = [ + ('G', 'M'), + ('M', 'K'), + ('K', 'G'), + ] elif bravais_info['index'] == 5: # oblique # NOTE: only end-points are high-symmetry points (not the path # in-between) - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - 'Y': [0., 0.5, 0.], - 'A': [0.5, 0.5, 0.], - } - path = [('X', 'G'), - ('G', 'Y'), - ('A', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + 'Y': [0., 0.5, 0.], + 'A': [0.5, 0.5, 0.], + } + path = [ + ('X', 'G'), + ('G', 'Y'), + ('A', 'G'), + ] elif dimension == 1: # 1D case: 1 Bravais lattice - special_points = {'G': [0., 0., 0.], - 'X': [0.5, 0., 0.], - } - path = [('G', 'X'), - ] + special_points = { + 'G': [0., 0., 0.], + 'X': [0.5, 0., 0.], + } + path = [ + ('G', 'X'), + ] elif dimension == 0: # 0D case: 1 Bravais lattice, only Gamma point, no path - special_points = {'G': [0., 0., 0.], - } - path = [('G', 'G'), - ] + special_points = { + 'G': [0., 0., 0.], + } + path = [ + ('G', 'G'), + ] permutation = bravais_info['permutation'] @@ -1873,23 +1945,19 @@ def permute(x, permutation): # return new_x such that new_x[i]=x[permutation[i]] return [x[int(p)] for p in permutation] - def invpermute(permutation): - # return the inverse of permutation - return [permutation.index(i) for i in range(3)] - the_special_points = {} - for k in special_points.keys(): + for key in special_points: # NOTE: this originally returned the inverse of the permutation, but was later changed to permutation - the_special_points[k] = permute(special_points[k], permutation) + the_special_points[key] = permute(special_points[key], permutation) # output crystal or cartesian if cartesian: the_abs_special_points = {} - for k in the_special_points.keys(): - the_abs_special_points[k] = change_reference( - reciprocal_cell, numpy.array(the_special_points[k]), to_cartesian=True + for key in the_special_points: + the_abs_special_points[key] = change_reference( + reciprocal_cell, numpy.array(the_special_points[key]), to_cartesian=True ) return the_abs_special_points, path, bravais_info - else: - return the_special_points, path, bravais_info + + return the_special_points, path, bravais_info diff --git a/aiida/tools/data/array/kpoints/seekpath.py b/aiida/tools/data/array/kpoints/seekpath.py index 176852fef0..9aa9dc36bb 100644 --- a/aiida/tools/data/array/kpoints/seekpath.py +++ b/aiida/tools/data/array/kpoints/seekpath.py @@ -7,22 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tool to automatically determine k-points for a given structure using SeeK-path.""" +import seekpath from aiida.orm import KpointsData, Dict -__all__ = ('check_seekpath_is_installed', 'get_explicit_kpoints_path', 'get_kpoints_path') - - -def check_seekpath_is_installed(): - """ - Tries to import the Seekpath module. Raise ImportError if it cannot be imported - - :raises: ImportError - """ - try: - import seekpath - except ImportError: - raise ImportError("Seekpath is not installed, please install with 'pip install seekpath'") +__all__ = ('get_explicit_kpoints_path', 'get_kpoints_path') def get_explicit_kpoints_path(structure, parameters): @@ -65,11 +55,9 @@ def get_explicit_kpoints_path(structure, parameters): - ``conv_structure``: A StructureData with the primitive structure """ + # pylint: disable=too-many-locals from aiida.tools.data.structure import spglib_tuple_to_structure, structure_to_spglib_tuple - check_seekpath_is_installed() - import seekpath - structure_tuple, kind_info, kinds = structure_to_spglib_tuple(structure) result = {} @@ -92,7 +80,6 @@ def get_explicit_kpoints_path(structure, parameters): # Remove reciprocal_primitive_lattice, recalculated by kpoints class rawdict.pop('reciprocal_primitive_lattice') kpoints_abs = rawdict.pop('explicit_kpoints_abs') - kpoints_rel = rawdict.pop('explicit_kpoints_rel') kpoints_labels = rawdict.pop('explicit_kpoints_labels') # set_kpoints expects labels like [[0,'X'],[34,'L'],...], so generate it here skipping empty labels @@ -146,9 +133,6 @@ def get_kpoints_path(structure, parameters): """ from aiida.tools.data.structure import spglib_tuple_to_structure, structure_to_spglib_tuple - check_seekpath_is_installed() - import seekpath - structure_tuple, kind_info, kinds = structure_to_spglib_tuple(structure) result = {} diff --git a/aiida/tools/data/cif.py b/aiida/tools/data/cif.py index 813ffa7fcc..e8e613ff15 100644 --- a/aiida/tools/data/cif.py +++ b/aiida/tools/data/cif.py @@ -203,8 +203,7 @@ def refine_inline(node): # Summary formula has to be calculated from non-reduced set of atoms. cif.values[name]['_chemical_formula_sum'] = \ - StructureData(ase=original_atoms).get_formula(mode='hill', - separator=' ') + StructureData(ase=original_atoms).get_formula(mode='hill', separator=' ') # If the number of reduced atoms multiplies the number of non-reduced # atoms, the new Z value can be calculated. diff --git a/aiida/tools/data/structure/__init__.py b/aiida/tools/data/structure/__init__.py index e75e893825..d308d04373 100644 --- a/aiida/tools/data/structure/__init__.py +++ b/aiida/tools/data/structure/__init__.py @@ -165,7 +165,7 @@ def spglib_tuple_to_structure(structure_tuple, kind_info=None, kinds=None): # p structure.append_kind(k) abs_pos = np.dot(rel_pos, cell) if len(abs_pos) != len(site_kinds): - raise ValueError('The length of the positions array is different from the ' 'length of the element numbers') + raise ValueError('The length of the positions array is different from the length of the element numbers') for kind, pos in zip(site_kinds, abs_pos): structure.append_site(Site(kind_name=kind.name, position=pos)) diff --git a/aiida/tools/dbimporters/__init__.py b/aiida/tools/dbimporters/__init__.py index 7ce146e69d..7cef45c164 100644 --- a/aiida/tools/dbimporters/__init__.py +++ b/aiida/tools/dbimporters/__init__.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module for plugins to import data from external databases into an AiiDA database.""" from .baseclasses import DbImporter __all__ = ('DbImporter',) diff --git a/aiida/tools/dbimporters/baseclasses.py b/aiida/tools/dbimporters/baseclasses.py index 28f6a52d17..64db13dac1 100644 --- a/aiida/tools/dbimporters/baseclasses.py +++ b/aiida/tools/dbimporters/baseclasses.py @@ -7,17 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Base class implementation for an external database importer.""" +import io class DbImporter: - """ - Base class for database importers. - """ + """Base class implementation for an external database importer.""" def query(self, **kwargs): - """ - Method to query the database. + """Method to query the database. :param id: database-specific entry identificator :param element: element name from periodic table of elements @@ -86,14 +84,14 @@ class DbSearchResults: ``__getitem__``. """ + _return_class = None + def __init__(self, results): self._results = results self._entries = {} class DbSearchResultsIterator: - """ - Iterator for search results - """ + """Iterator for search results.""" def __init__(self, results, increment=1): self._results = results @@ -101,12 +99,13 @@ def __init__(self, results, increment=1): self._increment = increment def __next__(self): + """Return the next entry in the iterator.""" pos = self._position - if pos >= 0 and pos < len(self._results): + if pos >= 0 and pos < len(self._results): # pylint: disable=chained-comparison self._position = self._position + self._increment return self._results[pos] - else: - raise StopIteration() + + raise StopIteration() def __iter__(self): """ @@ -141,7 +140,7 @@ def next(self): """ raise NotImplementedError('not implemented in base class') - def at(self, position): + def at(self, position): # pylint: disable=invalid-name """ Returns ``position``-th result as :py:class:`aiida.tools.dbimporters.baseclasses.DbEntry`. @@ -155,7 +154,7 @@ def at(self, position): if position not in self._entries: source_dict = self._get_source_dict(self._results[position]) url = self._get_url(self._results[position]) - self._entries[position] = self._return_class(url, **source_dict) + self._entries[position] = self._return_class(url, **source_dict) # pylint: disable=not-callable return self._entries[position] def _get_source_dict(self, result_dict): @@ -182,8 +181,7 @@ class DbEntry: """ _license = None - def __init__(self, db_name=None, db_uri=None, id=None, - version=None, extras={}, uri=None): + def __init__(self, db_name=None, db_uri=None, id=None, version=None, extras=None, uri=None): # pylint: disable=too-many-arguments,redefined-builtin """ Sets the basic parameters for the database entry: @@ -200,7 +198,7 @@ def __init__(self, db_name=None, db_uri=None, id=None, 'db_uri': db_uri, 'id': id, 'version': version, - 'extras': extras, + 'extras': extras or {}, 'uri': uri, 'source_md5': None, 'license': self._license, @@ -208,11 +206,13 @@ def __init__(self, db_name=None, db_uri=None, id=None, self._contents = None def __repr__(self): - return '{}({})'.format(self.__class__.__name__, - ','.join(['{}={}'.format(k, '"{}"'.format(self.source[k]) - if issubclass(self.source[k].__class__, str) - else self.source[k]) - for k in sorted(self.source.keys())])) + return '{}({})'.format( + self.__class__.__name__, ','.join([ + '{}={}'.format( + k, '"{}"'.format(self.source[k]) if issubclass(self.source[k].__class__, str) else self.source[k] + ) for k in sorted(self.source.keys()) + ]) + ) @property def contents(self): @@ -272,7 +272,7 @@ def get_ase_structure(self): :py:class:`aiida.orm.nodes.data.cif.CifData`. """ from aiida.orm import CifData - return CifData.read_cif(StringIO(self.cif)) + return CifData.read_cif(io.StringIO(self.cif)) def get_cif_node(self, store=False, parse_policy='lazy'): """ @@ -286,10 +286,10 @@ def get_cif_node(self, store=False, parse_policy='lazy'): cifnode = None - with tempfile.NamedTemporaryFile(mode='w+') as f: - f.write(self.cif) - f.flush() - cifnode = CifData(file=f.name, source=self.source, parse_policy=parse_policy) + with tempfile.NamedTemporaryFile(mode='w+') as handle: + handle.write(self.cif) + handle.flush() + cifnode = CifData(file=handle.name, source=self.source, parse_policy=parse_policy) # Maintaining backwards-compatibility. Parameter 'store' should # be removed in the future, as the new node can be stored later. @@ -333,10 +333,10 @@ def get_upf_node(self, store=False): # Prefixing with an ID in order to start file name with the name # of the described element. - with tempfile.NamedTemporaryFile(mode='w+', prefix=self.source['id']) as f: - f.write(self.contents) - f.flush() - upfnode = UpfData(file=f.name, source=self.source) + with tempfile.NamedTemporaryFile(mode='w+', prefix=self.source['id']) as handle: + handle.write(self.contents) + handle.flush() + upfnode = UpfData(file=handle.name, source=self.source) # Maintaining backwards-compatibility. Parameter 'store' should # be removed in the future, as the new node can be stored later. diff --git a/aiida/tools/dbimporters/plugins/__init__.py b/aiida/tools/dbimporters/plugins/__init__.py index 2776a55f97..a98fa2e761 100644 --- a/aiida/tools/dbimporters/plugins/__init__.py +++ b/aiida/tools/dbimporters/plugins/__init__.py @@ -7,3 +7,4 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Module for plugins to import data from external databases into an AiiDA database.""" diff --git a/aiida/tools/dbimporters/plugins/cod.py b/aiida/tools/dbimporters/plugins/cod.py index f549bfb54a..12f35e0c42 100644 --- a/aiida/tools/dbimporters/plugins/cod.py +++ b/aiida/tools/dbimporters/plugins/cod.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the COD database.""" +from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) class CodDbImporter(DbImporter): @@ -23,10 +21,9 @@ def _int_clause(self, key, alias, values): """ Returns SQL query predicate for querying integer fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") return key + ' IN (' + ', '.join(str(int(i)) for i in values) + ')' def _str_exact_clause(self, key, alias, values): @@ -34,13 +31,12 @@ def _str_exact_clause(self, key, alias, values): Returns SQL query predicate for querying string fields. """ clause_parts = [] - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") - if isinstance(e, int): - e = str(e) - clause_parts.append("'" + e + "'") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") + if isinstance(value, int): + value = str(value) + clause_parts.append("'" + value + "'") return key + ' IN (' + ', '.join(clause_parts) + ')' def _str_exact_or_none_clause(self, key, alias, values): @@ -50,64 +46,58 @@ def _str_exact_or_none_clause(self, key, alias, values): """ if None in values: values_now = [] - for e in values: - if e is not None: - values_now.append(e) - if len(values_now): + for value in values: + if value is not None: + values_now.append(value) + if values_now: clause = self._str_exact_clause(key, alias, values_now) return '{} OR {} IS NULL'.format(clause, key) - else: - return '{} IS NULL'.format(key) - else: - return self._str_exact_clause(key, alias, values) + + return '{} IS NULL'.format(key) + + return self._str_exact_clause(key, alias, values) def _formula_clause(self, key, alias, values): """ Returns SQL query predicate for querying formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - return self._str_exact_clause(key, \ - alias, \ - ['- {} -'.format(f) for f in values]) + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only strings are accepted") + return self._str_exact_clause(key, alias, ['- {} -'.format(f) for f in values]) def _str_fuzzy_clause(self, key, alias, values): """ Returns SQL query predicate for fuzzy querying of string fields. """ clause_parts = [] - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") - if isinstance(e, int): - e = str(e) - clause_parts.append(key + " LIKE '%" + e + "%'") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and strings are accepted") + if isinstance(value, int): + value = str(value) + clause_parts.append(key + " LIKE '%" + value + "%'") return ' OR '.join(clause_parts) - def _composition_clause(self, key, alias, values): + def _composition_clause(self, _, alias, values): """ Returns SQL query predicate for querying elements in formula fields. """ clause_parts = [] - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - clause_parts.append("formula REGEXP ' " + e + "[0-9 ]'") + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + "' only strings are accepted") + clause_parts.append("formula REGEXP ' " + value + "[0-9 ]'") return ' AND '.join(clause_parts) def _double_clause(self, key, alias, values, precision): """ Returns SQL query predicate for querying double-valued fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, float): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and floats are accepted") - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d-precision, d+precision) for d in values) + for value in values: + if not isinstance(value, int) and not isinstance(value, float): + raise ValueError("incorrect value for keyword '" + alias + "' only integers and floats are accepted") + return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) length_precision = 0.001 angle_precision = 0.001 @@ -145,46 +135,43 @@ def _pressure_clause(self, key, alias, values): """ return self._double_clause(key, alias, values, self.pressure_precision) - _keywords = {'id': ['file', _int_clause], - 'element': ['element', _composition_clause], - 'number_of_elements': ['nel', _int_clause], - 'mineral_name': ['mineral', _str_fuzzy_clause], - 'chemical_name': ['chemname', _str_fuzzy_clause], - 'formula': ['formula', _formula_clause], - 'volume': ['vol', _volume_clause], - 'spacegroup': ['sg', _str_exact_clause], - 'spacegroup_hall': ['sgHall', _str_exact_clause], - 'a': ['a', _length_clause], - 'b': ['b', _length_clause], - 'c': ['c', _length_clause], - 'alpha': ['alpha', _angle_clause], - 'beta': ['beta', _angle_clause], - 'gamma': ['gamma', _angle_clause], - 'z': ['Z', _int_clause], - 'measurement_temp': ['celltemp', _temperature_clause], - 'diffraction_temp': ['diffrtemp', _temperature_clause], - 'measurement_pressure': - ['cellpressure', _pressure_clause], - 'diffraction_pressure': - ['diffrpressure', _pressure_clause], - 'authors': ['authors', _str_fuzzy_clause], - 'journal': ['journal', _str_fuzzy_clause], - 'title': ['title', _str_fuzzy_clause], - 'year': ['year', _int_clause], - 'journal_volume': ['volume', _int_clause], - 'journal_issue': ['issue', _str_exact_clause], - 'first_page': ['firstpage', _str_exact_clause], - 'last_page': ['lastpage', _str_exact_clause], - 'doi': ['doi', _str_exact_clause], - 'determination_method': ['method', _str_exact_or_none_clause]} + _keywords = { + 'id': ['file', _int_clause], + 'element': ['element', _composition_clause], + 'number_of_elements': ['nel', _int_clause], + 'mineral_name': ['mineral', _str_fuzzy_clause], + 'chemical_name': ['chemname', _str_fuzzy_clause], + 'formula': ['formula', _formula_clause], + 'volume': ['vol', _volume_clause], + 'spacegroup': ['sg', _str_exact_clause], + 'spacegroup_hall': ['sgHall', _str_exact_clause], + 'a': ['a', _length_clause], + 'b': ['b', _length_clause], + 'c': ['c', _length_clause], + 'alpha': ['alpha', _angle_clause], + 'beta': ['beta', _angle_clause], + 'gamma': ['gamma', _angle_clause], + 'z': ['Z', _int_clause], + 'measurement_temp': ['celltemp', _temperature_clause], + 'diffraction_temp': ['diffrtemp', _temperature_clause], + 'measurement_pressure': ['cellpressure', _pressure_clause], + 'diffraction_pressure': ['diffrpressure', _pressure_clause], + 'authors': ['authors', _str_fuzzy_clause], + 'journal': ['journal', _str_fuzzy_clause], + 'title': ['title', _str_fuzzy_clause], + 'year': ['year', _int_clause], + 'journal_volume': ['volume', _int_clause], + 'journal_issue': ['issue', _str_exact_clause], + 'first_page': ['firstpage', _str_exact_clause], + 'last_page': ['lastpage', _str_exact_clause], + 'doi': ['doi', _str_exact_clause], + 'determination_method': ['method', _str_exact_or_none_clause] + } def __init__(self, **kwargs): self._db = None self._cursor = None - self._db_parameters = {'host': 'www.crystallography.net', - 'user': 'cod_reader', - 'passwd': '', - 'db': 'cod'} + self._db_parameters = {'host': 'www.crystallography.net', 'user': 'cod_reader', 'passwd': '', 'db': 'cod'} self.setup_db(**kwargs) def query_sql(self, **kwargs): @@ -200,19 +187,12 @@ def query_sql(self, **kwargs): values = kwargs.pop(key) if not isinstance(values, list): values = [values] - sql_parts.append( \ - '(' + self._keywords[key][1](self, \ - self._keywords[key][0], \ - key, \ - values) + \ - ')') - if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "search keyword(s) '" + \ - "', '".join(kwargs.keys()) + "' " + \ - 'is(are) not implemented for COD') - return 'SELECT file, svnrevision FROM data WHERE ' + \ - ' AND '.join(sql_parts) + sql_parts.append('(' + self._keywords[key][1](self, self._keywords[key][0], key, values) + ')') + + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) + + return 'SELECT file, svnrevision FROM data WHERE ' + ' AND '.join(sql_parts) def query(self, **kwargs): """ @@ -229,8 +209,7 @@ def query(self, **kwargs): self._cursor.execute(query_statement) self._db.commit() for row in self._cursor.fetchall(): - results.append({'id': str(row[0]), - 'svnrevision': str(row[1])}) + results.append({'id': str(row[0]), 'svnrevision': str(row[1])}) finally: self._disconnect_db() @@ -240,15 +219,14 @@ def setup_db(self, **kwargs): """ Changes the database connection details. """ - for key in self._db_parameters.keys(): + for key in self._db_parameters: if key in kwargs.keys(): self._db_parameters[key] = kwargs.pop(key) if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: '" + \ - "', '".join(self._db_parameters.keys()) + "'") + raise NotImplementedError( + "unknown database connection parameter(s): '" + "', '".join(kwargs.keys()) + + "', available parameters: '" + "', '".join(self._db_parameters.keys()) + "'" + ) def get_supported_keywords(self): """ @@ -267,10 +245,12 @@ def _connect_db(self): except ImportError: import pymysql as MySQLdb - self._db = MySQLdb.connect(host=self._db_parameters['host'], - user=self._db_parameters['user'], - passwd=self._db_parameters['passwd'], - db=self._db_parameters['db']) + self._db = MySQLdb.connect( + host=self._db_parameters['host'], + user=self._db_parameters['user'], + passwd=self._db_parameters['passwd'], + db=self._db_parameters['db'] + ) self._cursor = self._db.cursor() def _disconnect_db(self): @@ -280,7 +260,7 @@ def _disconnect_db(self): self._db.close() -class CodSearchResults(DbSearchResults): +class CodSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on COD. """ @@ -316,22 +296,22 @@ def _get_url(self, result_dict): if 'svnrevision' in result_dict and \ result_dict['svnrevision'] is not None: return '{}@{}'.format(url, result_dict['svnrevision']) - else: - return url + + return url -class CodEntry(CifEntry): +class CodEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from COD. """ _license = 'CC0' - def __init__(self, uri, db_name='Crystallography Open Database', - db_uri='http://www.crystallography.net/cod', **kwargs): + def __init__( + self, uri, db_name='Crystallography Open Database', db_uri='http://www.crystallography.net/cod', **kwargs + ): """ Creates an instance of :py:class:`aiida.tools.dbimporters.plugins.cod.CodEntry`, related to the supplied URI. """ - super().__init__(db_name=db_name, db_uri=db_uri, - uri=uri, **kwargs) + super().__init__(db_name=db_name, db_uri=db_uri, uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/icsd.py b/aiida/tools/dbimporters/plugins/icsd.py index 6367a522f5..2584a41081 100644 --- a/aiida/tools/dbimporters/plugins/icsd.py +++ b/aiida/tools/dbimporters/plugins/icsd.py @@ -7,29 +7,23 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the CISD database.""" +import io - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry class IcsdImporterExp(Exception): - pass + """Base class for ICSD exceptions.""" class CifFileErrorExp(IcsdImporterExp): - """ - Raised when the author loop is missing in a CIF file. - """ - pass + """Raised when the author loop is missing in a CIF file.""" class NoResultsWebExp(IcsdImporterExp): - """ - Raised when a webpage query returns no results. - """ - pass + """Raised when a webpage query returns no results.""" class IcsdDbImporter(DbImporter): @@ -82,15 +76,16 @@ class IcsdDbImporter(DbImporter): def __init__(self, **kwargs): - self.db_parameters = {'server': '', - 'urladd': 'index.php?', - 'querydb': True, - 'dl_db': 'icsd', - 'host': '', - 'user': 'dba', - 'passwd': 'sql', - 'db': 'icsd', - 'port': '3306', + self.db_parameters = { + 'server': '', + 'urladd': 'index.php?', + 'querydb': True, + 'dl_db': 'icsd', + 'host': '', + 'user': 'dba', + 'passwd': 'sql', + 'db': 'icsd', + 'port': '3306', } self.setup_db(**kwargs) @@ -103,69 +98,61 @@ def _int_clause(self, key, alias, values): :param values: Corresponding values from query :return: SQL query predicate """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return '{} IN ({})'.format(key, ', '.join(str(int(i)) for i in values)) def _str_exact_clause(self, key, alias, values): """ Return SQL query predicate for querying string fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return '{} IN ({})'.format(key, ', '.join("'{}'".format(f) for f in values)) def _formula_clause(self, key, alias, values): """ Return SQL query predicate for querying formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") - return self._str_exact_clause(key, \ - alias, \ - [str(f) for f in values]) + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') + return self._str_exact_clause(key, alias, [str(f) for f in values]) def _str_fuzzy_clause(self, key, alias, values): """ Return SQL query predicate for fuzzy querying of string fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') return ' OR '.join("{} LIKE '%{}%'".format(key, s) for s in values) - def _composition_clause(self, key, alias, values): + def _composition_clause(self, key, alias, values): # pylint: disable=unused-argument """ Return SQL query predicate for querying elements in formula fields. """ - for e in values: - if not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") + for value in values: + if not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') # SUM_FORM in the ICSD always stores a numeral after the element name, # STRUCT_FORM does not, so it's better to use SUM_FORM for the composition query. # The element-numeral pair can be in the beginning of the formula expression (therefore no space before), # or at the end of the formula expression (no space after). # Be aware that one needs to check that space/beginning of line before and ideally also space/end of line # after, because I found that capitalization of the element name is not enforced in these queries. - return ' AND '.join(r'SUM_FORM REGEXP \'(^|\ ){}[0-9\.]+($|\ )\''.format(e) for e in values) + return ' AND '.join(r'SUM_FORM REGEXP \'(^|\ ){}[0-9\.]+($|\ )\''.format(value) for value in values) def _double_clause(self, key, alias, values, precision): """ Return SQL query predicate for querying double-valued fields. """ - for e in values: - if not isinstance(e, int) and not isinstance(e, float): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only integers and floats are accepted") - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d-precision, d+precision) for d in values) + for value in values: + if not isinstance(value, int) and not isinstance(value, float): + raise ValueError("incorrect value for keyword '" + alias + ' only integers and floats are accepted') + return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) def _crystal_system_clause(self, key, alias, values): """ @@ -181,117 +168,127 @@ def _crystal_system_clause(self, key, alias, values): 'triclinic': 'TC' } # from icsd accepted crystal systems - for e in values: - if not isinstance(e, int) and not isinstance(e, str): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings are accepted") + for value in values: + if not isinstance(value, int) and not isinstance(value, str): + raise ValueError("incorrect value for keyword '" + alias + ' only strings are accepted') return key + ' IN (' + ', '.join("'" + valid_systems[f.lower()] + "'" for f in values) + ')' def _length_clause(self, key, alias, values): """ Return SQL query predicate for querying lattice vector lengths. """ - return self.double_clause(key, alias, values, self.length_precision) + return self._double_clause(key, alias, values, self.length_precision) def _density_clause(self, key, alias, values): """ Return SQL query predicate for querying density. """ - return self.double_clause(key, alias, values, self.density_precision) + return self._double_clause(key, alias, values, self.density_precision) def _angle_clause(self, key, alias, values): """ Return SQL query predicate for querying lattice angles. """ - return self.double_clause(key, alias, values, self.angle_precision) + return self._double_clause(key, alias, values, self.angle_precision) def _volume_clause(self, key, alias, values): """ Return SQL query predicate for querying unit cell volume. """ - return self.double_clause(key, alias, values, self.volume_precision) + return self._double_clause(key, alias, values, self.volume_precision) def _temperature_clause(self, key, alias, values): """ Return SQL query predicate for querying temperature. """ - return self.double_clause(key, alias, values, self.temperature_precision) + return self._double_clause(key, alias, values, self.temperature_precision) def _pressure_clause(self, key, alias, values): """ Return SQL query predicate for querying pressure. """ - return self.double_clause(key, alias, values, self.pressure_precision) + return self._double_clause(key, alias, values, self.pressure_precision) # for the web query - def _parse_all(k, v): + @staticmethod + def _parse_all(key, values): # pylint: disable=unused-argument """ Convert numbers, strings, lists into strings. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if type(v) is list: - retval = ' '.join(v) - elif type(v) is int: - retval = str(v) - elif type(v) is str: - retval = v + if isinstance(values, list): + retval = ' '.join(values) + elif isinstance(values, int): + retval = str(values) + elif isinstance(values, str): + retval = values return retval - def _parse_number(k, v): + @staticmethod + def _parse_number(key, values): # pylint: disable=unused-argument """ Convert int into string. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if type(v) is int: - retval = str(v) - elif type(v) is str: - retval = v + if isinstance(values, int): + retval = str(values) + elif isinstance(values, str): + retval = values return retval - def _parse_mineral(k, v): + @staticmethod + def _parse_mineral(key, values): """ Convert mineral_name and chemical_name into right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if k == 'mineral_name': - retval = 'M=' + v - elif k == 'chemical_name': - retval = 'C=' + v + if key == 'mineral_name': + retval = 'M=' + values + elif key == 'chemical_name': + retval = 'C=' + values return retval - def _parse_volume(k, v): + @staticmethod + def _parse_volume(key, values): # pylint: disable=too-many-return-statements """ Convert volume, cell parameter and angle queries into right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ - if k == 'volume': - return 'v=' + v - elif k == 'a': - return 'a=' + v - elif k == 'b': - return 'b=' + v - elif k == 'c': - return 'c=' + v - elif k == 'alpha': - return 'al=' + v - elif k == 'beta': - return 'be=' + v - elif k == 'gamma': - return 'ga=' + v - - def _parse_system(k, v): + if key == 'volume': + return 'v=' + values + + if key == 'a': + return 'a=' + values + + if key == 'b': + return 'b=' + values + + if key == 'c': + return 'c=' + values + + if key == 'alpha': + return 'al=' + values + + if key == 'beta': + return 'be=' + values + + if key == 'gamma': + return 'ga=' + values + + @staticmethod + def _parse_system(key, values): # pylint: disable=unused-argument """ Return crystal system in the right format. - :param k: query parameter - :param v: corresponding values + :param key: query parameter + :param values: corresponding values :return retval: string """ valid_systems = { @@ -304,54 +301,56 @@ def _parse_system(k, v): 'triclinic': 'TC' } - return valid_systems[v.lower()] + return valid_systems[values.lower()] # mysql database - query parameter (alias) : [mysql keyword (key), function to call] - keywords_db = {'id': ['COLL_CODE', _int_clause], - 'element': ['SUM_FORM;', _composition_clause], - 'number_of_elements': ['EL_COUNT', _int_clause], - 'chemical_name': ['CHEM_NAME', _str_fuzzy_clause], - 'formula': ['SUM_FORM', _formula_clause], - 'volume': ['C_VOL', _volume_clause], - 'spacegroup': ['SGR', _str_exact_clause], - 'a': ['A_LEN', _length_clause], - 'b': ['B_LEN', _length_clause], - 'c': ['C_LEN', _length_clause], - 'alpha': ['ALPHA', _angle_clause], - 'beta': ['BETA', _angle_clause], - 'gamma': ['GAMMA', _angle_clause], - 'density': ['DENSITY_CALC', _density_clause], - 'wyckoff': ['WYCK', _str_exact_clause], - 'molar_mass': ['MOL_MASS', _density_clause], - 'pdf_num': ['PDF_NUM', _str_exact_clause], - 'z': ['Z', _int_clause], - 'measurement_temp': ['TEMPERATURE', _temperature_clause], - 'authors': ['AUTHORS_TEXT', _str_fuzzy_clause], - 'journal': ['journal', _str_fuzzy_clause], - 'title': ['AU_TITLE', _str_fuzzy_clause], - 'year': ['MPY', _int_clause], - 'crystal_system': ['CRYST_SYS_CODE', _crystal_system_clause], + keywords_db = { + 'id': ['COLL_CODE', _int_clause], + 'element': ['SUM_FORM;', _composition_clause], + 'number_of_elements': ['EL_COUNT', _int_clause], + 'chemical_name': ['CHEM_NAME', _str_fuzzy_clause], + 'formula': ['SUM_FORM', _formula_clause], + 'volume': ['C_VOL', _volume_clause], + 'spacegroup': ['SGR', _str_exact_clause], + 'a': ['A_LEN', _length_clause], + 'b': ['B_LEN', _length_clause], + 'c': ['C_LEN', _length_clause], + 'alpha': ['ALPHA', _angle_clause], + 'beta': ['BETA', _angle_clause], + 'gamma': ['GAMMA', _angle_clause], + 'density': ['DENSITY_CALC', _density_clause], + 'wyckoff': ['WYCK', _str_exact_clause], + 'molar_mass': ['MOL_MASS', _density_clause], + 'pdf_num': ['PDF_NUM', _str_exact_clause], + 'z': ['Z', _int_clause], + 'measurement_temp': ['TEMPERATURE', _temperature_clause], + 'authors': ['AUTHORS_TEXT', _str_fuzzy_clause], + 'journal': ['journal', _str_fuzzy_clause], + 'title': ['AU_TITLE', _str_fuzzy_clause], + 'year': ['MPY', _int_clause], + 'crystal_system': ['CRYST_SYS_CODE', _crystal_system_clause], } # keywords accepted for the web page query - keywords = {'id': ('authors', _parse_all), - 'authors': ('authors', _parse_all), - 'element': ('elements', _parse_all), - 'number_of_elements': ('elementc', _parse_all), - 'mineral_name': ('mineral', _parse_mineral), - 'chemical_name': ('mineral', _parse_mineral), - 'formula': ('formula', _parse_all), - 'volume': ('volume', _parse_volume), - 'a': ('volume', _parse_volume), - 'b': ('volume', _parse_volume), - 'c': ('volume', _parse_volume), - 'alpha': ('volume', _parse_volume), - 'beta': ('volume', _parse_volume), - 'gamma': ('volume', _parse_volume), - 'spacegroup': ('spaceg', _parse_all), - 'journal': ('journal', _parse_all), - 'title': ('title', _parse_all), - 'year': ('year', _parse_all), - 'crystal_system': ('system', _parse_system), + keywords = { + 'id': ('authors', _parse_all), + 'authors': ('authors', _parse_all), + 'element': ('elements', _parse_all), + 'number_of_elements': ('elementc', _parse_all), + 'mineral_name': ('mineral', _parse_mineral), + 'chemical_name': ('mineral', _parse_mineral), + 'formula': ('formula', _parse_all), + 'volume': ('volume', _parse_volume), + 'a': ('volume', _parse_volume), + 'b': ('volume', _parse_volume), + 'c': ('volume', _parse_volume), + 'alpha': ('volume', _parse_volume), + 'beta': ('volume', _parse_volume), + 'gamma': ('volume', _parse_volume), + 'spacegroup': ('spaceg', _parse_all), + 'journal': ('journal', _parse_all), + 'title': ('title', _parse_all), + 'year': ('year', _parse_all), + 'crystal_system': ('system', _parse_system), } def query(self, **kwargs): @@ -361,11 +360,10 @@ def query(self, **kwargs): :param kwargs: A list of ''keyword = [values]'' pairs. """ - if self.db_parameters['querydb']: return self._query_sql_db(**kwargs) - else: - return self._queryweb(**kwargs) + + return self._queryweb(**kwargs) def _query_sql_db(self, **kwargs): """ @@ -377,13 +375,11 @@ def _query_sql_db(self, **kwargs): sql_where_query = [] # second part of sql query - for k, v in kwargs.items(): - if not isinstance(v, list): - v = [v] - sql_where_query.append('({})'.format(self.keywords_db[k][1](self, - self.keywords_db[k][0], - k, v))) - if 'crystal_system' in kwargs.keys(): # to query another table than the main one, add LEFT JOIN in front of WHERE + for key, value in kwargs.items(): + if not isinstance(value, list): + value = [value] + sql_where_query.append('({})'.format(self.keywords_db[key][1](self, self.keywords_db[key][0], key, value))) + if 'crystal_system' in kwargs: # to query another table than the main one, add LEFT JOIN in front of WHERE sql_query = 'LEFT JOIN space_group ON space_group.sgr=icsd.sgr LEFT '\ 'JOIN space_group_number ON '\ 'space_group_number.sgr_num=space_group.sgr_num '\ @@ -395,7 +391,6 @@ def _query_sql_db(self, **kwargs): return IcsdSearchResults(query=sql_query, db_parameters=self.db_parameters) - def _queryweb(self, **kwargs): """ Perform a query on the Icsd web database using ``keyword = value`` pairs, @@ -405,7 +400,7 @@ def _queryweb(self, **kwargs): :return: IcsdSearchResults """ from urllib.parse import urlencode - self.actual_args = { + self.actual_args = { # pylint: disable=attribute-defined-outside-init 'action': 'Search', 'nb_rows': '100', # max is 100 'order_by': 'yearDesc', @@ -414,10 +409,10 @@ def _queryweb(self, **kwargs): 'mineral': '' } - for k, v in kwargs.items(): + for key, value in kwargs.items(): try: - realname = self.keywords[k][0] - newv = self.keywords[k][1](k, v) + realname = self.keywords[key][0] + newv = self.keywords[key][1](key, value) # Because different keys correspond to the same search field. if realname in ['authors', 'volume', 'mineral']: self.actual_args[realname] = self.actual_args[realname] + newv + ' ' @@ -439,8 +434,8 @@ def setup_db(self, **kwargs): :param kwargs: db_parameters for the mysql database connection (host, user, passwd, db, port) """ - for key in self.db_parameters.keys(): - if key in kwargs.keys(): + for key in self.db_parameters: + if key in kwargs: self.db_parameters[key] = kwargs[key] def get_supported_keywords(self): @@ -449,11 +444,11 @@ def get_supported_keywords(self): """ if self.db_parameters['querydb']: return self.keywords_db.keys() - else: - return self.keywords.keys() + + return self.keywords.keys() -class IcsdSearchResults(DbSearchResults): +class IcsdSearchResults(DbSearchResults): # pylint: disable=abstract-method,too-many-instance-attributes """ Result manager for the query performed on ICSD. @@ -465,8 +460,8 @@ class IcsdSearchResults(DbSearchResults): db_name = 'Icsd' def __init__(self, query, db_parameters): - - self.db = None + # pylint: disable=super-init-not-called + self.db = None # pylint: disable=invalid-name self.cursor = None self.db_parameters = db_parameters self.query = query @@ -498,9 +493,9 @@ def next(self): if self.number_of_results > self.position: self.position = self.position + 1 return self.at(self.position - 1) - else: - self.position = 0 - raise StopIteration() + + self.position = 0 + raise StopIteration() def at(self, position): """ @@ -515,20 +510,23 @@ def at(self, position): if position not in self.entries: if self.db_parameters['querydb']: - self.entries[position] = IcsdEntry(self.db_parameters['server'] + - self.db_parameters['dl_db'] + self.cif_url.format( - self._results[position]), - db_name=self.db_name, id=self.cif_numbers[position], - version = self.db_version, - extras={'idnum': self._results[position]}) + self.entries[position] = IcsdEntry( + self.db_parameters['server'] + self.db_parameters['dl_db'] + + self.cif_url.format(self._results[position]), + db_name=self.db_name, + id=self.cif_numbers[position], + version=self.db_version, + extras={'idnum': self._results[position]} + ) else: - self.entries[position] = IcsdEntry(self.db_parameters['server'] + - self.db_parameters['dl_db'] + self.cif_url.format( - self._results[position]), - db_name=self.db_name, extras={'idnum': self._results[position]}) + self.entries[position] = IcsdEntry( + self.db_parameters['server'] + self.db_parameters['dl_db'] + + self.cif_url.format(self._results[position]), + db_name=self.db_name, + extras={'idnum': self._results[position]} + ) return self.entries[position] - def query_db_version(self): """ Query the version of the icsd database (last row of RELEASE_TAGS). @@ -554,8 +552,7 @@ def query_db_version(self): raise IcsdImporterExp('Database version not found') else: - raise NotImplementedError('Cannot query the database version with ' - 'a web query.') + raise NotImplementedError('Cannot query the database version with ' 'a web query.') def query_page(self): """ @@ -568,10 +565,9 @@ def query_page(self): if self.db_parameters['querydb']: self._connect_db() - query_statement = '{}{}{} LIMIT {}, 100'.format(self.sql_select_query, - self.sql_from_query, - self.query, - (self.page-1)*100) + query_statement = '{}{}{} LIMIT {}, 100'.format( + self.sql_select_query, self.sql_from_query, self.query, (self.page - 1) * 100 + ) self.cursor.execute(query_statement) self.db.commit() @@ -585,23 +581,21 @@ def query_page(self): self._disconnect_db() - else: - from bs4 import BeautifulSoup + from bs4 import BeautifulSoup # pylint: disable=import-error from urllib.request import urlopen import re - self.html = urlopen(self.db_parameters['server'] + - self.db_parameters['db'] + '/' + - self.query.format(str(self.page))).read() + self.html = urlopen( + self.db_parameters['server'] + self.db_parameters['db'] + '/' + self.query.format(str(self.page)) + ).read() self.soup = BeautifulSoup(self.html) try: if self.number_of_results is None: - self.number_of_results = int(re.findall(r'\d+', - str(self.soup.find_all('i')[-1]))[0]) + self.number_of_results = int(re.findall(r'\d+', str(self.soup.find_all('i')[-1]))[0]) except IndexError: raise NoResultsWebExp @@ -617,11 +611,12 @@ def _connect_db(self): except ImportError: import pymysql as MySQLdb - self.db = MySQLdb.connect(host=self.db_parameters['host'], - user=self.db_parameters['user'], - passwd=self.db_parameters['passwd'], - db=self.db_parameters['db'], - port=int(self.db_parameters['port']) + self.db = MySQLdb.connect( + host=self.db_parameters['host'], + user=self.db_parameters['user'], + passwd=self.db_parameters['passwd'], + db=self.db_parameters['db'], + port=int(self.db_parameters['port']) ) self.cursor = self.db.cursor() @@ -632,7 +627,7 @@ def _disconnect_db(self): self.db.close() -class IcsdEntry(CifEntry): +class IcsdEntry(CifEntry): # pylint: disable=abstract-method """ Represent an entry from Icsd. @@ -651,12 +646,14 @@ def __init__(self, uri, **kwargs): """ super().__init__(**kwargs) self.source = { - 'db_name': kwargs.get('db_name','Icsd'), + 'db_name': kwargs.get('db_name', 'Icsd'), 'db_uri': None, 'id': kwargs.get('id', None), 'version': kwargs.get('version', None), 'uri': uri, - 'extras': {'idnum': kwargs.get('extras', {}).get('idnum', None)}, + 'extras': { + 'idnum': kwargs.get('extras', {}).get('idnum', None) + }, 'license': self._license, } @@ -668,10 +665,11 @@ def contents(self): PyCifRW library (and most other sensible applications), expects UTF-8. Therefore, we decode the original CIF data to unicode and encode it in the UTF-8 format """ + import urllib.request + if self._contents is None: from hashlib import md5 - - self._contents = urlopen(self.source['uri']).read() + self._contents = urllib.request.urlopen(self.source['uri']).read() self._contents = self._contents.decode('iso-8859-1').encode('utf8') self.source['source_md5'] = md5(self._contents).hexdigest() @@ -682,9 +680,8 @@ def get_ase_structure(self): :return: ASE structure corresponding to the cif file. """ from aiida.orm import CifData - cif = correct_cif(self.cif) - return CifData.read_cif(StringIO(cif)) + return CifData.read_cif(io.StringIO(cif)) def correct_cif(cif): @@ -709,12 +706,12 @@ def correct_cif(cif): inc = 1 while True: words = lines[author_index + inc].split() - #in case loop is finished -> return cif lines. - #use regular expressions ? + # in case loop is finished -> return cif lines. + # use regular expressions ? if len(words) == 0 or words[0] == 'loop_' or words[0][0] == '_': return '\n'.join(lines) - elif ((words[0][0] == "'" and words[-1][-1] == "'") - or (words[0][0] == '"' and words[-1][-1] == '"')): + + if ((words[0][0] == "'" and words[-1][-1] == "'") or (words[0][0] == '"' and words[-1][-1] == '"')): # if quotes are already there, check next line inc = inc + 1 else: diff --git a/aiida/tools/dbimporters/plugins/materialsproject.py b/aiida/tools/dbimporters/plugins/materialsproject.py index 3da7bf9039..0f09d3b324 100644 --- a/aiida/tools/dbimporters/plugins/materialsproject.py +++ b/aiida/tools/dbimporters/plugins/materialsproject.py @@ -7,13 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Module that contains the class definitions necessary to offer support for -queries to Materials Project.""" - -import os +""""Implementation of `DbImporter` for the Materials Project database.""" import datetime +import os import requests + from pymatgen import MPRester + from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults diff --git a/aiida/tools/dbimporters/plugins/mpds.py b/aiida/tools/dbimporters/plugins/mpds.py index 00984b92b1..a5be35e9c4 100644 --- a/aiida/tools/dbimporters/plugins/mpds.py +++ b/aiida/tools/dbimporters/plugins/mpds.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +""""Implementation of `DbImporter` for the MPDS database.""" import copy import enum import os @@ -62,7 +62,7 @@ def __init__(self, url=None, api_key=None): self.setup_db(url=url, api_key=api_key) self._structures = StructuresCollection(self) - def setup_db(self, url=None, api_key=None, collection=None): + def setup_db(self, url=None, api_key=None, collection=None): # pylint: disable=arguments-differ """ Setup the required parameters for HTTP requests to the REST API @@ -112,7 +112,7 @@ def structures(self): return self._structures @property - def get_supported_keywords(self): + def get_supported_keywords(self): # pylint: disable=invalid-overridden-method """ Returns the list of all supported query keywords @@ -127,7 +127,7 @@ def url(self): """ return self._url - def query(self, query, collection=None): + def query(self, query, collection=None): # pylint: disable=arguments-differ """ Query the database with a given dictionary of query parameters for a given collection @@ -176,6 +176,8 @@ def find(self, query, fmt=DEFAULT_API_FORMAT): :param query: a dictionary with the query parameters """ + # pylint: disable=too-many-branches + if not isinstance(query, dict): raise TypeError('The query argument should be a dictionary') @@ -234,7 +236,8 @@ def get(self, fmt=DEFAULT_API_FORMAT, **kwargs): kwargs['fmt'] = fmt.value return requests.get(url=self.url, params=kwargs, headers={'Key': self.api_key}) - def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): + @staticmethod + def get_response_content(response, fmt=DEFAULT_API_FORMAT): """ Analyze the response of an HTTP GET request, verify that the response code is OK and return the json loaded response text @@ -254,10 +257,11 @@ def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): raise ValueError('Got error response: {}'.format(error)) return content - else: - return response.text - def get_id_from_cif(self, cif): + return response.text + + @staticmethod + def get_id_from_cif(cif): """ Extract the entry id from the string formatted cif response of the MPDS API @@ -275,6 +279,7 @@ def get_id_from_cif(self, cif): class StructuresCollection: + """Collection of structures.""" def __init__(self, engine): self._engine = engine @@ -305,11 +310,11 @@ class MpdsEntry(DbEntry): Represents an MPDS database entry """ - def __init__(self, url, **kwargs): + def __init__(self, _, **kwargs): """ Set the class license from the source dictionary """ - license = kwargs.pop('license', None) + license = kwargs.pop('license', None) # pylint: disable=redefined-builtin if license is not None: self._license = license @@ -317,7 +322,7 @@ def __init__(self, url, **kwargs): super().__init__(**kwargs) -class MpdsCifEntry(CifEntry, MpdsEntry): +class MpdsCifEntry(CifEntry, MpdsEntry): # pylint: disable=abstract-method """ An extension of the MpdsEntry class with the CifEntry class, which will treat the contents property through the URI as a cif file @@ -337,10 +342,8 @@ def __init__(self, url, **kwargs): self.cif = cif -class MpdsSearchResults(DbSearchResults): - """ - A collection of MpdsEntry query result entries - """ +class MpdsSearchResults(DbSearchResults): # pylint: disable=abstract-method + """Collection of MpdsEntry query result entries.""" _db_name = 'Materials Platform for Data Science' _db_uri = 'https://mpds.io/' diff --git a/aiida/tools/dbimporters/plugins/mpod.py b/aiida/tools/dbimporters/plugins/mpod.py index 88f67abcda..139dcf963e 100644 --- a/aiida/tools/dbimporters/plugins/mpod.py +++ b/aiida/tools/dbimporters/plugins/mpod.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the MPOD database.""" +from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) class MpodDbImporter(DbImporter): @@ -24,15 +22,16 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str) and not isinstance(values, int): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings and integers are accepted") + raise ValueError("incorrect value for keyword '" + alias + "' -- only strings and integers are accepted") return '{}={}'.format(key, values) - _keywords = {'phase_name': ['phase_name', _str_clause], - 'formula': ['formula', _str_clause], - 'element': ['element', None], - 'cod_id': ['cod_code', _str_clause], - 'authors': ['publ_author', _str_clause]} + _keywords = { + 'phase_name': ['phase_name', _str_clause], + 'formula': ['formula', _str_clause], + 'element': ['element', None], + 'cod_id': ['cod_code', _str_clause], + 'authors': ['publ_author', _str_clause] + } def __init__(self, **kwargs): self._query_url = 'http://mpod.cimav.edu.mx/data/search/' @@ -46,8 +45,7 @@ def query_get(self, **kwargs): :return: a list containing strings for HTTP GET statement. """ if 'formula' in kwargs.keys() and 'element' in kwargs.keys(): - raise ValueError('can not query both formula and elements ' - 'in MPOD') + raise ValueError('can not query both formula and elements ' 'in MPOD') elements = [] if 'element' in kwargs.keys(): @@ -56,25 +54,18 @@ def query_get(self, **kwargs): elements = [elements] get_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) - get_parts.append( - self._keywords[key][1](self, - self._keywords[key][0], - key, - values)) + get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) - if kwargs.keys(): - raise NotImplementedError("search keyword(s) '" - "', '".join(kwargs.keys()) + "' " - 'is(are) not implemented for MPOD') + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) queries = [] - for e in elements: - queries.append(self._query_url + '?' + - '&'.join(get_parts + - [self._str_clause('formula', 'element', e)])) + for element in elements: + clauses = [self._str_clause('formula', 'element', element)] + queries.append(self._query_url + '?' + '&'.join(get_parts + clauses)) if not queries: queries.append(self._query_url + '?' + '&'.join(get_parts)) @@ -103,18 +94,15 @@ def query(self, **kwargs): return MpodSearchResults([{'id': x} for x in results]) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ if query_url: self._query_url = query_url - if kwargs.keys(): - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: 'query_url'") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) def get_supported_keywords(self): """ @@ -125,7 +113,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class MpodSearchResults(DbSearchResults): +class MpodSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on MPOD. """ @@ -156,7 +144,7 @@ def _get_url(self, result_dict): return self._base_url + result_dict['id'] + '.mpod' -class MpodEntry(CifEntry): +class MpodEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from MPOD. """ @@ -167,7 +155,6 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.mpod.MpodEntry`, related to the supplied URI. """ - super().__init__(db_name='Material Properties Open Database', - db_uri='http://mpod.cimav.edu.mx', - uri=uri, - **kwargs) + super().__init__( + db_name='Material Properties Open Database', db_uri='http://mpod.cimav.edu.mx', uri=uri, **kwargs + ) diff --git a/aiida/tools/dbimporters/plugins/nninc.py b/aiida/tools/dbimporters/plugins/nninc.py index 0e7f13bcc3..ce2e724f47 100644 --- a/aiida/tools/dbimporters/plugins/nninc.py +++ b/aiida/tools/dbimporters/plugins/nninc.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - UpfEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the NNIN/C database.""" +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, UpfEntry class NnincDbImporter(DbImporter): @@ -24,14 +22,18 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str): - raise ValueError("incorrect value for keyword '{}' -- only " - 'strings and integers are accepted'.format(alias)) + raise ValueError( + "incorrect value for keyword '{}' -- only " + 'strings and integers are accepted'.format(alias) + ) return '{}={}'.format(key, values) - _keywords = {'xc_approximation': ['frmxcprox', _str_clause], - 'xc_type': ['frmxctype', _str_clause], - 'pseudopotential_class': ['frmspclass', _str_clause], - 'element': ['element', None]} + _keywords = { + 'xc_approximation': ['frmxcprox', _str_clause], + 'xc_type': ['frmxctype', _str_clause], + 'pseudopotential_class': ['frmspclass', _str_clause], + 'element': ['element', None] + } def __init__(self, **kwargs): self._query_url = 'http://nninc.cnf.cornell.edu/dd_search.php' @@ -45,20 +47,14 @@ def query_get(self, **kwargs): :return: a string with HTTP GET statement. """ get_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) if self._keywords[key][1] is not None: - get_parts.append( - self._keywords[key][1](self, - self._keywords[key][0], - key, - values)) + get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) - if kwargs.keys(): - raise NotImplementedError("search keyword(s) '" - "', '".join(kwargs.keys()) + \ - "' is(are) not implemented for NNIN/C") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) return self._query_url + '?' + '&'.join(get_parts) @@ -91,7 +87,7 @@ def query(self, **kwargs): return NnincSearchResults([{'id': x} for x in results]) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ @@ -113,7 +109,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class NnincSearchResults(DbSearchResults): +class NnincSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on NNIN/C Pseudopotential Virtual Vault. @@ -156,7 +152,6 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.nninc.NnincEntry`, related to the supplied URI. """ - super().__init__(db_name='NNIN/C Pseudopotential Virtual Vault', - db_uri='http://nninc.cnf.cornell.edu', - uri=uri, - **kwargs) + super().__init__( + db_name='NNIN/C Pseudopotential Virtual Vault', db_uri='http://nninc.cnf.cornell.edu', uri=uri, **kwargs + ) diff --git a/aiida/tools/dbimporters/plugins/oqmd.py b/aiida/tools/dbimporters/plugins/oqmd.py index b2249ade74..5af78fe127 100644 --- a/aiida/tools/dbimporters/plugins/oqmd.py +++ b/aiida/tools/dbimporters/plugins/oqmd.py @@ -7,11 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - - - -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, - CifEntry) +# pylint: disable=no-self-use +""""Implementation of `DbImporter` for the OQMD database.""" +from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry class OqmdDbImporter(DbImporter): @@ -24,8 +22,7 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str) and not isinstance(values, int): - raise ValueError("incorrect value for keyword '" + alias + \ - "' -- only strings and integers are accepted") + raise ValueError("incorrect value for keyword '" + alias + "' -- only strings and integers are accepted") return '{}={}'.format(key, values) _keywords = {'element': ['element', None]} @@ -65,27 +62,22 @@ def query(self, **kwargs): results = [] for entry in entries: - response = urlopen('{}{}'.format(self._query_url, - entry)).read() - structures = re.findall(r'/materials/export/conventional/cif/(\d+)', - response) + response = urlopen('{}{}'.format(self._query_url, entry)).read() + structures = re.findall(r'/materials/export/conventional/cif/(\d+)', response) for struct in structures: results.append({'id': struct}) return OqmdSearchResults(results) - def setup_db(self, query_url=None, **kwargs): + def setup_db(self, query_url=None, **kwargs): # pylint: disable=arguments-differ """ Changes the database connection details. """ if query_url: self._query_url = query_url - if kwargs.keys(): - raise NotImplementedError( \ - "unknown database connection parameter(s): '" + \ - "', '".join(kwargs.keys()) + \ - "', available parameters: 'query_url'") + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) def get_supported_keywords(self): """ @@ -96,7 +88,7 @@ def get_supported_keywords(self): return self._keywords.keys() -class OqmdSearchResults(DbSearchResults): +class OqmdSearchResults(DbSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on OQMD. """ @@ -127,7 +119,7 @@ def _get_url(self, result_dict): return self._base_url + result_dict['id'] -class OqmdEntry(CifEntry): +class OqmdEntry(CifEntry): # pylint: disable=abstract-method """ Represents an entry from OQMD. """ @@ -138,7 +130,4 @@ def __init__(self, uri, **kwargs): :py:class:`aiida.tools.dbimporters.plugins.oqmd.OqmdEntry`, related to the supplied URI. """ - super().__init__(db_name='Open Quantum Materials Database', - db_uri='http://oqmd.org', - uri=uri, - **kwargs) + super().__init__(db_name='Open Quantum Materials Database', db_uri='http://oqmd.org', uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/pcod.py b/aiida/tools/dbimporters/plugins/pcod.py index 7a7595c66c..4550ab5634 100644 --- a/aiida/tools/dbimporters/plugins/pcod.py +++ b/aiida/tools/dbimporters/plugins/pcod.py @@ -7,10 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - -from aiida.tools.dbimporters.plugins.cod import (CodDbImporter, - CodSearchResults, CodEntry) - +""""Implementation of `DbImporter` for the PCOD database.""" +from aiida.tools.dbimporters.plugins.cod import CodDbImporter, CodSearchResults, CodEntry class PcodDbImporter(CodDbImporter): @@ -18,50 +16,25 @@ class PcodDbImporter(CodDbImporter): Database importer for Predicted Crystallography Open Database. """ - def _int_clause(self, *args, **kwargs): - return super()._int_clause(*args, **kwargs) - - def _composition_clause(self, *args, **kwargs): - return super()._composition_clause(*args, **kwargs) - - def _formula_clause(self, *args, **kwargs): - return super()._formula_clause(*args, **kwargs) - - def _volume_clause(self, *args, **kwargs): - return super()._volume_clause(*args, **kwargs) - - def _str_exact_clause(self, *args, **kwargs): - return super()._str_exact_clause(*args, **kwargs) - - def _length_clause(self, *args, **kwargs): - return super()._length_clause(*args, **kwargs) - - def _angle_clause(self, *args, **kwargs): - return super()._angle_clause(*args, **kwargs) - - def _str_fuzzy_clause(self, *args, **kwargs): - return super()._str_fuzzy_clause(*args, **kwargs) - - _keywords = {'id': ['file', _int_clause], - 'element': ['element', _composition_clause], - 'number_of_elements': ['nel', _int_clause], - 'formula': ['formula', _formula_clause], - 'volume': ['vol', _volume_clause], - 'spacegroup': ['sg', _str_exact_clause], - 'a': ['a', _length_clause], - 'b': ['b', _length_clause], - 'c': ['c', _length_clause], - 'alpha': ['alpha', _angle_clause], - 'beta': ['beta', _angle_clause], - 'gamma': ['gamma', _angle_clause], - 'text': ['text', _str_fuzzy_clause]} + _keywords = { + 'id': ['file', CodDbImporter._int_clause], + 'element': ['element', CodDbImporter._composition_clause], + 'number_of_elements': ['nel', CodDbImporter._int_clause], + 'formula': ['formula', CodDbImporter._formula_clause], + 'volume': ['vol', CodDbImporter._volume_clause], + 'spacegroup': ['sg', CodDbImporter._str_exact_clause], + 'a': ['a', CodDbImporter._length_clause], + 'b': ['b', CodDbImporter._length_clause], + 'c': ['c', CodDbImporter._length_clause], + 'alpha': ['alpha', CodDbImporter._angle_clause], + 'beta': ['beta', CodDbImporter._angle_clause], + 'gamma': ['gamma', CodDbImporter._angle_clause], + 'text': ['text', CodDbImporter._str_fuzzy_clause] + } def __init__(self, **kwargs): super().__init__(**kwargs) - self._db_parameters = {'host': 'www.crystallography.net', - 'user': 'pcod_reader', - 'passwd': '', - 'db': 'pcod'} + self._db_parameters = {'host': 'www.crystallography.net', 'user': 'pcod_reader', 'passwd': '', 'db': 'pcod'} self.setup_db(**kwargs) def query_sql(self, **kwargs): @@ -72,24 +45,16 @@ def query_sql(self, **kwargs): :return: string containing a SQL statement. """ sql_parts = [] - for key in self._keywords.keys(): - if key in kwargs.keys(): + for key in self._keywords: + if key in kwargs: values = kwargs.pop(key) if not isinstance(values, list): values = [values] - sql_parts.append( \ - '(' + self._keywords[key][1](self, \ - self._keywords[key][0], \ - key, \ - values) + \ - ')') - if len(kwargs.keys()) > 0: - raise NotImplementedError( \ - "search keyword(s) '" + \ - "', '".join(kwargs.keys()) + "' " + \ - 'is(are) not implemented for PCOD') - return 'SELECT file FROM data WHERE ' + \ - ' AND '.join(sql_parts) + sql_parts.append('(' + self._keywords[key][1](self, self._keywords[key][0], key, values) + ')') + if kwargs: + raise NotImplementedError('following keyword(s) are not implemented: {}'.format(', '.join(kwargs.keys()))) + + return 'SELECT file FROM data WHERE ' + ' AND '.join(sql_parts) def query(self, **kwargs): """ @@ -113,7 +78,7 @@ def query(self, **kwargs): return PcodSearchResults(results) -class PcodSearchResults(CodSearchResults): +class PcodSearchResults(CodSearchResults): # pylint: disable=abstract-method """ Results of the search, performed on PCOD. """ @@ -129,27 +94,25 @@ def _get_url(self, result_dict): :param result_dict: dictionary, describing an entry in the results. """ - return self._base_url + \ - result_dict['id'][0] + '/' + \ - result_dict['id'][0:3] + '/' + \ - result_dict['id'] + '.cif' + return self._base_url + result_dict['id'][0] + '/' + result_dict['id'][0:3] + '/' + result_dict['id'] + '.cif' -class PcodEntry(CodEntry): +class PcodEntry(CodEntry): # pylint: disable=abstract-method """ Represents an entry from PCOD. """ _license = 'CC0' - def __init__(self, uri, - db_name='Predicted Crystallography Open Database', - db_uri='http://www.crystallography.net/pcod', **kwargs): + def __init__( + self, + uri, + db_name='Predicted Crystallography Open Database', + db_uri='http://www.crystallography.net/pcod', + **kwargs + ): """ Creates an instance of :py:class:`aiida.tools.dbimporters.plugins.pcod.PcodEntry`, related to the supplied URI. """ - super().__init__(db_name=db_name, - db_uri=db_uri, - uri=uri, - **kwargs) + super().__init__(db_name=db_name, db_uri=db_uri, uri=uri, **kwargs) diff --git a/aiida/tools/dbimporters/plugins/tcod.py b/aiida/tools/dbimporters/plugins/tcod.py index 70cf74e37e..7abdbd1275 100644 --- a/aiida/tools/dbimporters/plugins/tcod.py +++ b/aiida/tools/dbimporters/plugins/tcod.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Importer implementation for the TCOD.""" +""""Implementation of `DbImporter` for the TCOD database.""" from aiida.tools.dbimporters.plugins.cod import (CodDbImporter, CodSearchResults, CodEntry) diff --git a/aiida/tools/importexport/common/utils.py b/aiida/tools/importexport/common/utils.py index 1c5214f3e5..0aef11888a 100644 --- a/aiida/tools/importexport/common/utils.py +++ b/aiida/tools/importexport/common/utils.py @@ -8,8 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Utility functions for import/export of AiiDA entities """ -# pylint: disable=inconsistent-return-statements,too-many-branches,too-many-return-statements -# pylint: disable=too-many-nested-blocks,too-many-locals +# pylint: disable=too-many-branches,too-many-return-statements,too-many-nested-blocks,too-many-locals from html.parser import HTMLParser import urllib.request import urllib.parse diff --git a/aiida/tools/importexport/dbimport/utils.py b/aiida/tools/importexport/dbimport/utils.py index faaac81e1c..d25aab0cc0 100644 --- a/aiida/tools/importexport/dbimport/utils.py +++ b/aiida/tools/importexport/dbimport/utils.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Utility functions for import of AiiDA entities """ -# pylint: disable=inconsistent-return-statements,too-many-branches +# pylint: disable=too-many-branches import os import click diff --git a/aiida/workflows/arithmetic/multiply_add.py b/aiida/workflows/arithmetic/multiply_add.py index 125df06698..d63ab0329e 100644 --- a/aiida/workflows/arithmetic/multiply_add.py +++ b/aiida/workflows/arithmetic/multiply_add.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=inconsistent-return-statements,no-member +# pylint: disable=no-member # start-marker for docs """Implementation of the MultiplyAddWorkChain for testing and demonstration purposes.""" from aiida.orm import Code, Int diff --git a/docs/source/howto/codes.rst b/docs/source/howto/codes.rst index 713b1af9a0..e0247c4b93 100644 --- a/docs/source/howto/codes.rst +++ b/docs/source/howto/codes.rst @@ -206,7 +206,7 @@ The snippet of the previous section on :ref:`parsing the outputs @@ -169,8 +173,10 @@ class TestCommand(unittest.TestCase): + """Test various commands.""" def test_get_joblist_command(self): + """Test the `get_joblist_command`.""" sge = SgeScheduler() # TEST 1: @@ -194,6 +200,7 @@ def test_get_joblist_command(self): self.assertTrue('*' in sge_get_joblist_command) def test_detailed_jobinfo_command(self): + """Test the `get_detailed_jobinfo_command`.""" sge = SgeScheduler() sge_get_djobinfo_command = sge._get_detailed_job_info_command('123456') @@ -203,6 +210,7 @@ def test_detailed_jobinfo_command(self): self.assertTrue('-j' in sge_get_djobinfo_command) def test_get_submit_command(self): + """Test the `get_submit_command`.""" sge = SgeScheduler() sge_get_submit_command = sge._get_submit_command('script.sh') @@ -212,6 +220,7 @@ def test_get_submit_command(self): self.assertTrue('script.sh' in sge_get_submit_command) def test_parse_submit_output(self): + """Test the `parse_submit_command`.""" sge = SgeScheduler() # TEST 1: @@ -225,6 +234,7 @@ def test_parse_submit_output(self): logging.disable(logging.NOTSET) def test_parse_joblist_output(self): + """Test the `parse_joblist_command`.""" sge = SgeScheduler() retval = 0 @@ -294,6 +304,7 @@ def test_parse_joblist_output(self): logging.disable(logging.NOTSET) def test_submit_script(self): + """Test the submit script.""" from aiida.schedulers.datastructures import JobTemplate sge = SgeScheduler() diff --git a/tests/schedulers/test_slurm.py b/tests/schedulers/test_slurm.py index 057e26d3f8..47c97e8add 100644 --- a/tests/schedulers/test_slurm.py +++ b/tests/schedulers/test_slurm.py @@ -332,7 +332,3 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): # pylint: job_tmpl.job_resource = scheduler.create_job_resource( num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23 ) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/schedulers/test_torque.py b/tests/schedulers/test_torque.py index 7998036bd6..775e358e3e 100644 --- a/tests/schedulers/test_torque.py +++ b/tests/schedulers/test_torque.py @@ -7,11 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +# pylint: disable=invalid-name,protected-access,too-many-lines +"""Tests for the `TorqueScheduler` plugin.""" import unittest import uuid + from aiida.schedulers.datastructures import JobState -from aiida.schedulers.plugins.torque import * +from aiida.schedulers.plugins.torque import TorqueScheduler text_qstat_f_to_test = """Job Id: 68350.mycluster Job_Name = cell-Qnormal @@ -762,6 +764,7 @@ def test_parse_common_joblist_output(self): """ Test whether _parse_joblist can parse the qstat -f output """ + # pylint: disable=too-many-locals s = TorqueScheduler() retval = 0 @@ -810,13 +813,13 @@ def test_parse_common_joblist_output(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars def test_parse_with_unexpected_newlines(self): """ Test whether _parse_joblist can parse the qstat -f output also when there are unexpected newlines """ + # pylint: disable=too-many-locals s = TorqueScheduler() retval = 0 @@ -831,28 +834,23 @@ def test_parse_with_unexpected_newlines(self): self.assertEqual(job_parsed, job_on_cluster) job_running = 2 - job_running_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING]) + job_running_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.RUNNING]) self.assertEqual(job_running, job_running_parsed) job_held = 1 - job_held_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED_HELD]) + job_held_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED_HELD]) self.assertEqual(job_held, job_held_parsed) job_queued = 5 - job_queued_parsed = len([j for j in job_list if j.job_state \ - and j.job_state == JobState.QUEUED]) + job_queued_parsed = len([j for j in job_list if j.job_state and j.job_state == JobState.QUEUED]) self.assertEqual(job_queued, job_queued_parsed) running_users = ['somebody', 'user_556491'] - parsed_running_users = [j.job_owner for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_users = [j.job_owner for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_users), set(parsed_running_users)) running_jobs = ['555716', '556491'] - parsed_running_jobs = [j.job_id for j in job_list if j.job_state \ - and j.job_state == JobState.RUNNING] + parsed_running_jobs = [j.job_id for j in job_list if j.job_state and j.job_state == JobState.RUNNING] self.assertEqual(set(running_jobs), set(parsed_running_jobs)) for j in job_list: @@ -865,10 +863,10 @@ def test_parse_with_unexpected_newlines(self): self.assertTrue(j.num_machines == num_machines) self.assertTrue(j.num_cpus == num_cpus) - # TODO : parse the env_vars class TestSubmitScript(unittest.TestCase): + """Test the submit script.""" def test_submit_script(self): """ @@ -895,8 +893,7 @@ def test_submit_script(self): self.assertTrue('#PBS -r n' in submit_script_text) self.assertTrue(submit_script_text.startswith('#!/bin/bash')) self.assertTrue('#PBS -l nodes=1:ppn=1,walltime=24:00:00' in submit_script_text) - self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + \ - " < 'aiida.in'" in submit_script_text) + self.assertTrue("'mpirun' '-np' '23' 'pw.x' '-npool' '1'" + " < 'aiida.in'" in submit_script_text) def test_submit_script_with_num_cores_per_machine(self): """ @@ -906,12 +903,13 @@ def test_submit_script_with_num_cores_per_machine(self): from aiida.schedulers.datastructures import JobTemplate from aiida.common.datastructures import CodeInfo, CodeRunMode - s = TorqueScheduler() + scheduler = TorqueScheduler() job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' - job_tmpl.job_resource = s.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24) + job_tmpl.job_resource = scheduler.create_job_resource( + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -920,7 +918,7 @@ def test_submit_script_with_num_cores_per_machine(self): job_tmpl.codes_info = [code_info] job_tmpl.codes_run_mode = CodeRunMode.SERIAL - submit_script_text = s.get_submit_script(job_tmpl) + submit_script_text = scheduler.get_submit_script(job_tmpl) self.assertTrue('#PBS -r n' in submit_script_text) self.assertTrue(submit_script_text.startswith('#!/bin/bash')) @@ -940,7 +938,8 @@ def test_submit_script_with_num_cores_per_mpiproc(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -971,7 +970,8 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc1(self): job_tmpl = JobTemplate() job_tmpl.shebang = '#!/bin/bash' job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=24 + ) job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.max_wallclock_seconds = 24 * 3600 code_info = CodeInfo() @@ -1001,4 +1001,5 @@ def test_submit_script_with_num_cores_per_machine_and_mpiproc2(self): job_tmpl = JobTemplate() with self.assertRaises(ValueError): job_tmpl.job_resource = scheduler.create_job_resource( - num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23) + num_machines=1, num_mpiprocs_per_machine=1, num_cores_per_machine=24, num_cores_per_mpiproc=23 + ) diff --git a/tests/sphinxext/workchain_source/conf.py b/tests/sphinxext/workchain_source/conf.py index 1c68124fff..bb3ced18f9 100644 --- a/tests/sphinxext/workchain_source/conf.py +++ b/tests/sphinxext/workchain_source/conf.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-builtin,invalid-name,missing-module-docstring # # sphinx-aiida-demo documentation build configuration file, created by # sphinx-quickstart on Mon Oct 2 13:04:07 2017. diff --git a/tests/sphinxext/workchain_source_broken/conf.py b/tests/sphinxext/workchain_source_broken/conf.py index 1c68124fff..bb3ced18f9 100644 --- a/tests/sphinxext/workchain_source_broken/conf.py +++ b/tests/sphinxext/workchain_source_broken/conf.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=redefined-builtin,invalid-name,missing-module-docstring # # sphinx-aiida-demo documentation build configuration file, created by # sphinx-quickstart on Mon Oct 2 13:04:07 2017. diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index ec908c05ea..75f73c0d21 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -7,8 +7,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines,invalid-name """Tests for specific subclasses of Data.""" - import os import tempfile import unittest @@ -39,7 +39,7 @@ def to_list_of_lists(lofl): :param lofl: an iterable of iterables :return: a list of lists""" - return [[el for el in l] for l in lofl] + return [[el for el in l] for l in lofl] # pylint: disable=unnecessary-comprehension def simplify(string): @@ -96,6 +96,7 @@ class TestCifData(AiidaTestCase): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_reload_cifdata(self): + """Test `CifData` cycle.""" file_content = 'data_test _cell_length_a 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: filename = tmpf.name @@ -157,6 +158,7 @@ def test_reload_cifdata(self): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_parse_cifdata(self): + """Test parsing a CIF file.""" file_content = 'data_test _cell_length_a 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(file_content) @@ -167,6 +169,7 @@ def test_parse_cifdata(self): @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_change_cifdata_file(self): + """Test changing file for `CifData` before storing.""" file_content_1 = 'data_test _cell_length_a 10(1)' file_content_2 = 'data_test _cell_length_a 11(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -186,6 +189,7 @@ def test_change_cifdata_file(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_get_structure(self): + """Test `CifData.get_structure`.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -422,6 +426,7 @@ def test_cif_with_long_line(): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_cif_roundtrip(self): + """Test the `CifData` roundtrip.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -455,6 +460,7 @@ def test_cif_roundtrip(self): self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access def test_symop_string_from_symop_matrix_tr(self): + """Test symmetry operations.""" from aiida.tools.data.cif import symop_string_from_symop_matrix_tr self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 'x,y,z') @@ -468,6 +474,7 @@ def test_symop_string_from_symop_matrix_tr(self): @unittest.skipIf(not has_ase(), 'Unable to import ase') @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') def test_attached_hydrogens(self): + """Test parsing of file with attached hydrogens.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( ''' @@ -1024,6 +1031,7 @@ class TestStructureData(AiidaTestCase): """ Tests the creation of StructureData objects (cell and pbc). """ + # pylint: disable=too-many-public-methods from aiida.orm.nodes.data.structure import has_ase, has_spglib from aiida.orm.nodes.data.cif import has_pycifrw @@ -1362,9 +1370,9 @@ def test_kind_5_bis_ase(self): [4, 0, 0], ]) - asecell[0].mass = 12. - asecell[1].mass = 12. - asecell[2].mass = 12. + asecell[0].mass = 12. # pylint: disable=assigning-non-slot + asecell[1].mass = 12. # pylint: disable=assigning-non-slot + asecell[2].mass = 12. # pylint: disable=assigning-non-slot s = StructureData(ase=asecell) @@ -1392,9 +1400,9 @@ def test_kind_5_bis_ase_unknown(self): [4, 0, 0], ]) - asecell[0].mass = 12. - asecell[1].mass = 12. - asecell[2].mass = 12. + asecell[0].mass = 12. # pylint: disable=assigning-non-slot + asecell[1].mass = 12. # pylint: disable=assigning-non-slot + asecell[2].mass = 12. # pylint: disable=assigning-non-slot s = StructureData(ase=asecell) @@ -1497,6 +1505,7 @@ def test_kind_8(self): """ Test the ase_refine_cell() function """ + # pylint: disable=too-many-statements from aiida.orm.nodes.data.structure import ase_refine_cell import ase import math @@ -1912,7 +1921,7 @@ def test_ase(self): (0., 0., 0.), (0.5, 0.7, 0.9), )) - a[1].mass = 110.2 + a[1].mass = 110.2 # pylint: disable=assigning-non-slot b = StructureData(ase=a) c = b.get_ase() @@ -1973,8 +1982,8 @@ def test_conversion_of_types_2(self): )) a.set_tags((0, 1, 0, 1)) - a[2].mass = 100. - a[3].mass = 300. + a[2].mass = 100. # pylint: disable=assigning-non-slot + a[3].mass = 300. # pylint: disable=assigning-non-slot b = StructureData(ase=a) # This will give funny names to the kinds, because I am using @@ -2026,9 +2035,9 @@ def test_conversion_of_types_4(self): import ase atoms = ase.Atoms('Fe5') - atoms[2].tag = 1 - atoms[3].tag = 1 - atoms[4].tag = 4 + atoms[2].tag = 1 # pylint: disable=assigning-non-slot + atoms[3].tag = 1 # pylint: disable=assigning-non-slot + atoms[4].tag = 4 # pylint: disable=assigning-non-slot atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} @@ -2048,9 +2057,9 @@ def test_conversion_of_types_5(self): import ase atoms = ase.Atoms('Fe5') - atoms[0].tag = 1 - atoms[2].tag = 1 - atoms[3].tag = 4 + atoms[0].tag = 1 # pylint: disable=assigning-non-slot + atoms[2].tag = 1 # pylint: disable=assigning-non-slot + atoms[3].tag = 4 # pylint: disable=assigning-non-slot atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} @@ -2517,6 +2526,7 @@ def test_creation(self): Check the methods to add, remove, modify, and get arrays and array shapes. """ + # pylint: disable=too-many-statements import numpy # Create a node with two arrays @@ -2635,6 +2645,7 @@ class TestTrajectoryData(AiidaTestCase): def test_creation(self): """Check the methods to set and retrieve a trajectory.""" + # pylint: disable=too-many-statements import numpy # Create a node with two arrays @@ -3238,6 +3249,7 @@ def test_tetra_z_wrapper_legacy(self): class TestSpglibTupleConversion(AiidaTestCase): + """Tests for conversion of Spglib tuples.""" def test_simple_to_aiida(self): """ @@ -3385,9 +3397,11 @@ def test_aiida_roundtrip(self): class TestSeekpathExplicitPath(AiidaTestCase): + """Tests for the `get_explicit_kpoints_path` from SeeK-path.""" @unittest.skipIf(not has_seekpath(), 'No seekpath available') def test_simple(self): + """Test a simple case.""" import numpy as np from aiida.plugins import DataFactory @@ -3575,7 +3589,7 @@ def test_band(self): b.set_bands(input_bands, units='ev') b.set_bands(input_bands, occupations=input_occupations) with self.assertRaises(TypeError): - b.set_bands(occupations=input_occupations, units='ev') + b.set_bands(occupations=input_occupations, units='ev') # pylint: disable=no-value-for-parameter b.set_bands(input_bands, occupations=input_occupations, units='ev') bands, occupations = b.get_bands(also_occupations=True) diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index 2154dca36b..26a47d6a7f 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -72,14 +72,12 @@ def test_datatype_checks(self): from aiida.tools.dbimporters.plugins.cod import CodDbImporter codi = CodDbImporter() - messages = ['', - "incorrect value for keyword 'test' -- " + \ - 'only integers and strings are accepted', - "incorrect value for keyword 'test' -- " + \ - 'only strings are accepted', - "incorrect value for keyword 'test' -- " + \ - 'only integers and floats are accepted', - "invalid literal for int() with base 10: 'text'"] + messages = [ + '', "incorrect value for keyword 'test' only integers and strings are accepted", + "incorrect value for keyword 'test' only strings are accepted", + "incorrect value for keyword 'test' only integers and floats are accepted", + "invalid literal for int() with base 10: 'text'" + ] values = [10, 'text', 'text', '10', 1.0 / 3, [1, 2, 3]] methods = [ # pylint: disable=protected-access diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 59a39dfcc1..9a90aac453 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-lines,fixme """ This module contains a set of unittest test classes that can be loaded from the plugin. @@ -41,7 +42,9 @@ def get_all_custom_transports(): thisdir, thisfname = os.path.split(this_full_fname) test_modules = [ - os.path.split(fname)[1][:-3] for fname in os.listdir(thisdir) if fname.endswith('.py') and fname.startswith('test_') + os.path.split(fname)[1][:-3] + for fname in os.listdir(thisdir) + if fname.endswith('.py') and fname.startswith('test_') ] # Remove this module: note that I should be careful because __file__, from @@ -53,13 +56,13 @@ def get_all_custom_transports(): print('Warning, this module ({}) was not found!'.format(thisbasename)) all_custom_transports = {} - for m in test_modules: - module = importlib.import_module('.'.join([modulename, m])) + for module in test_modules: + module = importlib.import_module('.'.join([modulename, module])) custom_transport = module.__dict__.get('plugin_transport', None) if custom_transport is None: - print('Define the plugin_transport variable inside the {} module!'.format(m)) + print('Define the plugin_transport variable inside the {} module!'.format(module)) else: - all_custom_transports[m] = custom_transport + all_custom_transports[module] = custom_transport return all_custom_transports @@ -84,10 +87,9 @@ def test_all_plugins(self): for tr_name, custom_transport in all_custom_transports.items(): try: actual_test_method(self, custom_transport) - except Exception as e: + except Exception as exception: # pylint: disable=broad-except import traceback - - exceptions.append((e, traceback.format_exc(), tr_name)) + exceptions.append((exception, traceback.format_exc(), tr_name)) if exceptions: if all(isinstance(exc[0], AssertionError) for exc in exceptions): @@ -97,8 +99,7 @@ def test_all_plugins(self): messages = ['*** At least one test for a subplugin failed. See below ***', ''] for exc in exceptions: - messages.append("*** [For plugin {}]: Exception '{}': {}" - .format(exc[2], type(exc[0]).__name__, exc[0])) + messages.append("*** [For plugin {}]: Exception '{}': {}".format(exc[2], type(exc[0]).__name__, exc[0])) messages.append(exc[1]) raise exception_to_raise('\n'.join(messages)) @@ -137,38 +138,38 @@ def test_makedirs(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) # define folder structure dir_tree = os.path.join('1', '2') # I create the tree - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # verify the existence - self.assertTrue(t.isdir('1')) + self.assertTrue(transport.isdir('1')) self.assertTrue(dir_tree) # try to recreate the same folder with self.assertRaises(OSError): - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # recreate but with ignore flag - t.makedirs(dir_tree, True) + transport.makedirs(dir_tree, True) - t.rmdir(dir_tree) - t.rmdir('1') + transport.rmdir(dir_tree) + transport.rmdir('1') - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_rmtree(self, custom_transport): @@ -180,40 +181,40 @@ def test_rmtree(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) # define folder structure dir_tree = os.path.join('1', '2') # I create the tree - t.makedirs(dir_tree) + transport.makedirs(dir_tree) # remove it - t.rmtree('1') + transport.rmtree('1') # verify the removal - self.assertFalse(t.isdir('1')) + self.assertFalse(transport.isdir('1')) # also tests that it works with a single file # create file local_file_name = 'file.txt' text = 'Viva Verdi\n' - with open(os.path.join(t.getcwd(), local_file_name), 'w', encoding='utf8') as fhandle: + with open(os.path.join(transport.getcwd(), local_file_name), 'w', encoding='utf8') as fhandle: fhandle.write(text) # remove it - t.rmtree(local_file_name) + transport.rmtree(local_file_name) # verify the removal - self.assertFalse(t.isfile(local_file_name)) + self.assertFalse(transport.isfile(local_file_name)) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_listdir(self, custom_transport): @@ -317,9 +318,15 @@ def simplify_attributes(data): 'a2': True, 'a4f': True, 'a': False - }) + } + ) self.assertTrue(simplify_attributes(trans.listdir_withattributes('.', 'a?')), {'as': True, 'a2': True}) - self.assertTrue(simplify_attributes(trans.listdir_withattributes('.', 'a[2-4]*')), {'a2': True, 'a4f': True}) + self.assertTrue( + simplify_attributes(trans.listdir_withattributes('.', 'a[2-4]*')), { + 'a2': True, + 'a4f': True + } + ) for this_dir in list_of_dir: trans.rmdir(this_dir) @@ -332,29 +339,30 @@ def simplify_attributes(data): @run_for_all_plugins def test_dir_creation_deletion(self, custom_transport): + """Test creating and deleting directories.""" # Imports required later import random import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - self.assertEqual(location, t.getcwd()) - while t.isdir(directory): + self.assertEqual(location, transport.getcwd()) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) + transport.mkdir(directory) with self.assertRaises(OSError): # I create twice the same directory - t.mkdir(directory) + transport.mkdir(directory) - t.isdir(directory) - self.assertFalse(t.isfile(directory)) - t.rmdir(directory) + transport.isdir(directory) + self.assertFalse(transport.isfile(directory)) + transport.rmdir(directory) @run_for_all_plugins def test_dir_copy(self, custom_transport): @@ -367,30 +375,30 @@ def test_dir_copy(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) + transport.mkdir(directory) dest_directory = directory + '_copy' - t.copy(directory, dest_directory) + transport.copy(directory, dest_directory) with self.assertRaises(ValueError): - t.copy(directory, '') + transport.copy(directory, '') with self.assertRaises(ValueError): - t.copy('', directory) + transport.copy('', directory) - t.rmdir(directory) - t.rmdir(dest_directory) + transport.rmdir(directory) + transport.rmdir(dest_directory) @run_for_all_plugins - def test_dir_permissions_creation_modification(self, custom_transport): + def test_dir_permissions_creation_modification(self, custom_transport): # pylint: disable=invalid-name """ verify if chmod raises IOError when trying to change bits on a non-existing folder @@ -400,51 +408,51 @@ def test_dir_permissions_creation_modification(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - t.mkdir(directory) + transport.mkdir(directory) # change permissions - t.chmod(directory, 0o777) + transport.chmod(directory, 0o777) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0o777) + self.assertEqual(transport.get_mode(directory), 0o777) # change permissions - t.chmod(directory, 0o511) + transport.chmod(directory, 0o511) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0o511) + self.assertEqual(transport.get_mode(directory), 0o511) # TODO : bug in paramiko. When changing the directory to very low \ # I cannot set it back to higher permissions - ## TODO: probably here we should then check for - ## the new directory modes. To see if we want a higher - ## level function to ask for the mode, or we just - ## use get_attribute - t.chdir(directory) + # TODO: probably here we should then check for + # the new directory modes. To see if we want a higher + # level function to ask for the mode, or we just + # use get_attribute + transport.chdir(directory) # change permissions of an empty string, non existing folder. fake_dir = '' with self.assertRaises(IOError): - t.chmod(fake_dir, 0o777) + transport.chmod(fake_dir, 0o777) fake_dir = 'pippo' with self.assertRaises(IOError): # chmod to a non existing folder - t.chmod(fake_dir, 0o777) + transport.chmod(fake_dir, 0o777) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_dir_reading_permissions(self, custom_transport): @@ -457,37 +465,37 @@ def test_dir_reading_permissions(self, custom_transport): import string import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - t.chdir(location) + transport.chdir(location) - while t.isdir(directory): + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - t.mkdir(directory) + transport.mkdir(directory) # change permissions to low ones - t.chmod(directory, 0) + transport.chmod(directory, 0) # test if the security bits have changed - self.assertEqual(t.get_mode(directory), 0) + self.assertEqual(transport.get_mode(directory), 0) - old_cwd = t.getcwd() + old_cwd = transport.getcwd() with self.assertRaises(IOError): - t.chdir(directory) + transport.chdir(directory) - new_cwd = t.getcwd() + new_cwd = transport.getcwd() self.assertEqual(old_cwd, new_cwd) # TODO : the test leaves a directory even if it is successful # The bug is in paramiko. After lowering the permissions, # I cannot restore them to higher values - #t.rmdir(directory) + # transport.rmdir(directory) @run_for_all_plugins def test_isfile_isdir_to_empty_string(self, custom_transport): @@ -497,11 +505,11 @@ def test_isfile_isdir_to_empty_string(self, custom_transport): """ import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) - t.chdir(location) - self.assertFalse(t.isdir('')) - self.assertFalse(t.isfile('')) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(location) + self.assertFalse(transport.isdir('')) + self.assertFalse(transport.isfile('')) @run_for_all_plugins def test_isfile_isdir_to_non_existing_string(self, custom_transport): @@ -511,14 +519,14 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport): """ import os - with custom_transport as t: - location = t.normalize(os.path.join('/', 'tmp')) - t.chdir(location) + with custom_transport as transport: + location = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(location) fake_folder = 'pippo' - self.assertFalse(t.isfile(fake_folder)) - self.assertFalse(t.isdir(fake_folder)) + self.assertFalse(transport.isfile(fake_folder)) + self.assertFalse(transport.isdir(fake_folder)) with self.assertRaises(IOError): - t.chdir(fake_folder) + transport.chdir(fake_folder) @run_for_all_plugins def test_chdir_to_empty_string(self, custom_transport): @@ -529,11 +537,11 @@ def test_chdir_to_empty_string(self, custom_transport): """ import os - with custom_transport as t: - new_dir = t.normalize(os.path.join('/', 'tmp')) - t.chdir(new_dir) - t.chdir('') - self.assertEqual(new_dir, t.getcwd()) + with custom_transport as transport: + new_dir = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(new_dir) + transport.chdir('') + self.assertEqual(new_dir, transport.getcwd()) class TestPutGetFile(unittest.TestCase): @@ -546,6 +554,7 @@ class TestPutGetFile(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): + """Test putting and getting files.""" import os import random import string @@ -554,14 +563,14 @@ def test_put_and_get(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_dir, directory, 'file.txt') remote_file_name = 'file_remote.txt' @@ -572,12 +581,12 @@ def test_put_and_get(self, custom_transport): fhandle.write(text) # here use full path in src and dst - t.put(local_file_name, remote_file_name) - t.get(remote_file_name, retrieved_file_name) - t.putfile(local_file_name, remote_file_name) - t.getfile(remote_file_name, retrieved_file_name) + transport.put(local_file_name, remote_file_name) + transport.get(remote_file_name, retrieved_file_name) + transport.putfile(local_file_name, remote_file_name) + transport.getfile(remote_file_name, retrieved_file_name) - list_of_files = t.listdir('.') + list_of_files = transport.listdir('.') # it is False because local_file_name has the full path, # while list_of_files has not self.assertFalse(local_file_name in list_of_files) @@ -585,11 +594,11 @@ def test_put_and_get(self, custom_transport): self.assertFalse(retrieved_file_name in list_of_files) os.remove(local_file_name) - t.remove(remote_file_name) + transport.remove(remote_file_name) os.remove(retrieved_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_abs_path(self, custom_transport): @@ -604,14 +613,14 @@ def test_put_get_abs_path(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) partial_file_name = 'file.txt' local_file_name = os.path.join(local_dir, directory, 'file.txt') @@ -623,36 +632,36 @@ def test_put_get_abs_path(self, custom_transport): # partial_file_name is not an abs path with self.assertRaises(ValueError): - t.put(partial_file_name, remote_file_name) + transport.put(partial_file_name, remote_file_name) with self.assertRaises(ValueError): - t.putfile(partial_file_name, remote_file_name) + transport.putfile(partial_file_name, remote_file_name) # retrieved_file_name does not exist with self.assertRaises(OSError): - t.put(retrieved_file_name, remote_file_name) + transport.put(retrieved_file_name, remote_file_name) with self.assertRaises(OSError): - t.putfile(retrieved_file_name, remote_file_name) + transport.putfile(retrieved_file_name, remote_file_name) # remote_file_name does not exist with self.assertRaises(IOError): - t.get(remote_file_name, retrieved_file_name) + transport.get(remote_file_name, retrieved_file_name) with self.assertRaises(IOError): - t.getfile(remote_file_name, retrieved_file_name) + transport.getfile(remote_file_name, retrieved_file_name) - t.put(local_file_name, remote_file_name) - t.putfile(local_file_name, remote_file_name) + transport.put(local_file_name, remote_file_name) + transport.putfile(local_file_name, remote_file_name) # local filename is not an abs path with self.assertRaises(ValueError): - t.get(remote_file_name, 'delete_me.txt') + transport.get(remote_file_name, 'delete_me.txt') with self.assertRaises(ValueError): - t.getfile(remote_file_name, 'delete_me.txt') + transport.getfile(remote_file_name, 'delete_me.txt') - t.remove(remote_file_name) + transport.remove(remote_file_name) os.remove(local_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_empty_string(self, custom_transport): @@ -668,14 +677,14 @@ def test_put_get_empty_string(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) - while t.isdir(directory): + with custom_transport as transport: + transport.chdir(remote_dir) + while transport.isdir(directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_dir, directory, 'file_local.txt') remote_file_name = 'file_remote.txt' @@ -688,48 +697,48 @@ def test_put_get_empty_string(self, custom_transport): # localpath is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.put('', remote_file_name) + transport.put('', remote_file_name) with self.assertRaises(ValueError): - t.putfile('', remote_file_name) + transport.putfile('', remote_file_name) # remote path is an empty string with self.assertRaises(IOError): - t.put(local_file_name, '') + transport.put(local_file_name, '') with self.assertRaises(IOError): - t.putfile(local_file_name, '') + transport.putfile(local_file_name, '') - t.put(local_file_name, remote_file_name) + transport.put(local_file_name, remote_file_name) # overwrite the remote_file_name - t.putfile(local_file_name, remote_file_name) + transport.putfile(local_file_name, remote_file_name) # remote path is an empty string with self.assertRaises(IOError): - t.get('', retrieved_file_name) + transport.get('', retrieved_file_name) with self.assertRaises(IOError): - t.getfile('', retrieved_file_name) + transport.getfile('', retrieved_file_name) # local path is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.get(remote_file_name, '') + transport.get(remote_file_name, '') with self.assertRaises(ValueError): - t.getfile(remote_file_name, '') + transport.getfile(remote_file_name, '') # TODO : get doesn't retrieve empty files. # Is it what we want? - t.get(remote_file_name, retrieved_file_name) + transport.get(remote_file_name, retrieved_file_name) # overwrite retrieved_file_name - t.getfile(remote_file_name, retrieved_file_name) + transport.getfile(remote_file_name, retrieved_file_name) os.remove(local_file_name) - t.remove(remote_file_name) + transport.remove(remote_file_name) # If it couldn't end the copy, it leaves what he did on # local file - self.assertTrue('file_retrieved.txt' in t.listdir('.')) + self.assertTrue('file_retrieved.txt' in transport.listdir('.')) os.remove(retrieved_file_name) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) class TestPutGetTree(unittest.TestCase): @@ -742,6 +751,7 @@ class TestPutGetTree(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): + """Test putting and getting files.""" import os import random import string @@ -750,9 +760,9 @@ def test_put_and_get(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: + with custom_transport as transport: - t.chdir(remote_dir) + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -765,7 +775,7 @@ def test_put_and_get(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') @@ -776,14 +786,14 @@ def test_put_and_get(self, custom_transport): # here use full path in src and dst for i in range(2): if i == 0: - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) else: - t.puttree(local_subfolder, remote_subfolder) - t.gettree(remote_subfolder, retrieved_subfolder) + transport.puttree(local_subfolder, remote_subfolder) + transport.gettree(remote_subfolder, retrieved_subfolder) # Here I am mixing the local with the remote fold - list_of_dirs = t.listdir('.') + list_of_dirs = transport.listdir('.') # # it is False because local_file_name has the full path, # # while list_of_files has not self.assertFalse(local_subfolder in list_of_dirs) @@ -792,8 +802,8 @@ def test_put_and_get(self, custom_transport): self.assertTrue('tmp1' in list_of_dirs) self.assertTrue('tmp3' in list_of_dirs) - list_pushed_file = t.listdir('tmp2') - list_retrieved_file = t.listdir('tmp3') + list_pushed_file = transport.listdir('tmp2') + list_retrieved_file = transport.listdir('tmp3') self.assertTrue('file.txt' in list_pushed_file) self.assertTrue('file.txt' in list_retrieved_file) @@ -801,23 +811,25 @@ def test_put_and_get(self, custom_transport): shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) - t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_and_get_overwrite(self, custom_transport): - import os, shutil + """Test putting and getting files with overwrites.""" + import os import random + import shutil import string local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -830,7 +842,7 @@ def test_put_and_get_overwrite(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') @@ -838,32 +850,33 @@ def test_put_and_get_overwrite(self, custom_transport): with open(local_file_name, 'w', encoding='utf8') as fhandle: fhandle.write(text) - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) # by defaults rewrite everything - t.put(local_subfolder, remote_subfolder) - t.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, remote_subfolder) + transport.get(remote_subfolder, retrieved_subfolder) with self.assertRaises(OSError): - t.put(local_subfolder, remote_subfolder, overwrite=False) + transport.put(local_subfolder, remote_subfolder, overwrite=False) with self.assertRaises(OSError): - t.get(remote_subfolder, retrieved_subfolder, overwrite=False) + transport.get(remote_subfolder, retrieved_subfolder, overwrite=False) with self.assertRaises(OSError): - t.puttree(local_subfolder, remote_subfolder, overwrite=False) + transport.puttree(local_subfolder, remote_subfolder, overwrite=False) with self.assertRaises(OSError): - t.gettree(remote_subfolder, retrieved_subfolder, overwrite=False) + transport.gettree(remote_subfolder, retrieved_subfolder, overwrite=False) shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) - t.rmtree(remote_subfolder) - # t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) + # transport.rmtree(remote_subfolder) # here I am mixing inevitably the local and the remote folder - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_copy(self, custom_transport): + """Test copying.""" import os import random import string @@ -872,15 +885,15 @@ def test_copy(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -895,46 +908,48 @@ def test_copy(self, custom_transport): fhandle.write(text) # first test the copy. Copy of two files matching patterns, into a folder - t.copy(os.path.join('local', '*.txt'), '.') - self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(t.listdir('.'))) - t.remove('a.txt') - t.remove('c.txt') + transport.copy(os.path.join('local', '*.txt'), '.') + self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(transport.listdir('.'))) + transport.remove('a.txt') + transport.remove('c.txt') # second test copy. Copy of two folders - t.copy('local', 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir('prova'))) - t.rmtree('prova') + transport.copy('local', 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir('prova'))) + transport.rmtree('prova') # third test copy. Can copy one file into a new file - t.copy(os.path.join('local', '*.tmp'), 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - t.remove('prova') + transport.copy(os.path.join('local', '*.tmp'), 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + transport.remove('prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.copy(os.path.join('local', '*.txt'), 'prova') + transport.copy(os.path.join('local', '*.txt'), 'prova') # fifth test, copying one file into a folder - t.mkdir('prova') - t.copy(os.path.join('local', 'a.txt'), 'prova') - self.assertEqual(set(t.listdir('prova')), set(['a.txt'])) - t.rmtree('prova') + transport.mkdir('prova') + transport.copy(os.path.join('local', 'a.txt'), 'prova') + self.assertEqual(set(transport.listdir('prova')), set(['a.txt'])) + transport.rmtree('prova') # sixth test, copying one file into a file - t.copy(os.path.join('local', 'a.txt'), 'prova') - self.assertTrue(t.isfile('prova')) - t.remove('prova') + transport.copy(os.path.join('local', 'a.txt'), 'prova') + self.assertTrue(transport.isfile('prova')) + transport.remove('prova') # copy of folder into an existing folder - #NOTE: the command cp has a different behavior on Mac vs Ubuntu - #tests performed locally on a Mac may result in a failure. - t.mkdir('prova') - t.copy('local', 'prova') - self.assertEqual(set(['local']), set(t.listdir('prova'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir(os.path.join('prova', 'local')))) - t.rmtree('prova') + # NOTE: the command cp has a different behavior on Mac vs Ubuntu + # tests performed locally on a Mac may result in a failure. + transport.mkdir('prova') + transport.copy('local', 'prova') + self.assertEqual(set(['local']), set(transport.listdir('prova'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir(os.path.join('prova', 'local')))) + transport.rmtree('prova') # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_put(self, custom_transport): + """Test putting files.""" + # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute import os @@ -945,15 +960,15 @@ def test_put(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -967,72 +982,75 @@ def test_put(self, custom_transport): with open(filename, 'w', encoding='utf8') as fhandle: fhandle.write(text) - # first test put. Copy of two files matching patterns, into a folder - t.put(os.path.join(local_base_dir, '*.txt'), '.') - self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(t.listdir('.'))) - t.remove('a.txt') - t.remove('c.txt') + # first test putransport. Copy of two files matching patterns, into a folder + transport.put(os.path.join(local_base_dir, '*.txt'), '.') + self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(transport.listdir('.'))) + transport.remove('a.txt') + transport.remove('c.txt') # second. Copy of folder into a non existing folder - t.put(local_base_dir, 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir('prova'))) - t.rmtree('prova') + transport.put(local_base_dir, 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir('prova'))) + transport.rmtree('prova') # third. copy of folder into an existing folder - t.mkdir('prova') - t.put(local_base_dir, 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - self.assertEqual(set(['local']), set(t.listdir('prova'))) - self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(t.listdir(os.path.join('prova', 'local')))) - t.rmtree('prova') + transport.mkdir('prova') + transport.put(local_base_dir, 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + self.assertEqual(set(['local']), set(transport.listdir('prova'))) + self.assertEqual(set(['a.txt', 'b.tmp', 'c.txt']), set(transport.listdir(os.path.join('prova', 'local')))) + transport.rmtree('prova') # third test copy. Can copy one file into a new file - t.put(os.path.join(local_base_dir, '*.tmp'), 'prova') - self.assertEqual(set(['prova', 'local']), set(t.listdir('.'))) - t.remove('prova') + transport.put(os.path.join(local_base_dir, '*.tmp'), 'prova') + self.assertEqual(set(['prova', 'local']), set(transport.listdir('.'))) + transport.remove('prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.put(os.path.join(local_base_dir, '*.txt'), 'prova') + transport.put(os.path.join(local_base_dir, '*.txt'), 'prova') # copy of folder into file with open(os.path.join(local_dir, directory, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with self.assertRaises(OSError): - t.put(os.path.join(local_base_dir), 'existing.txt') - t.remove('existing.txt') + transport.put(os.path.join(local_base_dir), 'existing.txt') + transport.remove('existing.txt') # fifth test, copying one file into a folder - t.mkdir('prova') - t.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - self.assertEqual(set(t.listdir('prova')), set(['a.txt'])) - t.rmtree('prova') + transport.mkdir('prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') + self.assertEqual(set(transport.listdir('prova')), set(['a.txt'])) + transport.rmtree('prova') # sixth test, copying one file into a file - t.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - self.assertTrue(t.isfile('prova')) - t.remove('prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') + self.assertTrue(transport.isfile('prova')) + transport.remove('prova') # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_get(self, custom_transport): + """Test getting files.""" + # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute import os import random - import string, shutil + import shutil + import string local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - t.mkdir(directory) - t.chdir(directory) + transport.mkdir(directory) + transport.chdir(directory) local_base_dir = os.path.join(local_dir, directory, 'local') local_destination = os.path.join(local_dir, directory) @@ -1048,51 +1066,53 @@ def test_get(self, custom_transport): fhandle.write(text) # first test put. Copy of two files matching patterns, into a folder - t.get(os.path.join('local', '*.txt'), local_destination) + transport.get(os.path.join('local', '*.txt'), local_destination) self.assertEqual(set(['a.txt', 'c.txt', 'local']), set(os.listdir(local_destination))) os.remove(os.path.join(local_destination, 'a.txt')) os.remove(os.path.join(local_destination, 'c.txt')) # second. Copy of folder into a non existing folder - t.get('local', os.path.join(local_destination, 'prova')) + transport.get('local', os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) self.assertEqual( - set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova')))) + set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova'))) + ) shutil.rmtree(os.path.join(local_destination, 'prova')) # third. copy of folder into an existing folder os.mkdir(os.path.join(local_destination, 'prova')) - t.get('local', os.path.join(local_destination, 'prova')) + transport.get('local', os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) self.assertEqual(set(['local']), set(os.listdir(os.path.join(local_destination, 'prova')))) self.assertEqual( - set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova', 'local')))) + set(['a.txt', 'b.tmp', 'c.txt']), set(os.listdir(os.path.join(local_destination, 'prova', 'local'))) + ) shutil.rmtree(os.path.join(local_destination, 'prova')) # third test copy. Can copy one file into a new file - t.get(os.path.join('local', '*.tmp'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', '*.tmp'), os.path.join(local_destination, 'prova')) self.assertEqual(set(['prova', 'local']), set(os.listdir(local_destination))) os.remove(os.path.join(local_destination, 'prova')) # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with self.assertRaises(OSError): - t.get(os.path.join('local', '*.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', '*.txt'), os.path.join(local_destination, 'prova')) # copy of folder into file with open(os.path.join(local_destination, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with self.assertRaises(OSError): - t.get('local', os.path.join(local_destination, 'existing.txt')) + transport.get('local', os.path.join(local_destination, 'existing.txt')) os.remove(os.path.join(local_destination, 'existing.txt')) # fifth test, copying one file into a folder os.mkdir(os.path.join(local_destination, 'prova')) - t.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) self.assertEqual(set(os.listdir(os.path.join(local_destination, 'prova'))), set(['a.txt'])) shutil.rmtree(os.path.join(local_destination, 'prova')) # sixth test, copying one file into a file - t.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) self.assertTrue(os.path.isfile(os.path.join(local_destination, 'prova'))) os.remove(os.path.join(local_destination, 'prova')) # exit - t.chdir('..') - t.rmtree(directory) + transport.chdir('..') + transport.rmtree(directory) @run_for_all_plugins def test_put_get_abs_path(self, custom_transport): @@ -1107,8 +1127,8 @@ def test_put_get_abs_path(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -1121,7 +1141,7 @@ def test_put_get_abs_path(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') fhandle = open(local_file_name, 'w', encoding='utf8') @@ -1129,44 +1149,44 @@ def test_put_get_abs_path(self, custom_transport): # 'tmp1' is not an abs path with self.assertRaises(ValueError): - t.put('tmp1', remote_subfolder) + transport.put('tmp1', remote_subfolder) with self.assertRaises(ValueError): - t.putfile('tmp1', remote_subfolder) + transport.putfile('tmp1', remote_subfolder) with self.assertRaises(ValueError): - t.puttree('tmp1', remote_subfolder) + transport.puttree('tmp1', remote_subfolder) # 'tmp3' does not exist with self.assertRaises(OSError): - t.put(retrieved_subfolder, remote_subfolder) + transport.put(retrieved_subfolder, remote_subfolder) with self.assertRaises(OSError): - t.putfile(retrieved_subfolder, remote_subfolder) + transport.putfile(retrieved_subfolder, remote_subfolder) with self.assertRaises(OSError): - t.puttree(retrieved_subfolder, remote_subfolder) + transport.puttree(retrieved_subfolder, remote_subfolder) # remote_file_name does not exist with self.assertRaises(IOError): - t.get('non_existing', retrieved_subfolder) + transport.get('non_existing', retrieved_subfolder) with self.assertRaises(IOError): - t.getfile('non_existing', retrieved_subfolder) + transport.getfile('non_existing', retrieved_subfolder) with self.assertRaises(IOError): - t.gettree('non_existing', retrieved_subfolder) + transport.gettree('non_existing', retrieved_subfolder) - t.put(local_subfolder, remote_subfolder) + transport.put(local_subfolder, remote_subfolder) # local filename is not an abs path with self.assertRaises(ValueError): - t.get(remote_subfolder, 'delete_me_tree') + transport.get(remote_subfolder, 'delete_me_tree') with self.assertRaises(ValueError): - t.getfile(remote_subfolder, 'delete_me_tree') + transport.getfile(remote_subfolder, 'delete_me_tree') with self.assertRaises(ValueError): - t.gettree(remote_subfolder, 'delete_me_tree') + transport.gettree(remote_subfolder, 'delete_me_tree') os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - t.rmtree(remote_subfolder) + transport.rmtree(remote_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins def test_put_get_empty_string(self, custom_transport): @@ -1182,8 +1202,8 @@ def test_put_get_empty_string(self, custom_transport): remote_dir = local_dir directory = 'tmp_try' - with custom_transport as t: - t.chdir(remote_dir) + with custom_transport as transport: + transport.chdir(remote_dir) while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique @@ -1196,7 +1216,7 @@ def test_put_get_empty_string(self, custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - t.chdir(directory) + transport.chdir(directory) local_file_name = os.path.join(local_subfolder, 'file.txt') text = 'Viva Verdi\n' @@ -1206,42 +1226,43 @@ def test_put_get_empty_string(self, custom_transport): # localpath is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.puttree('', remote_subfolder) + transport.puttree('', remote_subfolder) # remote path is an empty string with self.assertRaises(IOError): - t.puttree(local_subfolder, '') + transport.puttree(local_subfolder, '') - t.puttree(local_subfolder, remote_subfolder) + transport.puttree(local_subfolder, remote_subfolder) # remote path is an empty string with self.assertRaises(IOError): - t.gettree('', retrieved_subfolder) + transport.gettree('', retrieved_subfolder) # local path is an empty string # ValueError because it is not an abs path with self.assertRaises(ValueError): - t.gettree(remote_subfolder, '') + transport.gettree(remote_subfolder, '') # TODO : get doesn't retrieve empty files. # Is it what we want? - t.gettree(remote_subfolder, retrieved_subfolder) + transport.gettree(remote_subfolder, retrieved_subfolder) os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - t.remove(os.path.join(remote_subfolder, 'file.txt')) - t.rmdir(remote_subfolder) + transport.remove(os.path.join(remote_subfolder, 'file.txt')) + transport.rmdir(remote_subfolder) # If it couldn't end the copy, it leaves what he did on local file # here I am mixing local with remote - self.assertTrue('file.txt' in t.listdir('tmp3')) + self.assertTrue('file.txt' in transport.listdir('tmp3')) os.remove(os.path.join(retrieved_subfolder, 'file.txt')) os.rmdir(retrieved_subfolder) - t.chdir('..') - t.rmdir(directory) + transport.chdir('..') + transport.rmdir(directory) @run_for_all_plugins - def test_gettree_nested_directory(self, custom_transport): + def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use + """Test `gettree` for a nested directory.""" import os import tempfile @@ -1278,66 +1299,69 @@ def test_exec_pwd(self, custom_transport): # Start value delete_at_end = False - with custom_transport as t: + with custom_transport as transport: # To compare with: getcwd uses the normalized ('realpath') path - location = t.normalize('/tmp') + location = transport.normalize('/tmp') subfolder = """_'s f"#""" # A folder with characters to escape subfolder_fullpath = os.path.join(location, subfolder) - t.chdir(location) - if not t.isdir(subfolder): + transport.chdir(location) + if not transport.isdir(subfolder): # Since I created the folder, I will remember to # delete it at the end of this test delete_at_end = True - t.mkdir(subfolder) + transport.mkdir(subfolder) - self.assertTrue(t.isdir(subfolder)) - t.chdir(subfolder) + self.assertTrue(transport.isdir(subfolder)) + transport.chdir(subfolder) - self.assertEqual(subfolder_fullpath, t.getcwd()) - retcode, stdout, stderr = t.exec_command_wait('pwd') + self.assertEqual(subfolder_fullpath, transport.getcwd()) + retcode, stdout, stderr = transport.exec_command_wait('pwd') self.assertEqual(retcode, 0) # I have to strip it because 'pwd' returns a trailing \n self.assertEqual(stdout.strip(), subfolder_fullpath) self.assertEqual(stderr, '') if delete_at_end: - t.chdir(location) - t.rmdir(subfolder) + transport.chdir(location) + transport.rmdir(subfolder) @run_for_all_plugins def test_exec_with_stdin_string(self, custom_transport): + """Test command execution with a stdin string.""" test_string = str('some_test String') - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=test_string) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=test_string) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_stdin_unicode(self, custom_transport): + """Test command execution with a unicode stdin string.""" test_string = 'some_test String' - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=test_string) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=test_string) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_stdin_filelike(self, custom_transport): - + """Test command execution with a stdin from filelike.""" test_string = 'some_test String' stdin = io.StringIO(test_string) - with custom_transport as t: - retcode, stdout, stderr = t.exec_command_wait('cat', stdin=stdin) + with custom_transport as transport: + retcode, stdout, stderr = transport.exec_command_wait('cat', stdin=stdin) self.assertEqual(retcode, 0) self.assertEqual(stdout, test_string) self.assertEqual(stderr, '') @run_for_all_plugins def test_exec_with_wrong_stdin(self, custom_transport): + """Test command execution with incorrect stdin string.""" # I pass a number - with custom_transport as t: + with custom_transport as transport: with self.assertRaises(ValueError): - t.exec_command_wait('cat', stdin=1) + transport.exec_command_wait('cat', stdin=1) diff --git a/tests/transports/test_local.py b/tests/transports/test_local.py index 712398183c..35b3e247f1 100644 --- a/tests/transports/test_local.py +++ b/tests/transports/test_local.py @@ -7,9 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tests for the `LocalTransport`.""" import unittest -from aiida.transports.plugins.local import * +from aiida.transports.plugins.local import LocalTransport +from aiida.transports.transport import TransportInternalError # This will be used by test_all_plugins @@ -22,10 +24,11 @@ class TestGeneric(unittest.TestCase): """ def test_whoami(self): + """Test the `whoami` command.""" import getpass - with LocalTransport() as t: - self.assertEqual(t.whoami(), getpass.getuser()) + with LocalTransport() as transport: + self.assertEqual(transport.whoami(), getpass.getuser()) class TestBasicConnection(unittest.TestCase): @@ -34,16 +37,14 @@ class TestBasicConnection(unittest.TestCase): """ def test_closed_connection(self): - from aiida.transports.transport import TransportInternalError + """Test running a command on a closed connection.""" with self.assertRaises(TransportInternalError): - t = LocalTransport() - t.listdir() + transport = LocalTransport() + transport.listdir() - def test_basic(self): + @staticmethod + def test_basic(): + """Test constructor.""" with LocalTransport(): pass - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/transports/test_ssh.py b/tests/transports/test_ssh.py index d7a0a7d2a4..2b1e083cf3 100644 --- a/tests/transports/test_ssh.py +++ b/tests/transports/test_ssh.py @@ -7,16 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Test ssh plugin on localhost -""" -import unittest +"""Test the `SshTransport` plugin on localhost.""" import logging +import unittest -import aiida.transports -import aiida.transports.transport import paramiko + from aiida.transports.plugins.ssh import SshTransport +from aiida.transports.transport import TransportInternalError # This will be used by test_all_plugins @@ -29,20 +27,25 @@ class TestBasicConnection(unittest.TestCase): """ def test_closed_connection_ssh(self): - with self.assertRaises(aiida.transports.transport.TransportInternalError): - t = SshTransport(machine='localhost') - t._exec_command_internal('ls') + """Test calling command on a closed connection.""" + with self.assertRaises(TransportInternalError): + transport = SshTransport(machine='localhost') + transport._exec_command_internal('ls') # pylint: disable=protected-access def test_closed_connection_sftp(self): - with self.assertRaises(aiida.transports.transport.TransportInternalError): - t = SshTransport(machine='localhost') - t.listdir() - - def test_auto_add_policy(self): + """Test calling sftp command on a closed connection.""" + with self.assertRaises(TransportInternalError): + transport = SshTransport(machine='localhost') + transport.listdir() + + @staticmethod + def test_auto_add_policy(): + """Test the auto add policy.""" with SshTransport(machine='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy'): pass def test_no_host_key(self): + """Test if there is no host key.""" # Disable logging to avoid output during test logging.disable(logging.ERROR) @@ -52,7 +55,3 @@ def test_no_host_key(self): # Reset logging level logging.disable(logging.NOTSET) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/processes.py b/tests/utils/processes.py index 7115b4fdc5..2d8d7fd6f9 100644 --- a/tests/utils/processes.py +++ b/tests/utils/processes.py @@ -95,7 +95,7 @@ def define(cls, spec): 123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True ) - def run(self): # pylint: disable=inconsistent-return-statements + def run(self): if self.inputs.return_exit_code: return self.exit_codes.GENERIC_EXIT_CODE # pylint: disable=no-member