Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Avoid keeping a copy of graph in executor. #888

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions tensorflow/core/common_runtime/direct_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,8 @@ Status DirectSession::RunInternal(
std::unordered_map<string, const Graph*> device_to_graph;
for (const PerPartitionExecutorsAndLib& partition :
executors_and_keys->items) {
const Graph* graph = partition.graph;
const string device = partition.flib->device()->name();
const Graph* graph = partition.graph.get();
const string& device = partition.flib->device()->name();
device_to_graph[device] = graph;
}

Expand All @@ -1348,7 +1348,7 @@ Status DirectSession::RunInternal(
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
for (const auto& item : executors_and_keys->items) {
TF_RETURN_IF_ERROR(
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
}
}

Expand Down Expand Up @@ -2006,13 +2006,12 @@ Status DirectSession::CreateExecutors(
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(),
partition_graph.get()));
// NewLocalExecutor takes ownership of partition_graph.
item->graph = partition_graph.get();
item->graph = std::move(partition_graph);
item->executor = nullptr;
item->device = device;
auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(NewExecutor(
executor_type, params, std::move(partition_graph), &item->executor));
executor_type, params, *item->graph, &item->executor));
}

// Cache the mapping from input/output names to graph elements to
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/direct_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DirectSession : public Session {
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
Graph* graph = nullptr; // not owned.
std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned.
FunctionLibraryRuntime* flib = nullptr; // not owned.
std::unique_ptr<Executor> executor;
Expand Down
24 changes: 10 additions & 14 deletions tensorflow/core/common_runtime/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,8 @@ typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;

class ExecutorImpl : public Executor {
public:
explicit ExecutorImpl(const LocalExecutorParams& p,
std::unique_ptr<const Graph> g)
: immutable_state_(p, std::move(g)) {
explicit ExecutorImpl(const LocalExecutorParams& p)
: immutable_state_(p) {
Status s = ReadBoolFromEnvVar(
nodestats::enable_cost_model_env_name, true, &enable_cost_model_);
if (!s.ok()) {
Expand All @@ -171,12 +170,11 @@ class ExecutorImpl : public Executor {
}
}

Status Initialize() {
TF_RETURN_IF_ERROR(immutable_state_.Initialize());
kernel_stats_.Initialize(immutable_state_.graph_view(),
immutable_state_.graph());
Status Initialize(const Graph& graph) {
TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
kernel_stats_.Initialize(immutable_state_.graph_view(), &graph);
if (immutable_state_.params().run_cost_model_executor) {
immutable_state_.InitializeScheduleInfo(&kernel_stats_);
immutable_state_.InitializeScheduleInfo(&kernel_stats_, graph);
}
return Status::OK();
}
Expand Down Expand Up @@ -1761,10 +1759,9 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
} // namespace

Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
const Status s = impl->Initialize();
const Graph& graph, Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params);
const Status s = impl->Initialize(graph);
if (s.ok()) {
*executor = impl;
} else {
Expand Down Expand Up @@ -1796,8 +1793,7 @@ class DefaultExecutorRegistrar {

private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
Executor* ret = nullptr;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/common_runtime/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ struct LocalExecutorParams {
// "params" provides a set of context for the executor. We expect that
// different context would provide different implementations.
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor);
const Graph& graph, Executor** executor);

// A class to help run multiple executors in parallel and wait until
// all of them are complete.
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/common_runtime/executor_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ Status ExecutorFactory::GetFactory(const string& executor_type,
}

Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) {
ExecutorFactory* factory = nullptr;
TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/core/common_runtime/executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct LocalExecutorParams;
class ExecutorFactory {
public:
virtual Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const Graph& graph,
std::unique_ptr<Executor>* out_executor) = 0;
virtual ~ExecutorFactory() {}

Expand All @@ -42,8 +42,7 @@ class ExecutorFactory {
};

Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor);

} // namespace tensorflow
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ExecutorTest : public ::testing::Test {
return Status::OK();
};
delete exec_;
TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_));
TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
cost_runner_ = [this](std::function<void()> fn, int64 cost)
{ thread_pool_->CostSchedule(fn, cost); };
Expand Down
9 changes: 4 additions & 5 deletions tensorflow/core/common_runtime/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// object, and an executor is created for the graph.
struct Item {
uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec.
std::unique_ptr<const Graph> graph = nullptr;
const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
Expand Down Expand Up @@ -962,14 +962,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
};
params.rendezvous_factory = (*item)->rendezvous_factory;
params.session_metadata = session_metadata_;
Graph* graph = g.get();
std::unique_ptr<Executor> exec;
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
{
// Guard item since it is already inserted in items_.
mutex_lock l(mu_);
if ((*item)->exec == nullptr) {
(*item)->graph = graph;
(*item)->graph = std::move(g);
(*item)->exec = exec.release();
}
}
Expand Down Expand Up @@ -1265,7 +1264,7 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status s = GetOrCreateItem(local_handle, &item);
if (s.ok()) {
return tensorflow::DebugString(item->graph);
return tensorflow::DebugString(item->graph.get());
} else {
return s.ToString();
}
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/core/common_runtime/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class FunctionTest : public ::testing::Test {
return Status::OK();
};
Executor* exec;
TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec));
TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
exec_.reset(exec);
}

