From c6499cfc0b1b4c272e63a92aaa07d5b97041859d Mon Sep 17 00:00:00 2001 From: Jeremy Howard Date: Sun, 2 Jun 2024 02:42:08 +1000 Subject: [PATCH] fixes #4 --- fastlite/core.py | 2 +- fastlite/kw.py | 16 +++++----- nbs/00_core.ipynb | 2 +- nbs/index.ipynb | 78 ++++++++++++++++++++++++++++++++++++----------- 4 files changed, 69 insertions(+), 29 deletions(-) diff --git a/fastlite/core.py b/fastlite/core.py index 2da6280..e8b4f55 100644 --- a/fastlite/core.py +++ b/fastlite/core.py @@ -25,7 +25,7 @@ def __init__(self, db): self.db = db def __repr__(self): return ", ".join(dir(self)) def __contains__(self, s): return (s if isinstance(s,str) else s.name) in dir(self) def __getitem__(self, idxs): - if isinstance(idxs,str): idxs = [idxs] + if isinstance(idxs,str): return self.db.table(idxs) return [self.db.table(o) for o in idxs] def __getattr__(self, k): if k[0]=='_': raise AttributeError diff --git a/fastlite/kw.py b/fastlite/kw.py index 573ec99..b0d8660 100644 --- a/fastlite/kw.py +++ b/fastlite/kw.py @@ -75,18 +75,16 @@ def transform_sql( drop_foreign_keys=drop_foreign_keys, add_foreign_keys=add_foreign_keys, foreign_keys=foreign_keys, column_order=column_order, keep_table=keep_table) + @patch -def update( - self:Table, - pk_values: Union[list, tuple, str, int, float], - updates: Any = None, - alter: bool = False, - conversions: Optional[dict] = None, - **kwargs) -> Table: +def update(self:Table, updates: dict|None=None, pk_values: list|tuple|str|int|float|None=None, + alter: bool=False, conversions: dict|None=None, **kwargs): if not updates: updates={} if is_dataclass(updates): updates = asdict(updates) updates = {**updates, **kwargs} - self._orig_update(pk_values=pk_values, updates=updates, alter=alter, conversions=conversions) + if not pk_values: pk_values = [updates[o] for o in self.pks] + self._orig_update(pk_values, updates=updates, alter=alter, conversions=conversions) + return self.get(self.last_pk) @patch @@ -121,7 +119,7 @@ def insert( @patch def upsert( self:Table, - record: Dict[str, Any]=None, + record:Any=None, pk=DEFAULT, foreign_keys=DEFAULT, column_order: Optional[Union[List[str], Default]] = DEFAULT, diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 3175070..0b91a56 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -78,7 +78,7 @@ " def __repr__(self): return \", \".join(dir(self))\n", " def __contains__(self, s): return (s if isinstance(s,str) else s.name) in dir(self)\n", " def __getitem__(self, idxs):\n", - " if isinstance(idxs,str): idxs = [idxs]\n", + " if isinstance(idxs,str): return self.db.table(idxs)\n", " return [self.db.table(o) for o in idxs]\n", " def __getattr__(self, k):\n", " if k[0]=='_': raise AttributeError\n", diff --git a/nbs/index.ipynb b/nbs/index.ipynb index a582c85..cb9e727 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -594,6 +594,41 @@ "- `lookup`" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can access a table that doesn't actually exist yet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cats = dt.cats\n", + "cats" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use keyword arguments to now create that table:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -620,7 +655,6 @@ } ], "source": [ - "cats = dt.cats\n", "cats.create(id=int, name=str, weight=float, pk='id')\n", "hl_md(cats.schema, 'sql')" ] @@ -672,7 +706,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using `**` in upsert here doesn't actually achieve anything, since we can just pass a `dict` directly -- it's just to show that it works:" + "Using `**` in `update` here doesn't actually achieve anything, since we can just pass a `dict` directly -- it's just to show that it works:" ] }, { @@ -693,7 +727,7 @@ ], "source": [ "cat['name'] = \"moo\"\n", - "cats.upsert(**cat)\n", + "cats.update(**cat)\n", "cats()" ] }, @@ -712,7 +746,7 @@ { "data": { "text/plain": [ - "[Cats(id=1, name='foo', weight=6.0)]" + "Cats(id=1, name='moo', weight=6.0)" ] }, "execution_count": null, @@ -721,8 +755,28 @@ } ], "source": [ - "catdc = cats.dataclass()\n", + "cats.dataclass()\n", "cat = cats.get(1)\n", + "cat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Cats(id=1, name='foo', weight=6.0)]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "cat.name = 'foo'\n", "cats.upsert(cat)\n", "cats()" @@ -1256,21 +1310,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "python3", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" } }, "nbformat": 4,