Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,13 +808,12 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
if (output.second->Exists()) {
auto name = output.second->Name();
if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) {
// if graph is split we dont know if output is used so we need this, otherwise if the graph isn't split
// then we can safely assume this output is a dangling output from a node and to discard it as part of the
// final graph output
if(is_graph_split)
{
output_names.push_back(name);
}
// if graph is split we dont know if output is used so we need this, otherwise if the graph isn't split
// then we can safely assume this output is a dangling output from a node and to discard it as part of the
// final graph output
if (is_graph_split) {
output_names.push_back(name);
}
} else {
graph_out_names.insert(name);
}
Expand Down Expand Up @@ -1316,7 +1315,23 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
input_shapes.push_back(tensor_shape->dim(j).dim_value());
}
}
model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr");
// capture flags outside of name/inputs that are used when models are compiled
// Each of these will change the final compiled model and need to be captured to ensure
// hash uses the quantization flags and modes
auto get_quant_and_tune_flags = [=]() {
std::vector<std::int64_t> data_out{};

data_out.push_back(static_cast<int64_t>(fp16_enable_));
data_out.push_back(static_cast<int64_t>(fp8_enable_));
data_out.push_back(static_cast<int64_t>(bf16_enable_));
data_out.push_back(static_cast<int64_t>(int8_enable_));
data_out.push_back(static_cast<int64_t>(mem_limit_));
data_out.push_back(static_cast<int64_t>(exhaustive_tune_));

return data_out;
};

model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + "-" + make_hash(get_quant_and_tune_flags()) + ".mxr");
}

// map parameter input name to index
Expand Down Expand Up @@ -1385,7 +1400,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_,
int8_calibration_cache_available_, dynamic_range_map_,
model_cache_path_.string(), dump_model_ops_};
model_cache_path_.string(), dump_model_ops_, exhaustive_tune_, mem_limit_};
*state = p.release();
return 0;
};
Expand All @@ -1411,6 +1426,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
bool fp8_enable = mgx_state->fp8_enable;
bool int8_enable = mgx_state->int8_enable;
bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available;
bool exhaustive_tune = mgx_state->exhaustive_tune;
size_t mem_limit = mgx_state->mem_limit;

// mean no program at all, so need to get the input shape info
// from input data
Expand Down Expand Up @@ -1469,7 +1486,19 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
std::filesystem::path model_cache_file;
// empty cache path means the MXR caching is disabled - always compile
if (!model_cache_path_.empty()) {
model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr");
auto get_quant_and_tune_flags = [=]() {
std::vector<std::int64_t> data_out{};

data_out.push_back(static_cast<int64_t>(fp16_enable));
data_out.push_back(static_cast<int64_t>(fp8_enable));
data_out.push_back(static_cast<int64_t>(bf16_enable));
data_out.push_back(static_cast<int64_t>(int8_enable));
data_out.push_back(static_cast<int64_t>(mem_limit));
data_out.push_back(static_cast<int64_t>(exhaustive_tune));

return data_out;
};
model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + "-" + make_hash(get_quant_and_tune_flags()) + ".mxr");
}
if (!load_precompiled_model(prog, model_cache_file)) {
LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct MIGraphXFuncState {
std::filesystem::path model_cache_dir;
bool dump_model_ops = false;
bool exhaustive_tune = false;
size_t mem_limit;
};

// Logical device representation.
Expand Down
Loading