Skip to content

Commit

Permalink
[FIX] Add missing dtype conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem committed Aug 20, 2024
1 parent 23092e8 commit 71f4490
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
54 changes: 43 additions & 11 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/random_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,12 +644,42 @@ void MersenneTwisterGenerator<isa>::generate() {
registersPool = RegistersPool::create(isa, {rax, rcx, rsp, rdi, k0});

r64_dst = getReg64();
r64_state = getReg64();
r64_state_shift = getReg64();
r64_step = getReg64();
r64_work_amount = getReg64();
r64_elements_remaining = getReg64();
r64_optimization_enabled = getReg64();
r64_output_type = getReg64();

mov(r64_work_amount, ptr[r64_params + GET_OFF(work_amount)]);
mov(r64_dst, ptr[r64_params + GET_OFF(dst_ptr)]);
mov(r64_state_id, ptr[r64_params + GET_OFF(state_id)]);
mov(r64_state_shift, ptr[r64_params + GET_OFF(state_shift)]);
mov(r64_step, ptr[r64_params + GET_OFF(step)]);
mov(r64_work_amount, ptr[r64_params + GET_OFF(work_amount)]);
mov(r64_elements_remaining, ptr[r64_params + GET_OFF(elements_remaining)]);
mov(r64_optimization_enabled, ptr[r64_params + GET_OFF(optimization_enabled)]);
mov(r64_output_type, ptr[r64_params + GET_OFF(out_data_type)]);

// switch (m_jcp.out_data_type) {
// case element::f32:
// mov(r64_output_type, FLOAT_AS_VALUE);
// break;
// case element::f16:
// mov(r64_output_type, FLOAT16_AS_VALUE);
// break;
// case element::bf16:
// mov(r64_output_type, BFLOAT16_AS_VALUE);
// break;
// case element::i32:
// mov(r64_output_type, INT_AS_VALUE);
// break;
// case element::i64:
// mov(r64_output_type, INT64_AS_VALUE);
// break;
// default:
// break;
// }

initVectors();
process();
Expand All @@ -662,7 +692,6 @@ template <>
void MersenneTwisterGenerator<x64::avx512_core>::initVectors() {
const auto r64_aux = getReg64();
const auto r32_aux = Xbyak::Reg32(r64_aux.getIdx());
const auto r16_aux = Xbyak::Reg16(r64_aux.getIdx());

v_min = getVmm();
v_range = getVmm();
Expand All @@ -672,7 +701,7 @@ void MersenneTwisterGenerator<x64::avx512_core>::initVectors() {
v_divisor = getVmm();

// Initialize constants based on the requested data type
switch (m_output_prc) {
switch (m_jcp.out_data_type) {
case element::f32:
BROADCAST_R(vpbroadcastd, v_mask, r32_aux, (1 << std::numeric_limits<float>::digits) - 1)
BROADCAST_R(vpbroadcastd, v_divisor, r32_aux, 1.0f / (1 << std::numeric_limits<float>::digits))
Expand All @@ -691,6 +720,9 @@ void MersenneTwisterGenerator<x64::avx512_core>::initVectors() {

BROADCAST_P(vpbroadcastd, v_min, r64_aux, min_ptr)
BROADCAST_P(vpbroadcastd, v_range, r64_aux, range_ptr)

BROADCAST_R(vpbroadcastd, v_const_1, r32_aux, MT_CONST_1)
BROADCAST_R(vpbroadcastd, v_const_2, r32_aux, MT_CONST_2)
}

template <x64::cpu_isa_t isa>
Expand Down Expand Up @@ -729,10 +761,10 @@ void MersenneTwisterGenerator<isa>::generateRandomNumbers(const Vmm& v_result, c
psrld(v_result, MT_U);
pxor(v_result, v_state);
pslld(v_result, MT_S);
pand(v_result, MT_CONST_1);
pand(v_result, v_const_1);
pxor(v_result, v_state);
pslld(v_result, MT_T);
pand(v_result, MT_CONST_2);
pand(v_result, v_const_2);
pxor(v_result, v_state);
psrld(v_result, MT_L);
pxor(v_result, v_state);
Expand All @@ -742,21 +774,21 @@ void MersenneTwisterGenerator<isa>::generateRandomNumbers(const Vmm& v_result, c
}

template <x64::cpu_isa_t isa>
void MersenneTwisterGenerator<isa>::convertToOutputTypeMersenne(const Vmm& v_result, const Vmm& v_min, const Vmm& v_range, const Vmm& v_dst, const Reg64& r64_elements_remaining) {
void MersenneTwisterGenerator<isa>::convertToOutputTypeMersenne(const Vmm& v_result, const Vmm& v_min, const Vmm& v_range, const Vmm& v_dst, const Xbyak::Reg64& r64_elements_remaining) {
using namespace Xbyak;

Label float_case, float16_case, bfloat16_case, int32_case, int64_case, end;

// Check the output type and jump to the corresponding case
cmp(r64_output_type, element::f32);
cmp(r64_output_type, FLOAT_AS_VALUE);
je(float_case);
cmp(r64_output_type, element::f16);
cmp(r64_output_type, FLOAT16_AS_VALUE);
je(float16_case);
cmp(r64_output_type, element::bf16);
cmp(r64_output_type, BFLOAT16_AS_VALUE);
je(bfloat16_case);
cmp(r64_output_type, element::i32);
cmp(r64_output_type, INT_AS_VALUE);
je(int32_case);
cmp(r64_output_type, element::i64);
cmp(r64_output_type, INT64_AS_VALUE);
je(int64_case);
jmp(end);

Expand Down
20 changes: 19 additions & 1 deletion src/plugins/intel_cpu/src/nodes/kernels/x64/random_uniform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct MersenneTwisterGeneratorCallArgs {
uint64_t work_amount = 0lu;
uint64_t elements_remaining = 0lu;
bool optimization_enabled = false;
uint32_t out_data_type = 0u;

};

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
Expand Down Expand Up @@ -123,13 +125,16 @@ class MersenneTwisterGenerator : public JitKernel<GeneratorCompileParams, Mersen
isa == dnnl::impl::cpu::x64::sse41, Xbyak::Xmm,
Xbyak::Ymm>::type;


RegistersPool::Reg<Xbyak::Reg64> r64_dst;
RegistersPool::Reg<Xbyak::Reg64> r64_state;
RegistersPool::Reg<Xbyak::Reg64> r64_state_id;
RegistersPool::Reg<Xbyak::Reg64> r64_state_shift;
RegistersPool::Reg<Xbyak::Reg64> r64_step;
RegistersPool::Reg<Xbyak::Reg64> r64_work_amount;
RegistersPool::Reg<Xbyak::Reg64> r64_elements_remaining;
RegistersPool::Reg<Xbyak::Reg64> r64_optimization_enabled;
RegistersPool::Reg<Xbyak::Reg64> r64_output_type;



const Xbyak::Reg64 r64_params = Xbyak::Reg64(dnnl::impl::cpu::x64::abi_param_regs[0]);
Expand All @@ -149,6 +154,9 @@ class MersenneTwisterGenerator : public JitKernel<GeneratorCompileParams, Mersen
RegistersPool::Reg<Vmm> v_result_bitshift_15_const_2;
RegistersPool::Reg<Vmm> v_result_bitshift_18;

RegistersPool::Reg<Vmm> v_const_1;
RegistersPool::Reg<Vmm> v_const_2;

//Vector registers for conversion.
RegistersPool::Reg<Vmm> v_mask;
RegistersPool::Reg<Vmm> v_divisor;
Expand All @@ -160,6 +168,8 @@ class MersenneTwisterGenerator : public JitKernel<GeneratorCompileParams, Mersen

void generateRandomNumbers(const Vmm& v_dst_0, const Vmm& v_dst_1);

void convertToOutputTypeMersenne(const Vmm& v_result, const Vmm& v_min, const Vmm& v_range, const Vmm& v_dst, const Xbyak::Reg64& r64_elements_remaining);

// Mersenne Twister constants
static constexpr uint32_t MT_CONST_1 = 0x9D2C5680;
static constexpr uint32_t MT_CONST_2 = 0xEFC60000;
Expand All @@ -171,6 +181,14 @@ class MersenneTwisterGenerator : public JitKernel<GeneratorCompileParams, Mersen
static constexpr uint32_t MT_L = 18;
static constexpr uint32_t MT_4_ELEMENTS = 4;
static constexpr uint32_t MT_2_ELEMENTS = 2;

static constexpr uint32_t FLOAT_AS_VALUE = 0;
static constexpr uint32_t FLOAT16_AS_VALUE = 1;
static constexpr uint32_t BFLOAT16_AS_VALUE = 2;
static constexpr uint32_t INT_AS_VALUE = 3;
static constexpr uint32_t INT64_AS_VALUE = 4;


};

} // namespace random_uniform
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_cpu/src/nodes/random_uniform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ class RandomUniform : public Node {

void computeMersenneTwister(void* out, size_t work_amount);

static constexpr uint32_t FLOAT_AS_VALUE = 0;
static constexpr uint32_t FLOAT16_AS_VALUE = 1;
static constexpr uint32_t BFLOAT16_AS_VALUE = 2;
static constexpr uint32_t INT_AS_VALUE = 3;
static constexpr uint32_t INT64_AS_VALUE = 4;

/////////////////////////////////////////////////////////////////////////////////

///// STL /////
Expand Down

0 comments on commit 71f4490

Please sign in to comment.