Skip to content

Commit

Permalink
add Dict2Dict test
Browse files Browse the repository at this point in the history
  • Loading branch information
jw3126 committed Mar 6, 2022
1 parent d70d650 commit b176667
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ function Base.show(io::IO, o::InferenceSession)
input_names: $(input_names(o))
output_names: $(output_names(o))
execution_provider: $(repr(o.execution_provider))
"""
)
""")
end

input_names(o::InferenceSession) = o._input_names
Expand Down
Binary file added test/data/Dict2Dict.onnx
Binary file not shown.
16 changes: 16 additions & 0 deletions test/test_highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ using ONNXRunTime: juliatype
@test out[1,2,2] == 0
@test out[1,2,3] == 0
end
@testset "Dict2Dict.onnx" begin
path = OX.testdatapath("Dict2Dict.onnx")
model = OX.load_inference(path, execution_provider=:cpu)
@test OX.input_names(model) == ["x", "y"]
@test OX.output_names(model) == ["x_times_y", "x_plus_y", "x_minus_y", "x_plus_1", "y_plus_2"]
nb = rand(1:10)
x = randn(Float32, nb,3)
y = randn(Float32, nb,3)
input = (;x,y)
out = model(input)
@test out.x_plus_y x .+ y
@test out.x_minus_y x .- y
@test out.x_times_y x .* y
@test out.x_plus_1 x .+ 1
@test out.y_plus_2 y .+ 2
end
end


Expand Down

0 comments on commit b176667

Please sign in to comment.