Skip to content

Commit

Permalink
refactor registry
Browse files Browse the repository at this point in the history
  • Loading branch information
nekufa committed Nov 3, 2023
1 parent 122825c commit 9fb16fc
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 91 deletions.
48 changes: 25 additions & 23 deletions sharded/drivers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cache
from typing import Optional, Protocol
from typing import Protocol

from sharded.entity import Entity
from sharded.schema import StorageDriver
Expand All @@ -9,22 +8,24 @@ class Driver(Protocol):
def __init__(self, dsn: str) -> None:
...

async def create(self, name: str, data: dict) -> dict:
async def create(self, entity: type[Entity], data: dict) -> dict:
raise NotImplementedError()

async def find(self, name: str, queries: list[dict]) -> list[dict]:
async def find(
self, entity: type[Entity], queries: list[dict]
) -> list[dict]:
raise NotImplementedError()

async def find_or_create(
self, name: str, query: dict, data: dict
self, entity: type[Entity], query: dict, data: dict
) -> dict:
result = await self.find(name, [query])
result = await self.find(entity, [query])
if len(result):
return result[0]

return await self.create(name, data)
return await self.create(entity, data)

async def init_schema(self, entity: Entity) -> None:
async def init_schema(self, entity: type[Entity]) -> None:
raise NotImplementedError()


Expand All @@ -39,9 +40,8 @@ async def get_driver(driver: StorageDriver, dsn: str) -> Driver:
return driver_instances[driver][dsn]


@cache
def get_implementation(driver):
implementations: dict[StorageDriver, Driver] = {
def get_implementation(driver: StorageDriver) -> type[Driver]:
implementations: dict[StorageDriver, type[Driver]] = {
StorageDriver.MEMORY: MemoryDriver
}
if driver in implementations:
Expand All @@ -50,25 +50,27 @@ def get_implementation(driver):


class MemoryDriver(Driver):
data: Optional[dict[str, list[dict]]] = None
def __init__(self, dsn: str) -> None:
self.data: dict[type[Entity], list[dict]] = {}

async def create(self, name: str, data: dict) -> dict:
data['id'] = len(self.data[name]) + 1
self.data[name].append(data)
async def create(self, entity: type[Entity], data: dict) -> dict:
await self.init_schema(entity)
data['id'] = len(self.data[entity]) + 1
self.data[entity].append(data)
return data

async def find(self, name: str, queries: list[dict]) -> list[dict]:
async def find(
self, entity: type[Entity], queries: list[dict]
) -> list[dict]:
await self.init_schema(entity)
return [
row for row in self.data[name]
row for row in self.data[entity]
if await self.is_valid(row, queries)
]

async def init_schema(self, entity: Entity) -> None:
if not self.data:
self.data = {}

if entity.__name__ not in self.data:
self.data[entity.__name__] = []
async def init_schema(self, entity: type[Entity]) -> None:
if entity not in self.data:
self.data[entity] = []

async def is_valid(self, row, queries: list) -> bool:
for query in queries:
Expand Down
2 changes: 1 addition & 1 deletion sharded/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ class Bucket(Entity):
storage_id: int
status: BucketStatus

def __hash__(self) -> str:
def __hash__(self) -> int:
return hash((self.repository, self.key))
75 changes: 44 additions & 31 deletions sharded/gateway.py → sharded/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from functools import cache
from typing import Optional
from typing import Any, Optional

from sharded.drivers import Driver, get_driver
from sharded.entity import Bucket, Entity, Storage
Expand All @@ -13,11 +13,11 @@
class QueryContext:
bucket: Bucket
driver: Driver
entity: Entity
entity: type[Entity]
repository: Repository


class Gateway:
class Registry:
def __init__(self) -> None:
self.ready: bool = False
self.repositories: dict[type[Repository], Repository] = {
Expand All @@ -33,6 +33,12 @@ def __init__(self) -> None:
)
]

def get_repository(self, cls: type[Repository]) -> Repository:
if cls not in self.repositories:
raise LookupError(f'repository {cls} is not registered')

