Skip to content

Commit

Permalink
feat: Implement the new way to train the branch predictor.
Browse files Browse the repository at this point in the history
Now the rediction and the training is decoupled:
- Redirection can happen speculatively. Redirection requires to restore the branch history (for TAGE) and append the new history of the instruction triggering redirection (e.g., branches, resync request).
- Training the branch predictor happens for all instructions in their commit stage. This step includes updating the BTB as well as update the corresponding counters in the TAGE.

This commit also cleans unused effects related to the branches.
  • Loading branch information
xusine committed Dec 10, 2024
1 parent b0426f6 commit 7e0c881
Show file tree
Hide file tree
Showing 26 changed files with 504 additions and 679 deletions.
5 changes: 3 additions & 2 deletions components/BranchPredictor/BTB.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "BTB.hpp"
#include "components/uFetch/uFetchTypes.hpp"

#include <cstdint>

Expand Down Expand Up @@ -92,9 +93,9 @@ BTB::update(VirtualMemoryAddress aPC, eBranchType aType, VirtualMemoryAddress aT
}

bool
BTB::update(BranchFeedback const& aFeedback)
BTB::update(const BPredState &aFeedback)
{
return update(aFeedback.thePC, aFeedback.theActualType, aFeedback.theActualTarget);
return update(aFeedback.pc, aFeedback.theActualType, aFeedback.theActualTarget);
}

json
Expand Down
3 changes: 2 additions & 1 deletion components/BranchPredictor/BTB.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define FLEXUS_BTB

#include "BTBSet.hpp"
#include "components/uFetch/uFetchTypes.hpp"
#include "core/checkpoint/json.hpp"
#include "core/types.hpp"

Expand Down Expand Up @@ -36,7 +37,7 @@ class BTB
boost::optional<VirtualMemoryAddress> target(VirtualMemoryAddress anAddress);
// Update or add a new entry to the BTB
bool update(VirtualMemoryAddress aPC, eBranchType aType, VirtualMemoryAddress aTarget);
bool update(BranchFeedback const& aFeedback);
bool update(const BPredState &aFeedback);

