Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export_workspace: support remove_original #320

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions openeo_driver/save_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ def add_workspace_export(self, workspace_id: str, merge: Optional[str]):
# results stored in env[ENV_SAVE_RESULT] instead of what ultimately comes out of the process graph.
self._workspace_exports.append(self._WorkspaceExport(workspace_id, merge))

def export_workspace(self, workspace_repository: WorkspaceRepository, hrefs: List[str], default_merge: str):
def export_workspace(
self,
workspace_repository: WorkspaceRepository,
hrefs: List[str],
default_merge: str,
remove_original: bool = False,
):
for export in self._workspace_exports:
workspace = workspace_repository.get_by_id(export.workspace_id)

Expand All @@ -113,9 +119,9 @@ def export_workspace(self, workspace_repository: WorkspaceRepository, hrefs: Lis
uri_parts = urlparse(href)

if not uri_parts.scheme or uri_parts.scheme.lower() == "file":
workspace.import_file(Path(uri_parts.path), merge)
workspace.import_file(Path(uri_parts.path), merge, remove_original)
elif uri_parts.scheme == "s3":
workspace.import_object(href, merge)
workspace.import_object(href, merge, remove_original)
else:
raise ValueError(f"unsupported scheme {uri_parts.scheme} for {href}; supported are: file, s3")

Expand Down
16 changes: 8 additions & 8 deletions openeo_driver/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

class Workspace(abc.ABC):
@abc.abstractmethod
def import_file(self, file: Path, merge: str):
def import_file(self, file: Path, merge: str, remove_original: bool = False):
raise NotImplementedError

@abc.abstractmethod
def import_object(self, s3_uri: str, merge: str):
def import_object(self, s3_uri: str, merge: str, remove_original: bool = False):
raise NotImplementedError


Expand All @@ -23,18 +23,18 @@ class DiskWorkspace(Workspace):
def __init__(self, root_directory: Path):
self.root_directory = root_directory

def import_file(self,
file: Path,
merge: str):
def import_file(self, file: Path, merge: str, remove_original: bool = False):
merge = os.path.normpath(merge)
subdirectory = merge[1:] if merge.startswith("/") else merge
target_directory = self.root_directory / subdirectory
target_directory.relative_to(self.root_directory) # assert target_directory is in root_directory

target_directory.mkdir(parents=True, exist_ok=True)
shutil.copy(file, target_directory)

_log.debug(f"copied {file.absolute()} to {target_directory}")
operation = shutil.move if remove_original else shutil.copy
operation(str(file), str(target_directory))

def import_object(self, s3_uri: str, merge: str):
_log.debug(f"{'moved' if remove_original else 'copied'} {file.absolute()} to {target_directory}")

def import_object(self, s3_uri: str, merge: str, remove_original: bool = False):
raise NotImplementedError(f"importing objects is not supported yet")
24 changes: 16 additions & 8 deletions tests/test_dry_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,7 +1916,8 @@ def test_invalid_latlon_in_geojson(dry_run_env):
evaluate(cube.flat_graph(), env=dry_run_env)


def test_export_workspace(dry_run_tracer, backend_implementation):
@pytest.mark.parametrize("remove_original", [False, True])
def test_export_workspace(dry_run_tracer, backend_implementation, remove_original):
mock_workspace_repository = mock.Mock(WorkspaceRepository)
mock_workspace = mock_workspace_repository.get_by_id.return_value

Expand Down Expand Up @@ -1954,17 +1955,21 @@ def test_export_workspace(dry_run_tracer, backend_implementation):
assert save_result.is_format("GTiff")

save_result.export_workspace(
mock_workspace_repository, hrefs=["file:file1", "file:file2"], default_merge="/some/unique/path"
mock_workspace_repository,
hrefs=["file:file1", "file:file2"],
default_merge="/some/unique/path",
remove_original=remove_original,
)
mock_workspace.import_file.assert_has_calls(
[
mock.call(Path("file1"), "some/path"),
mock.call(Path("file2"), "some/path"),
mock.call(Path("file1"), "some/path", remove_original),
mock.call(Path("file2"), "some/path", remove_original),
]
)


def test_export_workspace_with_multiple_save_result(dry_run_tracer, backend_implementation):
@pytest.mark.parametrize("remove_original", [False, True])
def test_export_workspace_with_multiple_save_result(dry_run_tracer, backend_implementation, remove_original):
mock_workspace_repository = mock.Mock(WorkspaceRepository)
mock_workspace = mock_workspace_repository.get_by_id.return_value

Expand Down Expand Up @@ -2018,13 +2023,16 @@ def test_export_workspace_with_multiple_save_result(dry_run_tracer, backend_impl

for save_result in save_results:
save_result.export_workspace(
mock_workspace_repository, hrefs=[f"file:out.{save_result.format}"], default_merge="/some/unique/path"
mock_workspace_repository,
hrefs=[f"file:out.{save_result.format}"],
default_merge="/some/unique/path",
remove_original=remove_original,
)

mock_workspace.import_file.assert_has_calls(
[
mock.call(Path("out.netCDF"), "some/path"),
mock.call(Path("out.GTiff"), "/some/unique/path"),
mock.call(Path("out.netCDF"), "some/path", remove_original),
mock.call(Path("out.GTiff"), "/some/unique/path", remove_original),
]
)

Expand Down
18 changes: 13 additions & 5 deletions tests/test_save_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,28 @@ def test_with_format():
("", "."),
(None, "/some/unique/path")
])
def test_export_workspace(merge, expected_workspace_path):
@pytest.mark.parametrize("remove_original", [False, True])
def test_export_workspace(merge, expected_workspace_path, remove_original):
mock_workspace_repository = mock.Mock(spec=WorkspaceRepository)
mock_workspace = mock_workspace_repository.get_by_id.return_value