Expand Down Expand Up @@ -603,8 +603,7 @@ class DummyExecutorRegistrar {

private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
return errors::Internal("This is a dummy.");
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/graph_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,

Executor* executor;
TF_RETURN_IF_ERROR(
NewLocalExecutor(params, std::move(graph_to_run), &executor));
NewLocalExecutor(params, *graph_to_run, &executor));
std::unique_ptr<Executor> executor_unref(executor);

Executor::Args args;
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/common_runtime/immutable_executor_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
}
}

Status ImmutableExecutorState::Initialize() {
const Graph& graph = *graph_;
Status ImmutableExecutorState::Initialize(const Graph& graph) {
TF_RETURN_IF_ERROR(gview_.Initialize(&graph));

// Build the information about frames in this subgraph.
Expand Down Expand Up @@ -263,8 +262,10 @@ Status ImmutableExecutorState::Initialize() {
return gview_.SetAllocAttrs(&graph, params_.device);
}

Status ImmutableExecutorState::InitializeScheduleInfo(ExecutorInternal::KernelStats* kernel_stats) {
for (const Node* n : graph_->nodes()) {
Status ImmutableExecutorState::InitializeScheduleInfo(
ExecutorInternal::KernelStats* kernel_stats,
const Graph& graph) {
for (const Node* n : graph.nodes()) {
if (IsSink(n)) continue;
const int id = n->id();
NodeItem* item = gview_.node(id);
Expand Down
13 changes: 6 additions & 7 deletions tensorflow/core/common_runtime/immutable_executor_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ class ImmutableExecutorState {
int32 parallel_iterations;
};

explicit ImmutableExecutorState(const LocalExecutorParams& p,
std::unique_ptr<const Graph> g)
: params_(p), gview_(), graph_(std::move(g)) {}
explicit ImmutableExecutorState(const LocalExecutorParams& p)
: params_(p), gview_() {}
~ImmutableExecutorState();

Status Initialize();
Status Initialize(const Graph& graph);

Status InitializeScheduleInfo(ExecutorInternal::KernelStats* kernel_stats);
Status InitializeScheduleInfo(
ExecutorInternal::KernelStats* kernel_stats,
const Graph& graph);

// Process all Nodes in the current graph, attempting to infer the
// memory allocation attributes to be used wherever they may allocate
Expand All @@ -93,7 +94,6 @@ class ImmutableExecutorState {

const LocalExecutorParams& params() const { return params_; }
const GraphView& graph_view() const { return gview_; }
const Graph* graph() const { return graph_.get(); }
const std::vector<PendingCounts::Handle>& pending_ids() const {
return pending_ids_;
}
Expand Down Expand Up @@ -138,7 +138,6 @@ class ImmutableExecutorState {
// Owned.
LocalExecutorParams params_;
GraphView gview_;
std::unique_ptr<const Graph> graph_;
bool requires_control_flow_;
std::vector<PendingCounts::Handle> pending_ids_;

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Benchmark::Benchmark(const string& device, Graph* g,

if (init) {
std::unique_ptr<Executor> init_exec;
TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
TF_CHECK_OK(NewExecutor(executor_type, params, *init,
&init_exec));
Executor::Args args;
args.rendezvous = rendez_;
Expand All @@ -101,7 +101,7 @@ Benchmark::Benchmark(const string& device, Graph* g,
}

TF_CHECK_OK(
NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
NewExecutor(executor_type, params, *g, &exec_));
}

Benchmark::~Benchmark() {
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/core/distributed_runtime/graph_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ GraphMgr::Item::~Item() {
for (const auto& unit : this->units) {
CHECK_NOTNULL(unit.device);
if (!graph_mgr->skip_cost_models_) {
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
}
delete unit.root;
unit.device->op_segment()->RemoveHold(this->session);
Expand Down Expand Up @@ -283,13 +283,13 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
TF_RETURN_IF_ERROR(
EnsureMemoryTypes(DeviceType(unit->device->device_type()),
unit->device->name(), subgraph.get()));
unit->graph = subgraph.get();
unit->graph = std::move(subgraph);
unit->build_cost_model = graph_options.build_cost_model();
if (unit->build_cost_model > 0) {
skip_cost_models_ = false;
}
TF_RETURN_IF_ERROR(
NewLocalExecutor(params, std::move(subgraph), &unit->root));
NewLocalExecutor(params, *unit->graph, &unit->root));
}
return Status::OK();
}
Expand Down Expand Up @@ -626,14 +626,14 @@ void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
std::unordered_map<string, const Graph*> device_to_graph;
for (const auto& unit : item->units) {
if (unit.build_cost_model > 0) {
device_to_graph[unit.device->name()] = unit.graph;
device_to_graph[unit.device->name()] = unit.graph.get();
}
}
collector->BuildCostModel(&cost_model_manager_, device_to_graph);

if (cost_graph != nullptr) {
for (const auto& unit : item->units) {
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
.IgnoreError();
}
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/distributed_runtime/graph_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class GraphMgr {
typedef GraphMgr ME;

struct ExecutionUnit {
Graph* graph = nullptr; // not owned.
std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned.
Executor* root = nullptr; // not owned.
FunctionLibraryRuntime* lib = nullptr; // not owned.
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/graph/graph_constructor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class GraphConstructor {
Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(in.allow_internal_ops),
expect_device_spec(in.expect_device_spec),
uniquify_names(false),
uniquify_prefix(false),
skip_mapped_nodes(false),
importing(false),
validate_colocation_constraints(false) {}
Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/data/dataset_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ Status DatasetOpsTestBase::RunFunction(
};

Executor* cur_exec;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &cur_exec));
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
exec.reset(cur_exec);
FunctionCallFrame frame(arg_types, ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args));
Expand Down
13 changes: 5 additions & 8 deletions tensorflow/core/kernels/data/single_threaded_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,11 @@ class SingleThreadedExecutorRegistrar {

private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
Executor* ret;
TF_RETURN_IF_ERROR(
NewSingleThreadedExecutor(params, std::move(graph), &ret));
NewSingleThreadedExecutor(params, graph, &ret));
out_executor->reset(ret);
return Status::OK();
}
Expand All @@ -372,11 +371,9 @@ static SingleThreadedExecutorRegistrar registrar;
} // namespace

Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor) {
std::unique_ptr<SingleThreadedExecutorImpl> impl =
absl::make_unique<SingleThreadedExecutorImpl>(params);
TF_RETURN_IF_ERROR(impl->Initialize(*graph));
const Graph& graph, Executor** executor) {
auto impl = absl::make_unique<SingleThreadedExecutorImpl>(params);
TF_RETURN_IF_ERROR(impl->Initialize(graph));
*executor = impl.release();
return Status::OK();
}
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/data/single_threaded_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ namespace data {
// The single-threaded executor is primarily suitable for executing simple
// TensorFlow functions, such as one might find in a `tf.data` pipeline.
Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor);
const Graph& graph, Executor** executor);

} // namespace data
} // namespace tensorflow
Expand Down
Loading