Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAS: handle a case where a jalr instruction requires a pop followed by a push #3277

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/main/scala/xiangshan/frontend/FTB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,13 @@ class FtbSlot(val offsetLen: Int, val subOffsetLen: Option[Int] = None)(implicit


class FTBEntry_part(implicit p: Parameters) extends XSBundle with FTBParams with BPUUtils {
val isCall = Bool()
val isRet = Bool()
val isJalr = Bool()
val isRetCall = Bool()
val isCall = Bool()
val hasRet = Bool() // maybe ret or ret-call
val isJalr = Bool()

def isJal = !isJalr
def isRet = !isCall && hasRet
def isRetCall = isCall && hasRet
}

class FTBEntry_FtqMem(implicit p: Parameters) extends FTBEntry_part with FTBParams with BPUUtils {
Expand Down Expand Up @@ -372,7 +373,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with
val pftAddrDiff = this.pftAddr === that.pftAddr
val carryDiff = this.carry === that.carry
val isCallDiff = this.isCall === that.isCall
val isRetDiff = this.isRet === that.isRet
val hasRetDiff = this.hasRet === that.hasRet
val isJalrDiff = this.isJalr === that.isJalr
val lastMayBeRviCallDiff = this.last_may_be_rvi_call === that.last_may_be_rvi_call
val alwaysTakenDiff : IndexedSeq[Bool] =
Expand All @@ -386,7 +387,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with
pftAddrDiff,
carryDiff,
isCallDiff,
isRetDiff,
hasRetDiff,
isJalrDiff,
lastMayBeRviCallDiff,
alwaysTakenDiff.reduce(_&&_)
Expand All @@ -403,7 +404,7 @@ class FTBEntry(implicit p: Parameters) extends FTBEntry_part with FTBParams with
XSDebug(cond, p"[tailSlot]: v=${tailSlot.valid}, offset=${tailSlot.offset}," +
p"lower=${Hexadecimal(tailSlot.lower)}, sharing=${tailSlot.sharing}}\n")
XSDebug(cond, p"pftAddr=${Hexadecimal(pftAddr)}, carry=$carry\n")
XSDebug(cond, p"isCall=$isCall, isRet=$isRet, isjalr=$isJalr\n")
XSDebug(cond, p"isCall=$isCall, hasRet=$hasRet, isjalr=$isJalr\n")
XSDebug(cond, p"last_may_be_rvi_call=$last_may_be_rvi_call\n")
XSDebug(cond, p"------------------------------- \n")
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/xiangshan/frontend/FrontendBundle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC
val is_jalr = Bool()
val is_call = Bool()
val is_ret = Bool()
val is_ret_call = Bool()
val has_ret = Bool() // only used for the ret-call behavior in RAS
val last_may_be_rvi_call = Bool()
val is_br_sharing = Bool()

Expand Down Expand Up @@ -537,8 +537,8 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC
is_jal := entry.tailSlot.valid && entry.isJal
is_jalr := entry.tailSlot.valid && entry.isJalr
is_call := entry.tailSlot.valid && entry.isCall
is_ret := entry.tailSlot.valid && entry.isRet
is_ret_call := entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone
is_ret := entry.tailSlot.valid && entry.hasRet && !entry.isCall
has_ret := entry.hasRet
last_may_be_rvi_call := entry.last_may_be_rvi_call
is_br_sharing := entry.tailSlot.valid && entry.tailSlot.sharing
predCycle.map(_ := GTimer())
Expand Down Expand Up @@ -649,11 +649,11 @@ class BranchPredictionUpdate(implicit p: Parameters) extends XSBundle with HasBP
val from_stage = UInt(2.W)
val ghist = UInt(HistoryLength.W)

def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal
def is_jal = ftb_entry.tailSlot.valid && ftb_entry.isJal
def is_jalr = ftb_entry.tailSlot.valid && ftb_entry.isJalr
def is_call = ftb_entry.tailSlot.valid && ftb_entry.isCall
def is_ret = ftb_entry.tailSlot.valid && ftb_entry.isRet
def is_ret_call = ftb_entry.isRetCall // The is_ret_call signal depends on the is_call signal and will not be pulled high alone
def is_ret = ftb_entry.tailSlot.valid && ftb_entry.hasRet && !ftb_entry.isCall
def has_ret = ftb_entry.hasRet

def is_call_taken = is_call && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset
def is_ret_taken = is_ret && jmp_taken && cfi_idx.valid && cfi_idx.bits === ftb_entry.tailSlot.offset
Expand Down
25 changes: 11 additions & 14 deletions src/main/scala/xiangshan/frontend/IFU.scala
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,15 @@ class NewIFU(implicit p: Parameters) extends XSModule
// Expand 1 bit to prevent overflow when assert
val f3_ftq_req_startAddr = Cat(0.U(1.W), f3_ftq_req.startAddr)
val f3_ftq_req_nextStartAddr = Cat(0.U(1.W), f3_ftq_req.nextStartAddr)
// brType, isCall and isRet generation is delayed to f3 stage
// brType, isCall and hasRet generation is delayed to f3 stage
val f3Predecoder = Module(new F3Predecoder)

f3Predecoder.io.in.instr := f3_instr

f3_pd.zipWithIndex.map{ case (pd,i) =>
pd.brType := f3Predecoder.io.out.pd(i).brType
pd.isCall := f3Predecoder.io.out.pd(i).isCall
pd.isRet := f3Predecoder.io.out.pd(i).isRet
pd.isRetCall := f3Predecoder.io.out.pd(i).isRetCall
pd.brType := f3Predecoder.io.out.pd(i).brType
pd.isCall := f3Predecoder.io.out.pd(i).isCall
pd.hasRet := f3Predecoder.io.out.pd(i).hasRet
}

val f3PdDiff = f3_pd_wire.zip(f3_pd).map{ case (a,b) => a.asUInt =/= b.asUInt }.reduce(_||_)
Expand Down Expand Up @@ -864,7 +863,7 @@ class NewIFU(implicit p: Parameters) extends XSModule
val inst = Cat(f3_mmio_data(1), f3_mmio_data(0))
val currentIsRVC = isRVC(inst)

val brType::isCall::isRet::isRetCall::Nil = brInfo(inst)
val brType::isCall::hasRet::Nil = brInfo(inst)
val jalOffset = jal_offset(inst, currentIsRVC)
val brOffset = br_offset(inst, currentIsRVC)

Expand All @@ -875,8 +874,7 @@ class NewIFU(implicit p: Parameters) extends XSModule
io.toIbuffer.bits.pd(0).isRVC := currentIsRVC
io.toIbuffer.bits.pd(0).brType := brType
io.toIbuffer.bits.pd(0).isCall := isCall
io.toIbuffer.bits.pd(0).isRet := isRet
io.toIbuffer.bits.pd(0).isRetCall := isRetCall
io.toIbuffer.bits.pd(0).hasRet := hasRet

when (mmio_resend_af) {
io.toIbuffer.bits.exceptionType(0) := ExceptionType.acf
Expand All @@ -889,12 +887,11 @@ class NewIFU(implicit p: Parameters) extends XSModule

io.toIbuffer.bits.enqEnable := f3_mmio_range.asUInt

mmioFlushWb.bits.pd(0).valid := true.B
mmioFlushWb.bits.pd(0).isRVC := currentIsRVC
mmioFlushWb.bits.pd(0).brType := brType
mmioFlushWb.bits.pd(0).isCall := isCall
mmioFlushWb.bits.pd(0).isRet := isRet
mmioFlushWb.bits.pd(0).isRetCall := isRetCall
mmioFlushWb.bits.pd(0).valid := true.B
mmioFlushWb.bits.pd(0).isRVC := currentIsRVC
mmioFlushWb.bits.pd(0).brType := brType
mmioFlushWb.bits.pd(0).isCall := isCall
mmioFlushWb.bits.pd(0).hasRet := hasRet
}

mmio_redirect := (f3_req_is_mmio && mmio_state === m_waitCommit && RegNext(fromUncache.fire) && f3_mmio_use_seq_pc)
Expand Down
54 changes: 24 additions & 30 deletions src/main/scala/xiangshan/frontend/NewFtq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class FtqDebugBundle extends Bundle {
val isJmp = Bool()
val isCall = Bool()
val isRet = Bool()
val isRetCall = Bool()
val misPred = Bool()
val isTaken = Bool()
val predStage = UInt(2.W)
Expand Down Expand Up @@ -110,23 +109,23 @@ class Ftq_RF_Components(implicit p: Parameters) extends XSBundle with BPUUtils {

class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle {
val brMask = Vec(PredictWidth, Bool())
// jmpInfo includes isJalr, isCall, isRet, isRetCall signals
val jmpInfo = ValidUndirectioned(Vec(4, Bool()))
// jmpInfo(0) = jalr, jmpInfo(1) = isCall, jmpInfo(2) = hasRet
val jmpInfo = ValidUndirectioned(Vec(3, Bool()))
val jmpOffset = UInt(log2Ceil(PredictWidth).W)
val jalTarget = UInt(VAddrBits.W)
val rvcMask = Vec(PredictWidth, Bool())
def hasJal = jmpInfo.valid && !jmpInfo.bits(0)
def hasJalr = jmpInfo.valid && jmpInfo.bits(0)
def hasCall = jmpInfo.valid && jmpInfo.bits(1)
def hasRet = jmpInfo.valid && jmpInfo.bits(2)
def hasRetCall = jmpInfo.valid && jmpInfo.bits(3)
def isCall = jmpInfo.valid && jmpInfo.bits(1)
ngc7331 marked this conversation as resolved.
Show resolved Hide resolved
def isRet = jmpInfo.valid && jmpInfo.bits(2) && !jmpInfo.bits(1)
def isRetCall = jmpInfo.valid && jmpInfo.bits(2) && jmpInfo.bits(1)

def fromPdWb(pdWb: PredecodeWritebackBundle) = {
val pds = pdWb.pd
this.brMask := VecInit(pds.map(pd => pd.isBr && pd.valid))
this.jmpInfo.valid := VecInit(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid)).asUInt.orR
this.jmpInfo.bits := ParallelPriorityMux(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid),
pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.isRet, pd.isRetCall)))
pds.map(pd => VecInit(pd.isJalr, pd.isCall, pd.hasRet)))
this.jmpOffset := ParallelPriorityEncoder(pds.map(pd => (pd.isJal || pd.isJalr) && pd.valid))
this.rvcMask := VecInit(pds.map(pd => pd.isRVC))
this.jalTarget := pdWb.jalTarget
Expand All @@ -139,10 +138,9 @@ class Ftq_pd_Entry(implicit p: Parameters) extends XSBundle {
pd.isRVC := rvcMask(offset)
val isBr = brMask(offset)
val isJalr = offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(0)
pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr)
pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1)
pd.isRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2)
pd.isRetCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(3)
pd.brType := Cat(offset === jmpOffset && jmpInfo.valid, isJalr || isBr)
pd.isCall := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(1)
pd.hasRet := offset === jmpOffset && jmpInfo.valid && jmpInfo.bits(2)
pd
}
}
Expand Down Expand Up @@ -276,8 +274,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire
val new_jmp_is_jal = entry_has_jmp && !pd.jmpInfo.bits(0) && io.cfiIndex.valid
val new_jmp_is_jalr = entry_has_jmp && pd.jmpInfo.bits(0) && io.cfiIndex.valid
val new_jmp_is_call = entry_has_jmp && pd.jmpInfo.bits(1) && io.cfiIndex.valid
val new_jmp_is_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid
val new_jmp_is_ret_call = entry_has_jmp && pd.jmpInfo.bits(3) && io.cfiIndex.valid
val new_jmp_has_ret = entry_has_jmp && pd.jmpInfo.bits(2) && io.cfiIndex.valid
val last_jmp_rvi = entry_has_jmp && pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask.last
// val last_br_rvi = cfi_is_br && io.cfiIndex.bits === (PredictWidth-1).U && !pd.rvcMask.last

