diff --git a/pyproject.toml b/pyproject.toml index db692e0..fe4eabf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databudgie" -version = "2.8.1" +version = "2.8.2" packages = [ { include = "databudgie", from = "src" }, ] diff --git a/src/databudgie/cli/base.py b/src/databudgie/cli/base.py index 6b8e31b..673a354 100644 --- a/src/databudgie/cli/base.py +++ b/src/databudgie/cli/base.py @@ -1,4 +1,6 @@ -from typing import Callable, Optional, Union +from __future__ import annotations + +from typing import Callable import click import sqlalchemy @@ -6,7 +8,7 @@ import sqlalchemy.orm import strapp.click -from databudgie.config import BackupConfig, RestoreConfig, RootConfig +from databudgie.config import BackupConfig, Connection, RestoreConfig, RootConfig from databudgie.output import Console version = getattr(sqlalchemy, "__version__", "") @@ -19,17 +21,22 @@ create_url = sqlalchemy.engine.url.URL -def _create_postgres_session(config: Union[BackupConfig, RestoreConfig], connection_name: Optional[str] = None): - if connection_name: - connection = config.connections.get(connection_name) - if connection is None: - raise click.UsageError(f"Connection '{connection_name}' not found") - else: - connection = config.connection - if connection is None: - raise click.UsageError("No config found for 'url' field. Either a 'connection' or a 'url' is required") - +def _create_postgres_session(config: BackupConfig | RestoreConfig): + connection: str | Connection = config.connection + if isinstance(config.connection, str): + if config.connection in config.connections: + # `connection` can either be a connection's name, or a literal connection string. We need + # to only overwrite the `connection` value if there is a correspondingly named connection + connection = config.connections[config.connection] + else: + raise click.UsageError( + f"'{config.connection}' did not resolve to a connection. 'url' or 'connection' must " + "resolve to a named connection or connection details." + ) + + assert isinstance(connection, Connection) url = connection.url + if isinstance(url, dict): url_obj = create_url(**url) else: @@ -51,18 +58,18 @@ def restore_config(root_config: RootConfig, console: Console): return root_config.restore -def backup_db(backup_config: BackupConfig, connection_name: Optional[str] = None): - return _create_postgres_session(backup_config, connection_name) +def backup_db(backup_config: BackupConfig): + return _create_postgres_session(backup_config) -def restore_db(restore_config: RestoreConfig, connection_name: Optional[str] = None): - return _create_postgres_session(restore_config, connection_name) +def restore_db(restore_config: RestoreConfig): + return _create_postgres_session(restore_config) def backup_manifest(backup_config: BackupConfig, backup_db): from databudgie.manifest.manager import BackupManifest - table_name: Optional[str] = backup_config.manifest + table_name: str | None = backup_config.manifest if table_name: return BackupManifest(backup_db, table_name) return None @@ -71,7 +78,7 @@ def backup_manifest(backup_config: BackupConfig, backup_db): def restore_manifest(restore_config: RestoreConfig, restore_db): from databudgie.manifest.manager import RestoreManifest - table_name: Optional[str] = restore_config.manifest + table_name: str | None = restore_config.manifest if table_name: return RestoreManifest(restore_db, table_name) return None diff --git a/src/databudgie/cli/commands.py b/src/databudgie/cli/commands.py index e8fc638..2ee13a4 100644 --- a/src/databudgie/cli/commands.py +++ b/src/databudgie/cli/commands.py @@ -115,6 +115,7 @@ def cli( location=location, adapter=adapter, strict=strict, + connection=conn, ) try: @@ -131,7 +132,6 @@ def cli( root_config=root_config, verbosity=verbose, console=Console(verbosity=verbose), - connection_name=conn, dry_run=bool(dry_run), stats=stats if stats is not None else dry_run, ) diff --git a/src/databudgie/cli/config.py b/src/databudgie/cli/config.py index 18073c0..c37e93b 100644 --- a/src/databudgie/cli/config.py +++ b/src/databudgie/cli/config.py @@ -41,6 +41,7 @@ class CliConfig(DatabudgieConfig): location: str | None = None adapter: str | None = None strict: bool | None = None + connection: str | None = None def to_dict(self) -> dict: config = asdict(self) diff --git a/src/databudgie/config.py b/src/databudgie/config.py index 4341195..0199589 100644 --- a/src/databudgie/config.py +++ b/src/databudgie/config.py @@ -128,10 +128,10 @@ def from_stack(cls, stack: ConfigStack): @dataclass class TableParentConfig(typing.Generic[T], Config): tables: list[T] - connections: dict[str, Connection] + connections: dict[str, Connection | str] ddl: DDLConfig - connection: Connection | None = None + connection: Connection | str = "default" manifest: str | None = None s3: S3Config | None = None @@ -145,7 +145,7 @@ def get_child_class(cls): @classmethod def from_stack(cls, stack: ConfigStack): - connection = Connection.from_raw(stack.get("url") or stack.get("connection"), name="default") + connection = Connection.from_raw(stack.get("url") or stack.get("connection"), name="default") or "default" root_location = stack.get("root_location") tables_config: list = normalize_table_config(stack.get("tables", [])) @@ -182,11 +182,14 @@ class Connection(Config): url: str | dict @classmethod - def from_raw(cls, raw: str | dict | None, *, name: str | None = None): + def from_raw(cls, raw: str | dict | None, *, name: str | None = None) -> Connection | str | None: if raw is None: return None if isinstance(raw, str): + if "://" not in raw: + return raw + if name is None: raise ConfigError(f"Connection '{raw}' requires a name") return cls(name="default", url=raw) @@ -202,7 +205,7 @@ def from_raw(cls, raw: str | dict | None, *, name: str | None = None): return cls(name=name or "default", url=url) @classmethod - def from_collection(cls, collection: list | dict | None) -> dict[str, Connection]: + def from_collection(cls, collection: list | dict | None) -> dict[str, Connection | str]: if collection is None: return {} @@ -211,7 +214,10 @@ def from_collection(cls, collection: list | dict | None) -> dict[str, Connection names = set() for c in collection: connection = Connection.from_raw(c, name=c.get("name")) - assert connection is not None + if not isinstance(connection, Connection): + raise ConfigError( + "Connections must be a database connection string or a mapping of individual connection fields." + ) if connection.name in names: raise ConfigError(f"Detected more than one connection with the same name: {connection.name}") @@ -220,7 +226,12 @@ def from_collection(cls, collection: list | dict | None) -> dict[str, Connection return {c.name: c for c in connections} - return {k: Connection.from_raw(c, name=k) for k, c in collection.items()} + result = {} + for k, c in collection.items(): + connection = Connection.from_raw(c, name=k) + if connection is not None: + result[k] = connection + return result @dataclass diff --git a/tests/cli/test_base.py b/tests/cli/test_base.py index 6d092fb..eb69b8e 100644 --- a/tests/cli/test_base.py +++ b/tests/cli/test_base.py @@ -1,3 +1,5 @@ +import click +import pytest from pytest_mock_resources import create_postgres_fixture from sqlalchemy import text @@ -29,3 +31,25 @@ def test_create_postgres_session_url_components(pg_engine): config = BackupConfig.from_stack(ConfigStack({"url": url_parts})) session = _create_postgres_session(config) session.execute(text("select 1")) + + +def test_connection_selection(pg_engine): + url = pg_engine.pmr_credentials.as_sqlalchemy_url() + + try: + url_str = url.render_as_string(hide_password=False) + except AttributeError: + url_str = str(url) + + assert url_str.startswith("postgresql+psycopg2://") + config = BackupConfig.from_stack(ConfigStack({"connections": {"foo": url_str}, "connection": "foo"})) + + session = _create_postgres_session(config) + session.execute(text("select 1")) + + +def test_missing_connection(): + config = BackupConfig.from_stack(ConfigStack({"connections": {}, "connection": "foo"})) + + with pytest.raises(click.UsageError): + _create_postgres_session(config) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 9552d75..4a7c192 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -28,7 +28,7 @@ def run_command(command, assert_exit_code=0): @pytest.mark.parametrize("command", ("backup", "restore")) def test_no_default_file_warns_of_no_url(command): result = run_command(command, assert_exit_code=2) - assert "No config found for 'url' field" in result.output + assert "did not resolve to a connection" in result.output class TestConfigCommand: @@ -43,7 +43,7 @@ def test_cli_args_pass_through_to_config(self): for part in ["backup", "restore"]: config_part = config[part] - assert config_part["connection"] == {"name": "default", "url": "foo"} + assert config_part["connection"] == "foo" assert config_part["ddl"]["enabled"] is True assert config_part["tables"][0]["location"] == "bar" assert config_part["tables"][0]["name"] == "baz" diff --git a/tests/test_config.py b/tests/test_config.py index aee2379..fa98732 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,4 @@ -from databudgie.config import ConfigStack, RootConfig +from databudgie.config import ConfigStack, Connection, RootConfig def test_only_leaf_values(): @@ -119,7 +119,7 @@ def test_root_level_tables(): } ) - assert config.backup.connection.url == "backup_url" + assert config.backup.connection == "backup_url" assert config.backup.tables[0].name == "root_table_1" assert config.backup.tables[0].query == "root_query" assert config.backup.tables[0].location == "root_location" @@ -127,7 +127,7 @@ def test_root_level_tables(): assert config.backup.tables[1].query == "root_query" assert config.backup.tables[1].location == "root_location" - assert config.restore.connection.url == "restore_url" + assert config.restore.connection == "restore_url" assert config.restore.tables[0].name == "root_table_1" assert config.restore.tables[0].strategy == "root_strategy" assert config.restore.tables[0].location == "root_location" @@ -151,7 +151,7 @@ def test_tables_as_just_strings(): } ) - assert config.backup.connection.url == "root_url" + assert config.backup.connection == "root_url" assert config.backup.tables[0].name == "root_table_1" assert config.backup.tables[0].query == "root_query" assert config.backup.tables[0].location == "root_location" @@ -174,7 +174,7 @@ def test_tables_mixed_str_dict(): } ) - assert config.backup.connection.url == "root_url" + assert config.backup.connection == "root_url" assert config.backup.tables[0].name == "table_1" assert config.backup.tables[0].query == "root_query" assert config.backup.tables[0].location == "root_location" @@ -216,8 +216,8 @@ def test_configs_stack(): ) config = RootConfig.from_stack(config_stack) - assert config.backup.connection.url == "root_url" - assert config.restore.connection.url == "restore url" + assert config.backup.connection == "root_url" + assert config.restore.connection == "restore url" assert config.backup.tables[0].name == "1" assert config.backup.tables[0].query == "bar" @@ -296,3 +296,25 @@ def test_parent_ddl_enabled(): assert config.backup.tables[0].ddl is False assert config.restore.tables[0].name == "backup_table_1" assert config.restore.tables[0].ddl is False + + +def test_connection_strings(): + """Assert connection strings are parsed correctly.""" + + config = RootConfig.from_dict({"connection": None}) + assert config.backup.connection == "default" + + config = RootConfig.from_dict({"connection": "example"}) + assert config.backup.connection == "example" + + config = RootConfig.from_dict({"connection": "dialect://foo"}) + assert isinstance(config.backup.connection, Connection) + assert config.backup.connection.url == "dialect://foo" + + config = RootConfig.from_dict({"connection": {"url": "foo"}}) + assert isinstance(config.backup.connection, Connection) + assert config.backup.connection.url == "foo" + + config = RootConfig.from_dict({"connection": {"dialect": "postgres"}}) + assert isinstance(config.backup.connection, Connection) + assert config.backup.connection.url == {"dialect": "postgres"}