Skip to content

Commit

Permalink
Merge pull request #9345 from OpenMined/deposit_result_test
Browse files Browse the repository at this point in the history
Added admin methods for get and set
  • Loading branch information
teo-milea authored Oct 14, 2024
2 parents e424573 + 142a6ea commit 50800b7
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 2 deletions.
31 changes: 31 additions & 0 deletions packages/syft/src/syft/service/migration/migration_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from collections import defaultdict
import logging
from typing import Any

# syft absolute
import syft
Expand All @@ -16,6 +17,7 @@
from ...types.syft_object import SyftObject
from ...types.syft_object_registry import SyftObjectRegistry
from ...types.twin_object import TwinObject
from ...types.uid import UID
from ..action.action_object import Action
from ..action.action_object import ActionObject
from ..action.action_permissions import ActionObjectPermission
Expand All @@ -26,7 +28,10 @@
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..sync.sync_service import get_store
from ..sync.sync_service import get_store_by_type
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
from .object_migration_state import MigrationData
from .object_migration_state import StoreMetadata
Expand Down Expand Up @@ -493,3 +498,29 @@ def reset_and_restore(
)

return SyftSuccess(message="Database reset successfully.")

@service_method(
path="migration._get_object",
name="_get_object",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def _get_object(
self, context: AuthedServiceContext, uid: UID, object_type: type
) -> Any:
return (
get_store_by_type(context, object_type)
.get_by_uid(credentials=context.credentials, uid=uid)
.unwrap()
)

@service_method(
path="migration._update_object",
name="_update_object",
roles=ADMIN_ROLE_LEVEL,
)
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
return (
get_store(context, object)
.update(credentials=context.credentials, obj=object)
.unwrap()
)
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat
# which tries to send an email to the admin and ends up here
pass # lets keep going

self.refresh()
if len(self.history) == 0:
return RequestStatus.PENDING

Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@


def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash:
if isinstance(item, ActionObject):
return get_store_by_type(context=context, obj_type=type(item))


def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash:
if issubclass(obj_type, ActionObject):
service = context.server.services.action # type: ignore
return service.stash # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore
return service.stash


Expand Down
11 changes: 11 additions & 0 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,17 @@ def make_id(cls, values: Any) -> Any:
__table_coll_widths__: ClassVar[list[str] | None] = None
__table_sort_attr__: ClassVar[str | None] = None

def refresh(self) -> None:
try:
api = self._get_api()
new_object = api.services.migration._get_object(
uid=self.id, object_type=type(self)
)
if type(new_object) == type(self):
self.__dict__.update(new_object.__dict__)
except Exception as _:
return

def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
funcs = print_type_cache[type(self)]
if len(funcs) > 0:
Expand Down
57 changes: 57 additions & 0 deletions packages/syft/tests/syft/service/sync/get_set_object_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# third party

# syft absolute
import syft as sy
from syft.client.datasite_client import DatasiteClient
from syft.service.action.action_object import ActionObject
from syft.service.dataset.dataset import Dataset


def get_ds_client(client: DatasiteClient) -> DatasiteClient:
client.register(
name="a",
email="[email protected]",
password="asdf",
password_verify="asdf",
)
return client.login(email="[email protected]", password="asdf")


def test_get_set_object(high_worker):
high_client: DatasiteClient = high_worker.root_client
_ = get_ds_client(high_client)
root_datasite_client = high_worker.root_client
dataset = sy.Dataset(
name="local_test",
asset_list=[
sy.Asset(
name="local_test",
data=[1, 2, 3],
mock=[1, 1, 1],
)
],
)
root_datasite_client.upload_dataset(dataset)
dataset = root_datasite_client.datasets[0]

other_dataset = high_client.api.services.migration._get_object(
uid=dataset.id, object_type=Dataset
)
other_dataset.server_uid = dataset.server_uid
assert dataset == other_dataset
other_dataset.name = "new_name"
updated_dataset = high_client.api.services.migration._update_object(
object=other_dataset
)
assert updated_dataset.name == "new_name"

asset = root_datasite_client.datasets[0].assets[0]
source_ao = high_client.api.services.action.get(uid=asset.action_id)
ao = high_client.api.services.migration._get_object(
uid=asset.action_id, object_type=ActionObject
)
ao._set_obj_location_(
high_worker.id,
root_datasite_client.credentials,
)
assert source_ao == ao

0 comments on commit 50800b7

Please sign in to comment.