Skip to content

Commit

Permalink
Revisit API to allow arbitrary keyword arguments to connect()
Browse files Browse the repository at this point in the history
This allows to easily pass arguments that cannot easily be serialized
as part of the data source string. In particular this allows to remove
some special casing of the beancount data source.
  • Loading branch information
dnicolodi committed Oct 19, 2024
1 parent d0e5a11 commit eaec87b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 41 deletions.
12 changes: 6 additions & 6 deletions beanquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@
__version__ = '0.1.dev1'


def connect(dsn=None):
return Connection(dsn)
def connect(dsn, **kwargs):
return Connection(dsn, **kwargs)


class Connection:
def __init__(self, dsn=None):
def __init__(self, dsn=None, **kwargs):
self.tables = {'': tables.NullTable()}
self.options = {}
self.errors = []
if dsn is not None:
self.attach(dsn)
self.attach(dsn, **kwargs)

def attach(self, dsn):
def attach(self, dsn, **kwargs):
scheme = urlparse(dsn).scheme
source = importlib.import_module(f'beanquery.sources.{scheme}')
source.attach(self, dsn)
source.attach(self, dsn, **kwargs)

def close(self):
# Required by the DB-API.
Expand Down
26 changes: 10 additions & 16 deletions beanquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
__copyright__ = "Copyright (C) 2015-2017 Martin Blais"
__license__ = "GNU GPLv2"

from beanquery import Connection
from beanquery import numberify as numberify_lib
from beanquery.sources.beancount import add_beancount_tables
import beanquery
import beanquery.numberify

def run_query(entries, options_map, query, *format_args, numberify=False):

def run_query(entries, options, query, *args, numberify=False):
"""Compile and execute a query, return the result types and rows.
Args:
entries: A list of entries, as produced by the loader.
options_map: A dict of options, as produced by the loader.
options: A dict of options, as produced by the loader.
query: A string, a single BQL query, optionally containing some new-style
(e.g., {}) formatting specifications.
format_args: A tuple of arguments to be formatted in the query. This is
args: A tuple of arguments to be formatted in the query. This is
just provided as a convenience.
numberify: If true, numberify the results before returning them.
Returns:
Expand All @@ -25,21 +25,15 @@ def run_query(entries, options_map, query, *format_args, numberify=False):
CompilationError: If the statement cannot be compiled.
"""

# Register tables.
ctx = Connection()
add_beancount_tables(ctx, entries, [], options_map)

# Apply formatting to the query.
formatted_query = query.format(*format_args)

# Execute it to obtain the result rows.
curs = ctx.execute(formatted_query)
# Execute the query.
ctx = beanquery.connect('beancount:', entries=entries, errors=[], options=options)
curs = ctx.execute(query.format(*args))
rrows = curs.fetchall()
rtypes = curs.description

# Numberify the results, if requested.
if numberify:
dformat = options_map['dcontext'].build()
rtypes, rrows = numberify_lib.numberify_results(rtypes, rrows, dformat)
rtypes, rrows = beanquery.numberify.numberify_results(rtypes, rrows, dformat)

return rtypes, rrows
15 changes: 5 additions & 10 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,29 @@
from beancount.parser import cmptest
from beancount import loader

import beanquery

from beanquery import CompilationError
from beanquery import query_compile as qc
from beanquery import query_env as qe
from beanquery import query_execute as qx
from beanquery import tables

from beanquery import Connection
from beanquery.sources.beancount import add_beancount_tables


class QueryBase(cmptest.TestCase):
INPUT = ""
maxDiff = 8192

def setUp(self):
entries, errors, options = loader.load_string(textwrap.dedent(self.INPUT))
self.ctx = Connection()
add_beancount_tables(self.ctx, entries, errors, options)
self.ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)

def compile(self, query):
return self.ctx.compile(self.ctx.parse(query))

def check_query(self, input_string, query, expected_types, expected_rows):
entries, errors, options = loader.load_string(input_string)
ctx = Connection()
add_beancount_tables(ctx, entries, errors, options)

ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)
curs = ctx.execute(query)
self.assertEqual(tuple(expected_types), curs.description)
result_rows = curs.fetchall()
Expand Down Expand Up @@ -1399,8 +1395,7 @@ def setUp(self):
super().setUp()
entries, errors, options = loader.load_string(self.data, dedent=True)
self.assertFalse(errors)
self.ctx = Connection()
add_beancount_tables(self.ctx, entries, errors, options)
self.ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)

def execute(self, query):
curs = self.ctx.execute(query)
Expand Down
3 changes: 1 addition & 2 deletions beanquery/shell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from beancount.utils import test_utils

from beanquery import shell
from beanquery.sources.beancount import add_beancount_tables


@functools.lru_cache(None)
Expand Down Expand Up @@ -124,7 +123,7 @@ def run_shell_command(cmd):
with test_utils.capture('stdout') as stdout, test_utils.capture('stderr') as stderr:
shell_obj = shell.BQLShell(None, sys.stdout)
entries, errors, options = load()
add_beancount_tables(shell_obj.context, entries, errors, options)
shell_obj.context.attach('beancount:', entries=entries, errors=errors, options=options)
shell_obj._extract_queries(entries) # pylint: disable=protected-access
shell_obj.onecmd(cmd)
return stdout.getvalue(), stderr.getvalue()
Expand Down
11 changes: 4 additions & 7 deletions beanquery/sources/beancount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
TABLES = [query_env.EntriesTable, query_env.PostingsTable]


def add_beancount_tables(context, entries, errors, options):
def attach(context, dsn, entries=None, errors=None, options=None):
filename = urlparse(dsn).path
if filename:
entries, errors, options = loader.load_file(filename)
for table in TABLES:
context.tables[table.name] = table(entries, options)
context.options.update(options)
context.errors.extend(errors)


def attach(context, dsn):
filename = urlparse(dsn).path
entries, errors, options = loader.load_file(filename)
add_beancount_tables(context, entries, errors, options)


class Metadata(dict):
pass

Expand Down

0 comments on commit eaec87b

Please sign in to comment.