-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
474 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from functools import cache | ||
from typing import Optional, Protocol | ||
|
||
from registry.entity import Entity | ||
from registry.schema import StorageDriver | ||
|
||
|
||
class Driver(Protocol): | ||
def __init__(self, dsn: str) -> None: | ||
... | ||
|
||
async def find( | ||
self, | ||
entity: type[Entity], | ||
queries: list[dict], | ||
limit: Optional[int] = None, | ||
) -> list[dict]: | ||
raise NotImplementedError() | ||
|
||
async def find_one( | ||
self, | ||
entity: type[Entity], | ||
queries: list[dict], | ||
) -> Optional[dict]: | ||
rows = await self.find(entity, queries, limit=1) | ||
if len(rows): | ||
return rows[0] | ||
|
||
return None | ||
|
||
async def find_or_create( | ||
self, entity: type[Entity], query: dict, data: dict | ||
) -> dict: | ||
result = await self.find(entity, [query]) | ||
if len(result): | ||
return result[0] | ||
|
||
return await self.insert(entity, data) | ||
|
||
async def find_or_fail( | ||
self, | ||
entity: type[Entity], | ||
queries: list[dict], | ||
) -> dict: | ||
instance = await self.find_one(entity, queries) | ||
if not instance: | ||
raise LookupError(f'{entity.__name__} not found') | ||
|
||
return instance | ||
|
||
async def init_schema(self, entity: type[Entity]) -> None: | ||
raise NotImplementedError() | ||
|
||
async def insert(self, entity: type[Entity], data: dict) -> dict: | ||
raise NotImplementedError() | ||
|
||
|
||
@cache | ||
def get_driver(driver: StorageDriver, dsn: str) -> Driver: | ||
return get_implementation(driver)(dsn) | ||
|
||
|
||
@cache | ||
def get_implementation(driver: StorageDriver) -> type[Driver]: | ||
if driver is StorageDriver.MEMORY: | ||
from registry.drivers.memory import MemoryDriver | ||
return MemoryDriver | ||
|
||
if driver is StorageDriver.TARANTOOL: | ||
from registry.drivers.tarantool import TarantoolDriver | ||
return TarantoolDriver | ||
|
||
raise NotImplementedError(f'{driver} driver not implemented') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Optional | ||
|
||
from registry.drivers import Driver | ||
from registry.entity import Entity | ||
|
||
|
||
class MemoryDriver(Driver): | ||
def __init__(self, dsn: str) -> None: | ||
self.data: dict[type[Entity], list[dict]] = {} | ||
|
||
async def find( | ||
self, | ||
entity: type[Entity], | ||
queries: list[dict], | ||
limit: Optional[int] = None, | ||
) -> list[dict]: | ||
await self.init_schema(entity) | ||
rows = [ | ||
row for row in self.data[entity] | ||
if await self.is_valid(row, queries) | ||
] | ||
if limit: | ||
rows = rows[0:limit] | ||
return rows | ||
|
||
async def init_schema(self, entity: type[Entity]) -> None: | ||
if entity not in self.data: | ||
self.data[entity] = [] | ||
|
||
async def insert(self, entity: type[Entity], data: dict) -> dict: | ||
await self.init_schema(entity) | ||
data['id'] = len(self.data[entity]) + 1 | ||
self.data[entity].append(data) | ||
return data | ||
|
||
async def is_valid(self, row, queries: list) -> bool: | ||
for query in queries: | ||
if False not in [ | ||
row[key] == value for (key, value) in query.items() | ||
]: | ||
return True | ||
|
||
return False |
Oops, something went wrong.