Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit API to allow arbitrary keyword arguments to connect() #202

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions beanquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@
__version__ = '0.1.dev1'


def connect(dsn=None):
return Connection(dsn)
def connect(dsn, **kwargs):
return Connection(dsn, **kwargs)


class Connection:
def __init__(self, dsn=None):
def __init__(self, dsn=None, **kwargs):
self.tables = {'': tables.NullTable()}
self.options = {}
self.errors = []
if dsn is not None:
self.attach(dsn)
self.attach(dsn, **kwargs)

def attach(self, dsn):
def attach(self, dsn, **kwargs):
scheme = urlparse(dsn).scheme
source = importlib.import_module(f'beanquery.sources.{scheme}')
source.attach(self, dsn)
source.attach(self, dsn, **kwargs)

def close(self):
# Required by the DB-API.
Expand Down
28 changes: 11 additions & 17 deletions beanquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
__copyright__ = "Copyright (C) 2015-2017 Martin Blais"
__license__ = "GNU GPLv2"

from beanquery import Connection
from beanquery import numberify as numberify_lib
from beanquery.sources.beancount import add_beancount_tables
import beanquery
import beanquery.numberify

def run_query(entries, options_map, query, *format_args, numberify=False):

def run_query(entries, options, query, *args, numberify=False):
"""Compile and execute a query, return the result types and rows.
Args:
entries: A list of entries, as produced by the loader.
options_map: A dict of options, as produced by the loader.
options: A dict of options, as produced by the loader.
query: A string, a single BQL query, optionally containing some new-style
(e.g., {}) formatting specifications.
format_args: A tuple of arguments to be formatted in the query. This is
args: A tuple of arguments to be formatted in the query. This is
just provided as a convenience.
numberify: If true, numberify the results before returning them.
Returns:
Expand All @@ -25,21 +25,15 @@ def run_query(entries, options_map, query, *format_args, numberify=False):
CompilationError: If the statement cannot be compiled.
"""

# Register tables.
ctx = Connection()
add_beancount_tables(ctx, entries, [], options_map)

# Apply formatting to the query.
formatted_query = query.format(*format_args)

# Execute it to obtain the result rows.
curs = ctx.execute(formatted_query)
# Execute the query.
ctx = beanquery.connect('beancount:', entries=entries, errors=[], options=options)
curs = ctx.execute(query.format(*args))
rrows = curs.fetchall()
rtypes = curs.description

# Numberify the results, if requested.
if numberify:
dformat = options_map['dcontext'].build()
rtypes, rrows = numberify_lib.numberify_results(rtypes, rrows, dformat)
dformat = options['dcontext'].build()
rtypes, rrows = beanquery.numberify.numberify_results(rtypes, rrows, dformat)

return rtypes, rrows
15 changes: 5 additions & 10 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,29 @@
from beancount.parser import cmptest
from beancount import loader

import beanquery

from beanquery import CompilationError
from beanquery import query_compile as qc
from beanquery import query_env as qe
from beanquery import query_execute as qx
from beanquery import tables

from beanquery import Connection
from beanquery.sources.beancount import add_beancount_tables


class QueryBase(cmptest.TestCase):
INPUT = ""
maxDiff = 8192

def setUp(self):
entries, errors, options = loader.load_string(textwrap.dedent(self.INPUT))
self.ctx = Connection()
add_beancount_tables(self.ctx, entries, errors, options)
self.ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)

def compile(self, query):
return self.ctx.compile(self.ctx.parse(query))

def check_query(self, input_string, query, expected_types, expected_rows):
entries, errors, options = loader.load_string(input_string)
ctx = Connection()
add_beancount_tables(ctx, entries, errors, options)

ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)
curs = ctx.execute(query)
self.assertEqual(tuple(expected_types), curs.description)
result_rows = curs.fetchall()
Expand Down Expand Up @@ -1399,8 +1395,7 @@ def setUp(self):
super().setUp()
entries, errors, options = loader.load_string(self.data, dedent=True)
self.assertFalse(errors)
self.ctx = Connection()
add_beancount_tables(self.ctx, entries, errors, options)
self.ctx = beanquery.connect('beancount:', entries=entries, errors=errors, options=options)

def execute(self, query):
curs = self.ctx.execute(query)
Expand Down
3 changes: 1 addition & 2 deletions beanquery/shell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from beancount.utils import test_utils

from beanquery import shell
from beanquery.sources.beancount import add_beancount_tables


@functools.lru_cache(None)
Expand Down Expand Up @@ -124,7 +123,7 @@ def run_shell_command(cmd):
with test_utils.capture('stdout') as stdout, test_utils.capture('stderr') as stderr:
shell_obj = shell.BQLShell(None, sys.stdout)
entries, errors, options = load()
add_beancount_tables(shell_obj.context, entries, errors, options)
shell_obj.context.attach('beancount:', entries=entries, errors=errors, options=options)
shell_obj._extract_queries(entries) # pylint: disable=protected-access
shell_obj.onecmd(cmd)
return stdout.getvalue(), stderr.getvalue()
Expand Down
11 changes: 4 additions & 7 deletions beanquery/sources/beancount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
TABLES = [query_env.EntriesTable, query_env.PostingsTable]


def add_beancount_tables(context, entries, errors, options):
def attach(context, dsn, entries=None, errors=None, options=None):
filename = urlparse(dsn).path
if filename:
entries, errors, options = loader.load_file(filename)
for table in TABLES:
context.tables[table.name] = table(entries, options)
context.options.update(options)
context.errors.extend(errors)


def attach(context, dsn):
filename = urlparse(dsn).path
entries, errors, options = loader.load_file(filename)
add_beancount_tables(context, entries, errors, options)


class Metadata(dict):
pass

Expand Down
Loading