Skip to content

Commit

Permalink
simplify code in amx.cc to avoid "magic" divide by 4 everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislav Shwartsman committed Nov 23, 2024
1 parent 6097fcc commit df12ca4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
46 changes: 22 additions & 24 deletions bochs/cpu/avx/amx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,14 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i)
check_tile(i, tile);

unsigned rows = BX_CPU_THIS_PTR amx->tile_num_rows(tile);
unsigned bytes_per_row = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile);
unsigned dword_elements_per_row = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile);

if (BX_CPU_THIS_PTR amx->start_row >= rows) {
BX_ERROR(("%s: invalid tile %d (start_row=%d) >= (rows=%d)", i->getIaOpcodeNameShort(), tile, BX_CPU_THIS_PTR amx->start_row, rows));
exception(BX_UD_EXCEPTION, 0);
}

unsigned elements_per_row = bytes_per_row / 4;
Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF;
Bit32u mask = (dword_elements_per_row < 16) ? ((1 << dword_elements_per_row) - 1) : 0xFFFF;

BX_CPU_THIS_PTR amx->set_tile_used(tile);

Expand All @@ -128,13 +127,13 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILELOADD_TnnnMdq(bxInstruction_c *i)
BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]);

Bit64u eaddr = start_eaddr + row * stride;
if (bytes_per_row == 64) {
if (dword_elements_per_row == 16) {
read_linear_zmmword(i->seg(), get_laddr64(i->seg(), eaddr), data);
}
else {
avx_masked_load32(i, eaddr, data, mask);

for (unsigned n=elements_per_row; n < 16; n++)
for (unsigned n=dword_elements_per_row; n < 16; n++)
data->vmm32u(n) = 0;
}

Expand All @@ -158,15 +157,14 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILESTORED_MdqTnnn(bxInstruction_c *i)
check_tile(i, tile);

unsigned rows = BX_CPU_THIS_PTR amx->tile_num_rows(tile);
unsigned bytes_per_row = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile);
unsigned dword_elements_per_row = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile);

if (BX_CPU_THIS_PTR amx->start_row >= rows) {
BX_ERROR(("TILESTORED: invalid tile %d (start_row=%d) >= (rows=%d)", tile, BX_CPU_THIS_PTR amx->start_row, rows));
exception(BX_UD_EXCEPTION, 0);
}

unsigned elements_per_row = bytes_per_row / 4;
Bit32u mask = (elements_per_row < 16) ? ((1 << elements_per_row) - 1) : 0xFFFF;
Bit32u mask = (dword_elements_per_row < 16) ? ((1 << dword_elements_per_row) - 1) : 0xFFFF;
i->setVL(BX_VL512);

Bit64u start_eaddr = BX_READ_64BIT_REG(i->sibBase()) + (Bit64s) i->displ32s();
Expand All @@ -175,7 +173,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TILESTORED_MdqTnnn(bxInstruction_c *i)
for (unsigned row=BX_CPU_THIS_PTR amx->start_row; row < rows; row++) {
BxPackedAvxRegister *data = &(BX_CPU_THIS_PTR amx->tile[tile].row[row]);
Bit64u eaddr = start_eaddr + row * stride;
if (bytes_per_row == 64)
if (dword_elements_per_row == 16)
write_linear_zmmword(i->seg(), get_laddr64(i->seg(), eaddr), data);
else
avx_masked_store32(i, eaddr, data, mask);
Expand Down Expand Up @@ -243,27 +241,27 @@ void BX_CPU_C::check_tiles(bxInstruction_c *i, unsigned tile_dst, unsigned tile_
check_tile(i, tile_src2);

unsigned rows[3];
unsigned bytes_per_row[3];
unsigned dword_elements_per_row[3];

rows[0] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
bytes_per_row[0] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst);
dword_elements_per_row[0] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
rows[1] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src1);
bytes_per_row[1] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_src1);
dword_elements_per_row[1] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_src1);
rows[2] = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);
bytes_per_row[2] = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_src2);
dword_elements_per_row[2] = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_src2);

// R C
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned n = bytes_per_row[0] / 4;
unsigned n = dword_elements_per_row[0];
unsigned m = rows[1];
unsigned k = rows[2];

// #UD if srcdest.colbytes != src2.colbytes (n)
// #UD if srcdest.rows != src1.rows (m)
// #UD if src1.colbytes / 4 != src2.rows (k)
if (n != (bytes_per_row[2] / 4) || m != rows[0] || k != (bytes_per_row[1] / 4)) {
if (n != dword_elements_per_row[2] || m != rows[0] || k != dword_elements_per_row[1]) {
BX_ERROR(("%s: invalid matmul tile dimenstions", i->getIaOpcodeNameShort()));
exception(BX_UD_EXCEPTION, 0);
}
Expand Down Expand Up @@ -341,7 +339,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSSD_TnnnTrmTreg(bxInstruction_c *i)
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -374,7 +372,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBSUD_TnnnTrmTreg(bxInstruction_c *i)
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -407,7 +405,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUSD_TnnnTrmTreg(bxInstruction_c *i)
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -440,7 +438,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBUUD_TnnnTrmTreg(bxInstruction_c *i)
/* A = m x k (tsrc1) */
/* B = k x n (tsrc2) */
/* C = m x n (tsrcdest) */
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -480,7 +478,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPBF16PS_TnnnTrmTreg(bxInstruction_c *i)
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -534,7 +532,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TDPFP16PS_TnnnTrmTreg(bxInstruction_c *i)
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -586,7 +584,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TCMMRLFP16PS_TnnnTrmTreg(bxInstruction_c *
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -638,7 +636,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TCMMIMFP16PS_TnnnTrmTreg(bxInstruction_c *
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down Expand Up @@ -697,7 +695,7 @@ void BX_CPP_AttrRegparmN(1) BX_CPU_C::TMMULTF32PS_TnnnTrmTreg(bxInstruction_c *i
// A = m x k (tsrc1)
// B = k x n (tsrc2)
// C = m x n (tsrcdest)
unsigned max_n = BX_CPU_THIS_PTR amx->tile_bytes_per_row(tile_dst) / 4;
unsigned max_n = BX_CPU_THIS_PTR amx->tile_dword_elements_per_row(tile_dst);
unsigned max_m = BX_CPU_THIS_PTR amx->tile_num_rows(tile_dst);
unsigned max_k = BX_CPU_THIS_PTR amx->tile_num_rows(tile_src2);

Expand Down
1 change: 1 addition & 0 deletions bochs/cpu/avx/amx.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct AMX {
bool tile_valid(unsigned tile_num) const { return tilecfg[tile_num].rows != 0; }
unsigned tile_num_rows(unsigned tile_num) const { return tilecfg[tile_num].rows; }
unsigned tile_bytes_per_row(unsigned tile_num) const { return tilecfg[tile_num].bytes_per_row; }
unsigned tile_dword_elements_per_row(unsigned tile_num) const { return tilecfg[tile_num].bytes_per_row / 4; }

bool is_tile_used(unsigned tile_num) const { return tile_use_tracker & (1 << tile_num); }
void set_tile_used(unsigned tile_num) { tile_use_tracker |= (1 << tile_num); }
Expand Down

0 comments on commit df12ca4

Please sign in to comment.