json saveState() const;
void loadState(json checkpoint);
Expand Down
90 changes: 63 additions & 27 deletions components/BranchPredictor/BranchPredictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ BranchPredictor::BranchPredictor(std::string const& aName, uint32_t anIndex, uin
, thePredictions_TAGE(aName + "-predictions:TAGE")
, theCorrect_TAGE(aName + "-correct:TAGE")
, theMispredict_TAGE(aName + "-mispredict:TAGE")
, theMispredict_TAGE_User(aName + "-mispredict:TAGE:User")
, theMispredict_TAGE_System(aName + "-mispredict:TAGE:System")

, thePredictions_BTB(aName + "-predictions:BTB")
, theCorrect_BTB(aName + "-correct:BTB")
, theMispredict_BTB(aName + "-mispredict:BTB")
, theMispredict_BTB_User(aName + "-mispredict:BTB:User")
, theMispredict_BTB_System(aName + "-mispredict:BTB:System")
{
}

Expand All @@ -52,11 +56,19 @@ BranchPredictor::predictConditional(VirtualMemoryAddress anAddress, BPredState&
}

void
BranchPredictor::reconstructHistory(BPredState aBPState)
BranchPredictor::recoverHistory(const BPredRedictRequest& aRequest)
{
assert(aBPState.theActualType != kNonBranch);
theTage.restore_history(*aRequest.theBPState);

theTage.restore_all_state(aBPState);
if (!aRequest.theInsertNewHistory) {
return;
}

const BPredState &aBPState = *aRequest.theBPState;

if(aBPState.theActualType == Flexus::SharedTypes::kNonBranch) {
return;
}

if (aBPState.theActualType == kConditional) {
if (aBPState.theActualDirection == kTaken) {
Expand All @@ -77,6 +89,12 @@ BranchPredictor::isBranch(VirtualMemoryAddress anAddress)
return theBTB.contains(anAddress);
}

void
BranchPredictor::checkpointHistory(BPredState& aBPState) const
{
theTage.checkpointHistory(aBPState);
}

VirtualMemoryAddress
BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
{
Expand All @@ -91,7 +109,6 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)

switch (aBPState.thePredictedType) {
case kNonBranch:
theTage.checkpoint_history(aBPState);
aBPState.thePredictedTarget = VirtualMemoryAddress(0);
break;
case kConditional:
Expand All @@ -109,7 +126,8 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
} else {
aBPState.thePredictedTarget = VirtualMemoryAddress(0);
}
theTage.get_prediction((uint64_t)anAddress, aBPState);
// theTage.get_prediction((uint64_t)anAddress, aBPState);
theTage.update_history(aBPState, true, aBPState.pc);
break;
default: aBPState.thePredictedTarget = VirtualMemoryAddress(0); break;
}
Expand All @@ -124,55 +142,73 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
}

void
BranchPredictor::feedback(VirtualMemoryAddress anAddress,
eBranchType anActualType,
eDirection anActualDirection,
VirtualMemoryAddress anActualAddress,
BPredState& aBPState)
BranchPredictor::train(const BPredState& aBPState)
{
DBG_(VVerb, (<< "Training Branch Predictor by PC: " << std::hex << aBPState.pc));
// Implementation of feedback function
theBTB.update(anAddress, anActualType, anActualAddress);
theBTB.update(aBPState.pc, aBPState.theActualType, aBPState.theActualTarget);

bool is_system = ((uint64_t)aBPState.pc >> 63) != 0;

bool is_mispredict = false;
if (anActualType != aBPState.thePredictedType) {
if (aBPState.theActualType != aBPState.thePredictedType) {
is_mispredict = true;
} else {
if (anActualType == kConditional) {
if (!(aBPState.thePrediction >= kNotTaken) && (anActualDirection >= kNotTaken)) {
if ((aBPState.thePrediction <= kTaken) && (anActualDirection <= kTaken)) {
if (anActualAddress == aBPState.thePredictedTarget) { is_mispredict = true; }
if (aBPState.theActualType == kConditional) {
if (!(aBPState.thePrediction >= kNotTaken) && (aBPState.theActualDirection >= kNotTaken)) {
if ((aBPState.thePrediction <= kTaken) && (aBPState.theActualDirection <= kTaken)) {
if (aBPState.theActualTarget == aBPState.thePredictedTarget) { is_mispredict = true; }
} else {
is_mispredict = true;
}
}
}
}

aBPState.theActualDirection = anActualDirection;
aBPState.theActualType = anActualType;

if (is_mispredict) {

if (aBPState.thePredictedType == kConditional) {
// we need to figure out whether the direction was correct or the target was correct
if (aBPState.thePrediction <= kTaken) {
if (anActualDirection >= kTaken) {
if (aBPState.theActualDirection >= kTaken) {
++theMispredict_TAGE;
if (is_system) {
++theMispredict_TAGE_System;
} else {
++theMispredict_TAGE_User;
}
} else {
++theMispredict_BTB;
if (is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
}
} else {
if (anActualAddress != aBPState.thePredictedTarget) {
if (aBPState.thePredictedTarget != aBPState.thePredictedTarget) {
++theMispredict_BTB;
if(is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
} else {
++theMispredict_TAGE;
if (is_system) {
++theMispredict_TAGE_System;
} else {
++theMispredict_TAGE_User;
}
}
}
} else {
++theMispredict_BTB;
if (is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
}

reconstructHistory(aBPState);
} else {
// If the prediction was correct, we need to update the stats
if (aBPState.thePredictedType == kConditional) {
Expand All @@ -188,9 +224,9 @@ BranchPredictor::feedback(VirtualMemoryAddress anAddress,
}
++theBranches;

if (aBPState.thePredictedType == kConditional && anActualType == kConditional) {
bool taken = (anActualDirection <= kTaken);
theTage.update_predictor(anAddress, aBPState, taken);
if (aBPState.thePredictedType == kConditional && aBPState.thePredictedType == kConditional) {
bool taken = (aBPState.theActualDirection <= kTaken);
theTage.update_predictor(aBPState.pc, aBPState, taken);
}
}

Expand Down
20 changes: 13 additions & 7 deletions components/BranchPredictor/BranchPredictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class BranchPredictor
Stat::StatCounter thePredictions_TAGE;
Stat::StatCounter theCorrect_TAGE;
Stat::StatCounter theMispredict_TAGE;
Stat::StatCounter theMispredict_TAGE_User;
Stat::StatCounter theMispredict_TAGE_System;

Stat::StatCounter thePredictions_BTB;
Stat::StatCounter theCorrect_BTB;
Stat::StatCounter theMispredict_BTB;
Stat::StatCounter theMispredict_BTB_User;
Stat::StatCounter theMispredict_BTB_System;


private:
/* Depending on whether the prediction of the Branch Predictor we use is Taken or Not Taken, the target is returned
Expand All @@ -37,18 +42,19 @@ class BranchPredictor
*/
VirtualMemoryAddress predictConditional(VirtualMemoryAddress anAddress, BPredState& aBPState);

void reconstructHistory(BPredState aBPState);

public:
BranchPredictor(std::string const& aName, uint32_t anIndex, uint32_t aBTBSets, uint32_t aBTBWays);
bool isBranch(VirtualMemoryAddress anAddress);

void checkpointHistory(BPredState& aBPState) const;

VirtualMemoryAddress predict(VirtualMemoryAddress anAddress, BPredState& aBPState);
void feedback(VirtualMemoryAddress anAddress,
eBranchType anActualType,
eDirection anActualDirection,
VirtualMemoryAddress anActualAddress,
BPredState& aBPState);

// This function is called whenever a prediction is resolved.
void recoverHistory(const BPredRedictRequest& aRequest);

// This function is called whenever an instruction triggering a prediction retires.
void train(const BPredState& aBPState);

void loadState(std::string const& aDirName);
void saveState(std::string const& aDirName);
Expand Down
Loading

0 comments on commit 7e0c881

Please sign in to comment.