Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knn dev #374

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3e295eb
modified build and pom files, Added new Parameter
cogmission Jan 9, 2016
219a023
Added new KNNClassifier class and test
cogmission Jan 9, 2016
c00eff7
Added new Enums and Constants
cogmission Jan 9, 2016
add191e
Modified BeanUtil to work with private fields, added test to ensure p…
cogmission Jan 9, 2016
ec5da11
Added assertions which validate enum compatibility
cogmission Jan 9, 2016
27e3ff1
Added test class (BeanUtilTest), and added test methods
cogmission Jan 9, 2016
7238233
Added test for builder
cogmission Jan 9, 2016
e2d4f21
Added new utility methods: rotateLeft(), rotateRight(), ravel(), unra…
cogmission Jan 10, 2016
9e8d2d1
Added new NN Matrix implementation
cogmission Jan 11, 2016
00ee815
Merge remote-tracking branch 'upstream/master' into knn_dev
cogmission Jan 11, 2016
12a1f99
more work on knn
cogmission Jan 12, 2016
8db7a17
more work on classifier
cogmission Jan 13, 2016
9223fcd
added more code and merged in ArrayUtil additions
cogmission Jan 13, 2016
2002c4c
changed input datatype to double added new ArrayUtils method for change
cogmission Jan 14, 2016
6c295da
added test files to resources dir
cogmission Jan 15, 2016
3ada066
Add test for new PCA data generation
cogmission Jan 16, 2016
6e533e1
Remove reference to test resources from source
cogmission Jan 16, 2016
aeb60ed
add tests for new ArrayUtils functions
cogmission Jan 17, 2016
c7d8f65
Add new TupleCollector and change test to use Streams
cogmission Jan 17, 2016
d9d1d79
Added matrix transpose method
cogmission Jan 17, 2016
7d62579
added original test code back in
cogmission Jan 17, 2016
a3ecca1
remove working code
cogmission Jan 17, 2016
edf9065
added test for another branch to dot product
cogmission Jan 17, 2016
0b58b3d
Renamed an ArrayUtil method, added a subtract method to support knn dev
cogmission Jan 18, 2016
4782e40
Added matrix dot with 2d to 1d, added new transpose, removed unused libs
cogmission Jan 18, 2016
2f502d5
Add completed sparsifyVector test
cogmission Jan 18, 2016
de02de1
Added more ArrayUtils tests (setRangeTo()), and finished logic branch…
cogmission Jan 19, 2016
5ac8ec8
Added la4j library
cogmission Jan 19, 2016
7db47f6
Added new linear matrix lib
cogmission Jan 29, 2016
37accba
add changes to .classpath
cogmission Jan 29, 2016
eb080e8
Merge branch 'master' into knn_dev
cogmission Jan 29, 2016
13d82ff
interim merge from ongoing work on master
May 29, 2016
0f74c76
Bring into interim working fashion following update as merge from master
May 30, 2016
8dde067
Update test following merge from master
May 30, 2016
26751a3
Merge remote-tracking branch 'upstream/master' into knn_dev
Jul 13, 2016
b7e858f
Merge remote-tracking branch 'origin/master' into knn_dev
Jul 13, 2016
abd1543
Fixing merge with latest master
Sep 30, 2016
7e374e8
Fix up tests for recent merge
Sep 30, 2016
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .classpath
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
<classpath>
<classpathentry kind="src" path="src/jmh/java"/>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="src" path="src/main/resources"/>
<classpathentry kind="src" path="src/test/java"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="libs/chutils-1.4.jar" sourcepath="/Users/cogmission/Resources/chutils/chutils-1.4-sources.jar"/>
<classpathentry kind="lib" path="libs/algorithmfoundry-shade-culled-1.3.jar"/>
<classpathentry kind="lib" path="libs/chutils-1.4.jar"/>
<classpathentry kind="lib" path="libs/fst-2.45.jar"/>
<classpathentry kind="lib" path="libs/hamcrest-core-1.3.jar"/>
<classpathentry kind="lib" path="libs/jackson-annotations-2.4.0.jar"/>
<classpathentry kind="lib" path="libs/jackson-core-2.4.4.jar"/>
<classpathentry kind="lib" path="libs/jackson-databind-2.4.4.jar"/>
<classpathentry kind="lib" path="libs/java-util-1.19.3.jar"/>
<classpathentry kind="lib" path="libs/jmh-core-1.5.1.jar"/>
<classpathentry kind="lib" path="libs/joda-time-2.5.jar"/>
<classpathentry kind="lib" path="libs/junit-4.11.jar"/>
<classpathentry kind="lib" path="libs/logback-classic-1.1.2.jar"/>
<classpathentry kind="lib" path="libs/logback-core-1.1.2.jar"/>
<classpathentry kind="lib" path="libs/rxjava-1.0.10.jar" sourcepath="/Users/cogmission/Resources/RxJava/rxjava-1.0.10-sources.jar"/>
<classpathentry kind="lib" path="libs/slf4j-api-1.7.10.jar"/>
<classpathentry kind="lib" path="libs/trove-3.1a1.jar" sourcepath="/Users/cogmission/Resources/trove/3.1a1/trove-3.1a1-src.jar"/>
<classpathentry kind="lib" path="libs/java-util-1.19.3.jar" sourcepath="/Users/cogmission/.m2/repository/com/cedarsoftware/java-util/1.19.3/java-util-1.19.3-sources.jar"/>
<classpathentry kind="lib" path="libs/objenesis-2.1.jar"/>
<classpathentry kind="lib" path="libs/algorithmfoundry-shade-culled-1.3.jar"/>
<classpathentry kind="lib" path="libs/fst-2.45.jar"/>
<classpathentry kind="lib" path="libs/rxjava-1.0.10.jar"/>
<classpathentry kind="lib" path="libs/slf4j-api-1.7.10.jar"/>
<classpathentry kind="lib" path="libs/trove-3.1a1.jar"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="output" path="bin"/>
</classpath>
17 changes: 9 additions & 8 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ apply plugin: 'eclipse'
apply plugin: 'signing'

