diff --git a/src/main/scala/device/MemEncrypt.scala b/src/main/scala/device/MemEncrypt.scala new file mode 100644 index 00000000000..550ea1cc8e7 --- /dev/null +++ b/src/main/scala/device/MemEncrypt.scala @@ -0,0 +1,1196 @@ +/*************************************************************************************** +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences +* +* XiangShan is licensed under Mulan PSL v2. +* You can use this software according to the terms and conditions of the Mulan PSL v2. +* You may obtain a copy of Mulan PSL v2 at: +* http://license.coscl.org.cn/MulanPSL2 +* +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +* +* See the Mulan PSL v2 for more details. +***************************************************************************************/ + +package device + +import chisel3._ +import chisel3.util._ +import chisel3.util.HasBlackBoxResource +import org.chipsalliance.cde.config.Field +import org.chipsalliance.cde.config.Parameters +import freechips.rocketchip.amba.axi4._ +import freechips.rocketchip.diplomacy._ +import freechips.rocketchip.util._ +import freechips.rocketchip.amba.apb._ +import freechips.rocketchip.tilelink.AXI4TLState +import javax.xml.crypto.dsig.keyinfo.KeyInfo +import system._ + +case object MemcEdgeInKey extends Field[AXI4EdgeParameters] +case object MemcEdgeOutKey extends Field[AXI4EdgeParameters] + +trait Memconsts { + val p: Parameters + val cvm = p(CVMParamskey) + val soc = p(SoCParamsKey) + val PAddrBits= soc.PAddrBits + val KeyIDBits= cvm.KeyIDBits + val MemencPipes = cvm.MemencPipes + lazy val MemcedgeIn = p(MemcEdgeInKey) + lazy val MemcedgeOut = p(MemcEdgeOutKey) + require (isPow2(MemencPipes), s"AXI4MemEncrypt: MemencPipes must be a power of two, not $MemencPipes") + require (PAddrBits > KeyIDBits, s"AXI4MemEncrypt: PAddrBits must be greater than KeyIDBits") + def HasDelayNoencryption = cvm.HasDelayNoencryption +} + + +abstract class MemEncryptModule(implicit val p: Parameters) extends Module with Memconsts + +class TweakEncrptyQueue(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val enq = Flipped(DecoupledIO(new Bundle { + val addr = UInt(PAddrBits.W) + val len = UInt(MemcedgeIn.bundle.lenBits.W) // number of beats - 1 + })) + val deq = DecoupledIO(new Bundle { + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(MemcedgeIn.bundle.dataBits.W) + val addr = UInt(MemcedgeIn.bundle.addrBits.W) + }) + val tweak_round_keys = Input(Vec(32, UInt(32.W))) + }) + val tweak_in = Cat(0.U((128 - PAddrBits).W), Cat(io.enq.bits.addr(PAddrBits - 1, 6), 0.U(6.W))) + + val tweak_enc_module = Module(new TweakEncrypt(opt = true)) + val tweakgf128_module = Module(new TweakGF128()) + + tweak_enc_module.io.tweak_enc_req.valid := io.enq.valid + tweak_enc_module.io.tweak_enc_resp.ready := tweakgf128_module.io.req.ready + tweak_enc_module.io.tweak_enc_req.bits.tweak := tweak_in + tweak_enc_module.io.tweak_enc_req.bits.addr_in := io.enq.bits.addr + tweak_enc_module.io.tweak_enc_req.bits.len_in := io.enq.bits.len + tweak_enc_module.io.tweak_enc_req.bits.id_in := 0.U + tweak_enc_module.io.tweak_enc_req.bits.tweak_round_keys := io.tweak_round_keys + + io.enq.ready := tweak_enc_module.io.tweak_enc_req.ready + + tweakgf128_module.io.req.bits.len := tweak_enc_module.io.tweak_enc_resp.bits.len_out + tweakgf128_module.io.req.bits.addr := tweak_enc_module.io.tweak_enc_resp.bits.addr_out + tweakgf128_module.io.req.bits.tweak_in := tweak_enc_module.io.tweak_enc_resp.bits.tweak_encrpty + tweakgf128_module.io.req.valid := tweak_enc_module.io.tweak_enc_resp.valid + tweakgf128_module.io.resp.ready := io.deq.ready + + io.deq.bits.keyid := tweakgf128_module.io.resp.bits.keyid_out + io.deq.bits.tweak := tweakgf128_module.io.resp.bits.tweak_out + io.deq.bits.addr := tweakgf128_module.io.resp.bits.addr_out + io.deq.valid := tweakgf128_module.io.resp.valid +} + +class AXI4W_KT(opt:Boolean)(implicit val p: Parameters) extends Bundle with Memconsts +{ + val edgeUse = if (opt) MemcedgeIn else MemcedgeOut + val axi4 = new AXI4BundleW(edgeUse.bundle) + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(edgeUse.bundle.dataBits.W) +} + +// Used to indicate the source of the req (L1I/L1D/PTW) +case object ReqSourceKey extends ControlKey[UInt]("reqSource") + +class AXI4WriteMachine(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle))) + val in_kt = Flipped(DecoupledIO(new Bundle { + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(MemcedgeOut.bundle.dataBits.W) + val addr = UInt(MemcedgeOut.bundle.addrBits.W) + })) + val out_ar = Irrevocable(new AXI4BundleAR(MemcedgeOut.bundle)) + val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))) + val out_w = DecoupledIO(new AXI4W_KT(true)) + val uncache_en = Output(Bool()) + val uncache_commit = Input(Bool()) + }) + // ---------------- + // s0 stage + // ---------------- + val w_cacheable = io.in_w.bits.strb.andR + + // ---------------- + // s1 stage + // ---------------- + val in_w_v = RegInit(false.B) + val in_kt_v = RegInit(false.B) + + val in_w_req = RegEnable(io.in_w.bits, io.in_w.fire) + val in_kt_req = RegEnable(io.in_kt.bits, io.in_kt.fire) + io.in_w.ready := !in_w_v || io.out_w.fire + io.in_kt.ready := !in_kt_v || io.out_w.fire + + when(io.in_w.fire) { + in_w_v := true.B + }.elsewhen(io.out_w.fire) { + in_w_v := false.B + }.otherwise { + in_w_v := in_w_v + } + + when(io.in_kt.fire) { + in_kt_v := true.B + }.elsewhen(io.out_w.fire) { + in_kt_v := false.B + }.otherwise { + in_kt_v := in_kt_v + } + + // ----------------------------- + // s2 stage only uncacheable use + // ----------------------------- + val out_ar_v = RegInit(false.B) + val out_ar_mask = RegInit(false.B) + val in_r_v = RegInit(false.B) + val r_uncache_en = RegInit(false.B) + when(io.in_r.fire) { + in_r_v := true.B + }.elsewhen(io.out_w.fire) { + in_r_v := false.B + }.otherwise { + in_r_v := in_r_v + } + + when(io.in_r.fire) { + r_uncache_en := true.B + }.elsewhen(io.uncache_commit) { + r_uncache_en := false.B + }.otherwise { + r_uncache_en := r_uncache_en + } + + io.in_r.ready := !r_uncache_en || io.uncache_commit + io.uncache_en := r_uncache_en + + val s1_w_cacheable = RegEnable(w_cacheable, io.in_w.fire) + + when(in_w_v && in_kt_v && !s1_w_cacheable && !out_ar_mask) { + out_ar_v := true.B + }.elsewhen(io.out_ar.fire) { + out_ar_v := false.B + }.otherwise { + out_ar_v := out_ar_v + } + + when(in_w_v && in_kt_v && !s1_w_cacheable && !out_ar_mask) { + out_ar_mask := true.B + }.elsewhen(io.out_w.fire) { + out_ar_mask := false.B + }.otherwise { + out_ar_mask := out_ar_mask + } + + io.out_ar.valid := out_ar_v + val ar = io.out_ar.bits + ar.id := 1.U << (MemcedgeOut.bundle.idBits - 1) + ar.addr := (in_kt_req.addr >> log2Ceil(MemcedgeOut.bundle.dataBits/8)) << log2Ceil(MemcedgeOut.bundle.dataBits/8) + ar.len := 0.U + ar.size := log2Ceil(MemcedgeOut.bundle.dataBits/8).U + ar.burst := AXI4Parameters.BURST_INCR + ar.lock := 0.U // not exclusive (LR/SC unsupported b/c no forward progress guarantee) + ar.cache := 0.U // do not allow AXI to modify our transactions + ar.prot := AXI4Parameters.PROT_PRIVILEGED + ar.qos := 0.U // no QoS + if (MemcedgeOut.bundle.echoFields != Nil) { + val ar_extra = ar.echo(AXI4TLState) + ar_extra.source := 0.U + ar_extra.size := 0.U + } + if (MemcedgeOut.bundle.requestFields != Nil) { + val ar_user = ar.user(ReqSourceKey) + ar_user := 0.U + } + + def gen_wmask(strb: UInt): UInt = { + val extendedBits = VecInit((0 until MemcedgeOut.bundle.dataBits/8).map(i => Cat(Fill(7, strb((MemcedgeOut.bundle.dataBits/8)-1-i)), strb((MemcedgeOut.bundle.dataBits/8)-1-i)))) + extendedBits.reduce(_ ## _) + } + + val new_data = Reg(UInt(MemcedgeOut.bundle.dataBits.W)) + val new_strb = ~0.U((MemcedgeOut.bundle.dataBits/8).W) + val wmask = gen_wmask(in_w_req.strb) + + when(io.in_r.fire) { + new_data := (io.in_r.bits.data & ~wmask) | (in_w_req.data & wmask) + } + + when(s1_w_cacheable) { + io.out_w.valid := in_w_v && in_kt_v + io.out_w.bits.axi4 := in_w_req + io.out_w.bits.keyid := in_kt_req.keyid + io.out_w.bits.tweak := in_kt_req.tweak + }.otherwise { + io.out_w.valid := in_w_v && in_kt_v && in_r_v + io.out_w.bits.axi4 := in_w_req + io.out_w.bits.axi4.data := new_data + io.out_w.bits.axi4.strb := new_strb + io.out_w.bits.keyid := in_kt_req.keyid + io.out_w.bits.tweak := in_kt_req.tweak + } + +} + +class WdataEncrptyPipe(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_w = Flipped(DecoupledIO(new AXI4W_KT(true))) + val out_w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle)) + val enc_keyids = Output(Vec(MemencPipes, UInt(KeyIDBits.W))) + val enc_round_keys = Input(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + }) + val reg_encdec_result_0 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_encdec_result_1 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_axi4_other_result = Reg(Vec(MemencPipes, new AXI4BundleWWithoutData(MemcedgeIn.bundle))) + val reg_tweak_result_0 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_tweak_result_1 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_keyid = Reg(Vec(MemencPipes, UInt(KeyIDBits.W))) + val reg_encdec_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B))) + val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B))) + + val wire_axi4_other = Wire(new AXI4BundleWWithoutData(MemcedgeIn.bundle)) + wire_axi4_other.strb := io.in_w.bits.axi4.strb + wire_axi4_other.last := io.in_w.bits.axi4.last + wire_axi4_other.user := io.in_w.bits.axi4.user + + + val pipes_first_data_0 = Wire(UInt(128.W)) + val pipes_first_data_1 = Wire(UInt(128.W)) + if (HasDelayNoencryption) { + pipes_first_data_0 := io.in_w.bits.axi4.data(127,0) + pipes_first_data_1 := io.in_w.bits.axi4.data(255,128) + } else { + pipes_first_data_0 := io.in_w.bits.axi4.data(127,0) ^ io.in_w.bits.tweak(127, 0) + pipes_first_data_1 := io.in_w.bits.axi4.data(255,128) ^ io.in_w.bits.tweak(255,128) + } + + def configureModule(flag: Boolean, i: Int, keyId: UInt, dataIn: UInt, tweakIn: UInt, axi4In: AXI4BundleWWithoutData, roundKeys: UInt): OnePipeEncBase = { + when(wire_ready_result(i) && (if (i == 0) io.in_w.valid else reg_encdec_valid(i-1))) { + reg_encdec_valid(i) := true.B + }.elsewhen(reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_w.ready else wire_ready_result(i + 1))) { + reg_encdec_valid(i) := false.B + }.otherwise { + reg_encdec_valid(i) := reg_encdec_valid(i) + } + + wire_ready_result(i) := !reg_encdec_valid(i) || (reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_w.ready else wire_ready_result(i+1))) + + val module: OnePipeEncBase = if (HasDelayNoencryption) Module(new OnePipeForEncNoEnc()) else Module(new OnePipeForEnc()) + module.io.onepipe_in.keyid := keyId + module.io.onepipe_in.data_in := dataIn + module.io.onepipe_in.tweak_in := tweakIn + module.io.onepipe_in.axi4_other := axi4In + for (i <- 0 until 32/MemencPipes) { + module.io.onepipe_in.round_key_in(i) := roundKeys(i * 32 + 31, i * 32) + } + when((if (i == 0) io.in_w.valid else reg_encdec_valid(i-1)) && wire_ready_result(i)) { + if (flag) { + reg_encdec_result_0(i) := module.io.onepipe_out.result_out + reg_tweak_result_0(i) := module.io.onepipe_out.tweak_out + reg_axi4_other_result(i) := module.io.onepipe_out.axi4_other_out + reg_keyid(i) := module.io.onepipe_out.keyid_out + } else { + reg_encdec_result_1(i) := module.io.onepipe_out.result_out + reg_tweak_result_1(i) := module.io.onepipe_out.tweak_out + } + } + io.enc_keyids(i) := module.io.onepipe_out.keyid_out + module + } + val modules_0 = (0 until MemencPipes).map { i => + if (i == 0) { + configureModule(true, i, io.in_w.bits.keyid, pipes_first_data_0, io.in_w.bits.tweak(127, 0), wire_axi4_other, io.enc_round_keys(i)) + } else { + configureModule(true, i, reg_keyid(i-1), reg_encdec_result_0(i-1), reg_tweak_result_0(i-1), reg_axi4_other_result(i-1), io.enc_round_keys(i)) + } + } + val modules_1 = (0 until MemencPipes).map { i => + if (i == 0) { + configureModule(false, i, io.in_w.bits.keyid, pipes_first_data_1, io.in_w.bits.tweak(255,128), wire_axi4_other, io.enc_round_keys(i)) + } else { + configureModule(false, i, reg_keyid(i-1), reg_encdec_result_1(i-1), reg_tweak_result_1(i-1), reg_axi4_other_result(i-1), io.enc_round_keys(i)) + } + } + if (HasDelayNoencryption) { + io.out_w.bits.data := Cat(reg_encdec_result_1.last, reg_encdec_result_0.last) + } else { + val enc_0_out = Cat( + reg_encdec_result_0.last(31, 0), + reg_encdec_result_0.last(63, 32), + reg_encdec_result_0.last(95, 64), + reg_encdec_result_0.last(127, 96) + ) + val enc_1_out = Cat( + reg_encdec_result_1.last(31, 0), + reg_encdec_result_1.last(63, 32), + reg_encdec_result_1.last(95, 64), + reg_encdec_result_1.last(127, 96) + ) + io.out_w.bits.data := Cat(enc_1_out ^ reg_tweak_result_1.last, enc_0_out ^ reg_tweak_result_0.last) + } + + io.out_w.bits.strb := reg_axi4_other_result.last.strb + io.out_w.bits.last := reg_axi4_other_result.last.last + io.out_w.bits.user := reg_axi4_other_result.last.user + io.out_w.valid := reg_encdec_valid.last + io.in_w.ready := wire_ready_result(0) +} + +class TweakEncrptyTable(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val enq = Flipped(DecoupledIO(new Bundle { + val addr = UInt(PAddrBits.W) + val len = UInt(MemcedgeOut.bundle.lenBits.W) // number of beats - 1 + val id = UInt(MemcedgeOut.bundle.idBits.W) // 7 bits + })) + val req = Flipped(DecoupledIO(new Bundle { + val id = UInt(MemcedgeOut.bundle.idBits.W) + })) + val resp = DecoupledIO(new Bundle { + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(MemcedgeOut.bundle.dataBits.W) + }) + val dec_r = new Bundle { + val id = Input(UInt(MemcedgeOut.bundle.idBits.W)) + val mode = Output(Bool()) + } + val dec_keyid = Output(UInt(KeyIDBits.W)) + val dec_mode = Input(Bool()) + val tweak_round_keys = Input(Vec(32, UInt(32.W))) + val memenc_enable = Input(Bool()) + }) + + val tweak_in = Cat(0.U((128 - PAddrBits).W), Cat(io.enq.bits.addr(PAddrBits-1, 6), 0.U(6.W))) + // query the dec_mode from the round key + io.dec_keyid := io.enq.bits.addr(PAddrBits - 1, PAddrBits - KeyIDBits) + + val tweak_enc_module = Module(new TweakEncrypt(opt = false)) + val tweak_table = Module(new TweakTable()) + val tweak_gf128 = Module(new GF128()) + + // updata mode table + tweak_table.io.w_mode.bits.id := io.enq.bits.id + tweak_table.io.w_mode.bits.dec_mode := io.dec_mode && io.memenc_enable + tweak_table.io.w_mode.valid := io.enq.fire + + tweak_enc_module.io.tweak_enc_resp.ready := tweak_table.io.write.ready // always true + tweak_enc_module.io.tweak_enc_req.bits.tweak := tweak_in + tweak_enc_module.io.tweak_enc_req.bits.addr_in := io.enq.bits.addr + tweak_enc_module.io.tweak_enc_req.bits.len_in := io.enq.bits.len + tweak_enc_module.io.tweak_enc_req.bits.id_in := io.enq.bits.id + tweak_enc_module.io.tweak_enc_req.bits.tweak_round_keys := io.tweak_round_keys + tweak_enc_module.io.tweak_enc_req.valid := io.enq.valid && io.dec_mode && io.memenc_enable + + io.enq.ready := tweak_enc_module.io.tweak_enc_req.ready + + // write signal in tweak table + tweak_table.io.write.valid := tweak_enc_module.io.tweak_enc_resp.valid + tweak_table.io.write.bits.id := tweak_enc_module.io.tweak_enc_resp.bits.id_out + tweak_table.io.write.bits.addr := tweak_enc_module.io.tweak_enc_resp.bits.addr_out + tweak_table.io.write.bits.len := tweak_enc_module.io.tweak_enc_resp.bits.len_out + tweak_table.io.write.bits.tweak_encrpty := tweak_enc_module.io.tweak_enc_resp.bits.tweak_encrpty + + // read signal in tweak table + tweak_table.io.req.valid := io.req.valid + tweak_table.io.resp.ready := io.resp.ready + + tweak_table.io.req.bits.read_id := io.req.bits.id + + val tweak_encrpty = tweak_table.io.resp.bits.read_tweak + val tweak_counter = tweak_table.io.resp.bits.read_sel_counter + val keyid = tweak_table.io.resp.bits.read_keyid + + tweak_table.io.r_mode.id := io.dec_r.id + val mode = tweak_table.io.r_mode.dec_mode + io.dec_r.mode := mode + + tweak_gf128.io.tweak_in := tweak_encrpty + io.resp.bits.tweak := Mux(tweak_counter, tweak_gf128.io.tweak_out(511, 256), tweak_gf128.io.tweak_out(255, 0)) + io.resp.bits.keyid := keyid + io.resp.valid := tweak_table.io.resp.valid + io.req.ready := tweak_table.io.req.ready + +} + +class AXI4R_KT(opt:Boolean)(implicit val p: Parameters) extends Bundle with Memconsts +{ + val edgeUse = if (opt) MemcedgeIn else MemcedgeOut + val axi4 = new AXI4BundleR(edgeUse.bundle) + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(edgeUse.bundle.dataBits.W) +} + +class AXI4ReadMachine(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))) + val kt_req = DecoupledIO(new Bundle { + val id = UInt(MemcedgeOut.bundle.idBits.W) + }) + val in_kt = Flipped(DecoupledIO(new Bundle { + val keyid = UInt(KeyIDBits.W) + val tweak = UInt(MemcedgeOut.bundle.dataBits.W) + })) + val out_r = DecoupledIO(new AXI4R_KT(false)) + }) + val s1_r_val = RegInit(false.B) + val s1_r_req = RegEnable(io.in_r.bits, io.in_r.fire) + val s1_r_out_rdy = Wire(Bool()) + + val s2_r_val = RegInit(false.B) + val s2_r_in_rdy = Wire(Bool()) + val s2_r_req = RegEnable(s1_r_req, s1_r_val && s2_r_in_rdy) + + // ---------------- + // s0 stage + // ---------------- + io.in_r.ready := !s1_r_val || (s1_r_val && s1_r_out_rdy) + + // ---------------- + // s1 stage + // ---------------- + when(io.in_r.fire) { + s1_r_val := true.B + }.elsewhen(s1_r_val && s1_r_out_rdy) { + s1_r_val := false.B + }.otherwise { + s1_r_val := s1_r_val + } + + s1_r_out_rdy := s2_r_in_rdy && io.kt_req.ready + io.kt_req.valid := s1_r_val && s2_r_in_rdy + io.kt_req.bits.id := s1_r_req.id + + // ---------------- + // s2 stage + // ---------------- + when(s1_r_val && s1_r_out_rdy) { + s2_r_val := true.B + }.elsewhen(s2_r_val && io.out_r.fire) { + s2_r_val := false.B + }.otherwise { + s2_r_val := s2_r_val + } + s2_r_in_rdy := !s2_r_val || io.out_r.fire + + io.in_kt.ready := io.out_r.fire + + io.out_r.valid := s2_r_val && io.in_kt.valid + io.out_r.bits.axi4 := s2_r_req + io.out_r.bits.keyid := io.in_kt.bits.keyid + io.out_r.bits.tweak := io.in_kt.bits.tweak +} + +class RdataDecrptyPipe(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_r = Flipped(DecoupledIO(new AXI4R_KT(false))) + val out_r = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)) + val dec_keyids = Output(Vec(MemencPipes, UInt(KeyIDBits.W))) + val dec_round_keys = Input(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + }) + + val reg_encdec_result_0 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_encdec_result_1 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_axi4_other_result = Reg(Vec(MemencPipes, new AXI4BundleRWithoutData(MemcedgeOut.bundle))) + val reg_tweak_result_0 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_tweak_result_1 = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_keyid = Reg(Vec(MemencPipes, UInt(KeyIDBits.W))) + val reg_encdec_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B))) + val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B))) + + + val wire_axi4_other = Wire(new AXI4BundleRWithoutData(MemcedgeOut.bundle)) + wire_axi4_other.id := io.in_r.bits.axi4.id + wire_axi4_other.resp := io.in_r.bits.axi4.resp + wire_axi4_other.user := io.in_r.bits.axi4.user + wire_axi4_other.echo := io.in_r.bits.axi4.echo + wire_axi4_other.last := io.in_r.bits.axi4.last + + val pipes_first_data_0 = Wire(UInt(128.W)) + val pipes_first_data_1 = Wire(UInt(128.W)) + + if (HasDelayNoencryption) { + pipes_first_data_0 := io.in_r.bits.axi4.data(127,0) + pipes_first_data_1 := io.in_r.bits.axi4.data(255,128) + } else { + pipes_first_data_0 := io.in_r.bits.axi4.data(127,0) ^ io.in_r.bits.tweak(127, 0) + pipes_first_data_1 := io.in_r.bits.axi4.data(255,128) ^ io.in_r.bits.tweak(255,128) + } + def configureModule(flag: Boolean, i: Int, keyId: UInt, dataIn: UInt, tweakIn: UInt, axi4In: AXI4BundleRWithoutData, roundKeys: UInt): OnePipeDecBase = { + + when(wire_ready_result(i) && (if (i == 0) io.in_r.valid else reg_encdec_valid(i-1))) { + reg_encdec_valid(i) := true.B + }.elsewhen(reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_r.ready else wire_ready_result(i+1))) { + reg_encdec_valid(i) := false.B + }.otherwise { + reg_encdec_valid(i) := reg_encdec_valid(i) + } + + wire_ready_result(i) := !reg_encdec_valid(i) || (reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_r.ready else wire_ready_result(i+1))) + + val module: OnePipeDecBase = if (HasDelayNoencryption) Module(new OnePipeForDecNoDec()) else Module(new OnePipeForDec()) + module.io.onepipe_in.keyid := keyId + module.io.onepipe_in.data_in := dataIn + module.io.onepipe_in.tweak_in := tweakIn + module.io.onepipe_in.axi4_other := axi4In + for (i <- 0 until 32/MemencPipes) { + module.io.onepipe_in.round_key_in(i) := roundKeys(i * 32 + 31, i * 32) + } + when((if (i == 0) io.in_r.valid else reg_encdec_valid(i-1)) && wire_ready_result(i)) { + if (flag) { + reg_encdec_result_0(i) := module.io.onepipe_out.result_out + reg_tweak_result_0(i) := module.io.onepipe_out.tweak_out + reg_axi4_other_result(i) := module.io.onepipe_out.axi4_other_out + reg_keyid(i) := module.io.onepipe_out.keyid_out + } else { + reg_encdec_result_1(i) := module.io.onepipe_out.result_out + reg_tweak_result_1(i) := module.io.onepipe_out.tweak_out + } + } + io.dec_keyids(i) := module.io.onepipe_out.keyid_out + module + } + val modules_0 = (0 until MemencPipes).map { i => + if (i == 0) { + configureModule(true, i, io.in_r.bits.keyid, pipes_first_data_0, io.in_r.bits.tweak(127, 0), wire_axi4_other, io.dec_round_keys(i)) + } else { + configureModule(true, i, reg_keyid(i-1), reg_encdec_result_0(i-1), reg_tweak_result_0(i-1), reg_axi4_other_result(i-1), io.dec_round_keys(i)) + } + } + + val modules_1 = (0 until MemencPipes).map { i => + if (i == 0) { + configureModule(false, i, io.in_r.bits.keyid, pipes_first_data_1, io.in_r.bits.tweak(255,128), wire_axi4_other, io.dec_round_keys(i)) + } else { + configureModule(false, i, reg_keyid(i-1),reg_encdec_result_1(i-1), reg_tweak_result_1(i-1), reg_axi4_other_result(i-1), io.dec_round_keys(i)) + } + } + if (HasDelayNoencryption) { + io.out_r.bits.data := Cat(reg_encdec_result_1.last, reg_encdec_result_0.last) + } else { + val enc_0_out = Cat( + reg_encdec_result_0.last(31, 0), + reg_encdec_result_0.last(63, 32), + reg_encdec_result_0.last(95, 64), + reg_encdec_result_0.last(127, 96) + ) + val enc_1_out = Cat( + reg_encdec_result_1.last(31, 0), + reg_encdec_result_1.last(63, 32), + reg_encdec_result_1.last(95, 64), + reg_encdec_result_1.last(127, 96) + ) + io.out_r.bits.data := Cat(enc_1_out ^ reg_tweak_result_1.last, enc_0_out ^ reg_tweak_result_0.last) + } + + io.out_r.bits.id := reg_axi4_other_result.last.id + io.out_r.bits.resp := reg_axi4_other_result.last.resp + io.out_r.bits.user := reg_axi4_other_result.last.user + io.out_r.bits.echo := reg_axi4_other_result.last.echo + io.out_r.bits.last := reg_axi4_other_result.last.last + io.out_r.valid := reg_encdec_valid.last + io.in_r.ready := wire_ready_result(0) + +} + +class RdataRoute(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))) + val out_r0 = Irrevocable(new AXI4BundleR(MemcedgeIn.bundle)) + val out_r1 = Irrevocable(new AXI4BundleR(MemcedgeIn.bundle)) + }) + + val r_sel = io.in_r.bits.id(MemcedgeOut.bundle.idBits - 1).asBool + + io.out_r0.bits <> io.in_r.bits + io.out_r1.bits <> io.in_r.bits + + io.out_r0.valid := io.in_r.valid && !r_sel + io.out_r1.valid := io.in_r.valid && r_sel + io.in_r.ready := Mux(r_sel, io.out_r1.ready, io.out_r0.ready) +} + +class MemEncryptCSR(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val en = Input(Bool()) + val wmode = Input(Bool()) + val addr = Input(UInt(12.W)) + val wdata = Input(UInt(64.W)) + val wmask = Input(UInt(8.W)) + val rdata = Output(UInt(64.W)) // get rdata next cycle after en + val memenc_enable = Output(Bool()) + val keyextend_req = DecoupledIO(new Bundle { + val key = UInt(128.W) + val keyid = UInt(KeyIDBits.W) + val enc_mode = Bool() // 1:this keyid open enc 0:this keyid close enc + val tweak_flage = Bool() // 1:extend tweak key 0:extend keyid key + }) + val randomio = new Bundle { + val random_req = Output(Bool()) + val random_val = Input(Bool()) + val random_data = Input(Bool()) + } + }) + // CSR + val key_id = RegInit(0.U(5.W)) // [4:0] + val mode = RegInit(0.U(2.W)) // [6:5] + val tweak_flage = RegInit(0.U(1.W)) // [7] + val memenc_enable = if (HasDelayNoencryption) RegInit(true.B) else RegInit(false.B) // [8] + val memenc_enable_lock = RegInit(false.B) + val random_ready_flag = Wire(Bool()) // [32] + val key_expansion_idle = Wire(Bool()) // [33] + val last_req_accepted = RegInit(false.B) // [34] + val cfg_succesd = Wire(Bool()) // [35] + val key_init_req = RegInit(false.B) // [63] + // KEY0&1 + val key0 = RegInit(0.U(64.W)) + val key1 = RegInit(0.U(64.W)) + // RelPaddrBitsMap + val relpaddrbitsmap = ~0.U((PAddrBits - KeyIDBits).W) + // KeyIDBitsMap + val keyidbitsmap = ~0.U(PAddrBits.W) - ~0.U((PAddrBits - KeyIDBits).W) + // Version + val memenc_version_p0 = (0x0001).U(16.W) + val memenc_version_p1 = (0x0001).U(16.W) + val memenc_version_p2 = (0x00000002).U(32.W) + val memenc_version = Cat(memenc_version_p0, memenc_version_p1, memenc_version_p2) + + // READ + val rdata_reg = RegInit(0.U(64.W)) + when(io.en && !io.wmode && (io.addr(11,3) === 0.U)) { + rdata_reg := Cat(0.U(28.W), cfg_succesd, last_req_accepted, key_expansion_idle, random_ready_flag, 0.U(23.W), memenc_enable, tweak_flage, mode, key_id) + }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 3.U)) { + rdata_reg := relpaddrbitsmap + }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 4.U)) { + rdata_reg := keyidbitsmap + }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 5.U)) { + rdata_reg := memenc_version + }.otherwise { + rdata_reg := 0.U + } + + io.rdata := rdata_reg + + // WRITE + val wmask_legal = (io.wmask === (0xff).U) + + when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 0.U)) { + key_id := io.wdata(4,0) + mode := io.wdata(6,5) + tweak_flage := io.wdata(7) + key_init_req := io.wdata(63).asBool + }.otherwise { + key_init_req := false.B + } + when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 0.U) && (!memenc_enable_lock)) { + memenc_enable := io.wdata(8) + memenc_enable_lock := true.B + } + when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 1.U)) { + key0 := io.wdata + } + when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 2.U)) { + key1 := io.wdata + } + io.memenc_enable := memenc_enable + + // RANDOM COLLECT + val random_vec_data = RegInit(0.U(128.W)) + val random_cnt = RegInit(0.U(8.W)) + val random_key_init_done = Wire(Bool()) + io.randomio.random_req := random_cnt =/= 128.U(8.W) + random_ready_flag := random_cnt === 128.U(8.W) + + when(io.randomio.random_req && io.randomio.random_val) { + random_vec_data := Cat(random_vec_data(127,1), io.randomio.random_data) + } + + when(random_ready_flag && random_key_init_done) { + random_cnt := 0.U + }.elsewhen(io.randomio.random_req && io.randomio.random_val) { + random_cnt := random_cnt + 1.U + } + + // KEY Extend Req + key_expansion_idle := io.keyextend_req.ready + cfg_succesd := io.keyextend_req.ready + + val keyextend_req_valid = RegInit(false.B) + val req_leagl = Wire(Bool()) + req_leagl := (mode =/= 3.U(2.W)) && key_expansion_idle && ((mode =/= 2.U(2.W)) || random_ready_flag) + + when(key_init_req && req_leagl) { + keyextend_req_valid := true.B + }.elsewhen(io.keyextend_req.fire) { + keyextend_req_valid := false.B + }.otherwise { + keyextend_req_valid := keyextend_req_valid + } + + when(key_init_req && req_leagl) { + last_req_accepted := true.B + }.elsewhen(key_init_req) { + last_req_accepted := false.B + }.otherwise { + last_req_accepted := last_req_accepted + } + + random_key_init_done := io.keyextend_req.fire && (mode === 2.U(2.W)) + + io.keyextend_req.valid := keyextend_req_valid + io.keyextend_req.bits.key := Mux(mode === 1.U(2.W), Cat(key1, key0), random_vec_data) + io.keyextend_req.bits.keyid := key_id + io.keyextend_req.bits.enc_mode := mode =/= 0.U(2.W) + io.keyextend_req.bits.tweak_flage := tweak_flage.asBool +} + +class KeyTableEntry extends Bundle { + val round_key_data = Vec(32, UInt(32.W)) + val encdec_mode = Bool() +} +class KeyTable(implicit p: Parameters) extends MemEncryptModule { + val io = IO(new Bundle { + val write_req = Input(new Bundle { + val keyid = UInt(KeyIDBits.W) + val keyid_valid = Input(Bool()) + val enc_mode = Input(Bool()) // 1: this keyid open enc, 0: this keyid close enc + val round_id = UInt(5.W) + val data = Input(UInt(32.W)) + }) + + val enc_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W))) + val enc_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + val dec_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W))) + val dec_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + val dec = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) // query dec_mode in advance in the AR channel + val mode = Output(Bool()) + } + val enc = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) // query enc_mode in advance in the AW channel + val mode = Output(Bool()) + } +}) + + val init_entry = Wire(new KeyTableEntry) + init_entry.round_key_data := DontCare // Keep round_key_data as default (uninitialized) + if (HasDelayNoencryption) { + init_entry.encdec_mode := true.B + } else { + init_entry.encdec_mode := false.B + } + val table = RegInit(VecInit(Seq.fill(1 << KeyIDBits)(init_entry))) + val wire_enc_round_keys = Wire(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + val wire_dec_round_keys = Wire(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + + // write and updata mode + when(io.write_req.keyid_valid && io.write_req.enc_mode) { + val entry = table(io.write_req.keyid) + entry.encdec_mode := io.write_req.enc_mode + entry.round_key_data(io.write_req.round_id) := io.write_req.data + } + when(io.write_req.keyid_valid && !io.write_req.enc_mode) { + val entry = table(io.write_req.keyid) + entry.encdec_mode := io.write_req.enc_mode + } + +// read logic + for (i <- 0 until MemencPipes) { + val enc_entry = table(io.enc_keyids(i)) + val enc_round_key_parts = VecInit(Seq.fill(32 / MemencPipes)(0.U(32.W))) + for (j <- 0 until (32 / MemencPipes)) { + enc_round_key_parts((32 / MemencPipes) - 1 - j) := enc_entry.round_key_data(i.U * (32 / MemencPipes).U + j.U) + } + wire_enc_round_keys(i) := enc_round_key_parts.reduce(Cat(_, _)) + + val dec_entry = table(io.dec_keyids(i)) + val dec_round_key_parts = VecInit(Seq.fill(32 / MemencPipes)(0.U(32.W))) + for (j <- 0 until (32 / MemencPipes)) { + dec_round_key_parts((32 / MemencPipes) - 1 - j) := dec_entry.round_key_data(31.U - (i.U * (32 / MemencPipes).U + j.U)) + } + wire_dec_round_keys(i) := dec_round_key_parts.reduce(Cat(_, _)) + } + // output read data(round keys, enc/dec_mode, ar_mode, aw_mode) + val dec_mode_entry = table(io.dec.keyid) + io.dec.mode := dec_mode_entry.encdec_mode + + val enc_mode_entry = table(io.enc.keyid) + io.enc.mode := enc_mode_entry.encdec_mode + + io.enc_round_keys := wire_enc_round_keys + io.dec_round_keys := wire_dec_round_keys + +} + +class KeyExtender(implicit p: Parameters) extends MemEncryptModule{ + val io = IO(new Bundle { + val keyextend_req = Flipped(DecoupledIO(new Bundle { + val key = UInt(128.W) + val keyid = UInt(KeyIDBits.W) + val enc_mode = Bool() // 1:this keyid open enc 0:this keyid close enc + val tweak_flage = Bool() // 1:extend tweak key 0:extend keyid key + })) + val tweak_round_keys = Output(Vec(32, UInt(32.W))) + val enc_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W))) + val enc_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + val dec_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W))) + val dec_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W))) + val dec = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) // query dec_mode in advance in the AR channel + val mode = Output(Bool()) + } + val enc = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) // query enc_mode in advance in the AW channel + val mode = Output(Bool()) + } + }) + + val idle :: keyExpansion :: Nil = Enum(2) + val current = RegInit(idle) + val next = WireDefault(idle) + current := next + + val count_round = RegInit(0.U(5.W)) + val reg_count_round = RegNext(count_round) + val reg_user_key = RegInit(0.U(128.W)) + val data_for_round = Wire(UInt(128.W)) + val data_after_round = Wire(UInt(128.W)) + val reg_data_after_round = RegInit(0.U(128.W)) + val key_exp_finished_out = RegInit(1.U) + val reg_key_valid = RegNext(io.keyextend_req.valid, false.B) + val reg_tweak_round_keys = Reg(Vec(32, UInt(32.W))) + + + switch(current) { + is(idle) { + when(!reg_key_valid && io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) { + next := keyExpansion + } + } + is(keyExpansion) { + when(reg_count_round === 31.U) { + next := idle + }.otherwise { + next := keyExpansion + } + } + } + + when(next === keyExpansion) { + count_round := count_round + 1.U + }.otherwise { + count_round := 0.U + } + + when(!reg_key_valid && io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) { + reg_user_key := io.keyextend_req.bits.key + } + + when(current === keyExpansion && next === idle) { + key_exp_finished_out := true.B + }.elsewhen(io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) { + key_exp_finished_out := false.B + } + io.keyextend_req.ready := key_exp_finished_out + + // Data for round calculation + data_for_round := Mux(reg_count_round =/= 0.U, reg_data_after_round, reg_user_key) + val cki = Module(new GetCKI) + cki.io.countRoundIn := count_round + val one_round = Module(new OneRoundForKeyExp) + one_round.io.countRoundIn := reg_count_round + one_round.io.dataIn := data_for_round + one_round.io.ckParameterIn := cki.io.ckiOut + data_after_round := one_round.io.resultOut + + when(current === keyExpansion) { + reg_data_after_round := data_after_round + } + + val keyTable = Module(new KeyTable()) + keyTable.io.write_req.keyid := io.keyextend_req.bits.keyid + keyTable.io.write_req.enc_mode := io.keyextend_req.bits.enc_mode + keyTable.io.write_req.round_id := reg_count_round + keyTable.io.write_req.data := data_after_round(31, 0) + + keyTable.io.enc_keyids := io.enc_keyids + keyTable.io.dec_keyids := io.dec_keyids + keyTable.io.dec.keyid := io.dec.keyid + keyTable.io.enc.keyid := io.enc.keyid + io.dec.mode := keyTable.io.dec.mode + io.enc.mode := keyTable.io.enc.mode + io.enc_round_keys := keyTable.io.enc_round_keys + io.dec_round_keys := keyTable.io.dec_round_keys + + + when(io.keyextend_req.bits.tweak_flage) { + reg_tweak_round_keys(reg_count_round) := data_after_round(31, 0) + keyTable.io.write_req.keyid_valid := false.B + }.otherwise { + keyTable.io.write_req.keyid_valid := current + } + io.tweak_round_keys := reg_tweak_round_keys +} + +class AXI4MemEncrypt(address: AddressSet)(implicit p: Parameters) extends LazyModule with Memconsts +{ + require (isPow2(MemencPipes), s"AXI4MemEncrypt: MemencPipes must be a power of two, not $MemencPipes") + require (PAddrBits > KeyIDBits, s"AXI4MemEncrypt: PAddrBits must be greater than KeyIDBits") + + val node = AXI4AdapterNode( + masterFn = { mp => + val new_idbits = log2Ceil(mp.endId) + 1 + // Create one new "master" per ID + val masters = Array.tabulate(1 << new_idbits) { i => AXI4MasterParameters( + name = "", + id = IdRange(i, i+1), + aligned = true, + maxFlight = Some(0)) + } + // Accumulate the names of masters we squish + val names = Array.fill(1 << new_idbits) { new scala.collection.mutable.HashSet[String]() } + // Squash the information from original masters into new ID masters + mp.masters.foreach { m => + for (i <- 0 until (1 << new_idbits)) { + val accumulated = masters(i) + names(i) += m.name + masters(i) = accumulated.copy( + aligned = accumulated.aligned && m.aligned, + maxFlight = accumulated.maxFlight.flatMap { o => m.maxFlight.map { n => o+n } }) + } + } + val finalNameStrings = names.map { n => if (n.isEmpty) "(unused)" else n.toList.mkString(", ") } + mp.copy(masters = masters.zip(finalNameStrings).map { case (m, n) => m.copy(name = n) }) + }, + slaveFn = { sp => sp }) + + val device = new SimpleDevice("mem-encrypt-unit", Seq("iie,memencrypt0")) + val ctrl_node = APBSlaveNode(Seq(APBSlavePortParameters( + Seq(APBSlaveParameters( + address = List(address), + resources = device.reg, + device = Some(device), + regionType = RegionType.IDEMPOTENT)), + beatBytes = 8))) + + lazy val module = new Impl + class Impl extends LazyModuleImp(this) { + val io = IO(new Bundle { + val random_req = Output(Bool()) + val random_val = Input(Bool()) + val random_data = Input(Bool()) + }) + + val en = Wire(Bool()) + val wmode = Wire(Bool()) + val addr = Wire(UInt(12.W)) + val wdata = Wire(UInt(64.W)) + val wmask = Wire(UInt(8.W)) + val rdata = Wire(UInt(64.W)) // get rdata next cycle after en + + (ctrl_node.in) foreach { case (ctrl_in, _) => + en := ctrl_in.psel && !ctrl_in.penable + wmode := ctrl_in.pwrite + addr := ctrl_in.paddr(11, 0) + wdata := ctrl_in.pwdata + wmask := ctrl_in.pstrb + ctrl_in.pready := true.B + ctrl_in.pslverr := false.B + ctrl_in.prdata := rdata + } + + (node.in zip node.out) foreach { case ((in, edgeIn), (out, edgeOut)) => + require (edgeIn.bundle.dataBits == 256, s"AXI4MemEncrypt: edgeIn dataBits must be 256") + require (edgeOut.bundle.dataBits == 256, s"AXI4MemEncrypt: edgeOut dataBits must be 256") + + val memencParams: Parameters = p.alterPartial { + case MemcEdgeInKey => edgeIn + case MemcEdgeOutKey => edgeOut + } + // ------------------------------------- + // MemEncrypt Config and State Registers + // ------------------------------------- + val memenc_enable = Wire(Bool()) + val memencrypt_csr = Module(new MemEncryptCSR()(memencParams)) + memencrypt_csr.io.en := en + memencrypt_csr.io.wmode := wmode + memencrypt_csr.io.addr := addr + memencrypt_csr.io.wdata := wdata + memencrypt_csr.io.wmask := wmask + memenc_enable := memencrypt_csr.io.memenc_enable + rdata := memencrypt_csr.io.rdata + + io.random_req := memencrypt_csr.io.randomio.random_req + memencrypt_csr.io.randomio.random_val := io.random_val + memencrypt_csr.io.randomio.random_data := io.random_data + + // ------------------------------------- + // Key Extender & Round Key Lookup Table + // ------------------------------------- + val key_extender = Module(new KeyExtender()(memencParams)) + key_extender.io.keyextend_req :<>= memencrypt_csr.io.keyextend_req + + // ------------------- + // AXI4 chanel B + // ------------------- + Connectable.waiveUnmatched(in.b, out.b) match { + case (lhs, rhs) => lhs.squeezeAll :<>= rhs.squeezeAll + } + + val write_route = Module(new WriteChanelRoute()(memencParams)) + val aw_tweakenc = Module(new TweakEncrptyQueue()(memencParams)) + val waddr_q = Module(new IrrevocableQueue(chiselTypeOf(in.aw.bits), entries = MemencPipes+1)) + val wdata_q = Module(new IrrevocableQueue(chiselTypeOf(in.w.bits), entries = MemencPipes+1)) + val write_machine = Module(new AXI4WriteMachine()(memencParams)) + val axi4w_kt_q = Module(new Queue(new AXI4W_KT(false)(memencParams), entries = 2, flow = true)) + val wdata_encpipe = Module(new WdataEncrptyPipe()(memencParams)) + val write_arb = Module(new WriteChanelArbiter()(memencParams)) + + // ------------------- + // AXI4 Write Route + // Unencrypt & Encrypt + // ------------------- + write_route.io.memenc_enable := memenc_enable + key_extender.io.enc.keyid := write_route.io.enc_keyid + write_route.io.enc_mode := key_extender.io.enc.mode + + write_route.io.in.aw :<>= in.aw + write_route.io.in.w :<>= in.w + + val unenc_aw = write_route.io.out0.aw + val unenc_w = write_route.io.out0.w + val pre_enc_aw = write_route.io.out1.aw + val pre_enc_w = write_route.io.out1.w + + // ------------------- + // AXI4 chanel AW + // ------------------- + pre_enc_aw.ready := waddr_q.io.enq.ready && aw_tweakenc.io.enq.ready + waddr_q.io.enq.valid := pre_enc_aw.valid && aw_tweakenc.io.enq.ready + aw_tweakenc.io.enq.valid := pre_enc_aw.valid && waddr_q.io.enq.ready + + waddr_q.io.enq.bits := pre_enc_aw.bits + waddr_q.io.enq.bits.addr := pre_enc_aw.bits.addr(PAddrBits-KeyIDBits-1, 0) + aw_tweakenc.io.enq.bits.addr := pre_enc_aw.bits.addr + aw_tweakenc.io.enq.bits.len := pre_enc_aw.bits.len + aw_tweakenc.io.tweak_round_keys := key_extender.io.tweak_round_keys + + // ------------------- + // AXI4 chanel W + // ------------------- + wdata_q.io.enq :<>= pre_enc_w + write_machine.io.in_w :<>= wdata_q.io.deq + write_machine.io.in_kt :<>= aw_tweakenc.io.deq + axi4w_kt_q.io.enq :<>= write_machine.io.out_w + wdata_encpipe.io.in_w :<>= axi4w_kt_q.io.deq + key_extender.io.enc_keyids := wdata_encpipe.io.enc_keyids + wdata_encpipe.io.enc_round_keys := key_extender.io.enc_round_keys + + // ------------------- + // AXI4 Write Arbiter + // Unencrypt & Encrypt + // ------------------- + write_arb.io.in0.aw :<>= unenc_aw + write_arb.io.in0.aw.bits.addr := unenc_aw.bits.addr(PAddrBits-KeyIDBits-1, 0) + write_arb.io.in0.w :<>= unenc_w + + write_arb.io.in1.aw.valid := waddr_q.io.deq.valid && (waddr_q.io.deq.bits.len =/=0.U || write_machine.io.uncache_en) + waddr_q.io.deq.ready := write_arb.io.in1.aw.ready && (waddr_q.io.deq.bits.len =/=0.U || write_machine.io.uncache_en) + write_machine.io.uncache_commit := write_arb.io.in1.aw.fire + write_arb.io.in1.aw.bits := waddr_q.io.deq.bits + write_arb.io.in1.w :<>= wdata_encpipe.io.out_w + + out.aw :<>= write_arb.io.out.aw + out.w :<>= write_arb.io.out.w + + val ar_arb = Module(new IrrevocableArbiter(chiselTypeOf(out.ar.bits), 2)) + val ar_tweakenc = Module(new TweakEncrptyTable()(memencParams)) + val read_machine = Module(new AXI4ReadMachine()(memencParams)) + val axi4r_kt_q = Module(new Queue(new AXI4R_KT(false)(memencParams), entries = 2, flow = true)) + val pre_dec_rdata_route = Module(new RdataChanelRoute()(memencParams)) + val rdata_decpipe = Module(new RdataDecrptyPipe()(memencParams)) + val r_arb = Module(new IrrevocableArbiter(chiselTypeOf(out.r.bits), 2)) + val post_dec_rdata_route = Module(new RdataRoute()(memencParams)) + + // ------------------- + // AXI4 chanel AR + // ------------------- + ar_arb.io.in(0) :<>= write_machine.io.out_ar + // DecoupledIO connect IrrevocableIO + ar_arb.io.in(1).valid := in.ar.valid + ar_arb.io.in(1).bits := in.ar.bits + in.ar.ready := ar_arb.io.in(1).ready + + ar_arb.io.out.ready := out.ar.ready && ar_tweakenc.io.enq.ready + + ar_tweakenc.io.enq.valid := ar_arb.io.out.valid && out.ar.ready + ar_tweakenc.io.enq.bits.addr := ar_arb.io.out.bits.addr + ar_tweakenc.io.enq.bits.len := ar_arb.io.out.bits.len + ar_tweakenc.io.enq.bits.id := ar_arb.io.out.bits.id + ar_tweakenc.io.tweak_round_keys := key_extender.io.tweak_round_keys + ar_tweakenc.io.memenc_enable := memenc_enable + key_extender.io.dec.keyid := ar_tweakenc.io.dec_keyid + ar_tweakenc.io.dec_mode := key_extender.io.dec.mode + + out.ar.valid := ar_arb.io.out.valid && ar_tweakenc.io.enq.ready + out.ar.bits := ar_arb.io.out.bits + out.ar.bits.addr := ar_arb.io.out.bits.addr(PAddrBits-KeyIDBits-1, 0) + + // ------------------- + // AXI4 Rdata Route + // Unencrypt & Encrypt + // ------------------- + pre_dec_rdata_route.io.in_r :<>= out.r + ar_tweakenc.io.dec_r.id := pre_dec_rdata_route.io.dec_rid + pre_dec_rdata_route.io.dec_mode := ar_tweakenc.io.dec_r.mode + + val undec_r = pre_dec_rdata_route.io.out_r0 + val pre_dec_r = pre_dec_rdata_route.io.out_r1 + + // ------------------- + // AXI4 chanel R + // ------------------- + read_machine.io.in_r :<>= pre_dec_r + ar_tweakenc.io.req :<>= read_machine.io.kt_req + read_machine.io.in_kt :<>= ar_tweakenc.io.resp + axi4r_kt_q.io.enq :<>= read_machine.io.out_r + rdata_decpipe.io.in_r :<>= axi4r_kt_q.io.deq + key_extender.io.dec_keyids := rdata_decpipe.io.dec_keyids + rdata_decpipe.io.dec_round_keys := key_extender.io.dec_round_keys + + // ------------------- + // AXI4 Rdata Arbiter + // Unencrypt & Encrypt + // ------------------- + r_arb.io.in(0) :<>= undec_r + r_arb.io.in(1) :<>= rdata_decpipe.io.out_r + + post_dec_rdata_route.io.in_r :<>= r_arb.io.out + write_machine.io.in_r :<>= post_dec_rdata_route.io.out_r1 + in.r :<>= post_dec_rdata_route.io.out_r0 + } + } +} + +object AXI4MemEncrypt +{ + def apply(address: AddressSet)(implicit p: Parameters): AXI4Node = + { + val axi4memenc = LazyModule(new AXI4MemEncrypt(address)) + axi4memenc.node + } +} diff --git a/src/main/scala/device/MemEncryptUtil.scala b/src/main/scala/device/MemEncryptUtil.scala new file mode 100644 index 00000000000..38dfffbec87 --- /dev/null +++ b/src/main/scala/device/MemEncryptUtil.scala @@ -0,0 +1,821 @@ +/*************************************************************************************** +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences +* +* XiangShan is licensed under Mulan PSL v2. +* You can use this software according to the terms and conditions of the Mulan PSL v2. +* You may obtain a copy of Mulan PSL v2 at: +* http://license.coscl.org.cn/MulanPSL2 +* +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +* +* See the Mulan PSL v2 for more details. +***************************************************************************************/ + +package device + +import chisel3._ +import chisel3.util._ +import chisel3.util.HasBlackBoxResource +import org.chipsalliance.cde.config.Parameters +import freechips.rocketchip.amba.axi4._ +import freechips.rocketchip.diplomacy._ +import freechips.rocketchip.util._ +import freechips.rocketchip.amba.apb._ +import freechips.rocketchip.tilelink.AXI4TLState + +class IrrevocableQueue[T <: Data](gen: T, entries: Int, flow: Boolean = false) extends Module { + val io = IO(new Bundle { + val enq = Flipped(Irrevocable(gen)) + val deq = Irrevocable(gen) + }) + val queue = Module(new Queue(gen, entries = entries, flow = flow)) + + queue.io.enq.valid := io.enq.valid + queue.io.enq.bits := io.enq.bits + io.enq.ready := queue.io.enq.ready + + io.deq.valid := queue.io.deq.valid + io.deq.bits := queue.io.deq.bits + queue.io.deq.ready := io.deq.ready +} + +class IrrevocableArbiter[T <: Data](gen: T, n: Int) extends Module { + val io = IO(new Bundle { + val in = Flipped(Vec(n, Irrevocable(gen))) + val out = Irrevocable(gen) + }) + + val decoupledIn = io.in.map { irrevocable => + val decoupled = Wire(Decoupled(gen)) + decoupled.valid := irrevocable.valid + decoupled.bits := irrevocable.bits + irrevocable.ready := decoupled.ready + decoupled + } + + val arbiter = Module(new Arbiter(gen, n)) + arbiter.io.in <> decoupledIn + + io.out.valid := arbiter.io.out.valid + io.out.bits := arbiter.io.out.bits + arbiter.io.out.ready := io.out.ready +} + +// CKI (Cipher Key Input) is a constant input used in the SM4 encryption algorithm. +// It is part of the key expansion process and participates in generating subkeys. +// During each round of the key expansion, the CKI value is mixed with other constants and +// the initial key to enhance the security of the encryption algorithm. +class GetCKI extends Module { + val io = IO(new Bundle { + val countRoundIn = Input(UInt(5.W)) + val ckiOut = Output(UInt(32.W)) + }) + val ckiOutReg= RegInit(0.U(32.W)) + // 32 32-bit CKI constant values + val ckiValuesVec = VecInit(Seq( + "h00070e15".U, "h1c232a31".U, "h383f464d".U, "h545b6269".U, + "h70777e85".U, "h8c939aa1".U, "ha8afb6bd".U, "hc4cbd2d9".U, + "he0e7eef5".U, "hfc030a11".U, "h181f262d".U, "h343b4249".U, + "h50575e65".U, "h6c737a81".U, "h888f969d".U, "ha4abb2b9".U, + "hc0c7ced5".U, "hdce3eaf1".U, "hf8ff060d".U, "h141b2229".U, + "h30373e45".U, "h4c535a61".U, "h686f767d".U, "h848b9299".U, + "ha0a7aeb5".U, "hbcc3cad1".U, "hd8dfe6ed".U, "hf4fb0209".U, + "h10171e25".U, "h2c333a41".U, "h484f565d".U, "h646b7279".U + )) + when(io.countRoundIn < 32.U) { + ckiOutReg := ckiValuesVec(io.countRoundIn) + }.otherwise { + ckiOutReg := 0.U + } + io.ckiOut := ckiOutReg +} + + +// S-box is used in SM4 for nonlinear transformations during encryption processes. +// SM4 uses a fixed 256 byte S-box for byte replacement. +// This replacement process is achieved by replacing the input 8-bit data +// with the corresponding values in the S-box lookup table. +class SboxReplace extends Module { + val io = IO(new Bundle { + val dataIn = Input(UInt(8.W)) + val resultOut = Output(UInt(8.W)) + }) + // A 256 element S-box lookup table, where each element is an 8-bit hexadecimal constant + val sbox = VecInit(Seq( + 0xd6.U, 0x90.U, 0xe9.U, 0xfe.U, 0xcc.U, 0xe1.U, 0x3d.U, 0xb7.U, 0x16.U, 0xb6.U, 0x14.U, 0xc2.U, 0x28.U, 0xfb.U, 0x2c.U, 0x05.U, + 0x2b.U, 0x67.U, 0x9a.U, 0x76.U, 0x2a.U, 0xbe.U, 0x04.U, 0xc3.U, 0xaa.U, 0x44.U, 0x13.U, 0x26.U, 0x49.U, 0x86.U, 0x06.U, 0x99.U, + 0x9c.U, 0x42.U, 0x50.U, 0xf4.U, 0x91.U, 0xef.U, 0x98.U, 0x7a.U, 0x33.U, 0x54.U, 0x0b.U, 0x43.U, 0xed.U, 0xcf.U, 0xac.U, 0x62.U, + 0xe4.U, 0xb3.U, 0x1c.U, 0xa9.U, 0xc9.U, 0x08.U, 0xe8.U, 0x95.U, 0x80.U, 0xdf.U, 0x94.U, 0xfa.U, 0x75.U, 0x8f.U, 0x3f.U, 0xa6.U, + 0x47.U, 0x07.U, 0xa7.U, 0xfc.U, 0xf3.U, 0x73.U, 0x17.U, 0xba.U, 0x83.U, 0x59.U, 0x3c.U, 0x19.U, 0xe6.U, 0x85.U, 0x4f.U, 0xa8.U, + 0x68.U, 0x6b.U, 0x81.U, 0xb2.U, 0x71.U, 0x64.U, 0xda.U, 0x8b.U, 0xf8.U, 0xeb.U, 0x0f.U, 0x4b.U, 0x70.U, 0x56.U, 0x9d.U, 0x35.U, + 0x1e.U, 0x24.U, 0x0e.U, 0x5e.U, 0x63.U, 0x58.U, 0xd1.U, 0xa2.U, 0x25.U, 0x22.U, 0x7c.U, 0x3b.U, 0x01.U, 0x21.U, 0x78.U, 0x87.U, + 0xd4.U, 0x00.U, 0x46.U, 0x57.U, 0x9f.U, 0xd3.U, 0x27.U, 0x52.U, 0x4c.U, 0x36.U, 0x02.U, 0xe7.U, 0xa0.U, 0xc4.U, 0xc8.U, 0x9e.U, + 0xea.U, 0xbf.U, 0x8a.U, 0xd2.U, 0x40.U, 0xc7.U, 0x38.U, 0xb5.U, 0xa3.U, 0xf7.U, 0xf2.U, 0xce.U, 0xf9.U, 0x61.U, 0x15.U, 0xa1.U, + 0xe0.U, 0xae.U, 0x5d.U, 0xa4.U, 0x9b.U, 0x34.U, 0x1a.U, 0x55.U, 0xad.U, 0x93.U, 0x32.U, 0x30.U, 0xf5.U, 0x8c.U, 0xb1.U, 0xe3.U, + 0x1d.U, 0xf6.U, 0xe2.U, 0x2e.U, 0x82.U, 0x66.U, 0xca.U, 0x60.U, 0xc0.U, 0x29.U, 0x23.U, 0xab.U, 0x0d.U, 0x53.U, 0x4e.U, 0x6f.U, + 0xd5.U, 0xdb.U, 0x37.U, 0x45.U, 0xde.U, 0xfd.U, 0x8e.U, 0x2f.U, 0x03.U, 0xff.U, 0x6a.U, 0x72.U, 0x6d.U, 0x6c.U, 0x5b.U, 0x51.U, + 0x8d.U, 0x1b.U, 0xaf.U, 0x92.U, 0xbb.U, 0xdd.U, 0xbc.U, 0x7f.U, 0x11.U, 0xd9.U, 0x5c.U, 0x41.U, 0x1f.U, 0x10.U, 0x5a.U, 0xd8.U, + 0x0a.U, 0xc1.U, 0x31.U, 0x88.U, 0xa5.U, 0xcd.U, 0x7b.U, 0xbd.U, 0x2d.U, 0x74.U, 0xd0.U, 0x12.U, 0xb8.U, 0xe5.U, 0xb4.U, 0xb0.U, + 0x89.U, 0x69.U, 0x97.U, 0x4a.U, 0x0c.U, 0x96.U, 0x77.U, 0x7e.U, 0x65.U, 0xb9.U, 0xf1.U, 0x09.U, 0xc5.U, 0x6e.U, 0xc6.U, 0x84.U, + 0x18.U, 0xf0.U, 0x7d.U, 0xec.U, 0x3a.U, 0xdc.U, 0x4d.U, 0x20.U, 0x79.U, 0xee.U, 0x5f.U, 0x3e.U, 0xd7.U, 0xcb.U, 0x39.U, 0x48.U + )) + + io.resultOut := sbox(io.dataIn) +} + +// Nonlinear Transformation in Data Encryption Process +class TransformForEncDec extends Module { + val io = IO(new Bundle { + val data_in = Input(UInt(32.W)) + val result_out = Output(UInt(32.W)) + }) + + val bytes_in = VecInit(Seq(io.data_in(7, 0), io.data_in(15, 8), io.data_in(23, 16), io.data_in(31, 24))) + val bytes_replaced = Wire(Vec(4, UInt(8.W))) + val word_replaced = Wire(UInt(32.W)) + + val sbox_replace_modules = VecInit(Seq.fill(4)(Module(new SboxReplace).io)) + for (i <- 0 until 4) { + sbox_replace_modules(i).dataIn := bytes_in(i) + bytes_replaced(i) := sbox_replace_modules(i).resultOut + } + + word_replaced := Cat(bytes_replaced.reverse) + + io.result_out := ((word_replaced ^ Cat(word_replaced(29, 0), word_replaced(31, 30))) ^ + (Cat(word_replaced(21, 0), word_replaced(31, 22)) ^ Cat(word_replaced(13, 0), word_replaced(31, 14)))) ^ + Cat(word_replaced(7, 0), word_replaced(31, 8)) +} + + +// Nonlinear Transformation in Key Expansion Process +class TransformForKeyExp extends Module { + val io = IO(new Bundle { + val data_in = Input(UInt(32.W)) + val data_after_linear_key_out = Output(UInt(32.W)) + }) + val bytes_in = VecInit(Seq(io.data_in(7, 0), io.data_in(15, 8), io.data_in(23, 16), io.data_in(31, 24))) + val bytes_replaced = Wire(Vec(4, UInt(8.W))) + val word_replaced = Wire(UInt(32.W)) + + val sbox_replace_modules = VecInit(Seq.fill(4)(Module(new SboxReplace).io)) + for (i <- 0 until 4) { + sbox_replace_modules(i).dataIn := bytes_in(i) + bytes_replaced(i) := sbox_replace_modules(i).resultOut + } + + word_replaced := Cat(bytes_replaced.reverse) + + io.data_after_linear_key_out := (word_replaced ^ Cat(word_replaced(18, 0), word_replaced(31, 19))) ^ Cat(word_replaced(8, 0), word_replaced(31, 9)) +} + +// The key expansion algorithm requires a total of 32 rounds of operations, including one round of operation +class OneRoundForKeyExp extends Module { + val io = IO(new Bundle { + val countRoundIn = Input(UInt(5.W)) + val dataIn = Input(UInt(128.W)) + val ckParameterIn = Input(UInt(32.W)) + val resultOut = Output(UInt(128.W)) + }) + // In key expansion, the first step is to XOR each word of the original key with the system parameter to obtain four new words. + // system parameter: FK0, FK1, FK2, FK3. + val FK0 = "ha3b1bac6".U + val FK1 = "h56aa3350".U + val FK2 = "h677d9197".U + val FK3 = "hb27022dc".U + + val word = VecInit(Seq(io.dataIn(127, 96), io.dataIn(95, 64), io.dataIn(63, 32), io.dataIn(31, 0))) + + val k0 = word(0) ^ FK0 + val k1 = word(1) ^ FK1 + val k2 = word(2) ^ FK2 + val k3 = word(3) ^ FK3 + + + val dataForXor = io.ckParameterIn + val tmp0 = Mux(io.countRoundIn === 0.U, k1 ^ k2, word(1) ^ word(2)) + val tmp1 = Mux(io.countRoundIn === 0.U, k3 ^ dataForXor, word(3) ^ dataForXor) + val dataForTransform = tmp0 ^ tmp1 + + val transformKey = Module(new TransformForKeyExp) + transformKey.io.data_in := dataForTransform + + io.resultOut := Mux(io.countRoundIn === 0.U, + Cat(k1, k2, k3, transformKey.io.data_after_linear_key_out ^ k0), + Cat(word(1), word(2), word(3), transformKey.io.data_after_linear_key_out ^ word(0))) +} + +// The SM4 encryption algorithm requires a total of 32 rounds of operations, including one round of operation +class OneRoundForEncDec extends Module { + val io = IO(new Bundle { + val data_in = Input(UInt(128.W)) + val round_key_in = Input(UInt(32.W)) + val result_out = Output(UInt(128.W)) + }) + + val word = VecInit(Seq(io.data_in(127, 96), io.data_in(95, 64), io.data_in(63, 32), io.data_in(31, 0))) + + val tmp0 = word(1) ^ word(2) + val tmp1 = word(3) ^ io.round_key_in + val data_for_transform = tmp0 ^ tmp1 + + val transform_encdec = Module(new TransformForEncDec) + transform_encdec.io.data_in := data_for_transform + + io.result_out := Cat(word(1), word(2), word(3), transform_encdec.io.result_out ^ word(0)) +} + + + +class AXI4BundleWWithoutData(params: AXI4BundleParameters) extends Bundle { + val strb = UInt((params.dataBits/8).W) + val last = Bool() + val user = BundleMap(params.requestFields.filter(_.key.isData)) +} + +class AXI4BundleRWithoutData(params: AXI4BundleParameters) extends Bundle { + val id = UInt(params.idBits.W) + val resp = UInt(params.respBits.W) + val user = BundleMap(params.responseFields) + val echo = BundleMap(params.echoFields) + val last = Bool() +} + +// OnePipeEncBase is an abstract class that defines the structure of a single-pipe encryption module. +// The main purpose of this class is to standardize the input and output interfaces for encryption modules. +abstract class OnePipeEncBase(implicit p: Parameters) extends MemEncryptModule { + val io = IO(new Bundle { + val onepipe_in = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) + val data_in = Input(UInt(128.W)) + val tweak_in = Input(UInt(128.W)) + val axi4_other = Input(new AXI4BundleWWithoutData(MemcedgeIn.bundle)) + val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W))) + } + val onepipe_out = new Bundle { + val result_out = Output(UInt(128.W)) + val axi4_other_out = Output(new AXI4BundleWWithoutData(MemcedgeIn.bundle)) + val tweak_out = Output(UInt(128.W)) + val keyid_out = Output(UInt(5.W)) + } + }) +} + +// The OnePipeForEnc module needs to actually perform the encryption process for each +// level of the pipeline in the encryption pipeline. +// The flow level can be customized and configured. +class OnePipeForEnc(implicit p: Parameters) extends OnePipeEncBase { + + val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec)) + + for (i <- 0 until 32/MemencPipes) { + val mod = OneRoundForEncDecs(i) + mod.io.round_key_in := io.onepipe_in.round_key_in(i) + if (i == 0) mod.io.data_in := io.onepipe_in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out + } + + io.onepipe_out.result_out := OneRoundForEncDecs.last.io.result_out + io.onepipe_out.keyid_out := io.onepipe_in.keyid + io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other + io.onepipe_out.tweak_out := io.onepipe_in.tweak_in + +} +// The encryption process of each stage in the encryption pipeline does not require +// the OnePipeForEnc module to actually execute the encryption process. +// Test usage +class OnePipeForEncNoEnc(implicit p: Parameters) extends OnePipeEncBase { + io.onepipe_out.result_out := io.onepipe_in.data_in + io.onepipe_out.keyid_out := io.onepipe_in.keyid + io.onepipe_out.tweak_out := io.onepipe_in.tweak_in + io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other +} + +abstract class OnePipeDecBase(implicit p: Parameters) extends MemEncryptModule { + val io = IO(new Bundle { + val onepipe_in = new Bundle { + val keyid = Input(UInt(KeyIDBits.W)) + val data_in = Input(UInt(128.W)) + val tweak_in = Input(UInt(128.W)) + val axi4_other = Input(new AXI4BundleRWithoutData(MemcedgeOut.bundle)) + val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W))) + } + val onepipe_out = new Bundle { + val result_out = Output(UInt(128.W)) + val axi4_other_out = Output(new AXI4BundleRWithoutData(MemcedgeOut.bundle)) + val tweak_out = Output(UInt(128.W)) + val keyid_out = Output(UInt(5.W)) + } + }) +} +class OnePipeForDec(implicit p: Parameters) extends OnePipeDecBase { + + val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec)) + + for (i <- 0 until 32/MemencPipes) { + val mod = OneRoundForEncDecs(i) + mod.io.round_key_in := io.onepipe_in.round_key_in(i) + if (i == 0) mod.io.data_in := io.onepipe_in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out + } + + io.onepipe_out.result_out := OneRoundForEncDecs.last.io.result_out + io.onepipe_out.keyid_out := io.onepipe_in.keyid + io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other + io.onepipe_out.tweak_out := io.onepipe_in.tweak_in + +} +class OnePipeForDecNoDec(implicit p: Parameters) extends OnePipeDecBase { + io.onepipe_out.result_out := io.onepipe_in.data_in + io.onepipe_out.keyid_out := io.onepipe_in.keyid + io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other + io.onepipe_out.tweak_out := io.onepipe_in.tweak_in +} + +// Finite field operations after encrypting tweak (encryption adjustment value) in XTS confidential mode +// Encryption adjustment value (tweak), This encryption adjustment utilizes finite fields and XOR operations to +// ensure security by preventing the same ciphertext from being obtained even if the packet is identical each time. + +// Calculation process: +// Move the logic one bit to the left. If the highest bit that is moved out is 1, XOR the lower 8 bits 0x87 three times, +// generating four different sets of data for XOR before and after data encryption; +class GF128 extends Module{ + val io = IO(new Bundle { + val tweak_in = Input(UInt(128.W)) + val tweak_out = Output(UInt(512.W)) + }) + + val gf_128_fdbk = "h87".U(8.W) + val tweak_1_isgf = io.tweak_in(127) + val tweak_2_isgf = io.tweak_in(126) + val tweak_3_isgf = io.tweak_in(125) + val tweak_1_shifted = Wire(UInt(128.W)) + val tweak_2_shifted = Wire(UInt(128.W)) + val tweak_3_shifted = Wire(UInt(128.W)) + val tweak_1_out = Wire(UInt(128.W)) + val tweak_2_out = Wire(UInt(128.W)) + val tweak_3_out = Wire(UInt(128.W)) + + tweak_1_shifted := io.tweak_in << 1 + tweak_2_shifted := tweak_1_out << 1 + tweak_3_shifted := tweak_2_out << 1 + + tweak_1_out := Mux(tweak_1_isgf, tweak_1_shifted ^ gf_128_fdbk, tweak_1_shifted) + tweak_2_out := Mux(tweak_2_isgf, tweak_2_shifted ^ gf_128_fdbk, tweak_2_shifted) + tweak_3_out := Mux(tweak_3_isgf, tweak_3_shifted ^ gf_128_fdbk, tweak_3_shifted) + + io.tweak_out := Cat(tweak_3_out, tweak_2_out, tweak_1_out, io.tweak_in) +} + +// Perform finite field operations on the initial tweak during the request sending process, +// and output according to the requirements (aw. len) +class TweakGF128(implicit p: Parameters) extends MemEncryptModule{ + val io = IO(new Bundle { + val req = Flipped(DecoupledIO(new Bundle { + val len = UInt(MemcedgeIn.bundle.lenBits.W) + val addr = UInt(PAddrBits.W) + val tweak_in = UInt(128.W) + })) + val resp = DecoupledIO(new Bundle { + val tweak_out = UInt(256.W) + val keyid_out = UInt(KeyIDBits.W) + val addr_out = UInt(PAddrBits.W) + }) + }) + val tweak_gf128 = Module(new GF128()) + tweak_gf128.io.tweak_in := io.req.bits.tweak_in + + val reg_valid = RegInit(false.B) + val reg_counter = RegInit(0.U(2.W)) + val reg_len = RegInit(0.U(MemcedgeIn.bundle.lenBits.W)) + val reg_addr = RegInit(0.U(PAddrBits.W)) + val reg_tweak_result = RegInit(0.U(512.W)) + + io.req.ready := !reg_valid || (reg_valid && io.resp.ready && (reg_len === 0.U || reg_counter =/= 0.U)) + + when(io.req.fire) { + reg_tweak_result := tweak_gf128.io.tweak_out + reg_len := io.req.bits.len + reg_addr := io.req.bits.addr + reg_valid := true.B + reg_counter := 0.U + }.elsewhen(reg_valid && io.resp.ready) { + when(reg_len === 0.U) { + reg_valid := false.B + reg_counter := 0.U + }.otherwise { + when(reg_counter === 0.U) { + reg_counter := reg_counter + 1.U + }.otherwise { + reg_valid := false.B + reg_counter := 0.U + } + } + }.otherwise { + reg_valid := reg_valid + reg_counter := reg_counter + } + + + io.resp.bits.addr_out := reg_addr + io.resp.bits.keyid_out := reg_addr(PAddrBits - 1, PAddrBits - KeyIDBits) + io.resp.bits.tweak_out := Mux(reg_len === 0.U, Mux(reg_addr(5) === 0.U, reg_tweak_result(255, 0), reg_tweak_result(511, 256)), + Mux(reg_counter === 0.U, reg_tweak_result(255, 0), reg_tweak_result(511, 256))) + io.resp.valid := reg_valid +} + +// The encryption process in each stage of the pipeline during the initial tweak encryption process +class OnePipeForTweakEnc(implicit p: Parameters) extends MemEncryptModule { + val io = IO(new Bundle { + val in = new Bundle { + val data_in = Input(UInt(128.W)) + val addr_in = Input(UInt(PAddrBits.W)) + val len_in = Input(UInt(MemcedgeOut.bundle.lenBits.W)) + val id_in = Input(UInt(MemcedgeOut.bundle.idBits.W)) + val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W))) + } + val out = new Bundle { + val result_out = Output(UInt(128.W)) + val addr_out = Output(UInt(PAddrBits.W)) + val len_out = Output(UInt(MemcedgeOut.bundle.lenBits.W)) + val id_out = Output(UInt(MemcedgeOut.bundle.idBits.W)) + } + }) + + val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec)) + for (i <- 0 until 32/MemencPipes) { + val mod = OneRoundForEncDecs(i) + mod.io.round_key_in := io.in.round_key_in(i) + if (i == 0) mod.io.data_in := io.in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out + } + + io.out.result_out := OneRoundForEncDecs.last.io.result_out + io.out.addr_out := io.in.addr_in + io.out.len_out := io.in.len_in + io.out.id_out := io.in.id_in +} + +// Initial TWEAK encryption module. +// The pipeline configuration is determined by the MemencPipes parameter +class TweakEncrypt(opt: Boolean)(implicit p: Parameters) extends MemEncryptModule{ + val edgeUse = if (opt) MemcedgeIn else MemcedgeOut + val io = IO(new Bundle { + val tweak_enc_req = Flipped(DecoupledIO(new Bundle { + val tweak = UInt(128.W) + val addr_in = UInt(PAddrBits.W) + val len_in = UInt(edgeUse.bundle.lenBits.W) // 6 bit + val id_in = UInt(edgeUse.bundle.idBits.W) + val tweak_round_keys = Vec(32, UInt(32.W)) + })) + val tweak_enc_resp = DecoupledIO(new Bundle { + val tweak_encrpty = UInt(128.W) + val addr_out = UInt(PAddrBits.W) + val len_out = UInt(edgeUse.bundle.lenBits.W) + val id_out = UInt(edgeUse.bundle.idBits.W) + }) + }) + + val reg_tweak = Reg(Vec(MemencPipes, UInt(128.W))) + val reg_addr = Reg(Vec(MemencPipes, UInt(PAddrBits.W))) + val reg_len = Reg(Vec(MemencPipes, UInt(edgeUse.bundle.lenBits.W))) + val reg_id = Reg(Vec(MemencPipes, UInt(edgeUse.bundle.idBits.W))) + val reg_tweak_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B))) + // TWEAK encryption requires 32 rounds of encryption keys, grouped by pipeline level + val wire_round_key = Wire(Vec(MemencPipes, UInt((32 * 32 / MemencPipes).W))) + + val keysPerPipe = 32 / MemencPipes + for (i <- 0 until MemencPipes) { + val keySegment = VecInit((0 until keysPerPipe).map(j => io.tweak_enc_req.bits.tweak_round_keys(i * keysPerPipe + j))) + wire_round_key(i) := Cat(keySegment.asUInt) + } + + val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B))) + // The configuration method for each level of encryption module in tweak + def configureModule(i: Int, dataIn: UInt, addrIn: UInt, lenIn: UInt, idIn: UInt, roundKeys: UInt): OnePipeForTweakEnc = { + + when(wire_ready_result(i) && (if (i == 0) io.tweak_enc_req.valid else reg_tweak_valid(i-1))) { + reg_tweak_valid(i) := true.B + }.elsewhen(reg_tweak_valid(i) && (if (i == MemencPipes - 1) io.tweak_enc_resp.ready else wire_ready_result(i+1))) { + reg_tweak_valid(i) := false.B + }.otherwise { + reg_tweak_valid(i) := reg_tweak_valid(i) + } + wire_ready_result(i) := !reg_tweak_valid(i) || (reg_tweak_valid(i) && (if (i == MemencPipes - 1) io.tweak_enc_resp.ready else wire_ready_result(i+1))) + + val module = Module(new OnePipeForTweakEnc()) + module.io.in.data_in := dataIn + module.io.in.addr_in := addrIn + module.io.in.len_in := lenIn + module.io.in.id_in := idIn + for (j <- 0 until 32/MemencPipes) { + module.io.in.round_key_in(j) := roundKeys(j * 32 + 31, j * 32) + } + when(wire_ready_result(i) && (if (i == 0) io.tweak_enc_req.valid else reg_tweak_valid(i-1))) { + reg_tweak(i) := module.io.out.result_out + reg_addr(i) := module.io.out.addr_out + reg_len(i) := module.io.out.len_out + reg_id(i) := module.io.out.id_out + } + module + } + // Instantiate the tweak encryption module for each pipeline level + val tweak_enc_modules = (0 until MemencPipes).map { i => + if (i == 0) { + configureModule(i, io.tweak_enc_req.bits.tweak, io.tweak_enc_req.bits.addr_in, io.tweak_enc_req.bits.len_in, io.tweak_enc_req.bits.id_in, wire_round_key(i)) + } else { + configureModule(i, reg_tweak(i-1), reg_addr(i-1), reg_len(i-1), reg_id(i-1), wire_round_key(i)) + } + } + val result_out = Cat( + reg_tweak.last(31, 0), + reg_tweak.last(63, 32), + reg_tweak.last(95, 64), + reg_tweak.last(127, 96) + ) + io.tweak_enc_resp.bits.tweak_encrpty := result_out + io.tweak_enc_resp.bits.addr_out := reg_addr.last + io.tweak_enc_resp.bits.len_out := reg_len.last + io.tweak_enc_resp.bits.id_out := reg_id.last + io.tweak_enc_resp.valid := reg_tweak_valid.last + io.tweak_enc_req.ready := wire_ready_result(0) + +} + + +// tweak table entry in AR Channel +class TweakTableEntry(implicit val p: Parameters) extends Bundle with Memconsts { + val v_flag = Bool() + val keyid = UInt(KeyIDBits.W) + val len = UInt(MemcedgeOut.bundle.lenBits.W) + val tweak_encrpty = UInt(128.W) + val sel_counter = Bool() +} +class TweakTableModeEntry extends Bundle { + val dec_mode = Bool() +} +// tweak table in AR Channel +class TweakTable(implicit p: Parameters) extends MemEncryptModule { + val io = IO(new Bundle { + // Write to tweak table + val write = Flipped(DecoupledIO(new Bundle { + val id = UInt(MemcedgeOut.bundle.idBits.W) + val len = UInt(MemcedgeOut.bundle.lenBits.W) + val addr = UInt(PAddrBits.W) + val tweak_encrpty = UInt(128.W) + })) + // Read from the tweak table with the ID of channel R in AXI4 as the index + val req = Flipped(DecoupledIO(new Bundle { + val read_id = UInt(MemcedgeOut.bundle.idBits.W) + })) + // Tweak table read response + val resp = DecoupledIO(new Bundle { + val read_tweak = UInt(128.W) + val read_keyid = UInt(KeyIDBits.W) + val read_sel_counter = Bool() + }) + val w_mode = Flipped(DecoupledIO(new Bundle { + val id = UInt(MemcedgeOut.bundle.idBits.W) + val dec_mode = Input(Bool()) + })) + val r_mode = new Bundle { + val id = Input(UInt(MemcedgeOut.bundle.idBits.W)) + val dec_mode = Output(Bool()) + } + }) + + val init_tweak_entry = Wire(new TweakTableEntry()) + init_tweak_entry.v_flag := false.B + init_tweak_entry.keyid := DontCare + init_tweak_entry.len := DontCare + init_tweak_entry.tweak_encrpty := DontCare + init_tweak_entry.sel_counter := DontCare + val init_mode_entry = Wire(new TweakTableModeEntry) + init_mode_entry.dec_mode := false.B + val tweak_table = RegInit(VecInit(Seq.fill((1 << (MemcedgeOut.bundle.idBits - 1)) + 1)(init_tweak_entry))) + val tweak_mode_table = RegInit(VecInit(Seq.fill((1 << (MemcedgeOut.bundle.idBits - 1)) + 1)(init_mode_entry))) + + // write tweak table entry logic + when(io.write.valid) { + val write_entry = tweak_table(io.write.bits.id) + write_entry.tweak_encrpty := io.write.bits.tweak_encrpty + write_entry.keyid := io.write.bits.addr(PAddrBits-1, PAddrBits-KeyIDBits) + write_entry.len := io.write.bits.len + write_entry.v_flag := true.B + + when(io.write.bits.len === 1.U) { + write_entry.sel_counter := false.B + }.otherwise { + write_entry.sel_counter := Mux(io.write.bits.addr(5) === 0.U, false.B, true.B) + } + } + io.write.ready := true.B + + // write mode table entry logic + when(io.w_mode.valid) { + val write_mode_entry = tweak_mode_table(io.w_mode.bits.id) + write_mode_entry.dec_mode := io.w_mode.bits.dec_mode + } + io.w_mode.ready := true.B + + // Tweak table read response logic + val reg_read_valid = RegInit(false.B) + val reg_tweak_encrpty = RegInit(0.U(128.W)) + val reg_keyid = RegInit(0.U(KeyIDBits.W)) + val reg_sel_counter = RegInit(false.B) + + val read_entry = tweak_table(io.req.bits.read_id) + val read_mode_entry = tweak_mode_table(io.r_mode.id) + io.r_mode.dec_mode := read_mode_entry.dec_mode + + io.req.ready := (!reg_read_valid || (reg_read_valid && io.resp.ready)) && read_entry.v_flag + when(io.req.fire) { + reg_read_valid := true.B + reg_tweak_encrpty := read_entry.tweak_encrpty + reg_keyid := read_entry.keyid + reg_sel_counter := read_entry.sel_counter + when(read_entry.len === 0.U) { + read_entry.v_flag := false.B + }.otherwise { + when(!read_entry.sel_counter) { + read_entry.sel_counter := true.B + }.otherwise { + read_entry.v_flag := false.B + } + } + }.elsewhen(reg_read_valid && io.resp.ready) { + reg_read_valid := false.B + }.otherwise { + reg_read_valid := reg_read_valid + } + + io.resp.bits.read_tweak := reg_tweak_encrpty + io.resp.bits.read_keyid := reg_keyid + io.resp.bits.read_sel_counter := reg_sel_counter + io.resp.valid := reg_read_valid + +} + + + +// AXI4Util +// Bypass routing, Determine the encryption mode in the key expansion module. +// write requests need to be encrypted ->io.out1; +// Writing requests does not require encryption --->io.out0. +class WriteChanelRoute(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in = new Bundle { + val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle))) + val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeIn.bundle))) + } + // Unencrypt Chanel + val out0 = new Bundle { + val aw = Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle)) + val w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle)) + } + // Encrypt Chanel + val out1 = new Bundle { + val aw = Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle)) + val w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle)) + } + val enc_keyid = Output(UInt(KeyIDBits.W)) + val enc_mode = Input(Bool()) + val memenc_enable = Input(Bool()) + }) + io.enc_keyid := io.in.aw.bits.addr(PAddrBits-1, PAddrBits-KeyIDBits) + + val reg_idle = RegInit(true.B) + val reg_enc_mode = RegInit(false.B) + + when(io.in.aw.fire) { + reg_idle := false.B + reg_enc_mode := io.enc_mode && io.memenc_enable + } + when(io.in.w.fire && io.in.w.bits.last) { + reg_idle := true.B + } + + val used_enc_mode = Mux(io.in.aw.fire, io.enc_mode && io.memenc_enable, reg_enc_mode) + + // Cut aw_queue.io.enq.ready from io.out*.awready + val aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.aw.bits), 1, flow = true)) + + io.in.aw.ready := reg_idle && aw_queue.io.enq.ready + aw_queue.io.enq.valid := io.in.aw.valid && reg_idle + aw_queue.io.enq.bits := io.in.aw.bits + + val unencrypt_aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.aw.bits), MemencPipes+1, flow = true)) + val unencrypt_w_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.w.bits), (MemencPipes+1)*2, flow = true)) + + aw_queue.io.deq.ready := Mux(used_enc_mode, io.out1.aw.ready, unencrypt_aw_queue.io.enq.ready) + io.in.w.ready := (io.in.aw.fire || !reg_idle) && Mux(used_enc_mode, io.out1.w.ready, unencrypt_w_queue.io.enq.ready) + + unencrypt_aw_queue.io.enq.valid := !used_enc_mode && aw_queue.io.deq.valid + unencrypt_w_queue.io.enq.valid := !used_enc_mode && io.in.w.valid && (io.in.aw.fire || !reg_idle) + + unencrypt_aw_queue.io.enq.bits := aw_queue.io.deq.bits + unencrypt_w_queue.io.enq.bits := io.in.w.bits + + io.out0.aw.valid := unencrypt_aw_queue.io.deq.valid + io.out0.w.valid := unencrypt_w_queue.io.deq.valid + + io.out0.aw.bits := unencrypt_aw_queue.io.deq.bits + io.out0.w.bits := unencrypt_w_queue.io.deq.bits + + unencrypt_aw_queue.io.deq.ready := io.out0.aw.ready + unencrypt_w_queue.io.deq.ready := io.out0.w.ready + + io.out1.aw.valid := used_enc_mode && aw_queue.io.deq.valid + io.out1.w.valid := used_enc_mode && io.in.w.valid && (io.in.aw.fire || !reg_idle) + + io.out1.aw.bits := aw_queue.io.deq.bits + io.out1.w.bits := io.in.w.bits +} + +class WriteChanelArbiter(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + // Unencrypt Chanel + val in0 = new Bundle { + val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle))) + val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle))) + } + // Encrypt Chanel + val in1 = new Bundle { + val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle))) + val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle))) + } + val out = new Bundle { + val aw = Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle)) + val w = Irrevocable(new AXI4BundleW(MemcedgeOut.bundle)) + } + }) + + val validMask = RegInit(false.B) // 1:last send write req from Encrypt Chanel + // 0:last send write req from Unencrypt Chanel + val aw_choice = Wire(Bool()) // 1:Encrypt Chanel 0:Unencrypt Chanel + val w_choice = RegInit(false.B) // 1:Encrypt Chanel 0:Unencrypt Chanel + val reg_idle = RegInit(true.B) + // Cut aw_queue.io.enq.ready from io.out*.awready + val aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in0.aw.bits), 1, flow = true)) + + when(io.in1.aw.fire) { + validMask := true.B + }.elsewhen(io.in0.aw.fire) { + validMask := false.B + }.otherwise { + validMask := validMask + } + + // --------------------------[Unencrypt pref] [Encrypt pref] + aw_choice := Mux(validMask, !io.in0.aw.valid, io.in1.aw.valid) + + when(aw_queue.io.enq.fire) { + reg_idle := false.B + w_choice := aw_choice + } + when(io.out.w.fire && io.out.w.bits.last) { + reg_idle := true.B + } + + val used_w_choice = Mux(aw_queue.io.enq.fire, aw_choice, w_choice) + + io.in0.aw.ready := reg_idle && !aw_choice && aw_queue.io.enq.ready + io.in1.aw.ready := reg_idle && aw_choice && aw_queue.io.enq.ready + aw_queue.io.enq.valid := (io.in0.aw.valid || io.in1.aw.valid) && reg_idle + aw_queue.io.enq.bits := Mux(aw_choice, io.in1.aw.bits, io.in0.aw.bits) + + // DecoupledIO connect IrrevocableIO + io.out.aw.valid := aw_queue.io.deq.valid + io.out.aw.bits := aw_queue.io.deq.bits + aw_queue.io.deq.ready := io.out.aw.ready + + io.in0.w.ready := (aw_queue.io.enq.fire || !reg_idle) && !used_w_choice && io.out.w.ready + io.in1.w.ready := (aw_queue.io.enq.fire || !reg_idle) && used_w_choice && io.out.w.ready + + io.out.w.valid := (aw_queue.io.enq.fire || !reg_idle) && Mux(used_w_choice, io.in1.w.valid, io.in0.w.valid) + io.out.w.bits := Mux(used_w_choice, io.in1.w.bits, io.in0.w.bits) +} + +class RdataChanelRoute(implicit p: Parameters) extends MemEncryptModule +{ + val io = IO(new Bundle { + val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))) + // Unencrypt Chanel + val out_r0 = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)) + // Encrypt Chanel + val out_r1 = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)) + val dec_rid = Output(UInt(MemcedgeOut.bundle.idBits.W)) + val dec_mode = Input(Bool()) + }) + io.dec_rid := io.in_r.bits.id + + val r_sel = io.dec_mode + + io.out_r0.bits <> io.in_r.bits + io.out_r1.bits <> io.in_r.bits + + io.out_r0.valid := io.in_r.valid && !r_sel + io.out_r1.valid := io.in_r.valid && r_sel + io.in_r.ready := Mux(r_sel, io.out_r1.ready, io.out_r0.ready) +} \ No newline at end of file diff --git a/src/main/scala/system/SoC.scala b/src/main/scala/system/SoC.scala index 16eaad515da..f73a6c00e71 100644 --- a/src/main/scala/system/SoC.scala +++ b/src/main/scala/system/SoC.scala @@ -19,7 +19,7 @@ package system import org.chipsalliance.cde.config.{Field, Parameters} import chisel3._ import chisel3.util._ -import device.{DebugModule, TLPMA, TLPMAIO} +import device.{DebugModule, TLPMA, TLPMAIO, AXI4MemEncrypt} import freechips.rocketchip.amba.axi4._ import freechips.rocketchip.devices.debug.DebugModuleKey import freechips.rocketchip.devices.tilelink._ @@ -38,6 +38,16 @@ import coupledL2.tl2chi.CHIIssue import openLLC.OpenLLCParam case object SoCParamsKey extends Field[SoCParameters] +case object CVMParamskey extends Field[CVMParameters] + +case class CVMParameters +( + MEMENCRange: AddressSet = AddressSet(0x38030000L, 0xfff), + KeyIDBits: Int = 0, + MemencPipes: Int = 4, + HasMEMencryption: Boolean = false, + HasDelayNoencryption: Boolean = false, // Test specific +) case class SoCParameters ( @@ -105,6 +115,7 @@ trait HasSoCParameter { implicit val p: Parameters val soc = p(SoCParamsKey) + val cvm = p(CVMParamskey) val debugOpts = p(DebugOptionsKey) val tiles = p(XSTileKey) val enableCHI = p(EnableCHI) @@ -140,11 +151,16 @@ trait HasSoCParameter { val EnableCHIAsyncBridge = if (enableCHI && soc.EnableCHIAsyncBridge.isDefined) soc.EnableCHIAsyncBridge else None val EnableClintAsyncBridge = soc.EnableClintAsyncBridge + + def HasMEMencryption = cvm.HasMEMencryption + require((cvm.HasMEMencryption && (cvm.KeyIDBits > 0)) || (!cvm.HasMEMencryption && (cvm.KeyIDBits == 0)) , + "HasMEMencryption most set with KeyIDBits > 0") } trait HasPeripheralRanges { implicit val p: Parameters + private def cvm = p(CVMParamskey) private def soc = p(SoCParamsKey) private def dm = p(DebugModuleKey) private def pmParams = p(PMParameKey) @@ -164,6 +180,11 @@ trait HasPeripheralRanges { Map("L3CTL" -> AddressSet(soc.L3CacheParamsOpt.get.ctrl.get.address, 0xffff)) else Map() + ) ++ ( + if (cvm.HasMEMencryption) + Map("MEMENC" -> cvm.MEMENCRange) + else + Map() ) def peripheralRange = onChipPeripheralRanges.values.foldLeft(Seq(AddressSet(0x0, 0x7fffffffL))) { (acc, x) => @@ -274,15 +295,30 @@ trait HaveAXI4MemPort { TLBuffer.chainNode(2) := mem_xbar } + val axi4memencrpty = Option.when(HasMEMencryption)(LazyModule(new AXI4MemEncrypt(cvm.MEMENCRange))) + if (HasMEMencryption) { + memAXI4SlaveNode := + AXI4Buffer() := + AXI4Buffer() := + AXI4Buffer() := + AXI4IdIndexer(idBits = 14) := + AXI4UserYanker() := + axi4memencrpty.get.node + + axi4memencrpty.get.node := + AXI4Deinterleaver(L3BlockSize) := + axi4mem_node + } else { + memAXI4SlaveNode := + AXI4Buffer() := + AXI4Buffer() := + AXI4Buffer() := + AXI4IdIndexer(idBits = 14) := + AXI4UserYanker() := + AXI4Deinterleaver(L3BlockSize) := + axi4mem_node + } - memAXI4SlaveNode := - AXI4Buffer() := - AXI4Buffer() := - AXI4Buffer() := - AXI4IdIndexer(idBits = 14) := - AXI4UserYanker() := - AXI4Deinterleaver(L3BlockSize) := - axi4mem_node val memory = InModuleBody { memAXI4SlaveNode.makeIOs() @@ -446,8 +482,14 @@ class MemMisc()(implicit p: Parameters) extends BaseSoC val pma = LazyModule(new TLPMA) if (enableCHI) { pma.node := TLBuffer.chainNode(4) := device_xbar.get + if (HasMEMencryption) { + axi4memencrpty.get.ctrl_node := TLToAPB() := device_xbar.get + } } else { pma.node := TLBuffer.chainNode(4) := peripheralXbar.get + if (HasMEMencryption) { + axi4memencrpty.get.ctrl_node := TLToAPB() := peripheralXbar.get + } } class SoCMiscImp(wrapper: LazyModule) extends LazyModuleImp(wrapper) { @@ -472,6 +514,11 @@ class MemMisc()(implicit p: Parameters) extends BaseSoC pma.module.io <> cacheable_check + if (HasMEMencryption) { + val cnt = Counter(true.B, 8)._1 + axi4memencrpty.get.module.io.random_val := axi4memencrpty.get.module.io.random_req && cnt(2).asBool + axi4memencrpty.get.module.io.random_data := cnt(0).asBool + } // positive edge sampling of the lower-speed rtc_clock val rtcTick = RegInit(0.U(3.W)) rtcTick := Cat(rtcTick(1, 0), rtc_clock) diff --git a/src/main/scala/top/Configs.scala b/src/main/scala/top/Configs.scala index e863f36202d..69661260c5e 100644 --- a/src/main/scala/top/Configs.scala +++ b/src/main/scala/top/Configs.scala @@ -41,6 +41,7 @@ class BaseConfig(n: Int) extends Config((site, here, up) => { case XLen => 64 case DebugOptionsKey => DebugOptions() case SoCParamsKey => SoCParameters() + case CVMParamskey => CVMParameters() case PMParameKey => PMParameters() case XSTileKey => Seq.tabulate(n){ i => XSCoreParameters(HartId = i) } case ExportDebug => DebugAttachParams(protocols = Set(JTAG)) @@ -411,6 +412,28 @@ class WithFuzzer extends Config((site, here, up) => { } }) +class CVMCompile extends Config((site, here, up) => { + case CVMParamskey => up(CVMParamskey).copy( + KeyIDBits = 5, + HasMEMencryption = true, + HasDelayNoencryption = false + ) + case XSTileKey => up(XSTileKey).map(_.copy( + HasBitmapCheck = true, + HasBitmapCheckDefault = false)) +}) + +class CVMTestCompile extends Config((site, here, up) => { + case CVMParamskey => up(CVMParamskey).copy( + KeyIDBits = 5, + HasMEMencryption = true, + HasDelayNoencryption = true + ) + case XSTileKey => up(XSTileKey).map(_.copy( + HasBitmapCheck =true, + HasBitmapCheckDefault = true)) +}) + class MinimalAliasDebugConfig(n: Int = 1) extends Config( L3CacheConfig("512KB", inclusive = false) ++ L2CacheConfig("256KB", inclusive = true) @@ -437,6 +460,16 @@ class DefaultConfig(n: Int = 1) extends Config( ++ new BaseConfig(n) ) +class CVMConfig(n: Int = 1) extends Config( + new CVMCompile + ++ new DefaultConfig(n) +) + +class CVMTestConfig(n: Int = 1) extends Config( + new CVMTestCompile + ++ new DefaultConfig(n) +) + class WithCHI extends Config((_, _, _) => { case EnableCHI => true }) diff --git a/src/main/scala/xiangshan/Bundle.scala b/src/main/scala/xiangshan/Bundle.scala index 617513398fd..e001ec9a3f1 100644 --- a/src/main/scala/xiangshan/Bundle.scala +++ b/src/main/scala/xiangshan/Bundle.scala @@ -514,10 +514,30 @@ class TlbHgatpBundle(implicit p: Parameters) extends HgatpStruct { } } +// add mbmc csr +class MbmcStruct(implicit p: Parameters) extends XSBundle { + val BME = UInt(1.W) + val CMODE = UInt(1.W) + val BCLEAR = UInt(1.W) + val BMA = UInt(58.W) +} + +class TlbMbmcBundle(implicit p: Parameters) extends MbmcStruct { + def apply(mbmc_value: UInt): Unit = { + require(mbmc_value.getWidth == XLEN) + val mc = mbmc_value.asTypeOf(new MbmcStruct) + BME := mc.BME + CMODE := mc.CMODE + BCLEAR := mc.BCLEAR + BMA := mc.BMA + } +} + class TlbCsrBundle(implicit p: Parameters) extends XSBundle { val satp = new TlbSatpBundle() val vsatp = new TlbSatpBundle() val hgatp = new TlbHgatpBundle() + val mbmc = new TlbMbmcBundle() val priv = new Bundle { val mxr = Bool() val sum = Bool() diff --git a/src/main/scala/xiangshan/PMParameters.scala b/src/main/scala/xiangshan/PMParameters.scala index 8f8eed4c111..3f4ec961b04 100644 --- a/src/main/scala/xiangshan/PMParameters.scala +++ b/src/main/scala/xiangshan/PMParameters.scala @@ -20,6 +20,7 @@ import chisel3.util.log2Ceil import org.chipsalliance.cde.config.{Field, Parameters} import freechips.rocketchip.tile.XLen import system.SoCParamsKey +import system.CVMParamskey import xiangshan.backend.fu.{MMPMAConfig, MMPMAMethod} case object PMParameKey extends Field[PMParameters] @@ -45,6 +46,7 @@ trait HasPMParameters { def PMPAddrBits = p(SoCParamsKey).PAddrBits def PMPPmemRanges = p(SoCParamsKey).PmemRanges def PMAConfigs = p(SoCParamsKey).PMAConfigs + val PMPKeyIDBits = p(CVMParamskey).KeyIDBits def PMXLEN = p(XLen) def pmParams = p(PMParameKey) def NumPMP = pmParams.NumPMP diff --git a/src/main/scala/xiangshan/Parameters.scala b/src/main/scala/xiangshan/Parameters.scala index 665a0a1c453..4768a093381 100644 --- a/src/main/scala/xiangshan/Parameters.scala +++ b/src/main/scala/xiangshan/Parameters.scala @@ -21,6 +21,7 @@ import chisel3._ import chisel3.util._ import huancun._ import system.SoCParamsKey +import system.CVMParamskey import xiangshan.backend.datapath.RdConfig._ import xiangshan.backend.datapath.WbConfig._ import xiangshan.backend.exu.ExeUnitParams @@ -60,6 +61,8 @@ case class XSCoreParameters VLEN: Int = 128, ELEN: Int = 64, HSXLEN: Int = 64, + HasBitmapCheck: Boolean = false, + HasBitmapCheckDefault: Boolean = false, HasMExtension: Boolean = true, HasCExtension: Boolean = true, HasHExtension: Boolean = true, @@ -569,6 +572,7 @@ trait HasXSParameter { def PAddrBits = p(SoCParamsKey).PAddrBits // PAddrBits is Phyical Memory addr bits def PmemRanges = p(SoCParamsKey).PmemRanges + def KeyIDBits = p(CVMParamskey).KeyIDBits final val PageOffsetWidth = 12 def NodeIDWidth = p(SoCParamsKey).NodeIDWidthList(p(CHIIssue)) // NodeID width among NoC @@ -586,6 +590,8 @@ trait HasXSParameter { def hartIdLen = p(MaxHartIdBits) val xLen = XLEN + def HasBitmapCheck = coreParams.HasBitmapCheck + def HasBitmapCheckDefault = coreParams.HasBitmapCheckDefault def HasMExtension = coreParams.HasMExtension def HasCExtension = coreParams.HasCExtension def HasHExtension = coreParams.HasHExtension diff --git a/src/main/scala/xiangshan/backend/fu/NewCSR/CSRDefines.scala b/src/main/scala/xiangshan/backend/fu/NewCSR/CSRDefines.scala index 265ae174c51..9b3b294a8ba 100644 --- a/src/main/scala/xiangshan/backend/fu/NewCSR/CSRDefines.scala +++ b/src/main/scala/xiangshan/backend/fu/NewCSR/CSRDefines.scala @@ -144,6 +144,11 @@ object CSRDefines { val Dirty = Value(3.U) } + object BMAField extends CSREnum with WARLApply { + val ResetBMA = Value(0.U) + val TestBMA = Value("h4000000".U) + } + object XLENField extends CSREnum with ROApply { val XLEN32 = Value(1.U) val XLEN64 = Value(2.U) diff --git a/src/main/scala/xiangshan/backend/fu/NewCSR/CSREvents/CSREvent.scala b/src/main/scala/xiangshan/backend/fu/NewCSR/CSREvents/CSREvent.scala index b4fff31d684..50534d47fb1 100644 --- a/src/main/scala/xiangshan/backend/fu/NewCSR/CSREvents/CSREvent.scala +++ b/src/main/scala/xiangshan/backend/fu/NewCSR/CSREvents/CSREvent.scala @@ -148,6 +148,7 @@ class TrapEntryEventInput(implicit val p: Parameters) extends Bundle with HasXSP val satp = Input(new SatpBundle) val vsatp = Input(new SatpBundle) val hgatp = Input(new HgatpBundle) + val mbmc = Input(new MbmcBundle) // from mem val memExceptionVAddr = Input(UInt(XLEN.W)) val memExceptionGPAddr = Input(UInt(XLEN.W)) diff --git a/src/main/scala/xiangshan/backend/fu/NewCSR/MachineLevel.scala b/src/main/scala/xiangshan/backend/fu/NewCSR/MachineLevel.scala index f664a6447eb..80d76bb680c 100644 --- a/src/main/scala/xiangshan/backend/fu/NewCSR/MachineLevel.scala +++ b/src/main/scala/xiangshan/backend/fu/NewCSR/MachineLevel.scala @@ -15,10 +15,27 @@ import xiangshan.backend.fu.NewCSR.ChiselRecordForField._ import xiangshan.backend.fu.PerfCounterIO import xiangshan.backend.fu.NewCSR.CSRConfig._ import xiangshan.backend.fu.NewCSR.CSRFunc._ +import xiangshan.backend.fu.util.CSRConst._ import scala.collection.immutable.SeqMap trait MachineLevel { self: NewCSR => + // Machine level Custom Read/Write + val mbmc = if (HasBitmapCheck) Some(Module(new CSRModule("Mbmc", new MbmcBundle) { + val mbmc_lock = reg.BME.asBool + if (!HasBitmapCheckDefault) { + reg.BME := Mux(wen && !mbmc_lock, wdata.BME, reg.BME) + reg.CMODE := Mux(wen, wdata.CMODE, reg.CMODE) + reg.BMA := Mux(wen && !mbmc_lock, wdata.BMA, reg.BMA) + } else { + reg.BME := 1.U + reg.CMODE := 0.U + reg.BMA := BMAField.TestBMA + } + reg.BCLEAR := Mux(reg.BCLEAR.asBool, 0.U, Mux(wen && wdata.BCLEAR.asBool, 1.U, 0.U)) + }) + .setAddr(Mbmc)) else None + val mstatus = Module(new MstatusModule) .setAddr(CSRs.mstatus) @@ -417,7 +434,8 @@ trait MachineLevel { self: NewCSR => mncause, mnstatus, mnscratch, - ) ++ mhpmevents ++ mhpmcounters + ) ++ mhpmevents ++ mhpmcounters ++ (if (HasBitmapCheck) Seq(mbmc.get) else Seq()) + val machineLevelCSRMap: SeqMap[Int, (CSRAddrWriteBundle[_], UInt)] = SeqMap.from( machineLevelCSRMods.map(csr => (csr.addr -> (csr.w -> csr.rdata))).iterator @@ -429,6 +447,13 @@ trait MachineLevel { self: NewCSR => } +class MbmcBundle extends CSRBundle { + val BMA = BMAField(63,6,null).withReset(BMAField.ResetBMA) + val BME = RW(2).withReset(0.U) + val BCLEAR = RW(1).withReset(0.U) + val CMODE = RW(0).withReset(0.U) +} + class MstatusBundle extends CSRBundle { val SIE = CSRRWField (1).withReset(0.U) diff --git a/src/main/scala/xiangshan/backend/fu/NewCSR/NewCSR.scala b/src/main/scala/xiangshan/backend/fu/NewCSR/NewCSR.scala index 55637ab8c93..8be2c37c767 100644 --- a/src/main/scala/xiangshan/backend/fu/NewCSR/NewCSR.scala +++ b/src/main/scala/xiangshan/backend/fu/NewCSR/NewCSR.scala @@ -204,6 +204,7 @@ class NewCSR(implicit val p: Parameters) extends Module val satp = new SatpBundle val vsatp = new SatpBundle val hgatp = new HgatpBundle + val mbmc = new MbmcBundle val mxr = Bool() val sum = Bool() val vmxr = Bool() @@ -795,6 +796,11 @@ class NewCSR(implicit val p: Parameters) extends Module in.satp := satp.regOut in.vsatp := vsatp.regOut in.hgatp := hgatp.regOut + if (HasBitmapCheck) { + in.mbmc := mbmc.get.regOut + } else { + in.mbmc := DontCare + } in.memExceptionVAddr := io.fromMem.excpVA in.memExceptionGPAddr := io.fromMem.excpGPA @@ -894,8 +900,13 @@ class NewCSR(implicit val p: Parameters) extends Module (addr >= CSRs.cycle.U) && (addr <= CSRs.hpmcounter31.U) ) + val resetSatp = WireInit(false.B) // flush - val resetSatp = Cat(Seq(satp, vsatp, hgatp).map(_.addr.U === addr)).orR && wenLegalReg // write to satp will cause the pipeline be flushed + if (HasBitmapCheck) { + resetSatp := Cat(Seq(satp, vsatp, hgatp, mbmc.get).map(_.addr.U === addr)).orR && wenLegalReg // write to satp will cause the pipeline be flushed + } else { + resetSatp := Cat(Seq(satp, vsatp, hgatp).map(_.addr.U === addr)).orR && wenLegalReg // write to satp will cause the pipeline be flushed + } val floatStatusOnOff = mstatus.w.wen && ( mstatus.w.wdataFields.FS === ContextStatus.Off && mstatus.regOut.FS =/= ContextStatus.Off || @@ -1354,6 +1365,11 @@ class NewCSR(implicit val p: Parameters) extends Module io.tlb.satp := satp.rdata io.tlb.vsatp := vsatp.rdata io.tlb.hgatp := hgatp.rdata + if (HasBitmapCheck) { + io.tlb.mbmc := mbmc.get.rdata + } else { + io.tlb.mbmc := DontCare + } io.tlb.mxr := mstatus.regOut.MXR.asBool io.tlb.sum := mstatus.regOut.SUM.asBool io.tlb.vmxr := vsstatus.regOut.MXR.asBool diff --git a/src/main/scala/xiangshan/backend/fu/PMP.scala b/src/main/scala/xiangshan/backend/fu/PMP.scala index 0a8991af5d1..7ee98f3b041 100644 --- a/src/main/scala/xiangshan/backend/fu/PMP.scala +++ b/src/main/scala/xiangshan/backend/fu/PMP.scala @@ -459,11 +459,20 @@ trait PMPCheckMethod extends PMPConst { } class PMPCheckerEnv(implicit p: Parameters) extends PMPBundle { + val cmode = Bool() val mode = UInt(2.W) val pmp = Vec(NumPMP, new PMPEntry()) val pma = Vec(NumPMA, new PMPEntry()) + def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry]): Unit = { + this.cmode := cmode + this.mode := mode + this.pmp := pmp + this.pma := pma + } + def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry]): Unit = { + this.cmode := true.B this.mode := mode this.pmp := pmp this.pma := pma @@ -475,6 +484,12 @@ class PMPCheckIO(lgMaxSize: Int)(implicit p: Parameters) extends PMPBundle { val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal val resp = new PMPRespBundle() + def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = { + check_env.apply(cmode, mode, pmp, pma) + this.req := req + resp + } + def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = { check_env.apply(mode, pmp, pma) this.req := req @@ -498,6 +513,12 @@ class PMPCheckv2IO(lgMaxSize: Int)(implicit p: Parameters) extends PMPBundle { val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal val resp = Output(new PMPConfig()) + def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], valid: Bool, addr: UInt) = { + check_env.apply(cmode, mode, pmp, pma) + req_apply(valid, addr) + resp + } + def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = { check_env.apply(mode, pmp, pma) this.req := req @@ -531,12 +552,44 @@ class PMPChecker val req = io.req.bits - val res_pmp = pmp_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.check_env.pmp, io.check_env.mode, lgMaxSize) - val res_pma = pma_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.check_env.pma, io.check_env.mode, lgMaxSize) + /* The KeyIDBits is used for memary encrypt, as part of address MSB, + * so (PMPKeyIDBits > 0) usually set with HasMEMencryption = true. + * + * Example: + * PAddrBits=48 & PMPKeyIDBits=5 + * [47,46,45,44,43, 42,41,.......,1,0] + * {----KeyID----} {----RealPAddr----} + * + * The nonzero keyID is binding with Enclave/CVM(cmode=1) to select different memary encrypt key, + * and the OS/VMM/APP/VM(cmode=0) can only use zero as KeyID. + * + * So only the RealPAddr need PMP&PMA check. + */ + + val res_pmp = pmp_match_res(leaveHitMux, io.req.valid)(req.addr(PMPAddrBits-PMPKeyIDBits-1, 0), req.size, io.check_env.pmp, io.check_env.mode, lgMaxSize) + val res_pma = pma_match_res(leaveHitMux, io.req.valid)(req.addr(PMPAddrBits-PMPKeyIDBits-1, 0), req.size, io.check_env.pma, io.check_env.mode, lgMaxSize) val resp_pmp = pmp_check(req.cmd, res_pmp.cfg) val resp_pma = pma_check(req.cmd, res_pma.cfg) - val resp = if (pmpUsed) (resp_pmp | resp_pma) else resp_pma + + def keyid_check(leaveHitMux: Boolean = false, valid: Bool = true.B, addr: UInt) = { + val resp = Wire(new PMPRespBundle) + val keyid_nz = if (PMPKeyIDBits > 0) addr(PMPAddrBits-1, PMPAddrBits-PMPKeyIDBits) =/= 0.U else false.B + resp.ld := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U) + resp.st := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U) + resp.instr := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U) + resp.mmio := false.B + resp.atomic := false.B + if (leaveHitMux) { + RegEnable(resp, valid) + } else { + resp + } + } + + val resp_keyid = keyid_check(leaveHitMux, io.req.valid, req.addr) + + val resp = if (pmpUsed) (resp_pmp | resp_pma | resp_keyid) else (resp_pma | resp_keyid) if (sameCycle || leaveHitMux) { io.resp := resp diff --git a/src/main/scala/xiangshan/backend/fu/util/CSRConst.scala b/src/main/scala/xiangshan/backend/fu/util/CSRConst.scala index 4e68c8ed5de..bb758858e4b 100644 --- a/src/main/scala/xiangshan/backend/fu/util/CSRConst.scala +++ b/src/main/scala/xiangshan/backend/fu/util/CSRConst.scala @@ -34,6 +34,9 @@ trait HasCSRConst { val PmacfgBase = 0x7C0 val PmaaddrBase = 0x7C8 // 64 entry at most + // Machine level Bitmap Check(Custom Read/Write) + val Mbmc = 0xBC2 + def privEcall = 0x000.U def privEbreak = 0x001.U def privMNret = 0x702.U diff --git a/src/main/scala/xiangshan/backend/fu/wrapper/CSR.scala b/src/main/scala/xiangshan/backend/fu/wrapper/CSR.scala index 1181927a12b..f92782acee6 100644 --- a/src/main/scala/xiangshan/backend/fu/wrapper/CSR.scala +++ b/src/main/scala/xiangshan/backend/fu/wrapper/CSR.scala @@ -254,6 +254,10 @@ class CSR(cfg: FuConfig)(implicit p: Parameters) extends FuncUnit(cfg) tlb.hgatp.mode := csrMod.io.tlb.hgatp.MODE.asUInt tlb.hgatp.vmid := csrMod.io.tlb.hgatp.VMID.asUInt tlb.hgatp.ppn := csrMod.io.tlb.hgatp.PPN.asUInt + tlb.mbmc.BME := csrMod.io.tlb.mbmc.BME.asUInt + tlb.mbmc.CMODE := csrMod.io.tlb.mbmc.CMODE.asUInt + tlb.mbmc.BCLEAR := csrMod.io.tlb.mbmc.BCLEAR.asUInt + tlb.mbmc.BMA := csrMod.io.tlb.mbmc.BMA.asUInt // expose several csr bits for tlb tlb.priv.mxr := csrMod.io.tlb.mxr diff --git a/src/main/scala/xiangshan/cache/mmu/BitmapCheck.scala b/src/main/scala/xiangshan/cache/mmu/BitmapCheck.scala new file mode 100644 index 00000000000..7efa1296d2a --- /dev/null +++ b/src/main/scala/xiangshan/cache/mmu/BitmapCheck.scala @@ -0,0 +1,435 @@ +/*************************************************************************************** +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences +* +* XiangShan is licensed under Mulan PSL v2. +* You can use this software according to the terms and conditions of the Mulan PSL v2. +* You may obtain a copy of Mulan PSL v2 at: +* http://license.coscl.org.cn/MulanPSL2 +* +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +* +* See the Mulan PSL v2 for more details. +***************************************************************************************/ + +package xiangshan.cache.mmu + +import org.chipsalliance.cde.config.Parameters +import chisel3._ +import chisel3.util._ +import xiangshan._ +import xiangshan.cache.{HasDCacheParameters, MemoryOpConstants} +import utils._ +import utility._ +import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp} +import freechips.rocketchip.tilelink._ +import xiangshan.backend.fu.{PMPReqBundle, PMPRespBundle} + +class bitmapReqBundle(implicit p: Parameters) extends XSBundle with HasPtwConst { + val bmppn = UInt(ppnLen.W) + val id = UInt(log2Up(l2tlbParams.llptwsize+2).W) + val vpn = UInt(vpnLen.W) + val level = UInt(log2Up(Level).W) + val way_info = UInt(l2tlbParams.l0nWays.W) + val hptw_bypassed = Bool() +} + +class bitmapRespBundle(implicit p: Parameters) extends XSBundle with HasPtwConst { + val cf = Bool() + val cfs = Vec(tlbcontiguous,Bool()) + val id = UInt(log2Up(l2tlbParams.llptwsize+2).W) +} + +class bitmapEntry(implicit p: Parameters) extends XSBundle with HasPtwConst { + val ppn = UInt(ppnLen.W) + val vpn = UInt(vpnLen.W) + val id = UInt(bMemID.W) + val wait_id = UInt(log2Up(l2tlbParams.llptwsize+2).W) + // bitmap check faild? : 0 success, 1 faild + val cf = Bool() + val hit = Bool() + val cfs = Vec(tlbcontiguous,Bool()) + val level = UInt(log2Up(Level).W) + val way_info = UInt(l2tlbParams.l0nWays.W) + val hptw_bypassed = Bool() + val data = UInt(XLEN.W) +} + +class bitmapIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { + val mem = new Bundle { + val req = DecoupledIO(new L2TlbMemReqBundle()) + val resp = Flipped(DecoupledIO(new Bundle { + val id = Output(UInt(bMemID.W)) + val value = Output(UInt(blockBits.W)) + })) + val req_mask = Input(Vec(l2tlbParams.llptwsize+2, Bool())) + } + val req = Flipped(DecoupledIO(new bitmapReqBundle())) + val resp = DecoupledIO(new bitmapRespBundle()) + + val pmp = new Bundle { + val req = ValidIO(new PMPReqBundle()) + val resp = Flipped(new PMPRespBundle()) + } + + val wakeup = ValidIO(new Bundle { + val setIndex = UInt(PtwL0SetIdxLen.W) + val tag = UInt(SPTagLen.W) + val isSp = Bool() + val way_info = UInt(l2tlbParams.l0nWays.W) + val pte_index = UInt(sectortlbwidth.W) + val check_success = Bool() + }) + + // bitmap cache req/resp and refill port + val cache = new Bundle { + val req = DecoupledIO(new bitmapCacheReqBundle()) + val resp = Flipped(DecoupledIO(new bitmapCacheRespBundle())) + } + val refill = Output(ValidIO(new Bundle { + val tag = UInt(ppnLen.W) + val data = UInt(XLEN.W) + })) +} + +class Bitmap(implicit p: Parameters) extends XSModule with HasPtwConst { + def getBitmapAddr(ppn: UInt): UInt = { + val effective_ppn = ppn(ppnLen-KeyIDBits-1, 0) + bitmap_base + (effective_ppn >> log2Ceil(XLEN) << log2Ceil(8)) + } + + val io = IO(new bitmapIO) + + val csr = io.csr + val sfence = io.sfence + val flush = sfence.valid || csr.satp.changed || csr.vsatp.changed || csr.hgatp.changed + val bitmap_base = csr.mbmc.BMA << 6 + + val entries = Reg(Vec(l2tlbParams.llptwsize+2, new bitmapEntry())) + // add pmp check + val state_idle :: state_addr_check :: state_cache_req :: state_cache_resp ::state_mem_req :: state_mem_waiting :: state_mem_out :: Nil = Enum(7) + val state = RegInit(VecInit(Seq.fill(l2tlbParams.llptwsize+2)(state_idle))) + + val is_emptys = state.map(_ === state_idle) + val is_cache_req = state.map (_ === state_cache_req) + val is_cache_resp = state.map (_ === state_cache_resp) + val is_mems = state.map(_ === state_mem_req) + val is_waiting = state.map(_ === state_mem_waiting) + val is_having = state.map(_ === state_mem_out) + + val full = !ParallelOR(is_emptys).asBool + val waiting = ParallelOR(is_waiting).asBool + val enq_ptr = ParallelPriorityEncoder(is_emptys) + + val mem_ptr = ParallelPriorityEncoder(is_having) + val mem_arb = Module(new RRArbiter(new bitmapEntry(), l2tlbParams.llptwsize+2)) + + val bitmapdata = Wire(Vec(blockBits / XLEN, UInt(XLEN.W))) + if (HasBitmapCheckDefault) { + for (i <- 0 until blockBits / XLEN) { + bitmapdata(i) := 0.U + } + } else { + bitmapdata := io.mem.resp.bits.value.asTypeOf(Vec(blockBits / XLEN, UInt(XLEN.W))) + } + + for (i <- 0 until l2tlbParams.llptwsize+2) { + mem_arb.io.in(i).bits := entries(i) + mem_arb.io.in(i).valid := is_mems(i) && !io.mem.req_mask(i) + } + + val cache_req_arb = Module(new Arbiter(new bitmapCacheReqBundle(),l2tlbParams.llptwsize + 2)) + for (i <- 0 until l2tlbParams.llptwsize+2) { + cache_req_arb.io.in(i).valid := is_cache_req(i) + cache_req_arb.io.in(i).bits.tag := entries(i).ppn + cache_req_arb.io.in(i).bits.order := i.U; + } + + val dup_vec = state.indices.map(i => + dupBitmapPPN(io.req.bits.bmppn, entries(i).ppn) + ) + val dup_req_fire = mem_arb.io.out.fire && dupBitmapPPN(io.req.bits.bmppn, mem_arb.io.out.bits.ppn) + val dup_vec_wait = dup_vec.zip(is_waiting).map{case (d, w) => d && w} + val dup_wait_resp = io.mem.resp.fire && VecInit(dup_vec_wait)(io.mem.resp.bits.id - (l2tlbParams.llptwsize + 2).U) + val wait_id = Mux(dup_req_fire, mem_arb.io.chosen, ParallelMux(dup_vec_wait zip entries.map(_.wait_id))) + + val to_wait = Cat(dup_vec_wait).orR || dup_req_fire + val to_mem_out = dup_wait_resp + + val enq_state_normal = MuxCase(state_addr_check, Seq( + to_mem_out -> state_mem_out, + to_wait -> state_mem_waiting + )) + val enq_state = enq_state_normal + val enq_ptr_reg = RegNext(enq_ptr) + + val need_addr_check = RegNext(enq_state === state_addr_check && io.req.fire && !flush) + + io.pmp.req.valid := need_addr_check + io.pmp.req.bits.addr := RegEnable(getBitmapAddr(io.req.bits.bmppn),io.req.fire) + io.pmp.req.bits.cmd := TlbCmd.read + io.pmp.req.bits.size := 3.U + val pmp_resp_valid = io.pmp.req.valid + + when (io.req.fire) { + state(enq_ptr) := enq_state + entries(enq_ptr).ppn := io.req.bits.bmppn + entries(enq_ptr).vpn := io.req.bits.vpn + entries(enq_ptr).id := io.req.bits.id + entries(enq_ptr).wait_id := Mux(to_wait, wait_id, enq_ptr) + entries(enq_ptr).cf := false.B + for (i <- 0 until tlbcontiguous) { + entries(enq_ptr).cfs(i) := false.B + } + entries(enq_ptr).hit := to_wait + entries(enq_ptr).level := io.req.bits.level + entries(enq_ptr).way_info := io.req.bits.way_info + entries(enq_ptr).hptw_bypassed := io.req.bits.hptw_bypassed + } + + // when pmp check failed, use cf bit represent + when (pmp_resp_valid) { + val ptr = enq_ptr_reg + val accessFault = io.pmp.resp.ld || io.pmp.resp.mmio + entries(ptr).cf := accessFault + for (i <- 0 until tlbcontiguous) { + entries(ptr).cfs(i) := accessFault + } + // firstly req bitmap cache + state(ptr) := Mux(accessFault, state_mem_out, state_cache_req) + } + + val cache_wait = ParallelOR(is_cache_resp).asBool + io.cache.resp.ready := !flush && cache_wait + + val hit = WireInit(false.B) + io.cache.req.valid := cache_req_arb.io.out.valid && !flush + io.cache.req.bits.tag := cache_req_arb.io.out.bits.tag + io.cache.req.bits.order := cache_req_arb.io.out.bits.order + cache_req_arb.io.out.ready := io.cache.req.ready + + + when (cache_req_arb.io.out.fire) { + for (i <- state.indices) { + when (state(i) === state_cache_req && cache_req_arb.io.chosen === i.U) { + state(i) := state_cache_resp + } + } + } + + when (io.cache.resp.fire) { + for (i <- state.indices) { + val cm_dup_vec = state.indices.map(j => + dupBitmapPPN(entries(i).ppn, entries(j).ppn) + ) + val cm_dup_req_fire = mem_arb.io.out.fire && dupBitmapPPN(entries(i).ppn, mem_arb.io.out.bits.ppn) + val cm_dup_vec_wait = cm_dup_vec.zip(is_waiting).map{case (d, w) => d && w} + val cm_dup_wait_resp = io.mem.resp.fire && VecInit(cm_dup_vec_wait)(io.mem.resp.bits.id - (l2tlbParams.llptwsize + 2).U) + val cm_wait_id = Mux(cm_dup_req_fire, mem_arb.io.chosen, ParallelMux(cm_dup_vec_wait zip entries.map(_.wait_id))) + val cm_to_wait = Cat(cm_dup_vec_wait).orR || cm_dup_req_fire + val cm_to_mem_out = cm_dup_wait_resp + val cm_next_state_normal = MuxCase(state_mem_req, Seq( + cm_to_mem_out -> state_mem_out, + cm_to_wait -> state_mem_waiting + )) + when (state(i) === state_cache_resp && io.cache.resp.bits.order === i.U) { + hit := io.cache.resp.bits.hit + when (hit) { + entries(i).cf := io.cache.resp.bits.cfs(entries(i).ppn(5,0)) + entries(i).hit := true.B + entries(i).cfs := io.cache.resp.bits.cfs + state(i) := state_mem_out + } .otherwise { + state(i) := cm_next_state_normal + entries(i).wait_id := Mux(cm_to_wait, cm_wait_id, entries(i).wait_id) + entries(i).hit := cm_to_wait + } + } + } + } + + when (mem_arb.io.out.fire) { + for (i <- state.indices) { + when (state(i) === state_mem_req && dupBitmapPPN(entries(i).ppn, mem_arb.io.out.bits.ppn)) { + state(i) := state_mem_waiting + entries(i).wait_id := mem_arb.io.chosen + } + } + } + + when (io.mem.resp.fire) { + state.indices.map{i => + when (state(i) === state_mem_waiting && io.mem.resp.bits.id === entries(i).wait_id + (l2tlbParams.llptwsize + 2).U) { + state(i) := state_mem_out + val index = getBitmapAddr(entries(i).ppn)(log2Up(l2tlbParams.blockBytes)-1, log2Up(XLEN/8)) + entries(i).data := bitmapdata(index) + entries(i).cf := bitmapdata(index)(entries(i).ppn(5,0)) + val ppnPart = entries(i).ppn(5,3) + val start = (ppnPart << 3.U) + val end = start + 7.U + val mask = (1.U << 8) - 1.U + val selectedBits = (bitmapdata(index) >> start) & mask + for (j <- 0 until tlbcontiguous) { + entries(i).cfs(j) := selectedBits(j) + } + } + } + } + + when (io.resp.fire) { + state(mem_ptr) := state_idle + } + + when (flush) { + state.map(_ := state_idle) + } + + io.req.ready := !full + + io.resp.valid := ParallelOR(is_having).asBool + // if cache hit, resp the cache's resp + io.resp.bits.cf := entries(mem_ptr).cf + io.resp.bits.cfs := entries(mem_ptr).cfs + io.resp.bits.id := entries(mem_ptr).id + + io.mem.req.valid := mem_arb.io.out.valid && !flush + io.mem.req.bits.addr := getBitmapAddr(mem_arb.io.out.bits.ppn) + io.mem.req.bits.id := mem_arb.io.chosen + (l2tlbParams.llptwsize + 2).U + mem_arb.io.out.ready := io.mem.req.ready + + io.mem.resp.ready := waiting + + io.mem.req.bits.hptw_bypassed := false.B + + io.wakeup.valid := io.resp.valid && !entries(mem_ptr).hptw_bypassed + io.wakeup.bits.setIndex := genPtwL0SetIdx(entries(mem_ptr).vpn) + io.wakeup.bits.tag := entries(mem_ptr).vpn(vpnLen - 1, vpnLen - SPTagLen) + io.wakeup.bits.isSp := entries(mem_ptr).level =/= 0.U + io.wakeup.bits.way_info := entries(mem_ptr).way_info + io.wakeup.bits.pte_index := entries(mem_ptr).vpn(sectortlbwidth - 1, 0) + io.wakeup.bits.check_success := !entries(mem_ptr).cf + + // when don't hit, refill the data to bitmap cache + io.refill.valid := io.resp.valid && !entries(mem_ptr).hit + io.refill.bits.tag := entries(mem_ptr).ppn + io.refill.bits.data := entries(mem_ptr).data + + XSPerfAccumulate("bitmap_req", io.req.fire) + XSPerfAccumulate("bitmap_mem_req", io.mem.req.fire) +} + +// add bitmap cache +class bitmapCacheReqBundle(implicit p: Parameters) extends PtwBundle{ + val order = UInt((l2tlbParams.llptwsize + 2).W) + val tag = UInt(ppnLen.W) +} +class bitmapCacheRespBundle(implicit p: Parameters) extends PtwBundle{ + val hit = Bool() + val cfs = Vec(tlbcontiguous,Bool()) + val order = UInt((l2tlbParams.llptwsize + 2).W) + def apply(hit : Bool, cfs : Vec[Bool], order : UInt) = { + this.hit := hit + this.cfs := cfs + this.order := order + } +} +class bitmapCacheIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { + val req = Flipped(DecoupledIO(new bitmapCacheReqBundle())) + val resp = DecoupledIO(new bitmapCacheRespBundle()) + val refill = Flipped(ValidIO(new Bundle { + val tag = UInt(ppnLen.W) + val data = UInt(XLEN.W) + })) +} +class bitmapCacheEntry(implicit p: Parameters) extends PtwBundle{ + val tag = UInt((ppnLen-log2Ceil(XLEN)).W) + val data = UInt(XLEN.W) // store 64bits in one entry + val valid = Bool() + def hit(tag : UInt) = { + (this.tag === tag(ppnLen-1,log2Ceil(XLEN))) && this.valid === 1.B + } + def refill(tag : UInt,data : UInt,valid : Bool) = { + this.tag := tag(ppnLen-1,log2Ceil(XLEN)) + this.data := data + this.valid := valid + } +} + +class BitmapCache(implicit p: Parameters) extends XSModule with HasPtwConst { + val io = IO(new bitmapCacheIO) + + val csr = io.csr + val sfence = io.sfence + val flush = sfence.valid || csr.satp.changed || csr.vsatp.changed || csr.hgatp.changed + val bitmap_cache_clear = csr.mbmc.BCLEAR + + val bitmapCachesize = 16 + val bitmapcache = Reg(Vec(bitmapCachesize,new bitmapCacheEntry())) + val bitmapReplace = ReplacementPolicy.fromString(l2tlbParams.l3Replacer, bitmapCachesize) + + // ----- + // -S0-- + // ----- + val addr_search = io.req.bits.tag + val hitVecT = bitmapcache.map(_.hit(addr_search)) + + // ----- + // -S1-- + // ----- + val index = RegEnable(addr_search(log2Up(XLEN)-1,0), io.req.fire) + val order = RegEnable(io.req.bits.order, io.req.fire) + val hitVec = RegEnable(VecInit(hitVecT), io.req.fire) + val CacheData = RegEnable(ParallelPriorityMux(hitVecT zip bitmapcache.map(_.data)), io.req.fire) + val cfs = Wire(Vec(tlbcontiguous, Bool())) + + val start = (index(5, 3) << 3.U) + val end = start + 7.U + val mask = (1.U << 8) - 1.U + val cfsdata = (CacheData >> start) & mask + for (i <- 0 until tlbcontiguous) { + cfs(i) := cfsdata(i) + } + val hit = ParallelOR(hitVec) + + val resp_res = Wire(new bitmapCacheRespBundle()) + resp_res.apply(hit,cfs,order) + + val resp_valid_reg = RegInit(false.B) + when (flush) { + resp_valid_reg := false.B + } .elsewhen(io.req.fire) { + resp_valid_reg := true.B + } .elsewhen(io.resp.fire) { + resp_valid_reg := false.B + } .otherwise { + resp_valid_reg := resp_valid_reg + } + + io.req.ready := !resp_valid_reg || io.resp.fire + io.resp.valid := resp_valid_reg + io.resp.bits := resp_res + + when (!flush && hit && io.resp.fire) { + bitmapReplace.access(OHToUInt(hitVec)) + } + + // ----- + // refill + // ----- + val rf_addr = io.refill.bits.tag + val rf_data = io.refill.bits.data + val rf_vd = io.refill.valid + when (!flush && rf_vd) { + val refillindex = bitmapReplace.way + dontTouch(refillindex) + bitmapcache(refillindex).refill(rf_addr,rf_data,true.B) + bitmapReplace.access(refillindex) + } + when (bitmap_cache_clear === 1.U) { + bitmapcache.foreach(_.valid := false.B) + } + + XSPerfAccumulate("bitmap_cache_resp", io.resp.fire) + XSPerfAccumulate("bitmap_cache_resp_miss", io.resp.fire && !io.resp.bits.hit) +} diff --git a/src/main/scala/xiangshan/cache/mmu/L2TLB.scala b/src/main/scala/xiangshan/cache/mmu/L2TLB.scala index 38af7a99047..17601dd3626 100644 --- a/src/main/scala/xiangshan/cache/mmu/L2TLB.scala +++ b/src/main/scala/xiangshan/cache/mmu/L2TLB.scala @@ -1,6 +1,8 @@ /*************************************************************************************** -* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences +* Copyright (c) 2021-2025 Beijing Institute of Open Source Chip (BOSC) +* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences * Copyright (c) 2020-2021 Peng Cheng Laboratory +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences * * XiangShan is licensed under Mulan PSL v2. * You can use this software according to the terms and conditions of the Mulan PSL v2. @@ -78,8 +80,8 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val sfence_tmp = DelayN(io.sfence, 1) val csr_tmp = DelayN(io.csr.tlb, 1) - val sfence_dup = Seq.fill(9)(RegNext(sfence_tmp)) - val csr_dup = Seq.fill(8)(RegNext(csr_tmp)) // TODO: add csr_modified? + val sfence_dup = Seq.fill(if (HasBitmapCheck) 11 else 9)(RegNext(sfence_tmp)) + val csr_dup = Seq.fill(if (HasBitmapCheck) 10 else 8)(RegNext(csr_tmp)) // TODO: add csr_modified? val satp = csr_dup(0).satp val vsatp = csr_dup(0).vsatp val hgatp = csr_dup(0).hgatp @@ -89,9 +91,32 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val flush = sfence_dup(0).valid || satp.changed || vsatp.changed || hgatp.changed val pmp = Module(new PMP()) - val pmp_check = VecInit(Seq.fill(3)(Module(new PMPChecker(lgMaxSize = 3, sameCycle = true)).io)) + val pmp_check = VecInit(Seq.fill(if (HasBitmapCheck) 4 else 3)(Module(new PMPChecker(lgMaxSize = 3, sameCycle = true)).io)) pmp.io.distribute_csr := io.csr.distribute_csr - pmp_check.foreach(_.check_env.apply(ModeS, pmp.io.pmp, pmp.io.pma)) + if (HasBitmapCheck) { + pmp_check.foreach(_.check_env.apply(csr_dup(0).mbmc.CMODE.asBool, ModeS, pmp.io.pmp, pmp.io.pma)) + } else { + pmp_check.foreach(_.check_env.apply(ModeS, pmp.io.pmp, pmp.io.pma)) + } + + // add bitmapcheck + val bitmap = Option.when(HasBitmapCheck)(Module(new Bitmap)) + val bitmapcache = Option.when(HasBitmapCheck)(Module(new BitmapCache)) + + if (HasBitmapCheck) { + bitmap.foreach { Bitmap => + Bitmap.io.csr := csr_dup(8) + Bitmap.io.sfence := sfence_dup(9) + bitmapcache.foreach { BitmapCache => + // connect bitmap and bitmapcache + BitmapCache.io.req <> Bitmap.io.cache.req + Bitmap.io.cache.resp <> BitmapCache.io.resp + BitmapCache.io.refill <> Bitmap.io.refill + BitmapCache.io.csr := csr_dup(9) + BitmapCache.io.sfence := sfence_dup(10) + } + } + } val missQueue = Module(new L2TlbMissQueue) val cache = Module(new PtwCache) @@ -124,6 +149,11 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val outArbFsmPort = 1 val outArbMqPort = 2 + if (HasBitmapCheck) { + // connect ptwcache and bitmap sleep-wakeup port + cache.io.bitmap_wakeup.get <> bitmap.get.io.wakeup + } + // hptw arb input port val InHptwArbPTWPort = 0 val InHptwArbLLPTWPort = 1 @@ -217,19 +247,21 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi } arb2.io.out.ready := cache.io.req.ready + // Instructs requests from cache need go to LLPTW for processing + val toFsm_toLLPTW = if (HasBitmapCheck) cache.io.resp.bits.toFsm.bitmapCheck.get.toLLPTW else false.B val mq_arb = Module(new Arbiter(new L2TlbWithHptwIdBundle, 2)) mq_arb.io.in(0).valid := cache.io.resp.valid && !cache.io.resp.bits.hit && !from_pre(cache.io.resp.bits.req_info.source) && !cache.io.resp.bits.isHptwReq && // hptw reqs are not sent to missqueue (cache.io.resp.bits.bypassed || ( - ((!cache.io.resp.bits.toFsm.l1Hit || cache.io.resp.bits.toFsm.stage1Hit) && !cache.io.resp.bits.isHptwReq && (cache.io.resp.bits.isFirst || !ptw.io.req.ready)) // send to ptw, is first or ptw is busy; - || (cache.io.resp.bits.toFsm.l1Hit && !llptw.io.in.ready) // send to llptw, llptw is full + (((!cache.io.resp.bits.toFsm.l1Hit && !toFsm_toLLPTW) || cache.io.resp.bits.toFsm.stage1Hit) && !cache.io.resp.bits.isHptwReq && (cache.io.resp.bits.isFirst || !ptw.io.req.ready)) // send to ptw, is first or ptw is busy; + || ((cache.io.resp.bits.toFsm.l1Hit || toFsm_toLLPTW) && !llptw.io.in.ready) // send to llptw, llptw is full )) mq_arb.io.in(0).bits.req_info := cache.io.resp.bits.req_info mq_arb.io.in(0).bits.isHptwReq := false.B mq_arb.io.in(0).bits.hptwId := DontCare - mq_arb.io.in(0).bits.isLLptw := cache.io.resp.bits.toFsm.l1Hit + mq_arb.io.in(0).bits.isLLptw := cache.io.resp.bits.toFsm.l1Hit || toFsm_toLLPTW mq_arb.io.in(1).bits.req_info := llptw.io.cache.bits mq_arb.io.in(1).bits.isHptwReq := false.B mq_arb.io.in(1).bits.hptwId := DontCare @@ -245,11 +277,17 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi llptw.io.in.valid := cache.io.resp.valid && !cache.io.resp.bits.hit && - cache.io.resp.bits.toFsm.l1Hit && + (toFsm_toLLPTW || cache.io.resp.bits.toFsm.l1Hit) && !cache.io.resp.bits.bypassed && !cache.io.resp.bits.isHptwReq llptw.io.in.bits.req_info := cache.io.resp.bits.req_info llptw.io.in.bits.ppn := cache.io.resp.bits.toFsm.ppn + if (HasBitmapCheck) { + llptw.io.in.bits.bitmapCheck.get.jmp_bitmap_check := cache.io.resp.bits.toFsm.bitmapCheck.get.jmp_bitmap_check + llptw.io.in.bits.bitmapCheck.get.ptes := cache.io.resp.bits.toFsm.bitmapCheck.get.ptes + llptw.io.in.bits.bitmapCheck.get.cfs := cache.io.resp.bits.toFsm.bitmapCheck.get.cfs + llptw.io.in.bits.bitmapCheck.get.hitway := cache.io.resp.bits.toFsm.bitmapCheck.get.hitway + } llptw.io.sfence := sfence_dup(1) llptw.io.csr := csr_dup(1) val llptw_stage1 = Reg(Vec(l2tlbParams.llptwsize, new PtwMergeResp())) @@ -271,7 +309,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi (!cache.io.resp.bits.hit && cache.io.resp.bits.isHptwReq) -> hptw.io.req.ready, (cache.io.resp.bits.hit && cache.io.resp.bits.isHptwReq) -> hptw_resp_arb.io.in(HptwRespArbCachePort).ready, cache.io.resp.bits.hit -> outReady(cache.io.resp.bits.req_info.source, outArbCachePort), - (cache.io.resp.bits.toFsm.l1Hit && !cache.io.resp.bits.bypassed && llptw.io.in.ready) -> llptw.io.in.ready, + ((toFsm_toLLPTW || cache.io.resp.bits.toFsm.l1Hit) && !cache.io.resp.bits.bypassed && llptw.io.in.ready) -> llptw.io.in.ready, (cache.io.resp.bits.bypassed || cache.io.resp.bits.isFirst) -> mq_arb.io.in(0).ready )) @@ -279,7 +317,8 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi ptw.io.req.valid := cache.io.resp.valid && !cache.io.resp.bits.hit && !cache.io.resp.bits.toFsm.l1Hit && !cache.io.resp.bits.bypassed && !cache.io.resp.bits.isFirst && - !cache.io.resp.bits.isHptwReq + !cache.io.resp.bits.isHptwReq && + !toFsm_toLLPTW ptw.io.req.bits.req_info := cache.io.resp.bits.req_info if (EnableSv48) { ptw.io.req.bits.l3Hit.get := cache.io.resp.bits.toFsm.l3Hit.get @@ -288,6 +327,12 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi ptw.io.req.bits.ppn := cache.io.resp.bits.toFsm.ppn ptw.io.req.bits.stage1Hit := cache.io.resp.bits.toFsm.stage1Hit ptw.io.req.bits.stage1 := cache.io.resp.bits.stage1 + if (HasBitmapCheck) { + ptw.io.req.bits.bitmapCheck.get.jmp_bitmap_check := cache.io.resp.bits.toFsm.bitmapCheck.get.jmp_bitmap_check + ptw.io.req.bits.bitmapCheck.get.pte := cache.io.resp.bits.toFsm.bitmapCheck.get.pte + ptw.io.req.bits.bitmapCheck.get.cfs := cache.io.resp.bits.toFsm.bitmapCheck.get.cfs + ptw.io.req.bits.bitmapCheck.get.SPlevel := cache.io.resp.bits.toFsm.bitmapCheck.get.SPlevel + } ptw.io.sfence := sfence_dup(7) ptw.io.csr := csr_dup(6) ptw.io.resp.ready := outReady(ptw.io.resp.bits.source, outArbFsmPort) @@ -303,6 +348,9 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi hptw.io.req.bits.l1Hit := cache.io.resp.bits.toHptw.l1Hit hptw.io.req.bits.ppn := cache.io.resp.bits.toHptw.ppn hptw.io.req.bits.bypassed := cache.io.resp.bits.toHptw.bypassed + if (HasBitmapCheck) { + hptw.io.req.bits.bitmapCheck.get <> cache.io.resp.bits.toHptw.bitmapCheck.get + } hptw.io.sfence := sfence_dup(8) hptw.io.csr := csr_dup(7) // mem req @@ -324,6 +372,9 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi def from_hptw(id: UInt) = { id === l2tlbParams.llptwsize.U + 1.U } + def from_bitmap(id: UInt) = { + (id > l2tlbParams.llptwsize.U + 1.U) && (id < MemReqWidth.U) + } val waiting_resp = RegInit(VecInit(Seq.fill(MemReqWidth)(false.B))) val flush_latch = RegInit(VecInit(Seq.fill(MemReqWidth)(false.B))) val hptw_bypassed = RegInit(false.B) @@ -337,11 +388,16 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi llptw_mem.req_mask := waiting_resp.take(l2tlbParams.llptwsize) ptw.io.mem.mask := waiting_resp.apply(l2tlbParams.llptwsize) hptw.io.mem.mask := waiting_resp.apply(l2tlbParams.llptwsize + 1) - - val mem_arb = Module(new Arbiter(new L2TlbMemReqBundle(), 3)) + if (HasBitmapCheck) { + bitmap.get.io.mem.req_mask := waiting_resp.slice(MemReqWidth - (l2tlbParams.llptwsize + 2), MemReqWidth) + } + val mem_arb = Module(new Arbiter(new L2TlbMemReqBundle(), if (HasBitmapCheck) 4 else 3)) mem_arb.io.in(0) <> ptw.io.mem.req mem_arb.io.in(1) <> llptw_mem.req mem_arb.io.in(2) <> hptw.io.mem.req + if (HasBitmapCheck) { + mem_arb.io.in(3) <> bitmap.get.io.mem.req + } mem_arb.io.out.ready := mem.a.ready && !flush // // assert, should not send mem access at same addr for twice. @@ -390,6 +446,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val mem_resp_from_llptw = from_llptw(mem.d.bits.source) val mem_resp_from_ptw = from_ptw(mem.d.bits.source) val mem_resp_from_hptw = from_hptw(mem.d.bits.source) + val mem_resp_from_bitmap = from_bitmap(mem.d.bits.source) when (mem.d.valid) { assert(mem.d.bits.source < MemReqWidth.U) refill_data(refill_helper._4) := mem.d.bits.data @@ -400,7 +457,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi // save only one pte for each id // (miss queue may can't resp to tlb with low latency, it should have highest priority, but diffcult to design cache) - val resp_pte = VecInit((0 until MemReqWidth).map(i => + val resp_pte = VecInit((0 until (if (HasBitmapCheck) MemReqWidth / 2 else MemReqWidth)).map(i => if (i == l2tlbParams.llptwsize + 1) {RegEnable(get_part(refill_data_tmp, req_addr_low(i)), 0.U.asTypeOf(get_part(refill_data_tmp, req_addr_low(i))), mem_resp_done && mem_resp_from_hptw) } else if (i == l2tlbParams.llptwsize) {RegEnable(get_part(refill_data_tmp, req_addr_low(i)), 0.U.asTypeOf(get_part(refill_data_tmp, req_addr_low(i))), mem_resp_done && mem_resp_from_ptw) } else { Mux(llptw_mem.buffer_it(i), get_part(refill_data, req_addr_low(i)), RegEnable(get_part(refill_data, req_addr_low(i)), 0.U.asTypeOf(get_part(refill_data, req_addr_low(i))), llptw_mem.buffer_it(i))) } @@ -409,13 +466,52 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi // save eight ptes for each id when sector tlb // (miss queue may can't resp to tlb with low latency, it should have highest priority, but diffcult to design cache) - val resp_pte_sector = VecInit((0 until MemReqWidth).map(i => + val resp_pte_sector = VecInit((0 until (if (HasBitmapCheck) MemReqWidth / 2 else MemReqWidth)).map(i => if (i == l2tlbParams.llptwsize + 1) {RegEnable(refill_data_tmp, 0.U.asTypeOf(refill_data_tmp), mem_resp_done && mem_resp_from_hptw) } else if (i == l2tlbParams.llptwsize) {RegEnable(refill_data_tmp, 0.U.asTypeOf(refill_data_tmp), mem_resp_done && mem_resp_from_ptw) } else { Mux(llptw_mem.buffer_it(i), refill_data, RegEnable(refill_data, 0.U.asTypeOf(refill_data), llptw_mem.buffer_it(i))) } // llptw could not use refill_data_tmp, because enq bypass's result works at next cycle )) + if (HasBitmapCheck) { + // add bitmap arb + bitmap.foreach { Bitmap => + val bitmap_arb = Module(new Arbiter(new bitmapReqBundle(), 3)) + bitmap_arb.io.in(0) <> ptw.io.bitmap.get.req + bitmap_arb.io.in(1) <> llptw.io.bitmap.get.req + bitmap_arb.io.in(2) <> hptw.io.bitmap.get.req + bitmap_arb.io.out.ready := Bitmap.io.req.ready + + Bitmap.io.req <> bitmap_arb.io.out + + // connect bitmap resp to PTW + val bitmapresp_to_llptw = from_llptw(Bitmap.io.resp.bits.id) + val bitmapresp_to_hptw = from_hptw(Bitmap.io.resp.bits.id) + val bitmapresp_to_ptw = from_ptw(Bitmap.io.resp.bits.id) + + Bitmap.io.resp.ready := (llptw.io.bitmap.get.resp.ready && bitmapresp_to_llptw) || (hptw.io.bitmap.get.resp.ready && bitmapresp_to_hptw) || (ptw.io.bitmap.get.resp.ready && bitmapresp_to_ptw) + + // bitmap -> llptw ptw hptw + llptw.io.bitmap.get.resp.valid := Bitmap.io.resp.valid && bitmapresp_to_llptw + hptw.io.bitmap.get.resp.valid := Bitmap.io.resp.valid && bitmapresp_to_hptw + ptw.io.bitmap.get.resp.valid := Bitmap.io.resp.valid && bitmapresp_to_ptw + + // add ptw、hptw、llptw with bitmap resp connect + ptw.io.bitmap.get.resp.bits := Bitmap.io.resp.bits + hptw.io.bitmap.get.resp.bits := Bitmap.io.resp.bits + llptw.io.bitmap.get.resp.bits := Bitmap.io.resp.bits + + // mem -> bitmap + Bitmap.io.mem.resp.valid := mem_resp_done && mem_resp_from_bitmap + Bitmap.io.mem.resp.bits.id := DataHoldBypass(mem.d.bits.source, mem.d.valid) + Bitmap.io.mem.resp.bits.value := DataHoldBypass(refill_data_tmp.asUInt, mem.d.valid) + } + + // ptwcache -> hptw llptw + hptw.io.l0_way_info.get := cache.io.l0_way_info.get + llptw.io.l0_way_info.get := cache.io.l0_way_info.get + } + // mem -> llptw llptw_mem.resp.valid := mem_resp_done && mem_resp_from_llptw llptw_mem.resp.bits.id := DataHoldBypass(mem.d.bits.source, mem.d.valid) @@ -431,7 +527,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val refill_from_ptw = mem_resp_from_ptw val refill_from_hptw = mem_resp_from_hptw val refill_level = Mux(refill_from_llptw, 0.U, Mux(refill_from_ptw, RegEnable(ptw.io.refill.level, 0.U, ptw.io.mem.req.fire), RegEnable(hptw.io.refill.level, 0.U, hptw.io.mem.req.fire))) - val refill_valid = mem_resp_done && !flush && !flush_latch(mem.d.bits.source) && !hptw_bypassed + val refill_valid = mem_resp_done && (if (HasBitmapCheck) !mem_resp_from_bitmap else true.B) && !flush && !flush_latch(mem.d.bits.source) && !hptw_bypassed cache.io.refill.valid := GatedValidRegNext(refill_valid, false.B) cache.io.refill.bits.ptes := refill_data.asUInt @@ -491,6 +587,10 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi llptw.io.pmp.resp <> pmp_check(1).resp pmp_check(2).req <> hptw.io.pmp.req hptw.io.pmp.resp <> pmp_check(2).resp + if (HasBitmapCheck) { + pmp_check(3).req <> bitmap.get.io.pmp.req + bitmap.get.io.pmp.resp <> pmp_check(3).resp + } llptw_out.ready := outReady(llptw_out.bits.req_info.source, outArbMqPort) @@ -512,6 +612,8 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi llptw.io.hptw.resp.bits.h_resp := hptw_resp_arb.io.out.bits.resp hptw_resp_arb.io.out.ready := true.B + val cfsValue = Option.when(HasBitmapCheck)(llptw_out.bits.bitmapCheck.get.cfs) + // Timing: Maybe need to do some optimization or even add one more cycle for (i <- 0 until PtwWidth) { mergeArb(i).in(outArbCachePort).valid := cache.io.resp.valid && cache.io.resp.bits.hit && cache.io.resp.bits.req_info.source===i.U && !cache.io.resp.bits.isHptwReq @@ -527,8 +629,9 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi mergeArb(i).in(outArbMqPort).bits.s1 := Mux( llptw_out.bits.first_s2xlate_fault, llptw_stage1(llptw_out.bits.id), contiguous_pte_to_merge_ptwResp( - resp_pte_sector(llptw_out.bits.id).asUInt, llptw_out.bits.req_info.vpn, llptw_out.bits.af, - true, s2xlate = llptw_out.bits.req_info.s2xlate, mPBMTE = mPBMTE, hPBMTE = hPBMTE, gpf = llptw_out.bits.h_resp.gpf + if (HasBitmapCheck) Mux(llptw_out.bits.bitmapCheck.get.jmp_bitmap_check, llptw_out.bits.bitmapCheck.get.ptes.asUInt, resp_pte_sector(llptw_out.bits.id).asUInt) else resp_pte_sector(llptw_out.bits.id).asUInt, llptw_out.bits.req_info.vpn, llptw_out.bits.af, + true, s2xlate = llptw_out.bits.req_info.s2xlate, mPBMTE = mPBMTE, hPBMTE = hPBMTE, gpf = llptw_out.bits.h_resp.gpf, + cfs = cfsValue.getOrElse(VecInit(Seq.fill(tlbcontiguous)(false.B))) ) ) mergeArb(i).in(outArbMqPort).bits.s2 := llptw_out.bits.h_resp @@ -575,7 +678,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi // not_super means that this is a normal page // valididx(i) will be all true when super page to be convenient for l1 tlb matching - def contiguous_pte_to_merge_ptwResp(pte: UInt, vpn: UInt, af: Bool, af_first: Boolean, s2xlate: UInt, mPBMTE: Bool, hPBMTE: Bool, not_super: Boolean = true, gpf: Bool) : PtwMergeResp = { + def contiguous_pte_to_merge_ptwResp(pte: UInt, vpn: UInt, af: Bool, af_first: Boolean, s2xlate: UInt, mPBMTE: Bool, hPBMTE: Bool, not_super: Boolean = true, gpf: Bool, cfs : Vec[Bool]) : PtwMergeResp = { assert(tlbcontiguous == 8, "Only support tlbcontiguous = 8!") val ptw_merge_resp = Wire(new PtwMergeResp()) val hasS2xlate = s2xlate =/= noS2xlate @@ -592,6 +695,7 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi ptw_resp.tag := vpn(vpnLen - 1, sectortlbwidth) ptw_resp.pf := (if (af_first) !af else true.B) && (pte_in.isPf(0.U, pbmte) || !pte_in.isLeaf()) ptw_resp.af := (if (!af_first) pte_in.isPf(0.U, pbmte) else true.B) && (af || (Mux(s2xlate === allStage, false.B, pte_in.isAf()) && !(hasS2xlate && gpf))) + ptw_resp.cf := cfs(ptw_resp.ppn(sectortlbwidth - 1, 0)) ptw_resp.v := !ptw_resp.pf ptw_resp.prefetch := DontCare ptw_resp.asid := Mux(hasS2xlate, vsatp.asid, satp.asid) @@ -629,7 +733,8 @@ class L2TLBImp(outer: L2TLB)(implicit p: Parameters) extends PtwModule(outer) wi val v_equal = pte.entry(i).v === pte.entry(OHToUInt(pte.pteidx)).v val af_equal = pte.entry(i).af === pte.entry(OHToUInt(pte.pteidx)).af val pf_equal = pte.entry(i).pf === pte.entry(OHToUInt(pte.pteidx)).pf - ptw_sector_resp.valididx(i) := ((ppn_equal && pbmt_equal && n_equal && perm_equal && v_equal && af_equal && pf_equal) || !pte.not_super) && !pte.not_merge + val cf_equal = if (HasBitmapCheck) pte.entry(i).cf === pte.entry(OHToUInt(pte.pteidx)).cf else true.B + ptw_sector_resp.valididx(i) := ((ppn_equal && pbmt_equal && n_equal && perm_equal && v_equal && af_equal && pf_equal && cf_equal) || !pte.not_super) && !pte.not_merge ptw_sector_resp.ppn_low(i) := pte.entry(i).ppn_low } ptw_sector_resp.valididx(OHToUInt(pte.pteidx)) := true.B diff --git a/src/main/scala/xiangshan/cache/mmu/MMUBundle.scala b/src/main/scala/xiangshan/cache/mmu/MMUBundle.scala index 51f9092815f..ebb5e0e7a27 100644 --- a/src/main/scala/xiangshan/cache/mmu/MMUBundle.scala +++ b/src/main/scala/xiangshan/cache/mmu/MMUBundle.scala @@ -941,6 +941,7 @@ class PtwMergeEntry(tagLen: Int, hasPerm: Boolean = false, hasLevel: Boolean = f val ppn_low = UInt(sectortlbwidth.W) val af = Bool() val pf = Bool() + val cf = Bool() // Bitmap Check Failed } class PtwEntries(num: Int, tagLen: Int, level: Int, hasPerm: Boolean, ReservedBits: Int)(implicit p: Parameters) extends PtwBundle { @@ -1235,7 +1236,7 @@ class PtwMergeResp(implicit p: Parameters) extends PtwBundle { val not_super = Bool() val not_merge = Bool() - def apply(pf: Bool, af: Bool, level: UInt, pte: PteBundle, vpn: UInt, asid: UInt, vmid:UInt, addr_low : UInt, not_super : Boolean = true, not_merge: Boolean = false) = { + def apply(pf: Bool, af: Bool, level: UInt, pte: PteBundle, vpn: UInt, asid: UInt, vmid:UInt, addr_low : UInt, not_super : Boolean = true, not_merge: Boolean = false, cf : Bool) = { assert(tlbcontiguous == 8, "Only support tlbcontiguous = 8!") val resp_pte = pte val ptw_resp = Wire(new PtwMergeEntry(tagLen = sectorvpnLen, hasPerm = true, hasLevel = true, hasNapot = true)) @@ -1248,6 +1249,7 @@ class PtwMergeResp(implicit p: Parameters) extends PtwBundle { ptw_resp.tag := vpn(vpnLen - 1, sectortlbwidth) ptw_resp.pf := pf ptw_resp.af := af + ptw_resp.cf := cf // Bitmap Check Failed ptw_resp.v := resp_pte.perm.v ptw_resp.prefetch := DontCare ptw_resp.asid := asid diff --git a/src/main/scala/xiangshan/cache/mmu/MMUConst.scala b/src/main/scala/xiangshan/cache/mmu/MMUConst.scala index 612778482a7..1edb4f3da73 100644 --- a/src/main/scala/xiangshan/cache/mmu/MMUConst.scala +++ b/src/main/scala/xiangshan/cache/mmu/MMUConst.scala @@ -258,11 +258,23 @@ trait HasPtwConst extends HasTlbConst with MemoryOpConstants{ // miss queue val MissQueueSize = l2tlbParams.ifilterSize + l2tlbParams.dfilterSize - val MemReqWidth = l2tlbParams.llptwsize + 1 + 1 + val MemReqWidth = if (HasBitmapCheck) 2 *(l2tlbParams.llptwsize + 1 + 1) else (l2tlbParams.llptwsize + 1 + 1) val HptwReqId = l2tlbParams.llptwsize + 1 val FsmReqID = l2tlbParams.llptwsize val bMemID = log2Up(MemReqWidth) + def ptwTranVec(flushMask: UInt): Vec[Bool] = { + val vec = Wire(Vec(tlbcontiguous, Bool())) + for (i <- 0 until tlbcontiguous) { + vec(i) := flushMask(i) + } + vec + } + + def dupBitmapPPN(ppn1: UInt, ppn2: UInt) : Bool = { + ppn1(ppnLen-1, ppnLen-log2Up(XLEN)) === ppn2(ppnLen-1, ppnLen-log2Up(XLEN)) + } + def genPtwL1Idx(vpn: UInt) = { (vpn(vpnLen - 1, vpnnLen))(PtwL1IdxLen - 1, 0) } diff --git a/src/main/scala/xiangshan/cache/mmu/PageTableCache.scala b/src/main/scala/xiangshan/cache/mmu/PageTableCache.scala index d6ecc912a44..bd989560de6 100644 --- a/src/main/scala/xiangshan/cache/mmu/PageTableCache.scala +++ b/src/main/scala/xiangshan/cache/mmu/PageTableCache.scala @@ -1,7 +1,8 @@ /*************************************************************************************** -* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC) +* Copyright (c) 2021-2025 Beijing Institute of Open Source Chip (BOSC) * Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences * Copyright (c) 2020-2021 Peng Cheng Laboratory +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences * * XiangShan is licensed under Mulan PSL v2. * You can use this software according to the terms and conditions of the Mulan PSL v2. @@ -44,10 +45,15 @@ class PageCachePerPespBundle(implicit p: Parameters) extends PtwBundle { val ecc = Bool() val level = UInt(2.W) val v = Bool() + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() + val pte = UInt(XLEN.W) // Page Table Entry + }) def apply(hit: Bool, pre: Bool, ppn: UInt, pbmt: UInt = 0.U, n: UInt = 0.U, perm: PtePermBundle = 0.U.asTypeOf(new PtePermBundle()), - ecc: Bool = false.B, level: UInt = 0.U, valid: Bool = true.B): Unit = { + ecc: Bool = false.B, level: UInt = 0.U, valid: Bool = true.B, jmp_bitmap_check: Bool = false.B, + pte: UInt = 0.U): Unit = { this.hit := hit && !ecc this.pre := pre this.ppn := ppn @@ -57,6 +63,10 @@ class PageCachePerPespBundle(implicit p: Parameters) extends PtwBundle { this.ecc := ecc && hit this.level := level this.v := valid + if (HasBitmapCheck) { + this.bitmapCheck.get.jmp_bitmap_check := jmp_bitmap_check + this.bitmapCheck.get.pte := pte + } } } @@ -70,10 +80,18 @@ class PageCacheMergePespBundle(implicit p: Parameters) extends PtwBundle { val ecc = Bool() val level = UInt(2.W) val v = Vec(tlbcontiguous, Bool()) + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() + val hitway = UInt(l2tlbParams.l0nWays.W) + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + }) def apply(hit: Bool, pre: Bool, ppn: Vec[UInt], pbmt: Vec[UInt] = Vec(tlbcontiguous, 0.U), perm: Vec[PtePermBundle] = Vec(tlbcontiguous, 0.U.asTypeOf(new PtePermBundle())), - ecc: Bool = false.B, level: UInt = 0.U, valid: Vec[Bool] = Vec(tlbcontiguous, true.B)): Unit = { + ecc: Bool = false.B, level: UInt = 0.U, valid: Vec[Bool] = Vec(tlbcontiguous, true.B), + jmp_bitmap_check: Bool = false.B, + hitway: UInt = 0.U, ptes: Vec[UInt] , cfs: Vec[Bool]): Unit = { this.hit := hit && !ecc this.pre := pre this.ppn := ppn @@ -82,6 +100,12 @@ class PageCacheMergePespBundle(implicit p: Parameters) extends PtwBundle { this.ecc := ecc && hit this.level := level this.v := valid + if (HasBitmapCheck) { + this.bitmapCheck.get.jmp_bitmap_check := jmp_bitmap_check + this.bitmapCheck.get.hitway := hitway + this.bitmapCheck.get.ptes := ptes + this.bitmapCheck.get.cfs := cfs + } } } @@ -115,6 +139,15 @@ class PtwCacheIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwCo val l1Hit = Bool() val ppn = UInt(gvpnLen.W) val stage1Hit = Bool() // find stage 1 pte in cache, but need to search stage 2 pte in cache at PTW + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // find pte in l0 or sp, but need bitmap check + val toLLPTW = Bool() + val hitway = UInt(l2tlbParams.l0nWays.W) + val pte = UInt(XLEN.W) // Page Table Entry + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + val SPlevel = UInt(log2Up(Level).W) + }) } val stage1 = new PtwMergeResp() val isHptwReq = Bool() @@ -126,6 +159,15 @@ class PtwCacheIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwCo val id = UInt(log2Up(l2tlbParams.llptwsize).W) val resp = new HptwResp() // used if hit val bypassed = Bool() + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // find pte in l0 or sp, but need bitmap check + val hitway = UInt(l2tlbParams.l0nWays.W) + val pte = UInt(XLEN.W) // Page Table Entry + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + val fromSP = Bool() + val SPlevel = UInt(log2Up(Level).W) + }) } }) val refill = Flipped(ValidIO(new Bundle { @@ -150,8 +192,19 @@ class PtwCacheIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwCo val level_dup = Vec(3, UInt(log2Up(Level + 1).W)) val sel_pte_dup = Vec(3, UInt(XLEN.W)) })) + // when refill l0,save way info for late bitmap wakeup convenient + // valid at same cycle of refill.levelOH.l0 + val l0_way_info = Option.when(HasBitmapCheck)(Output(UInt(l2tlbParams.l0nWays.W))) val sfence_dup = Vec(4, Input(new SfenceBundle())) val csr_dup = Vec(3, Input(new TlbCsrBundle())) + val bitmap_wakeup = Option.when(HasBitmapCheck)(Flipped(ValidIO(new Bundle { + val setIndex = Input(UInt(PtwL0SetIdxLen.W)) + val tag = Input(UInt(SPTagLen.W)) + val isSp = Input(Bool()) + val way_info = UInt(l2tlbParams.l0nWays.W) + val pte_index = UInt(sectortlbwidth.W) + val check_success = Bool() + }))) } class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPerfEvents { @@ -160,6 +213,12 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with val l1EntryType = new PTWEntriesWithEcc(ecc, num = PtwL1SectorSize, tagLen = PtwL1TagLen, level = 1, hasPerm = false, ReservedBits = l2tlbParams.l1ReservedBits) val l0EntryType = new PTWEntriesWithEcc(ecc, num = PtwL0SectorSize, tagLen = PtwL0TagLen, level = 0, hasPerm = true, ReservedBits = l2tlbParams.l0ReservedBits) + // use two additional regs to record corresponding cache entry whether via bitmap check + // 32(l0nSets)* 8 (l0nWays) * 8 (tlbcontiguous) + val l0BitmapReg = RegInit(VecInit(Seq.fill(l2tlbParams.l0nSets)(VecInit(Seq.fill(l2tlbParams.l0nWays)(VecInit(Seq.fill(tlbcontiguous)(0.U(1.W)))))))) + val spBitmapReg = RegInit(VecInit(Seq.fill(l2tlbParams.spSize)(0.U(1.W)))) + + val bitmapEnable = io.csr_dup(0).mbmc.BME === 1.U && io.csr_dup(0).mbmc.CMODE === 0.U // TODO: four caches make the codes dirty, think about how to deal with it val sfence_dup = io.sfence_dup @@ -270,6 +329,24 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with val spvmids = sp.map(_.vmid) val sph = Reg(Vec(l2tlbParams.spSize, UInt(2.W))) + if (HasBitmapCheck) { + // wakeup corresponding entry + when (io.bitmap_wakeup.get.valid) { + when (io.bitmap_wakeup.get.bits.isSp) { + for (i <- 0 until l2tlbParams.spSize) { + when (sp(i).tag === io.bitmap_wakeup.get.bits.tag && spv(i) === 1.U) { + spBitmapReg(i) := io.bitmap_wakeup.get.bits.check_success + } + } + } .otherwise { + val wakeup_setindex = io.bitmap_wakeup.get.bits.setIndex + l0BitmapReg(wakeup_setindex)(OHToUInt(io.bitmap_wakeup.get.bits.way_info))(io.bitmap_wakeup.get.bits.pte_index) := io.bitmap_wakeup.get.bits.check_success + assert(l0v(wakeup_setindex * l2tlbParams.l0nWays.U + OHToUInt(io.bitmap_wakeup.get.bits.way_info)) === 1.U, + "Wakeuped entry must be valid!") + } + } + } + // Access Perf val l3AccessPerf = if(EnableSv48) Some(Wire(Vec(l2tlbParams.l3Size, Bool()))) else None val l2AccessPerf = Wire(Vec(l2tlbParams.l2Size, Bool())) @@ -445,7 +522,7 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with l1.clock := l1_masked_clock // l0 val ptwl0replace = ReplacementPolicy.fromString(l2tlbParams.l0Replacer,l2tlbParams.l0nWays,l2tlbParams.l0nSets) - val (l0Hit, l0HitData, l0Pre, l0eccError) = { + val (l0Hit, l0HitData, l0Pre, l0eccError, l0HitWay, l0BitmapCheckResult, l0JmpBitmapCheck) = { val ridx = genPtwL0SetIdx(vpn_search) l0.io.r.req.valid := stageReq.fire l0.io.r.req.bits.apply(setIdx = ridx) @@ -474,8 +551,26 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with val hitWayEntry = ParallelPriorityMux(hitVec zip ramDatas) val hitWayData = hitWayEntry.entries val hitWayEcc = hitWayEntry.ecc - val hit = ParallelOR(hitVec) val hitWay = ParallelPriorityMux(hitVec zip (0 until l2tlbParams.l0nWays).map(_.U(log2Up(l2tlbParams.l0nWays).W))) + + val ishptw = RegEnable(stageDelay(0).bits.isHptwReq,stageDelay(1).fire) + val s2x_info = RegEnable(stageDelay(0).bits.req_info.s2xlate,stageDelay(1).fire) + val pte_index = RegEnable(stageDelay(0).bits.req_info.vpn(sectortlbwidth - 1, 0),stageDelay(1).fire) + val jmp_bitmap_check = WireInit(false.B) + val hit = WireInit(false.B) + val l0bitmapreg = WireInit((VecInit(Seq.fill(l2tlbParams.l0nWays)(VecInit(Seq.fill(tlbcontiguous)(0.U(1.W))))))) + if (HasBitmapCheck) { + l0bitmapreg := RegEnable(RegNext(l0BitmapReg(ridx)), stageDelay(1).fire) + // cause llptw will trigger bitmapcheck + // add a coniditonal logic + // (s2x_info =/= allStage || ishptw) + hit := Mux(bitmapEnable && (s2x_info =/= allStage || ishptw), ParallelOR(hitVec) && l0bitmapreg(hitWay)(pte_index) === 1.U, ParallelOR(hitVec)) + when (bitmapEnable && (s2x_info =/= allStage || ishptw) && ParallelOR(hitVec) && l0bitmapreg(hitWay)(pte_index) === 0.U) { + jmp_bitmap_check := true.B + } + } else { + hit := ParallelOR(hitVec) + } val eccError = WireInit(false.B) if (l2tlbParams.enablePTWECC) { eccError := hitWayEntry.decode() @@ -497,20 +592,39 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with hitVec.suggestName(s"l0_hitVec") hitWay.suggestName(s"l0_hitWay") - (hit, hitWayData, hitWayData.prefetch, eccError) + (hit, hitWayData, hitWayData.prefetch, eccError, UIntToOH(hitWay), l0bitmapreg(hitWay), jmp_bitmap_check) } val l0HitPPN = l0HitData.ppns val l0HitPbmt = l0HitData.pbmts val l0HitPerm = l0HitData.perms.getOrElse(0.U.asTypeOf(Vec(PtwL0SectorSize, new PtePermBundle))) val l0HitValid = VecInit(l0HitData.onlypf.map(!_)) + val l0Ptes = WireInit(VecInit(Seq.fill(tlbcontiguous)(0.U(XLEN.W)))) // L0 lavel Page Table Entry Vector + val l0cfs = WireInit(VecInit(Seq.fill(tlbcontiguous)(false.B))) // L0 lavel Bitmap Check Failed Vector + if (HasBitmapCheck) { + for (i <- 0 until tlbcontiguous) { + l0Ptes(i) := Cat(l0HitData.pbmts(i).asUInt,l0HitPPN(i), 0.U(2.W),l0HitPerm(i).asUInt,l0HitValid(i).asUInt) + l0cfs(i) := !l0BitmapCheckResult(i) + } + } // super page val spreplace = ReplacementPolicy.fromString(l2tlbParams.spReplacer, l2tlbParams.spSize) - val (spHit, spHitData, spPre, spValid) = { + val (spHit, spHitData, spPre, spValid, spJmpBitmapCheck) = { val hitVecT = sp.zipWithIndex.map { case (e, i) => e.hit(vpn_search, io.csr_dup(0).satp.asid, io.csr_dup(0).vsatp.asid, io.csr_dup(0).hgatp.vmid, allType = true, s2xlate = h_search =/= noS2xlate) && spv(i) && (sph(i) === h_search) } val hitVec = hitVecT.map(RegEnable(_, stageReq.fire)) val hitData = ParallelPriorityMux(hitVec zip sp) - val hit = ParallelOR(hitVec) + val ishptw = RegEnable(stageReq.bits.isHptwReq, stageReq.fire) + val s2x_info = RegEnable(stageReq.bits.req_info.s2xlate, stageReq.fire) + val jmp_bitmap_check = WireInit(false.B) + val hit = WireInit(false.B) + if (HasBitmapCheck) { + hit := Mux(bitmapEnable && (s2x_info =/= allStage || ishptw), ParallelOR(hitVec) && spBitmapReg(OHToUInt(hitVec)) === 1.U, ParallelOR(hitVec)) + when (bitmapEnable && (s2x_info =/= allStage || ishptw) && ParallelOR(hitVec) && spBitmapReg(OHToUInt(hitVec)) === 0.U) { + jmp_bitmap_check := true.B + } + } else { + hit := ParallelOR(hitVec) + } when (hit && stageDelay_valid_1cycle) { spreplace.access(OHToUInt(hitVec)) } @@ -526,17 +640,19 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with (RegEnable(hit, stageDelay(1).fire), RegEnable(hitData, stageDelay(1).fire), RegEnable(hitData.prefetch, stageDelay(1).fire), - RegEnable(hitData.v, stageDelay(1).fire)) + RegEnable(hitData.v, stageDelay(1).fire), + RegEnable(jmp_bitmap_check, stageDelay(1).fire)) } val spHitPerm = spHitData.perm.getOrElse(0.U.asTypeOf(new PtePermBundle)) val spHitLevel = spHitData.level.getOrElse(0.U) + val spPte = Cat(spHitData.pbmt.asUInt,spHitData.ppn, 0.U(2.W), spHitPerm.asUInt,spHitData.v.asUInt) // Super-page Page Table Entry val check_res = Wire(new PageCacheRespBundle) check_res.l3.map(_.apply(l3Hit.get, l3Pre.get, l3HitPPN.get, l3HitPbmt.get)) check_res.l2.apply(l2Hit, l2Pre, l2HitPPN, l2HitPbmt) check_res.l1.apply(l1Hit, l1Pre, l1HitPPN, l1HitPbmt, ecc = l1eccError) - check_res.l0.apply(l0Hit, l0Pre, l0HitPPN, l0HitPbmt, l0HitPerm, l0eccError, valid = l0HitValid) - check_res.sp.apply(spHit, spPre, spHitData.ppn, spHitData.pbmt, spHitData.n.getOrElse(0.U), spHitPerm, false.B, spHitLevel, spValid) + check_res.l0.apply(l0Hit, l0Pre, l0HitPPN, l0HitPbmt, l0HitPerm, l0eccError, valid = l0HitValid, jmp_bitmap_check = l0JmpBitmapCheck, hitway = l0HitWay, ptes = l0Ptes, cfs = l0cfs) + check_res.sp.apply(spHit, spPre, spHitData.ppn, spHitData.pbmt, spHitData.n.getOrElse(0.U), spHitPerm, false.B, spHitLevel, spValid, spJmpBitmapCheck, spPte) val resp_res = Reg(new PageCacheRespBundle) when (stageCheck(1).fire) { resp_res := check_res } @@ -576,6 +692,15 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with io.resp.bits.toFsm.l1Hit := resp_res.l1.hit && !stage1Hit && !isOnlyStage2 && !stageResp.bits.isHptwReq io.resp.bits.toFsm.ppn := Mux(resp_res.l1.hit, resp_res.l1.ppn, Mux(resp_res.l2.hit, resp_res.l2.ppn, resp_res.l3.getOrElse(0.U.asTypeOf(new PageCachePerPespBundle)).ppn)) io.resp.bits.toFsm.stage1Hit := stage1Hit + if (HasBitmapCheck) { + io.resp.bits.toFsm.bitmapCheck.get.jmp_bitmap_check := resp_res.l0.bitmapCheck.get.jmp_bitmap_check || resp_res.sp.bitmapCheck.get.jmp_bitmap_check + io.resp.bits.toFsm.bitmapCheck.get.toLLPTW := resp_res.l0.bitmapCheck.get.jmp_bitmap_check && (stageResp.bits.req_info.s2xlate === noS2xlate || stageResp.bits.req_info.s2xlate === onlyStage1) + io.resp.bits.toFsm.bitmapCheck.get.hitway := resp_res.l0.bitmapCheck.get.hitway + io.resp.bits.toFsm.bitmapCheck.get.pte := resp_res.sp.bitmapCheck.get.pte + io.resp.bits.toFsm.bitmapCheck.get.ptes := resp_res.l0.bitmapCheck.get.ptes + io.resp.bits.toFsm.bitmapCheck.get.cfs := resp_res.l0.bitmapCheck.get.cfs + io.resp.bits.toFsm.bitmapCheck.get.SPlevel := resp_res.sp.level + } io.resp.bits.isHptwReq := stageResp.bits.isHptwReq if (EnableSv48) { @@ -600,6 +725,15 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with io.resp.bits.toHptw.resp.entry.v := Mux(resp_res.l0.hit, resp_res.l0.v(idx), resp_res.sp.v) io.resp.bits.toHptw.resp.gpf := !io.resp.bits.toHptw.resp.entry.v io.resp.bits.toHptw.resp.gaf := false.B + if (HasBitmapCheck) { + io.resp.bits.toHptw.bitmapCheck.get.jmp_bitmap_check := resp_res.l0.bitmapCheck.get.jmp_bitmap_check || resp_res.sp.bitmapCheck.get.jmp_bitmap_check + io.resp.bits.toHptw.bitmapCheck.get.hitway := resp_res.l0.bitmapCheck.get.hitway + io.resp.bits.toHptw.bitmapCheck.get.pte := resp_res.sp.bitmapCheck.get.pte + io.resp.bits.toHptw.bitmapCheck.get.ptes := resp_res.l0.bitmapCheck.get.ptes + io.resp.bits.toHptw.bitmapCheck.get.cfs := resp_res.l0.bitmapCheck.get.cfs + io.resp.bits.toHptw.bitmapCheck.get.fromSP := resp_res.sp.bitmapCheck.get.jmp_bitmap_check + io.resp.bits.toHptw.bitmapCheck.get.SPlevel := resp_res.sp.level + } io.resp.bits.stage1.entry.map(_.tag := stageResp.bits.req_info.vpn(vpnLen - 1, 3)) io.resp.bits.stage1.entry.map(_.asid := Mux(stageResp.bits.req_info.hasS2xlate(), io.csr_dup(0).vsatp.asid, io.csr_dup(0).satp.asid)) // DontCare @@ -654,6 +788,7 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with io.resp.bits.stage1.entry(i).perm.map(_ := Mux(resp_res.l0.hit, resp_res.l0.perm(i), Mux(resp_res.sp.hit, resp_res.sp.perm, 0.U.asTypeOf(new PtePermBundle)))) io.resp.bits.stage1.entry(i).pf := !io.resp.bits.stage1.entry(i).v io.resp.bits.stage1.entry(i).af := false.B + io.resp.bits.stage1.entry(i).cf := l0cfs(i) // L0 lavel Bitmap Check Failed Vector } io.resp.bits.stage1.pteidx := UIntToOH(idx).asBools io.resp.bits.stage1.not_super := Mux(resp_res.l0.hit, true.B, false.B) @@ -687,6 +822,37 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with val hPBMTE = io.csr.hPBMTE val pbmte = Mux(refill.req_info_dup(0).s2xlate === onlyStage1 || refill.req_info_dup(0).s2xlate === allStage, hPBMTE, mPBMTE) + def Tran2D(flushMask: UInt): Vec[UInt] = { + val tran2D = Wire(Vec(l2tlbParams.l0nSets,UInt(l2tlbParams.l0nWays.W))) + for (i <- 0 until l2tlbParams.l0nSets) { + tran2D(i) := flushMask((i + 1) * l2tlbParams.l0nWays - 1, i * l2tlbParams.l0nWays) + } + tran2D + } + def updateL0BitmapReg(l0BitmapReg: Vec[Vec[Vec[UInt]]], tran2D: Vec[UInt]) = { + for (i <- 0 until l2tlbParams.l0nSets) { + for (j <- 0 until l2tlbParams.l0nWays) { + when (tran2D(i)(j) === 0.U) { + for (k <- 0 until tlbcontiguous) { + l0BitmapReg(i)(j)(k) := 0.U + } + } + } + } + } + def TranVec(flushMask: UInt): Vec[UInt] = { + val vec = Wire(Vec(l2tlbParams.spSize,UInt(1.W))) + for (i <- 0 until l2tlbParams.spSize) { + vec(i) := flushMask(i) + } + vec + } + def updateSpBitmapReg(spBitmapReg: Vec[UInt], vec : Vec[UInt]) = { + for (i <- 0 until l2tlbParams.spSize) { + spBitmapReg(i) := spBitmapReg(i) & vec(i) + } + } + // TODO: handle sfenceLatch outsize if (EnableSv48) { val l3Refill = @@ -795,6 +961,10 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with val l0VictimWayOH = UIntToOH(l0VictimWay).asUInt.suggestName(s"l0_victimWayOH") val l0RfvOH = UIntToOH(Cat(l0RefillIdx, l0VictimWay)).suggestName(s"l0_rfvOH") val l0Wdata = Wire(l0EntryType) + // trans the l0 way info, for late wakeup logic + if (HasBitmapCheck) { + io.l0_way_info.get := l0VictimWayOH + } l0Wdata.gen( vpn = refill.req_info_dup(0).vpn, asid = Mux(refill.req_info_dup(0).s2xlate =/= noS2xlate, io.csr_dup(0).vsatp.asid, io.csr_dup(0).satp.asid), @@ -817,6 +987,7 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with l0v := l0v | l0RfvOH l0g := l0g & ~l0RfvOH | Mux(Cat(memPtes.map(_.perm.g)).andR, l0RfvOH, 0.U) l0h(l0RefillIdx)(l0VictimWay) := refill_h(0) + if (HasBitmapCheck) {updateL0BitmapReg(l0BitmapReg, Tran2D(~l0RfvOH))} for (i <- 0 until l2tlbParams.l0nWays) { l0RefillPerf(i) := i.U === l0VictimWay @@ -850,6 +1021,7 @@ class PtwCache()(implicit p: Parameters) extends XSModule with HasPtwConst with spv := spv | spRfOH spg := spg & ~spRfOH | Mux(memPte(0).perm.g, spRfOH, 0.U) sph(spRefillIdx) := refill_h(0) + if (HasBitmapCheck) {updateSpBitmapReg(spBitmapReg, TranVec(~spRfOH))} for (i <- 0 until l2tlbParams.spSize) { spRefillPerf(i) := i.U === spRefillIdx diff --git a/src/main/scala/xiangshan/cache/mmu/PageTableWalker.scala b/src/main/scala/xiangshan/cache/mmu/PageTableWalker.scala index b3a07dd9679..8de8eee98c3 100644 --- a/src/main/scala/xiangshan/cache/mmu/PageTableWalker.scala +++ b/src/main/scala/xiangshan/cache/mmu/PageTableWalker.scala @@ -1,6 +1,8 @@ /*************************************************************************************** -* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences +* Copyright (c) 2021-2025 Beijing Institute of Open Source Chip (BOSC) +* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences * Copyright (c) 2020-2021 Peng Cheng Laboratory +* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences * * XiangShan is licensed under Mulan PSL v2. * You can use this software according to the terms and conditions of the Mulan PSL v2. @@ -46,6 +48,12 @@ class PTWIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { val ppn = UInt(ptePPNLen.W) val stage1Hit = Bool() val stage1 = new PtwMergeResp + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // super page in PtwCache ptw hit, but need bitmap check + val pte = UInt(XLEN.W) // Page Table Entry + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + val SPlevel = UInt(log2Up(Level).W) + }) })) val resp = DecoupledIO(new Bundle { val source = UInt(bSourceWidth.W) @@ -82,6 +90,10 @@ class PTWIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { val req_info = new L2TlbInnerBundle() val level = UInt(log2Up(Level + 1).W) }) + val bitmap = Option.when(HasBitmapCheck)(new Bundle { + val req = DecoupledIO(new bitmapReqBundle()) + val resp = Flipped(DecoupledIO(new bitmapRespBundle())) + }) } class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPerfEvents { @@ -92,6 +104,11 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val enableS2xlate = req_s2xlate =/= noS2xlate val onlyS1xlate = req_s2xlate === onlyStage1 val onlyS2xlate = req_s2xlate === onlyStage2 + + // mbmc:bitmap csr + val mbmc = io.csr.mbmc + val bitmap_enable = (if (HasBitmapCheck) true.B else false.B) && mbmc.BME === 1.U && mbmc.CMODE === 0.U + val satp = Wire(new TlbSatpBundle()) when (io.req.fire) { satp := Mux(io.req.bits.req_info.s2xlate =/= noS2xlate, io.csr.vsatp, io.csr.satp) @@ -112,7 +129,10 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val levelNext = level - 1.U val l3Hit = Reg(Bool()) val l2Hit = Reg(Bool()) - val pte = mem.resp.bits.asTypeOf(new PteBundle()) + val jmp_bitmap_check_w = if (HasBitmapCheck) { io.req.bits.bitmapCheck.get.jmp_bitmap_check && io.req.bits.req_info.s2xlate =/= onlyStage2 } else { false.B } + val jmp_bitmap_check_r = if (HasBitmapCheck) { RegEnable(jmp_bitmap_check_w, io.req.fire) } else { false.B } + val cache_pte = Option.when(HasBitmapCheck)(RegEnable(io.req.bits.bitmapCheck.get.pte.asTypeOf(new PteBundle().cloneType), io.req.fire)) + val pte = if (HasBitmapCheck) { Mux(jmp_bitmap_check_r, cache_pte.get, io.mem.resp.bits.asTypeOf(new PteBundle().cloneType)) } else { mem.resp.bits.asTypeOf(new PteBundle()) } // s/w register val s_pmp_check = RegInit(true.B) @@ -126,6 +146,11 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe // for updating "level" val mem_addr_update = RegInit(false.B) + val s_bitmap_check = RegInit(true.B) + val w_bitmap_resp = RegInit(true.B) + val whether_need_bitmap_check = RegInit(false.B) + val bitmap_checkfailed = RegInit(false.B) + val idle = RegInit(true.B) val finish = WireInit(false.B) val sent_to_pmp = idle === false.B && (s_pmp_check === false.B || mem_addr_update) && !finish @@ -140,7 +165,14 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val stage1 = RegEnable(io.req.bits.stage1, io.req.fire) val hptw_resp_stage2 = Reg(Bool()) - val ppn_af = Mux(enableS2xlate, Mux(onlyS1xlate, pte.isAf(), false.B), pte.isAf()) // In two-stage address translation, stage 1 ppn is a vpn for host, so don't need to check ppn_high + // use accessfault repersent bitmap check failed + val pte_isAf = Mux(bitmap_enable, pte.isAf() || bitmap_checkfailed, pte.isAf()) + val ppn_af = if (HasBitmapCheck) { + Mux(enableS2xlate, Mux(onlyS1xlate, pte_isAf, false.B), pte_isAf) // In two-stage address translation, stage 1 ppn is a vpn for host, so don't need to check ppn_high + } else { + Mux(enableS2xlate, Mux(onlyS1xlate, pte.isAf(), false.B), pte.isAf()) // In two-stage address translation, stage 1 ppn is a vpn for host, so don't need to check ppn_high + } + val find_pte = pte.isLeaf() || ppn_af || pageFault val to_find_pte = level === 1.U && find_pte === false.B val source = RegEnable(io.req.bits.req_info.source, io.req.fire) @@ -206,7 +238,7 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe io.req.ready := idle val ptw_resp = Wire(new PtwMergeResp) - ptw_resp.apply(Mux(pte_valid, pageFault && !accessFault, false.B), accessFault || (ppn_af && !(pte_valid && (pageFault || guestFault))), Mux(accessFault, af_level, Mux(guestFault, gpf_level, level)), Mux(pte_valid, pte, fake_pte), vpn, satp.asid, hgatp.vmid, vpn(sectortlbwidth - 1, 0), not_super = false, not_merge = false) + ptw_resp.apply(Mux(pte_valid, pageFault && !accessFault, false.B), accessFault || (ppn_af && !(pte_valid && (pageFault || guestFault))), Mux(accessFault, af_level, Mux(guestFault, gpf_level, level)), Mux(pte_valid, pte, fake_pte), vpn, satp.asid, hgatp.vmid, vpn(sectortlbwidth - 1, 0), not_super = false, not_merge = false, bitmap_checkfailed.asBool) val normal_resp = idle === false.B && mem_addr_update && !need_last_s2xlate && (guestFault || (w_mem_resp && find_pte) || (s_pmp_check && accessFault) || onlyS2xlate ) val stageHit_resp = idle === false.B && hptw_resp_stage2 @@ -221,12 +253,29 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe io.llptw.bits.req_info.vpn := vpn io.llptw.bits.req_info.s2xlate := req_s2xlate io.llptw.bits.ppn := DontCare + if (HasBitmapCheck) { + io.llptw.bits.bitmapCheck.get.jmp_bitmap_check := DontCare + io.llptw.bits.bitmapCheck.get.ptes := DontCare + io.llptw.bits.bitmapCheck.get.cfs := DontCare + io.llptw.bits.bitmapCheck.get.hitway := DontCare + } io.pmp.req.valid := DontCare // samecycle, do not use valid io.pmp.req.bits.addr := Mux(s2xlate, hpaddr, mem_addr) io.pmp.req.bits.size := 3.U // TODO: fix it io.pmp.req.bits.cmd := TlbCmd.read + if (HasBitmapCheck) { + val cache_level = RegEnable(io.req.bits.bitmapCheck.get.SPlevel, io.req.fire) + io.bitmap.get.req.valid := !s_bitmap_check + io.bitmap.get.req.bits.bmppn := pte.ppn + io.bitmap.get.req.bits.id := FsmReqID.U(bMemID.W) + io.bitmap.get.req.bits.vpn := vpn + io.bitmap.get.req.bits.level := Mux(jmp_bitmap_check_r, cache_level, level) + io.bitmap.get.req.bits.way_info := DontCare + io.bitmap.get.req.bits.hptw_bypassed := false.B + io.bitmap.get.resp.ready := !w_bitmap_resp + } mem.req.valid := s_mem_req === false.B && !mem.mask && !accessFault && s_pmp_check mem.req.bits.addr := Mux(s2xlate, hpaddr, mem_addr) mem.req.bits.id := FsmReqID.U(bMemID.W) @@ -242,7 +291,22 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe io.hptw.req.bits.gvpn := get_pn(gpaddr) io.hptw.req.bits.source := source - when (io.req.fire && io.req.bits.stage1Hit){ + if (HasBitmapCheck) { + when (io.req.fire && jmp_bitmap_check_w) { + idle := false.B + req_s2xlate := io.req.bits.req_info.s2xlate + vpn := io.req.bits.req_info.vpn + s_bitmap_check := false.B + need_last_s2xlate := false.B + hptw_pageFault := false.B + hptw_accessFault := false.B + level := io.req.bits.bitmapCheck.get.SPlevel + pte_valid := true.B + accessFault := false.B + } + } + + when (io.req.fire && io.req.bits.stage1Hit && (if (HasBitmapCheck) !jmp_bitmap_check_w else true.B)) { idle := false.B req_s2xlate := io.req.bits.req_info.s2xlate s_last_hptw_req := false.B @@ -257,7 +321,7 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe idle := true.B } - when (io.req.fire && !io.req.bits.stage1Hit){ + when (io.req.fire && !io.req.bits.stage1Hit && (if (HasBitmapCheck) !jmp_bitmap_check_w else true.B)) { val req = io.req.bits val gvpn_wire = Wire(UInt(ptePPNLen.W)) if (EnableSv48) { @@ -374,6 +438,12 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe w_last_hptw_resp := true.B mem_addr_update := true.B need_last_s2xlate := false.B + if (HasBitmapCheck) { + s_bitmap_check := true.B + w_bitmap_resp := true.B + whether_need_bitmap_check := false.B + bitmap_checkfailed := false.B + } } when(guestFault && idle === false.B){ @@ -387,6 +457,12 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe w_last_hptw_resp := true.B mem_addr_update := true.B need_last_s2xlate := false.B + if (HasBitmapCheck) { + s_bitmap_check := true.B + w_bitmap_resp := true.B + whether_need_bitmap_check := false.B + bitmap_checkfailed := false.B + } } when (mem.req.fire){ @@ -398,10 +474,19 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe w_mem_resp := true.B af_level := af_level - 1.U s_llptw_req := false.B - mem_addr_update := true.B gpf_level := Mux(mode === Sv39 && !pte_valid && !(l3Hit || l2Hit), gpf_level - 2.U, gpf_level - 1.U) pte_valid := true.B update_full_gvpn_mem_resp := true.B + if (HasBitmapCheck) { + when (bitmap_enable) { + whether_need_bitmap_check := true.B + } .otherwise { + mem_addr_update := true.B + whether_need_bitmap_check := false.B + } + } else { + mem_addr_update := true.B + } } when(update_full_gvpn_mem_resp) { @@ -409,6 +494,28 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe full_gvpn_reg := pte.getPPN() } + if (HasBitmapCheck) { + when (whether_need_bitmap_check) { + when (bitmap_enable && (!enableS2xlate || onlyS1xlate) && pte.isLeaf()) { + s_bitmap_check := false.B + whether_need_bitmap_check := false.B + } .otherwise { + mem_addr_update := true.B + whether_need_bitmap_check := false.B + } + } + // bitmapcheck + when (io.bitmap.get.req.fire) { + s_bitmap_check := true.B + w_bitmap_resp := false.B + } + when (io.bitmap.get.resp.fire) { + w_bitmap_resp := true.B + mem_addr_update := true.B + bitmap_checkfailed := io.bitmap.get.resp.bits.cf + } + } + when(mem_addr_update){ when(level >= 2.U && !onlyS2xlate && !(guestFault || find_pte || accessFault)) { level := levelNext @@ -457,6 +564,12 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe w_hptw_resp := true.B s_last_hptw_req := true.B w_last_hptw_resp := true.B + if (HasBitmapCheck) { + s_bitmap_check := true.B + w_bitmap_resp := true.B + whether_need_bitmap_check := false.B + bitmap_checkfailed := false.B + } } @@ -496,6 +609,12 @@ class PTW()(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe class LLPTWInBundle(implicit p: Parameters) extends XSBundle with HasPtwConst { val req_info = Output(new L2TlbInnerBundle()) val ppn = Output(UInt(ptePPNLen.W)) + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // find pte in l0 or sp, but need bitmap check + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + val hitway = UInt(l2tlbParams.l0nWays.W) + }) } class LLPTWIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { @@ -506,6 +625,11 @@ class LLPTWIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { val h_resp = Output(new HptwResp) val first_s2xlate_fault = Output(Bool()) // Whether the first stage 2 translation occurs pf/af val af = Output(Bool()) + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // find pte in l0 or sp, but need bitmap check + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + }) }) val mem = new Bundle { val req = DecoupledIO(new L2TlbMemReqBundle()) @@ -535,6 +659,12 @@ class LLPTWIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst { val h_resp = Output(new HptwResp) })) } + val bitmap = Option.when(HasBitmapCheck)(new Bundle { + val req = DecoupledIO(new bitmapReqBundle()) + val resp = Flipped(DecoupledIO(new bitmapRespBundle())) + }) + + val l0_way_info = Option.when(HasBitmapCheck)(Input(UInt(l2tlbParams.l0nWays.W))) } class LLPTWEntry(implicit p: Parameters) extends XSBundle with HasPtwConst { @@ -544,6 +674,12 @@ class LLPTWEntry(implicit p: Parameters) extends XSBundle with HasPtwConst { val af = Bool() val hptw_resp = new HptwResp() val first_s2xlate_fault = Output(Bool()) + val cf = Bool() + val from_l0 = Bool() + val way_info = UInt(l2tlbParams.l0nWays.W) + val jmp_bitmap_check = Bool() + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) + val cfs = Vec(tlbcontiguous, Bool()) } @@ -553,9 +689,13 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val satp = Mux(enableS2xlate, io.csr.vsatp, io.csr.satp) val s1Pbmte = Mux(enableS2xlate, io.csr.hPBMTE, io.csr.mPBMTE) + // mbmc:bitmap csr + val mbmc = io.csr.mbmc + val bitmap_enable = (if (HasBitmapCheck) true.B else false.B) && mbmc.BME === 1.U && mbmc.CMODE === 0.U + val flush = io.sfence.valid || io.csr.satp.changed || io.csr.vsatp.changed || io.csr.hgatp.changed val entries = RegInit(VecInit(Seq.fill(l2tlbParams.llptwsize)(0.U.asTypeOf(new LLPTWEntry())))) - val state_idle :: state_hptw_req :: state_hptw_resp :: state_addr_check :: state_mem_req :: state_mem_waiting :: state_mem_out :: state_last_hptw_req :: state_last_hptw_resp :: state_cache :: Nil = Enum(10) + val state_idle :: state_hptw_req :: state_hptw_resp :: state_addr_check :: state_mem_req :: state_mem_waiting :: state_mem_out :: state_last_hptw_req :: state_last_hptw_resp :: state_cache :: state_bitmap_check :: state_bitmap_resp :: Nil = Enum(12) val state = RegInit(VecInit(Seq.fill(l2tlbParams.llptwsize)(state_idle))) val is_emptys = state.map(_ === state_idle) @@ -567,6 +707,8 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val is_last_hptw_req = state.map(_ === state_last_hptw_req) val is_hptw_resp = state.map(_ === state_hptw_resp) val is_last_hptw_resp = state.map(_ === state_last_hptw_resp) + val is_bitmap_req = state.map(_ === state_bitmap_check) + val is_bitmap_resp = state.map(_ === state_bitmap_resp) val full = !ParallelOR(is_emptys).asBool val enq_ptr = ParallelPriorityEncoder(is_emptys) @@ -590,6 +732,21 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe hyper_arb2.io.in(i).valid := is_last_hptw_req(i) && !(Cat(is_hptw_resp).orR) && !(Cat(is_last_hptw_resp).orR) } + + val bitmap_arb = Option.when(HasBitmapCheck)(Module(new RRArbiter(new bitmapReqBundle(), l2tlbParams.llptwsize))) + val way_info = Option.when(HasBitmapCheck)(Wire(Vec(l2tlbParams.llptwsize, UInt(l2tlbParams.l0nWays.W)))) + if (HasBitmapCheck) { + for (i <- 0 until l2tlbParams.llptwsize) { + bitmap_arb.get.io.in(i).valid := is_bitmap_req(i) + bitmap_arb.get.io.in(i).bits.bmppn := entries(i).ppn + bitmap_arb.get.io.in(i).bits.vpn := entries(i).req_info.vpn + bitmap_arb.get.io.in(i).bits.id := i.U + bitmap_arb.get.io.in(i).bits.level := 0.U // last level + bitmap_arb.get.io.in(i).bits.way_info := Mux(entries(i).from_l0, entries(i).way_info, way_info.get(i)) + bitmap_arb.get.io.in(i).bits.hptw_bypassed := false.B + } + } + val cache_ptr = ParallelMux(is_cache, (0 until l2tlbParams.llptwsize).map(_.U(log2Up(l2tlbParams.llptwsize).W))) // duplicate req @@ -601,12 +758,15 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val dup_req_fire = mem_arb.io.out.fire && dup(io.in.bits.req_info.vpn, mem_arb.io.out.bits.req_info.vpn) && io.in.bits.req_info.s2xlate === mem_arb.io.out.bits.req_info.s2xlate // dup with the req fire entry val dup_vec_wait = dup_vec.zip(is_waiting).map{case (d, w) => d && w} // dup with "mem_waiting" entries, sending mem req already val dup_vec_having = dup_vec.zipWithIndex.map{case (d, i) => d && is_having(i)} // dup with the "mem_out" entry recv the data just now + val dup_vec_bitmap = dup_vec.zipWithIndex.map{case (d, i) => d && (is_bitmap_req(i) || is_bitmap_resp(i))} val dup_vec_last_hptw = dup_vec.zipWithIndex.map{case (d, i) => d && (is_last_hptw_req(i) || is_last_hptw_resp(i))} val wait_id = Mux(dup_req_fire, mem_arb.io.chosen, ParallelMux(dup_vec_wait zip entries.map(_.wait_id))) val dup_wait_resp = io.mem.resp.fire && VecInit(dup_vec_wait)(io.mem.resp.bits.id) && !io.mem.flush_latch(io.mem.resp.bits.id) // dup with the entry that data coming next cycle val to_wait = Cat(dup_vec_wait).orR || dup_req_fire - val to_mem_out = dup_wait_resp && ((entries(io.mem.resp.bits.id).req_info.s2xlate === noS2xlate) || (entries(io.mem.resp.bits.id).req_info.s2xlate === onlyStage1)) - val to_cache = Cat(dup_vec_having).orR || Cat(dup_vec_last_hptw).orR + val to_mem_out = dup_wait_resp && ((entries(io.mem.resp.bits.id).req_info.s2xlate === noS2xlate) || (entries(io.mem.resp.bits.id).req_info.s2xlate === onlyStage1)) && !bitmap_enable + val to_bitmap_req = (if (HasBitmapCheck) true.B else false.B) && dup_wait_resp && ((entries(io.mem.resp.bits.id).req_info.s2xlate === noS2xlate) || (entries(io.mem.resp.bits.id).req_info.s2xlate === onlyStage1)) && bitmap_enable + val to_cache = if (HasBitmapCheck) Cat(dup_vec_bitmap).orR || Cat(dup_vec_having).orR || Cat(dup_vec_last_hptw).orR + else Cat(dup_vec_having).orR || Cat(dup_vec_last_hptw).orR val to_hptw_req = io.in.bits.req_info.s2xlate === allStage val to_last_hptw_req = dup_wait_resp && entries(io.mem.resp.bits.id).req_info.s2xlate === allStage val last_hptw_req_id = io.mem.resp.bits.id @@ -620,13 +780,14 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val mem_resp_hit = RegInit(VecInit(Seq.fill(l2tlbParams.llptwsize)(false.B))) val enq_state_normal = MuxCase(state_addr_check, Seq( to_mem_out -> state_mem_out, // same to the blew, but the mem resp now + to_bitmap_req -> state_bitmap_check, to_last_hptw_req -> state_last_hptw_req, to_wait -> state_mem_waiting, to_cache -> state_cache, to_hptw_req -> state_hptw_req )) val enq_state = Mux(from_pre(io.in.bits.req_info.source) && enq_state_normal =/= state_addr_check, state_idle, enq_state_normal) - when (io.in.fire) { + when (io.in.fire && (if (HasBitmapCheck) !io.in.bits.bitmapCheck.get.jmp_bitmap_check else true.B)) { // if prefetch req does not need mem access, just give it up. // so there will be at most 1 + FilterSize entries that needs re-access page cache // so 2 + FilterSize is enough to avoid dead-lock @@ -635,13 +796,40 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe entries(enq_ptr).ppn := Mux(to_last_hptw_req, last_hptw_req_ppn, io.in.bits.ppn) entries(enq_ptr).wait_id := Mux(to_wait, wait_id, enq_ptr) entries(enq_ptr).af := false.B + if (HasBitmapCheck) { + entries(enq_ptr).cf := false.B + entries(enq_ptr).from_l0 := false.B + entries(enq_ptr).way_info := 0.U + entries(enq_ptr).jmp_bitmap_check := false.B + for (i <- 0 until tlbcontiguous) { + entries(enq_ptr).ptes(i) := 0.U + } + entries(enq_ptr).cfs := io.in.bits.bitmapCheck.get.cfs + } entries(enq_ptr).hptw_resp := Mux(to_last_hptw_req, entries(last_hptw_req_id).hptw_resp, Mux(to_wait, entries(wait_id).hptw_resp, entries(enq_ptr).hptw_resp)) entries(enq_ptr).first_s2xlate_fault := false.B - mem_resp_hit(enq_ptr) := to_mem_out || to_last_hptw_req + mem_resp_hit(enq_ptr) := to_bitmap_req || to_mem_out || to_last_hptw_req + } + + if (HasBitmapCheck) { + when (io.in.bits.bitmapCheck.get.jmp_bitmap_check && io.in.fire) { + state(enq_ptr) := state_bitmap_check + entries(enq_ptr).req_info := io.in.bits.req_info + entries(enq_ptr).ppn := io.in.bits.bitmapCheck.get.ptes(io.in.bits.req_info.vpn(sectortlbwidth - 1, 0)).asTypeOf(new PteBundle().cloneType).ppn + entries(enq_ptr).wait_id := enq_ptr + entries(enq_ptr).af := false.B + entries(enq_ptr).cf := false.B + entries(enq_ptr).from_l0 := true.B + entries(enq_ptr).way_info := io.in.bits.bitmapCheck.get.hitway + entries(enq_ptr).jmp_bitmap_check := io.in.bits.bitmapCheck.get.jmp_bitmap_check + entries(enq_ptr).ptes := io.in.bits.bitmapCheck.get.ptes + entries(enq_ptr).cfs := io.in.bits.bitmapCheck.get.cfs + mem_resp_hit(enq_ptr) := false.B + } } val enq_ptr_reg = RegNext(enq_ptr) - val need_addr_check = GatedValidRegNext(enq_state === state_addr_check && io.in.fire && !flush) + val need_addr_check = GatedValidRegNext(enq_state === state_addr_check && io.in.fire && !flush && (if (HasBitmapCheck) !io.in.bits.bitmapCheck.get.jmp_bitmap_check else true.B)) val hasHptwResp = ParallelOR(state.map(_ === state_hptw_resp)).asBool val hptw_resp_ptr_reg = RegNext(io.hptw.resp.bits.id) @@ -669,6 +857,7 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe when (mem_arb.io.out.fire) { for (i <- state.indices) { when (state(i) =/= state_idle && state(i) =/= state_mem_out && state(i) =/= state_last_hptw_req && state(i) =/= state_last_hptw_resp + && (if (HasBitmapCheck) state(i) =/= state_bitmap_check && state(i) =/= state_bitmap_resp else true.B) && entries(i).req_info.s2xlate === mem_arb.io.out.bits.req_info.s2xlate && dup(entries(i).req_info.vpn, mem_arb.io.out.bits.req_info.vpn)) { // NOTE: "dup enq set state to mem_wait" -> "sending req set other dup entries to mem_wait" @@ -685,7 +874,7 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe val req_hpaddr = MakeAddr(entries(i).hptw_resp.genPPNS2(get_pn(req_paddr)), getVpnn(entries(i).req_info.vpn, 0)) val index = Mux(entries(i).req_info.s2xlate === allStage, req_hpaddr, req_paddr)(log2Up(l2tlbParams.blockBytes)-1, log2Up(XLEN/8)) state(i) := Mux(entries(i).req_info.s2xlate === allStage && !(ptes(index).isPf(0.U, s1Pbmte) || !ptes(index).isLeaf() || ptes(index).isAf() || ptes(index).isStage1Gpf(io.csr.vsatp.mode)) - , state_last_hptw_req, state_mem_out) + , state_last_hptw_req, Mux(bitmap_enable, state_bitmap_check, state_mem_out)) mem_resp_hit(i) := true.B entries(i).ppn := ptes(index).getPPN() // for last stage 2 translation entries(i).hptw_resp.gpf := Mux(entries(i).req_info.s2xlate === allStage, ptes(index).isStage1Gpf(io.csr.vsatp.mode), false.B) @@ -693,6 +882,12 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe } } + if (HasBitmapCheck) { + for (i <- 0 until l2tlbParams.llptwsize) { + way_info.get(i) := DataHoldBypass(io.l0_way_info.get, mem_resp_hit(i)) + } + } + when (hyper_arb1.io.out.fire) { for (i <- state.indices) { when (state(i) === state_hptw_req && entries(i).ppn === hyper_arb1.io.out.bits.ppn && entries(i).req_info.s2xlate === allStage && hyper_arb1.io.chosen === i.U) { @@ -711,6 +906,27 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe } } + if (HasBitmapCheck) { + when (bitmap_arb.get.io.out.fire) { + for (i <- state.indices) { + when (is_bitmap_req(i) && bitmap_arb.get.io.out.bits.bmppn === entries(i).ppn(ppnLen - 1, 0)) { + state(i) := state_bitmap_resp + entries(i).wait_id := bitmap_arb.get.io.chosen + } + } + } + + when (io.bitmap.get.resp.fire) { + for (i <- state.indices) { + when (is_bitmap_resp(i) && io.bitmap.get.resp.bits.id === entries(i).wait_id) { + entries(i).cfs := io.bitmap.get.resp.bits.cfs + entries(i).cf := io.bitmap.get.resp.bits.cf + state(i) := state_mem_out + } + } + } + } + when (io.hptw.resp.fire) { for (i <- state.indices) { when (state(i) === state_hptw_resp && io.hptw.resp.bits.id === entries(i).wait_id && io.hptw.resp.bits.h_resp.entry.tag === entries(i).ppn) { @@ -756,7 +972,15 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe io.out.valid := ParallelOR(is_having).asBool io.out.bits.req_info := entries(mem_ptr).req_info io.out.bits.id := mem_ptr - io.out.bits.af := entries(mem_ptr).af + if (HasBitmapCheck) { + io.out.bits.af := Mux(bitmap_enable, entries(mem_ptr).af || entries(mem_ptr).cf, entries(mem_ptr).af) + io.out.bits.bitmapCheck.get.jmp_bitmap_check := entries(mem_ptr).jmp_bitmap_check + io.out.bits.bitmapCheck.get.ptes := entries(mem_ptr).ptes + io.out.bits.bitmapCheck.get.cfs := entries(mem_ptr).cfs + } else { + io.out.bits.af := entries(mem_ptr).af + } + io.out.bits.h_resp := entries(mem_ptr).hptw_resp io.out.bits.first_s2xlate_fault := entries(mem_ptr).first_s2xlate_fault @@ -799,6 +1023,19 @@ class LLPTW(implicit p: Parameters) extends XSModule with HasPtwConst with HasPe io.cache.valid := Cat(is_cache).orR io.cache.bits := ParallelMux(is_cache, entries.map(_.req_info)) + val has_bitmap_resp = ParallelOR(is_bitmap_resp).asBool + if (HasBitmapCheck) { + io.bitmap.get.req.valid := bitmap_arb.get.io.out.valid && !flush + io.bitmap.get.req.bits.bmppn := bitmap_arb.get.io.out.bits.bmppn + io.bitmap.get.req.bits.id := bitmap_arb.get.io.chosen + io.bitmap.get.req.bits.vpn := bitmap_arb.get.io.out.bits.vpn + io.bitmap.get.req.bits.level := 0.U + io.bitmap.get.req.bits.way_info := bitmap_arb.get.io.out.bits.way_info + io.bitmap.get.req.bits.hptw_bypassed := bitmap_arb.get.io.out.bits.hptw_bypassed + bitmap_arb.get.io.out.ready := io.bitmap.get.req.ready + io.bitmap.get.resp.ready := has_bitmap_resp + } + XSPerfAccumulate("llptw_in_count", io.in.fire) XSPerfAccumulate("llptw_in_block", io.in.valid && !io.in.ready) for (i <- 0 until 7) { @@ -838,6 +1075,15 @@ class HPTWIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst val l2Hit = Bool() val l1Hit = Bool() val bypassed = Bool() // if bypass, don't refill + val bitmapCheck = Option.when(HasBitmapCheck)(new Bundle { + val jmp_bitmap_check = Bool() // find pte in l0 or sp, but need bitmap check + val pte = UInt(XLEN.W) // Page Table Entry + val ptes = Vec(tlbcontiguous, UInt(XLEN.W)) // Page Table Entry Vector + val cfs = Vec(tlbcontiguous, Bool()) // Bitmap Check Failed Vector + val hitway = UInt(l2tlbParams.l0nWays.W) + val fromSP = Bool() + val SPlevel = UInt(log2Up(Level).W) + }) })) val resp = DecoupledIO(new Bundle { val source = UInt(bSourceWidth.W) @@ -858,6 +1104,12 @@ class HPTWIO()(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst val req = ValidIO(new PMPReqBundle()) val resp = Flipped(new PMPRespBundle()) } + val bitmap = Option.when(HasBitmapCheck)(new Bundle { + val req = DecoupledIO(new bitmapReqBundle()) + val resp = Flipped(DecoupledIO(new bitmapRespBundle())) + }) + + val l0_way_info = Option.when(HasBitmapCheck)(Input(UInt(l2tlbParams.l0nWays.W))) } class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { @@ -868,6 +1120,10 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { val flush = sfence.valid || hgatp.changed || io.csr.satp.changed || io.csr.vsatp.changed val mode = hgatp.mode + // mbmc:bitmap csr + val mbmc = io.csr.mbmc + val bitmap_enable = (if (HasBitmapCheck) true.B else false.B) && mbmc.BME === 1.U && mbmc.CMODE === 0.U + val level = RegInit(3.U(log2Up(Level + 1).W)) val af_level = RegInit(3.U(log2Up(Level + 1).W)) // access fault return this level val gpaddr = Reg(UInt(GPAddrBits.W)) @@ -879,7 +1135,10 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { val l1Hit = Reg(Bool()) val bypassed = Reg(Bool()) // val pte = io.mem.resp.bits.MergeRespToPte() - val pte = io.mem.resp.bits.asTypeOf(new PteBundle().cloneType) + val jmp_bitmap_check = if (HasBitmapCheck) RegEnable(io.req.bits.bitmapCheck.get.jmp_bitmap_check, io.req.fire) else false.B + val fromSP = if (HasBitmapCheck) RegEnable(io.req.bits.bitmapCheck.get.fromSP, io.req.fire) else false.B + val cache_pte = Option.when(HasBitmapCheck)(RegEnable(Mux(io.req.bits.bitmapCheck.get.fromSP, io.req.bits.bitmapCheck.get.pte.asTypeOf(new PteBundle().cloneType), io.req.bits.bitmapCheck.get.ptes(io.req.bits.gvpn(sectortlbwidth - 1, 0)).asTypeOf(new PteBundle().cloneType)), io.req.fire)) + val pte = if (HasBitmapCheck) Mux(jmp_bitmap_check, cache_pte.get, io.mem.resp.bits.asTypeOf(new PteBundle().cloneType)) else io.mem.resp.bits.asTypeOf(new PteBundle().cloneType) val ppn_l3 = Mux(l3Hit, req_ppn, pte.ppn) val ppn_l2 = Mux(l2Hit, req_ppn, pte.ppn) val ppn_l1 = Mux(l1Hit, req_ppn, pte.ppn) @@ -910,12 +1169,21 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { val idle = RegInit(true.B) val mem_addr_update = RegInit(false.B) val finish = WireInit(false.B) + val s_bitmap_check = RegInit(true.B) + val w_bitmap_resp = RegInit(true.B) + val whether_need_bitmap_check = RegInit(false.B) + val bitmap_checkfailed = RegInit(false.B) val sent_to_pmp = !idle && (!s_pmp_check || mem_addr_update) && !finish val pageFault = pte.isGpf(level, mpbmte) || (!pte.isLeaf() && level === 0.U) val accessFault = RegEnable(io.pmp.resp.ld || io.pmp.resp.mmio, sent_to_pmp) - val ppn_af = pte.isAf() + // use access fault when bitmap check failed + val ppn_af = if (HasBitmapCheck) { + Mux(bitmap_enable, pte.isAf() || bitmap_checkfailed, pte.isAf()) + } else { + pte.isAf() + } val find_pte = pte.isLeaf() || ppn_af || pageFault val resp_valid = !idle && mem_addr_update && ((w_mem_resp && find_pte) || (s_pmp_check && accessFault)) @@ -943,6 +1211,20 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { io.pmp.req.bits.size := 3.U io.pmp.req.bits.cmd := TlbCmd.read + if (HasBitmapCheck) { + val way_info = DataHoldBypass(io.l0_way_info.get, RegNext(io.mem.resp.fire, init=false.B)) + val cache_hitway = RegEnable(io.req.bits.bitmapCheck.get.hitway, io.req.fire) + val cache_level = RegEnable(io.req.bits.bitmapCheck.get.SPlevel, io.req.fire) + io.bitmap.get.req.valid := !s_bitmap_check + io.bitmap.get.req.bits.bmppn := pte.ppn + io.bitmap.get.req.bits.id := HptwReqId.U(bMemID.W) + io.bitmap.get.req.bits.vpn := vpn + io.bitmap.get.req.bits.level := Mux(jmp_bitmap_check, Mux(fromSP,cache_level,0.U), level) + io.bitmap.get.req.bits.way_info := Mux(jmp_bitmap_check, cache_hitway, way_info) + io.bitmap.get.req.bits.hptw_bypassed := bypassed + io.bitmap.get.resp.ready := !w_bitmap_resp + } + io.mem.req.valid := !s_mem_req && !io.mem.mask && !accessFault && s_pmp_check io.mem.req.bits.addr := mem_addr io.mem.req.bits.id := HptwReqId.U(bMemID.W) @@ -952,8 +1234,18 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { io.refill.level := level io.refill.req_info.source := source io.refill.req_info.s2xlate := onlyStage2 + when (idle){ - when(io.req.fire){ + if (HasBitmapCheck) { + when (io.req.bits.bitmapCheck.get.jmp_bitmap_check && io.req.fire) { + idle := false.B + gpaddr := Cat(io.req.bits.gvpn, 0.U(offLen.W)) + s_bitmap_check := false.B + id := io.req.bits.id + level := Mux(io.req.bits.bitmapCheck.get.fromSP, io.req.bits.bitmapCheck.get.SPlevel, 0.U) + } + } + when (io.req.fire && (if (HasBitmapCheck) !io.req.bits.bitmapCheck.get.jmp_bitmap_check else true.B)) { bypassed := io.req.bits.bypassed idle := false.B gpaddr := Cat(io.req.bits.gvpn, 0.U(offLen.W)) @@ -991,6 +1283,12 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { s_mem_req := true.B w_mem_resp := true.B mem_addr_update := true.B + if (HasBitmapCheck) { + s_bitmap_check := true.B + w_bitmap_resp := true.B + whether_need_bitmap_check := false.B + bitmap_checkfailed := false.B + } } when(io.mem.req.fire){ @@ -1001,7 +1299,38 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { when(io.mem.resp.fire && !w_mem_resp){ w_mem_resp := true.B af_level := af_level - 1.U - mem_addr_update := true.B + if (HasBitmapCheck) { + when (bitmap_enable) { + whether_need_bitmap_check := true.B + } .otherwise { + mem_addr_update := true.B + whether_need_bitmap_check := false.B + } + } else { + mem_addr_update := true.B + } + } + + if (HasBitmapCheck) { + when (whether_need_bitmap_check) { + when (bitmap_enable && pte.isLeaf()) { + s_bitmap_check := false.B + whether_need_bitmap_check := false.B + } .otherwise { + mem_addr_update := true.B + whether_need_bitmap_check := false.B + } + } + // bitmapcheck + when (io.bitmap.get.req.fire) { + s_bitmap_check := true.B + w_bitmap_resp := false.B + } + when (io.bitmap.get.resp.fire) { + w_bitmap_resp := true.B + mem_addr_update := true.B + bitmap_checkfailed := io.bitmap.get.resp.bits.cf + } } when(mem_addr_update){ @@ -1018,12 +1347,18 @@ class HPTW()(implicit p: Parameters) extends XSModule with HasPtwConst { finish := true.B } } - when (flush) { + when (flush) { idle := true.B s_pmp_check := true.B s_mem_req := true.B w_mem_resp := true.B accessFault := false.B mem_addr_update := false.B + if (HasBitmapCheck) { + s_bitmap_check := true.B + w_bitmap_resp := true.B + whether_need_bitmap_check := false.B + bitmap_checkfailed := false.B + } } } diff --git a/src/main/scala/xiangshan/frontend/Frontend.scala b/src/main/scala/xiangshan/frontend/Frontend.scala index b990eb6fc9a..745dcffe6ef 100644 --- a/src/main/scala/xiangshan/frontend/Frontend.scala +++ b/src/main/scala/xiangshan/frontend/Frontend.scala @@ -135,7 +135,11 @@ class FrontendInlinedImp(outer: FrontendInlined) extends LazyModuleImp(outer) pmp_req_vec.last <> ifu.io.pmp.req for (i <- pmp_check.indices) { - pmp_check(i).apply(tlbCsr.priv.imode, pmp.io.pmp, pmp.io.pma, pmp_req_vec(i)) + if (HasBitmapCheck) { + pmp_check(i).apply(tlbCsr.mbmc.CMODE.asBool, tlbCsr.priv.imode, pmp.io.pmp, pmp.io.pma, pmp_req_vec(i)) + } else { + pmp_check(i).apply(tlbCsr.priv.imode, pmp.io.pmp, pmp.io.pma, pmp_req_vec(i)) + } } (0 until 2 * PortNumber).foreach(i => icache.io.pmp(i).resp <> pmp_check(i).resp) ifu.io.pmp.resp <> pmp_check.last.resp diff --git a/src/main/scala/xiangshan/mem/MemBlock.scala b/src/main/scala/xiangshan/mem/MemBlock.scala index ff71ddfb460..0799d81c395 100644 --- a/src/main/scala/xiangshan/mem/MemBlock.scala +++ b/src/main/scala/xiangshan/mem/MemBlock.scala @@ -743,7 +743,11 @@ class MemBlockInlinedImp(outer: MemBlockInlined) extends LazyModuleImp(outer) val pmp_checkers = Seq.fill(DTlbSize)(Module(new PMPChecker(4, leaveHitMux = true))) val pmp_check = pmp_checkers.map(_.io) for ((p,d) <- pmp_check zip dtlb_pmps) { - p.apply(tlbcsr.priv.dmode, pmp.io.pmp, pmp.io.pma, d) + if (HasBitmapCheck) { + p.apply(tlbcsr.mbmc.CMODE.asBool, tlbcsr.priv.dmode, pmp.io.pmp, pmp.io.pma, d) + } else { + p.apply(tlbcsr.priv.dmode, pmp.io.pmp, pmp.io.pma, d) + } require(p.req.bits.size.getWidth == d.bits.size.getWidth) }