Skip to content

Commit

Permalink
Add support for column property in version models
Browse files Browse the repository at this point in the history
Copy the column_property() defined in the parent model to the
version model
  • Loading branch information
AbdealiLoKo committed Aug 30, 2022
1 parent 38cef84 commit 4348533
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 21 deletions.
23 changes: 23 additions & 0 deletions sqlalchemy_continuum/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import wraps

import sqlalchemy as sa
from sqlalchemy_continuum.expression_reflector import VersionExpressionReflector
from sqlalchemy_continuum.utils import is_table_column
from sqlalchemy_utils.functions import get_declarative_base

from .dialects.postgresql import create_versioning_trigger_listeners
Expand Down Expand Up @@ -189,6 +191,7 @@ def configure_versioned_classes(self):
self.build_relationships(pending_classes_copies)
self.enable_active_history(pending_classes_copies)
self.create_column_aliases(pending_classes_copies)
self.create_column_properties(pending_classes_copies)

def enable_active_history(self, version_classes):
"""
Expand Down Expand Up @@ -220,3 +223,23 @@ def create_column_aliases(self, version_classes):
continue

version_class_mapper.add_property(key, sa.orm.column_property(version_class_column))

def create_column_properties(self, version_classes):
"""
Create equivalent column_property() on the version class (as it is on the parent model)
This does not handle the simple column aliases - just expressions
"""
for cls in version_classes:
model_mapper = sa.inspect(cls)
version_class = self.manager.version_class_map.get(cls)
if not version_class:
continue

version_class_mapper = sa.inspect(version_class)
reflector = VersionExpressionReflector()
for key, column in model_mapper.columns.items():
if is_table_column(column): # We ignore simple table columns
continue
version_column = reflector(column)
version_class_mapper.add_property(key, sa.orm.column_property(version_column))
39 changes: 27 additions & 12 deletions sqlalchemy_continuum/expression_reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@


class VersionExpressionReflector(sa.sql.visitors.ReplacingCloningVisitor):
def __init__(self, parent, relationship):
self.parent = parent
self.relationship = relationship

"""Take an expression and convert the columns to the version_table's columns"""
def replace(self, column):
if not isinstance(column, sa.Column):
return
Expand All @@ -18,16 +15,34 @@ def replace(self, column):
reflected_column = column
else:
reflected_column = table.c[column.name]
if (
column in self.relationship.local_columns and
table == self.parent.__table__
):
reflected_column = bindparam(
column.key,
getattr(self.parent, column.key)
)

return reflected_column

def __call__(self, expr):
return self.traverse(expr)


class RelationshipPrimaryJoinReflector(VersionExpressionReflector):
"""
Takes a relationship and modifies it to handle the primaryjoin of the relationship
"""
def __init__(self, parent, relationship):
self.parent = parent
self.relationship = relationship

def replace(self, column):
reflected_column = super().replace(column)
if reflected_column is None:
return

if (
column in self.relationship.local_columns and
reflected_column.table == self.parent.__table__
):
# Keep the columns from the self.parent.__table__ as is
reflected_column = bindparam(
column.key,
getattr(self.parent, column.key)
)

return reflected_column
12 changes: 6 additions & 6 deletions sqlalchemy_continuum/relationship_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlalchemy as sa

from .exc import ClassNotVersioned
from .expression_reflector import VersionExpressionReflector
from .expression_reflector import RelationshipPrimaryJoinReflector
from .operation import Operation
from .table_builder import TableBuilder
from .utils import adapt_columns, version_class, option
Expand Down Expand Up @@ -46,7 +46,7 @@ def one_to_many_subquery(self, obj):

def many_to_one_subquery(self, obj):
tx_column = option(obj, 'transaction_column_name')
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)

return getattr(self.remote_cls, tx_column) == (
sa.select(
Expand Down Expand Up @@ -93,7 +93,7 @@ def criteria(self, obj):
elif direction.name == 'MANYTOONE':
return self.many_to_one_criteria(obj)
else:
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return reflector(self.property.primaryjoin)

def many_to_many_criteria(self, obj):
Expand Down Expand Up @@ -171,7 +171,7 @@ def many_to_one_criteria(self, obj):
AND operation_type != 2
"""
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.many_to_one_subquery(obj),
Expand Down Expand Up @@ -209,7 +209,7 @@ def one_to_many_criteria(self, obj):
)
"""
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.one_to_many_subquery(obj),
Expand Down Expand Up @@ -263,7 +263,7 @@ def association_subquery(self, obj):
tx_column = option(obj, 'transaction_column_name')
join_column = self.property.primaryjoin.right.name
object_join_column = self.property.primaryjoin.left.name
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)

association_table_alias = self.association_version_table.alias()
association_cols = [
Expand Down
7 changes: 4 additions & 3 deletions sqlalchemy_continuum/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlalchemy as sa

from .reverter import Reverter
from .utils import get_versioning_manager, is_internal_column, parent_class
from .utils import get_versioning_manager, is_internal_column, is_table_column, parent_class


class VersionClassBase(object):
Expand Down Expand Up @@ -52,8 +52,9 @@ def changeset(self):
previous_version = self.previous
data = {}

for key in sa.inspect(self.__class__).columns.keys():
if is_internal_column(self, key):
for key, column in sa.inspect(self.__class__).columns.items():
if is_internal_column(self, key) or not is_table_column(column):
# Ignore internal columns and column_property() which are expressions
continue
if not previous_version:
old = None
Expand Down
13 changes: 13 additions & 0 deletions tests/builders/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,16 @@ def test_builds_relationship(self):

def test_parent_has_access_to_versioning_manager(self):
assert self.Article.__versioning_manager__


def test_column_properties(self):
article = self.Article()
article.name = u'Name'
article.content = u'Content'
article.description = u'Desc'
self.session.add(article)
self.session.commit()

article_version = article.versions[0]
assert article.fulltext_content == article.name + article.content + article.description
assert article.fulltext_content == article_version.fulltext_content

0 comments on commit 4348533

Please sign in to comment.