Skip to content

Commit

Permalink
Keep tags (and annotation) when copying inputs to new history
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdbeek committed Oct 4, 2024
1 parent 40fbe9a commit 9b83b16
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from galaxy.util.rules_dsl import RuleSet
from galaxy.util.template import fill_template
from galaxy.util.tool_shed.common_util import get_tool_shed_url_from_tool_shed_registry
from galaxy.work.context import WorkRequestContext

if TYPE_CHECKING:
from galaxy.schema.invocation import InvocationMessageUnion
Expand Down Expand Up @@ -468,7 +469,7 @@ def decode_runtime_state(self, step, runtime_state):
return state

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self, trans: WorkRequestContext, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
"""Execute the given workflow invocation step.
Expand Down Expand Up @@ -764,7 +765,11 @@ def get_content_id(self):
return self.trans.security.encode_id(self.subworkflow.id)

def execute(
self, trans, progress: "WorkflowProgress", invocation_step: WorkflowInvocationStep, use_cached_job: bool = False
self,
trans: WorkRequestContext,
progress: "WorkflowProgress",
invocation_step: WorkflowInvocationStep,
use_cached_job: bool = False,
) -> Optional[bool]:
"""Execute the given workflow step in the given workflow invocation.
Use the supplied workflow progress object to track outputs, find
Expand Down Expand Up @@ -949,33 +954,33 @@ def get_all_inputs(self, data_only=False, connectable_only=False):
return []

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self, trans: WorkRequestContext, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
if input_value is None:
default_value = step.get_input_default_value(NO_REPLACEMENT)
if default_value is not NO_REPLACEMENT:
input_value = raw_to_galaxy(trans.app, trans.history, default_value)
input_value = raw_to_galaxy(trans.app, invocation.history, default_value)

step_outputs = dict(output=input_value)

# Web controller may set copy_inputs_to_history, API controller always sets
# inputs.
if progress.copy_inputs_to_history:
for input_dataset_hda in list(step_outputs.values()):
content_type = input_dataset_hda.history_content_type
if content_type == "dataset":
new_hda = input_dataset_hda.copy()
invocation.history.add_dataset(new_hda)
step_outputs["input_ds_copy"] = new_hda
elif content_type == "dataset_collection":
new_hdca = input_dataset_hda.copy()
invocation.history.add_dataset_collection(new_hdca)
step_outputs["input_ds_copy"] = new_hdca
history = invocation.history
for input_item in list(step_outputs.values()):
if isinstance(input_item, model.HistoryDatasetAssociation):
step_outputs["input_ds_copy"] = trans.app.hda_manager.copy(input_item, history, flush=False)
elif isinstance(input_item, model.HistoryDatasetCollectionAssociation):
step_outputs["input_ds_copy"] = input_item.copy(
element_destination=history, set_hid=False, flush=False
)
history.stage_addition(step_outputs["input_ds_copy"])
else:
raise Exception("Unknown history content encountered")
history.add_pending_items()
# If coming from UI - we haven't registered invocation inputs yet,
# so do that now so dependent steps can be recalculated. In the future
# everything should come in from the API and this can be eliminated.
Expand Down Expand Up @@ -1548,7 +1553,7 @@ def get_all_outputs(self, data_only=False):
]

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self, trans: WorkRequestContext, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
Expand Down Expand Up @@ -1695,7 +1700,7 @@ def get_runtime_state(self):
return state

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self, trans: WorkRequestContext, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
progress.mark_step_outputs_delayed(step, why="executing pause step")
Expand Down Expand Up @@ -2151,7 +2156,7 @@ def decode_runtime_state(self, step, runtime_state):
)

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self, trans: WorkRequestContext, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
Expand Down

0 comments on commit 9b83b16

Please sign in to comment.