Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement the new way to train the branch predictor. #93

Open
wants to merge 17 commits into
base: maintainer/bryan
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions components/BranchPredictor/BTB.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "BTB.hpp"
#include "components/uFetch/uFetchTypes.hpp"

#include <cstdint>

Expand Down Expand Up @@ -92,9 +93,9 @@ BTB::update(VirtualMemoryAddress aPC, eBranchType aType, VirtualMemoryAddress aT
}

bool
BTB::update(BranchFeedback const& aFeedback)
BTB::update(const BPredState &aFeedback)
{
return update(aFeedback.thePC, aFeedback.theActualType, aFeedback.theActualTarget);
return update(aFeedback.pc, aFeedback.theActualType, aFeedback.theActualTarget);
}

json
Expand Down
3 changes: 2 additions & 1 deletion components/BranchPredictor/BTB.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define FLEXUS_BTB

#include "BTBSet.hpp"
#include "components/uFetch/uFetchTypes.hpp"
#include "core/checkpoint/json.hpp"
#include "core/types.hpp"

Expand Down Expand Up @@ -36,7 +37,7 @@ class BTB
boost::optional<VirtualMemoryAddress> target(VirtualMemoryAddress anAddress);
// Update or add a new entry to the BTB
bool update(VirtualMemoryAddress aPC, eBranchType aType, VirtualMemoryAddress aTarget);
bool update(BranchFeedback const& aFeedback);
bool update(const BPredState &aFeedback);

