Skip to content

Commit

Permalink
Merge pull request #31 from AnswerDotAI/tests
Browse files Browse the repository at this point in the history
correct tests for and `Enum|None` behavior
  • Loading branch information
jph00 authored Oct 5, 2024
2 parents 04f8204 + 48b54c6 commit 4fd4870
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
13 changes: 7 additions & 6 deletions fastlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,15 @@ def _parse_typ(t): return t if not (_args:= get_args(t)) else first(_args, bool)

# %% ../nbs/00_core.ipynb 50
def _is_enum(o): return isinstance(o, type) and issubclass(o, Enum)
def _enum_types(e): return {_parse_typ(t) for t in e.__annotations__.values()}
def _enum_types(e): return {type(v.value) for v in e}

def get_typ(t):
"Get the underlying type."
t = _parse_typ(t) # incase Union[Enum,None]
if _is_enum(t) and len(types:=_enum_types(t)) == 1: return first(types)
return _parse_typ(t)
return t

# %% ../nbs/00_core.ipynb 56
# %% ../nbs/00_core.ipynb 58
@patch
def create(
self: Database,
Expand Down Expand Up @@ -182,7 +183,7 @@ def create(
res.cls = cls
return res

# %% ../nbs/00_core.ipynb 66
# %% ../nbs/00_core.ipynb 68
@patch
def import_file(self:Database, table_name, file, format=None, pk=None, alter=False):
"Import path or handle `file` to new table `table_name`"
Expand All @@ -197,7 +198,7 @@ def import_file(self:Database, table_name, file, format=None, pk=None, alter=Fal
if pk: tbl.transform(pk=pk)
return tbl

# %% ../nbs/00_core.ipynb 72
# %% ../nbs/00_core.ipynb 74
def _edge(tbl):
return "\n".join(f"{fk.table}:{fk.column} -> {fk.other_table}:{fk.other_column};"
for fk in tbl.foreign_keys)
Expand All @@ -215,7 +216,7 @@ def _tnode(tbl):
</table>"""
return f"{tbl.name} [label=<{res}>];\n"

# %% ../nbs/00_core.ipynb 73
# %% ../nbs/00_core.ipynb 75
def diagram(tbls, ratio=0.7, size="10", neato=False, render=True):
layout = "\nlayout=neato;\noverlap=prism;\noverlap_scaling=0.5;""" if neato else ""
edges = "\n".join(map(_edge, tbls))
Expand Down
49 changes: 34 additions & 15 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"from IPython.display import Markdown\n",
"from fastcore.test import test_fail"
"from fastcore.test import test_fail, test_eq"
]
},
{
Expand Down Expand Up @@ -121,7 +121,7 @@
{
"data": {
"text/plain": [
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track"
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track, cat"
]
},
"execution_count": null,
Expand Down Expand Up @@ -797,12 +797,13 @@
"source": [
"#|export\n",
"def _is_enum(o): return isinstance(o, type) and issubclass(o, Enum)\n",
"def _enum_types(e): return {_parse_typ(t) for t in e.__annotations__.values()}\n",
"def _enum_types(e): return {type(v.value) for v in e}\n",
"\n",
"def get_typ(t):\n",
" \"Get the underlying type.\"\n",
" t = _parse_typ(t) # incase Union[Enum,None]\n",
" if _is_enum(t) and len(types:=_enum_types(t)) == 1: return first(types)\n",
" return _parse_typ(t)"
" return t"
]
},
{
Expand Down Expand Up @@ -852,6 +853,16 @@
"If you have an `Enum` where all the fields are the same type, then `_get_typ` will return that type."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class _Test(Enum): foo='val1'; bar=2\n",
"class _Test2(Enum): foo='val3'; bar='val4'"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -869,9 +880,6 @@
}
],
"source": [
"class _Test(Enum): foo:str; bar:int\n",
"class _Test2(Enum): foo:str|None; bar:str\n",
"\n",
"# fields are not the same type\n",
"get_typ(_Test)"
]
Expand All @@ -897,6 +905,19 @@
"get_typ(_Test2)"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"#|hide\n",
"test_eq(get_typ(int), int)\n",
"test_eq(get_typ(Union[int, None]), int)\n",
"assert get_typ(_Test) != str\n",
"test_eq(get_typ(_Test2), str)\n",
"test_eq(get_typ(Union[_Test2, None]), str)\n",
"test_eq(get_typ(_Test2|None), str)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -960,9 +981,9 @@
}
],
"source": [
"class Nm(Enum): fn:str|None='meow'; ln:str \n",
"class Nm(Enum): fn='meow'; ln='prr' \n",
" \n",
"class Cat: id: int; name:Nm; age: int|None; city: str = \"Unknown\"\n",
"class Cat: id: int; name:Nm|None; age: int|None; city: str = \"Unknown\"\n",
"cats = db.create(Cat)\n",
"Cat(1)"
]
Expand All @@ -979,8 +1000,7 @@
"CREATE TABLE [cat] (\n",
" [id] INTEGER PRIMARY KEY,\n",
" [name] TEXT,\n",
" [age] INTEGER,\n",
" [city] TEXT\n",
" [weight] FLOAT\n",
")\n"
]
}
Expand All @@ -997,7 +1017,7 @@
{
"data": {
"text/plain": [
"<Table cat (id, name, age, city)>"
"<Table cat (id, name, weight)>"
]
},
"execution_count": null,
Expand All @@ -1024,7 +1044,7 @@
{
"data": {
"text/plain": [
"<Table cat (id, name, age, city)>"
"<Table cat (id, name, weight)>"
]
},
"execution_count": null,
Expand Down Expand Up @@ -1070,8 +1090,7 @@
"CREATE TABLE [cat] (\n",
" [id] INTEGER PRIMARY KEY,\n",
" [name] TEXT,\n",
" [age] INTEGER,\n",
" [city] TEXT\n",
" [weight] FLOAT\n",
")\n"
]
}
Expand Down

0 comments on commit 4fd4870

Please sign in to comment.