Expand Down Expand Up @@ -307,12 +304,11 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire
}

val jmpPft = getLower(io.start_addr) +& pd.jmpOffset +& Mux(pd.rvcMask(pd.jmpOffset), 1.U, 2.U)
init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr))
init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B)
init_entry.pftAddr := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft, getLower(io.start_addr))
init_entry.carry := Mux(entry_has_jmp && !last_jmp_rvi, jmpPft(carryPos-instOffsetBits), true.B)
init_entry.isJalr := new_jmp_is_jalr
init_entry.isCall := new_jmp_is_call
init_entry.isRet := new_jmp_is_ret
init_entry.isRetCall := new_jmp_is_ret_call
init_entry.hasRet := new_jmp_has_ret
// that means fall thru points to the middle of an inst
init_entry.last_may_be_rvi_call := pd.jmpOffset === (PredictWidth-1).U && !pd.rvcMask(pd.jmpOffset)

Expand Down Expand Up @@ -374,8 +370,7 @@ class FTBEntryGen(implicit p: Parameters) extends XSModule with HasBackendRedire
old_entry_modified.carry := (getLower(io.start_addr) +& new_pft_offset).head(1).asBool
old_entry_modified.last_may_be_rvi_call := false.B
old_entry_modified.isCall := false.B
old_entry_modified.isRet := false.B
old_entry_modified.isRetCall := false.B
old_entry_modified.hasRet := false.B
old_entry_modified.isJalr := false.B
}

Expand Down Expand Up @@ -1449,15 +1444,14 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe
p"brInEntry(${inFtbEntry}) brIdx(${brIdx}) target(${Hexadecimal(target)})\n")

val logbundle = Wire(new FtqDebugBundle)
logbundle.pc := pc
logbundle.target := target
logbundle.isBr := isBr
logbundle.isJmp := isJmp
logbundle.isCall := isJmp && commit_pd.hasCall
logbundle.isRet := isJmp && commit_pd.hasRet
logbundle.isRetCall := isJmp && commit_pd.hasRetCall
logbundle.misPred := misPred
logbundle.isTaken := isTaken
logbundle.pc := pc
logbundle.target := target
logbundle.isBr := isBr
logbundle.isJmp := isJmp
logbundle.isCall := isJmp && commit_pd.isCall
logbundle.isRet := isJmp && commit_pd.isRet
logbundle.misPred := misPred
logbundle.isTaken := isTaken
logbundle.predStage := commit_stage

ftqBranchTraceDB.log(
Expand Down Expand Up @@ -1502,8 +1496,8 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe

val commit_jal_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJal.asTypeOf(UInt(1.W)))
val commit_jalr_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasJalr.asTypeOf(UInt(1.W)))
val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasCall.asTypeOf(UInt(1.W)))
val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.hasRet.asTypeOf(UInt(1.W)))
val commit_call_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isCall.asTypeOf(UInt(1.W)))
val commit_ret_mask = UIntToOH(commit_pd.jmpOffset) & Fill(PredictWidth, commit_pd.isRet.asTypeOf(UInt(1.W)))