json saveState() const;
void loadState(json checkpoint);
Expand Down
90 changes: 63 additions & 27 deletions components/BranchPredictor/BranchPredictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ BranchPredictor::BranchPredictor(std::string const& aName, uint32_t anIndex, uin
, thePredictions_TAGE(aName + "-predictions:TAGE")
, theCorrect_TAGE(aName + "-correct:TAGE")
, theMispredict_TAGE(aName + "-mispredict:TAGE")
, theMispredict_TAGE_User(aName + "-mispredict:TAGE:User")
, theMispredict_TAGE_System(aName + "-mispredict:TAGE:System")

, thePredictions_BTB(aName + "-predictions:BTB")
, theCorrect_BTB(aName + "-correct:BTB")
, theMispredict_BTB(aName + "-mispredict:BTB")
, theMispredict_BTB_User(aName + "-mispredict:BTB:User")
, theMispredict_BTB_System(aName + "-mispredict:BTB:System")
{
}

Expand All @@ -52,11 +56,19 @@ BranchPredictor::predictConditional(VirtualMemoryAddress anAddress, BPredState&
}

void
BranchPredictor::reconstructHistory(BPredState aBPState)
BranchPredictor::recoverHistory(const BPredRedictRequest& aRequest)
{
assert(aBPState.theActualType != kNonBranch);
theTage.restore_history(*aRequest.theBPState);

theTage.restore_all_state(aBPState);
if (!aRequest.theInsertNewHistory) {
return;
}

const BPredState &aBPState = *aRequest.theBPState;

if(aBPState.theActualType == Flexus::SharedTypes::kNonBranch) {
return;
}

if (aBPState.theActualType == kConditional) {
if (aBPState.theActualDirection == kTaken) {
Expand All @@ -77,6 +89,12 @@ BranchPredictor::isBranch(VirtualMemoryAddress anAddress)
return theBTB.contains(anAddress);
}

void
BranchPredictor::checkpointHistory(BPredState& aBPState) const
{
theTage.checkpointHistory(aBPState);
}

VirtualMemoryAddress
BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
{
Expand All @@ -91,7 +109,6 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)

switch (aBPState.thePredictedType) {
case kNonBranch:
theTage.checkpoint_history(aBPState);
aBPState.thePredictedTarget = VirtualMemoryAddress(0);
break;
case kConditional:
Expand All @@ -109,7 +126,8 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
} else {
aBPState.thePredictedTarget = VirtualMemoryAddress(0);
}
theTage.get_prediction((uint64_t)anAddress, aBPState);
// theTage.get_prediction((uint64_t)anAddress, aBPState);
theTage.update_history(aBPState, true, aBPState.pc);
break;
default: aBPState.thePredictedTarget = VirtualMemoryAddress(0); break;
}
Expand All @@ -124,55 +142,73 @@ BranchPredictor::predict(VirtualMemoryAddress anAddress, BPredState& aBPState)
}

void
BranchPredictor::feedback(VirtualMemoryAddress anAddress,
eBranchType anActualType,
eDirection anActualDirection,
VirtualMemoryAddress anActualAddress,
BPredState& aBPState)
BranchPredictor::train(const BPredState& aBPState)
{
DBG_(VVerb, (<< "Training Branch Predictor by PC: " << std::hex << aBPState.pc));
// Implementation of feedback function
theBTB.update(anAddress, anActualType, anActualAddress);
theBTB.update(aBPState.pc, aBPState.theActualType, aBPState.theActualTarget);

bool is_system = ((uint64_t)aBPState.pc >> 63) != 0;

bool is_mispredict = false;
if (anActualType != aBPState.thePredictedType) {
if (aBPState.theActualType != aBPState.thePredictedType) {
is_mispredict = true;
} else {
if (anActualType == kConditional) {
if (!(aBPState.thePrediction >= kNotTaken) && (anActualDirection >= kNotTaken)) {
if ((aBPState.thePrediction <= kTaken) && (anActualDirection <= kTaken)) {
if (anActualAddress == aBPState.thePredictedTarget) { is_mispredict = true; }
if (aBPState.theActualType == kConditional) {
if (!(aBPState.thePrediction >= kNotTaken) && (aBPState.theActualDirection >= kNotTaken)) {
if ((aBPState.thePrediction <= kTaken) && (aBPState.theActualDirection <= kTaken)) {
if (aBPState.theActualTarget == aBPState.thePredictedTarget) { is_mispredict = true; }
} else {
is_mispredict = true;
}
}
}
}

aBPState.theActualDirection = anActualDirection;
aBPState.theActualType = anActualType;

if (is_mispredict) {

if (aBPState.thePredictedType == kConditional) {
// we need to figure out whether the direction was correct or the target was correct
if (aBPState.thePrediction <= kTaken) {
if (anActualDirection >= kTaken) {
if (aBPState.theActualDirection >= kTaken) {
++theMispredict_TAGE;
if (is_system) {
++theMispredict_TAGE_System;
} else {
++theMispredict_TAGE_User;
}
} else {
++theMispredict_BTB;
if (is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
}
} else {
if (anActualAddress != aBPState.thePredictedTarget) {
if (aBPState.thePredictedTarget != aBPState.thePredictedTarget) {
++theMispredict_BTB;
if(is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
} else {
++theMispredict_TAGE;
if (is_system) {
++theMispredict_TAGE_System;
} else {
++theMispredict_TAGE_User;
}
}
}
} else {
++theMispredict_BTB;
if (is_system) {
++theMispredict_BTB_System;
} else {
++theMispredict_BTB_User;
}
}

reconstructHistory(aBPState);
} else {
// If the prediction was correct, we need to update the stats
if (aBPState.thePredictedType == kConditional) {
Expand All @@ -188,9 +224,9 @@ BranchPredictor::feedback(VirtualMemoryAddress anAddress,
}
++theBranches;

if (aBPState.thePredictedType == kConditional && anActualType == kConditional) {
bool taken = (anActualDirection <= kTaken);
theTage.update_predictor(anAddress, aBPState, taken);
if (aBPState.thePredictedType == kConditional && aBPState.thePredictedType == kConditional) {
bool taken = (aBPState.theActualDirection <= kTaken);
theTage.update_predictor(aBPState.pc, aBPState, taken);
}
}

Expand Down
20 changes: 13 additions & 7 deletions components/BranchPredictor/BranchPredictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class BranchPredictor
Stat::StatCounter thePredictions_TAGE;
Stat::StatCounter theCorrect_TAGE;
Stat::StatCounter theMispredict_TAGE;
Stat::StatCounter theMispredict_TAGE_User;
Stat::StatCounter theMispredict_TAGE_System;

Stat::StatCounter thePredictions_BTB;
Stat::StatCounter theCorrect_BTB;
Stat::StatCounter theMispredict_BTB;
Stat::StatCounter theMispredict_BTB_User;
Stat::StatCounter theMispredict_BTB_System;


private:
/* Depending on whether the prediction of the Branch Predictor we use is Taken or Not Taken, the target is returned
Expand All @@ -37,18 +42,19 @@ class BranchPredictor
*/
VirtualMemoryAddress predictConditional(VirtualMemoryAddress anAddress, BPredState& aBPState);

void reconstructHistory(BPredState aBPState);

public:
BranchPredictor(std::string const& aName, uint32_t anIndex, uint32_t aBTBSets, uint32_t aBTBWays);
bool isBranch(VirtualMemoryAddress anAddress);

void checkpointHistory(BPredState& aBPState) const;

VirtualMemoryAddress predict(VirtualMemoryAddress anAddress, BPredState& aBPState);
void feedback(VirtualMemoryAddress anAddress,
eBranchType anActualType,
eDirection anActualDirection,
VirtualMemoryAddress anActualAddress,
BPredState& aBPState);

// This function is called whenever a prediction is resolved.
void recoverHistory(const BPredRedictRequest& aRequest);

// This function is called whenever an instruction triggering a prediction retires.
void train(const BPredState& aBPState);

void loadState(std::string const& aDirName);
void saveState(std::string const& aDirName);
Expand Down
Loading
Loading