Skip to content

Commit

Permalink
feat: Route --conn through config so that config-level "connection:" …
Browse files Browse the repository at this point in the history
…field works in the same way.
  • Loading branch information
DanCardin committed Jun 8, 2023
1 parent 4e04900 commit f238c30
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "databudgie"
version = "2.8.1"
version = "2.8.2"
packages = [
{ include = "databudgie", from = "src" },
]
Expand Down
43 changes: 25 additions & 18 deletions src/databudgie/cli/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Callable, Optional, Union
from __future__ import annotations

from typing import Callable

import click
import sqlalchemy
import sqlalchemy.engine.url
import sqlalchemy.orm
import strapp.click

from databudgie.config import BackupConfig, RestoreConfig, RootConfig
from databudgie.config import BackupConfig, Connection, RestoreConfig, RootConfig
from databudgie.output import Console

version = getattr(sqlalchemy, "__version__", "")
Expand All @@ -19,17 +21,22 @@
create_url = sqlalchemy.engine.url.URL


def _create_postgres_session(config: Union[BackupConfig, RestoreConfig], connection_name: Optional[str] = None):
if connection_name:
connection = config.connections.get(connection_name)
if connection is None:
raise click.UsageError(f"Connection '{connection_name}' not found")
else:
connection = config.connection
if connection is None:
raise click.UsageError("No config found for 'url' field. Either a 'connection' or a 'url' is required")

def _create_postgres_session(config: BackupConfig | RestoreConfig):
connection: str | Connection = config.connection
if isinstance(config.connection, str):
if config.connection in config.connections:
# `connection` can either be a connection's name, or a literal connection string. We need
# to only overwrite the `connection` value if there is a correspondingly named connection
connection = config.connections[config.connection]
else:
raise click.UsageError(
f"'{config.connection}' did not resolve to a connection. 'url' or 'connection' must "
"resolve to a named connection or connection details."
)

assert isinstance(connection, Connection)
url = connection.url

if isinstance(url, dict):
url_obj = create_url(**url)
else:
Expand All @@ -51,18 +58,18 @@ def restore_config(root_config: RootConfig, console: Console):
return root_config.restore


def backup_db(backup_config: BackupConfig, connection_name: Optional[str] = None):
return _create_postgres_session(backup_config, connection_name)
def backup_db(backup_config: BackupConfig):
return _create_postgres_session(backup_config)


def restore_db(restore_config: RestoreConfig, connection_name: Optional[str] = None):
return _create_postgres_session(restore_config, connection_name)
def restore_db(restore_config: RestoreConfig):
return _create_postgres_session(restore_config)


def backup_manifest(backup_config: BackupConfig, backup_db):
from databudgie.manifest.manager import BackupManifest

table_name: Optional[str] = backup_config.manifest
table_name: str | None = backup_config.manifest
if table_name:
return BackupManifest(backup_db, table_name)
return None
Expand All @@ -71,7 +78,7 @@ def backup_manifest(backup_config: BackupConfig, backup_db):
def restore_manifest(restore_config: RestoreConfig, restore_db):
from databudgie.manifest.manager import RestoreManifest

table_name: Optional[str] = restore_config.manifest
table_name: str | None = restore_config.manifest
if table_name:
return RestoreManifest(restore_db, table_name)
return None
Expand Down
2 changes: 1 addition & 1 deletion src/databudgie/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def cli(
location=location,
adapter=adapter,
strict=strict,
connection=conn,
)

try:
Expand All @@ -131,7 +132,6 @@ def cli(
root_config=root_config,
verbosity=verbose,
console=Console(verbosity=verbose),
connection_name=conn,
dry_run=bool(dry_run),
stats=stats if stats is not None else dry_run,
)
Expand Down
1 change: 1 addition & 0 deletions src/databudgie/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CliConfig(DatabudgieConfig):
location: str | None = None
adapter: str | None = None
strict: bool | None = None
connection: str | None = None

def to_dict(self) -> dict:
config = asdict(self)
Expand Down
25 changes: 18 additions & 7 deletions src/databudgie/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def from_stack(cls, stack: ConfigStack):
@dataclass
class TableParentConfig(typing.Generic[T], Config):
tables: list[T]
connections: dict[str, Connection]
connections: dict[str, Connection | str]
ddl: DDLConfig

connection: Connection | None = None
connection: Connection | str = "default"
manifest: str | None = None

s3: S3Config | None = None
Expand All @@ -145,7 +145,7 @@ def get_child_class(cls):

@classmethod
def from_stack(cls, stack: ConfigStack):
connection = Connection.from_raw(stack.get("url") or stack.get("connection"), name="default")
connection = Connection.from_raw(stack.get("url") or stack.get("connection"), name="default") or "default"
root_location = stack.get("root_location")

tables_config: list = normalize_table_config(stack.get("tables", []))
Expand Down Expand Up @@ -182,11 +182,14 @@ class Connection(Config):
url: str | dict

@classmethod
def from_raw(cls, raw: str | dict | None, *, name: str | None = None):
def from_raw(cls, raw: str | dict | None, *, name: str | None = None) -> Connection | str | None:
if raw is None:
return None

if isinstance(raw, str):
if "://" not in raw:
return raw

if name is None:
raise ConfigError(f"Connection '{raw}' requires a name")
return cls(name="default", url=raw)
Expand All @@ -202,7 +205,7 @@ def from_raw(cls, raw: str | dict | None, *, name: str | None = None):
return cls(name=name or "default", url=url)

@classmethod
def from_collection(cls, collection: list | dict | None) -> dict[str, Connection]:
def from_collection(cls, collection: list | dict | None) -> dict[str, Connection | str]:
if collection is None:
return {}

