diff --git a/scanpipe/forms.py b/scanpipe/forms.py index 77547c329..f64d0bf1c 100644 --- a/scanpipe/forms.py +++ b/scanpipe/forms.py @@ -252,7 +252,7 @@ class BaseProjectActionForm(forms.Form): ) -class ArchiveProjectForm(forms.Form): +class ArchiveProjectForm(BaseProjectActionForm): remove_input = forms.BooleanField( label="Remove inputs", initial=True, diff --git a/scanpipe/views.py b/scanpipe/views.py index 19d9acded..0ae9b0bea 100644 --- a/scanpipe/views.py +++ b/scanpipe/views.py @@ -77,6 +77,7 @@ from scanpipe.forms import AddLabelsForm from scanpipe.forms import AddPipelineForm from scanpipe.forms import ArchiveProjectForm +from scanpipe.forms import BaseProjectActionForm from scanpipe.forms import EditInputSourceTagForm from scanpipe.forms import PipelineRunStepSelectionForm from scanpipe.forms import ProjectCloneForm @@ -1176,31 +1177,38 @@ class ProjectActionView(ConditionalLoginRequired, ExportXLSXMixin, generic.ListV model = Project allowed_actions = ["archive", "delete", "reset", "report", "download"] + action_to_form_class = { + "archive": ArchiveProjectForm, + "report": ProjectReportForm, + "download": ProjectOutputDownloadForm, + } success_url = reverse_lazy("project_list") def post(self, request, *args, **kwargs): + action_kwargs = {} + action = request.POST.get("action") if action not in self.allowed_actions: raise Http404 - self.selected_project_ids = request.POST.get("selected_ids", "").split(",") - count = 0 + action_form = self.get_action_form(action) + selected_project_ids = request.POST.get("selected_ids", "").split(",") + project_qs = self.get_project_queryset(selected_project_ids, action_form) + + if action == "download": + return self.download_outputs_zip_response(project_qs, action_form) - action_kwargs = {} if action == "report": + self.action_form = action_form + self.project_qs = project_qs return self.export_xlsx_file_response() - if action == "download": - return self.download_outputs_zip_response() - if action == "archive": - archive_form = ArchiveProjectForm(request.POST) - if not archive_form.is_valid(): - raise Http404 - action_kwargs = archive_form.cleaned_data + action_kwargs = action_form.cleaned_data - for project_uuid in self.selected_project_ids: - if self.perform_action(action, project_uuid, action_kwargs): + count = 0 + for project in project_qs: + if self.perform_action(action, project, action_kwargs): count += 1 if count: @@ -1208,25 +1216,34 @@ def post(self, request, *args, **kwargs): return HttpResponseRedirect(self.success_url) - def perform_action(self, action, project_uuid, action_kwargs=None): + def get_action_form(self, action): + """Return the validated ``action_form`` instance.""" + action_form_class = self.action_to_form_class.get(action, BaseProjectActionForm) + action_form = action_form_class(self.request.POST) + + if not action_form.is_valid(): + raise Http404 + + return action_form + + def perform_action(self, action, project, action_kwargs=None): if not action_kwargs: action_kwargs = {} try: - project = Project.objects.get(pk=project_uuid) getattr(project, action)(**action_kwargs) - return True - except Project.DoesNotExist: - messages.error(self.request, f"Project {project_uuid} does not exist.") except RunInProgressError as error: messages.error(self.request, str(error)) except (AttributeError, ValidationError): raise Http404 + return True + def get_success_message(self, action, count): return f"{count} projects have been {action}." - def get_projects_queryset(self, action_form=None): + @staticmethod + def get_project_queryset(selected_project_ids=None, action_form=None): """ Return the Project QuerySet from the user selection. @@ -1234,54 +1251,41 @@ def get_projects_queryset(self, action_form=None): argument for the support of ``select_across``. """ if action_form: - select_across = self.report_form.cleaned_data.get("select_across") - url_query = self.report_form.cleaned_data.GET("url_query") - if select_across and url_query: + select_across = action_form.cleaned_data.get("select_across") + # url_query may be empty for a "select everything" + url_query = action_form.cleaned_data.get("url_query", "") + if select_across: project_filterset = ProjectFilterSet(data=QueryDict(url_query)) if project_filterset.is_valid(): return project_filterset.qs - return Project.objects.filter(pk__in=self.selected_project_ids) - - def export_xlsx_file_response(self): - self.report_form = ProjectReportForm(self.request.POST) - if not self.report_form.is_valid(): - return HttpResponseRedirect(self.success_url) - - return super().export_xlsx_file_response() + selected_project_ids = selected_project_ids or [] + return Project.objects.filter(uuid__in=selected_project_ids) def get_export_xlsx_queryset(self): - projects = self.get_projects_queryset(action_form=self.report_form) - - model_name = self.report_form.cleaned_data["model_name"] + model_name = self.action_form.cleaned_data["model_name"] queryset = output.get_queryset(project=None, model_name=model_name) - - return queryset.filter(project__in=projects) + return queryset.filter(project__in=self.project_qs) def get_export_xlsx_prepend_fields(self): return ["project"] def get_export_xlsx_worksheet_name(self): - if self.report_form.cleaned_data.get("model_name") == "todos": + if self.action_form.cleaned_data.get("model_name") == "todos": return "TODOS" def get_export_xlsx_filename(self): return "report.xlsx" - def download_outputs_zip_response(self): - outputs_download_form = ProjectOutputDownloadForm(self.request.POST) - if not outputs_download_form.is_valid(): - return HttpResponseRedirect(self.success_url) - - output_format = outputs_download_form.cleaned_data["output_format"] + @staticmethod + def download_outputs_zip_response(project_qs, action_form): + output_format = action_form.cleaned_data["output_format"] output_function = output.FORMAT_TO_FUNCTION_MAPPING.get(output_format) - projects = self.get_projects_queryset(action_form=outputs_download_form) - # In-memory file storage for the zip archive zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: - for project in projects: + for project in project_qs: output_file = output_function(project) filename = output.safe_filename(f"{project.name}_{output_file.name}") with open(output_file, "rb") as f: