Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTW Hypervisor bug fixes: check GPA bits higher than HGATP.Mode #3591

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/main/scala/rocket/PTW.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
val aux_count = Reg(UInt(log2Ceil(pgLevels).W))
/** pte for 2-stage translation */
val aux_pte = Reg(new PTE)
val aux_ppn_hi = (pgLevels > 4 && r_req.addr.getWidth > aux_pte.ppn.getWidth).option(Reg(UInt((r_req.addr.getWidth - aux_pte.ppn.getWidth).W)))
val gpa_pgoff = Reg(UInt(pgIdxBits.W)) // only valid in resp_gf case
val stage2 = Reg(Bool())
val stage2_final = Reg(Bool())
Expand All @@ -301,7 +300,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
}
}
// construct pte from mem.resp
val (pte, invalid_paddr) = {
val (pte, invalid_paddr, invalid_gpa) = {
val tmp = mem_resp_data.asTypeOf(new PTE())
val res = WireDefault(tmp)
res.ppn := Mux(do_both_stages && !stage2, tmp.ppn(vpnBits.min(tmp.ppn.getWidth)-1, 0), tmp.ppn(ppnBits-1, 0))
Expand All @@ -310,10 +309,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
for (i <- 0 until pgLevels-1)
when (count <= i.U && tmp.ppn((pgLevels-1-i)*pgLevelBits-1, (pgLevels-2-i)*pgLevelBits) =/= 0.U) { res.v := false.B }
}
(res, Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U))
(res,
Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U),
do_both_stages && !stage2 && checkInvalidHypervisorGPA(r_hgatp, tmp.ppn))
}
// find non-leaf PTE, need traverse
val traverse = pte.table() && !invalid_paddr && count < (pgLevels-1).U
val traverse = pte.table() && !invalid_paddr && !invalid_gpa && count < (pgLevels-1).U
/** address send to mem for enquerry */
val pte_addr = if (!usingVM) 0.U else {
val vpn_idxs = (0 until pgLevels).map { i =>
Expand All @@ -328,19 +329,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
//use vpn slice as offset
raw_pte_addr.apply(size.min(raw_pte_addr.getWidth) - 1, 0)
}
/** pte_cache input addr */
val pte_cache_addr = if (!usingHypervisor) pte_addr else {
val vpn_idxs = (0 until pgLevels-1).map { i =>
val ext_aux_pte_ppn = aux_ppn_hi match {
case None => aux_pte.ppn
case Some(hi) => Cat(hi, aux_pte.ppn)
}
(ext_aux_pte_ppn >> (pgLevels - i - 1) * pgLevelBits)(pgLevelBits - 1, 0)
}
val vpn_idx = vpn_idxs(count)
val raw_pte_cache_addr = Cat(r_pte.ppn, vpn_idx) << log2Ceil(xLen/8)
raw_pte_cache_addr(vaddrBits.min(raw_pte_cache_addr.getWidth)-1, 0)
}
/** stage2_pte_cache input addr */
val stage2_pte_cache_addr = if (!usingHypervisor) 0.U else {
val vpn_idxs = (0 until pgLevels - 1).map { i =>
Expand Down Expand Up @@ -373,7 +361,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
else can_hit
val tag =
if (s2) Cat(true.B, stage2_pte_cache_addr.padTo(vaddrBits))
else Cat(r_req.vstage1, pte_cache_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits))
else Cat(r_req.vstage1, pte_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits))

val hits = tags.map(_ === tag).asUInt & valid
val hit = hits.orR && can_hit
Expand Down Expand Up @@ -539,7 +527,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
io.mem.req.bits.data := DontCare
io.mem.req.bits.mask := DontCare

io.mem.s1_kill := l2_hit || state =/= s_wait1
io.mem.s1_kill := l2_hit || (state =/= s_wait1) || resp_gf
io.mem.s1_data := DontCare
io.mem.s2_kill := false.B

Expand Down Expand Up @@ -607,12 +595,11 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
count := Mux(arb.io.out.bits.bits.stage2, hgatp_initial_count, satp_initial_count)
aux_count := Mux(arb.io.out.bits.bits.vstage1, vsatp_initial_count, 0.U)
aux_pte.ppn := aux_ppn
aux_ppn_hi.foreach { _ := aux_ppn >> aux_pte.ppn.getWidth }
aux_pte.reserved_for_future := 0.U
resp_ae_ptw := false.B
resp_ae_final := false.B
resp_pf := false.B
resp_gf := false.B
resp_gf := checkInvalidHypervisorGPA(io.dpath.hgatp, aux_ppn) && arb.io.out.bits.bits.stage2
resp_hr := true.B
resp_hw := true.B
resp_hx := true.B
Expand All @@ -630,7 +617,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
when (stage2_pte_cache_hit) {
aux_count := aux_count + 1.U
aux_pte.ppn := stage2_pte_cache_data
aux_ppn_hi.foreach { _ := 0.U }
aux_pte.reserved_for_future := 0.U
pte_hit := true.B
}.elsewhen (pte_cache_hit) {
Expand All @@ -639,6 +625,10 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
}.otherwise {
next_state := Mux(io.mem.req.ready, s_wait1, s_req)
}
when(resp_gf) {
next_state := s_ready
resp_valid(r_req_dest) := true.B
}
}
is (s_wait1) {
// This Mux is for the l2_error case; the l2_hit && !l2_error case is overriden below
Expand Down Expand Up @@ -676,7 +666,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(

r_pte := OptimizationBarrier(
// l2tlb hit->find a leaf PTE(l2_pte), respond to L1TLB
Mux(l2_hit && !l2_error, l2_pte,
Mux(l2_hit && !l2_error && !resp_gf, l2_pte,
// S2 PTE cache hit -> proceed to the next level of walking, update the r_pte with hgatp
Mux(state === s_req && stage2_pte_cache_hit, makeHypervisorRootPTE(r_hgatp, stage2_pte_cache_data, l2_pte),
// pte cache hit->find a non-leaf PTE(pte_cache),continue to request mem
Expand All @@ -691,7 +681,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
Mux(arb.io.out.fire, Mux(arb.io.out.bits.bits.stage2, makeHypervisorRootPTE(io.dpath.hgatp, io.dpath.vsatp.ppn, r_pte), makePTE(satp.ppn, r_pte)),
r_pte))))))))

when (l2_hit && !l2_error) {
when (l2_hit && !l2_error && !resp_gf) {
assert(state === s_req || state === s_wait1)
next_state := s_ready
resp_valid(r_req_dest) := true.B
Expand Down Expand Up @@ -753,14 +743,14 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
val s1_ppns = (0 until pgLevels-1).map(i => Cat(pte.ppn(pte.ppn.getWidth-1, (pgLevels-i-1)*pgLevelBits), r_req.addr(((pgLevels-i-1)*pgLevelBits min vpnBits)-1,0).padTo((pgLevels-i-1)*pgLevelBits))) :+ pte.ppn
makePTE(s1_ppns(count), pte)
})
aux_ppn_hi.foreach { _ := 0.U }
stage2 := true.B
}

for (i <- 0 until pgLevels) {
val leaf = mem_resp_valid && !traverse && count === i.U
ccover(leaf && pte.v && !invalid_paddr && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i")
ccover(leaf && pte.v && !invalid_paddr && !invalid_gpa && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i")
ccover(leaf && pte.v && invalid_paddr, s"L${i}_BAD_PPN_MSB", s"PPN too large, level $i")
ccover(leaf && pte.v && invalid_gpa, s"L${i}_BAD_GPA_MSB", s"GPA too large, level $i")
ccover(leaf && pte.v && pte.reserved_for_future =/= 0.U, s"L${i}_BAD_RSV_MSB", s"reserved MSBs set, level $i")
ccover(leaf && !mem_resp_data(0), s"L${i}_INVALID_PTE", s"page not present, level $i")
if (i != pgLevels-1)
Expand Down Expand Up @@ -790,6 +780,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
pte.ppn := Cat(hgatp.ppn >> maxHypervisorExtraAddrBits, lsbs)
pte
}
/** use hgatp and vpn to check for gpa out of range */
private def checkInvalidHypervisorGPA(hgatp: PTBR, vpn: UInt) = {
val count = pgLevels.U - minPgLevels.U - hgatp.additionalPgLevels
val idxs = (0 to pgLevels-minPgLevels).map(i => (vpn >> ((pgLevels-i)*pgLevelBits)+maxHypervisorExtraAddrBits))
idxs.extract(count) =/= 0.U
}
}

/** Mix-ins for constructing tiles that might have a PTW */
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/rocket/TLB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T
val pf_ld_array = Mux(cmd_read, ((~Mux(cmd_readx, x_array, r_array) & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U)
val pf_st_array = Mux(cmd_write_perms, ((~w_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U)
val pf_inst_array = ((~x_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array
val gf_ld_array = Mux(priv_v && cmd_read, ~Mux(cmd_readx, hx_array, hr_array) & ~ptw_ae_array, 0.U)
val gf_st_array = Mux(priv_v && cmd_write_perms, ~hw_array & ~ptw_ae_array, 0.U)
val gf_inst_array = Mux(priv_v, ~hx_array & ~ptw_ae_array, 0.U)
val gf_ld_array = Mux(priv_v && cmd_read, (~Mux(cmd_readx, hx_array, hr_array) | ptw_gf_array) & ~ptw_ae_array, 0.U)
val gf_st_array = Mux(priv_v && cmd_write_perms, (~hw_array | ptw_gf_array) & ~ptw_ae_array, 0.U)
val gf_inst_array = Mux(priv_v, (~hx_array | ptw_gf_array) & ~ptw_ae_array, 0.U)

val gpa_hits = {
val need_gpa_mask = if (instruction) gf_inst_array else gf_ld_array | gf_st_array
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/util/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ package object util {
def isOneOf(u1: UInt, u2: UInt*): Bool = isOneOf(u1 +: u2.toSeq)
}

implicit class VecToAugmentedVec[T <: Data](private val x: Vec[T]) extends AnyVal {

/** Like Vec.apply(idx), but tolerates indices of mismatched width */
def extract(idx: UInt): T = x((idx | 0.U(log2Ceil(x.size).W)).extract(log2Ceil(x.size) - 1, 0))
}

implicit class SeqToAugmentedSeq[T <: Data](private val x: Seq[T]) extends AnyVal {
def apply(idx: UInt): T = {
if (x.size <= 1) {
Expand All @@ -34,6 +40,8 @@ package object util {
}
}

def extract(idx: UInt): T = VecInit(x).extract(idx)

def asUInt: UInt = Cat(x.map(_.asUInt).reverse)

def rotate(n: Int): Seq[T] = x.drop(n) ++ x.take(n)
Expand Down
Loading