Expand All @@ -211,7 +214,10 @@ def from_collection(cls, collection: list | dict | None) -> dict[str, Connection
names = set()
for c in collection:
connection = Connection.from_raw(c, name=c.get("name"))
assert connection is not None
if not isinstance(connection, Connection):
raise ConfigError(
"Connections must be a database connection string or a mapping of individual connection fields."
)

if connection.name in names:
raise ConfigError(f"Detected more than one connection with the same name: {connection.name}")
Expand All @@ -220,7 +226,12 @@ def from_collection(cls, collection: list | dict | None) -> dict[str, Connection

return {c.name: c for c in connections}

return {k: Connection.from_raw(c, name=k) for k, c in collection.items()}
result = {}
for k, c in collection.items():
connection = Connection.from_raw(c, name=k)
if connection is not None:
result[k] = connection
return result


@dataclass
Expand Down
24 changes: 24 additions & 0 deletions tests/cli/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import click
import pytest
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import text

Expand Down Expand Up @@ -29,3 +31,25 @@ def test_create_postgres_session_url_components(pg_engine):
config = BackupConfig.from_stack(ConfigStack({"url": url_parts}))
session = _create_postgres_session(config)
session.execute(text("select 1"))


def test_connection_selection(pg_engine):
url = pg_engine.pmr_credentials.as_sqlalchemy_url()

try:
url_str = url.render_as_string(hide_password=False)
except AttributeError:
url_str = str(url)

assert url_str.startswith("postgresql+psycopg2://")
config = BackupConfig.from_stack(ConfigStack({"connections": {"foo": url_str}, "connection": "foo"}))

session = _create_postgres_session(config)
session.execute(text("select 1"))


def test_missing_connection():
config = BackupConfig.from_stack(ConfigStack({"connections": {}, "connection": "foo"}))

with pytest.raises(click.UsageError):
_create_postgres_session(config)
4 changes: 2 additions & 2 deletions tests/cli/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run_command(command, assert_exit_code=0):
@pytest.mark.parametrize("command", ("backup", "restore"))
def test_no_default_file_warns_of_no_url(command):
result = run_command(command, assert_exit_code=2)
assert "No config found for 'url' field" in result.output
assert "did not resolve to a connection" in result.output


class TestConfigCommand:
Expand All @@ -43,7 +43,7 @@ def test_cli_args_pass_through_to_config(self):

for part in ["backup", "restore"]:
config_part = config[part]
assert config_part["connection"] == {"name": "default", "url": "foo"}
assert config_part["connection"] == "foo"
assert config_part["ddl"]["enabled"] is True
assert config_part["tables"][0]["location"] == "bar"
assert config_part["tables"][0]["name"] == "baz"
Expand Down
36 changes: 29 additions & 7 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from databudgie.config import ConfigStack, RootConfig
from databudgie.config import ConfigStack, Connection, RootConfig


def test_only_leaf_values():
Expand Down Expand Up @@ -119,15 +119,15 @@ def test_root_level_tables():
}
)

assert config.backup.connection.url == "backup_url"
assert config.backup.connection == "backup_url"
assert config.backup.tables[0].name == "root_table_1"
assert config.backup.tables[0].query == "root_query"
assert config.backup.tables[0].location == "root_location"
assert config.backup.tables[1].name == "root_table_2"
assert config.backup.tables[1].query == "root_query"
assert config.backup.tables[1].location == "root_location"

assert config.restore.connection.url == "restore_url"
assert config.restore.connection == "restore_url"
assert config.restore.tables[0].name == "root_table_1"
assert config.restore.tables[0].strategy == "root_strategy"
assert config.restore.tables[0].location == "root_location"
Expand All @@ -151,7 +151,7 @@ def test_tables_as_just_strings():
}
)

assert config.backup.connection.url == "root_url"
assert config.backup.connection == "root_url"
assert config.backup.tables[0].name == "root_table_1"
assert config.backup.tables[0].query == "root_query"
assert config.backup.tables[0].location == "root_location"
Expand All @@ -174,7 +174,7 @@ def test_tables_mixed_str_dict():
}
)

assert config.backup.connection.url == "root_url"
assert config.backup.connection == "root_url"
assert config.backup.tables[0].name == "table_1"
assert config.backup.tables[0].query == "root_query"
assert config.backup.tables[0].location == "root_location"
Expand Down Expand Up @@ -216,8 +216,8 @@ def test_configs_stack():
)
config = RootConfig.from_stack(config_stack)

assert config.backup.connection.url == "root_url"
assert config.restore.connection.url == "restore url"
assert config.backup.connection == "root_url"
assert config.restore.connection == "restore url"

assert config.backup.tables[0].name == "1"
assert config.backup.tables[0].query == "bar"
Expand Down Expand Up @@ -296,3 +296,25 @@ def test_parent_ddl_enabled():
assert config.backup.tables[0].ddl is False
assert config.restore.tables[0].name == "backup_table_1"
assert config.restore.tables[0].ddl is False


def test_connection_strings():
"""Assert connection strings are parsed correctly."""

config = RootConfig.from_dict({"connection": None})
assert config.backup.connection == "default"

config = RootConfig.from_dict({"connection": "example"})
assert config.backup.connection == "example"

config = RootConfig.from_dict({"connection": "dialect://foo"})
assert isinstance(config.backup.connection, Connection)
assert config.backup.connection.url == "dialect://foo"

config = RootConfig.from_dict({"connection": {"url": "foo"}})
assert isinstance(config.backup.connection, Connection)
assert config.backup.connection.url == "foo"

config = RootConfig.from_dict({"connection": {"dialect": "postgres"}})
assert isinstance(config.backup.connection, Connection)
assert config.backup.connection.url == {"dialect": "postgres"}

0 comments on commit f238c30

Please sign in to comment.