Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1059 from mrcslws/want-more-learning
Browse files Browse the repository at this point in the history
Grow synapses in predicted columns, not just bursting columns
  • Loading branch information
mrcslws authored Aug 26, 2016
2 parents fc4ed78 + 99f4704 commit 8a846d1
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 157 deletions.
1 change: 1 addition & 0 deletions src/nupic/algorithms/Connections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ SegmentExcitationTally::SegmentExcitationTally(
numActiveSynapsesForSegment_(connections.nextFlatIdx_, 0),
numMatchingSynapsesForSegment_(connections.nextFlatIdx_, 0)
{
NTA_ASSERT(matchingPermanenceThreshold <= activePermanenceThreshold);
}

void SegmentExcitationTally::addActivePresynapticCell(CellIdx cell)
Expand Down
173 changes: 123 additions & 50 deletions src/nupic/algorithms/TemporalMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void TemporalMemory::initialize(
NTA_CHECK(connectedPermanence >= 0.0 && connectedPermanence <= 1.0);
NTA_CHECK(permanenceIncrement >= 0.0 && permanenceIncrement <= 1.0);
NTA_CHECK(permanenceDecrement >= 0.0 && permanenceDecrement <= 1.0);
NTA_CHECK(minThreshold <= activationThreshold);

// Save member variables

Expand Down Expand Up @@ -153,6 +154,11 @@ void TemporalMemory::initialize(
matchingSegments_.clear();
}

static CellIdx cellForSegment(const SegmentOverlap& s)
{
return s.segment.cell;
}

static CellIdx getLeastUsedCell(
Connections& connections,
Random& rng,
Expand Down Expand Up @@ -278,7 +284,7 @@ static void growSynapses(
// Pick nActual cells randomly.
for (UInt32 c = 0; c < nActual; c++)
{
size_t i = rng.getUInt32(std::distance(candidates.begin(), eligibleEnd));;
size_t i = rng.getUInt32(std::distance(candidates.begin(), eligibleEnd));
connections.createSynapse(segment, candidates[i], initialPermanence);
eligibleEnd--;
std::swap(candidates[i], *eligibleEnd);
Expand All @@ -289,34 +295,82 @@ static void activatePredictedColumn(
vector<CellIdx>& activeCells,
vector<CellIdx>& winnerCells,
Connections& connections,
Random& rng,
vector<SegmentOverlap>::const_iterator columnActiveSegmentsBegin,
vector<SegmentOverlap>::const_iterator columnActiveSegmentsEnd,
vector<SegmentOverlap>::const_iterator columnMatchingSegmentsBegin,
vector<SegmentOverlap>::const_iterator columnMatchingSegmentsEnd,
const vector<CellIdx>& prevActiveCells,
const vector<CellIdx>& prevWinnerCells,
UInt maxNewSynapseCount,
Permanence initialPermanence,
Permanence permanenceIncrement,
Permanence permanenceDecrement,
bool learn)
{
auto active = columnActiveSegmentsBegin;
do
for (auto& cellData : iterGroupBy(
columnActiveSegmentsBegin, columnActiveSegmentsEnd, cellForSegment,
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, cellForSegment))
{
const CellIdx cell = active->segment.cell;
activeCells.push_back(cell);
winnerCells.push_back(cell);
CellIdx cell;
vector<SegmentOverlap>::const_iterator
cellActiveSegmentsBegin, cellActiveSegmentsEnd,
cellMatchingSegmentsBegin, cellMatchingSegmentsEnd;
tie(cell,
cellActiveSegmentsBegin, cellActiveSegmentsEnd,
cellMatchingSegmentsBegin, cellMatchingSegmentsEnd) = cellData;

// This cell might have multiple active segments.
do
if (cellActiveSegmentsBegin != cellActiveSegmentsEnd)
{
activeCells.push_back(cell);
winnerCells.push_back(cell);

if (learn)
{
adaptSegment(connections,
active->segment,
prevActiveCells,
permanenceIncrement, permanenceDecrement);
// Learn on every active segment.

auto bySegment = [](const SegmentOverlap& x) { return x.segment; };
for (auto segmentData : iterGroupBy(
cellActiveSegmentsBegin, cellActiveSegmentsEnd, bySegment,
cellMatchingSegmentsBegin, cellMatchingSegmentsEnd, bySegment))
{
// Find the active segment's corresponding "matching" overlap.
Segment segment;
vector<SegmentOverlap>::const_iterator
activeOverlapsBegin, activeOverlapsEnd,
matchingOverlapsBegin, matchingOverlapsEnd;
tie(segment,
activeOverlapsBegin, activeOverlapsEnd,
matchingOverlapsBegin, matchingOverlapsEnd) = segmentData;
if (activeOverlapsBegin != activeOverlapsEnd)
{
// Active segments are a superset of matching segments.
NTA_ASSERT(std::distance(activeOverlapsBegin,
activeOverlapsEnd) == 1);
NTA_ASSERT(std::distance(matchingOverlapsBegin,
matchingOverlapsEnd) == 1);

adaptSegment(connections,
segment,
prevActiveCells,
permanenceIncrement, permanenceDecrement);

const Int32 nActivePotentialSynapses =
matchingOverlapsBegin->overlap;
const Int32 nGrowDesired =
maxNewSynapseCount - nActivePotentialSynapses;
if (nGrowDesired > 0)
{
growSynapses(connections, rng,
segment, nGrowDesired,
prevWinnerCells,
initialPermanence);
}
}
}
}
active++;
} while (active != columnActiveSegmentsEnd &&
active->segment.cell == cell);
} while (active != columnActiveSegmentsEnd);
}
}
}

static void burstColumn(
Expand All @@ -330,55 +384,58 @@ static void burstColumn(
const vector<CellIdx>& prevActiveCells,
const vector<CellIdx>& prevWinnerCells,
UInt cellsPerColumn,
Permanence initialPermanence,
UInt maxNewSynapseCount,
Permanence initialPermanence,
Permanence permanenceIncrement,
Permanence permanenceDecrement,
bool learn)
{
// Calculate the active cells.
const CellIdx start = column * cellsPerColumn;
const CellIdx end = start + cellsPerColumn;
for (CellIdx cell = start; cell < end; cell++)
{
activeCells.push_back(cell);
}

if (columnMatchingSegmentsBegin != columnMatchingSegmentsEnd)
{
auto bestMatch = std::max_element(
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd,
[](const SegmentOverlap& a, const SegmentOverlap& b)
{
return a.overlap < b.overlap;
});
const auto bestMatching = std::max_element(
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd,
[](const SegmentOverlap& a, const SegmentOverlap& b)
{
return a.overlap < b.overlap;
});

const CellIdx winnerCell = (bestMatching != columnMatchingSegmentsEnd)
? bestMatching->segment.cell
: getLeastUsedCell(connections, rng, column, cellsPerColumn);

winnerCells.push_back(bestMatch->segment.cell);
winnerCells.push_back(winnerCell);

if (learn)
// Learn.
if (learn)
{
if (bestMatching != columnMatchingSegmentsEnd)
{
// Learn on the best matching segment.
adaptSegment(connections,
bestMatch->segment,
bestMatching->segment,
prevActiveCells,
permanenceIncrement, permanenceDecrement);

const Int32 nGrowDesired = maxNewSynapseCount - bestMatch->overlap;
const Int32 nGrowDesired = maxNewSynapseCount - bestMatching->overlap;
if (nGrowDesired > 0)
{
growSynapses(connections, rng,
bestMatch->segment, nGrowDesired,
bestMatching->segment, nGrowDesired,
prevWinnerCells,
initialPermanence);
}
}
}
else
{
const CellIdx winnerCell = getLeastUsedCell(connections, rng, column,
cellsPerColumn);
winnerCells.push_back(winnerCell);

if (learn)
else
{
// No matching segments.
// Grow a new segment and learn on it.

// Don't grow a segment that will never match.
const UInt32 nGrowExact = std::min(maxNewSynapseCount,
(UInt32)prevWinnerCells.size());
Expand Down Expand Up @@ -413,18 +470,15 @@ static void punishPredictedColumn(
}
}

void TemporalMemory::compute(
UInt activeColumnsSize,
const UInt activeColumnsUnsorted[],
void TemporalMemory::activateCells(
const vector<UInt>& activeColumns,
bool learn)
{
NTA_ASSERT(std::is_sorted(activeColumns.begin(), activeColumns.end()));

const vector<CellIdx> prevActiveCells = std::move(activeCells_);
const vector<CellIdx> prevWinnerCells = std::move(winnerCells_);

vector<UInt> activeColumns(activeColumnsUnsorted,
activeColumnsUnsorted + activeColumnsSize);
std::sort(activeColumns.begin(), activeColumns.end());

const auto columnForSegment =
[&](const SegmentOverlap& s) { return s.segment.cell / cellsPerColumn_; };

Expand All @@ -449,10 +503,12 @@ void TemporalMemory::compute(
if (columnActiveSegmentsBegin != columnActiveSegmentsEnd)
{
activatePredictedColumn(
activeCells_, winnerCells_, connections,
activeCells_, winnerCells_, connections, rng_,
columnActiveSegmentsBegin, columnActiveSegmentsEnd,
prevActiveCells,
permanenceIncrement_, permanenceDecrement_,
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd,
prevActiveCells, prevWinnerCells,
maxNewSynapseCount_,
initialPermanence_, permanenceIncrement_, permanenceDecrement_,
learn);
}
else
Expand All @@ -461,8 +517,8 @@ void TemporalMemory::compute(
activeCells_, winnerCells_, connections, rng_,
column, columnMatchingSegmentsBegin, columnMatchingSegmentsEnd,
prevActiveCells, prevWinnerCells,
cellsPerColumn_, initialPermanence_, maxNewSynapseCount_,
permanenceIncrement_, permanenceDecrement_,
cellsPerColumn_, maxNewSynapseCount_,
initialPermanence_, permanenceIncrement_, permanenceDecrement_,
learn);
}
}
Expand All @@ -478,7 +534,10 @@ void TemporalMemory::compute(
}
}
}
}

void TemporalMemory::activateDendrites(bool learn)
{
activeSegments_.clear();
matchingSegments_.clear();
connections.computeActivity(activeCells_,
Expand All @@ -497,6 +556,20 @@ void TemporalMemory::compute(
}
}

void TemporalMemory::compute(
UInt activeColumnsSize,
const UInt activeColumnsUnsorted[],
bool learn)
{
vector<UInt> activeColumns(activeColumnsUnsorted,
activeColumnsUnsorted + activeColumnsSize);
std::sort(activeColumns.begin(), activeColumns.end());

activateCells(activeColumns, learn);

activateDendrites(learn);
}

void TemporalMemory::reset(void)
{
activeCells_.clear();
Expand Down
25 changes: 24 additions & 1 deletion src/nupic/algorithms/TemporalMemory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,28 @@ namespace nupic {
*/
virtual void reset();

/**
* Calculate the active cells, using the current active columns and
* dendrite segments. Grow and reinforce synapses.
*
* @param activeColumns
* A sorted list of active column indices.
*
* @param learn
* If true, reinforce / punish / grow synapses.
*/
void activateCells(
const vector<UInt>& activeColumns, bool learn = true);

/**
* Calculate dendrite segment activity, using the current active cells.
*
* @param learn
* If true, segment activations will be recorded. This information is
* used during segment cleanup.
*/
void activateDendrites(bool learn = true);

/**
* Feeds input record through TM, performing inference and learning.
*
Expand Down Expand Up @@ -303,7 +325,8 @@ namespace nupic {
void setMinThreshold(UInt);

/**
* Returns the maximum new synapse count.
* Returns the maximum number of synapses that can be added to a segment
* in a single time step.
*
* @returns Integer number of maximum new synapse count
*/
Expand Down
Loading

0 comments on commit 8a846d1

Please sign in to comment.