group = 'org.numenta'
version = '0.6.9-SNAPSHOT'
version = '0.6.10-SNAPSHOT'
archivesBaseName = 'htm.java'

sourceCompatibility = 1.8
targetCompatibility = 1.8

jar {
manifest {
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.9-SNAPSHOT'
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.10-SNAPSHOT'
}
}

Expand All @@ -33,10 +33,11 @@ test {
}

// UNCOMMENT TO SEE STANDARD_OUT & STANDARD_ERR DURING BUILD

/*
test {
testLogging.showStandardStreams = true
}
*/

dependencies {
compile group: 'joda-time', name: 'joda-time', version: '2.5'
Expand All @@ -52,11 +53,11 @@ dependencies {
compile group: 'com.fasterxml.jackson.core', name: 'jackson-core', version:'2.4.4'
compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version:'2.4.4'
compile group: 'org.slf4j', name: 'slf4j-api', version:'1.7.10'
compile group: 'io.reactivex', name: 'rxjava', version: '1.0.10'
compile group: 'de.ruedigermoeller', name: 'fst', version: '2.45'
compile group: 'com.cedarsoftware', name: 'java-util', version: '1.19.3'
compile group: 'algorithmfoundry', name: 'algorithmfoundry-shade-culled', version:'1.3'
testCompile group: 'junit', name: 'junit', version:'4.11'
compile group: 'io.reactivex', name: 'rxjava', version: '1.0.10'
compile group: 'de.ruedigermoeller', name: 'fst', version: '2.45'
compile group: 'com.cedarsoftware', name: 'java-util', version: '1.19.3'
compile group: 'algorithmfoundry', name: 'algorithmfoundry-shade-culled', version:'1.3'
testCompile group: 'junit', name: 'junit', version:'4.11'
testCompile group: 'ch.qos.logback', name: 'logback-classic', version:'1.1.2'
}

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>org.numenta</groupId>
<artifactId>htm.java</artifactId>
<version>0.6.9-SNAPSHOT</version>
<version>0.6.10-SNAPSHOT</version>
<name>htm.java</name>
<description>The Java version of Numenta's HTM technology</description>

Expand Down
39 changes: 3 additions & 36 deletions src/main/java/org/numenta/nupic/Connections.java
Original file line number Diff line number Diff line change
Expand Up @@ -2229,39 +2229,7 @@ public double[] getPAOverlaps() {
* High verbose output useful for debugging
*/
public void printParameters() {
System.out.println("------------ SpatialPooler Parameters ------------------");
System.out.println("numInputs = " + getNumInputs());
System.out.println("numColumns = " + getNumColumns());
System.out.println("cellsPerColumn = " + getCellsPerColumn());
System.out.println("columnDimensions = " + Arrays.toString(getColumnDimensions()));
System.out.println("numActiveColumnsPerInhArea = " + getNumActiveColumnsPerInhArea());
System.out.println("potentialPct = " + getPotentialPct());
System.out.println("potentialRadius = " + getPotentialRadius());
System.out.println("globalInhibition = " + getGlobalInhibition());
System.out.println("localAreaDensity = " + getLocalAreaDensity());
System.out.println("inhibitionRadius = " + getInhibitionRadius());
System.out.println("stimulusThreshold = " + getStimulusThreshold());
System.out.println("synPermActiveInc = " + getSynPermActiveInc());
System.out.println("synPermInactiveDec = " + getSynPermInactiveDec());
System.out.println("synPermConnected = " + getSynPermConnected());
System.out.println("minPctOverlapDutyCycle = " + getMinPctOverlapDutyCycles());
System.out.println("minPctActiveDutyCycle = " + getMinPctActiveDutyCycles());
System.out.println("dutyCyclePeriod = " + getDutyCyclePeriod());
System.out.println("maxBoost = " + getMaxBoost());
System.out.println("version = " + getVersion());

System.out.println("\n------------ TemporalMemory Parameters ------------------");
System.out.println("activationThreshold = " + getActivationThreshold());
System.out.println("learningRadius = " + getLearningRadius());
System.out.println("minThreshold = " + getMinThreshold());
System.out.println("maxNewSynapseCount = " + getMaxNewSynapseCount());
System.out.println("maxSynapsesPerSegment = " + getMaxSynapsesPerSegment());
System.out.println("maxSegmentsPerCell = " + getMaxSegmentsPerCell());
System.out.println("initialPermanence = " + getInitialPermanence());
System.out.println("connectedPermanence = " + getConnectedPermanence());
System.out.println("permanenceIncrement = " + getPermanenceIncrement());
System.out.println("permanenceDecrement = " + getPermanenceDecrement());
System.out.println("predictedSegmentDecrement = " + getPredictedSegmentDecrement());
System.out.println(getPrintString());
}

/**
Expand All @@ -2275,7 +2243,6 @@ public String getPrintString() {
pw.println("columnDimensions = " + Arrays.toString(getColumnDimensions()));
pw.println("inputDimensions = " + Arrays.toString(getInputDimensions()));
pw.println("cellsPerColumn = " + getCellsPerColumn());

pw.println("random = " + getRandom());
pw.println("seed = " + getSeed());

Expand All @@ -2294,8 +2261,8 @@ public String getPrintString() {
pw.println("synPermConnected = " + getSynPermConnected());
pw.println("synPermBelowStimulusInc = " + getSynPermBelowStimulusInc());
pw.println("synPermTrimThreshold = " + getSynPermTrimThreshold());
pw.println("minPctOverlapDutyCycles = " + getMinPctOverlapDutyCycles());
pw.println("minPctActiveDutyCycles = " + getMinPctActiveDutyCycles());
pw.println("minPctOverlapDutyCycles = " + getMinPctOverlapDutyCycles());
pw.println("minPctActiveDutyCycles = " + getMinPctActiveDutyCycles());
pw.println("dutyCyclePeriod = " + getDutyCyclePeriod());
pw.println("maxBoost = " + getMaxBoost());
pw.println("version = " + getVersion());
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/numenta/nupic/Constants.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.numenta.nupic;


public class Constants {
public static final String ADAPTIVE = "adaptive";

public enum KNN {
ADAPTIVE("adaptive");

private String description;
private KNN(String desc) { this.description = desc; }
/** {@inheritDoc} */
public String toString() { return description; }
}

}
58 changes: 58 additions & 0 deletions src/main/java/org/numenta/nupic/DistanceMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.numenta.nupic;