return self.repositories[cls]

async def bootstrap(self, source: Optional[Storage] = None):
if self.ready:
return
Expand All @@ -42,25 +48,29 @@ async def bootstrap(self, source: Optional[Storage] = None):
source = self.storages[0]

driver = await get_driver(source.driver, source.dsn)
await driver.init_schema(Bucket)
await driver.init_schema(Storage)
await self.repositories[BucketRepository].bootstrap(driver)
await self.repositories[StorageRepository].bootstrap(driver, source)

for repository in self.repositories.values():
if isinstance(repository, BucketRepository):
await repository.bootstrap(driver)
if isinstance(repository, StorageRepository):
await repository.bootstrap(driver, source)

async def find_or_create(
self,
entity: type[Entity],
data: Optional[dict | list] = None,
query: Optional[dict | list] = None,
key: Optional[any] = None,
data: Optional[dict] = None,
query: Optional[dict] = None,
key: Optional[Any] = None,
) -> Entity:
if data is None:
data = {}
if query is None:
query = data
context = await self.context(entity, key)
return context.repository.make(
entity=entity,
row=await context.driver.find_or_create(
name=entity.__name__,
entity=entity,
query=dict(bucket_id=context.bucket.id, **query),
data=dict(bucket_id=context.bucket.id, **data),
)
Expand All @@ -69,14 +79,14 @@ async def find_or_create(
async def create(
self,
entity: type[Entity],
data: Optional[dict | list] = None,
key: Optional[any] = None,
data: dict,
key: Optional[Any] = None,
) -> Entity:
context = await self.context(entity, key)
return context.repository.make(
entity=entity,
row=await context.driver.create(
name=entity.__name__,
entity=entity,
data=dict(bucket_id=context.bucket.id, **data),
)
)
Expand All @@ -85,37 +95,38 @@ async def find(
self,
entity: type[Entity],
queries: Optional[dict | list] = None,
key: Optional[any] = None,
key: Optional[Any] = None,
) -> list[Entity]:
# return repository.get_instances(entity, data)
# return await repository.find(entity, bucket, queries)
context = await self.context(entity, key)
if not queries:
queries = {}
if isinstance(queries, dict):
queries = [queries]

queries = [
dict(bucket_id=context.bucket.id, **query) for query in queries
context = await self.context(entity, key)
bucket_queries = [
dict(bucket_id=context.bucket.id, **query)
for query in queries
]
rows = await context.driver.find(entity.__name__, queries)

rows = await context.driver.find(entity, bucket_queries)
return [context.repository.make(entity, row) for row in rows]

async def get(
async def get_instance(
self,
entity: type[Entity],
id: int,
key: Optional[any] = None,
key: Optional[Any] = None,
) -> Optional[Entity]:
instances = await self.find(entity, {'id': id}, key)
if len(instances):
return instances[0]

async def context(self, entity: type[Entity], key: any) -> QueryContext:
return None

async def context(self, entity: type[Entity], key: Any) -> QueryContext:
await self.bootstrap()

repository = self.repositories[get_entity_repository_class(entity)]
repository = self.get_repository(get_entity_repository_class(entity))
bucket = await self.get_bucket(repository, key)

if not bucket.storage_id:
Expand All @@ -140,11 +151,13 @@ async def context(self, entity: type[Entity], key: any) -> QueryContext:

return QueryContext(bucket, driver, entity, repository)

async def get_bucket(self, repository: Repository, key: any) -> Bucket:
if isinstance(repository, BucketRepository):
return self.repositories[BucketRepository].buckets[Bucket]
if isinstance(repository, StorageRepository):
return self.repositories[BucketRepository].buckets[Storage]
async def get_bucket(self, repository: Repository, key: Any) -> Bucket:
if isinstance(repository, BucketRepository | StorageRepository):
buckets = self.get_repository(BucketRepository)
bucket = buckets.map[Bucket][repository.bucket_id]
if isinstance(bucket, Bucket):
return bucket

return await self.find_or_create(
entity=Bucket,
query={
Expand All @@ -166,7 +179,7 @@ def get_storage(self, storage_id: int) -> Storage:
if candidate.id == storage_id:
storage = candidate

if not Storage:
if not storage:
raise LookupError(
f'storage {storage_id} not found'
)
Expand Down
72 changes: 44 additions & 28 deletions sharded/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def __init__(self) -> None:
if not self.indexes:
self.indexes = []

self.map: dict[type[Entity], dict[int, Entity]] = {}
for entity in self.entities:
self.indexes.insert(0, UniqueIndex(entity, ['id']))
self.map[entity] = {}

async def cast_storage(self, storages: list[Storage]) -> Storage:
return storages[0]
Expand All @@ -43,38 +45,49 @@ async def init_schema(self, driver: Driver) -> None:
await driver.init_schema(entity)

def make(self, entity: type[Entity], row: dict) -> Entity:
return entity(**{k: v for (k, v) in row.items() if k != 'bucket_id'})
if row['id'] in self.map[entity]:
instance = self.map[entity][row['id']]
for key, value in row.items():
setattr(instance, key, value)
else:
instance = entity(**{
k: v for (k, v) in row.items() if k != 'bucket_id'
})
self.map[entity][row['id']] = instance

return instance


@cache
def get_entity_repository_class(entity: type[Entity]) -> type[Repository]:
repositories = [
repository for repository in Repository.__subclasses__()
if entity in repository.entities
]
if not len(repositories):
map = get_entity_repository_map()
if entity not in map:
raise LookupError(f'No entity repository found: {entity}')
return map[entity]


if len(repositories) > 1:
raise LookupError(f'Duplicate entity repository: {entity}')
@cache
def get_entity_repository_map() -> dict[type[Entity], type[Repository]]:
map: dict[type[Entity], type[Repository]] = {}
for repository in Repository.__subclasses__():
for entity in repository.entities:
if entity in map:
raise LookupError(f'Duplicate entity repository: {entity}')
map[entity] = repository

return repositories[0]
return map


class BucketRepository(Repository):
buckets: dict[type[Entity], Bucket]
bucket_id: int = 1
entities = [Bucket]

def __init__(self) -> None:
self.buckets = {}

async def bootstrap(self, driver: Driver) -> tuple[Bucket, Bucket]:
async def bootstrap(self, driver: Driver) -> None:
bucket_row = await driver.find_or_create(
name='Bucket',
query={'id': 1},
entity=Bucket,
query={'id': BucketRepository.bucket_id},
data={
'bucket_id': 1,
'id': 1,
'bucket_id': BucketRepository.bucket_id,
'id': BucketRepository.bucket_id,
'key': '',
'repository': BucketRepository,
'status': BucketStatus.READY,
Expand All @@ -83,28 +96,31 @@ async def bootstrap(self, driver: Driver) -> tuple[Bucket, Bucket]:
)

storage_row = await driver.find_or_create(
name='Bucket',
query={'id': 2},
entity=Bucket,
query={'id': StorageRepository.bucket_id},
data={
'bucket_id': 1,
'id': 2,
'bucket_id': BucketRepository.bucket_id,
'id': StorageRepository.bucket_id,
'key': '',
'repository': StorageRepository,
'status': BucketStatus.READY,
'storage_id': 1,
}
)

self.buckets[Bucket] = self.make(Bucket, bucket_row)
self.buckets[Storage] = self.make(Bucket, storage_row)
self.make(Bucket, bucket_row)
self.make(Bucket, storage_row)


class StorageRepository(Repository):
bucket_id: int = 2
entities = [Storage]

async def bootstrap(self, driver: Driver, storage: Storage):
async def bootstrap(self, driver: Driver, storage: Storage) -> None:
await driver.find_or_create(
name='Storage',
entity=Storage,
query={'id': storage.id},
data=dict(bucket_id=2, **storage.__dict__),
data=dict(
bucket_id=StorageRepository.bucket_id, **storage.__dict__
),
)
Loading

0 comments on commit 9fb16fc

Please sign in to comment.