diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8c3d6223..6641daf1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,8 +21,8 @@ jobs: matrix: os: ['ubuntu-latest', 'macos-latest'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - cratedb-version: ['5.1.1'] - sqla-version: ['1.3.24', '1.4.44'] + cratedb-version: ['5.1.2'] + sqla-version: ['1.3.24', '1.4.45'] # To save resources, only use the most recent Python version on macOS. exclude: - os: 'macos-latest' @@ -33,7 +33,7 @@ jobs: python-version: '3.9' - os: 'macos-latest' python-version: '3.10' - fail-fast: true + fail-fast: false env: CRATEDB_VERSION: ${{ matrix.cratedb-version }} SQLALCHEMY_VERSION: ${{ matrix.sqla-version }} @@ -60,11 +60,14 @@ jobs: # Report about the test matrix slot. echo "Invoking tests with CrateDB ${CRATEDB_VERSION} and SQLAlchemy ${SQLALCHEMY_VERSION}" - # Invoke validation tasks. + # Run linter. flake8 src bin - coverage run bin/test -vv1 + + # Run tests. + export SQLALCHEMY_WARN_20=1 + coverage run bin/test -vvv - # Set the stage for the Codecov step. + # Set the stage for uploading the coverage report. coverage xml # https://github.com/codecov/codecov-action diff --git a/bootstrap.sh b/bootstrap.sh index 60d05f4f..613373a1 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -18,8 +18,8 @@ # Default variables. BUILDOUT_VERSION=${BUILDOUT_VERSION:-2.13.7} -CRATEDB_VERSION=${CRATEDB_VERSION:-5.1.1} -SQLALCHEMY_VERSION=${SQLALCHEMY_VERSION:-1.4.44} +CRATEDB_VERSION=${CRATEDB_VERSION:-5.1.2} +SQLALCHEMY_VERSION=${SQLALCHEMY_VERSION:-1.4.45} function print_header() { diff --git a/docs/by-example/sqlalchemy/advanced-querying.rst b/docs/by-example/sqlalchemy/advanced-querying.rst index 026ef635..863373e4 100644 --- a/docs/by-example/sqlalchemy/advanced-querying.rst +++ b/docs/by-example/sqlalchemy/advanced-querying.rst @@ -21,8 +21,11 @@ Introduction Import the relevant symbols: >>> import sqlalchemy as sa - >>> from sqlalchemy.ext.declarative import declarative_base >>> from sqlalchemy.orm import sessionmaker + >>> try: + ... from sqlalchemy.orm import declarative_base + ... except ImportError: + ... from sqlalchemy.ext.declarative import declarative_base >>> from uuid import uuid4 Establish a connection to the database, see also :ref:`sa:engines_toplevel` @@ -237,8 +240,8 @@ Let's add a task to the ``Todo`` table: Now, let's use ``insert().from_select()`` to archive the task into the ``ArchivedTasks`` table: - >>> sel = select([Todos.id, Todos.content]).where(Todos.status == "done") - >>> ins = insert(ArchivedTasks).from_select(['id','content'], sel) + >>> sel = select(Todos.id, Todos.content).where(Todos.status == "done") + >>> ins = insert(ArchivedTasks).from_select(['id', 'content'], sel) >>> result = session.execute(ins) >>> session.commit() @@ -250,7 +253,7 @@ This will emit the following ``INSERT`` statement to the database: Now, verify that the data is present in the database: >>> _ = connection.execute(sa.text("REFRESH TABLE archived_tasks")) - >>> pprint([str(r) for r in session.execute("SELECT content FROM archived_tasks")]) + >>> pprint([str(r) for r in session.execute(sa.text("SELECT content FROM archived_tasks"))]) ["('Write Tests',)"] diff --git a/docs/by-example/sqlalchemy/crud.rst b/docs/by-example/sqlalchemy/crud.rst index a84404f3..d2840c52 100644 --- a/docs/by-example/sqlalchemy/crud.rst +++ b/docs/by-example/sqlalchemy/crud.rst @@ -27,8 +27,11 @@ Import the relevant symbols: >>> import sqlalchemy as sa >>> from datetime import datetime >>> from sqlalchemy import delete, func, text - >>> from sqlalchemy.ext.declarative import declarative_base >>> from sqlalchemy.orm import sessionmaker + >>> try: + ... from sqlalchemy.orm import declarative_base + ... except ImportError: + ... from sqlalchemy.ext.declarative import declarative_base >>> from crate.client.sqlalchemy.types import ObjectArray Establish a connection to the database, see also :ref:`sa:engines_toplevel` @@ -40,7 +43,7 @@ and :ref:`connect`: Define the ORM schema for the ``Location`` entity using SQLAlchemy's :ref:`sa:orm_declarative_mapping`: - >>> Base = declarative_base(bind=engine) + >>> Base = declarative_base() >>> class Location(Base): ... __tablename__ = 'locations' @@ -74,7 +77,7 @@ Insert a new location: Refresh "locations" table: - >>> _ = connection.execute("REFRESH TABLE locations") + >>> _ = connection.execute(text("REFRESH TABLE locations")) Inserted location is available: @@ -175,7 +178,7 @@ The datetime and date can be set using an update statement: Refresh "locations" table: - >>> _ = connection.execute("REFRESH TABLE locations") + >>> _ = connection.execute(text("REFRESH TABLE locations")) Boolean values get set natively: @@ -196,8 +199,9 @@ And verify that the date and datetime was persisted: Update a record using SQL: - >>> result = connection.execute("update locations set kind='Heimat' where name='Earth'") - >>> result.rowcount + >>> with engine.begin() as conn: + ... result = conn.execute(text("update locations set kind='Heimat' where name='Earth'")) + ... result.rowcount 1 Update multiple records: @@ -211,27 +215,29 @@ Update multiple records: Refresh table: - >>> _ = connection.execute("REFRESH TABLE locations") + >>> _ = connection.execute(text("REFRESH TABLE locations")) Update multiple records using SQL: - >>> result = connection.execute("update locations set flag=true where kind='Update'") - >>> result.rowcount + >>> with engine.begin() as conn: + ... result = conn.execute(text("update locations set flag=true where kind='Update'")) + ... result.rowcount 10 Update all records using SQL, and check that the number of documents affected of an update without ``where-clause`` matches the number of all documents in the table: - >>> result = connection.execute(u"update locations set kind='Überall'") - >>> result.rowcount == connection.execute("select * from locations limit 100").rowcount + >>> with engine.begin() as conn: + ... result = conn.execute(text(u"update locations set kind='Überall'")) + ... result.rowcount == conn.execute(text("select * from locations limit 100")).rowcount True >>> session.commit() Refresh "locations" table: - >>> _ = connection.execute("REFRESH TABLE locations") + >>> _ = connection.execute(text("REFRESH TABLE locations")) Objects can be used within lists, too: @@ -282,7 +288,7 @@ Deleting a record with SQLAlchemy works like this. >>> session.commit() >>> session.flush() - >>> _ = connection.execute("REFRESH TABLE locations") + >>> _ = connection.execute(text("REFRESH TABLE locations")) >>> session.query(Location).count() 23 diff --git a/docs/by-example/sqlalchemy/getting-started.rst b/docs/by-example/sqlalchemy/getting-started.rst index f3fa34cb..c64964dc 100644 --- a/docs/by-example/sqlalchemy/getting-started.rst +++ b/docs/by-example/sqlalchemy/getting-started.rst @@ -28,8 +28,11 @@ Introduction Import the relevant symbols: >>> import sqlalchemy as sa - >>> from sqlalchemy.ext.declarative import declarative_base >>> from sqlalchemy.orm import sessionmaker + >>> try: + ... from sqlalchemy.orm import declarative_base + ... except ImportError: + ... from sqlalchemy.ext.declarative import declarative_base Establish a connection to the database, see also :ref:`sa:engines_toplevel` and :ref:`connect`: diff --git a/docs/by-example/sqlalchemy/inspection-reflection.rst b/docs/by-example/sqlalchemy/inspection-reflection.rst index 1d811a17..bb291157 100644 --- a/docs/by-example/sqlalchemy/inspection-reflection.rst +++ b/docs/by-example/sqlalchemy/inspection-reflection.rst @@ -82,7 +82,6 @@ Create a SQLAlchemy table object: >>> meta = sa.MetaData() >>> table = sa.Table( ... "characters", meta, - ... autoload=True, ... autoload_with=engine) Reflect column data types from the table metadata: diff --git a/docs/by-example/sqlalchemy/working-with-types.rst b/docs/by-example/sqlalchemy/working-with-types.rst index 1016c439..bcddf8f8 100644 --- a/docs/by-example/sqlalchemy/working-with-types.rst +++ b/docs/by-example/sqlalchemy/working-with-types.rst @@ -26,9 +26,12 @@ Import the relevant symbols: >>> from datetime import datetime >>> from geojson import Point, Polygon >>> from sqlalchemy import delete, func, text - >>> from sqlalchemy.ext.declarative import declarative_base >>> from sqlalchemy.orm import sessionmaker >>> from sqlalchemy.sql import operators + >>> try: + ... from sqlalchemy.orm import declarative_base + ... except ImportError: + ... from sqlalchemy.ext.declarative import declarative_base >>> from uuid import uuid4 >>> from crate.client.sqlalchemy.types import Object, ObjectArray >>> from crate.client.sqlalchemy.types import Geopoint, Geoshape @@ -156,7 +159,7 @@ Update nested dictionary Refresh and query "characters" table: - >>> _ = connection.execute("REFRESH TABLE characters") + >>> _ = connection.execute(text("REFRESH TABLE characters")) >>> session.refresh(char_nested) >>> char_nested = session.query(Character).filter_by(id='1234id').one() diff --git a/src/crate/client/sqlalchemy/__init__.py b/src/crate/client/sqlalchemy/__init__.py index a0241e99..52864719 100644 --- a/src/crate/client/sqlalchemy/__init__.py +++ b/src/crate/client/sqlalchemy/__init__.py @@ -19,7 +19,13 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +from .compat.api13 import monkeypatch_add_exec_driver_sql from .dialect import CrateDialect +from .sa_version import SA_1_4, SA_VERSION + +# SQLAlchemy 1.3 does not have the `exec_driver_sql` method. +if SA_VERSION < SA_1_4: + monkeypatch_add_exec_driver_sql() __all__ = [ CrateDialect, diff --git a/src/crate/client/sqlalchemy/compat/__init__.py b/src/crate/client/sqlalchemy/compat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/crate/client/sqlalchemy/compat/api13.py b/src/crate/client/sqlalchemy/compat/api13.py new file mode 100644 index 00000000..16f53393 --- /dev/null +++ b/src/crate/client/sqlalchemy/compat/api13.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +""" +Compatibility module for running a subset of SQLAlchemy 2.0 programs on +SQLAlchemy 1.3. By using monkey-patching, it can do two things: + +1. Add the `exec_driver_sql` method to SA's `Connection` and `Engine`. +2. Amend the `sql.select` function to accept the calling semantics of + the modern variant. + +Reason: `exec_driver_sql` gets used within the CrateDB dialect already, +and the new calling semantics of `sql.select` already get used within +many of the test cases already. Please note that the patch for +`sql.select` is only applied when running the test suite. +""" + +import collections.abc as collections_abc + +from sqlalchemy import exc +from sqlalchemy.sql import Select +from sqlalchemy.sql import select as original_select +from sqlalchemy.util import immutabledict + + +# `_distill_params_20` copied from SA14's `sqlalchemy.engine.{base,util}`. +_no_tuple = () +_no_kw = immutabledict() + + +def _distill_params_20(params): + if params is None: + return _no_tuple, _no_kw + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance(params[0], (collections_abc.Mapping, tuple)): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return (params,), _no_kw + elif isinstance( + params, + (tuple, dict, immutabledict), + # only do abc.__instancecheck__ for Mapping after we've checked + # for plain dictionaries and would otherwise raise + ) or isinstance(params, collections_abc.Mapping): + return (params,), _no_kw + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") + + +def exec_driver_sql(self, statement, parameters=None, execution_options=None): + """ + Adapter for `exec_driver_sql`, which is available since SA14, for SA13. + """ + if execution_options is not None: + raise ValueError( + "SA13 backward-compatibility: " + "`exec_driver_sql` does not support `execution_options`" + ) + args_10style, kwargs_10style = _distill_params_20(parameters) + return self.execute(statement, *args_10style, **kwargs_10style) + + +def monkeypatch_add_exec_driver_sql(): + """ + Transparently add SA14's `exec_driver_sql()` method to SA13. + + AttributeError: 'Connection' object has no attribute 'exec_driver_sql' + AttributeError: 'Engine' object has no attribute 'exec_driver_sql' + """ + from sqlalchemy.engine.base import Connection, Engine + + # Add `exec_driver_sql` method to SA's `Connection` and `Engine` classes. + Connection.exec_driver_sql = exec_driver_sql + Engine.exec_driver_sql = exec_driver_sql + + +def select_sa14(*columns, **kw) -> Select: + """ + Adapt SA14/SA20's calling semantics of `sql.select()` to SA13. + + With SA20, `select()` no longer accepts varied constructor arguments, only + the "generative" style of `select()` will be supported. The list of columns + / tables to select from should be passed positionally. + + Derived from https://github.com/sqlalchemy/alembic/blob/b1fad6b6/alembic/util/sqla_compat.py#L557-L558 + + sqlalchemy.exc.ArgumentError: columns argument to select() must be a Python list or other iterable + """ + if isinstance(columns, tuple) and isinstance(columns[0], list): + if "whereclause" in kw: + raise ValueError( + "SA13 backward-compatibility: " + "`whereclause` is both in kwargs and columns tuple" + ) + columns, whereclause = columns + kw["whereclause"] = whereclause + return original_select(columns, **kw) + + +def monkeypatch_amend_select_sa14(): + """ + Make SA13's `sql.select()` transparently accept calling semantics of SA14 + and SA20, by swapping in the newer variant of `select_sa14()`. + + This supports the test suite of `crate-python`, because it already uses the + modern calling semantics. + """ + import sqlalchemy + + sqlalchemy.select = select_sa14 + sqlalchemy.sql.select = select_sa14 + sqlalchemy.sql.expression.select = select_sa14 diff --git a/src/crate/client/sqlalchemy/dialect.py b/src/crate/client/sqlalchemy/dialect.py index 903a803c..80ab2c20 100644 --- a/src/crate/client/sqlalchemy/dialect.py +++ b/src/crate/client/sqlalchemy/dialect.py @@ -228,7 +228,7 @@ def has_table(self, connection, table_name, schema=None): @reflection.cache def get_schema_names(self, connection, **kw): - cursor = connection.execute( + cursor = connection.exec_driver_sql( "select schema_name " "from information_schema.schemata " "order by schema_name asc" @@ -237,21 +237,21 @@ def get_schema_names(self, connection, **kw): @reflection.cache def get_table_names(self, connection, schema=None, **kw): - cursor = connection.execute( + cursor = connection.exec_driver_sql( "SELECT table_name FROM information_schema.tables " "WHERE {0} = ? " "AND table_type = 'BASE TABLE' " "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), - [schema or self.default_schema_name] + (schema or self.default_schema_name, ) ) return [row[0] for row in cursor.fetchall()] @reflection.cache def get_view_names(self, connection, schema=None, **kw): - cursor = connection.execute( + cursor = connection.exec_driver_sql( "SELECT table_name FROM information_schema.views " "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), - [schema or self.default_schema_name] + (schema or self.default_schema_name, ) ) return [row[0] for row in cursor.fetchall()] @@ -262,11 +262,11 @@ def get_columns(self, connection, table_name, schema=None, **kw): "WHERE table_name = ? AND {0} = ? " \ "AND column_name !~ ?" \ .format(self.schema_column) - cursor = connection.execute( + cursor = connection.exec_driver_sql( query, - [table_name, + (table_name, schema or self.default_schema_name, - r"(.*)\[\'(.*)\'\]"] # regex to filter subscript + r"(.*)\[\'(.*)\'\]") # regex to filter subscript ) return [self._create_column_info(row) for row in cursor.fetchall()] @@ -301,9 +301,9 @@ def result_fun(result): rows = result.fetchone() return set(rows[0] if rows else []) - pk_result = engine.execute( + pk_result = engine.exec_driver_sql( query, - [table_name, schema or self.default_schema_name] + (table_name, schema or self.default_schema_name) ) pks = result_fun(pk_result) return {'constrained_columns': pks, diff --git a/src/crate/client/sqlalchemy/tests/__init__.py b/src/crate/client/sqlalchemy/tests/__init__.py index 81f1ba2a..61a2669b 100644 --- a/src/crate/client/sqlalchemy/tests/__init__.py +++ b/src/crate/client/sqlalchemy/tests/__init__.py @@ -1,5 +1,13 @@ # -*- coding: utf-8 -*- +from ..compat.api13 import monkeypatch_amend_select_sa14 +from ..sa_version import SA_1_4, SA_VERSION + +# `sql.select()` of SQLAlchemy 1.3 uses old calling semantics, +# but the test cases already need the modern ones. +if SA_VERSION < SA_1_4: + monkeypatch_amend_select_sa14() + from unittest import TestSuite, makeSuite from .connection_test import SqlAlchemyConnectionTest from .dict_test import SqlAlchemyDictTypeTest diff --git a/src/crate/client/sqlalchemy/tests/array_test.py b/src/crate/client/sqlalchemy/tests/array_test.py index c65a2d9d..6d663327 100644 --- a/src/crate/client/sqlalchemy/tests/array_test.py +++ b/src/crate/client/sqlalchemy/tests/array_test.py @@ -26,7 +26,10 @@ import sqlalchemy as sa from sqlalchemy.sql import operators from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor @@ -40,7 +43,7 @@ class SqlAlchemyArrayTypeTest(TestCase): def setUp(self): self.engine = sa.create_engine('crate://') - Base = declarative_base(bind=self.engine) + Base = declarative_base() self.metadata = sa.MetaData() class User(Base): @@ -51,7 +54,7 @@ class User(Base): scores = sa.Column(sa.ARRAY(sa.Integer)) self.User = User - self.session = Session() + self.session = Session(bind=self.engine) def assertSQL(self, expected_str, actual_expr): self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) diff --git a/src/crate/client/sqlalchemy/tests/bulk_test.py b/src/crate/client/sqlalchemy/tests/bulk_test.py index c9d60319..95bc1ddd 100644 --- a/src/crate/client/sqlalchemy/tests/bulk_test.py +++ b/src/crate/client/sqlalchemy/tests/bulk_test.py @@ -24,7 +24,10 @@ import sqlalchemy as sa from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor @@ -38,7 +41,7 @@ class SqlAlchemyBulkTest(TestCase): def setUp(self): self.engine = sa.create_engine('crate://') - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' @@ -47,7 +50,7 @@ class Character(Base): age = sa.Column(sa.Integer) self.character = Character - self.session = Session() + self.session = Session(bind=self.engine) @patch('crate.client.connection.Cursor', FakeCursor) def test_bulk_save(self): diff --git a/src/crate/client/sqlalchemy/tests/compiler_test.py b/src/crate/client/sqlalchemy/tests/compiler_test.py index c49e14b3..47317db7 100644 --- a/src/crate/client/sqlalchemy/tests/compiler_test.py +++ b/src/crate/client/sqlalchemy/tests/compiler_test.py @@ -24,7 +24,7 @@ from crate.client.sqlalchemy.compiler import crate_before_execute import sqlalchemy as sa -from sqlalchemy.sql import update, text +from sqlalchemy.sql import text, Update from crate.client.sqlalchemy.sa_version import SA_VERSION, SA_1_4 from crate.client.sqlalchemy.types import Craty @@ -40,7 +40,7 @@ def setUp(self): sa.Column('name', sa.String), sa.Column('data', Craty)) - self.update = update(self.mytable, text('where name=:name')) + self.update = Update(self.mytable).where(text('name=:name')) self.values = [{'name': 'crate'}] self.values = (self.values, ) @@ -75,9 +75,8 @@ def test_select_with_offset(self): """ Verify the `CrateCompiler.limit_clause` method, with offset. """ - self.metadata.bind = self.crate_engine selectable = self.mytable.select().offset(5) - statement = str(selectable.compile()) + statement = str(selectable.compile(bind=self.crate_engine)) if SA_VERSION >= SA_1_4: self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable\n LIMIT ALL OFFSET ?") else: @@ -87,16 +86,14 @@ def test_select_with_limit(self): """ Verify the `CrateCompiler.limit_clause` method, with limit. """ - self.metadata.bind = self.crate_engine selectable = self.mytable.select().limit(42) - statement = str(selectable.compile()) + statement = str(selectable.compile(bind=self.crate_engine)) self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ?") def test_select_with_offset_and_limit(self): """ Verify the `CrateCompiler.limit_clause` method, with offset and limit. """ - self.metadata.bind = self.crate_engine selectable = self.mytable.select().offset(5).limit(42) - statement = str(selectable.compile()) + statement = str(selectable.compile(bind=self.crate_engine)) self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ? OFFSET ?") diff --git a/src/crate/client/sqlalchemy/tests/connection_test.py b/src/crate/client/sqlalchemy/tests/connection_test.py index b1dc5d85..8344adc1 100644 --- a/src/crate/client/sqlalchemy/tests/connection_test.py +++ b/src/crate/client/sqlalchemy/tests/connection_test.py @@ -35,6 +35,8 @@ def test_default_connection(self): conn = engine.raw_connection() self.assertEqual(">", repr(conn.connection)) + conn.close() + engine.dispose() def test_connection_server_uri_http(self): engine = sa.create_engine( @@ -42,6 +44,8 @@ def test_connection_server_uri_http(self): conn = engine.raw_connection() self.assertEqual(">", repr(conn.connection)) + conn.close() + engine.dispose() def test_connection_server_uri_https(self): engine = sa.create_engine( @@ -49,6 +53,8 @@ def test_connection_server_uri_https(self): conn = engine.raw_connection() self.assertEqual(">", repr(conn.connection)) + conn.close() + engine.dispose() def test_connection_server_uri_invalid_port(self): with self.assertRaises(ValueError) as context: @@ -63,6 +69,8 @@ def test_connection_server_uri_https_with_trusted_user(self): repr(conn.connection)) self.assertEqual(conn.connection.client.username, "foo") self.assertEqual(conn.connection.client.password, None) + conn.close() + engine.dispose() def test_connection_server_uri_https_with_credentials(self): engine = sa.create_engine( @@ -72,6 +80,8 @@ def test_connection_server_uri_https_with_credentials(self): repr(conn.connection)) self.assertEqual(conn.connection.client.username, "foo") self.assertEqual(conn.connection.client.password, "bar") + conn.close() + engine.dispose() def test_connection_multiple_server_http(self): engine = sa.create_engine( @@ -84,6 +94,8 @@ def test_connection_multiple_server_http(self): ">", repr(conn.connection)) + conn.close() + engine.dispose() def test_connection_multiple_server_https(self): engine = sa.create_engine( @@ -97,3 +109,5 @@ def test_connection_multiple_server_https(self): ">", repr(conn.connection)) + conn.close() + engine.dispose() diff --git a/src/crate/client/sqlalchemy/tests/create_table_test.py b/src/crate/client/sqlalchemy/tests/create_table_test.py index 595f2aa8..7eca2628 100644 --- a/src/crate/client/sqlalchemy/tests/create_table_test.py +++ b/src/crate/client/sqlalchemy/tests/create_table_test.py @@ -20,7 +20,10 @@ # software solely pursuant to the terms of the relevant commercial agreement. import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.sqlalchemy.types import Object, ObjectArray, Geopoint from crate.client.cursor import Cursor @@ -39,7 +42,7 @@ class SqlAlchemyCreateTableTest(TestCase): def setUp(self): self.engine = sa.create_engine('crate://') - self.Base = declarative_base(bind=self.engine) + self.Base = declarative_base() def test_table_basic_types(self): class User(self.Base): @@ -57,7 +60,7 @@ class User(self.Base): float_col = sa.Column(sa.Float) double_col = sa.Column(sa.DECIMAL) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE users (\n\tstring_col STRING NOT NULL, ' '\n\tunicode_col STRING, \n\ttext_col STRING, \n\tint_col INT, ' @@ -74,7 +77,7 @@ class DummyTable(self.Base): __tablename__ = 'dummy' pk = sa.Column(sa.String, primary_key=True) obj_col = sa.Column(Object) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE dummy (\n\tpk STRING NOT NULL, \n\tobj_col OBJECT, ' '\n\tPRIMARY KEY (pk)\n)\n\n'), @@ -88,7 +91,7 @@ class DummyTable(self.Base): } pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -102,7 +105,7 @@ class DummyTable(self.Base): __tablename__ = 't' ts = sa.Column(sa.BigInteger, primary_key=True) p = sa.Column(sa.BigInteger, sa.Computed("date_trunc('day', ts)")) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'ts LONG NOT NULL, \n\t' @@ -117,7 +120,7 @@ class DummyTable(self.Base): ts = sa.Column(sa.BigInteger, primary_key=True) p = sa.Column(sa.BigInteger, sa.Computed("date_trunc('day', ts)", persisted=False)) with self.assertRaises(sa.exc.CompileError): - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) def test_table_partitioned_by(self): class DummyTable(self.Base): @@ -128,7 +131,7 @@ class DummyTable(self.Base): } pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -146,7 +149,7 @@ class DummyTable(self.Base): } pk = sa.Column(sa.String, primary_key=True) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -163,7 +166,7 @@ class DummyTable(self.Base): } pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String, primary_key=True) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -178,7 +181,7 @@ class DummyTable(self.Base): pk = sa.Column(sa.String, primary_key=True) tags = sa.Column(ObjectArray) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -192,7 +195,7 @@ class DummyTable(self.Base): a = sa.Column(sa.Integer, nullable=True) b = sa.Column(sa.Integer, nullable=False) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -205,7 +208,7 @@ class DummyTable(self.Base): __tablename__ = 't' pk = sa.Column(sa.String, primary_key=True, nullable=True) with self.assertRaises(sa.exc.CompileError): - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) def test_column_crate_index(self): class DummyTable(self.Base): @@ -214,7 +217,7 @@ class DummyTable(self.Base): a = sa.Column(sa.Integer, crate_index=False) b = sa.Column(sa.Integer, crate_index=True) - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' @@ -228,4 +231,4 @@ class DummyTable(self.Base): pk = sa.Column(sa.String, primary_key=True) a = sa.Column(Geopoint, crate_index=False) with self.assertRaises(sa.exc.CompileError): - self.Base.metadata.create_all() + self.Base.metadata.create_all(bind=self.engine) diff --git a/src/crate/client/sqlalchemy/tests/datetime_test.py b/src/crate/client/sqlalchemy/tests/datetime_test.py index 0e85e3ef..07e98ede 100644 --- a/src/crate/client/sqlalchemy/tests/datetime_test.py +++ b/src/crate/client/sqlalchemy/tests/datetime_test.py @@ -27,7 +27,10 @@ import sqlalchemy as sa from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor @@ -54,7 +57,7 @@ class SqlAlchemyDateAndDateTimeTest(TestCase): def setUp(self): self.engine = sa.create_engine('crate://') - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' @@ -66,7 +69,7 @@ class Character(Base): ('characters_name', None, None, None, None, None, None), ('characters_date', None, None, None, None, None, None) ) - self.session = Session() + self.session = Session(bind=self.engine) self.Character = Character def test_date_can_handle_datetime(self): diff --git a/src/crate/client/sqlalchemy/tests/dialect_test.py b/src/crate/client/sqlalchemy/tests/dialect_test.py index 51922c84..a6669df4 100644 --- a/src/crate/client/sqlalchemy/tests/dialect_test.py +++ b/src/crate/client/sqlalchemy/tests/dialect_test.py @@ -28,8 +28,11 @@ from crate.client.cursor import Cursor from crate.client.sqlalchemy.types import Object from sqlalchemy import inspect -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.testing import eq_, in_ FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) @@ -48,13 +51,14 @@ def setUp(self): FakeCursor.return_value = self.fake_cursor self.engine = sa.create_engine('crate://') + self.executed_statement = None self.connection = self.engine.connect() self.fake_cursor.execute = self.execute_wrapper - self.base = declarative_base(bind=self.engine) + self.base = declarative_base() class Character(self.base): __tablename__ = 'characters' @@ -64,12 +68,10 @@ class Character(self.base): obj = sa.Column(Object) ts = sa.Column(sa.DateTime, onupdate=datetime.utcnow) - self.character = Character - self.session = Session() + self.session = Session(bind=self.engine) def test_primary_keys_2_3_0(self): - meta = self.character.metadata - insp = inspect(meta.bind) + insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (2, 3, 0) self.fake_cursor.rowcount = 3 @@ -84,8 +86,7 @@ def test_primary_keys_2_3_0(self): in_("table_catalog = ?", self.executed_statement) def test_primary_keys_3_0_0(self): - meta = self.character.metadata - insp = inspect(meta.bind) + insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (3, 0, 0) self.fake_cursor.rowcount = 3 @@ -106,7 +107,7 @@ def test_get_table_names(self): ) self.fake_cursor.fetchall = MagicMock(return_value=[["t1"], ["t2"]]) - insp = inspect(self.character.metadata.bind) + insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (2, 0, 0) eq_(insp.get_table_names(schema="doc"), ['t1', 't2']) @@ -119,7 +120,7 @@ def test_get_view_names(self): ) self.fake_cursor.fetchall = MagicMock(return_value=[["v1"], ["v2"]]) - insp = inspect(self.character.metadata.bind) + insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (2, 0, 0) eq_(insp.get_view_names(schema="doc"), ['v1', 'v2']) diff --git a/src/crate/client/sqlalchemy/tests/dict_test.py b/src/crate/client/sqlalchemy/tests/dict_test.py index 7f4464c0..2324591e 100644 --- a/src/crate/client/sqlalchemy/tests/dict_test.py +++ b/src/crate/client/sqlalchemy/tests/dict_test.py @@ -26,7 +26,10 @@ import sqlalchemy as sa from sqlalchemy.sql import select from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.sqlalchemy.types import Craty, ObjectArray from crate.client.cursor import Cursor @@ -46,19 +49,20 @@ def setUp(self): sa.Column('name', sa.String), sa.Column('data', Craty)) - def assertSQL(self, expected_str, actual_expr): + def assertSQL(self, expected_str, selectable): + actual_expr = selectable.compile(bind=self.engine) self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) def test_select_with_dict_column(self): mytable = self.mytable self.assertSQL( "SELECT mytable.data['x'] AS anon_1 FROM mytable", - select([mytable.c.data['x']], bind=self.engine) + select(mytable.c.data['x']) ) def test_select_with_dict_column_where_clause(self): mytable = self.mytable - s = select([mytable.c.data], bind=self.engine).\ + s = select(mytable.c.data).\ where(mytable.c.data['x'] == 1) self.assertSQL( "SELECT mytable.data FROM mytable WHERE mytable.data['x'] = ?", @@ -67,7 +71,7 @@ def test_select_with_dict_column_where_clause(self): def test_select_with_dict_column_nested_where(self): mytable = self.mytable - s = select([mytable.c.name], bind=self.engine) + s = select(mytable.c.name) s = s.where(mytable.c.data['x']['y'] == 1) self.assertSQL( "SELECT mytable.name FROM mytable " + @@ -77,7 +81,7 @@ def test_select_with_dict_column_nested_where(self): def test_select_with_dict_column_where_clause_gt(self): mytable = self.mytable - s = select([mytable.c.data], bind=self.engine).\ + s = select(mytable.c.data).\ where(mytable.c.data['x'] > 1) self.assertSQL( "SELECT mytable.data FROM mytable WHERE mytable.data['x'] > ?", @@ -86,7 +90,7 @@ def test_select_with_dict_column_where_clause_gt(self): def test_select_with_dict_column_where_clause_other_col(self): mytable = self.mytable - s = select([mytable.c.name], bind=self.engine) + s = select(mytable.c.name) s = s.where(mytable.c.data['x'] == mytable.c.name) self.assertSQL( "SELECT mytable.name FROM mytable " + @@ -96,7 +100,7 @@ def test_select_with_dict_column_where_clause_other_col(self): def test_update_with_dict_column(self): mytable = self.mytable - stmt = mytable.update(bind=self.engine).\ + stmt = mytable.update().\ where(mytable.c.name == 'Arthur Dent').\ values({ "data['x']": "Trillian" @@ -114,7 +118,7 @@ def set_up_character_and_cursor(self, return_value=None): ('characters_data', None, None, None, None, None, None) ) fake_cursor.rowcount = 1 - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' @@ -123,7 +127,7 @@ class Character(Base): data = sa.Column(Craty) data_list = sa.Column(ObjectArray) - session = Session() + session = Session(bind=self.engine) return session, Character def test_assign_null_to_object_array(self): @@ -266,7 +270,7 @@ def test_partial_dict_update_with_delitem_setitem(self): return_value=[('Trillian', {'x': 1})] ) - session = Session() + session = Session(bind=self.engine) char = Character(name='Trillian') char.data = {'x': 1} session.add(char) @@ -339,14 +343,14 @@ def set_up_character_and_cursor_data_list(self, return_value=None): ) fake_cursor.rowcount = 1 - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' name = sa.Column(sa.String, primary_key=True) data_list = sa.Column(ObjectArray) - session = Session() + session = Session(bind=self.engine) return session, Character def _setup_object_array_char(self): diff --git a/src/crate/client/sqlalchemy/tests/function_test.py b/src/crate/client/sqlalchemy/tests/function_test.py index 1b4a1983..072ab43a 100644 --- a/src/crate/client/sqlalchemy/tests/function_test.py +++ b/src/crate/client/sqlalchemy/tests/function_test.py @@ -23,12 +23,15 @@ import sqlalchemy as sa from sqlalchemy.sql.sqltypes import TIMESTAMP -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base class SqlAlchemyFunctionTest(TestCase): def setUp(self): - Base = declarative_base(bind=sa.create_engine("crate://")) + Base = declarative_base() class Character(Base): __tablename__ = "characters" diff --git a/src/crate/client/sqlalchemy/tests/insert_from_select_test.py b/src/crate/client/sqlalchemy/tests/insert_from_select_test.py index 0c5ba73f..692dfa55 100644 --- a/src/crate/client/sqlalchemy/tests/insert_from_select_test.py +++ b/src/crate/client/sqlalchemy/tests/insert_from_select_test.py @@ -24,9 +24,12 @@ from unittest.mock import patch, MagicMock import sqlalchemy as sa -from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import select, insert +from sqlalchemy.orm import Session +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor @@ -44,7 +47,7 @@ def assertSQL(self, expected_str, actual_expr): def setUp(self): self.engine = sa.create_engine('crate://') - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' @@ -64,7 +67,7 @@ class CharacterArchive(Base): self.character = Character self.character_archived = CharacterArchive - self.session = Session() + self.session = Session(bind=self.engine) @patch('crate.client.connection.Cursor', FakeCursor) def test_insert_from_select_triggered(self): @@ -72,11 +75,11 @@ def test_insert_from_select_triggered(self): self.session.add(char) self.session.commit() - sel = select([self.character.name, self.character.age]).where(self.character.status == "Archived") + sel = select(self.character.name, self.character.age).where(self.character.status == "Archived") ins = insert(self.character_archived).from_select(['name', 'age'], sel) self.session.execute(ins) self.session.commit() self.assertSQL( "INSERT INTO characters_archive (name, age) SELECT characters.name, characters.age FROM characters WHERE characters.status = ?", - ins + ins.compile(bind=self.engine) ) diff --git a/src/crate/client/sqlalchemy/tests/match_test.py b/src/crate/client/sqlalchemy/tests/match_test.py index 71b79d0d..fdd5b7d0 100644 --- a/src/crate/client/sqlalchemy/tests/match_test.py +++ b/src/crate/client/sqlalchemy/tests/match_test.py @@ -25,7 +25,10 @@ import sqlalchemy as sa from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.sqlalchemy.types import Craty from crate.client.sqlalchemy.predicates import match @@ -52,14 +55,14 @@ def assertSQL(self, expected_str, actual_expr): self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) def set_up_character_and_session(self): - Base = declarative_base(bind=self.engine) + Base = declarative_base() class Character(Base): __tablename__ = 'characters' name = sa.Column(sa.String, primary_key=True) info = sa.Column(Craty) - session = Session() + session = Session(bind=self.engine) return session, Character def test_simple_match(self): diff --git a/src/crate/client/sqlalchemy/tests/update_test.py b/src/crate/client/sqlalchemy/tests/update_test.py index 394163aa..00aeef0a 100644 --- a/src/crate/client/sqlalchemy/tests/update_test.py +++ b/src/crate/client/sqlalchemy/tests/update_test.py @@ -27,7 +27,10 @@ import sqlalchemy as sa from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor @@ -42,7 +45,7 @@ class SqlAlchemyUpdateTest(TestCase): def setUp(self): self.engine = sa.create_engine('crate://') - self.base = declarative_base(bind=self.engine) + self.base = declarative_base() class Character(self.base): __tablename__ = 'characters' @@ -53,7 +56,7 @@ class Character(self.base): ts = sa.Column(sa.DateTime, onupdate=datetime.utcnow) self.character = Character - self.session = Session() + self.session = Session(bind=self.engine) @patch('crate.client.connection.Cursor', FakeCursor) def test_onupdate_is_triggered(self): diff --git a/src/crate/testing/layer.py b/src/crate/testing/layer.py index 3c5ed939..5fd6d8fd 100644 --- a/src/crate/testing/layer.py +++ b/src/crate/testing/layer.py @@ -321,12 +321,12 @@ def start(self): sys.stderr.write('\nCrate instance ready.\n') def stop(self): + self.conn_pool.clear() if self.process: self.process.terminate() self.process.communicate(timeout=10) self.process.stdout.close() self.process = None - self.conn_pool.clear() self.monitor.stop() self._clean()