From 9c7528933f6245fcf7ac5ca63dc23e5e6bc1055b Mon Sep 17 00:00:00 2001 From: Oren Ben-Kiki Date: Mon, 20 May 2024 16:20:06 +0300 Subject: [PATCH] Fix dtype of Clamp. --- deps/jet.sh | 2 +- docs/v0.1.0/.documenter-siteinfo.json | 2 +- src/operations.jl | 38 ++++++++++++++++++++------- test/operations.jl | 26 +++++++++++++++--- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/deps/jet.sh b/deps/jet.sh index 0c3a4d1..e1daacd 100755 --- a/deps/jet.sh +++ b/deps/jet.sh @@ -1,3 +1,3 @@ #!/bin/bash set -e -o pipefail -julia --color=no deps/jet.jl 2>&1 | python3 deps/jet.py +JULIA_DEBUG="" julia --color=no deps/jet.jl 2>&1 | python3 deps/jet.py diff --git a/docs/v0.1.0/.documenter-siteinfo.json b/docs/v0.1.0/.documenter-siteinfo.json index 18d11c4..e75c094 100644 --- a/docs/v0.1.0/.documenter-siteinfo.json +++ b/docs/v0.1.0/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.10.0","generation_timestamp":"2024-05-20T15:09:19","documenter_version":"1.4.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.10.0","generation_timestamp":"2024-05-20T16:19:41","documenter_version":"1.4.0"}} \ No newline at end of file diff --git a/src/operations.jl b/src/operations.jl index 323e845..c696847 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -478,17 +478,21 @@ Element-wise operation that converts every element to a value inside a range. At least one of `min` and `max` must be specified. """ struct Clamp <: EltwiseOperation + dtype::Maybe{Type} min::Float64 max::Float64 end @query_operation Clamp -function Clamp(; min::StorageNumber = -Inf, max::StorageNumber = Inf)::Clamp +function Clamp(; dtype::Maybe{Type} = nothing, min::StorageNumber = -Inf, max::StorageNumber = Inf)::Clamp @assert min < max - return Clamp(Float64(min), Float64(max)) + return Clamp(dtype, Float64(min), Float64(max)) end function Clamp(operation_name::Token, parameters_values::Dict{String, Token})::Clamp + dtype = parse_parameter_value(operation_name, "eltwise", parameters_values, "dtype", nothing) do parameter_value + return parse_number_dtype_value(operation_name, "dtype", parameter_value) + end min = parse_parameter_value(operation_name, "eltwise", parameters_values, "min", -Inf) do parameter_value return parse_number_value(operation_name, "min", parameter_value, Float64) end @@ -499,21 +503,37 @@ function Clamp(operation_name::Token, parameters_values::Dict{String, Token})::C end return value end - return Clamp(min, max) + return Clamp(dtype, min, max) +end + +function is_int(value::Float64)::Bool + return value == -Inf || value == Inf || isinteger(value) +end + +function dtype_for_clamp(operation::Clamp, input_type::Type)::Type + if operation.dtype !== nothing + return operation.dtype + elseif input_type <: Integer && is_int(operation.min) && is_int(operation.max) + return int_dtype_for(input_type, nothing) + else + return float_dtype_for(input_type, nothing) + end end function compute_eltwise( operation::Clamp, - input::Union{StorageMatrix{T}, StorageVector{T}}, -)::Union{StorageMatrix{T}, StorageVector{T}} where {T <: StorageNumber} - output = copy_array(input) - clamp!(output, operation.min, operation.max) + input::Union{StorageMatrix, StorageVector}, +)::Union{StorageMatrix, StorageVector} + dtype = dtype_for_clamp(operation, eltype(input)) + output = similar(input, dtype) + output .= clamp.(input, operation.min, operation.max) return output end -function compute_eltwise(operation::Clamp, input::T)::T where {T <: StorageNumber} +function compute_eltwise(operation::Clamp, input::T)::StorageNumber where {T <: StorageNumber} + dtype = dtype_for_clamp(operation, T) output = clamp(input, operation.min, operation.max) - return T(output) + return dtype(output) end """ diff --git a/test/operations.jl b/test/operations.jl index 01d150b..767444a 100644 --- a/test/operations.jl +++ b/test/operations.jl @@ -97,9 +97,29 @@ nested_test("operations") do end nested_test("vector") do - set_vector!(daf, "cell", "value", [-1.7, 2.3]) - @test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1.7, 0.0], Float64) - @test with_type(daf["/ cell : value % Clamp min 0"]) == ([0.0, 2.3], Float64) + nested_test("float") do + set_vector!(daf, "cell", "value", [-1.7, 2.3]) + @test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1.7, 0.0], Float64) + @test with_type(daf["/ cell : value % Clamp min 0"]) == ([0.0, 2.3], Float64) + end + + nested_test("int") do + set_vector!(daf, "cell", "value", [-1, 2]) + @test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1, 0], Int64) + @test with_type(daf["/ cell : value % Clamp min 0"]) == ([0, 2], Int64) + end + + nested_test("mix") do + set_vector!(daf, "cell", "value", [-1, 2]) + @test with_type(daf["/ cell : value % Clamp max 0.5"]) == ([-1.0, 0.5], Float64) + @test with_type(daf["/ cell : value % Clamp min 0.5"]) == ([0.5, 2.0], Float64) + end + + nested_test("dtype") do + set_vector!(daf, "cell", "value", [-1, 2]) + @test with_type(daf["/ cell : value % Clamp dtype Float32 max 0"]) == ([-1, 0], Float32) + @test with_type(daf["/ cell : value % Clamp dtype Float32 min 0"]) == ([0, 2], Float32) + end end nested_test("matrix") do