val mbpBRights = mbpRights & commit_br_mask
Expand Down
52 changes: 27 additions & 25 deletions src/main/scala/xiangshan/frontend/PreDecode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUCons
val rs = Mux(isRVC(instr), Mux(brType === BrType.jal, 0.U, instr(11, 7)), instr(19, 15))
val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64
val isRet = brType === BrType.jalr && isLink(rs) && !isCall
val isRetCall = (brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd))
List(brType, isCall, isRet, isRetCall)
val isRetCall = brType === BrType.jalr && isLink(rs) && isLink(rd) && (rs =/= rd)
val hasRet = isRet || isRetCall
List(brType, isCall, hasRet)
ngc7331 marked this conversation as resolved.
Show resolved Hide resolved
}
def jal_offset(inst: UInt, rvc: Bool): UInt = {
val rvc_offset = Cat(inst(12), inst(8), inst(10, 9), inst(6), inst(7), inst(2), inst(11), inst(5, 3), 0.U(1.W))
Expand Down Expand Up @@ -70,17 +71,20 @@ object ExcType { //TODO:add exctype
}

class PreDecodeInfo extends Bundle { // 8 bit
val valid = Bool()
val isRVC = Bool()
val brType = UInt(2.W)
val isCall = Bool()
val isRet = Bool()
val isRetCall = Bool()
val valid = Bool()
val isRVC = Bool()
val brType = UInt(2.W)
val isCall = Bool()
val hasRet = Bool() // maybe ret or ret-call
// val isCall = Bool()
// val isRet = Bool()
//val excType = UInt(3.W)
def isBr = brType === BrType.branch
def isJal = brType === BrType.jal
def isJalr = brType === BrType.jalr
def notCFI = brType === BrType.notCFI
def isRet = hasRet && !isCall
def isRetCall = isCall && hasRet
}

class PreDecodeResp(implicit p: Parameters) extends XSBundle with HasPdConst {
Expand Down Expand Up @@ -135,27 +139,26 @@ class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{
val currentPC = io.in.bits.pc(i)
//expander.io.in := inst

val brType::isCall::isRet::isRetCall::Nil = brInfo(inst)
val brType::isCall::hasRet::Nil = brInfo(inst)
val jalOffset = jal_offset(inst, currentIsRVC(i))
val brOffset = br_offset(inst, currentIsRVC(i))

io.out.hasHalfValid(i) := h_validStart(i)
io.out.hasHalfValid(i) := h_validStart(i)

io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B))
io.out.triggered(i) := DontCare//VecInit(Seq.fill(10)(false.B))


io.out.pd(i).valid := validStart(i)
io.out.pd(i).isRVC := currentIsRVC(i)
io.out.pd(i).valid := validStart(i)
io.out.pd(i).isRVC := currentIsRVC(i)

// for diff purpose only
io.out.pd(i).brType := brType
io.out.pd(i).isCall := isCall
io.out.pd(i).isRet := isRet
io.out.pd(i).isRetCall := isRetCall
io.out.pd(i).brType := brType
io.out.pd(i).isCall := isCall
io.out.pd(i).hasRet := hasRet

//io.out.expInstr(i) := expander.io.out.bits
io.out.instr(i) :=inst
io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset)
io.out.instr(i) :=inst
io.out.jumpOffset(i) := Mux(io.out.pd(i).isBr, brOffset, jalOffset)
}

// the first half is always reliable
Expand Down Expand Up @@ -262,12 +265,11 @@ class F3Predecoder(implicit p: Parameters) extends XSModule with HasPdConst {
val out = Output(new F3PreDecodeResp)
})
io.out.pd.zipWithIndex.map{ case (pd,i) =>
pd.valid := DontCare
pd.isRVC := DontCare
pd.brType := brInfo(io.in.instr(i))(0)
pd.isCall := brInfo(io.in.instr(i))(1)
pd.isRet := brInfo(io.in.instr(i))(2)
pd.isRetCall := brInfo(io.in.instr(i))(3)
pd.valid := DontCare
pd.isRVC := DontCare
pd.brType := brInfo(io.in.instr(i))(0)
pd.isCall := brInfo(io.in.instr(i))(1)
pd.hasRet := brInfo(io.in.instr(i))(2)
}

}
Expand Down
Loading
Loading