diff --git a/core/Integer.savi b/core/Integer.savi index 306fa5b3..c9dcbc3c 100644 --- a/core/Integer.savi +++ b/core/Integer.savi @@ -120,6 +120,8 @@ :: highest bits shifting out of the bounds of the `bit_width` to disappear, :: and the lowest bits being filled by zeroes in the empty space left behind. :: + :: If more than `bit_width` bits are shifted, the result is all zero bits. + :: :: Because each bit represents a successive power of two, this operation is :: equivalent to multiplying the value by 2 the given number of times. :fun val bit_shl(bits U8) T @@ -131,8 +133,12 @@ :: lowest bits shifting out of the bounds of the `bit_width` to disappear, :: and the highest bits being filled by zeroes in the empty space left behind. :: + :: If more than `bit_width` bits are shifted, the result is all zero bits. + :: :: Because each bit represents a successive power of two, this operation is - :: equivalent to dividing the value by 2 the given number of times. + :: equivalent to dividing the value by 2 the given number of times, provided + :: that the value is a positive integer rather than a negative integer + :: (as a shifted negative integer will have its sign bit filled with a zero). :fun val bit_shr(bits U8) T :: Do a bitwise "rotate left" on this value by the given number of bits. diff --git a/spec/core/Numeric.Spec.savi b/spec/core/Numeric.Spec.savi index d72c3c29..117efcb0 100644 --- a/spec/core/Numeric.Spec.savi +++ b/spec/core/Numeric.Spec.savi @@ -372,6 +372,39 @@ assert: U16[0b1110010110001010].bit_rotl(5) == 0b1011000101011100 assert: U16[0b1110010110001010].bit_shr(5) == 0b0000011100101100 assert: U16[0b1110010110001010].bit_rotr(5) == 0b0101011100101100 + assert: True.bit_shl(0) == True + assert: True.bit_shl(1) == False + assert: True.bit_shr(0) == True + assert: True.bit_shr(1) == False + assert: False.bit_shl(0) == False + assert: False.bit_shl(1) == False + assert: False.bit_shr(0) == False + assert: False.bit_shr(1) == False + + :it "uses logical bit shift right, even for signed integers" + // Some languages/compilers use "arithmetic right shift" for signed integers + // in which the sign bit is preserved during shifting, such that shifting + // by one can be treated as a proxy for dividing by two, even for + // negative numbers (which must keep the sign bit as 1 to remain negative). + // + // However, in Savi, we use only "logical right shift", wherein for any + // non-zero shift amount, the new most significant bits will always be zero. + // This makes bit shifting operations work consistently for signed and + // unsigned integers, but makes it not a tenable practice to use shifting + // as a proxy for dividing by two. Just use division instead, and let LLVM + // optimize division by twos into arithmetic bit shifts where appropriate. + + assert: I16[0b1011011100111101].bit_shr(0) == 0b1011011100111101 + assert: I16[0b1011011100111101].bit_shr(1) == 0b0101101110011110 + assert: I16[0b1011011100111101].bit_shr(5) == 0b0000010110111001 + assert: I16[0b1011011100111101].bit_shr(13) == 0b0000000000000101 + assert: I16[0b1011011100111101].bit_shr(16) == 0b0000000000000000 + + assert: U16[0b1011011100111101].bit_shr(0) == 0b1011011100111101 + assert: U16[0b1011011100111101].bit_shr(1) == 0b0101101110011110 + assert: U16[0b1011011100111101].bit_shr(5) == 0b0000010110111001 + assert: U16[0b1011011100111101].bit_shr(13) == 0b0000000000000101 + assert: U16[0b1011011100111101].bit_shr(16) == 0b0000000000000000 :it "implements special multiplication without overflow by returning a pair" product = U8[99].wide_multiply(200) diff --git a/src/savi/compiler/code_gen.cr b/src/savi/compiler/code_gen.cr index 8879b3f7..90d236db 100644 --- a/src/savi/compiler/code_gen.cr +++ b/src/savi/compiler/code_gen.cr @@ -1265,28 +1265,36 @@ class Savi::Compiler::CodeGen @builder.xor(params[0], params[1]) when "bit_shl" raise "bit_shl float" if gtype.type_def.is_floating_point_numeric?(ctx) - bits = gen_numeric_conv(@gtypes["U8"], gtype, params[1]) - clamp = llvm_type_of(gtype).const_int(bit_width_of(gtype) - 1) - bits = @builder.select( - @builder.icmp(LLVM::IntPredicate::ULE, bits, clamp), - bits, - clamp, - ) - @builder.shl(params[0], bits) + + bits = params[1] + all_bits = @i8.const_int(bit_width_of(gtype)) + is_all = @builder.icmp(LLVM::IntPredicate::UGE, bits, all_bits) + + zero_block = gen_block("zero") + normal_block = gen_block("normal") + @builder.cond(is_all, zero_block, normal_block) + + finish_block_and_move_to(zero_block) + @builder.ret(llvm_type_of(gtype).const_int(0)) + + finish_block_and_move_to(normal_block) + @builder.shl(params[0], gen_numeric_conv(@gtypes["U8"], gtype, bits)) when "bit_shr" raise "bit_shr float" if gtype.type_def.is_floating_point_numeric?(ctx) - bits = gen_numeric_conv(@gtypes["U8"], gtype, params[1]) - clamp = llvm_type_of(gtype).const_int(bit_width_of(gtype) - 1) - bits = @builder.select( - @builder.icmp(LLVM::IntPredicate::ULE, bits, clamp), - bits, - clamp, - ) - if gtype.type_def.is_signed_numeric?(ctx) - @builder.ashr(params[0], bits) - else - @builder.lshr(params[0], bits) - end + + bits = params[1] + all_bits = @i8.const_int(bit_width_of(gtype)) + is_all = @builder.icmp(LLVM::IntPredicate::UGE, bits, all_bits) + + zero_block = gen_block("zero") + normal_block = gen_block("normal") + @builder.cond(is_all, zero_block, normal_block) + + finish_block_and_move_to(zero_block) + @builder.ret(llvm_type_of(gtype).const_int(0)) + + finish_block_and_move_to(normal_block) + @builder.lshr(params[0], gen_numeric_conv(@gtypes["U8"], gtype, bits)) when "invert" raise "invert float" if gtype.type_def.is_floating_point_numeric?(ctx) @builder.not(params[0])