Skip to content

Commit

Permalink
run the drop database awaitable
Browse files Browse the repository at this point in the history
  • Loading branch information
DataDaoDe committed Feb 1, 2025
1 parent a7f84e4 commit d43629c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
7 changes: 6 additions & 1 deletion advanced_alchemy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ async def _create_database_wrapper() -> None:
@bind_key_option
@no_prompt_option
def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
from anyio import run
from rich.prompt import Confirm

from advanced_alchemy.utils.databases import drop_database as _drop_database
Expand All @@ -450,6 +451,10 @@ def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ig
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to drop database `{db_name}`?[/]")
)
if input_confirmed:
_drop_database(sqlalchemy_config)

async def _drop_database_wrapper() -> None:
await _drop_database(sqlalchemy_config)

run(_drop_database_wrapper)

return database_group
9 changes: 3 additions & 6 deletions advanced_alchemy/utils/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@ def _disconnect_users_sql(self, version: tuple[int, int] | None, database: str |

def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter:
dialect_name = config.get_engine().url.get_dialect().name
driver = config.get_engine().url.get_dialect().driver

adapter_class = ADAPTERS.get(dialect_name)

if not adapter_class:
msg = f"No adapter available for {dialect_name}"
raise ValueError(msg)

driver = config.get_engine().url.get_dialect().driver
if driver not in adapter_class.supported_drivers:
msg = f"{dialect_name} adapter does not support the {driver} driver"
raise ValueError(msg)
Expand All @@ -154,17 +153,15 @@ def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding:

async def create_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, encoding: str = "utf8") -> None:
adapter = get_adapter(config, encoding)
engine = config.get_engine()
if isinstance(engine, AsyncEngine):
if isinstance(config.get_engine(), AsyncEngine):
await adapter.create_async()
else:
adapter.create()


async def drop_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig) -> None:
adapter = get_adapter(config)
engine = config.get_engine()
if isinstance(engine, AsyncEngine):
if isinstance(config.get_engine(), AsyncEngine):
await adapter.drop_async()
else:
adapter.drop()

0 comments on commit d43629c

Please sign in to comment.