diff --git a/dali/pipeline/operator/op_spec.h b/dali/pipeline/operator/op_spec.h index 30bcd53898f..c334866a125 100644 --- a/dali/pipeline/operator/op_spec.h +++ b/dali/pipeline/operator/op_spec.h @@ -264,6 +264,16 @@ class DLL_PUBLIC OpSpec { return outputs_[idx].device; } + DLL_PUBLIC inline void RenameInput(int idx, std::string name) { + DALI_ENFORCE_VALID_INDEX(idx, NumInput()); + inputs_[idx].name = std::move(name); + } + + DLL_PUBLIC inline void RenameOutput(int idx, std::string name) { + DALI_ENFORCE_VALID_INDEX(idx, NumOutput()); + outputs_[idx].name = std::move(name); + } + DLL_PUBLIC inline auto &ArgumentInputs() const { return argument_inputs_; } diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index 3587a6f923c..5e650c1ef53 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -312,7 +312,7 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ DALI_ENFORCE(it != edge_names_.end(), make_string("Data node \"", input_name, "\" requested as ", FormatInput(spec, i), - " to the operator is not known to the pipeline.")); + " to operator \"", inst_name, "\" is not known to the pipeline.")); // Table of possible scenarios: // Op location / requested input type / data location @@ -360,7 +360,7 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ DALI_ENFORCE( it != edge_names_.end(), make_string("Data node \"", input_name, "\" requested as ", FormatArgument(spec, arg_name), - " to operator is not known to the pipeline.")); + " to operator \"", inst_name, "\" is not known to the pipeline.")); if (!it->second.has_cpu) { assert(it->second.has_gpu); diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index cf9c1ce71a1..fb9511830a6 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -2152,6 +2152,15 @@ PYBIND11_MODULE(backend_impl, m) { py::return_value_policy::reference_internal) .def("AddOutput", &OpSpec::AddOutput, py::return_value_policy::reference_internal) + .def("RenameInput", &OpSpec::RenameInput, "idx"_a, "name"_a) + .def("RenameOutput", &OpSpec::RenameOutput, "idx"_a, "name"_a) + .def("InputName", &OpSpec::InputName, "idx"_a) + .def("InputDevice", &OpSpec::InputDevice, "idx"_a) + .def("OutputName", &OpSpec::OutputName, "idx"_a) + .def("OutputDevice", &OpSpec::OutputDevice, "idx"_a) + .def("NumInput", &OpSpec::NumInput) + .def("NumRegularInput", &OpSpec::NumRegularInput) + .def("NumOutput", &OpSpec::NumOutput) DALI_OPSPEC_ADDARG(std::string) DALI_OPSPEC_ADDARG(bool) DALI_OPSPEC_ADDARG(int64) diff --git a/dali/python/nvidia/dali/data_node.py b/dali/python/nvidia/dali/data_node.py index 5a79fd40085..48d636c0a2e 100644 --- a/dali/python/nvidia/dali/data_node.py +++ b/dali/python/nvidia/dali/data_node.py @@ -73,7 +73,7 @@ def __init__(self, name, device="cpu", source=None): self.source = source def __str__(self): - return f'DataNode(name="{self.name}", device="{self.device}")' + return f'DataNode(name="{self.name}", device="{self.device}, source="{self.source}")' __repr__ = __str__ diff --git a/dali/python/nvidia/dali/ops/__init__.py b/dali/python/nvidia/dali/ops/__init__.py index 513786ce554..4e5063257c3 100644 --- a/dali/python/nvidia/dali/ops/__init__.py +++ b/dali/python/nvidia/dali/ops/__init__.py @@ -17,6 +17,7 @@ import threading import tree import warnings +import weakref from itertools import count import nvidia.dali.python_function_plugin @@ -369,11 +370,16 @@ def __init__(self, inputs, arg_inputs, arguments, _processed_arguments, op): op : Operator class. Operator class containing the schema, and spec filled with `processed_arguments`. """ - self._counter = _OpCounter() + + if _Pipeline.current(): + self._pipeline = weakref.ref(_Pipeline.current()) + else: + self._pipeline = None + self._id = None self._outputs = [] self._op = op self._spec = op.spec.copy() - self._relation_id = self._counter.id + self._relation_id = None if _conditionals.conditionals_enabled(): inputs, arg_inputs = _conditionals.apply_conditional_split_to_args(inputs, arg_inputs) @@ -412,8 +418,13 @@ def _process_instance_name(self, arguments): name = arguments.pop("name", None) if name is not None: self._name = name + self._autoname = False else: - self._name = "__" + type(self._op).__name__ + "_" + str(self._counter.id) + has_pipeline = self.pipeline is not None + # to avoid mixing up global and per-pipeline ids + infix = "_" if has_pipeline else "_detached_" + self._name = "__" + type(self._op).__name__ + infix + str(self.id) + self._autoname = True def _process_trace(self, arguments): from nvidia.dali._debug_mode import _PipelineDebug @@ -463,9 +474,21 @@ def _generate_outputs(self): pipeline.add_sink(t) self.append_output(t) + @property + def pipeline(self): + return None if self._pipeline is None else self._pipeline() + @property def id(self): - return self._counter.id + if self._id is None: + if self.pipeline is None and _Pipeline.current(): + self._pipeline = weakref.ref(_Pipeline.current()) + if self.pipeline: + self._id = self.pipeline._next_op_id() + else: + self._id = _OpCounter().id + + return self._id @property def inputs(self): @@ -492,6 +515,8 @@ def name(self): @property def relation_id(self): + if self._relation_id is None: + self._relation_id = id(self) return self._relation_id @relation_id.setter @@ -629,7 +654,7 @@ def __call__(self, *inputs, **kwargs): ) # Tie the instances together - relation_id = op_instances[0].id + relation_id = op_instances[0].relation_id for op in op_instances: op.relation_id = relation_id diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index aca5d2b451d..9a23b6dd0a9 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -14,7 +14,6 @@ # pylint: disable=no-member from typing import Any, List, Tuple, Callable, Optional, Union, TypeVar, overload -from collections import deque from nvidia.dali import backend as b from nvidia.dali import types from nvidia.dali import internal @@ -235,6 +234,7 @@ def __init__( self._num_threads = num_threads self._device_id = device_id self._seed = seed + self._next_op_id_counter = 0 self._exec_pipelined = exec_pipelined # When initializing DALI, we do the following in order: # * Discover the ops specified in Python, group the ExternalSources (_build_graph()) @@ -719,6 +719,48 @@ def __exit__(self, type, value, traceback): return api_checker(self) + def _require_unique_names(self): + ops_by_name = {} + for op in self._ops: + ops = ops_by_name.get(op.name, []) + ops.append(op) + duplicate = {} + foreign = False + for name, ops in ops_by_name.items(): + if len(ops) > 1: + duplicate[name] = ops + for op in ops: + if op.pipeline is not self: + foreign = True + + if duplicate: + message = ( + f"The pipeline is invalid because it contains operators with non-unique names:\n" + f"{duplicate}" + ) + if foreign: + message += ( + "\nThe likely cause is that the pipeline contains a subgraph " + "instantiated while a different pipeline was set as the current " + "pipeline (e.g. inside another pipeline's graph definition function).\n" + ) + raise RuntimeError(message) + + def _require_no_foreign_ops(self, message): + foreign = [] + for op in self._ops: + if op.pipeline is not self: + foreign.append(op) + if foreign: + raise RuntimeError( + f"{message} because it contains operator(s) " + f"that were defined outside the pipeline scope:\n" + f"{[o.name for o in foreign]}\n" + f"All operators should be defined while the pipeline is set as the current " + f"pipeline. This happens automatically when defining the pipeline in a " + f"function decorated with `@pipeline_def`." + ) + # Graph is constructed by backtracking from the output edges and the edges marked as sinks def _build_graph(self, define_graph=None): if define_graph is not None: @@ -776,6 +818,10 @@ def contains_nested_datanode(nested): _data_node._check(outputs[i]) self._ops = _collect_ops(list(outputs) + self._sinks) + self._require_unique_names() + if self._enable_checkpointing: + self._require_no_foreign_ops("The pipeline does not support checkpointing") + self._graph_outputs = outputs self._setup_input_callbacks() self._disable_pruned_external_source_instances() @@ -960,6 +1006,11 @@ def _restore_state_from_checkpoint(self): self._iterator_data = external_ctx_cpt.iterator_data self._is_restored_from_checkpoint = True + def _next_op_id(self): + i = self._next_op_id_counter + self._next_op_id_counter += 1 + return i + def build(self): """Build the pipeline. @@ -2101,37 +2152,26 @@ def get_op_input_edges(op) -> List[DataNode]: else: yield inp - def get_op_outputs_num(): - # BSF traverse the graph first to learn, for each reachable operator in the graph, - # how many data-nodes/edges the operator contributes to - # (i.e. the number of outputs of the operator instance) - op_outputs_num = {} - edges = deque(output_nodes) - while edges: - current_edge = edges.popleft() - source_op = get_source_op(current_edge) - if source_op.id in op_outputs_num: - op_outputs_num[source_op.id] += 1 - else: - op_outputs_num[source_op.id] = 1 - source_op.check_args() - edges.extend(get_op_input_edges(source_op)) - return op_outputs_num - + visited = set() ops = [] - edges = deque(output_nodes) - op_total_outputs_num = get_op_outputs_num() - op_visited_outputs_num = {op_id: 0 for op_id in op_total_outputs_num} - while edges: - current_edge = edges.popleft() - source_op = get_source_op(current_edge) - op_visited_outputs_num[source_op.id] += 1 - # Actually visit the operator only when all the nodes it contributes to - # were already processed - if op_visited_outputs_num[source_op.id] == op_total_outputs_num[source_op.id]: - ops.append(source_op) - edges.extend(get_op_input_edges(source_op)) - ops.reverse() + + # Depth-first search returns the graph topologically sorted. + # We go over each operator's inputs before adding it to the list. + + def visit_op(op): + if id(op) in visited: + return + visited.add(id(op)) + op.check_args() + # visit conttributing inputs + for edge in get_op_input_edges(op): + visit_op(get_source_op(edge)) + # add the operator to the list of contributing ops + ops.append(op) + + for edge in output_nodes: + visit_op(get_source_op(edge)) + return ops diff --git a/dali/test/python/auto_aug/test_auto_augment.py b/dali/test/python/auto_aug/test_auto_augment.py index 8b97ed1b2ec..5c76ee8f51d 100644 --- a/dali/test/python/auto_aug/test_auto_augment.py +++ b/dali/test/python/auto_aug/test_auto_augment.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -210,8 +210,8 @@ def pipeline(): @params( (False, "cpu", 256), (False, "gpu", 512), - (True, "cpu", 400), - (True, "gpu", 348), + (True, "cpu", 2000), + (True, "gpu", 2000), ) def test_sub_policy(randomly_negate, dev, batch_size): num_magnitude_bins = 10 @@ -305,9 +305,8 @@ def third(data, op_id_mag_id): expected_counts.append(expected) stat = chisquare(counts, expected_counts) # assert that the magnitudes negation looks independently enough - # (0.05 <=), but also that it is not too ideal (i.e. like all - # cases happening exactly the expected number of times) - assert 0.05 <= stat.pvalue <= 0.95, f"{stat}" + # (0.01 <=) + assert 0.01 <= stat.pvalue, f"{stat}" @params(("cpu",), ("gpu",)) @@ -397,7 +396,7 @@ def second_stage_only(data, op_id_mag_id): ) policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies) - p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy) + p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy, seed=1234) p.build() for _ in range(5): @@ -415,10 +414,8 @@ def second_stage_only(data, op_id_mag_id): actual.append(actual_counts[mags]) expected.append(expected_counts[mags]) stat = chisquare(actual, expected) - # assert that the magnitudes negation looks independently enough - # (0.05 <=), but also that it is not too ideal (i.e. like all - # cases happening exactly the expected number of times) - assert 0.05 <= stat.pvalue <= 0.95, f"{stat}" + # assert that the magnitudes negation looks independently enough (0.01 <=) + assert 0.01 <= stat.pvalue, f"{stat}" def test_policy_presentation(): diff --git a/dali/test/python/auto_aug/test_rand_augment.py b/dali/test/python/auto_aug/test_rand_augment.py index 83db28ce4ae..77d9b8eb41c 100644 --- a/dali/test/python/auto_aug/test_rand_augment.py +++ b/dali/test/python/auto_aug/test_rand_augment.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -312,7 +312,7 @@ def pipeline(): actual.append(actual_count[out]) expected.append(expected_counts[out]) stat = chisquare(actual, expected) - assert 0.01 <= stat.pvalue <= 0.99, f"{stat} {actual} {expected}" + assert 0.01 <= stat.pvalue, f"{stat} {actual} {expected}" def test_wrong_params_fail(): diff --git a/dali/test/python/auto_aug/test_trivial_augment.py b/dali/test/python/auto_aug/test_trivial_augment.py index 03b52164863..cda65792484 100644 --- a/dali/test/python/auto_aug/test_trivial_augment.py +++ b/dali/test/python/auto_aug/test_trivial_augment.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -236,4 +236,4 @@ def pipeline(): stat = chisquare(actual, expected) stats.append(stat) mean_p_val = sum(stat.pvalue for stat in stats) / len(stats) - assert 0.05 <= mean_p_val <= 0.95, f"{mean_p_val} {stat} {actual} {expected}" + assert 0.01 <= mean_p_val, f"{mean_p_val} {stat} {actual} {expected}" diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index e79613e6f65..ff0a0b9d55b 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -2155,3 +2155,52 @@ def create_test_package(output_dtype=None, output_ndim=None, cast_labels=False): create_test_package(output_dtype=int) with assert_raises(ValueError, glob="*types.NO_TYPE*"): create_test_package(output_dtype=types.NO_TYPE) + + +def test_dangling_subgraph(): + # This test ensures that operators defined outside of the pipeline are assigned + # same ids when the pipeline is built. + + pipes = [] + op1 = fn.external_source( + source=[np.int32([1, 2, 3]), np.int32([4, 5, 6])], cycle=True, batch=False + ) + op2 = fn.external_source( + source=[np.int32([6, 5, 4]), np.int32([3, 2, 1])], cycle=True, batch=False + ) + for i in range(2): + with Pipeline(batch_size=1, device_id=None, num_threads=1, seed=123) as p: + ret1 = op1 + op2 + p.set_outputs(ret1) + pipes.append(p) + + pipes[0].build() # names and ids of op1 and op2 are adjusted here + pipes[1].build() # names and ids of op3 and op4 are adjusted here + + ser1 = pipes[0].serialize() + ser2 = pipes[1].serialize() + assert ser1 == ser2 + + (o1,) = pipes[0].run() + (o2,) = pipes[1].run() + assert np.array_equal(o1[0], np.int32([7, 7, 7])) + assert np.array_equal(o2[0], np.int32([7, 7, 7])) + + +def test_regression_without_current_pipeline1(): + def get_pipe(device): + pipe = Pipeline(batch_size=1, num_threads=1, device_id=0) + data = fn.external_source(source=[1, 2, 3], batch=False, cycle=True, device=device) + dist = data + fn.random.normal() + pipe.set_outputs(dist) + return pipe + + p = get_pipe("gpu") + p.build() + + +def test_regression_without_current_pipeline2(): + pipe = Pipeline(batch_size=4, num_threads=3, device_id=0) + data = fn.external_source(source=[1, 2, 3], batch=False, cycle=True) + pipe.set_outputs(data.gpu()) + pipe.build()