diff --git a/pytest_examples/magic_post.py b/pytest_examples/magic_post.py index c843f29..2947afa 100644 --- a/pytest_examples/magic_post.py +++ b/pytest_examples/magic_post.py @@ -1,7 +1,10 @@ # this is the changed file +from pydal import DAL my_method = SomeClass()['first'] +db = DAL("sqlite://my_real_database") + db.define_table( "person", Field( @@ -16,4 +19,9 @@ db.define_table("empty") -db_type = "sqlite" + +db.define_table("new_table", + Field("new_field") + ) + +db_type = "psql" diff --git a/pytest_examples/magic_pre.py b/pytest_examples/magic_pre.py index 254ab57..bf09169 100644 --- a/pytest_examples/magic_pre.py +++ b/pytest_examples/magic_pre.py @@ -11,8 +11,13 @@ default=my_method.new_uuid() ), Field("birthday", "datetime", default=datetime.datetime.utcnow()), + Field("removed") ) +db.define_table("old_table", + Field("old_field") + ) + db.define_table("empty") -db_type = "sqlite" \ No newline at end of file +db_type = "psql" diff --git a/src/pydal2sql/cli.py b/src/pydal2sql/cli.py index 43b70bc..5c61397 100644 --- a/src/pydal2sql/cli.py +++ b/src/pydal2sql/cli.py @@ -5,6 +5,7 @@ from typing import Annotated, Optional import typer +from configuraptor import Singleton from rich import print from typer import Argument from typing_extensions import Never @@ -16,7 +17,7 @@ find_git_root, get_absolute_path_info, get_file_for_version, - handle_cli_create, + handle_cli, ) from .typer_support import ( DEFAULT_VERBOSITY, @@ -101,7 +102,8 @@ def create( text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition") - return handle_cli_create( + return handle_cli( + "", text, db_type=db_type.value if db_type else None, tables=config.tables, @@ -117,6 +119,12 @@ def alter( filename_before: OptionalArgument[str] = None, filename_after: OptionalArgument[str] = None, db_type: DB_Types = None, + tables: Annotated[ + Optional[list[str]], + typer.Option("--table", "--tables", "-t", help="One or more table names, default is all tables."), + ] = None, + magic: Optional[bool] = None, + noop: Optional[bool] = None, ) -> bool: """ Todo: docs @@ -137,6 +145,8 @@ def alter( """ git_root = find_git_root() or Path(os.getcwd()) + config = state.update_config(magic=magic, noop=noop, tables=tables) + before, after = extract_file_versions_and_paths(filename_before, filename_after) version_before, filename_before = before @@ -150,7 +160,8 @@ def alter( if not (before_exists and after_exists): message = "" message += "" if before_exists else f"Path {filename_before} does not exist! " - message += "" if after_exists else f"Path {filename_after} does not exist!" + if filename_before != filename_after: + message += "" if after_exists else f"Path {filename_after} does not exist!" raise ValueError(message) code_before = get_file_for_version( @@ -167,8 +178,15 @@ def alter( if code_before == code_after: raise ValueError("Both contain the same code!") - print(len(code_before), len(code_after), db_type) - return True + return handle_cli( + code_before, + code_after, + db_type=db_type.value if db_type else None, + tables=config.tables, + verbose=state.verbosity > Verbosity.normal, + noop=config.noop, + magic=config.magic, + ) """ @@ -226,6 +244,11 @@ def main( version: display current version? """ + if state.config: + # if a config already exists, it's outdated, so we clear it. + # only really applicable in Pytest scenarios where multiple commands are executed after eachother + Singleton.clear(state.config) + state.load_config(config_file=config, verbosity=verbosity) if show_config: diff --git a/src/pydal2sql/cli_support.py b/src/pydal2sql/cli_support.py index 96761cc..791bed6 100644 --- a/src/pydal2sql/cli_support.py +++ b/src/pydal2sql/cli_support.py @@ -19,7 +19,12 @@ from git.repo import Repo from .helpers import flatten -from .magic import find_missing_variables, generate_magic_code +from .magic import ( + find_defined_variables, + find_missing_variables, + generate_magic_code, + remove_specific_variables, +) def has_stdin_data() -> bool: # pragma: no cover @@ -153,7 +158,12 @@ def extract_file_version_and_path( def extract_file_versions_and_paths( filename_before: Optional[str], filename_after: Optional[str] ) -> tuple[tuple[str, str | None], tuple[str, str | None]]: - version_before, filepath_before = extract_file_version_and_path(filename_before, default_version="latest") + version_before, filepath_before = extract_file_version_and_path( + filename_before, + default_version="current" + if filename_after and filename_before and filename_after != filename_before + else "latest", + ) version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current") if not (filepath_before or filepath_after): @@ -192,8 +202,37 @@ def get_absolute_path_info(filename: Optional[str], version: str, git_root: Opti return exists, absolute_path -def handle_cli_create( - code: str, +def ensure_no_migrate_on_real_db( + code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False +) -> str: + variables = find_defined_variables(code) + + found_variables = set() + + for db_name in db_names: + if db_name in variables: + if fix: + code = remove_specific_variables(code, db_names) + else: + found_variables.add(db_name) + + if found_variables: + if len(found_variables) == 1: + var = next(iter(found_variables)) + message = f"Variable {var} defined in code! " + else: # pragma: no cover + var = ", ".join(found_variables) + message = f"Variables {var} defined in code! " + raise ValueError( + f"{message} Please remove this or use --magic to prevent performing actual migrations on your database." + ) + + return code + + +def handle_cli( + code_before: str, + code_after: str, db_type: Optional[str] = None, tables: Optional[list[str] | list[list[str]]] = None, verbose: Optional[bool] = False, @@ -203,6 +242,8 @@ def handle_cli_create( """ Handle user input. """ + # todo: prefix (e.g. public.) + to_execute = string.Template( textwrap.dedent( """ @@ -219,19 +260,46 @@ def handle_cli_create( $extra - $code + $code_before + + db_old = db + db_new = db = database = DAL(None, migrate=False) + + $code_after + + if not tables: + tables = set(db_old._tables + db_new._tables) if not tables: - tables = db._tables + print("No tables found!", file=sys.stderr) + print("Please use `db.define_table` or `database.define_table`, \ + or if you really need to use an alias like my_db.define_tables, \ + add `my_db = db` at the top of the file or pass `--db-name mydb`.") + for table in tables: - print(generate_sql(db[table], db_type=db_type)) + print('--', table) + if table in db_old and table in db_new: + print(generate_sql(db_old[table], db_new[table], db_type=db_type)) + elif table in db_old: + print(f'DROP TABLE {table};') + else: + print(generate_sql(db_new[table], db_type=db_type)) """ ) ) + code_before = ensure_no_migrate_on_real_db(code_before, fix=magic) + code_after = ensure_no_migrate_on_real_db(code_after, fix=magic) + generated_code = to_execute.substitute( - {"tables": flatten(tables or []), "db_type": db_type or "", "code": textwrap.dedent(code), "extra": ""} + { + "tables": flatten(tables or []), + "db_type": db_type or "", + "code_before": textwrap.dedent(code_before), + "code_after": textwrap.dedent(code_after), + "extra": "", + } ) if verbose or noop: rich.print(generated_code, file=sys.stderr) @@ -255,7 +323,8 @@ def handle_cli_create( "tables": flatten(tables or []), "db_type": db_type or "", "extra": extra_code, - "code": textwrap.dedent(code), + "code_before": textwrap.dedent(code_before), + "code_after": textwrap.dedent(code_after), } ) diff --git a/src/pydal2sql/core.py b/src/pydal2sql/core.py index 16ef09a..f9f2535 100644 --- a/src/pydal2sql/core.py +++ b/src/pydal2sql/core.py @@ -170,6 +170,10 @@ def generate_alter_statement( fake_migrate=True, ) + if not sql_log.exists(): + # no changes! + return "" + with sql_log.open() as f: for line in f: if not line.startswith(("ALTER", "UPDATE")): diff --git a/src/pydal2sql/magic.py b/src/pydal2sql/magic.py index 3c63746..431d080 100644 --- a/src/pydal2sql/magic.py +++ b/src/pydal2sql/magic.py @@ -20,6 +20,61 @@ def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], N traverse_ast(child, variable_collector) +def find_defined_variables(code_str: str) -> set[str]: + tree: ast.Module = ast.parse(code_str) + + variables: set[str] = set() + + def collect_definitions(node: ast.AST) -> None: + if isinstance(node, ast.Assign): + node_targets = typing.cast(list[ast.Name], node.targets) + + variables.update(target.id for target in node_targets) + + traverse_ast(tree, collect_definitions) + return variables + + +def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str: + # Parse the code into an Abstract Syntax Tree (AST) - by ChatGPT + tree = ast.parse(code) + + # Function to check if a variable name is 'db' or 'database' + def should_remove(var_name: str) -> bool: + return var_name in to_remove + + # Function to recursively traverse the AST and remove definitions of 'db' or 'database' + def remove_db_and_database_defs_rec(node: ast.AST) -> typing.Optional[ast.AST]: + if isinstance(node, ast.Assign): + # Check if the assignment targets contain 'db' or 'database' + new_targets = [ + target for target in node.targets if not (isinstance(target, ast.Name) and should_remove(target.id)) + ] + node.targets = new_targets + + elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name): + # Check if function or class name is 'db' or 'database' + return None + + for child_node in ast.iter_child_nodes(node): + # Recursively process child nodes + new_child_node = remove_db_and_database_defs_rec(child_node) + if new_child_node is None and hasattr(node, "body"): + # If the child node was removed, remove it from the parent's body + node.body.remove(child_node) + + return node + + # Traverse the AST to remove 'db' and 'database' definitions + new_tree = remove_db_and_database_defs_rec(tree) + + if not new_tree: # pragma: no cover + return "" + + # Generate the modified code from the new AST + return ast.unparse(new_tree) + + def find_variables(code_str: str) -> tuple[set[str], set[str]]: """ Look through the source code in code_str and try to detect using ast parsing which variables are undefined. diff --git a/tests/test_cli.py b/tests/test_cli.py index 31ac9ee..779922f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,11 +18,16 @@ def test_cli_create(): assert result.exit_code == 1 assert "could not be found" in result.stderr - result = runner.invoke(app, ["create", "magic.py"]) + result = runner.invoke(app, ["create", "magic.py@latest"]) assert result.exit_code == 1 assert not result.stdout assert "missing some variables" in result.stderr + result = runner.invoke(app, ["create", "magic.py@current"]) + assert result.exit_code == 1 + assert not result.stdout + assert "db defined in code!" in result.stderr + result = runner.invoke(app, ["create", "magic.py", "--magic"]) assert result.exit_code == 0 assert not result.stderr @@ -42,9 +47,11 @@ def test_cli_create(): def test_cli_alter(): with mock_git(): - result = runner.invoke(app, ["alter", "missing.py"]) + result = runner.invoke(app, ["alter", "missing.py", "missing2.py"]) assert result.exit_code == 1 assert "does not exist" in result.stderr + assert "missing.py" in result.stderr + assert "missing2.py" in result.stderr Path("empty.py").touch() @@ -56,8 +63,13 @@ def test_cli_alter(): assert result.exit_code == 1 assert "contain the same code" in result.stderr - result = runner.invoke(app, ["alter", "magic.py"]) + result = runner.invoke(app, ["alter", "magic.py", "--magic"]) + print('++ stdout', result.stdout) + print('++ stderr', result.stderr) assert result.exit_code == 0 + assert result.stdout + + def test_cli_version(): result = runner.invoke(app, ["--version"]) diff --git a/tests/test_magic.py b/tests/test_magic.py index c37c34f..93437f7 100644 --- a/tests/test_magic.py +++ b/tests/test_magic.py @@ -1,4 +1,6 @@ -from src.pydal2sql.magic import find_missing_variables, generate_magic_code +import textwrap + +from src.pydal2sql.magic import find_missing_variables, generate_magic_code, remove_specific_variables def test_find_missing(): @@ -41,3 +43,18 @@ def test_fix_missing(): assert "empty = Empty()" in code assert "bla = empty" in code + +def test_remove_specific_variables(): + code = textwrap.dedent(""" + db = 1 + def database(): + return True + + my_database = 'exists' + print('hi') + """) + new_code = remove_specific_variables(code) + assert "print('hi')" in new_code + assert 'db' not in new_code + assert 'def database' not in new_code + assert 'my_database' in new_code