Skip to content

Commit

Permalink
fixes #6
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Jun 3, 2024
1 parent 82b9d8f commit fdeccce
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions fastlite/kw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,37 @@ def xtra(self:Table, **kwargs):
@patch
def get_last(self:Table, as_cls:bool=True):
row = first(self.rows_where('_rowid_=?', (self.last_rowid,)))
assert row, f"Couldn't find {self.last_rowid}"
vals = [row[pk] for pk in self.pks]
self.last_pk = vals[0] if len(vals)==1 else vals
if as_cls and hasattr(self,'cls'): row = self.cls(**row)
return row

@patch
def ids_and_rows_where(
self:Table,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Generator[Tuple[Any, Dict], None, None]:
"""
Like ``.rows_where()`` but returns ``(rowid, row)`` pairs.
:param where: SQL where fragment to use, for example ``id > ?``
:param where_args: Parameters to use with that fragment - an iterable for ``id > ?``
parameters, or a dictionary for ``id > :id``
:param order_by: Column or fragment of SQL to order by
:param select: Comma-separated list of columns to select - defaults to ``*``
:param limit: Integer number of rows to limit to
:param offset: Integer for SQL offset
"""
cs = [c.name for c in self.columns]
select = ",".join("[{}]".format(c) for c in cs)
select = "_rowid_ as __rid, " + select
for row in self.rows_where(select=select, where=where, where_args=where_args, order_by=order_by, limit=limit, offset=offset):
yield row.pop('__rid'), row

@patch
def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True)->Any:
Expand All @@ -28,14 +54,13 @@ def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True)->Any:
xtra = getattr(self, 'xtra_id', {})
vals = pk_values + list(xtra.values())
pks = self.pks + list(xtra.keys())
if len(pks) != len(vals):
raise NotFoundError( "Need {} primary key value{}".format( len(pks), "" if len(pks) == 1 else "s"))

if len(pks)!=len(vals): raise NotFoundError(f"Need {len(pks)} pk")
wheres = ["[{}] = ?".format(pk_name) for pk_name in pks]
rows = self.rows_where(" and ".join(wheres), vals)
try: row = list(rows)[0]
except IndexError: raise NotFoundError
item = first(self.ids_and_rows_where(" and ".join(wheres), vals))
if not item: raise NotFoundError
rid,row = item
self.last_pk = last_pk
self.last_rowid = rid
if as_cls and hasattr(self,'cls'): row = self.cls(**row)
return row

Expand Down

0 comments on commit fdeccce

Please sign in to comment.