/**
* <p>
* The method used to compute distance between input patterns and prototype patterns.
* </p><p>
* The possible options are:
* <ul>
* <li>NORM: When distanceNorm is 2, this is the euclidean distance,
* When distanceNorm is 1, this is the manhattan distance
* In general: sum(abs(x-proto) ^ distanceNorm) ^ (1/distanceNorm)
* The distances are normalized such that farthest prototype from
* a given input is 1.0.</li>
*
* <li>RAW_OVERLAP: Only appropriate when inputs are binary. This computes:
* (width of the input) - (# bits of overlap between input
* and prototype).</li>
*
* <li>PCT_INPUT_OVERLAP: Only appropriate for binary inputs. This computes
* 1.0 - (# bits overlap between input and prototype) /
* (# ON bits in input)
*
* <li>PCT_PROTO_OVERLAP: Only appropriate for binary inputs. This computes
* 1.0 - (# bits overlap between input and prototype) /
* (# ON bits in prototype)
*
* <li>PCT_LARGER_OVERLAP: Only appropriate for binary inputs. This computes
* 1.0 - (# bits overlap between input and prototype) /
* max(# ON bits in input, # ON bits in prototype)
*
*
* @author Numenta
* @author cogmission
*/
public enum DistanceMethod {
NORM("norm"),
RAW_OVERLAP("rawOverlap"),
PCT_INPUT_OVERLAP("pctOverlapOfInput"),
PCT_PROTO_OVERLAP("pctOverlapOfProto"),
PCT_LARGER_OVERLAP("pctOverlapOfLarger");

private String key;

/**
* Constructor passing the value used for KNN Distance Method key.
* @param key the string value identifying the distance method type.
*/
private DistanceMethod(String key) {
this.key = key;
}

/**
* {@inheritDoc}
*/
public String toString() {
return key;
}
}
110 changes: 108 additions & 2 deletions src/main/java/org/numenta/nupic/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Random;
import java.util.Set;

