From 3f3ecf8659244baa9526f7768d40008e081d9aa3 Mon Sep 17 00:00:00 2001 From: zhou tao Date: Wed, 24 Jul 2024 01:25:57 +0800 Subject: [PATCH 1/3] RAS: handle a case where a jalr instruction requires a pop followed by a push --- src/main/scala/xiangshan/frontend/FTB.scala | 1 + .../xiangshan/frontend/FrontendBundle.scala | 7 +- src/main/scala/xiangshan/frontend/IFU.scala | 5 +- .../scala/xiangshan/frontend/NewFtq.scala | 58 ++++++---- .../scala/xiangshan/frontend/PreDecode.scala | 32 +++--- .../scala/xiangshan/frontend/newRAS.scala | 100 +++++++++++++----- 6 files changed, 135 insertions(+), 68 deletions(-) diff --git a/src/main/scala/xiangshan/frontend/FTB.scala b/src/main/scala/xiangshan/frontend/FTB.scala index 0afbaed74dc..7f3114a5ca6 100644 --- a/src/main/scala/xiangshan/frontend/FTB.scala +++ b/src/main/scala/xiangshan/frontend/FTB.scala @@ -148,6 +148,7 @@ class FTBEntry_part(implicit p: Parameters) extends XSBundle with FTBParams with val isCall = Bool() val isRet = Bool() val isJalr = Bool() + val isRetCall = Bool() def isJal = !isJalr } diff --git a/src/main/scala/xiangshan/frontend/FrontendBundle.scala b/src/main/scala/xiangshan/frontend/FrontendBundle.scala index 42b1b718f8e..693c258c38b 100644 --- a/src/main/scala/xiangshan/frontend/FrontendBundle.scala +++ b/src/main/scala/xiangshan/frontend/FrontendBundle.scala @@ -430,10 +430,11 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC val fallThroughErr = Bool() val multiHit = Bool() - val is_jal = Bool() + val is_jal = Bool() val is_jalr = Bool() val is_call = Bool() - val is_ret = Bool() + val is_ret = Bool() + val is_ret_call = Bool() val last_may_be_rvi_call = Bool() val is_br_sharing = Bool() @@ -537,6 +538,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC is_jalr := entry.tailSlot.valid && entry.isJalr is_call := entry.tailSlot.valid && entry.isCall is_ret := entry.tailSlot.valid && entry.isRet + is_ret_call := entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone last_may_be_rvi_call := entry.last_may_be_rvi_call is_br_sharing := entry.tailSlot.valid && entry.tailSlot.sharing predCycle.map(_ := GTimer()) @@ -651,6 +653,7 @@ class BranchPredictionUpdate(implicit p: Parameters) extends XSBundle with HasBP def is_jalr = ftb_entry.tailSlot.valid && ftb_entry.isJalr def is_call = ftb_entry.tailSlot.valid && ftb_entry.isCall def is_ret = ftb_entry.tailSlot.valid && ftb_entry.isRet + def is_ret_call = ftb_entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone def is_call_taken = is_call && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset def is_ret_taken = is_ret && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset diff --git a/src/main/scala/xiangshan/frontend/IFU.scala b/src/main/scala/xiangshan/frontend/IFU.scala index 0654255b639..03c551110df 100644 --- a/src/main/scala/xiangshan/frontend/IFU.scala +++ b/src/main/scala/xiangshan/frontend/IFU.scala @@ -556,6 +556,7 @@ class NewIFU(implicit p: Parameters) extends XSModule pd.brType := f3Predecoder.io.out.pd(i).brType pd.isCall := f3Predecoder.io.out.pd(i).isCall pd.isRet := f3Predecoder.io.out.pd(i).isRet + pd.isRetCall := f3Predecoder.io.out.pd(i).isRetCall } val f3PdDiff = f3_pd_wire.zip(f3_pd).map{ case (a,b) => a.asUInt =/= b.asUInt }.reduce(_||_) @@ -863,7 +864,7 @@ class NewIFU(implicit p: Parameters) extends XSModule val inst = Cat(f3_mmio_data(1), f3_mmio_data(0)) val currentIsRVC = isRVC(inst) - val brType::isCall::isRet::Nil = brInfo(inst) + val brType::isCall::isRet::isRetCall::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC) val brOffset = br_offset(inst, currentIsRVC) @@ -875,6 +876,7 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.bits.pd(0).brType := brType io.toIbuffer.bits.pd(0).isCall := isCall io.toIbuffer.bits.pd(0).isRet := isRet + io.toIbuffer.bits.pd(0).isRetCall := isRetCall when (mmio_resend_af) { io.toIbuffer.bits.exceptionType(0) := ExceptionType.acf @@ -892,6 +894,7 @@ class NewIFU(implicit p: Parameters) extends XSModule mmioFlushWb.bits.pd(0).brType := brType mmioFlushWb.bits.pd(0).isCall := isCall mmioFlushWb.bits.pd(0).isRet := isRet + mmioFlushWb.bits.pd(0).isRetCall := isRetCall } mmio_redirect := (f3_req_is_mmio && mmio_state === m_waitCommit && RegNext(fromUncache.fire) && f3_mmio_use_seq_pc) diff --git a/src/main/scala/xiangshan/frontend/NewFtq.scala b/src/main/scala/xiangshan/frontend/NewFtq.scala index 876a32c7a3d..62d19ad0e24 100644 --- a/src/main/scala/xiangshan/frontend/NewFtq.scala +++ b/src/main/scala/xiangshan/frontend/NewFtq.scala @@ -29,13 +29,14 @@ import utility.ChiselDB class FtqDebugBundle extends Bundle { val pc = UInt(39.W) - val target = UInt(39.W) - val isBr = Bool() - val isJmp = Bool() - val isCall = Bool() - val isRet = Bool() - val misPred = Bool() - val isTaken = Bool() + val target = UInt(39.W) + val isBr = Bool() + val isJmp = Bool() + val isCall = Bool() + val isRet = Bool() + val isRetCall = Bool() + val misPred = Bool() + val isTaken = Bool() val predStage = UInt(2.W) } @@ -109,7 +110,8 @@ class Ftq_RF_Components(implicit p: Parameters) extends XSBundle with BPUUtils { class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { val brMask = Vec(PredictWidth, Bool()) - val jmpInfo = ValidUndirectioned(Vec(3, Bool())) + // jmpInfo includes isJalr, isCall, isRet, isRetCall signals + val jmpInfo = ValidUndirectioned(Vec(4, Bool())) val jmpOffset = UInt(log2Ceil(PredictWidth).W) val jalTarget = UInt(VAddrBits.W) val rvcMask = Vec(PredictWidth, Bool()) @@ -117,13 +119,14 @@ class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { def hasJalr = jmpInfo.valid && jmpInfo.bits(0) def hasCall = jmpInfo.valid && jmpInfo.bits(1) def hasRet = jmpInfo.valid && jmpInfo.bits(2) + def hasRetCall = jmpInfo.valid && jmpInfo.bits(3) def fromPdWb(pdWb: PredecodeWritebackBundle) = { val pds = pdWb.pd this.brMask := VecInit(pds.map(pd => pd.isBr && pd.valid)) this.jmpInfo.valid := VecInit(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)).asUInt.orR this.jmpInfo.bits := ParallelPriorityMux(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid), - pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.isRet))) + pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.isRet, pd.isRetCall))) this.jmpOffset := ParallelPriorityEncoder(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)) this.rvcMask := VecInit(pds.map(pd => pd.isRVC)) this.jalTarget := pdWb.jalTarget @@ -139,6 +142,7 @@ class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr) pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1) pd.isRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2) + pd.isRetCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(3) pd } } @@ -273,6 +277,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire val new_jmp_is_jalr = entry_has_jmp && pd.jmpInfo.bits(0) && io.cfiIndex.valid val new_jmp_is_call = entry_has_jmp && pd.jmpInfo.bits(1) && io.cfiIndex.valid val new_jmp_is_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid + val new_jmp_is_ret_call = entry_has_jmp && pd.jmpInfo.bits(3) && io.cfiIndex.valid val last_jmp_rvi = entry_has_jmp && pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask.last // val last_br_rvi = cfi_is_br && io.cfiIndex.bits === (PredictWidth-1).U && !pd.rvcMask.last @@ -302,11 +307,12 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire } val jmpPft = getLower(io.start_addr) +& pd.jmpOffset +& Mux(pd.rvcMask(pd.jmpOffset), 1.U, 2.U) - init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr)) - init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B) - init_entry.isJalr := new_jmp_is_jalr - init_entry.isCall := new_jmp_is_call - init_entry.isRet := new_jmp_is_ret + init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr)) + init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B) + init_entry.isJalr := new_jmp_is_jalr + init_entry.isCall := new_jmp_is_call + init_entry.isRet := new_jmp_is_ret + init_entry.isRetCall := new_jmp_is_ret_call // that means fall thru points to the middle of an inst init_entry.last_may_be_rvi_call := pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask(pd.jmpOffset) @@ -369,6 +375,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire old_entry_modified.last_may_be_rvi_call := false.B old_entry_modified.isCall := false.B old_entry_modified.isRet := false.B + old_entry_modified.isRetCall := false.B old_entry_modified.isJalr := false.B } @@ -967,7 +974,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe ((pred_ftb_entry.isJal && !(jmp_pd.valid && jmp_pd.isJal)) || (pred_ftb_entry.isJalr && !(jmp_pd.valid && jmp_pd.isJalr)) || (pred_ftb_entry.isCall && !(jmp_pd.valid && jmp_pd.isCall)) || - (pred_ftb_entry.isRet && !(jmp_pd.valid && jmp_pd.isRet)) + (pred_ftb_entry.isRet && !(jmp_pd.valid && jmp_pd.isRet)) || + (pred_ftb_entry.isRetCall && !(jmp_pd.valid && jmp_pd.isRetCall)) ) has_false_hit := br_false_hit || jal_false_hit || hit_pd_mispred_reg @@ -1364,6 +1372,7 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe update.spec_info := commit_spec_meta XSError(commit_valid && do_commit && debug_cfi, "\ncommit cfi can be non c_commited\n") + val commit_real_hit = commit_hit === h_hit val update_ftb_entry = update.ftb_entry @@ -1440,14 +1449,15 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe p"brInEntry(${inFtbEntry}) brIdx(${brIdx}) target(${Hexadecimal(target)})\n") val logbundle = Wire(new FtqDebugBundle) - logbundle.pc := pc - logbundle.target := target - logbundle.isBr := isBr - logbundle.isJmp := isJmp - logbundle.isCall := isJmp && commit_pd.hasCall - logbundle.isRet := isJmp && commit_pd.hasRet - logbundle.misPred := misPred - logbundle.isTaken := isTaken + logbundle.pc := pc + logbundle.target := target + logbundle.isBr := isBr + logbundle.isJmp := isJmp + logbundle.isCall := isJmp && commit_pd.hasCall + logbundle.isRet := isJmp && commit_pd.hasRet + logbundle.isRetCall := isJmp && commit_pd.hasRetCall + logbundle.misPred := misPred + logbundle.isTaken := isTaken logbundle.predStage := commit_stage ftqBranchTraceDB.log( @@ -1479,6 +1489,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe XSPerfAccumulate("fromBackendRedirect_ValidNum", io.fromBackend.redirect.valid) XSPerfAccumulate("toBpuRedirect_ValidNum", io.toBpu.redirect.valid) + XSPerfAccumulate("ret_call_Num", io.toBpu.update.valid && io.toBpu.update.bits.ftb_entry.isRetCall && io.toBpu.update.bits.ftb_entry.valid && io.toBpu.update.bits.ftb_entry.tailSlot.valid) + val from_bpu = io.fromBpu.resp.bits val to_ifu = io.toIfu.req.bits diff --git a/src/main/scala/xiangshan/frontend/PreDecode.scala b/src/main/scala/xiangshan/frontend/PreDecode.scala index 5cb3f360bbd..c86cf5479fc 100644 --- a/src/main/scala/xiangshan/frontend/PreDecode.scala +++ b/src/main/scala/xiangshan/frontend/PreDecode.scala @@ -37,7 +37,8 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons val rs = Mux(isRVC(instr), Mux(brType === BrType.jal, 0.U, instr(11, 7)), instr(19, 15)) val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64 val isRet = brType === BrType.jalr && isLink(rs) && !isCall - List(brType, isCall, isRet) + val isRetCall = (brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd)) + List(brType, isCall, isRet, isRetCall) } def jal_offset(inst: UInt, rvc: Bool): UInt = { val rvc_offset = Cat(inst(12), inst(8), inst(10, 9), inst(6), inst(7), inst(2), inst(11), inst(5, 3), 0.U(1.W)) @@ -56,7 +57,7 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons } object BrType { - def notCFI = "b00".U + def notCFI = "b00".U def branch = "b01".U def jal = "b10".U def jalr = "b11".U @@ -64,16 +65,17 @@ object BrType { } object ExcType { //TODO:add exctype - def notExc = "b000".U + def notExc = "b000".U def apply() = UInt(3.W) } class PreDecodeInfo extends Bundle { // 8 bit - val valid = Bool() - val isRVC = Bool() - val brType = UInt(2.W) - val isCall = Bool() - val isRet = Bool() + val valid = Bool() + val isRVC = Bool() + val brType = UInt(2.W) + val isCall = Bool() + val isRet = Bool() + val isRetCall = Bool() //val excType = UInt(3.W) def isBr = brType === BrType.branch def isJal = brType === BrType.jal @@ -133,7 +135,7 @@ class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{ val currentPC = io.in.bits.pc(i) //expander.io.in := inst - val brType::isCall::isRet::Nil = brInfo(inst) + val brType::isCall::isRet::isRetCall::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC(i)) val brOffset = br_offset(inst, currentIsRVC(i)) @@ -149,10 +151,11 @@ class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{ io.out.pd(i).brType := brType io.out.pd(i).isCall := isCall io.out.pd(i).isRet := isRet + io.out.pd(i).isRetCall := isRetCall //io.out.expInstr(i) := expander.io.out.bits - io.out.instr(i) :=inst - io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) + io.out.instr(i) :=inst + io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) } // the first half is always reliable @@ -259,11 +262,12 @@ class F3Predecoder(implicit p: Parameters) extends XSModule with HasPdConst { val out = Output(new F3PreDecodeResp) }) io.out.pd.zipWithIndex.map{ case (pd,i) => - pd.valid := DontCare - pd.isRVC := DontCare + pd.valid := DontCare + pd.isRVC := DontCare pd.brType := brInfo(io.in.instr(i))(0) pd.isCall := brInfo(io.in.instr(i))(1) - pd.isRet := brInfo(io.in.instr(i))(2) + pd.isRet := brInfo(io.in.instr(i))(2) + pd.isRetCall := brInfo(io.in.instr(i))(3) } } diff --git a/src/main/scala/xiangshan/frontend/newRAS.scala b/src/main/scala/xiangshan/frontend/newRAS.scala index 86b7a7f3c7c..ab431002fca 100644 --- a/src/main/scala/xiangshan/frontend/newRAS.scala +++ b/src/main/scala/xiangshan/frontend/newRAS.scala @@ -102,8 +102,9 @@ class RAS(implicit p: Parameters) extends BasePredictor { class RASStack(rasSize: Int, rasSpecSize: Int) extends XSModule with HasCircularQueuePtrHelper { val io = IO(new Bundle { val spec_push_valid = Input(Bool()) - val spec_pop_valid = Input(Bool()) - val spec_push_addr = Input(UInt(VAddrBits.W)) + val spec_pop_valid = Input(Bool()) + val spec_push_addr = Input(UInt(VAddrBits.W)) + val spec_ret_call = Input(Bool()) // for write bypass between s2 and s3 val s2_fire = Input(Bool()) @@ -113,18 +114,21 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_missed_pop = Input(Bool()) val s3_missed_push = Input(Bool()) val s3_pushAddr = Input(UInt(VAddrBits.W)) + val s3_ret_call = Input(Bool()) val spec_pop_addr = Output(UInt(VAddrBits.W)) val commit_push_valid = Input(Bool()) - val commit_pop_valid = Input(Bool()) + val commit_pop_valid = Input(Bool()) + val commit_ret_call = Input(Bool()) val commit_push_addr = Input(UInt(VAddrBits.W)) val commit_meta_TOSW = Input(new RASPtr) // for debug purpose only val commit_meta_ssp = Input(UInt(log2Up(RasSize).W)) - val redirect_valid = Input(Bool()) - val redirect_isCall = Input(Bool()) - val redirect_isRet = Input(Bool()) + val redirect_valid = Input(Bool()) + val redirect_isCall = Input(Bool()) + val redirect_isRet = Input(Bool()) + val redirect_ret_call = Input(Bool()) val redirect_meta_ssp = Input(UInt(log2Up(RasSize).W)) val redirect_meta_sctr = Input(UInt(RasCtrSize.W)) val redirect_meta_TOSW = Input(new RASPtr) @@ -154,6 +158,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { val TOSR = RegInit(RASPtr(true.B, (RasSpecSize - 1).U)) val TOSW = RegInit(RASPtr(false.B, 0.U)) val BOS = RegInit(RASPtr(false.B, 0.U)) + val NOS = RegInit(RASPtr(false.B, 0.U)) val spec_overflowed = RegInit(false.B) @@ -242,8 +247,6 @@ class RAS(implicit p: Parameters) extends BasePredictor { writeBypassValidWire := writeBypassValid } - - val topEntry = getTop(ssp, sctr, TOSR, TOSW, true) val topNos = getTopNos(TOSR, true) val redirectTopEntry = getTop(io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, false) @@ -254,12 +257,12 @@ class RAS(implicit p: Parameters) extends BasePredictor { val writeEntry = Wire(new RASEntry) val writeNos = Wire(new RASPtr) writeEntry.retAddr := Mux(io.redirect_valid && io.redirect_isCall, io.redirect_callAddr, io.spec_push_addr) - writeEntry.ctr := Mux(io.redirect_valid && io.redirect_isCall, - Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U), - Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U)) + writeEntry.ctr := Mux(io.redirect_valid && io.redirect_isCall, + Mux(io.redirect_ret_call, 0.U, Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U)), + Mux(io.spec_ret_call, 0.U, Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U))) writeNos := Mux(io.redirect_valid && io.redirect_isCall, - io.redirect_meta_TOSR, TOSR) + Mux(io.redirect_ret_call, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_ret_call, topNos,TOSR)) when (io.spec_push_valid || (io.redirect_valid && io.redirect_isCall)) { writeBypassEntry := writeEntry @@ -390,9 +393,9 @@ class RAS(implicit p: Parameters) extends BasePredictor { val realWriteAddr = Mux(io.redirect_isCall, realWriteAddr_next, Mux(io.s3_missed_push, s3_missPushAddr, realWriteAddr_next)) - val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, io.redirect_meta_TOSR, TOSR), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) + val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, Mux(io.redirect_ret_call, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_ret_call, topNos, TOSR)), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) val realNos = Mux(io.redirect_isCall, realNos_next, - Mux(io.s3_missed_push, s3_missPushNos, + Mux(io.s3_missed_push, Mux(io.s3_ret_call, io.s3_meta.NOS, io.s3_meta.TOSR), realNos_next)) realPush := (io.s3_fire && (!io.s3_cancel && RegEnable(io.spec_push_valid, io.s2_fire) || io.s3_missed_push)) || RegNext(io.redirect_valid && io.redirect_isCall) @@ -420,7 +423,14 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.spec_push_valid) { - specPush(io.spec_push_addr, ssp, sctr, TOSR, TOSW, topEntry) + when(!io.spec_ret_call) { + specPush(io.spec_push_addr, ssp, sctr, TOSR, TOSW, topEntry) + } .otherwise { + TOSR := TOSW + TOSW := specPtrInc(TOSW) + sctr := 0.U + ssp := Mux(sctr > 0.U, ptrInc(ssp), ssp) + } } def specPop(currentSsp: UInt, currentSctr: UInt, currentTOSR: RASPtr, currentTOSW: RASPtr, currentTopNos: RASPtr) = { // TOSR is only maintained when spec queue is not empty @@ -469,7 +479,14 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.s3_missed_push) { // do not use any bypass from f2 - specPush(io.s3_pushAddr, io.s3_meta.ssp, io.s3_meta.sctr, io.s3_meta.TOSR, io.s3_meta.TOSW, s3TopEntry) + when(!io.s3_ret_call) { + specPush(io.s3_pushAddr, io.s3_meta.ssp, io.s3_meta.sctr, io.s3_meta.TOSR, io.s3_meta.TOSW, s3TopEntry) + }.otherwise { + TOSR := io.s3_meta.TOSW + TOSW := specPtrInc(io.s3_meta.TOSW) + sctr := 0.U + ssp := Mux(io.s3_meta.sctr > 0.U, ptrInc(io.s3_meta.ssp), io.s3_meta.ssp) + } } } @@ -508,14 +525,27 @@ class RAS(implicit p: Parameters) extends BasePredictor { nsp_update := nsp } // if ctr < max && topAddr == push addr, ++ctr, otherwise ++nsp - when (commitTop.ctr < ctrMax && commitTop.retAddr === commit_push_addr) { - commit_stack(nsp_update).ctr := commitTop.ctr + 1.U - nsp := nsp_update - } .otherwise { - nsp := ptrInc(nsp_update) - commit_stack(ptrInc(nsp_update)).retAddr := commit_push_addr - commit_stack(ptrInc(nsp_update)).ctr := 0.U + when(!io.commit_ret_call){ + when (commitTop.ctr < ctrMax && commitTop.retAddr === commit_push_addr) { + commit_stack(nsp_update).ctr := commitTop.ctr + 1.U + nsp := nsp_update + } .otherwise { + nsp := ptrInc(nsp_update) + commit_stack(ptrInc(nsp_update)).retAddr := commit_push_addr + commit_stack(ptrInc(nsp_update)).ctr := 0.U + } + }.otherwise { + when (commitTop.ctr > 0.U){ + nsp := ptrInc(nsp_update) + commit_stack(ptrInc(nsp_update)).retAddr := commit_push_addr + commit_stack(ptrInc(nsp_update)).ctr := 0.U + }.otherwise { + nsp := nsp_update + commit_stack(nsp_update).retAddr := commit_push_addr + commit_stack(nsp_update).ctr := 0.U + } } + // when overflow, BOS may be forced move forward, do not revert those changes when (!spec_overflowed || isAfter(io.commit_meta_TOSW, BOS)) { BOS := io.commit_meta_TOSW @@ -533,7 +563,14 @@ class RAS(implicit p: Parameters) extends BasePredictor { sctr := io.redirect_meta_sctr when (io.redirect_isCall) { - specPush(io.redirect_callAddr, io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopEntry) + when (!io.redirect_ret_call) { + specPush(io.redirect_callAddr, io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopEntry) + } .otherwise { + TOSR := io.redirect_meta_TOSW + TOSW := specPtrInc(io.redirect_meta_TOSW) + sctr := 0.U // Sacrifice part of the counter space to simplify the judgment + ssp := Mux(io.redirect_meta_sctr > 0.U, ptrInc(io.redirect_meta_ssp), io.redirect_meta_ssp) + } } when (io.redirect_isRet) { specPop(io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopNos) @@ -547,18 +584,21 @@ class RAS(implicit p: Parameters) extends BasePredictor { val stack = Module(new RASStack(RasSize, RasSpecSize)).io - val s2_spec_push = WireInit(false.B) - val s2_spec_pop = WireInit(false.B) + val s2_spec_push = WireInit(false.B) + val s2_spec_pop = WireInit(false.B) + val s2_is_ret_call = WireInit(false.B) val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred(2) // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U) stack.spec_push_valid := s2_spec_push stack.spec_pop_valid := s2_spec_pop - stack.spec_push_addr := s2_spec_new_addr + stack.spec_push_addr := s2_spec_new_addr + stack.spec_ret_call := s2_is_ret_call // confirm that the call/ret is the taken cfi s2_spec_push := io.s2_fire(2) && s2_full_pred.hit_taken_on_call && !io.s3_redirect(2) s2_spec_pop := io.s2_fire(2) && s2_full_pred.hit_taken_on_ret && !io.s3_redirect(2) + s2_is_ret_call := s2_full_pred.is_ret_call //val s2_jalr_target = io.out.s2.full_pred.jalr_target //val s2_last_target_in = s2_full_pred.targets.last @@ -605,6 +645,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_popped_in_s2 = RegEnable(s2_spec_pop, io.s2_fire(2)) val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret + val s3_ret_call = io.in.bits.resp_in(0).s3.full_pred(2).is_ret_call val s3_cancel = io.s3_fire(2) && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop) stack.s2_fire := io.s2_fire(2) @@ -617,7 +658,8 @@ class RAS(implicit p: Parameters) extends BasePredictor { stack.s3_meta := s3_meta stack.s3_missed_pop := s3_pop && !s3_popped_in_s2 stack.s3_missed_push := s3_push && !s3_pushed_in_s2 - stack.s3_pushAddr := s3_spec_new_addr + stack.s3_pushAddr := s3_spec_new_addr + stack.s3_ret_call := s3_ret_call // no longer need the top Entry, but TOSR, TOSW, ssp sctr // TODO: remove related signals @@ -646,6 +688,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { stack.redirect_valid := do_recover stack.redirect_isCall := callMissPred stack.redirect_isRet := retMissPred + stack.redirect_ret_call := recover_cfi.pd.isRetCall stack.redirect_meta_ssp := recover_cfi.ssp stack.redirect_meta_sctr := recover_cfi.sctr stack.redirect_meta_TOSW := recover_cfi.TOSW @@ -659,6 +702,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { stack.commit_push_valid := updateValid && update.is_call_taken stack.commit_pop_valid := updateValid && update.is_ret_taken + stack.commit_ret_call := update.is_ret_call stack.commit_push_addr := update.ftb_entry.getFallThrough(update.pc) + Mux(update.ftb_entry.last_may_be_rvi_call, 2.U, 0.U) stack.commit_meta_TOSW := updateMeta.TOSW stack.commit_meta_ssp := updateMeta.ssp From a1599fa2d4d33cdd0e7a064d3413ba1894a516ac Mon Sep 17 00:00:00 2001 From: zhou tao Date: Wed, 24 Jul 2024 22:50:31 +0800 Subject: [PATCH 2/3] the combination of signals using 'isCall' and 'hasRet' represents the operations 'ret', 'call', and 'ret-call'. --- src/main/scala/xiangshan/frontend/FTB.scala | 15 +-- .../xiangshan/frontend/FrontendBundle.scala | 12 +- src/main/scala/xiangshan/frontend/IFU.scala | 25 ++--- .../scala/xiangshan/frontend/NewFtq.scala | 54 ++++----- .../scala/xiangshan/frontend/PreDecode.scala | 52 ++++----- .../scala/xiangshan/frontend/newRAS.scala | 106 +++++++++--------- 6 files changed, 129 insertions(+), 135 deletions(-) diff --git a/src/main/scala/xiangshan/frontend/FTB.scala b/src/main/scala/xiangshan/frontend/FTB.scala index 7f3114a5ca6..a24f77dc12e 100644 --- a/src/main/scala/xiangshan/frontend/FTB.scala +++ b/src/main/scala/xiangshan/frontend/FTB.scala @@ -145,12 +145,13 @@ class FtbSlot(val offsetLen: Int, val subOffsetLen: Option[Int] = None)(implicit class FTBEntry_part(implicit p: Parameters) extends XSBundle with FTBParams with BPUUtils { - val isCall = Bool() - val isRet = Bool() - val isJalr = Bool() - val isRetCall = Bool() + val isCall = Bool() + val hasRet = Bool() // maybe ret or ret-call + val isJalr = Bool() def isJal = !isJalr + def isRet = !isCall && hasRet + def isRetCall = isCall && hasRet } class FTBEntry_FtqMem(implicit p: Parameters) extends FTBEntry_part with FTBParams with BPUUtils { @@ -372,7 +373,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with val pftAddrDiff = this.pftAddr === that.pftAddr val carryDiff = this.carry === that.carry val isCallDiff = this.isCall === that.isCall - val isRetDiff = this.isRet === that.isRet + val hasRetDiff = this.hasRet === that.hasRet val isJalrDiff = this.isJalr === that.isJalr val lastMayBeRviCallDiff = this.last_may_be_rvi_call === that.last_may_be_rvi_call val alwaysTakenDiff : IndexedSeq[Bool] = @@ -386,7 +387,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with pftAddrDiff, carryDiff, isCallDiff, - isRetDiff, + hasRetDiff, isJalrDiff, lastMayBeRviCallDiff, alwaysTakenDiff.reduce(_&&_) @@ -403,7 +404,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with XSDebug(cond, p"[tailSlot]: v=${tailSlot.valid}, offset=${tailSlot.offset}," + p"lower=${Hexadecimal(tailSlot.lower)}, sharing=${tailSlot.sharing}}\n") XSDebug(cond, p"pftAddr=${Hexadecimal(pftAddr)}, carry=$carry\n") - XSDebug(cond, p"isCall=$isCall, isRet=$isRet, isjalr=$isJalr\n") + XSDebug(cond, p"isCall=$isCall, hasRet=$hasRet, isjalr=$isJalr\n") XSDebug(cond, p"last_may_be_rvi_call=$last_may_be_rvi_call\n") XSDebug(cond, p"------------------------------- \n") } diff --git a/src/main/scala/xiangshan/frontend/FrontendBundle.scala b/src/main/scala/xiangshan/frontend/FrontendBundle.scala index 693c258c38b..9164538d789 100644 --- a/src/main/scala/xiangshan/frontend/FrontendBundle.scala +++ b/src/main/scala/xiangshan/frontend/FrontendBundle.scala @@ -434,7 +434,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC val is_jalr = Bool() val is_call = Bool() val is_ret = Bool() - val is_ret_call = Bool() + val has_ret = Bool() // only used for the ret-call behavior in RAS val last_may_be_rvi_call = Bool() val is_br_sharing = Bool() @@ -537,8 +537,8 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC is_jal := entry.tailSlot.valid && entry.isJal is_jalr := entry.tailSlot.valid && entry.isJalr is_call := entry.tailSlot.valid && entry.isCall - is_ret := entry.tailSlot.valid && entry.isRet - is_ret_call := entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone + is_ret := entry.tailSlot.valid && entry.hasRet && !entry.isCall + has_ret := entry.hasRet last_may_be_rvi_call := entry.last_may_be_rvi_call is_br_sharing := entry.tailSlot.valid && entry.tailSlot.sharing predCycle.map(_ := GTimer()) @@ -649,11 +649,11 @@ class BranchPredictionUpdate(implicit p: Parameters) extends XSBundle with HasBP val from_stage = UInt(2.W) val ghist = UInt(HistoryLength.W) - def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal + def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal def is_jalr = ftb_entry.tailSlot.valid && ftb_entry.isJalr def is_call = ftb_entry.tailSlot.valid && ftb_entry.isCall - def is_ret = ftb_entry.tailSlot.valid && ftb_entry.isRet - def is_ret_call = ftb_entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone + def is_ret = ftb_entry.tailSlot.valid && ftb_entry.hasRet && !ftb_entry.isCall + def has_ret = ftb_entry.hasRet def is_call_taken = is_call && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset def is_ret_taken = is_ret && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset diff --git a/src/main/scala/xiangshan/frontend/IFU.scala b/src/main/scala/xiangshan/frontend/IFU.scala index 03c551110df..48d168de057 100644 --- a/src/main/scala/xiangshan/frontend/IFU.scala +++ b/src/main/scala/xiangshan/frontend/IFU.scala @@ -547,16 +547,15 @@ class NewIFU(implicit p: Parameters) extends XSModule // Expand 1 bit to prevent overflow when assert val f3_ftq_req_startAddr = Cat(0.U(1.W), f3_ftq_req.startAddr) val f3_ftq_req_nextStartAddr = Cat(0.U(1.W), f3_ftq_req.nextStartAddr) - // brType, isCall and isRet generation is delayed to f3 stage + // brType, isCall and hasRet generation is delayed to f3 stage val f3Predecoder = Module(new F3Predecoder) f3Predecoder.io.in.instr := f3_instr f3_pd.zipWithIndex.map{ case (pd,i) => - pd.brType := f3Predecoder.io.out.pd(i).brType - pd.isCall := f3Predecoder.io.out.pd(i).isCall - pd.isRet := f3Predecoder.io.out.pd(i).isRet - pd.isRetCall := f3Predecoder.io.out.pd(i).isRetCall + pd.brType := f3Predecoder.io.out.pd(i).brType + pd.isCall := f3Predecoder.io.out.pd(i).isCall + pd.hasRet := f3Predecoder.io.out.pd(i).hasRet } val f3PdDiff = f3_pd_wire.zip(f3_pd).map{ case (a,b) => a.asUInt =/= b.asUInt }.reduce(_||_) @@ -864,7 +863,7 @@ class NewIFU(implicit p: Parameters) extends XSModule val inst = Cat(f3_mmio_data(1), f3_mmio_data(0)) val currentIsRVC = isRVC(inst) - val brType::isCall::isRet::isRetCall::Nil = brInfo(inst) + val brType::isCall::hasRet::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC) val brOffset = br_offset(inst, currentIsRVC) @@ -875,8 +874,7 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.bits.pd(0).isRVC := currentIsRVC io.toIbuffer.bits.pd(0).brType := brType io.toIbuffer.bits.pd(0).isCall := isCall - io.toIbuffer.bits.pd(0).isRet := isRet - io.toIbuffer.bits.pd(0).isRetCall := isRetCall + io.toIbuffer.bits.pd(0).hasRet := hasRet when (mmio_resend_af) { io.toIbuffer.bits.exceptionType(0) := ExceptionType.acf @@ -889,12 +887,11 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.bits.enqEnable := f3_mmio_range.asUInt - mmioFlushWb.bits.pd(0).valid := true.B - mmioFlushWb.bits.pd(0).isRVC := currentIsRVC - mmioFlushWb.bits.pd(0).brType := brType - mmioFlushWb.bits.pd(0).isCall := isCall - mmioFlushWb.bits.pd(0).isRet := isRet - mmioFlushWb.bits.pd(0).isRetCall := isRetCall + mmioFlushWb.bits.pd(0).valid := true.B + mmioFlushWb.bits.pd(0).isRVC := currentIsRVC + mmioFlushWb.bits.pd(0).brType := brType + mmioFlushWb.bits.pd(0).isCall := isCall + mmioFlushWb.bits.pd(0).hasRet := hasRet } mmio_redirect := (f3_req_is_mmio && mmio_state === m_waitCommit && RegNext(fromUncache.fire) && f3_mmio_use_seq_pc) diff --git a/src/main/scala/xiangshan/frontend/NewFtq.scala b/src/main/scala/xiangshan/frontend/NewFtq.scala index 62d19ad0e24..4327910c47b 100644 --- a/src/main/scala/xiangshan/frontend/NewFtq.scala +++ b/src/main/scala/xiangshan/frontend/NewFtq.scala @@ -34,7 +34,6 @@ class FtqDebugBundle extends Bundle { val isJmp = Bool() val isCall = Bool() val isRet = Bool() - val isRetCall = Bool() val misPred = Bool() val isTaken = Bool() val predStage = UInt(2.W) @@ -110,23 +109,23 @@ class Ftq_RF_Components(implicit p: Parameters) extends XSBundle with BPUUtils { class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { val brMask = Vec(PredictWidth, Bool()) - // jmpInfo includes isJalr, isCall, isRet, isRetCall signals - val jmpInfo = ValidUndirectioned(Vec(4, Bool())) + // jmpInfo(0) = jalr, jmpInfo(1) = isCall, jmpInfo(2) = hasRet + val jmpInfo = ValidUndirectioned(Vec(3, Bool())) val jmpOffset = UInt(log2Ceil(PredictWidth).W) val jalTarget = UInt(VAddrBits.W) val rvcMask = Vec(PredictWidth, Bool()) def hasJal = jmpInfo.valid && !jmpInfo.bits(0) def hasJalr = jmpInfo.valid && jmpInfo.bits(0) - def hasCall = jmpInfo.valid && jmpInfo.bits(1) - def hasRet = jmpInfo.valid && jmpInfo.bits(2) - def hasRetCall = jmpInfo.valid && jmpInfo.bits(3) + def isCall = jmpInfo.valid && jmpInfo.bits(1) + def isRet = jmpInfo.valid && jmpInfo.bits(2) && !jmpInfo.bits(1) + def isRetCall = jmpInfo.valid && jmpInfo.bits(2) && jmpInfo.bits(1) def fromPdWb(pdWb: PredecodeWritebackBundle) = { val pds = pdWb.pd this.brMask := VecInit(pds.map(pd => pd.isBr && pd.valid)) this.jmpInfo.valid := VecInit(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)).asUInt.orR this.jmpInfo.bits := ParallelPriorityMux(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid), - pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.isRet, pd.isRetCall))) + pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.hasRet))) this.jmpOffset := ParallelPriorityEncoder(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)) this.rvcMask := VecInit(pds.map(pd => pd.isRVC)) this.jalTarget := pdWb.jalTarget @@ -139,10 +138,9 @@ class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { pd.isRVC := rvcMask(offset) val isBr = brMask(offset) val isJalr = offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(0) - pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr) - pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1) - pd.isRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2) - pd.isRetCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(3) + pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr) + pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1) + pd.hasRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2) pd } } @@ -276,8 +274,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire val new_jmp_is_jal = entry_has_jmp && !pd.jmpInfo.bits(0) && io.cfiIndex.valid val new_jmp_is_jalr = entry_has_jmp && pd.jmpInfo.bits(0) && io.cfiIndex.valid val new_jmp_is_call = entry_has_jmp && pd.jmpInfo.bits(1) && io.cfiIndex.valid - val new_jmp_is_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid - val new_jmp_is_ret_call = entry_has_jmp && pd.jmpInfo.bits(3) && io.cfiIndex.valid + val new_jmp_has_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid val last_jmp_rvi = entry_has_jmp && pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask.last // val last_br_rvi = cfi_is_br && io.cfiIndex.bits === (PredictWidth-1).U && !pd.rvcMask.last @@ -307,12 +304,11 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire } val jmpPft = getLower(io.start_addr) +& pd.jmpOffset +& Mux(pd.rvcMask(pd.jmpOffset), 1.U, 2.U) - init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr)) - init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B) + init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr)) + init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B) init_entry.isJalr := new_jmp_is_jalr init_entry.isCall := new_jmp_is_call - init_entry.isRet := new_jmp_is_ret - init_entry.isRetCall := new_jmp_is_ret_call + init_entry.hasRet := new_jmp_has_ret // that means fall thru points to the middle of an inst init_entry.last_may_be_rvi_call := pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask(pd.jmpOffset) @@ -374,8 +370,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire old_entry_modified.carry := (getLower(io.start_addr) +& new_pft_offset).head(1).asBool old_entry_modified.last_may_be_rvi_call := false.B old_entry_modified.isCall := false.B - old_entry_modified.isRet := false.B - old_entry_modified.isRetCall := false.B + old_entry_modified.hasRet := false.B old_entry_modified.isJalr := false.B } @@ -1449,15 +1444,14 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe p"brInEntry(${inFtbEntry}) brIdx(${brIdx}) target(${Hexadecimal(target)})\n") val logbundle = Wire(new FtqDebugBundle) - logbundle.pc := pc - logbundle.target := target - logbundle.isBr := isBr - logbundle.isJmp := isJmp - logbundle.isCall := isJmp && commit_pd.hasCall - logbundle.isRet := isJmp && commit_pd.hasRet - logbundle.isRetCall := isJmp && commit_pd.hasRetCall - logbundle.misPred := misPred - logbundle.isTaken := isTaken + logbundle.pc := pc + logbundle.target := target + logbundle.isBr := isBr + logbundle.isJmp := isJmp + logbundle.isCall := isJmp && commit_pd.isCall + logbundle.isRet := isJmp && commit_pd.isRet + logbundle.misPred := misPred + logbundle.isTaken := isTaken logbundle.predStage := commit_stage ftqBranchTraceDB.log( @@ -1502,8 +1496,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe val commit_jal_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJal.asTypeOf(UInt(1.W))) val commit_jalr_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJalr.asTypeOf(UInt(1.W))) - val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasCall.asTypeOf(UInt(1.W))) - val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasRet.asTypeOf(UInt(1.W))) + val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isCall.asTypeOf(UInt(1.W))) + val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isRet.asTypeOf(UInt(1.W))) val mbpBRights = mbpRights & commit_br_mask diff --git a/src/main/scala/xiangshan/frontend/PreDecode.scala b/src/main/scala/xiangshan/frontend/PreDecode.scala index c86cf5479fc..b84c068e6ae 100644 --- a/src/main/scala/xiangshan/frontend/PreDecode.scala +++ b/src/main/scala/xiangshan/frontend/PreDecode.scala @@ -37,8 +37,9 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons val rs = Mux(isRVC(instr), Mux(brType === BrType.jal, 0.U, instr(11, 7)), instr(19, 15)) val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64 val isRet = brType === BrType.jalr && isLink(rs) && !isCall - val isRetCall = (brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd)) - List(brType, isCall, isRet, isRetCall) + val isRetCall = brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd) + val hasRet = isRet || isRetCall + List(brType, isCall, hasRet) } def jal_offset(inst: UInt, rvc: Bool): UInt = { val rvc_offset = Cat(inst(12), inst(8), inst(10, 9), inst(6), inst(7), inst(2), inst(11), inst(5, 3), 0.U(1.W)) @@ -70,17 +71,20 @@ object ExcType { //TODO:add exctype } class PreDecodeInfo extends Bundle { // 8 bit - val valid = Bool() - val isRVC = Bool() - val brType = UInt(2.W) - val isCall = Bool() - val isRet = Bool() - val isRetCall = Bool() + val valid = Bool() + val isRVC = Bool() + val brType = UInt(2.W) + val isCall = Bool() + val hasRet = Bool() // maybe ret or ret-call + // val isCall = Bool() + // val isRet = Bool() //val excType = UInt(3.W) def isBr = brType === BrType.branch def isJal = brType === BrType.jal def isJalr = brType === BrType.jalr def notCFI = brType === BrType.notCFI + def isRet = hasRet && !isCall + def isRetCall = isCall && hasRet } class PreDecodeResp(implicit p: Parameters) extends XSBundle with HasPdConst { @@ -135,27 +139,26 @@ class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{ val currentPC = io.in.bits.pc(i) //expander.io.in := inst - val brType::isCall::isRet::isRetCall::Nil = brInfo(inst) + val brType::isCall::hasRet::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC(i)) val brOffset = br_offset(inst, currentIsRVC(i)) - io.out.hasHalfValid(i) := h_validStart(i) + io.out.hasHalfValid(i) := h_validStart(i) - io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B)) + io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B)) - io.out.pd(i).valid := validStart(i) - io.out.pd(i).isRVC := currentIsRVC(i) + io.out.pd(i).valid := validStart(i) + io.out.pd(i).isRVC := currentIsRVC(i) // for diff purpose only - io.out.pd(i).brType := brType - io.out.pd(i).isCall := isCall - io.out.pd(i).isRet := isRet - io.out.pd(i).isRetCall := isRetCall + io.out.pd(i).brType := brType + io.out.pd(i).isCall := isCall + io.out.pd(i).hasRet := hasRet //io.out.expInstr(i) := expander.io.out.bits - io.out.instr(i) :=inst - io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) + io.out.instr(i) :=inst + io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) } // the first half is always reliable @@ -262,12 +265,11 @@ class F3Predecoder(implicit p: Parameters) extends XSModule with HasPdConst { val out = Output(new F3PreDecodeResp) }) io.out.pd.zipWithIndex.map{ case (pd,i) => - pd.valid := DontCare - pd.isRVC := DontCare - pd.brType := brInfo(io.in.instr(i))(0) - pd.isCall := brInfo(io.in.instr(i))(1) - pd.isRet := brInfo(io.in.instr(i))(2) - pd.isRetCall := brInfo(io.in.instr(i))(3) + pd.valid := DontCare + pd.isRVC := DontCare + pd.brType := brInfo(io.in.instr(i))(0) + pd.isCall := brInfo(io.in.instr(i))(1) + pd.hasRet := brInfo(io.in.instr(i))(2) } } diff --git a/src/main/scala/xiangshan/frontend/newRAS.scala b/src/main/scala/xiangshan/frontend/newRAS.scala index ab431002fca..2d7a82cc1a7 100644 --- a/src/main/scala/xiangshan/frontend/newRAS.scala +++ b/src/main/scala/xiangshan/frontend/newRAS.scala @@ -104,37 +104,37 @@ class RAS(implicit p: Parameters) extends BasePredictor { val spec_push_valid = Input(Bool()) val spec_pop_valid = Input(Bool()) val spec_push_addr = Input(UInt(VAddrBits.W)) - val spec_ret_call = Input(Bool()) + val spec_has_ret = Input(Bool()) // for write bypass between s2 and s3 - val s2_fire = Input(Bool()) - val s3_fire = Input(Bool()) + val s2_fire = Input(Bool()) + val s3_fire = Input(Bool()) val s3_cancel = Input(Bool()) - val s3_meta = Input(new RASInternalMeta) - val s3_missed_pop = Input(Bool()) - val s3_missed_push = Input(Bool()) - val s3_pushAddr = Input(UInt(VAddrBits.W)) - val s3_ret_call = Input(Bool()) - val spec_pop_addr = Output(UInt(VAddrBits.W)) + val s3_meta = Input(new RASInternalMeta) + val s3_missed_pop = Input(Bool()) + val s3_missed_push = Input(Bool()) + val s3_pushAddr = Input(UInt(VAddrBits.W)) + val s3_has_ret = Input(Bool()) + val spec_pop_addr = Output(UInt(VAddrBits.W)) val commit_push_valid = Input(Bool()) val commit_pop_valid = Input(Bool()) - val commit_ret_call = Input(Bool()) - val commit_push_addr = Input(UInt(VAddrBits.W)) - val commit_meta_TOSW = Input(new RASPtr) + val commit_has_ret = Input(Bool()) + val commit_push_addr = Input(UInt(VAddrBits.W)) + val commit_meta_TOSW = Input(new RASPtr) // for debug purpose only - val commit_meta_ssp = Input(UInt(log2Up(RasSize).W)) + val commit_meta_ssp = Input(UInt(log2Up(RasSize).W)) val redirect_valid = Input(Bool()) val redirect_isCall = Input(Bool()) val redirect_isRet = Input(Bool()) - val redirect_ret_call = Input(Bool()) + val redirect_has_ret = Input(Bool()) val redirect_meta_ssp = Input(UInt(log2Up(RasSize).W)) - val redirect_meta_sctr = Input(UInt(RasCtrSize.W)) - val redirect_meta_TOSW = Input(new RASPtr) - val redirect_meta_TOSR = Input(new RASPtr) - val redirect_meta_NOS = Input(new RASPtr) - val redirect_callAddr = Input(UInt(VAddrBits.W)) + val redirect_meta_sctr = Input(UInt(RasCtrSize.W)) + val redirect_meta_TOSW = Input(new RASPtr) + val redirect_meta_TOSR = Input(new RASPtr) + val redirect_meta_NOS = Input(new RASPtr) + val redirect_callAddr = Input(UInt(VAddrBits.W)) val ssp = Output(UInt(log2Up(RasSize).W)) val sctr = Output(UInt(RasCtrSize.W)) @@ -258,11 +258,11 @@ class RAS(implicit p: Parameters) extends BasePredictor { val writeNos = Wire(new RASPtr) writeEntry.retAddr := Mux(io.redirect_valid && io.redirect_isCall, io.redirect_callAddr, io.spec_push_addr) writeEntry.ctr := Mux(io.redirect_valid && io.redirect_isCall, - Mux(io.redirect_ret_call, 0.U, Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U)), - Mux(io.spec_ret_call, 0.U, Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U))) + Mux(io.redirect_has_ret, 0.U, Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U)), + Mux(io.spec_has_ret, 0.U, Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U))) writeNos := Mux(io.redirect_valid && io.redirect_isCall, - Mux(io.redirect_ret_call, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_ret_call, topNos,TOSR)) + Mux(io.redirect_has_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_has_ret, topNos,TOSR)) when (io.spec_push_valid || (io.redirect_valid && io.redirect_isCall)) { writeBypassEntry := writeEntry @@ -393,9 +393,9 @@ class RAS(implicit p: Parameters) extends BasePredictor { val realWriteAddr = Mux(io.redirect_isCall, realWriteAddr_next, Mux(io.s3_missed_push, s3_missPushAddr, realWriteAddr_next)) - val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, Mux(io.redirect_ret_call, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_ret_call, topNos, TOSR)), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) + val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, Mux(io.redirect_has_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_has_ret, topNos, TOSR)), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) val realNos = Mux(io.redirect_isCall, realNos_next, - Mux(io.s3_missed_push, Mux(io.s3_ret_call, io.s3_meta.NOS, io.s3_meta.TOSR), + Mux(io.s3_missed_push, Mux(io.s3_has_ret, io.s3_meta.NOS, io.s3_meta.TOSR), realNos_next)) realPush := (io.s3_fire && (!io.s3_cancel && RegEnable(io.spec_push_valid, io.s2_fire) || io.s3_missed_push)) || RegNext(io.redirect_valid && io.redirect_isCall) @@ -423,7 +423,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.spec_push_valid) { - when(!io.spec_ret_call) { + when(!io.spec_has_ret) { specPush(io.spec_push_addr, ssp, sctr, TOSR, TOSW, topEntry) } .otherwise { TOSR := TOSW @@ -479,7 +479,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.s3_missed_push) { // do not use any bypass from f2 - when(!io.s3_ret_call) { + when(!io.s3_has_ret) { specPush(io.s3_pushAddr, io.s3_meta.ssp, io.s3_meta.sctr, io.s3_meta.TOSR, io.s3_meta.TOSW, s3TopEntry) }.otherwise { TOSR := io.s3_meta.TOSW @@ -525,7 +525,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { nsp_update := nsp } // if ctr < max && topAddr == push addr, ++ctr, otherwise ++nsp - when(!io.commit_ret_call){ + when(!io.commit_has_ret){ when (commitTop.ctr < ctrMax && commitTop.retAddr === commit_push_addr) { commit_stack(nsp_update).ctr := commitTop.ctr + 1.U nsp := nsp_update @@ -563,7 +563,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { sctr := io.redirect_meta_sctr when (io.redirect_isCall) { - when (!io.redirect_ret_call) { + when (!io.redirect_has_ret) { specPush(io.redirect_callAddr, io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopEntry) } .otherwise { TOSR := io.redirect_meta_TOSW @@ -586,19 +586,19 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s2_spec_push = WireInit(false.B) val s2_spec_pop = WireInit(false.B) - val s2_is_ret_call = WireInit(false.B) - val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred(2) + val s2_has_ret = WireInit(false.B) + val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred(2) // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U) stack.spec_push_valid := s2_spec_push stack.spec_pop_valid := s2_spec_pop stack.spec_push_addr := s2_spec_new_addr - stack.spec_ret_call := s2_is_ret_call + stack.spec_has_ret := s2_has_ret // confirm that the call/ret is the taken cfi - s2_spec_push := io.s2_fire(2) && s2_full_pred.hit_taken_on_call && !io.s3_redirect(2) - s2_spec_pop := io.s2_fire(2) && s2_full_pred.hit_taken_on_ret && !io.s3_redirect(2) - s2_is_ret_call := s2_full_pred.is_ret_call + s2_spec_push := io.s2_fire(2) && s2_full_pred.hit_taken_on_call && !io.s3_redirect(2) + s2_spec_pop := io.s2_fire(2) && s2_full_pred.hit_taken_on_ret && !io.s3_redirect(2) + s2_has_ret := s2_full_pred.has_ret //val s2_jalr_target = io.out.s2.full_pred.jalr_target //val s2_last_target_in = s2_full_pred.targets.last @@ -645,7 +645,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_popped_in_s2 = RegEnable(s2_spec_pop, io.s2_fire(2)) val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret - val s3_ret_call = io.in.bits.resp_in(0).s3.full_pred(2).is_ret_call + val s3_has_ret = io.in.bits.resp_in(0).s3.full_pred(2).has_ret val s3_cancel = io.s3_fire(2) && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop) stack.s2_fire := io.s2_fire(2) @@ -656,10 +656,10 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_meta = RegEnable(s2_meta, io.s2_fire(2)) stack.s3_meta := s3_meta - stack.s3_missed_pop := s3_pop && !s3_popped_in_s2 - stack.s3_missed_push := s3_push && !s3_pushed_in_s2 + stack.s3_missed_pop := s3_pop && !s3_popped_in_s2 + stack.s3_missed_push := s3_push && !s3_pushed_in_s2 stack.s3_pushAddr := s3_spec_new_addr - stack.s3_ret_call := s3_ret_call + stack.s3_has_ret := s3_has_ret // no longer need the top Entry, but TOSR, TOSW, ssp sctr // TODO: remove related signals @@ -685,27 +685,27 @@ class RAS(implicit p: Parameters) extends BasePredictor { val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall // when we mispredict a call, we must redo a push operation // similarly, when we mispredict a return, we should redo a pop - stack.redirect_valid := do_recover + stack.redirect_valid := do_recover stack.redirect_isCall := callMissPred - stack.redirect_isRet := retMissPred - stack.redirect_ret_call := recover_cfi.pd.isRetCall - stack.redirect_meta_ssp := recover_cfi.ssp - stack.redirect_meta_sctr := recover_cfi.sctr - stack.redirect_meta_TOSW := recover_cfi.TOSW - stack.redirect_meta_TOSR := recover_cfi.TOSR - stack.redirect_meta_NOS := recover_cfi.NOS - stack.redirect_callAddr := recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U) + stack.redirect_isRet := retMissPred + stack.redirect_has_ret := recover_cfi.pd.isRetCall + stack.redirect_meta_ssp := recover_cfi.ssp + stack.redirect_meta_sctr := recover_cfi.sctr + stack.redirect_meta_TOSW := recover_cfi.TOSW + stack.redirect_meta_TOSR := recover_cfi.TOSR + stack.redirect_meta_NOS := recover_cfi.NOS + stack.redirect_callAddr := recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U) val update = io.update.bits val updateMeta = io.update.bits.meta.asTypeOf(new RASMeta) val updateValid = io.update.valid stack.commit_push_valid := updateValid && update.is_call_taken - stack.commit_pop_valid := updateValid && update.is_ret_taken - stack.commit_ret_call := update.is_ret_call - stack.commit_push_addr := update.ftb_entry.getFallThrough(update.pc) + Mux(update.ftb_entry.last_may_be_rvi_call, 2.U, 0.U) - stack.commit_meta_TOSW := updateMeta.TOSW - stack.commit_meta_ssp := updateMeta.ssp + stack.commit_pop_valid := updateValid && update.is_ret_taken + stack.commit_has_ret := update.has_ret + stack.commit_push_addr := update.ftb_entry.getFallThrough(update.pc) + Mux(update.ftb_entry.last_may_be_rvi_call, 2.U, 0.U) + stack.commit_meta_TOSW := updateMeta.TOSW + stack.commit_meta_ssp := updateMeta.ssp XSPerfAccumulate("ras_s3_cancel", s3_cancel) @@ -747,4 +747,4 @@ class RAS(implicit p: Parameters) extends BasePredictor { */ generatePerfEvent() -} +} \ No newline at end of file From b7adbee8ed0fb40e401bfd519d3e86519a775932 Mon Sep 17 00:00:00 2001 From: zhou tao Date: Fri, 26 Jul 2024 01:12:07 +0800 Subject: [PATCH 3/3] RAS: adjust RAS module related signal interface names --- src/main/scala/xiangshan/frontend/FTB.scala | 21 ++- .../xiangshan/frontend/FrontendBundle.scala | 24 ++- src/main/scala/xiangshan/frontend/IFU.scala | 24 +-- .../scala/xiangshan/frontend/NewFtq.scala | 54 +++---- .../scala/xiangshan/frontend/PreDecode.scala | 51 +++--- .../scala/xiangshan/frontend/newRAS.scala | 147 ++++++++---------- 6 files changed, 149 insertions(+), 172 deletions(-) diff --git a/src/main/scala/xiangshan/frontend/FTB.scala b/src/main/scala/xiangshan/frontend/FTB.scala index a24f77dc12e..69ac893f756 100644 --- a/src/main/scala/xiangshan/frontend/FTB.scala +++ b/src/main/scala/xiangshan/frontend/FTB.scala @@ -145,13 +145,12 @@ class FtbSlot(val offsetLen: Int, val subOffsetLen: Option[Int] = None)(implicit class FTBEntry_part(implicit p: Parameters) extends XSBundle with FTBParams with BPUUtils { - val isCall = Bool() - val hasRet = Bool() // maybe ret or ret-call - val isJalr = Bool() + val isCall = Bool() + val isRet = Bool() + val isJalr = Bool() - def isJal = !isJalr - def isRet = !isCall && hasRet - def isRetCall = isCall && hasRet + def isJal = !isJalr + def onlyRet = isRet && !isCall } class FTBEntry_FtqMem(implicit p: Parameters) extends FTBEntry_part with FTBParams with BPUUtils { @@ -373,7 +372,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with val pftAddrDiff = this.pftAddr === that.pftAddr val carryDiff = this.carry === that.carry val isCallDiff = this.isCall === that.isCall - val hasRetDiff = this.hasRet === that.hasRet + val isRetDiff = this.isRet === that.isRet val isJalrDiff = this.isJalr === that.isJalr val lastMayBeRviCallDiff = this.last_may_be_rvi_call === that.last_may_be_rvi_call val alwaysTakenDiff : IndexedSeq[Bool] = @@ -387,7 +386,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with pftAddrDiff, carryDiff, isCallDiff, - hasRetDiff, + isRetDiff, isJalrDiff, lastMayBeRviCallDiff, alwaysTakenDiff.reduce(_&&_) @@ -404,7 +403,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with XSDebug(cond, p"[tailSlot]: v=${tailSlot.valid}, offset=${tailSlot.offset}," + p"lower=${Hexadecimal(tailSlot.lower)}, sharing=${tailSlot.sharing}}\n") XSDebug(cond, p"pftAddr=${Hexadecimal(pftAddr)}, carry=$carry\n") - XSDebug(cond, p"isCall=$isCall, hasRet=$hasRet, isjalr=$isJalr\n") + XSDebug(cond, p"isCall=$isCall, isRet=$isRet, isjalr=$isJalr\n") XSDebug(cond, p"last_may_be_rvi_call=$last_may_be_rvi_call\n") XSDebug(cond, p"------------------------------- \n") } @@ -739,7 +738,7 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU io.out.s1_ftbCloseReq := s1_close_ftb_req io.out.s1_uftbHit := io.fauftb_entry_hit_in val s1_uftbHasIndirect = io.fauftb_entry_in.jmpValid && - io.fauftb_entry_in.isJalr && !io.fauftb_entry_in.isRet // uFTB determines that it's real JALR, RET and JAL are excluded + io.fauftb_entry_in.isJalr && !io.fauftb_entry_in.onlyRet // uFTB determines that it's real JALR, only RET and JAL are excluded io.out.s1_uftbHasIndirect := s1_uftbHasIndirect // always taken logic @@ -815,4 +814,4 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU ("ftb_commit_misses ", io.update.valid && !u_meta.hit), ) generatePerfEvent() -} +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/FrontendBundle.scala b/src/main/scala/xiangshan/frontend/FrontendBundle.scala index 9164538d789..574a3766b62 100644 --- a/src/main/scala/xiangshan/frontend/FrontendBundle.scala +++ b/src/main/scala/xiangshan/frontend/FrontendBundle.scala @@ -430,11 +430,10 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC val fallThroughErr = Bool() val multiHit = Bool() - val is_jal = Bool() + val is_jal = Bool() val is_jalr = Bool() val is_call = Bool() - val is_ret = Bool() - val has_ret = Bool() // only used for the ret-call behavior in RAS + val is_ret = Bool() val last_may_be_rvi_call = Bool() val is_br_sharing = Bool() @@ -509,7 +508,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC !real_slot_taken_mask().init.reduce(_||_) && real_slot_taken_mask().last && !is_br_sharing def hit_taken_on_call = hit_taken_on_jmp && is_call - def hit_taken_on_ret = hit_taken_on_jmp && is_ret + def hit_taken_on_only_ret = hit_taken_on_jmp && is_ret && !is_call def hit_taken_on_jalr = hit_taken_on_jmp && is_jalr def cfiIndex = { @@ -537,8 +536,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC is_jal := entry.tailSlot.valid && entry.isJal is_jalr := entry.tailSlot.valid && entry.isJalr is_call := entry.tailSlot.valid && entry.isCall - is_ret := entry.tailSlot.valid && entry.hasRet && !entry.isCall - has_ret := entry.hasRet + is_ret := entry.tailSlot.valid && entry.isRet last_may_be_rvi_call := entry.last_may_be_rvi_call is_br_sharing := entry.tailSlot.valid && entry.tailSlot.sharing predCycle.map(_ := GTimer()) @@ -649,14 +647,13 @@ class BranchPredictionUpdate(implicit p: Parameters) extends XSBundle with HasBP val from_stage = UInt(2.W) val ghist = UInt(HistoryLength.W) - def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal + def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal def is_jalr = ftb_entry.tailSlot.valid && ftb_entry.isJalr def is_call = ftb_entry.tailSlot.valid && ftb_entry.isCall - def is_ret = ftb_entry.tailSlot.valid && ftb_entry.hasRet && !ftb_entry.isCall - def has_ret = ftb_entry.hasRet + def is_ret = ftb_entry.tailSlot.valid && ftb_entry.isRet def is_call_taken = is_call && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset - def is_ret_taken = is_ret && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset + def is_only_ret_taken = !is_call && is_ret && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset def display(cond: Bool) = { XSDebug(cond, p"-----------BranchPredictionUpdate-----------\n") @@ -693,8 +690,9 @@ class BranchPredictionRedirect(implicit p: Parameters) extends Redirect with Has def ControlBTBMissBubble = ControlRedirectBubble && !cfiUpdate.br_hit && !cfiUpdate.jr_hit def TAGEMissBubble = ControlRedirectBubble && cfiUpdate.br_hit && !cfiUpdate.sc_hit def SCMissBubble = ControlRedirectBubble && cfiUpdate.br_hit && cfiUpdate.sc_hit - def ITTAGEMissBubble = ControlRedirectBubble && cfiUpdate.jr_hit && !cfiUpdate.pd.isRet - def RASMissBubble = ControlRedirectBubble && cfiUpdate.jr_hit && cfiUpdate.pd.isRet + // ret-call instruction will jump in the same way as the call instruction + def ITTAGEMissBubble = ControlRedirectBubble && cfiUpdate.jr_hit && (!cfiUpdate.pd.isRet || cfiUpdate.pd.isCall) + def RASMissBubble = ControlRedirectBubble && cfiUpdate.jr_hit && (cfiUpdate.pd.isRet && !cfiUpdate.pd.isCall) def MemVioRedirectBubble = debugIsMemVio def OtherRedirectBubble = !debugIsCtrl && !debugIsMemVio @@ -723,4 +721,4 @@ class BranchPredictionRedirect(implicit p: Parameters) extends Redirect with Has XSDebug(cond, p"[stFtqOffset] ${stFtqOffset}\n") XSDebug(cond, p"---------------------------------------------- \n") } -} +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/IFU.scala b/src/main/scala/xiangshan/frontend/IFU.scala index 48d168de057..21761fe394a 100644 --- a/src/main/scala/xiangshan/frontend/IFU.scala +++ b/src/main/scala/xiangshan/frontend/IFU.scala @@ -547,15 +547,15 @@ class NewIFU(implicit p: Parameters) extends XSModule // Expand 1 bit to prevent overflow when assert val f3_ftq_req_startAddr = Cat(0.U(1.W), f3_ftq_req.startAddr) val f3_ftq_req_nextStartAddr = Cat(0.U(1.W), f3_ftq_req.nextStartAddr) - // brType, isCall and hasRet generation is delayed to f3 stage + // brType, isCall and isRet generation is delayed to f3 stage val f3Predecoder = Module(new F3Predecoder) f3Predecoder.io.in.instr := f3_instr f3_pd.zipWithIndex.map{ case (pd,i) => - pd.brType := f3Predecoder.io.out.pd(i).brType - pd.isCall := f3Predecoder.io.out.pd(i).isCall - pd.hasRet := f3Predecoder.io.out.pd(i).hasRet + pd.brType := f3Predecoder.io.out.pd(i).brType + pd.isCall := f3Predecoder.io.out.pd(i).isCall + pd.isRet := f3Predecoder.io.out.pd(i).isRet } val f3PdDiff = f3_pd_wire.zip(f3_pd).map{ case (a,b) => a.asUInt =/= b.asUInt }.reduce(_||_) @@ -863,7 +863,7 @@ class NewIFU(implicit p: Parameters) extends XSModule val inst = Cat(f3_mmio_data(1), f3_mmio_data(0)) val currentIsRVC = isRVC(inst) - val brType::isCall::hasRet::Nil = brInfo(inst) + val brType::isCall::isRet::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC) val brOffset = br_offset(inst, currentIsRVC) @@ -874,7 +874,7 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.bits.pd(0).isRVC := currentIsRVC io.toIbuffer.bits.pd(0).brType := brType io.toIbuffer.bits.pd(0).isCall := isCall - io.toIbuffer.bits.pd(0).hasRet := hasRet + io.toIbuffer.bits.pd(0).isRet := isRet when (mmio_resend_af) { io.toIbuffer.bits.exceptionType(0) := ExceptionType.acf @@ -887,11 +887,11 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.bits.enqEnable := f3_mmio_range.asUInt - mmioFlushWb.bits.pd(0).valid := true.B - mmioFlushWb.bits.pd(0).isRVC := currentIsRVC - mmioFlushWb.bits.pd(0).brType := brType - mmioFlushWb.bits.pd(0).isCall := isCall - mmioFlushWb.bits.pd(0).hasRet := hasRet + mmioFlushWb.bits.pd(0).valid := true.B + mmioFlushWb.bits.pd(0).isRVC := currentIsRVC + mmioFlushWb.bits.pd(0).brType := brType + mmioFlushWb.bits.pd(0).isCall := isCall + mmioFlushWb.bits.pd(0).isRet := isRet } mmio_redirect := (f3_req_is_mmio && mmio_state === m_waitCommit && RegNext(fromUncache.fire) && f3_mmio_use_seq_pc) @@ -1077,4 +1077,4 @@ class NewIFU(implicit p: Parameters) extends XSModule reset = reset ) -} +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/NewFtq.scala b/src/main/scala/xiangshan/frontend/NewFtq.scala index 4327910c47b..38c8b0c526e 100644 --- a/src/main/scala/xiangshan/frontend/NewFtq.scala +++ b/src/main/scala/xiangshan/frontend/NewFtq.scala @@ -29,13 +29,13 @@ import utility.ChiselDB class FtqDebugBundle extends Bundle { val pc = UInt(39.W) - val target = UInt(39.W) - val isBr = Bool() - val isJmp = Bool() - val isCall = Bool() - val isRet = Bool() - val misPred = Bool() - val isTaken = Bool() + val target = UInt(39.W) + val isBr = Bool() + val isJmp = Bool() + val isCall = Bool() + val isRet = Bool() + val misPred = Bool() + val isTaken = Bool() val predStage = UInt(2.W) } @@ -109,23 +109,21 @@ class Ftq_RF_Components(implicit p: Parameters) extends XSBundle with BPUUtils { class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { val brMask = Vec(PredictWidth, Bool()) - // jmpInfo(0) = jalr, jmpInfo(1) = isCall, jmpInfo(2) = hasRet val jmpInfo = ValidUndirectioned(Vec(3, Bool())) val jmpOffset = UInt(log2Ceil(PredictWidth).W) val jalTarget = UInt(VAddrBits.W) val rvcMask = Vec(PredictWidth, Bool()) def hasJal = jmpInfo.valid && !jmpInfo.bits(0) def hasJalr = jmpInfo.valid && jmpInfo.bits(0) - def isCall = jmpInfo.valid && jmpInfo.bits(1) - def isRet = jmpInfo.valid && jmpInfo.bits(2) && !jmpInfo.bits(1) - def isRetCall = jmpInfo.valid && jmpInfo.bits(2) && jmpInfo.bits(1) + def hasCall = jmpInfo.valid && jmpInfo.bits(1) + def hasRet = jmpInfo.valid && jmpInfo.bits(2) def fromPdWb(pdWb: PredecodeWritebackBundle) = { val pds = pdWb.pd this.brMask := VecInit(pds.map(pd => pd.isBr && pd.valid)) this.jmpInfo.valid := VecInit(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)).asUInt.orR this.jmpInfo.bits := ParallelPriorityMux(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid), - pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.hasRet))) + pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.isRet))) this.jmpOffset := ParallelPriorityEncoder(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)) this.rvcMask := VecInit(pds.map(pd => pd.isRVC)) this.jalTarget := pdWb.jalTarget @@ -138,9 +136,9 @@ class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle { pd.isRVC := rvcMask(offset) val isBr = brMask(offset) val isJalr = offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(0) - pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr) - pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1) - pd.hasRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2) + pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr) + pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1) + pd.isRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2) pd } } @@ -274,7 +272,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire val new_jmp_is_jal = entry_has_jmp && !pd.jmpInfo.bits(0) && io.cfiIndex.valid val new_jmp_is_jalr = entry_has_jmp && pd.jmpInfo.bits(0) && io.cfiIndex.valid val new_jmp_is_call = entry_has_jmp && pd.jmpInfo.bits(1) && io.cfiIndex.valid - val new_jmp_has_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid + val new_jmp_is_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid val last_jmp_rvi = entry_has_jmp && pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask.last // val last_br_rvi = cfi_is_br && io.cfiIndex.bits === (PredictWidth-1).U && !pd.rvcMask.last @@ -306,9 +304,9 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire val jmpPft = getLower(io.start_addr) +& pd.jmpOffset +& Mux(pd.rvcMask(pd.jmpOffset), 1.U, 2.U) init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr)) init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B) - init_entry.isJalr := new_jmp_is_jalr - init_entry.isCall := new_jmp_is_call - init_entry.hasRet := new_jmp_has_ret + init_entry.isJalr := new_jmp_is_jalr + init_entry.isCall := new_jmp_is_call + init_entry.isRet := new_jmp_is_ret // that means fall thru points to the middle of an inst init_entry.last_may_be_rvi_call := pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask(pd.jmpOffset) @@ -370,7 +368,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire old_entry_modified.carry := (getLower(io.start_addr) +& new_pft_offset).head(1).asBool old_entry_modified.last_may_be_rvi_call := false.B old_entry_modified.isCall := false.B - old_entry_modified.hasRet := false.B + old_entry_modified.isRet := false.B old_entry_modified.isJalr := false.B } @@ -969,8 +967,7 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe ((pred_ftb_entry.isJal && !(jmp_pd.valid && jmp_pd.isJal)) || (pred_ftb_entry.isJalr && !(jmp_pd.valid && jmp_pd.isJalr)) || (pred_ftb_entry.isCall && !(jmp_pd.valid && jmp_pd.isCall)) || - (pred_ftb_entry.isRet && !(jmp_pd.valid && jmp_pd.isRet)) || - (pred_ftb_entry.isRetCall && !(jmp_pd.valid && jmp_pd.isRetCall)) + (pred_ftb_entry.isRet && !(jmp_pd.valid && jmp_pd.isRet)) ) has_false_hit := br_false_hit || jal_false_hit || hit_pd_mispred_reg @@ -1367,7 +1364,6 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe update.spec_info := commit_spec_meta XSError(commit_valid && do_commit && debug_cfi, "\ncommit cfi can be non c_commited\n") - val commit_real_hit = commit_hit === h_hit val update_ftb_entry = update.ftb_entry @@ -1448,8 +1444,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe logbundle.target := target logbundle.isBr := isBr logbundle.isJmp := isJmp - logbundle.isCall := isJmp && commit_pd.isCall - logbundle.isRet := isJmp && commit_pd.isRet + logbundle.isCall := isJmp && commit_pd.hasCall + logbundle.isRet := isJmp && commit_pd.hasRet logbundle.misPred := misPred logbundle.isTaken := isTaken logbundle.predStage := commit_stage @@ -1483,8 +1479,6 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe XSPerfAccumulate("fromBackendRedirect_ValidNum", io.fromBackend.redirect.valid) XSPerfAccumulate("toBpuRedirect_ValidNum", io.toBpu.redirect.valid) - XSPerfAccumulate("ret_call_Num", io.toBpu.update.valid && io.toBpu.update.bits.ftb_entry.isRetCall && io.toBpu.update.bits.ftb_entry.valid && io.toBpu.update.bits.ftb_entry.tailSlot.valid) - val from_bpu = io.fromBpu.resp.bits val to_ifu = io.toIfu.req.bits @@ -1496,8 +1490,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe val commit_jal_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJal.asTypeOf(UInt(1.W))) val commit_jalr_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJalr.asTypeOf(UInt(1.W))) - val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isCall.asTypeOf(UInt(1.W))) - val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isRet.asTypeOf(UInt(1.W))) + val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasCall.asTypeOf(UInt(1.W))) + val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasRet.asTypeOf(UInt(1.W))) val mbpBRights = mbpRights & commit_br_mask @@ -1692,4 +1686,4 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe ("ftb_hit ", PopCount(ftb_hit) ), ) generatePerfEvent() -} +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/PreDecode.scala b/src/main/scala/xiangshan/frontend/PreDecode.scala index b84c068e6ae..6484947048d 100644 --- a/src/main/scala/xiangshan/frontend/PreDecode.scala +++ b/src/main/scala/xiangshan/frontend/PreDecode.scala @@ -35,11 +35,9 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons val brType::Nil = ListLookup(instr, List(BrType.notCFI), PreDecodeInst.brTable) val rd = Mux(isRVC(instr), instr(12), instr(11,7)) val rs = Mux(isRVC(instr), Mux(brType === BrType.jal, 0.U, instr(11, 7)), instr(19, 15)) - val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64 - val isRet = brType === BrType.jalr && isLink(rs) && !isCall - val isRetCall = brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd) - val hasRet = isRet || isRetCall - List(brType, isCall, hasRet) + val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64 + val isRet = (brType === BrType.jalr && isLink(rs) && !isCall) || (brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd)) + List(brType, isCall, isRet) } def jal_offset(inst: UInt, rvc: Bool): UInt = { val rvc_offset = Cat(inst(12), inst(8), inst(10, 9), inst(6), inst(7), inst(2), inst(11), inst(5, 3), 0.U(1.W)) @@ -58,7 +56,7 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons } object BrType { - def notCFI = "b00".U + def notCFI = "b00".U def branch = "b01".U def jal = "b10".U def jalr = "b11".U @@ -66,7 +64,7 @@ object BrType { } object ExcType { //TODO:add exctype - def notExc = "b000".U + def notExc = "b000".U def apply() = UInt(3.W) } @@ -75,16 +73,13 @@ class PreDecodeInfo extends Bundle { // 8 bit val isRVC = Bool() val brType = UInt(2.W) val isCall = Bool() - val hasRet = Bool() // maybe ret or ret-call - // val isCall = Bool() - // val isRet = Bool() + val isRet = Bool() //val excType = UInt(3.W) def isBr = brType === BrType.branch def isJal = brType === BrType.jal def isJalr = brType === BrType.jalr def notCFI = brType === BrType.notCFI - def isRet = hasRet && !isCall - def isRetCall = isCall && hasRet + def onlyRet = isRet && !isCall } class PreDecodeResp(implicit p: Parameters) extends XSBundle with HasPdConst { @@ -139,26 +134,26 @@ class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{ val currentPC = io.in.bits.pc(i) //expander.io.in := inst - val brType::isCall::hasRet::Nil = brInfo(inst) + val brType::isCall::isRet::Nil = brInfo(inst) val jalOffset = jal_offset(inst, currentIsRVC(i)) val brOffset = br_offset(inst, currentIsRVC(i)) - io.out.hasHalfValid(i) := h_validStart(i) + io.out.hasHalfValid(i) := h_validStart(i) - io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B)) + io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B)) - io.out.pd(i).valid := validStart(i) - io.out.pd(i).isRVC := currentIsRVC(i) + io.out.pd(i).valid := validStart(i) + io.out.pd(i).isRVC := currentIsRVC(i) // for diff purpose only - io.out.pd(i).brType := brType - io.out.pd(i).isCall := isCall - io.out.pd(i).hasRet := hasRet + io.out.pd(i).brType := brType + io.out.pd(i).isCall := isCall + io.out.pd(i).isRet := isRet //io.out.expInstr(i) := expander.io.out.bits - io.out.instr(i) :=inst - io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) + io.out.instr(i) :=inst + io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset) } // the first half is always reliable @@ -267,9 +262,9 @@ class F3Predecoder(implicit p: Parameters) extends XSModule with HasPdConst { io.out.pd.zipWithIndex.map{ case (pd,i) => pd.valid := DontCare pd.isRVC := DontCare - pd.brType := brInfo(io.in.instr(i))(0) - pd.isCall := brInfo(io.in.instr(i))(1) - pd.hasRet := brInfo(io.in.instr(i))(2) + pd.brType := brInfo(io.in.instr(i))(0) + pd.isCall := brInfo(io.in.instr(i))(1) + pd.isRet := brInfo(io.in.instr(i))(2) } } @@ -349,7 +344,7 @@ class PredChecker(implicit p: Parameters) extends XSModule with HasPdConst { //Stage 1: detect remask fault /** first check: remask Fault */ jalFaultVec := VecInit(pds.zipWithIndex.map{case(pd, i) => pd.isJal && instrRange(i) && instrValid(i) && (takenIdx > i.U && predTaken || !predTaken) }) - retFaultVec := VecInit(pds.zipWithIndex.map{case(pd, i) => pd.isRet && instrRange(i) && instrValid(i) && (takenIdx > i.U && predTaken || !predTaken) }) + retFaultVec := VecInit(pds.zipWithIndex.map{case(pd, i) => pd.onlyRet && instrRange(i) && instrValid(i) && (takenIdx > i.U && predTaken || !predTaken) }) val remaskFault = VecInit((0 until PredictWidth).map(i => jalFaultVec(i) || retFaultVec(i))) val remaskIdx = ParallelPriorityEncoder(remaskFault.asUInt) val needRemask = ParallelOR(remaskFault) @@ -357,7 +352,7 @@ class PredChecker(implicit p: Parameters) extends XSModule with HasPdConst { io.out.stage1Out.fixedRange := fixedRange.asTypeOf((Vec(PredictWidth, Bool()))) - io.out.stage1Out.fixedTaken := VecInit(pds.zipWithIndex.map{case(pd, i) => instrValid (i) && fixedRange(i) && (pd.isRet || pd.isJal || takenIdx === i.U && predTaken && !pd.notCFI) }) + io.out.stage1Out.fixedTaken := VecInit(pds.zipWithIndex.map{case(pd, i) => instrValid (i) && fixedRange(i) && (pd.onlyRet || pd.isJal || takenIdx === i.U && predTaken && !pd.notCFI) }) /** second check: faulse prediction fault and target fault */ notCFITaken := VecInit(pds.zipWithIndex.map{case(pd, i) => fixedRange(i) && instrValid(i) && i.U === takenIdx && pd.notCFI && predTaken }) @@ -442,4 +437,4 @@ class FrontendTrigger(implicit p: Parameters) extends XSModule with SdtrigExt { } io.triggered.foreach(_.backendCanFire := VecInit(Seq.fill(TriggerNum)(false.B))) io.triggered.foreach(_.backendHit := VecInit(Seq.fill(TriggerNum)(false.B))) -} +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/newRAS.scala b/src/main/scala/xiangshan/frontend/newRAS.scala index 2d7a82cc1a7..e0ed432638d 100644 --- a/src/main/scala/xiangshan/frontend/newRAS.scala +++ b/src/main/scala/xiangshan/frontend/newRAS.scala @@ -103,8 +103,8 @@ class RAS(implicit p: Parameters) extends BasePredictor { val io = IO(new Bundle { val spec_push_valid = Input(Bool()) val spec_pop_valid = Input(Bool()) + val spec_is_ret = Input(Bool()) val spec_push_addr = Input(UInt(VAddrBits.W)) - val spec_has_ret = Input(Bool()) // for write bypass between s2 and s3 val s2_fire = Input(Bool()) @@ -113,36 +113,36 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_meta = Input(new RASInternalMeta) val s3_missed_pop = Input(Bool()) val s3_missed_push = Input(Bool()) + val s3_is_ret = Input(Bool()) val s3_pushAddr = Input(UInt(VAddrBits.W)) - val s3_has_ret = Input(Bool()) val spec_pop_addr = Output(UInt(VAddrBits.W)) val commit_push_valid = Input(Bool()) val commit_pop_valid = Input(Bool()) - val commit_has_ret = Input(Bool()) + val commit_is_ret = Input(Bool()) val commit_push_addr = Input(UInt(VAddrBits.W)) val commit_meta_TOSW = Input(new RASPtr) // for debug purpose only val commit_meta_ssp = Input(UInt(log2Up(RasSize).W)) - val redirect_valid = Input(Bool()) - val redirect_isCall = Input(Bool()) - val redirect_isRet = Input(Bool()) - val redirect_has_ret = Input(Bool()) - val redirect_meta_ssp = Input(UInt(log2Up(RasSize).W)) + val redirect_valid = Input(Bool()) + val redirect_push_valid = Input(Bool()) + val redirect_pop_valid = Input(Bool()) + val redirect_is_ret = Input(Bool()) + val redirect_meta_ssp = Input(UInt(log2Up(RasSize).W)) val redirect_meta_sctr = Input(UInt(RasCtrSize.W)) val redirect_meta_TOSW = Input(new RASPtr) val redirect_meta_TOSR = Input(new RASPtr) val redirect_meta_NOS = Input(new RASPtr) val redirect_callAddr = Input(UInt(VAddrBits.W)) - val ssp = Output(UInt(log2Up(RasSize).W)) - val sctr = Output(UInt(RasCtrSize.W)) - val nsp = Output(UInt(log2Up(RasSize).W)) - val TOSR = Output(new RASPtr) - val TOSW = Output(new RASPtr) - val NOS = Output(new RASPtr) - val BOS = Output(new RASPtr) + val ssp = Output(UInt(log2Up(RasSize).W)) + val sctr = Output(UInt(RasCtrSize.W)) + val nsp = Output(UInt(log2Up(RasSize).W)) + val TOSR = Output(new RASPtr) + val TOSW = Output(new RASPtr) + val NOS = Output(new RASPtr) + val BOS = Output(new RASPtr) val debug = new RASDebug }) @@ -154,11 +154,10 @@ class RAS(implicit p: Parameters) extends BasePredictor { val nsp = RegInit(0.U(log2Up(rasSize).W)) val ssp = RegInit(0.U(log2Up(rasSize).W)) - val sctr = RegInit(0.U(RasCtrSize.W)) - val TOSR = RegInit(RASPtr(true.B, (RasSpecSize - 1).U)) - val TOSW = RegInit(RASPtr(false.B, 0.U)) - val BOS = RegInit(RASPtr(false.B, 0.U)) - val NOS = RegInit(RASPtr(false.B, 0.U)) + val sctr = RegInit(0.U(RasCtrSize.W)) + val TOSR = RegInit(RASPtr(true.B, (RasSpecSize - 1).U)) + val TOSW = RegInit(RASPtr(false.B, 0.U)) + val BOS = RegInit(RASPtr(false.B, 0.U)) val spec_overflowed = RegInit(false.B) @@ -225,12 +224,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { def specPtrInc(ptr: RASPtr) = ptr + 1.U def specPtrDec(ptr: RASPtr) = ptr - 1.U - - - - - - when (io.redirect_valid && io.redirect_isCall) { + when (io.redirect_valid && io.redirect_push_valid) { writeBypassValidWire := true.B writeBypassValid := true.B } .elsewhen (io.redirect_valid) { @@ -256,15 +250,15 @@ class RAS(implicit p: Parameters) extends BasePredictor { val writeEntry = Wire(new RASEntry) val writeNos = Wire(new RASPtr) - writeEntry.retAddr := Mux(io.redirect_valid && io.redirect_isCall, io.redirect_callAddr, io.spec_push_addr) - writeEntry.ctr := Mux(io.redirect_valid && io.redirect_isCall, - Mux(io.redirect_has_ret, 0.U, Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U)), - Mux(io.spec_has_ret, 0.U, Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U))) + writeEntry.retAddr := Mux(io.redirect_valid && io.redirect_push_valid, io.redirect_callAddr, io.spec_push_addr) + writeEntry.ctr := Mux(io.redirect_valid && io.redirect_push_valid, + Mux(io.redirect_is_ret, 0.U, Mux(redirectTopEntry.retAddr === io.redirect_callAddr && redirectTopEntry.ctr < ctrMax, io.redirect_meta_sctr + 1.U, 0.U)), + Mux(io.spec_is_ret, 0.U, Mux(topEntry.retAddr === io.spec_push_addr && topEntry.ctr < ctrMax, sctr + 1.U, 0.U))) - writeNos := Mux(io.redirect_valid && io.redirect_isCall, - Mux(io.redirect_has_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_has_ret, topNos,TOSR)) + writeNos := Mux(io.redirect_valid && io.redirect_push_valid, + Mux(io.redirect_is_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_is_ret, topNos,TOSR)) - when (io.spec_push_valid || (io.redirect_valid && io.redirect_isCall)) { + when (io.spec_push_valid || (io.redirect_valid && io.redirect_push_valid)) { writeBypassEntry := writeEntry writeBypassNos := writeNos } @@ -275,7 +269,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { val timingNos = RegInit(0.U.asTypeOf(new RASPtr)) when (writeBypassValidWire) { - when ((io.redirect_valid && io.redirect_isCall) || io.spec_push_valid) { + when ((io.redirect_valid && io.redirect_push_valid) || io.spec_push_valid) { timingTop := writeEntry timingNos := writeNos } .otherwise { @@ -283,7 +277,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { timingNos := writeBypassNos } - } .elsewhen (io.redirect_valid && io.redirect_isRet) { + } .elsewhen (io.redirect_valid && io.redirect_pop_valid) { // getTop using redirect Nos as TOSR val popRedSsp = Wire(UInt(log2Up(rasSize).W)) val popRedSctr = Wire(UInt(RasCtrSize.W)) @@ -373,7 +367,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { XSPerfAccumulate("ras_top_mismatch", diffTop =/= timingTop.retAddr); // could diff when more pop than push and a commit stack is updated with inflight info - val realWriteEntry_next = RegEnable(writeEntry, io.s2_fire || io.redirect_isCall) + val realWriteEntry_next = RegEnable(writeEntry, io.s2_fire || io.redirect_push_valid) val s3_missPushEntry = Wire(new RASEntry) val s3_missPushAddr = Wire(new RASPtr) val s3_missPushNos = Wire(new RASPtr) @@ -385,20 +379,20 @@ class RAS(implicit p: Parameters) extends BasePredictor { - realWriteEntry := Mux(io.redirect_isCall, realWriteEntry_next, + realWriteEntry := Mux(io.redirect_push_valid, realWriteEntry_next, Mux(io.s3_missed_push, s3_missPushEntry, realWriteEntry_next)) - val realWriteAddr_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, io.redirect_meta_TOSW, TOSW), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) - val realWriteAddr = Mux(io.redirect_isCall, realWriteAddr_next, + val realWriteAddr_next = RegEnable(Mux(io.redirect_valid && io.redirect_push_valid, io.redirect_meta_TOSW, TOSW), io.s2_fire || (io.redirect_valid && io.redirect_push_valid)) + val realWriteAddr = Mux(io.redirect_push_valid, realWriteAddr_next, Mux(io.s3_missed_push, s3_missPushAddr, realWriteAddr_next)) - val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_isCall, Mux(io.redirect_has_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_has_ret, topNos, TOSR)), io.s2_fire || (io.redirect_valid && io.redirect_isCall)) - val realNos = Mux(io.redirect_isCall, realNos_next, - Mux(io.s3_missed_push, Mux(io.s3_has_ret, io.s3_meta.NOS, io.s3_meta.TOSR), + val realNos_next = RegEnable(Mux(io.redirect_valid && io.redirect_push_valid, Mux(io.redirect_is_ret, io.redirect_meta_NOS, io.redirect_meta_TOSR), Mux(io.spec_is_ret, topNos, TOSR)), io.s2_fire || (io.redirect_valid && io.redirect_push_valid)) + val realNos = Mux(io.redirect_push_valid, realNos_next, + Mux(io.s3_missed_push, Mux(io.s3_is_ret, io.s3_meta.NOS, io.s3_meta.TOSR), realNos_next)) - realPush := (io.s3_fire && (!io.s3_cancel && RegEnable(io.spec_push_valid, io.s2_fire) || io.s3_missed_push)) || RegNext(io.redirect_valid && io.redirect_isCall) + realPush := (io.s3_fire && (!io.s3_cancel && RegEnable(io.spec_push_valid, io.s2_fire) || io.s3_missed_push)) || RegNext(io.redirect_valid && io.redirect_push_valid) when (realPush) { spec_queue(realWriteAddr.value) := realWriteEntry @@ -423,7 +417,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.spec_push_valid) { - when(!io.spec_has_ret) { + when(!io.spec_is_ret) { specPush(io.spec_push_addr, ssp, sctr, TOSR, TOSW, topEntry) } .otherwise { TOSR := TOSW @@ -479,7 +473,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { } when (io.s3_missed_push) { // do not use any bypass from f2 - when(!io.s3_has_ret) { + when(!io.s3_is_ret) { specPush(io.s3_pushAddr, io.s3_meta.ssp, io.s3_meta.sctr, io.s3_meta.TOSR, io.s3_meta.TOSW, s3TopEntry) }.otherwise { TOSR := io.s3_meta.TOSW @@ -525,7 +519,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { nsp_update := nsp } // if ctr < max && topAddr == push addr, ++ctr, otherwise ++nsp - when(!io.commit_has_ret){ + when(!io.commit_is_ret){ when (commitTop.ctr < ctrMax && commitTop.retAddr === commit_push_addr) { commit_stack(nsp_update).ctr := commitTop.ctr + 1.U nsp := nsp_update @@ -562,8 +556,8 @@ class RAS(implicit p: Parameters) extends BasePredictor { ssp := io.redirect_meta_ssp sctr := io.redirect_meta_sctr - when (io.redirect_isCall) { - when (!io.redirect_has_ret) { + when (io.redirect_push_valid) { + when (!io.redirect_is_ret) { specPush(io.redirect_callAddr, io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopEntry) } .otherwise { TOSR := io.redirect_meta_TOSW @@ -572,7 +566,7 @@ class RAS(implicit p: Parameters) extends BasePredictor { ssp := Mux(io.redirect_meta_sctr > 0.U, ptrInc(io.redirect_meta_ssp), io.redirect_meta_ssp) } } - when (io.redirect_isRet) { + when (io.redirect_pop_valid) { specPop(io.redirect_meta_ssp, io.redirect_meta_sctr, io.redirect_meta_TOSR, io.redirect_meta_TOSW, redirectTopNos) } } @@ -584,30 +578,30 @@ class RAS(implicit p: Parameters) extends BasePredictor { val stack = Module(new RASStack(RasSize, RasSpecSize)).io - val s2_spec_push = WireInit(false.B) - val s2_spec_pop = WireInit(false.B) - val s2_has_ret = WireInit(false.B) - val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred(2) + val s2_spec_push = WireInit(false.B) + val s2_spec_pop = WireInit(false.B) + val s2_is_ret = WireInit(false.B) + val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred(2) // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U) stack.spec_push_valid := s2_spec_push stack.spec_pop_valid := s2_spec_pop + stack.spec_is_ret := s2_is_ret stack.spec_push_addr := s2_spec_new_addr - stack.spec_has_ret := s2_has_ret // confirm that the call/ret is the taken cfi s2_spec_push := io.s2_fire(2) && s2_full_pred.hit_taken_on_call && !io.s3_redirect(2) - s2_spec_pop := io.s2_fire(2) && s2_full_pred.hit_taken_on_ret && !io.s3_redirect(2) - s2_has_ret := s2_full_pred.has_ret + s2_spec_pop := io.s2_fire(2) && s2_full_pred.hit_taken_on_only_ret && !io.s3_redirect(2) + s2_is_ret := s2_full_pred.is_ret //val s2_jalr_target = io.out.s2.full_pred.jalr_target //val s2_last_target_in = s2_full_pred.targets.last // val s2_last_target_out = io.out.s2.full_pred(2).targets.last val s2_is_jalr = s2_full_pred.is_jalr - val s2_is_ret = s2_full_pred.is_ret + val s2_only_ret = s2_full_pred.is_ret && !s2_full_pred.is_call val s2_top = stack.spec_pop_addr // assert(is_jalr && is_ret || !is_ret) - when(s2_is_ret && io.ctrl.ras_enable) { + when(s2_only_ret && io.ctrl.ras_enable) { io.out.s2.full_pred.map(_.jalr_target).foreach(_ := s2_top) // FIXME: should use s1 globally } @@ -626,13 +620,10 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_top = RegEnable(stack.spec_pop_addr, io.s2_fire(2)) val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire(2)) - // val s3_jalr_target = io.out.s3.full_pred.jalr_target - // val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred(2).targets.last - // val s3_last_target_out = io.out.s3.full_pred(2).targets.last val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred(2).is_jalr - val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred(2).is_ret + val s3_only_ret = io.in.bits.resp_in(0).s3.full_pred(2).is_ret && !io.in.bits.resp_in(0).s3.full_pred(2).is_call // assert(is_jalr && is_ret || !is_ret) - when(s3_is_ret && io.ctrl.ras_enable) { + when(s3_only_ret && io.ctrl.ras_enable) { io.out.s3.full_pred.map(_.jalr_target).foreach(_ := s3_top) // FIXME: should use s1 globally } @@ -643,9 +634,9 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire(2)) val s3_popped_in_s2 = RegEnable(s2_spec_pop, io.s2_fire(2)) - val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call - val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret - val s3_has_ret = io.in.bits.resp_in(0).s3.full_pred(2).has_ret + val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call + val s3_pop = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_only_ret + val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred(2).is_ret val s3_cancel = io.s3_fire(2) && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop) stack.s2_fire := io.s2_fire(2) @@ -655,11 +646,11 @@ class RAS(implicit p: Parameters) extends BasePredictor { val s3_meta = RegEnable(s2_meta, io.s2_fire(2)) - stack.s3_meta := s3_meta + stack.s3_meta := s3_meta stack.s3_missed_pop := s3_pop && !s3_popped_in_s2 stack.s3_missed_push := s3_push && !s3_pushed_in_s2 + stack.s3_is_ret := s3_is_ret stack.s3_pushAddr := s3_spec_new_addr - stack.s3_has_ret := s3_has_ret // no longer need the top Entry, but TOSR, TOSW, ssp sctr // TODO: remove related signals @@ -677,18 +668,18 @@ class RAS(implicit p: Parameters) extends BasePredictor { io.out.last_stage_meta := last_stage_meta.asUInt - val redirect = RegNextWithEnable(io.redirect) - val do_recover = redirect.valid + val redirect = RegNextWithEnable(io.redirect) + val do_recover = redirect.valid val recover_cfi = redirect.bits.cfiUpdate - val retMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet + val retMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.onlyRet val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall // when we mispredict a call, we must redo a push operation // similarly, when we mispredict a return, we should redo a pop - stack.redirect_valid := do_recover - stack.redirect_isCall := callMissPred - stack.redirect_isRet := retMissPred - stack.redirect_has_ret := recover_cfi.pd.isRetCall + stack.redirect_valid := do_recover + stack.redirect_push_valid := callMissPred + stack.redirect_pop_valid := recover_cfi.pd.isRet + stack.redirect_is_ret := retMissPred stack.redirect_meta_ssp := recover_cfi.ssp stack.redirect_meta_sctr := recover_cfi.sctr stack.redirect_meta_TOSW := recover_cfi.TOSW @@ -697,12 +688,12 @@ class RAS(implicit p: Parameters) extends BasePredictor { stack.redirect_callAddr := recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U) val update = io.update.bits - val updateMeta = io.update.bits.meta.asTypeOf(new RASMeta) + val updateMeta = io.update.bits.meta.asTypeOf(new RASMeta) val updateValid = io.update.valid stack.commit_push_valid := updateValid && update.is_call_taken - stack.commit_pop_valid := updateValid && update.is_ret_taken - stack.commit_has_ret := update.has_ret + stack.commit_pop_valid := updateValid && update.is_only_ret_taken + stack.commit_is_ret := updateValid && update.is_ret stack.commit_push_addr := update.ftb_entry.getFallThrough(update.pc) + Mux(update.ftb_entry.last_may_be_rvi_call, 2.U, 0.U) stack.commit_meta_TOSW := updateMeta.TOSW stack.commit_meta_ssp := updateMeta.ssp