r = SaveResult()
r.add_workspace_export(workspace_id="some-workspace", merge=merge)
r.export_workspace(
workspace_repository=mock_workspace_repository, hrefs=["/some/file"], default_merge="/some/unique/path"
workspace_repository=mock_workspace_repository,
hrefs=["/some/file"],
default_merge="/some/unique/path",
remove_original=remove_original,
)

mock_workspace.import_file.assert_called_with(Path("/some/file"), expected_workspace_path)
mock_workspace.import_file.assert_called_with(Path("/some/file"), expected_workspace_path, remove_original)


@pytest.mark.parametrize(
["merge", "expected_workspace_path"], [("some/path", "some/path"), ("", "."), (None, "/some/unique/path")]
)
def test_export_workspace_s3(merge, expected_workspace_path):
@pytest.mark.parametrize("remove_original", [False, True])
def test_export_workspace_s3(merge, expected_workspace_path, remove_original):
mock_workspace_repository = mock.Mock(spec=WorkspaceRepository)
mock_workspace = mock_workspace_repository.get_by_id.return_value

Expand All @@ -68,9 +73,12 @@ def test_export_workspace_s3(merge, expected_workspace_path):
workspace_repository=mock_workspace_repository,
hrefs=["s3://some_bucket/some/key"],
default_merge="/some/unique/path",
remove_original=remove_original,
)

mock_workspace.import_object.assert_called_with("s3://some_bucket/some/key", expected_workspace_path)
mock_workspace.import_object.assert_called_with(
"s3://some_bucket/some/key", expected_workspace_path, remove_original
)


def test_aggregate_polygon_result_basic():
Expand Down
32 changes: 25 additions & 7 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
from pathlib import Path

import pytest

from openeo_driver.workspace import DiskWorkspace
Expand All @@ -14,12 +11,33 @@
".",
])
def test_disk_workspace(tmp_path, merge):
workspace = DiskWorkspace(root_directory=tmp_path)
source_directory = tmp_path / "src"
source_directory.mkdir()
source_file = source_directory / "file"
source_file.touch()

subdirectory = merge[1:] if merge.startswith("/") else merge
target_directory = tmp_path / subdirectory

input_file = Path(__file__)
workspace.import_file(file=input_file, merge=merge)
workspace = DiskWorkspace(root_directory=tmp_path)
workspace.import_file(file=source_file, merge=merge)

assert (target_directory / source_file.name).exists()
assert source_file.exists()


@pytest.mark.parametrize("remove_original", [False, True])
def test_disk_workspace_remove_original(tmp_path, remove_original):
source_directory = tmp_path / "src"
source_directory.mkdir()
source_file = source_directory / "file"
source_file.touch()

merge = "."
target_directory = tmp_path / merge

workspace = DiskWorkspace(root_directory=tmp_path)
workspace.import_file(source_file, merge=merge, remove_original=remove_original)

assert "test_workspace.py" in os.listdir(target_directory)
assert (target_directory / source_file.name).exists()
assert source_file.exists() != remove_original