From 21acc2c1c15cdfbce2bba1610182d1d71e9e0be8 Mon Sep 17 00:00:00 2001 From: Daniele Nicolodi Date: Mon, 11 Nov 2024 21:28:29 +0100 Subject: [PATCH] Implement date_bin() BQL function --- beanquery/query_env.py | 43 +++++++++++++++++++++++++++++++++ beanquery/query_execute_test.py | 8 ++++++ 2 files changed, 51 insertions(+) diff --git a/beanquery/query_env.py b/beanquery/query_env.py index a9afccb..eae7fe9 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -681,6 +681,49 @@ def interval(x): return None +@function([relativedelta, datetime.date, datetime.date], datetime.date) +def date_bin(stride, source, origin): + """Bin a date into the specified stride aligned with the specified origin. + + As an extension to the the SQL standard ``date_bin()`` function this + function also accepts strides containing units of months and years. + """ + if stride.months or stride.years: + if origin + stride <= origin: + # FIXME: this should raise and error: stride must be greater than zero + return None + if source >= origin: + d = n = origin + while True: + n += stride + if n >= source: + return d + d = n + else: + n = origin + while True: + n -= stride + if n <= source: + return n + else: + seconds = stride.days * 86400 + stride.hours * 3600 + stride.minutes * 60 + stride.seconds + if seconds < 0: + # FIXME: this should raise and error: stride must be greater than zero + return None + diff = (source - origin).total_seconds() + modulo = diff % seconds + delta = diff - modulo + result = origin + datetime.timedelta(seconds=delta) + if modulo < 0: + result -= datetime.timedelta(seconds=seconds) + return result + + +@function([str, datetime.date, datetime.date], datetime.date, name='date_bin') +def date_bin_str(stride, source, origin): + return date_bin(interval(stride), source, origin) + + def aggregator(intypes, name=None): def decorator(cls): cls.__intypes__ = intypes diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 3ddf1f3..af8ec83 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -401,6 +401,14 @@ def test_interval_ops(self): self.assertResult('''SELECT interval("1 baz")''', None, relativedelta) self.assertResult('''SELECT interval("A days")''', None, relativedelta) + def test_date_bin(self): + self.assertResult('''SELECT date_bin(interval('1 year'), 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2025-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 month', 2024-11-10, 2024-06-03)''', datetime.date(2024, 11, 3)) + self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-02)''', datetime.date(2024, 11, 8)) + self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-14)''', datetime.date(2024, 11, 8)) + class TestBeancountFunctions(QueryBase): INPUT = textwrap.dedent("""