diff --git a/scrunch/datasets.py b/scrunch/datasets.py index 8ab11750..11ad2c9c 100644 --- a/scrunch/datasets.py +++ b/scrunch/datasets.py @@ -2367,18 +2367,38 @@ def fork(self, description=None, name=None, is_published=False, If True, the fork will be visible to viewers of ds. If False it will only be viewable to editors of ds. :param preserve_owner: bool, default=True - If True, the owner of the fork will be the same as the parent - dataset otherwise the owner will be the current user in the - session and the Dataset will be set under `Personal Project`. + If preserve_owner=True and project=None, the fork dataset will + be created in the same location as source dataset. + If preserve_owner=False and project is passed, the fork dataset + will be created in the given project location. :param project: str, default=None - The project ID or URL for the project in which the fork - dataset should be created. + The project ID or URL for the project in which the fork dataset + should be created. + If project=None, the fork dataset will be created in the same + location as the source dataset. :returns _fork: scrunch.datasets.BaseDataset """ from scrunch.mutable_dataset import MutableDataset + + # Handling project vs owner conflict + owner = kwargs.get("owner") - description = description or self.resource.body.description + if project and owner: + raise ValueError( + "Cannot pass both 'project' & 'owner' parameters together. " + "Please try again by passing only 'project' parameter." + ) + elif owner: + project = owner + del kwargs["owner"] + + LOG.warning( + "The preserve_owner parameter will be removed soon" + " in favour of providing a project or not " + "(in which case the behavior will be as it is currently" + " if preserve_owner=True and project=None).", + ) if name is None: nforks = len(self.resource.forks.index) @@ -2389,31 +2409,29 @@ def fork(self, description=None, name=None, is_published=False, body = dict( name=name, - description=description, + description=description or self.resource.body.description, is_published=is_published, **kwargs ) - - # Handling project vs owner conflict - owner = kwargs.get("owner") - - if project and owner: - raise ValueError( - "Cannot pass both 'project' & 'owner' parameters together. " - "Please try again by passing only 'project' parameter." - ) - elif owner: - project = owner - # Setting project value based on preserve_owner. - if preserve_owner and project: - raise ValueError("Cannot pass 'project' or 'owner' when preserve_owner=True") - elif preserve_owner: - body["owner"] = self.resource.body.owner - elif project: - body["owner"] = ( - project if project.startswith("http") else get_project(project).url - ) + if preserve_owner: + if project: + raise ValueError( + "Cannot pass 'project' or 'owner' when preserve_owner=True." + ) + else: + # Create fork in source dataset path. + body["project"] = self.resource.body.owner + else: + if project: + # Create fork in given Project path. + body["project"] = ( + project if project.startswith("http") else get_project(project).url + ) + else: + raise ValueError( + "Project parameter should be provided when preserve_owner=False." + ) payload = shoji_entity_wrapper(body) diff --git a/scrunch/tests/test_datasets.py b/scrunch/tests/test_datasets.py index 63782aa4..3d7fe98c 100644 --- a/scrunch/tests/test_datasets.py +++ b/scrunch/tests/test_datasets.py @@ -1769,12 +1769,14 @@ def _create(*args): ds_res.forks = MagicMock() ds_res.forks.create.side_effect = _create ds = StreamingDataset(ds_res) - forked_ds = ds.fork(preserve_owner=False) + forked_ds = ds.fork(preserve_owner=True) assert isinstance(forked_ds, MutableDataset) ds_res.forks.create.assert_called_with(as_entity({ 'name': 'FORK #1 of ds name', 'description': 'ds description', 'is_published': False, + 'project': 'http://test.crunch.io/api/users/123/', + })) @@ -1796,7 +1798,7 @@ def test_fork_project(self): 'body': { 'name': 'FORK #1 of ds name', 'description': 'ds description', - 'owner': project, # Project added + 'project': project, # Project added 'is_published': False, } } @@ -1816,38 +1818,49 @@ def test_fork_project(self): with pytest.raises(ValueError, match=err_msg1): ds.fork(owner="ABCD", project="1234") - err_msg2 = ("Cannot pass 'project' or 'owner' when preserve_owner=True") + err_msg2 = ("Cannot pass 'project' or 'owner' when preserve_owner=True.") with pytest.raises(ValueError, match=err_msg2): ds.fork(preserve_owner=True, project="1234") with pytest.raises(ValueError, match=err_msg2): - ds.fork(preserve_owner=True, owner="1234") - + ds.fork(preserve_owner=True, owner="1234") def test_fork_preserve_owner(self): - project = 'http://test.crunch.io/api/projects/123/' + project = "http://test.crunch.io/api/projects/123/" sess = MagicMock() - body = JSONObject({ - 'name': 'ds name', - 'description': 'ds description', - 'owner': project, - }) + body = JSONObject( + { + "name": "ds name", + "description": "ds description", + "owner": project, + } + ) ds_res = MagicMock(session=sess, body=body) ds_res.project.self = project ds_res.forks = MagicMock() ds_res.forks.index = {} ds = BaseDataset(ds_res) - ds.fork(preserve_owner=True) - ds_res.forks.create.assert_called_with({ - 'element': 'shoji:entity', - 'body': { - 'name': 'FORK #1 of ds name', - 'description': 'ds description', - 'owner': project, # Project preserved - 'is_published': False, + + # Test validations + with pytest.raises( + ValueError, + match="Project parameter should be provided when preserve_owner=False.", + ): + ds.fork(preserve_owner=False, project=None) + + ds.fork(preserve_owner=False, owner=project) + ds_res.forks.create.assert_called_with( + { + "element": "shoji:entity", + "body": { + "name": "FORK #1 of ds name", + "description": "ds description", + "project": project, # Project preserved + "is_published": False, + }, } - }) + ) def test_delete_forks(self): f1 = MagicMock()