diff --git a/bochs/cpu/avx/amx.cc b/bochs/cpu/avx/amx.cc index a1ad998afd..db62290485 100644 --- a/bochs/cpu/avx/amx.cc +++ b/bochs/cpu/avx/amx.cc @@ -106,15 +106,14 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i) check_tile(i, tile); unsigned rows = BX_CPU_THIS_PTR amx->tile_num_rows(tile); - unsigned bytes_per_row = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile); + unsigned dword_elements_per_row = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile); if (BX_CPU_THIS_PTR amx->start_row >= rows) { BX_ERROR(("%s: invalid tile %d (start_row=%d) >= (rows=%d)", i->getIaOpcodeNameShort(), tile, BX_CPU_THIS_PTR amx->start_row, rows)); exception(BX_UD_EXCEPTION, 0); } - unsigned elements_per_row = bytes_per_row / 4; - Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF; + Bit32u mask = (dword_elements_per_row < 16) ? ((1 << dword_elements_per_row) - 1) : 0xFFFF; BX_CPU_THIS_PTR amx->set_tile_used(tile); @@ -128,13 +127,13 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i) BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]); Bit64u eaddr = start_eaddr + row * stride; - if (bytes_per_row == 64) { + if (dword_elements_per_row == 16) { read_linear_zmmword(i->seg(), get_laddr64(i->seg(), eaddr), data); } else { avx_masked_load32(i, eaddr, data, mask); - for (unsigned n=elements_per_row; n < 16; n++) + for (unsigned n=dword_elements_per_row; n < 16; n++) data->vmm32u(n) = 0; } @@ -158,15 +157,14 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILESTORED_MdqTnnn(bxInstruction_c *i) check_tile(i, tile); unsigned rows = BX_CPU_THIS_PTR amx->tile_num_rows(tile); - unsigned bytes_per_row = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile); + unsigned dword_elements_per_row = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile); if (BX_CPU_THIS_PTR amx->start_row >= rows) { BX_ERROR(("TILESTORED: invalid tile %d (start_row=%d) >= (rows=%d)", tile, BX_CPU_THIS_PTR amx->start_row, rows)); exception(BX_UD_EXCEPTION, 0); } - unsigned elements_per_row = bytes_per_row / 4; - Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF; + Bit32u mask = (dword_elements_per_row < 16) ? ((1 << dword_elements_per_row) - 1) : 0xFFFF; i->setVL(BX_VL512); Bit64u start_eaddr = BX_READ_64BIT_REG(i->sibBase()) + (Bit64s) i->displ32s(); @@ -175,7 +173,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILESTORED_MdqTnnn(bxInstruction_c *i) for (unsigned row=BX_CPU_THIS_PTR amx->start_row; row < rows; row++) { BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]); Bit64u eaddr = start_eaddr + row * stride; - if (bytes_per_row == 64) + if (dword_elements_per_row == 16) write_linear_zmmword(i->seg(), get_laddr64(i->seg(), eaddr), data); else avx_masked_store32(i, eaddr, data, mask); @@ -243,27 +241,27 @@ void BX_CPU_C::check_tiles(bxInstruction_c *i, unsigned tile_dst, unsigned tile_ check_tile(i, tile_src2); unsigned rows[3]; - unsigned bytes_per_row[3]; + unsigned dword_elements_per_row[3]; rows[0] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); - bytes_per_row[0] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst); + dword_elements_per_row[0] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); rows[1] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src1); - bytes_per_row[1] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_src1); + dword_elements_per_row[1] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_src1); rows[2] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); - bytes_per_row[2] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_src2); + dword_elements_per_row[2] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_src2); // R C // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned n = bytes_per_row[0] / 4; + unsigned n = dword_elements_per_row[0]; unsigned m = rows[1]; unsigned k = rows[2]; // #UD if srcdest.colbytes != src2.colbytes (n) // #UD if srcdest.rows != src1.rows (m) // #UD if src1.colbytes / 4 != src2.rows (k) - if (n != (bytes_per_row[2] / 4) || m != rows[0] || k != (bytes_per_row[1] / 4)) { + if (n != dword_elements_per_row[2] || m != rows[0] || k != dword_elements_per_row[1]) { BX_ERROR(("%s: invalid matmul tile dimenstions", i->getIaOpcodeNameShort())); exception(BX_UD_EXCEPTION, 0); } @@ -341,7 +339,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSSD_TnnnTrmTreg(bxInstruction_c *i) /* A = m x k (tsrc1) */ /* B = k x n (tsrc2) */ /* C = m x n (tsrcdest) */ - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -374,7 +372,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSUD_TnnnTrmTreg(bxInstruction_c *i) /* A = m x k (tsrc1) */ /* B = k x n (tsrc2) */ /* C = m x n (tsrcdest) */ - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -407,7 +405,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUSD_TnnnTrmTreg(bxInstruction_c *i) /* A = m x k (tsrc1) */ /* B = k x n (tsrc2) */ /* C = m x n (tsrcdest) */ - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -440,7 +438,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUUD_TnnnTrmTreg(bxInstruction_c *i) /* A = m x k (tsrc1) */ /* B = k x n (tsrc2) */ /* C = m x n (tsrcdest) */ - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -480,7 +478,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBF16PS_TnnnTrmTreg(bxInstruction_c *i) // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -534,7 +532,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPFP16PS_TnnnTrmTreg(bxInstruction_c *i) // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -586,7 +584,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TCMMRLFP16PS_TnnnTrmTreg(bxInstruction_c * // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -638,7 +636,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TCMMIMFP16PS_TnnnTrmTreg(bxInstruction_c * // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); @@ -697,7 +695,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TMMULTF32PS_TnnnTrmTreg(bxInstruction_c *i // A = m x k (tsrc1) // B = k x n (tsrc2) // C = m x n (tsrcdest) - unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4; + unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst); unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst); unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2); diff --git a/bochs/cpu/avx/amx.h b/bochs/cpu/avx/amx.h index 0994aa25db..607cbab212 100644 --- a/bochs/cpu/avx/amx.h +++ b/bochs/cpu/avx/amx.h @@ -51,6 +51,7 @@ struct AMX { bool tile_valid(unsigned tile_num) const { return tilecfg[tile_num].rows != 0; } unsigned tile_num_rows(unsigned tile_num) const { return tilecfg[tile_num].rows; } unsigned tile_bytes_per_row(unsigned tile_num) const { return tilecfg[tile_num].bytes_per_row; } + unsigned tile_dword_elements_per_row(unsigned tile_num) const { return tilecfg[tile_num].bytes_per_row / 4; } bool is_tile_used(unsigned tile_num) const { return tile_use_tracker & (1 << tile_num); } void set_tile_used(unsigned tile_num) { tile_use_tracker |= (1 << tile_num); }