Skip to content

Commit

Permalink
remove timed enable multi-threading (#14)
Browse files Browse the repository at this point in the history
* remove TimerOutputs
  • Loading branch information
Moelf authored Mar 24, 2022
1 parent b176667 commit 9b30ed5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 66 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ONNXRunTime"
uuid = "e034b28e-924e-41b2-b98f-d2bbeb830c6a"
authors = ["Jan Weidner <[email protected]> and contributors"]
version = "0.2.3"
version = "0.3.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -12,15 +12,13 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
ArgCheck = "2"
CEnum = "0.4"
DataStructures = "0.18"
DocStringExtensions = "0.8"
Requires = "1"
TimerOutputs = "0.5"
julia = "1.6"

[extras]
Expand Down
3 changes: 0 additions & 3 deletions src/ONNXRunTime.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
module ONNXRunTime
using Requires:@require
import TimerOutputs

const TIMER = TimerOutputs.TimerOutput()

function _perm(arr::AbstractArray{T,N}) where {T,N}
ntuple(i->N+1-i, N)
Expand Down
13 changes: 4 additions & 9 deletions src/capi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ This module closely follows the offical onnxruntime [C-API](https://github.com/m
See [here](https://github.com/microsoft/onnxruntime-inference-examples/blob/d031f879c9a8d33c8b7dc52c5bc65fe8b9e3960d/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c) for a C code example.
"""
module CAPI
using ONNXRunTime: TIMER, reversedims_lazy
using TimerOutputs: @timeit
using ONNXRunTime: reversedims_lazy

using DocStringExtensions
using Libdl
Expand Down Expand Up @@ -901,14 +900,10 @@ end
$TYPEDSIGNATURES
"""
@timeit TIMER function GetTensorMutableData(api::OrtApi, tensor::OrtValue)::AbstractArray
function GetTensorMutableData(api::OrtApi, tensor::OrtValue)::AbstractArray
GC.@preserve tensor begin
@timeit TIMER "unsafe_GetTensorMutableData" begin
data_owned_by_tensor::PermutedDimsArray = unsafe_GetTensorMutableData(api, tensor)
end
@timeit TIMER "copy" begin
reversedims_lazy(copy(parent(data_owned_by_tensor)))
end
data_owned_by_tensor::PermutedDimsArray = unsafe_GetTensorMutableData(api, tensor)
reversedims_lazy(copy(parent(data_owned_by_tensor)))
end
end

Expand Down
90 changes: 39 additions & 51 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using ArgCheck
using LazyArtifacts
using DataStructures: OrderedDict
using DocStringExtensions
using TimerOutputs: @timeit, TimerOutput
################################################################################
##### testdatapath
################################################################################
Expand All @@ -29,7 +28,6 @@ struct InferenceSession
allocator::OrtAllocator
_input_names::Vector{String}
_output_names::Vector{String}
timer::TimerOutput
end
function Base.show(io::IO, o::InferenceSession)
print(io,
Expand Down Expand Up @@ -60,8 +58,8 @@ end
$TYPEDSIGNATURES
"""
function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
envname::AbstractString="defaultenv", timer=TIMER,
)::InferenceSession
envname::AbstractString="defaultenv",
)::InferenceSession
api = GetApi(;execution_provider)
env = CreateEnv(api, name=envname)
if execution_provider === :cpu
Expand All @@ -87,17 +85,14 @@ function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
@check allunique(_input_names)
@check allunique(_output_names)
return InferenceSession(api, execution_provider, session, meminfo, allocator,
_input_names,
_output_names,
timer,
)
_input_names,
_output_names,
)
end

@timeit o.timer function make_input_tensor(o::InferenceSession, inputs, key)
function make_input_tensor(o::InferenceSession, inputs, key)
arr = inputs[keytype(inputs)(key)]
@timeit o.timer "clayout" begin
cstorage = vec(reversedims(arr)::Array)
end
cstorage = vec(reversedims(arr)::Array)
CreateTensorWithDataAsOrtValue(o.api, o.meminfo, cstorage, size(arr))
end

Expand Down Expand Up @@ -136,46 +131,39 @@ be a `NamedTuple` or an `AbstractDict`. Optionally `output_names` can be passed.
In this case only the outputs whose name is contained in `output_names` are computed.
"""
function (o::InferenceSession)(
inputs,
output_names=nothing
)
timer = o.timer
@timeit timer "pre CAPI.Run" begin
if output_names === nothing
output_names = @__MODULE__().output_names(o)
end
@argcheck o.execution_provider in EXECUTION_PROVIDERS
@argcheck eltype(output_names) <: Union{AbstractString, Symbol}
@argcheck keytype(inputs) <: Union{AbstractString, Symbol}
expected_input_names = ONNXRunTime.input_names(o)
for key in keys(inputs)
if !(String(key) in expected_input_names)
msg = """
Invalid input name.
Expected: $(expected_input_names)
Got: $(key)
"""
throw(ArgumentError(msg))
end
end
expected_output_names = ONNXRunTime.output_names(o)
for name in output_names
if !(String(name) in expected_output_names)
msg = """
Invalid output name.
Expected: $(expected_output_names)
Got: $(name)
"""
throw(ArgumentError(msg))
end
end
inp_names, input_tensors = prepare_inputs(o, inputs)
run_options = nothing
inputs,
output_names=nothing
)
if output_names === nothing
output_names = @__MODULE__().output_names(o)
end
@timeit timer "CAPI.Run" begin
output_tensors = Run(o.api, o.session, run_options, inp_names, input_tensors, output_names)
@argcheck o.execution_provider in EXECUTION_PROVIDERS
@argcheck eltype(output_names) <: Union{AbstractString, Symbol}
@argcheck keytype(inputs) <: Union{AbstractString, Symbol}
expected_input_names = ONNXRunTime.input_names(o)
for key in keys(inputs)
if !(String(key) in expected_input_names)
msg = """
Invalid input name.
Expected: $(expected_input_names)
Got: $(key)
"""
throw(ArgumentError(msg))
end
end
@timeit timer "post CAPI.Run" begin
make_output(o, inputs, output_names, output_tensors)
expected_output_names = ONNXRunTime.output_names(o)
for name in output_names
if !(String(name) in expected_output_names)
msg = """
Invalid output name.
Expected: $(expected_output_names)
Got: $(name)
"""
throw(ArgumentError(msg))
end
end
inp_names, input_tensors = prepare_inputs(o, inputs)
run_options = nothing
output_tensors = Run(o.api, o.session, run_options, inp_names, input_tensors, output_names)
make_output(o, inputs, output_names, output_tensors)
end

3 comments on commit 9b30ed5

@jw3126
Copy link
Owner

@jw3126 jw3126 commented on 9b30ed5 Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/57257

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 9b30ed5e6656152e9aaaacda4919e35630ce75f4
git push origin v0.3.0

@jw3126
Copy link
Owner

@jw3126 jw3126 commented on 9b30ed5 Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Moelf thanks for the PR!

Please sign in to comment.