Skip to content

Commit

Permalink
tune driver api
Browse files Browse the repository at this point in the history
  • Loading branch information
nekufa committed Nov 13, 2023
1 parent b240e5c commit aa0e4f6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
57 changes: 43 additions & 14 deletions registry/drivers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol
from typing import Optional, Protocol

from registry.entity import Entity
from registry.schema import StorageDriver
Expand All @@ -8,26 +8,49 @@ class Driver(Protocol):
def __init__(self, dsn: str) -> None:
...

async def create(self, entity: type[Entity], data: dict) -> dict:
raise NotImplementedError()

async def find(
self, entity: type[Entity], queries: list[dict]
self,
entity: type[Entity],
queries: list[dict],
limit: Optional[int] = None,
) -> list[dict]:
raise NotImplementedError()

async def find_one(
self,
entity: type[Entity],
queries: list[dict],
) -> Optional[dict]:
rows = await self.find(entity, queries, limit=1)
if len(rows):
return rows[0]

async def find_or_create(
self, entity: type[Entity], query: dict, data: dict
) -> dict:
result = await self.find(entity, [query])
if len(result):
return result[0]

return await self.create(entity, data)
return await self.insert(entity, data)

async def find_or_fail(
self,
entity: type[Entity],
queries: list[dict],
) -> dict:
instance = await self.find_one(entity, queries)
if not instance:
raise LookupError(f'{entity.nick} not found')

return instance

async def init_schema(self, entity: type[Entity]) -> None:
raise NotImplementedError()

async def insert(self, entity: type[Entity], data: dict) -> dict:
raise NotImplementedError()


driver_instances: dict[str, dict[str, Driver]] = {}

Expand All @@ -53,25 +76,31 @@ class MemoryDriver(Driver):
def __init__(self, dsn: str) -> None:
self.data: dict[type[Entity], list[dict]] = {}

async def create(self, entity: type[Entity], data: dict) -> dict:
await self.init_schema(entity)
data['id'] = len(self.data[entity]) + 1
self.data[entity].append(data)
return data

async def find(
self, entity: type[Entity], queries: list[dict]
self,
entity: type[Entity],
queries: list[dict],
limit: Optional[int] = None,
) -> list[dict]:
await self.init_schema(entity)
return [
rows = [
row for row in self.data[entity]
if await self.is_valid(row, queries)
]
if limit:
rows = rows[0:limit]
return rows

async def init_schema(self, entity: type[Entity]) -> None:
if entity not in self.data:
self.data[entity] = []

async def insert(self, entity: type[Entity], data: dict) -> dict:
await self.init_schema(entity)
data['id'] = len(self.data[entity]) + 1
self.data[entity].append(data)
return data

async def is_valid(self, row, queries: list) -> bool:
for query in queries:
if False not in [
Expand Down
2 changes: 1 addition & 1 deletion registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def create(
context = await self.context(entity, key)
return context.repository.make(
entity=entity,
row=await context.driver.create(
row=await context.driver.insert(
entity=entity,
data=dict(bucket_id=context.bucket.id, **data),
)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,35 @@ class ActionRepository(Repository):
async def test_hello():
registry = Registry()
assert len(await registry.find(Action)) == 0

# create two actions
action1 = await registry.find_or_create(Action, {'type': 'tester'})
action2 = await registry.find_or_create(Action, {'type': 'tester2'})

# validate properties
assert action1.id == 1
assert action1.type == 'tester'
assert action1.owner_id == 0
assert action2.id == 2
assert action2.type == 'tester2'
assert action2.owner_id == 0

# identity map
action3 = await registry.find_or_create(Action, {'id': 2})
assert action3 == action2
action4 = await registry.find_or_create(Action, {'type': 'tester'})
assert action4 == action1
assert len(await registry.find(Action)) == 2

# lookup checks
assert (await registry.get_instance(Action, 2)).type == 'tester2'
assert (await registry.get_instance(Action, 3)) is None

# storage level persistence check
[storage] = registry.storages
driver = await get_driver(storage.driver, storage.dsn)
assert driver.data[Action][0]['owner_id'] == 0
assert len(await driver.find(Action, queries=[{}])) == 2

# default values peristence
first_action_dict = await driver.find_one(Action, queries=[{'id': 1}])
assert first_action_dict['owner_id'] == 0

0 comments on commit aa0e4f6

Please sign in to comment.