import org.joda.time.format.DateTimeFormatter;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class Parameters implements Persistable {
private static final Map<KEY, Object> DEFAULTS_TEMPORAL;
private static final Map<KEY, Object> DEFAULTS_SPATIAL;
private static final Map<KEY, Object> DEFAULTS_ENCODER;
private static final Map<KEY, Object> DEFAULTS_KNN;


static {
Expand Down Expand Up @@ -135,6 +137,28 @@ public class Parameters implements Persistable {
defaultEncoderParams.put(KEY.AUTO_CLASSIFY, Boolean.FALSE);
DEFAULTS_ENCODER = Collections.unmodifiableMap(defaultEncoderParams);
defaultParams.putAll(DEFAULTS_ENCODER);

////////////////// KNNClassifier Defaults ///////////////////
Map<KEY, Object> defaultKNNParams = new ParametersMap();
defaultKNNParams.put(KEY.K, 1);
defaultKNNParams.put(KEY.EXACT, false);
defaultKNNParams.put(KEY.DISTANCE_NORM, 2.0);
defaultKNNParams.put(KEY.DISTANCE_METHOD, DistanceMethod.NORM);
defaultKNNParams.put(KEY.DISTANCE_THRESHOLD, .0);
defaultKNNParams.put(KEY.DO_BINARIZATION, false);
defaultKNNParams.put(KEY.BINARIZATION_THRESHOLD, 0.5);
defaultKNNParams.put(KEY.USE_SPARSE_MEMORY, true);
defaultKNNParams.put(KEY.SPARSE_THRESHOLD, 0.1);
defaultKNNParams.put(KEY.RELATIVE_THRESHOLD, false);
defaultKNNParams.put(KEY.NUM_WINNERS, 0);
defaultKNNParams.put(KEY.NUM_SVD_SAMPLES, -1);
defaultKNNParams.put(KEY.NUM_SVD_DIMS, Constants.KNN.ADAPTIVE);
defaultKNNParams.put(KEY.FRACTION_OF_MAX, -1.0);
defaultKNNParams.put(KEY.MAX_STORED_PATTERNS, -1);
defaultKNNParams.put(KEY.REPLACE_DUPLICATES, false);
defaultKNNParams.put(KEY.KNN_CELLS_PER_COL, 0);
DEFAULTS_KNN = Collections.unmodifiableMap(defaultKNNParams);
defaultParams.putAll(DEFAULTS_KNN);

DEFAULTS_ALL = Collections.unmodifiableMap(defaultParams);
}
Expand Down Expand Up @@ -299,7 +323,6 @@ public static enum KEY {
// Network Layer indicator for auto classifier generation
AUTO_CLASSIFY("hasClassifiers", Boolean.class),


// How many bits to use if encoding the respective date fields.
// e.g. Tuple(bits to use:int, radius:double)
DATEFIELD_SEASON("season", Tuple.class),
Expand All @@ -308,7 +331,81 @@ public static enum KEY {
DATEFIELD_HOLIDAY("holiday", Tuple.class),
DATEFIELD_TOFD("timeOfDay", Tuple.class),
DATEFIELD_CUSTOM("customDays", Tuple.class), // e.g. Tuple(bits:int, List<String>:"mon,tue,fri")
DATEFIELD_PATTERN("formatPattern", String.class);
DATEFIELD_PATTERN("formatPattern", String.class),
DATEFIELD_FORMATTER("dateFormatter", DateTimeFormatter.class),


///////////// KNNClassifier Parameters //////////////
/** The number of nearest neighbors used in the classification of patterns. <b>Must be odd</b> */
K("k", Integer.class),
/** If true, patterns must match exactly when assigning class labels */
EXACT("exact", Boolean.class),
/** When distance method is "norm", this specifies the p value of the Lp-norm */
DISTANCE_NORM("distanceNorm", Double.class),
/**
* The method used to compute distance between input patterns and prototype patterns.
* see({@link DistanceMethod})
*/
DISTANCE_METHOD("distanceMethod", DistanceMethod.class),
/**
* A threshold on the distance between learned
* patterns and a new pattern proposed to be learned. The distance must be
* greater than this threshold in order for the new pattern to be added to
* the classifier's memory
*/
DISTANCE_THRESHOLD("distanceThreshold", Double.class),
/** If True, then scalar inputs will be binarized. */
DO_BINARIZATION("doBinarization", Boolean.class),
/** If doBinarization is True, this specifies the threshold for the binarization of inputs */
BINARIZATION_THRESHOLD("binarizationThreshold", Double.class),
/** If True, classifier will use a sparse memory matrix */
USE_SPARSE_MEMORY("useSparseMemory", Boolean.class),
/**
* If useSparseMemory is True, input variables whose absolute values are
* less than this threshold will be stored as zero
*/
SPARSE_THRESHOLD("sparseThreshold", Double.class),
/** Flag specifying whether to multiply sparseThreshold by max value in input */
RELATIVE_THRESHOLD("relativeThreshold", Boolean.class),
/** Number of elements of the input that are stored. If 0, all elements are stored */
NUM_WINNERS("numWinners", Integer.class),
/**
* Number of samples the must occur before a SVD
* (Singular Value Decomposition) transformation will be performed. If 0,
* the transformation will never be performed
*/
NUM_SVD_SAMPLES("numSVDSamples", Integer.class),
/**
* Controls dimensions kept after SVD transformation. If "adaptive",
* the number is chosen automatically
*/
NUM_SVD_DIMS("numSVDDims", Constants.KNN.class),
/**
* If numSVDDims is "adaptive", this controls the
* smallest singular value that is retained as a fraction of the largest
* singular value
*/
FRACTION_OF_MAX("fractionOfMax", Double.class),
/**
* Limits the maximum number of the training
* patterns stored. When KNN learns in a fixed capacity mode, the unused
* patterns are deleted once the number of stored patterns is greater than
* maxStoredPatterns. A value of -1 is no limit
*/
MAX_STORED_PATTERNS("maxStoredPatterns", Integer.class),
/**
* A boolean flag that determines whether,
* during learning, the classifier replaces duplicates that match exactly,
* even if distThreshold is 0. Should be TRUE for online learning
*/
REPLACE_DUPLICATES("replaceDuplicates", Boolean.class),
/**
* If >= 1, input is assumed to be organized into
* columns, in the same manner as the temporal pooler AND whenever a new
* prototype is stored, only the start cell (first cell) is stored in any
* bursting column
*/
KNN_CELLS_PER_COL("cellsPerCol", Integer.class);


private static final Map<String, KEY> fieldMap = new HashMap<>();
Expand Down Expand Up @@ -459,6 +556,15 @@ public static Parameters getSpatialDefaultParameters() {
public static Parameters getEncoderDefaultParameters() {
return getParameters(DEFAULTS_ENCODER);
}

/**
* Factory method. Return KNNClassifier {@link Parameters} object with default values
* @return
*/
public static Parameters getKNNDefaultParameters() {
return getParameters(DEFAULTS_KNN);
}

/**
* Called internally to populate a {@link Parameters} object with the keys
* and values specified in the passed in map.
Expand Down
Loading