Skip to content

Commit

Permalink
fix: use current() instead of approved_up_to_transaction() in codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
nboyse committed Nov 13, 2023
1 parent 7f4df66 commit 894a04e
Show file tree
Hide file tree
Showing 42 changed files with 138 additions and 241 deletions.
10 changes: 4 additions & 6 deletions additional_codes/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,11 @@ def clean(self):
def save(self, commit=True):
instance = super().save(commit=False)

tx = WorkBasket.get_current_transaction(self.request)

highest_sid = (
models.AdditionalCode.objects.approved_up_to_transaction(tx).aggregate(
Max("sid"),
)["sid__max"]
) or 0
models.AdditionalCode.objects.current().aggregate(
Max("sid"),
)["sid__max"]
) or 0
instance.sid = highest_sid + 1

if commit:
Expand Down
9 changes: 3 additions & 6 deletions additional_codes/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ class AdditionalCodeViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = [AdditionalCodeFilterBackend]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return (
AdditionalCode.objects.approved_up_to_transaction(tx)
AdditionalCode.objects.current()
.select_related("type")
.prefetch_related("descriptions")
)
Expand All @@ -60,8 +59,7 @@ class AdditionalCodeMixin:
model: Type[TrackedModel] = AdditionalCode

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return AdditionalCode.objects.approved_up_to_transaction(tx).select_related(
return AdditionalCode.objects.current().select_related(
"type",
)

Expand Down Expand Up @@ -180,8 +178,7 @@ class AdditionalCodeDescriptionMixin:
model: Type[TrackedModel] = AdditionalCodeDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return AdditionalCodeDescription.objects.approved_up_to_transaction(tx)
return AdditionalCodeDescription.objects.current()


class AdditionalCodeDescriptionCreate(
Expand Down
8 changes: 3 additions & 5 deletions certificates/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(self, *args, **kwargs):

def filter_certificates_for_sid(self, sid):
certificate_type = self.cleaned_data["certificate_type"]
tx = WorkBasket.get_current_transaction(self.request)
return models.Certificate.objects.approved_up_to_transaction(tx).filter(
return models.Certificate.objects.current().filter(
sid=sid,
certificate_type=certificate_type,
)
Expand All @@ -64,14 +63,13 @@ def next_sid(self, instance):
form's save() method (with its commit param set either to True or
False).
"""
current_transaction = WorkBasket.get_current_transaction(self.request)
# Filter certificate by type and find the highest sid, using regex to
# ignore legacy, non-numeric identifiers
return get_next_id(
models.Certificate.objects.filter(
models.Certificate.objects.current().filter(
sid__regex=r"^[0-9]*$",
certificate_type__sid=instance.certificate_type.sid,
).approved_up_to_transaction(current_transaction),
),
instance._meta.get_field("sid"),
max_len=3,
)
Expand Down
9 changes: 3 additions & 6 deletions certificates/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ class CertificatesViewSet(viewsets.ReadOnlyModelViewSet):
permission_classes = [permissions.IsAuthenticated]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return (
models.Certificate.objects.approved_up_to_transaction(tx)
models.Certificate.objects.current()
.select_related("certificate_type")
.prefetch_related("descriptions")
)
Expand All @@ -53,8 +52,7 @@ class CertificateMixin:
model: Type[TrackedModel] = models.Certificate

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.Certificate.objects.approved_up_to_transaction(tx).select_related(
return models.Certificate.objects.current().select_related(
"certificate_type",
)

Expand Down Expand Up @@ -171,8 +169,7 @@ class CertificateDescriptionMixin:
model: Type[TrackedModel] = models.CertificateDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.CertificateDescription.objects.approved_up_to_transaction(tx)
return models.CertificateDescription.objects.current()


class CertificateCreateDescriptionMixin:
Expand Down
31 changes: 10 additions & 21 deletions commodities/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ def __init__(self, transaction=None):
self.logger = logging.getLogger(type(self).__name__)

def parent_spans_child(self, parent, child) -> bool:
parent_validity = parent.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
child_validity = child.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
parent_validity = parent.indented_goods_nomenclature.version_at().valid_between
child_validity = child.indented_goods_nomenclature.version_at().valid_between
return validity_range_contains_range(parent_validity, child_validity)

def parents_span_childs_future(self, parents, child):
Expand All @@ -59,17 +55,13 @@ def parents_span_childs_future(self, parents, child):
parents_validity = []
for parent in parents:
parents_validity.append(
parent.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between,
parent.indented_goods_nomenclature.version_at().valid_between,
)

# sort by start date so any gaps will be obvious
parents_validity.sort(key=lambda daterange: daterange.lower)

child_validity = child.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
child_validity = child.indented_goods_nomenclature.version_at().valid_between

if (
not child_validity.upper_inf
Expand Down Expand Up @@ -108,7 +100,7 @@ def validate(self, indent):
from commodities.models.dc import get_chapter_collection

try:
good = indent.indented_goods_nomenclature.version_at(self.transaction)
good = indent.indented_goods_nomenclature.version_at()
except TrackedModel.DoesNotExist:
self.logger.warning(
"Goods nomenclature %s no longer exists at transaction %s "
Expand Down Expand Up @@ -166,10 +158,9 @@ def validate(self, good):

if not (
good.code.is_chapter
or GoodsNomenclatureOrigin.objects.filter(
or GoodsNomenclatureOrigin.objects.current().filter(
new_goods_nomenclature__sid=good.sid,
)
.approved_up_to_transaction(good.transaction)
.exists()
):
raise self.violation(
Expand Down Expand Up @@ -252,9 +243,9 @@ class NIG11(ValidityStartDateRules):
def get_objects(self, good):
GoodsNomenclatureIndent = good.indents.model

return GoodsNomenclatureIndent.objects.filter(
return GoodsNomenclatureIndent.objects.current().filter(
indented_goods_nomenclature__sid=good.sid,
).approved_up_to_transaction(self.transaction)
)


class NIG12(DescriptionsRules):
Expand Down Expand Up @@ -305,7 +296,7 @@ def validate(self, association):
goods_nomenclature__sid=association.goods_nomenclature.sid,
valid_between__overlap=association.valid_between,
)
.approved_up_to_transaction(association.transaction)
.current()
.exclude(
id=association.pk,
)
Expand Down Expand Up @@ -351,9 +342,7 @@ def has_violation(self, good):
goods_nomenclature__sid=good.sid,
additional_code__isnull=False,
)
.approved_up_to_transaction(
self.transaction,
)
.current()
.exists()
)

Expand Down
2 changes: 1 addition & 1 deletion commodities/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_fields(self):
].help_text = "Leave empty if the footnote is needed for an unlimited time"
self.fields[
"associated_footnote"
].queryset = Footnote.objects.approved_up_to_transaction(self.tx).filter(
].queryset = Footnote.objects.current().filter(
footnote_type__application_code__in=[1, 2],
)
self.fields[
Expand Down
8 changes: 3 additions & 5 deletions commodities/models/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def get_dependent_measures(
measure_qs = Measure.objects.filter(goods_sid_query)

if self.moment.clock_type.is_transaction_clock:
measure_qs = measure_qs.approved_up_to_transaction(self.moment.transaction)
measure_qs = measure_qs.current()
else:
measure_qs = measure_qs.latest_approved()

Expand Down Expand Up @@ -823,7 +823,7 @@ def get_snapshot(
date=snapshot_date,
)

commodities = self._get_snapshot_commodities(transaction, snapshot_date)
commodities = self._get_snapshot_commodities(snapshot_date)

return CommodityTreeSnapshot(
moment=moment,
Expand All @@ -832,7 +832,6 @@ def get_snapshot(

def _get_snapshot_commodities(
self,
transaction: Transaction,
snapshot_date: date,
) -> List[Commodity]:
"""
Expand All @@ -853,8 +852,7 @@ def _get_snapshot_commodities(
that match the latest_version goods.
"""
item_ids = {c.item_id for c in self.commodities if c.obj}
goods = GoodsNomenclature.objects.approved_up_to_transaction(
transaction,
goods = GoodsNomenclature.objects.current(
).filter(
item_id__in=item_ids,
valid_between__contains=snapshot_date,
Expand Down
6 changes: 3 additions & 3 deletions commodities/models/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def get_url(self):
return reverse("commodity-ui-detail", kwargs={"sid": self.sid})

def get_dependent_measures(self, transaction=None):
return self.measures.model.objects.filter(
return self.measures.model.objects.current().filter(
goods_nomenclature__sid=self.sid,
).approved_up_to_transaction(transaction)
)

@property
def is_taric_code(self) -> bool:
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_good_indents(
) -> QuerySet:
"""Return the related goods indents based on approval status."""
good = self.indented_goods_nomenclature
return good.indents.approved_up_to_transaction(
return good.indents.current(
as_of_transaction or self.transaction,
)

Expand Down
7 changes: 2 additions & 5 deletions commodities/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_commodity_list_queryset():
good_1 = factories.SimpleGoodsNomenclatureFactory.create(item_id="1010000000")
good_2 = factories.SimpleGoodsNomenclatureFactory.create(item_id="1000000000")
tx = Transaction.objects.last()
commodity_count = GoodsNomenclature.objects.approved_up_to_transaction(tx).count()
commodity_count = GoodsNomenclature.objects.current().count()
with override_current_transaction(tx):
qs = view.get_queryset()

Expand Down Expand Up @@ -520,11 +520,8 @@ def test_commodity_footnote_update_success(valid_user_client, date_ranges):
"end_date": "",
}
response = valid_user_client.post(url, data)
tx = Transaction.objects.last()
updated_association = (
FootnoteAssociationGoodsNomenclature.objects.approved_up_to_transaction(
tx,
).first()
FootnoteAssociationGoodsNomenclature.objects.current().first()
)
assert response.status_code == 302
assert response.url == updated_association.get_url("confirm-update")
Expand Down
9 changes: 2 additions & 7 deletions commodities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def get_queryset(self):
"""
tx = WorkBasket.get_current_transaction(self.request)
return (
GoodsNomenclature.objects.approved_up_to_transaction(
tx,
)
GoodsNomenclature.objects.current()
.prefetch_related("descriptions")
.as_at_and_beyond(date.today())
.filter(suffix=80)
Expand All @@ -69,10 +67,7 @@ class FootnoteAssociationMixin:
model = FootnoteAssociationGoodsNomenclature

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return FootnoteAssociationGoodsNomenclature.objects.approved_up_to_transaction(
tx,
)
return self.model.objects.current()


class CommodityList(CommodityMixin, WithPaginationListView):
Expand Down
10 changes: 4 additions & 6 deletions common/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_linked_models(
related_instances = [getattr(model, field.name)]
for instance in related_instances:
try:
yield instance.version_at(transaction)
yield instance.version_at()
except TrackedModel.DoesNotExist:
# `related_instances` will contain all instances, even
# deleted ones, and `version_at` will return
Expand Down Expand Up @@ -278,8 +278,7 @@ def validate(self, model):

if (
type(model)
.objects.filter(**query)
.approved_up_to_transaction(self.transaction)
.objects.current().filter(**query)
.exclude(version_group=model.version_group)
.exists()
):
Expand All @@ -305,8 +304,7 @@ def validate(self, model):
query["valid_between__overlap"] = model.valid_between

if (
model.__class__.objects.filter(**query)
.approved_up_to_transaction(self.transaction)
model.__class__.objects.current().filter(**query)
.exclude(version_group=model.version_group)
.exists()
):
Expand Down Expand Up @@ -573,7 +571,7 @@ def validate(self, exclusion):
Membership = geo_group._meta.get_field("members").related_model

if (
not Membership.objects.approved_up_to_transaction(self.transaction)
not Membership.objects.current()
.filter(
geo_group__sid=geo_group.sid,
member__sid=exclusion.excluded_geographical_area.sid,
Expand Down
3 changes: 2 additions & 1 deletion common/models/tracked_qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def current(self) -> TrackedModelQuerySet:
)

def approved_up_to_transaction(self, transaction=None) -> TrackedModelQuerySet:
"""Get the approved versions of the model being queried, unless there
"""This function is called using the current() function instead of directly calling it on model queries.
Get the approved versions of the model being queried, unless there
exists a version of the model in a draft state within a transaction
preceding (and including) the given transaction in the workbasket of the
given transaction."""
Expand Down
Loading

0 comments on commit 894a04e

Please sign in to comment.