Skip to content

Commit

Permalink
savepoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Jan 13, 2025
1 parent a29f078 commit c310db9
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO)
# "-g3" generates DWARF format debug info.
# NOTE: With debug info enabled, web assembly artifacts will be very huge (>1GB). So we offer an option to build without debug info.
set(CMAKE_CXX_FLAGS_DEBUG "-g3")
set(CMAKE_CXX_FLAGS_DEBUG "-g2")
else()
set(CMAKE_CXX_FLAGS_DEBUG "-g2")
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,17 +263,22 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
//std::cout << "Calling f32Multiply\n";
// should split in parts and call ctx.ParallelFor just on the rows part

#if 0
// rowsA = M
// width = K
// colsB = N
#if 0
size_t rowsA = static_cast<size_t>(helper.M());

if (rowsA > 1) {
size_t width = static_cast<size_t>(helper.K());
size_t colsB = static_cast<size_t>(helper.N());

const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());
//std::cout << "Calling GeckoMatmulIntegerToFloat\n";
//int threads = concurrency::ThreadPool::DegreeOfParallelism(ctx->GetOperatorThreadPool());
//std::cout << "degree of parallelism: " << threads << "\n";
//std::cout << "batch size: " << num_gemms << "\n";



GeckoMatmulIntegerToFloat(a_data,
a_zp,
Expand All @@ -291,7 +296,7 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
#endif
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());

//}
// }

//
/*
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ Return Value:
{
const ptrdiff_t ThreadIdM = ThreadId / WorkBlock->ThreadCountN;
const ptrdiff_t ThreadIdN = ThreadId % WorkBlock->ThreadCountN;

//
// Partition the operation along the M dimension.
//
Expand Down Expand Up @@ -197,16 +196,11 @@ MlasGemmBatch(
WorkBlock.ThreadCountN = 1;
}
TargetThreadCount = ThreadsPerGemm * BatchN;
//std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << std::endl;
//std::cout << "TargetThreadCount: " << TargetThreadCount << std::endl;
//std::cout << "MaximumThreadCount: " << MaximumThreadCount << std::endl;



MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) {
const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
//std::cout << "gemm_i: " << gemm_i << " blk_i: " << blk_i << std::endl;
MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i);
});
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,12 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
data[i].alpha = alpha_attr_;
data[i].beta = 0.0f;
}

//auto start = std::chrono::steady_clock::now();
MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
M, N, K, data.data(), max_len, thread_pool);
//auto end = std::chrono::steady_clock::now();
//std::cout << "MatMul<float>," << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << "," << max_len << std::endl;
}
return Status::OK();
}
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,8 @@ common::Status InferenceSession::Initialize() {
if (session_profiler_.IsEnabled()) {
tp = session_profiler_.Start();
}
//std::cout << "session Initialize" << std::endl;
//auto startInit = std::chrono::steady_clock::now();

ORT_TRY {
LOGS(*session_logger_, INFO) << "Initializing session.";
Expand All @@ -1720,6 +1722,9 @@ common::Status InferenceSession::Initialize() {
}

// Verify that there are no external initializers in the graph if external data is disabled.
//std::cout << "session Initialize loading main graph" << std::endl;


onnxruntime::Graph& graph = model_->MainGraph();
#ifdef DISABLE_EXTERNAL_INITIALIZERS
const InitializedTensorSet& initializers = graph.GetAllInitializedTensors();
Expand Down Expand Up @@ -1767,6 +1772,8 @@ common::Status InferenceSession::Initialize() {
TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity");
session_activity_started_ = true;
#endif
//std::cout << "session Initialize - creating state" << std::endl;


// now that we have all the execution providers, create the session state
session_state_ = std::make_unique<SessionState>(
Expand Down Expand Up @@ -1824,6 +1831,10 @@ common::Status InferenceSession::Initialize() {
}();

if (!loading_ort_format) {
//std::cout << "session Initialize not using ort" << std::endl;



#if !defined(ORT_MINIMAL_BUILD)
const auto minimal_build_opt_config_value = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigMinimalBuildOptimizations, "");
Expand All @@ -1845,6 +1856,10 @@ common::Status InferenceSession::Initialize() {
*session_logger_));

#ifdef USE_DML
// std::cout << "session Initialize using DML" << std::endl;



const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider);

if (dmlExecutionProvider) {
Expand Down Expand Up @@ -1900,10 +1915,16 @@ common::Status InferenceSession::Initialize() {
#endif

// apply any transformations to the main graph and any subgraphs
//auto start = std::chrono::steady_clock::now();
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, saving_ort_format));
//auto end = std::chrono::steady_clock::now();
//std::cout << "Graph transformations took " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms" << std::endl;

// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
//start = std::chrono::steady_clock::now();
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
//end = std::chrono::steady_clock::now();
//std::cout << "Graph resolution took " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms" << std::endl;

// Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP.
//
Expand Down Expand Up @@ -2052,6 +2073,9 @@ common::Status InferenceSession::Initialize() {
"Loading anything other than ORT format models is not enabled in this build."));
#endif // !defined(ORT_MINIMAL_BUILD)
} else {
//std::cout << "session Initialize - loading ort" << std::endl;


ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
*session_state_, session_options_.config_options, *session_logger_));

Expand Down Expand Up @@ -2171,6 +2195,8 @@ common::Status InferenceSession::Initialize() {
}
}

//auto endInitialization = std::chrono::steady_clock::now();
//std::cout << "session Initialize - Initialization time: " << std::chrono::duration_cast<std::chrono::milliseconds>(endInitialization - startInit).count() << " ms" << std::endl;
return status;
}
#if defined(_MSC_VER) && !defined(__clang__)
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
// Modifications Copyright (c) Microsoft.

#include <iostream>
#include <chrono>
#include "core/util/math_cpuonly.h"
#include "core/util/math.h"
#include "core/framework/float16.h"
Expand Down

0 comments on commit c310db9

Please sign in to comment.