Skip to content

Commit

Permalink
feat: 'alter' via cli should work now!
Browse files Browse the repository at this point in the history
  • Loading branch information
robinvandernoord committed Jul 31, 2023
1 parent fb4d826 commit a734a8c
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 20 deletions.
10 changes: 9 additions & 1 deletion pytest_examples/magic_post.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# this is the changed file
from pydal import DAL

my_method = SomeClass()['first']

db = DAL("sqlite://my_real_database")

db.define_table(
"person",
Field(
Expand All @@ -16,4 +19,9 @@

db.define_table("empty")

db_type = "sqlite"

db.define_table("new_table",
Field("new_field")
)

db_type = "psql"
7 changes: 6 additions & 1 deletion pytest_examples/magic_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
default=my_method.new_uuid()
),
Field("birthday", "datetime", default=datetime.datetime.utcnow()),
Field("removed")
)

db.define_table("old_table",
Field("old_field")
)

db.define_table("empty")

db_type = "sqlite"
db_type = "psql"
33 changes: 28 additions & 5 deletions src/pydal2sql/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Annotated, Optional

import typer
from configuraptor import Singleton
from rich import print
from typer import Argument
from typing_extensions import Never
Expand All @@ -16,7 +17,7 @@
find_git_root,
get_absolute_path_info,
get_file_for_version,
handle_cli_create,
handle_cli,
)
from .typer_support import (
DEFAULT_VERBOSITY,
Expand Down Expand Up @@ -101,7 +102,8 @@ def create(

text = get_file_for_version(file_absolute_path, file_version, prompt_description="table definition")

return handle_cli_create(
return handle_cli(
"",
text,
db_type=db_type.value if db_type else None,
tables=config.tables,
Expand All @@ -117,6 +119,12 @@ def alter(
filename_before: OptionalArgument[str] = None,
filename_after: OptionalArgument[str] = None,
db_type: DB_Types = None,
tables: Annotated[
Optional[list[str]],
typer.Option("--table", "--tables", "-t", help="One or more table names, default is all tables."),
] = None,
magic: Optional[bool] = None,
noop: Optional[bool] = None,
) -> bool:
"""
Todo: docs
Expand All @@ -137,6 +145,8 @@ def alter(
"""
git_root = find_git_root() or Path(os.getcwd())

config = state.update_config(magic=magic, noop=noop, tables=tables)

before, after = extract_file_versions_and_paths(filename_before, filename_after)

version_before, filename_before = before
Expand All @@ -150,7 +160,8 @@ def alter(
if not (before_exists and after_exists):
message = ""
message += "" if before_exists else f"Path {filename_before} does not exist! "
message += "" if after_exists else f"Path {filename_after} does not exist!"
if filename_before != filename_after:
message += "" if after_exists else f"Path {filename_after} does not exist!"
raise ValueError(message)

code_before = get_file_for_version(
Expand All @@ -167,8 +178,15 @@ def alter(
if code_before == code_after:
raise ValueError("Both contain the same code!")

print(len(code_before), len(code_after), db_type)
return True
return handle_cli(
code_before,
code_after,
db_type=db_type.value if db_type else None,
tables=config.tables,
verbose=state.verbosity > Verbosity.normal,
noop=config.noop,
magic=config.magic,
)


"""
Expand Down Expand Up @@ -226,6 +244,11 @@ def main(
version: display current version?
"""
if state.config:
# if a config already exists, it's outdated, so we clear it.
# only really applicable in Pytest scenarios where multiple commands are executed after eachother
Singleton.clear(state.config)

state.load_config(config_file=config, verbosity=verbosity)

if show_config:
Expand Down
87 changes: 78 additions & 9 deletions src/pydal2sql/cli_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from git.repo import Repo

from .helpers import flatten
from .magic import find_missing_variables, generate_magic_code
from .magic import (
find_defined_variables,
find_missing_variables,
generate_magic_code,
remove_specific_variables,
)


def has_stdin_data() -> bool: # pragma: no cover
Expand Down Expand Up @@ -153,7 +158,12 @@ def extract_file_version_and_path(
def extract_file_versions_and_paths(
filename_before: Optional[str], filename_after: Optional[str]
) -> tuple[tuple[str, str | None], tuple[str, str | None]]:
version_before, filepath_before = extract_file_version_and_path(filename_before, default_version="latest")
version_before, filepath_before = extract_file_version_and_path(
filename_before,
default_version="current"
if filename_after and filename_before and filename_after != filename_before
else "latest",
)
version_after, filepath_after = extract_file_version_and_path(filename_after, default_version="current")

if not (filepath_before or filepath_after):
Expand Down Expand Up @@ -192,8 +202,37 @@ def get_absolute_path_info(filename: Optional[str], version: str, git_root: Opti
return exists, absolute_path


def handle_cli_create(
code: str,
def ensure_no_migrate_on_real_db(
code: str, db_names: typing.Iterable[str] = ("db", "database"), fix: typing.Optional[bool] = False
) -> str:
variables = find_defined_variables(code)

found_variables = set()

for db_name in db_names:
if db_name in variables:
if fix:
code = remove_specific_variables(code, db_names)
else:
found_variables.add(db_name)

if found_variables:
if len(found_variables) == 1:
var = next(iter(found_variables))
message = f"Variable {var} defined in code! "
else: # pragma: no cover
var = ", ".join(found_variables)
message = f"Variables {var} defined in code! "
raise ValueError(
f"{message} Please remove this or use --magic to prevent performing actual migrations on your database."
)

return code


def handle_cli(
code_before: str,
code_after: str,
db_type: Optional[str] = None,
tables: Optional[list[str] | list[list[str]]] = None,
verbose: Optional[bool] = False,
Expand All @@ -203,6 +242,8 @@ def handle_cli_create(
"""
Handle user input.
"""
# todo: prefix (e.g. public.)

to_execute = string.Template(
textwrap.dedent(
"""
Expand All @@ -219,19 +260,46 @@ def handle_cli_create(
$extra
$code
$code_before
db_old = db
db_new = db = database = DAL(None, migrate=False)
$code_after
if not tables:
tables = set(db_old._tables + db_new._tables)
if not tables:
tables = db._tables
print("No tables found!", file=sys.stderr)
print("Please use `db.define_table` or `database.define_table`, \
or if you really need to use an alias like my_db.define_tables, \
add `my_db = db` at the top of the file or pass `--db-name mydb`.")
for table in tables:
print(generate_sql(db[table], db_type=db_type))
print('--', table)
if table in db_old and table in db_new:
print(generate_sql(db_old[table], db_new[table], db_type=db_type))
elif table in db_old:
print(f'DROP TABLE {table};')
else:
print(generate_sql(db_new[table], db_type=db_type))
"""
)
)

code_before = ensure_no_migrate_on_real_db(code_before, fix=magic)
code_after = ensure_no_migrate_on_real_db(code_after, fix=magic)

generated_code = to_execute.substitute(
{"tables": flatten(tables or []), "db_type": db_type or "", "code": textwrap.dedent(code), "extra": ""}
{
"tables": flatten(tables or []),
"db_type": db_type or "",
"code_before": textwrap.dedent(code_before),
"code_after": textwrap.dedent(code_after),
"extra": "",
}
)
if verbose or noop:
rich.print(generated_code, file=sys.stderr)
Expand All @@ -255,7 +323,8 @@ def handle_cli_create(
"tables": flatten(tables or []),
"db_type": db_type or "",
"extra": extra_code,
"code": textwrap.dedent(code),
"code_before": textwrap.dedent(code_before),
"code_after": textwrap.dedent(code_after),
}
)

Expand Down
4 changes: 4 additions & 0 deletions src/pydal2sql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def generate_alter_statement(
fake_migrate=True,
)

if not sql_log.exists():
# no changes!
return ""

with sql_log.open() as f:
for line in f:
if not line.startswith(("ALTER", "UPDATE")):
Expand Down
55 changes: 55 additions & 0 deletions src/pydal2sql/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,61 @@ def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], N
traverse_ast(child, variable_collector)


def find_defined_variables(code_str: str) -> set[str]:
tree: ast.Module = ast.parse(code_str)

variables: set[str] = set()

def collect_definitions(node: ast.AST) -> None:
if isinstance(node, ast.Assign):
node_targets = typing.cast(list[ast.Name], node.targets)

variables.update(target.id for target in node_targets)

traverse_ast(tree, collect_definitions)
return variables


def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str:
# Parse the code into an Abstract Syntax Tree (AST) - by ChatGPT
tree = ast.parse(code)

# Function to check if a variable name is 'db' or 'database'
def should_remove(var_name: str) -> bool:
return var_name in to_remove

# Function to recursively traverse the AST and remove definitions of 'db' or 'database'
def remove_db_and_database_defs_rec(node: ast.AST) -> typing.Optional[ast.AST]:
if isinstance(node, ast.Assign):
# Check if the assignment targets contain 'db' or 'database'
new_targets = [
target for target in node.targets if not (isinstance(target, ast.Name) and should_remove(target.id))
]
node.targets = new_targets

elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name):
# Check if function or class name is 'db' or 'database'
return None

for child_node in ast.iter_child_nodes(node):
# Recursively process child nodes
new_child_node = remove_db_and_database_defs_rec(child_node)
if new_child_node is None and hasattr(node, "body"):
# If the child node was removed, remove it from the parent's body
node.body.remove(child_node)

return node

# Traverse the AST to remove 'db' and 'database' definitions
new_tree = remove_db_and_database_defs_rec(tree)

if not new_tree: # pragma: no cover
return ""

# Generate the modified code from the new AST
return ast.unparse(new_tree)


def find_variables(code_str: str) -> tuple[set[str], set[str]]:
"""
Look through the source code in code_str and try to detect using ast parsing which variables are undefined.
Expand Down
18 changes: 15 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ def test_cli_create():
assert result.exit_code == 1
assert "could not be found" in result.stderr

result = runner.invoke(app, ["create", "magic.py"])
result = runner.invoke(app, ["create", "magic.py@latest"])
assert result.exit_code == 1
assert not result.stdout
assert "missing some variables" in result.stderr

result = runner.invoke(app, ["create", "magic.py@current"])
assert result.exit_code == 1
assert not result.stdout
assert "db defined in code!" in result.stderr

result = runner.invoke(app, ["create", "magic.py", "--magic"])
assert result.exit_code == 0
assert not result.stderr
Expand All @@ -42,9 +47,11 @@ def test_cli_create():

def test_cli_alter():
with mock_git():
result = runner.invoke(app, ["alter", "missing.py"])
result = runner.invoke(app, ["alter", "missing.py", "missing2.py"])
assert result.exit_code == 1
assert "does not exist" in result.stderr
assert "missing.py" in result.stderr
assert "missing2.py" in result.stderr

Path("empty.py").touch()

Expand All @@ -56,8 +63,13 @@ def test_cli_alter():
assert result.exit_code == 1
assert "contain the same code" in result.stderr

result = runner.invoke(app, ["alter", "magic.py"])
result = runner.invoke(app, ["alter", "magic.py", "--magic"])
print('++ stdout', result.stdout)
print('++ stderr', result.stderr)
assert result.exit_code == 0
assert result.stdout



def test_cli_version():
result = runner.invoke(app, ["--version"])
Expand Down
Loading

0 comments on commit a734a8c

Please sign in to comment.