Skip to content

Commit

Permalink
Fix dtype of Clamp.
Browse files Browse the repository at this point in the history
  • Loading branch information
orenbenkiki committed May 20, 2024
1 parent a95ccd4 commit 9c75289
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
2 changes: 1 addition & 1 deletion deps/jet.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash
set -e -o pipefail
julia --color=no deps/jet.jl 2>&1 | python3 deps/jet.py
JULIA_DEBUG="" julia --color=no deps/jet.jl 2>&1 | python3 deps/jet.py
2 changes: 1 addition & 1 deletion docs/v0.1.0/.documenter-siteinfo.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"documenter":{"julia_version":"1.10.0","generation_timestamp":"2024-05-20T15:09:19","documenter_version":"1.4.0"}}
{"documenter":{"julia_version":"1.10.0","generation_timestamp":"2024-05-20T16:19:41","documenter_version":"1.4.0"}}
38 changes: 29 additions & 9 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,17 +478,21 @@ Element-wise operation that converts every element to a value inside a range.
At least one of `min` and `max` must be specified.
"""
struct Clamp <: EltwiseOperation
dtype::Maybe{Type}
min::Float64
max::Float64
end
@query_operation Clamp

function Clamp(; min::StorageNumber = -Inf, max::StorageNumber = Inf)::Clamp
function Clamp(; dtype::Maybe{Type} = nothing, min::StorageNumber = -Inf, max::StorageNumber = Inf)::Clamp
@assert min < max
return Clamp(Float64(min), Float64(max))
return Clamp(dtype, Float64(min), Float64(max))
end

function Clamp(operation_name::Token, parameters_values::Dict{String, Token})::Clamp
dtype = parse_parameter_value(operation_name, "eltwise", parameters_values, "dtype", nothing) do parameter_value
return parse_number_dtype_value(operation_name, "dtype", parameter_value)
end
min = parse_parameter_value(operation_name, "eltwise", parameters_values, "min", -Inf) do parameter_value
return parse_number_value(operation_name, "min", parameter_value, Float64)
end
Expand All @@ -499,21 +503,37 @@ function Clamp(operation_name::Token, parameters_values::Dict{String, Token})::C
end
return value
end
return Clamp(min, max)
return Clamp(dtype, min, max)
end

function is_int(value::Float64)::Bool
return value == -Inf || value == Inf || isinteger(value)
end

function dtype_for_clamp(operation::Clamp, input_type::Type)::Type
if operation.dtype !== nothing
return operation.dtype
elseif input_type <: Integer && is_int(operation.min) && is_int(operation.max)
return int_dtype_for(input_type, nothing)
else
return float_dtype_for(input_type, nothing)
end
end

function compute_eltwise(
operation::Clamp,
input::Union{StorageMatrix{T}, StorageVector{T}},
)::Union{StorageMatrix{T}, StorageVector{T}} where {T <: StorageNumber}
output = copy_array(input)
clamp!(output, operation.min, operation.max)
input::Union{StorageMatrix, StorageVector},
)::Union{StorageMatrix, StorageVector}
dtype = dtype_for_clamp(operation, eltype(input))
output = similar(input, dtype)
output .= clamp.(input, operation.min, operation.max)
return output
end

function compute_eltwise(operation::Clamp, input::T)::T where {T <: StorageNumber}
function compute_eltwise(operation::Clamp, input::T)::StorageNumber where {T <: StorageNumber}
dtype = dtype_for_clamp(operation, T)
output = clamp(input, operation.min, operation.max)
return T(output)
return dtype(output)
end

"""
Expand Down
26 changes: 23 additions & 3 deletions test/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,29 @@ nested_test("operations") do
end

nested_test("vector") do
set_vector!(daf, "cell", "value", [-1.7, 2.3])
@test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1.7, 0.0], Float64)
@test with_type(daf["/ cell : value % Clamp min 0"]) == ([0.0, 2.3], Float64)
nested_test("float") do
set_vector!(daf, "cell", "value", [-1.7, 2.3])
@test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1.7, 0.0], Float64)
@test with_type(daf["/ cell : value % Clamp min 0"]) == ([0.0, 2.3], Float64)
end

nested_test("int") do
set_vector!(daf, "cell", "value", [-1, 2])
@test with_type(daf["/ cell : value % Clamp max 0"]) == ([-1, 0], Int64)
@test with_type(daf["/ cell : value % Clamp min 0"]) == ([0, 2], Int64)
end

nested_test("mix") do
set_vector!(daf, "cell", "value", [-1, 2])
@test with_type(daf["/ cell : value % Clamp max 0.5"]) == ([-1.0, 0.5], Float64)
@test with_type(daf["/ cell : value % Clamp min 0.5"]) == ([0.5, 2.0], Float64)
end

nested_test("dtype") do
set_vector!(daf, "cell", "value", [-1, 2])
@test with_type(daf["/ cell : value % Clamp dtype Float32 max 0"]) == ([-1, 0], Float32)
@test with_type(daf["/ cell : value % Clamp dtype Float32 min 0"]) == ([0, 2], Float32)
end
end

nested_test("matrix") do
Expand Down

0 comments on commit 9c75289

Please sign in to comment.