diff --git a/fastlite/kw.py b/fastlite/kw.py index 1e5c511..2807693 100644 --- a/fastlite/kw.py +++ b/fastlite/kw.py @@ -15,11 +15,37 @@ def xtra(self:Table, **kwargs): @patch def get_last(self:Table, as_cls:bool=True): row = first(self.rows_where('_rowid_=?', (self.last_rowid,))) + assert row, f"Couldn't find {self.last_rowid}" vals = [row[pk] for pk in self.pks] self.last_pk = vals[0] if len(vals)==1 else vals if as_cls and hasattr(self,'cls'): row = self.cls(**row) return row +@patch +def ids_and_rows_where( + self:Table, + where: Optional[str] = None, + where_args: Optional[Union[Iterable, dict]] = None, + order_by: Optional[str] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, +) -> Generator[Tuple[Any, Dict], None, None]: + """ + Like ``.rows_where()`` but returns ``(rowid, row)`` pairs. + + :param where: SQL where fragment to use, for example ``id > ?`` + :param where_args: Parameters to use with that fragment - an iterable for ``id > ?`` + parameters, or a dictionary for ``id > :id`` + :param order_by: Column or fragment of SQL to order by + :param select: Comma-separated list of columns to select - defaults to ``*`` + :param limit: Integer number of rows to limit to + :param offset: Integer for SQL offset + """ + cs = [c.name for c in self.columns] + select = ",".join("[{}]".format(c) for c in cs) + select = "_rowid_ as __rid, " + select + for row in self.rows_where(select=select, where=where, where_args=where_args, order_by=order_by, limit=limit, offset=offset): + yield row.pop('__rid'), row @patch def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True)->Any: @@ -28,14 +54,13 @@ def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True)->Any: xtra = getattr(self, 'xtra_id', {}) vals = pk_values + list(xtra.values()) pks = self.pks + list(xtra.keys()) - if len(pks) != len(vals): - raise NotFoundError( "Need {} primary key value{}".format( len(pks), "" if len(pks) == 1 else "s")) - + if len(pks)!=len(vals): raise NotFoundError(f"Need {len(pks)} pk") wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] - rows = self.rows_where(" and ".join(wheres), vals) - try: row = list(rows)[0] - except IndexError: raise NotFoundError + item = first(self.ids_and_rows_where(" and ".join(wheres), vals)) + if not item: raise NotFoundError + rid,row = item self.last_pk = last_pk + self.last_rowid = rid if as_cls and hasattr(self,'cls'): row = self.cls(**row) return row