diff --git a/airlock/business_logic.py b/airlock/business_logic.py index c0c13391..6b2f6a3d 100644 --- a/airlock/business_logic.py +++ b/airlock/business_logic.py @@ -139,13 +139,16 @@ class FileGroup: """ name: str - files: list[RequestFile] + files: dict[RequestFile] @classmethod def from_dict(cls, attrs): return cls( **{k: v for k, v in attrs.items() if k != "files"}, - files=[RequestFile.from_dict(value) for value in attrs.get("files", ())], + files={ + UrlPath(value["relpath"]): RequestFile.from_dict(value) + for value in attrs.get("files", ()) + }, ) @@ -207,11 +210,7 @@ def get_contents_url(self, relpath, download=False): url += "?download" return url - def abspath(self, relpath): - """Returns abspath to the file on disk. - - The first part of the relpath is the group, so we parse and validate that first. - """ + def get_request_file(self, relpath: UrlPath | str): relpath = UrlPath(relpath) group = relpath.parts[0] file_relpath = UrlPath(*relpath.parts[1:]) @@ -219,19 +218,24 @@ def abspath(self, relpath): if not (filegroup := self.filegroups.get(group)): raise BusinessLogicLayer.FileNotFound(f"bad group {group} in url {relpath}") - matching_files = [f for f in filegroup.files if f.relpath == file_relpath] - if not matching_files: + if not (request_file := filegroup.files.get(file_relpath)): raise BusinessLogicLayer.FileNotFound(relpath) - assert len(matching_files) == 1 - request_file = matching_files[0] + return request_file + + def abspath(self, relpath): + """Returns abspath to the file on disk. + + The first part of the relpath is the group, so we parse and validate that first. + """ + request_file = self.get_request_file(relpath) return self.root() / request_file.file_id def file_set(self): return { request_file.relpath for filegroup in self.filegroups.values() - for request_file in filegroup.files + for request_file in filegroup.files.values() } def set_filegroups_from_dict(self, attrs): @@ -240,7 +244,7 @@ def set_filegroups_from_dict(self, attrs): def get_file_paths(self): paths = [] for file_group in self.filegroups.values(): - for request_file in file_group.files: + for request_file in file_group.files.values(): relpath = request_file.relpath abspath = self.abspath(file_group.name / relpath) paths.append((relpath, abspath)) diff --git a/airlock/file_browser_api.py b/airlock/file_browser_api.py index 8e4c30b9..e2840a7c 100644 --- a/airlock/file_browser_api.py +++ b/airlock/file_browser_api.py @@ -268,22 +268,20 @@ def get_request_tree(release_request, selected_path=ROOT_PATH, selected_only=Fal expanded=selected or expanded, ) - group_paths = [f.relpath for f in group.files] - if selected_only: if expanded: if group_path == selected_path: # we just need the group's immediate child paths - pathlist = [UrlPath(p.parts[0]) for p in group_paths] + pathlist = [UrlPath(p.parts[0]) for p in group.files] else: # filter for just the selected path and any immediate children selected_subpath = selected_path.relative_to(group_path) - pathlist = list(filter_files(selected_subpath, group_paths)) + pathlist = list(filter_files(selected_subpath, group.files)) else: # we don't want any children for unselected groups pathlist = [] else: - pathlist = group_paths + pathlist = list(group.files) group_node.children = get_path_tree( release_request, diff --git a/tests/integration/views/test_workspace.py b/tests/integration/views/test_workspace.py index 2e5d0429..09762d04 100644 --- a/tests/integration/views/test_workspace.py +++ b/tests/integration/views/test_workspace.py @@ -2,6 +2,7 @@ from django.contrib import messages from django.shortcuts import reverse +from airlock.business_logic import UrlPath from tests import factories @@ -275,7 +276,7 @@ def test_workspace_request_file_creates(airlock_client, bll): release_request = bll.get_current_request(workspace.name, airlock_client.user) filegroup = release_request.filegroups["default"] assert filegroup.name == "default" - assert str(filegroup.files[0].relpath) == "test/path.txt" + assert UrlPath("test/path.txt") in filegroup.files assert release_request.abspath("default/test/path.txt").exists() @@ -299,7 +300,7 @@ def test_workspace_request_file_request_already_exists(airlock_client, bll): assert current_release_request.abspath("default/test/path.txt").exists() filegroup = current_release_request.filegroups["default"] assert filegroup.name == "default" - assert str(filegroup.files[0].relpath) == "test/path.txt" + assert UrlPath("test/path.txt") in filegroup.files def test_workspace_request_file_with_new_filegroup(airlock_client, bll): diff --git a/tests/unit/test_business_logic.py b/tests/unit/test_business_logic.py index 5f909ea1..54655ad5 100644 --- a/tests/unit/test_business_logic.py +++ b/tests/unit/test_business_logic.py @@ -338,16 +338,25 @@ def test_request_release_invalid_state(): ) -def test_request_release_abspath(bll): +def test_request_release_get_request_file(bll): path = UrlPath("foo/bar.txt") release_request = factories.create_release_request("id") factories.write_request_file(release_request, "default", path) with pytest.raises(bll.FileNotFound): - release_request.abspath("badgroup" / path) + release_request.get_request_file("badgroup" / path) with pytest.raises(bll.FileNotFound): - release_request.abspath("default/does/not/exist") + release_request.get_request_file("default/does/not/exist") + + request_file = release_request.get_request_file("default" / path) + assert request_file.relpath == path + + +def test_request_release_abspath(bll): + path = UrlPath("foo/bar.txt") + release_request = factories.create_release_request("id") + factories.write_request_file(release_request, "default", path) assert release_request.abspath("default" / path).exists() @@ -377,7 +386,7 @@ def test_release_request_filegroups_default_filegroup(bll): filegroup = release_request.filegroups["default"] assert filegroup.name == "default" assert len(filegroup.files) == 1 - assert filegroup.files[0].relpath == path + assert path in filegroup.files def test_release_request_filegroups_named_filegroup(bll): @@ -388,7 +397,7 @@ def test_release_request_filegroups_named_filegroup(bll): filegroup = release_request.filegroups["test_group"] assert filegroup.name == "test_group" assert len(filegroup.files) == 1 - assert filegroup.files[0].relpath == path + assert path in filegroup.files def test_release_request_filegroups_multiple_filegroups(bll): @@ -408,7 +417,7 @@ def test_release_request_filegroups_multiple_filegroups(bll): assert len(release_request.filegroups) == 2 release_request_files = { - filegroup.name: [file.relpath for file in filegroup.files] + filegroup.name: list(filegroup.files) for filegroup in release_request.filegroups.values() }