Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
minseongg committed Nov 1, 2024
1 parent 4681dc1 commit 7173763
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

use super::*;

/// Bit width of the register type.
const ACC_BITS: usize = 32;

/// PE row data signals.
#[derive(Debug, Clone, Copy)]
pub struct PeRowData {
Expand Down Expand Up @@ -51,7 +54,7 @@ pub struct PeControl {

/// The number of bits by which the accumulated result of matrix multiplication is right-shifted when leaving the
/// systolic array, used to scale down the result.
pub shift: U<5>,
pub shift: U<{ clog2(ACC_BITS) }>,
}

/// Represents the dataflow.
Expand Down Expand Up @@ -89,10 +92,10 @@ pub enum Propagate {
#[derive(Debug, Default, Clone, Copy)]
pub struct PeS {
/// Register 1.
pub reg1: S<32>,
pub reg1: S<ACC_BITS>,

/// Register 2.
pub reg2: S<32>,
pub reg2: S<ACC_BITS>,

/// The propagate value comes from the previous input.
///
Expand All @@ -102,7 +105,7 @@ pub struct PeS {

impl PeS {
/// Creates a new PE state.
pub fn new(reg1: S<32>, reg2: S<32>, propagate: Propagate) -> Self {
pub fn new(reg1: S<ACC_BITS>, reg2: S<ACC_BITS>, propagate: Propagate) -> Self {
Self { reg1, reg2, propagate }
}

Expand All @@ -113,7 +116,10 @@ impl PeS {
/// - `preload`: Bias value for the next operation.
/// - `partial_sum`: MAC result of the current operation.
/// - `propagate`: Propagate value.
pub fn new_os(preload: S<32>, partial_sum: S<32>, propagate: Propagate) -> Self {
pub fn new_os(preload: S<OUTPUT_BITS>, partial_sum: S<OUTPUT_BITS>, propagate: Propagate) -> Self {
let preload = preload.sext::<ACC_BITS>();
let partial_sum = partial_sum.sext::<ACC_BITS>();

match propagate {
Propagate::Reg1 => PeS::new(preload, partial_sum, propagate),
Propagate::Reg2 => PeS::new(partial_sum, preload, propagate),
Expand All @@ -127,7 +133,10 @@ impl PeS {
/// - `preload`: Weight value for the next operation.
/// - `weight`: Weight value for the current operation.
/// - `propagate`: Propagate value.
pub fn new_ws(preload: S<32>, weight: S<32>, propagate: Propagate) -> Self {
pub fn new_ws(preload: S<INPUT_BITS>, weight: S<INPUT_BITS>, propagate: Propagate) -> Self {
let preload = preload.sext::<ACC_BITS>();
let weight = weight.sext::<ACC_BITS>();

match propagate {
Propagate::Reg1 => PeS::new(preload, weight, propagate),
Propagate::Reg2 => PeS::new(weight, preload, propagate),
Expand All @@ -138,16 +147,16 @@ impl PeS {
/// MAC unit (computes `a * b + c`).
///
/// It preserves the signedness of operands.
fn mac(a: S<8>, b: S<8>, c: S<32>) -> S<OUTPUT_BITS> {
fn mac(a: S<INPUT_BITS>, b: S<INPUT_BITS>, c: S<ACC_BITS>) -> S<OUTPUT_BITS> {
super::arithmetic::mac(a, b, c)
}

/// Performs right-shift (`val >> shamt`) and then clips to `OUTPUT_BITS`.
///
/// It preserves the signedness of `val`.
fn shift_and_clip(val: S<32>, shamt: U<5>) -> S<OUTPUT_BITS> {
fn shift_and_clip(val: S<ACC_BITS>, shamt: U<{ clog2(ACC_BITS) }>) -> S<OUTPUT_BITS> {
let shifted = rounding_shift(val, shamt);
super::arithmetic::clip_with_saturation::<32, OUTPUT_BITS>(shifted)
super::arithmetic::clip_with_saturation::<ACC_BITS, OUTPUT_BITS>(shifted)
}

/// PE.
Expand Down

0 comments on commit 7173763

Please sign in to comment.