Skip to content

Commit

Permalink
Implement date_bin() BQL function
Browse files Browse the repository at this point in the history
  • Loading branch information
dnicolodi committed Nov 11, 2024
1 parent 9f6461a commit 21acc2c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
43 changes: 43 additions & 0 deletions beanquery/query_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,49 @@ def interval(x):
return None


@function([relativedelta, datetime.date, datetime.date], datetime.date)
def date_bin(stride, source, origin):
"""Bin a date into the specified stride aligned with the specified origin.
As an extension to the the SQL standard ``date_bin()`` function this
function also accepts strides containing units of months and years.
"""
if stride.months or stride.years:
if origin + stride <= origin:
# FIXME: this should raise and error: stride must be greater than zero
return None
if source >= origin:
d = n = origin
while True:
n += stride
if n >= source:
return d
d = n
else:
n = origin
while True:
n -= stride
if n <= source:
return n
else:
seconds = stride.days * 86400 + stride.hours * 3600 + stride.minutes * 60 + stride.seconds
if seconds < 0:
# FIXME: this should raise and error: stride must be greater than zero
return None
diff = (source - origin).total_seconds()
modulo = diff % seconds
delta = diff - modulo
result = origin + datetime.timedelta(seconds=delta)
if modulo < 0:
result -= datetime.timedelta(seconds=seconds)
return result


@function([str, datetime.date, datetime.date], datetime.date, name='date_bin')
def date_bin_str(stride, source, origin):
return date_bin(interval(stride), source, origin)


def aggregator(intypes, name=None):
def decorator(cls):
cls.__intypes__ = intypes
Expand Down
8 changes: 8 additions & 0 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ def test_interval_ops(self):
self.assertResult('''SELECT interval("1 baz")''', None, relativedelta)
self.assertResult('''SELECT interval("A days")''', None, relativedelta)

def test_date_bin(self):
self.assertResult('''SELECT date_bin(interval('1 year'), 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1))
self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1))
self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2025-06-01)''', datetime.date(2024, 6, 1))
self.assertResult('''SELECT date_bin('1 month', 2024-11-10, 2024-06-03)''', datetime.date(2024, 11, 3))
self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-02)''', datetime.date(2024, 11, 8))
self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-14)''', datetime.date(2024, 11, 8))


class TestBeancountFunctions(QueryBase):
INPUT = textwrap.dedent("""
Expand Down

0 comments on commit 21acc2c

Please sign in to comment.