-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[demo] Add Naive Bayes Classifier (#24)
demo: Bump to version 1.1.0
- Loading branch information
Showing
4 changed files
with
318 additions
and
2 deletions.
There are no files selected for viewing
162 changes: 162 additions & 0 deletions
162
demo/client/contracts/classification/NaiveBayesClassifier.sol
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
pragma solidity ^0.5.0; | ||
pragma experimental ABIEncoderV2; | ||
|
||
import "../libs/Math.sol"; | ||
import "../libs/SafeMath.sol"; | ||
import "../libs/SignedSafeMath.sol"; | ||
|
||
import {Classifier64} from "./Classifier.sol"; | ||
|
||
/** | ||
* A Multinomial Naive Bayes classifier. | ||
* Works like in https://scikit-learn.org/stable/modules/naive_bayes.html#multinomial-naive-bayes. | ||
* | ||
* The prediction function is not optimized with typical things like working with log-probabilities because: | ||
* * it is mainly an example, | ||
* * storing log probabilities would be extra work for the update function which should be very efficient, and | ||
* * computing log in solidity is not reliable (yet). | ||
*/ | ||
contract NaiveBayesClassifier is Classifier64 { | ||
using SafeMath for uint256; | ||
using SignedSafeMath for int256; | ||
|
||
/** A class has been added. */ | ||
event AddClass( | ||
/** The name of the class. */ | ||
string name, | ||
/** The index for the class in the members of this classifier. */ | ||
uint index | ||
); | ||
|
||
/** | ||
* Information for a class. | ||
*/ | ||
struct ClassInfo { | ||
/** | ||
* The total number of occurrences of all features (sum of featureCounts). | ||
*/ | ||
uint64 totalFeatureCount; | ||
/** | ||
* The number of occurrences of a feature. | ||
*/ | ||
mapping(uint32 => uint32) featureCounts; | ||
} | ||
|
||
ClassInfo[] public classInfos; | ||
|
||
/** | ||
* The smoothing factor (sometimes called alpha). | ||
* Use toFloat (1 mapped) for Laplace smoothing. | ||
*/ | ||
uint32 public smoothingFactor; | ||
|
||
/** | ||
* The number of samples in each class. | ||
* We use this instead of class prior probabilities. | ||
* Scaled by toFloat. | ||
*/ | ||
uint[] public classCounts; | ||
|
||
/** | ||
* The total number of features throughout all classes. | ||
*/ | ||
uint totalNumFeatures; | ||
|
||
constructor( | ||
string[] memory _classifications, | ||
uint[] memory _classCounts, | ||
uint32[][][] memory _featureCounts, | ||
uint _totalNumFeatures, | ||
uint32 _smoothingFactor) | ||
Classifier64(_classifications) public { | ||
require(_classifications.length > 0, "At least one class is required."); | ||
require(_classifications.length < 2 ** 64, "Too many classes given."); | ||
totalNumFeatures = _totalNumFeatures; | ||
smoothingFactor = _smoothingFactor; | ||
classCounts = _classCounts; | ||
for (uint i = 0; i < _featureCounts.length; ++i){ | ||
ClassInfo memory info = ClassInfo(0); | ||
uint totalFeatureCount = 0; | ||
classInfos.push(info); | ||
ClassInfo storage storedInfo = classInfos[i]; | ||
for (uint j = 0 ; j < _featureCounts[i].length; ++j) { | ||
storedInfo.featureCounts[_featureCounts[i][j][0]] = _featureCounts[i][j][1]; | ||
totalFeatureCount = totalFeatureCount.add(_featureCounts[i][j][1]); | ||
} | ||
require(totalFeatureCount < 2 ** 64, "There are too many features."); | ||
classInfos[i].totalFeatureCount = uint64(totalFeatureCount); | ||
} | ||
} | ||
|
||
// Main overriden methods for training and predicting: | ||
|
||
function addClass(uint classCount, uint32[][] memory featureCounts, string memory classification) public onlyOwner { | ||
require(classifications.length + 1 < 2 ** 64, "There are too many classes already."); | ||
classifications.push(classification); | ||
uint classIndex = classifications.length - 1; | ||
emit AddClass(classification, classIndex); | ||
classCounts.push(classCount); | ||
ClassInfo memory info = ClassInfo(0); | ||
uint totalFeatureCount = 0; | ||
classInfos.push(info); | ||
ClassInfo storage storedInfo = classInfos[classIndex]; | ||
for (uint j = 0 ; j < featureCounts.length; ++j) { | ||
storedInfo.featureCounts[featureCounts[j][0]] = featureCounts[j][1]; | ||
totalFeatureCount = totalFeatureCount.add(featureCounts[j][1]); | ||
} | ||
require(totalFeatureCount < 2 ** 64, "There are too many features."); | ||
classInfos[classIndex].totalFeatureCount = uint64(totalFeatureCount); | ||
} | ||
|
||
function norm(int64[] memory /* data */) public pure returns (uint) { | ||
revert("Normalization is not required."); | ||
} | ||
|
||
function predict(int64[] memory data) public view returns (uint64 bestClass) { | ||
// Implementation: simple calculation (no log-probabilities optimization, see contract docs for the reasons) | ||
bestClass = 0; | ||
uint maxProb = 0; | ||
uint denominatorSmoothFactor = uint(smoothingFactor).mul(totalNumFeatures); | ||
for (uint classIndex = 0; classIndex < classCounts.length; ++classIndex) { | ||
uint prob = classCounts[classIndex].mul(toFloat); | ||
ClassInfo storage info = classInfos[classIndex]; | ||
for (uint featureIndex = 0; featureIndex < data.length; ++featureIndex) { | ||
uint32 featureCount = info.featureCounts[uint32(data[featureIndex])]; | ||
prob = prob.mul(toFloat * featureCount + smoothingFactor).div(toFloat * info.totalFeatureCount + denominatorSmoothFactor); | ||
} | ||
if (prob > maxProb) { | ||
maxProb = prob; | ||
// There are already checks to make sure there are a limited number of classes. | ||
bestClass = uint64(classIndex); | ||
} | ||
} | ||
} | ||
|
||
function update(int64[] memory data, uint64 classification) public onlyOwner { | ||
// Data is binarized (data holds the indices of the features that are present). | ||
require(classification < classifications.length, "Classification is out of bounds."); | ||
classCounts[classification] = classCounts[classification].add(1); | ||
|
||
ClassInfo storage info = classInfos[classification]; | ||
|
||
uint totalFeatureCount = data.length.add(info.totalFeatureCount); | ||
require(totalFeatureCount < 2 ** 64, "Feature count will be too high."); | ||
info.totalFeatureCount = uint64(totalFeatureCount); | ||
|
||
for (uint dataIndex = 0; dataIndex < data.length; ++dataIndex) { | ||
int64 featureIndex = data[dataIndex]; | ||
require(featureIndex < 2 ** 32, "A feature index is too high."); | ||
info.featureCounts[uint32(featureIndex)] += 1; | ||
} | ||
} | ||
|
||
// Useful methods to view the underlying data: | ||
|
||
function getClassTotalFeatureCount(uint classIndex) public view returns (uint64) { | ||
return classInfos[classIndex].totalFeatureCount; | ||
} | ||
|
||
function getFeatureCount(uint classIndex, uint32 featureIndex) public view returns (uint32) { | ||
return classInfos[classIndex].featureCounts[featureIndex]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
demo/client/test/contracts/classification/naivebayes.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
const NaiveBayesClassifier = artifacts.require("./classification/NaiveBayesClassifier"); | ||
|
||
contract('NaiveBayesClassifier', function (accounts) { | ||
const toFloat = 1E9; | ||
|
||
const smoothingFactor = convertNum(1); | ||
const classifications = ["ALARM", "WEATHER"]; | ||
const vocab = {}; | ||
let vocabLength = 0; | ||
let classifier; | ||
|
||
function convertNum(num) { | ||
return web3.utils.toBN(Math.round(num * toFloat)); | ||
} | ||
|
||
function parseBN(num) { | ||
if (web3.utils.isBN(num)) { | ||
return num.toNumber(); | ||
} else { | ||
assert.typeOf(num, 'number'); | ||
return num; | ||
} | ||
} | ||
|
||
function parseFloatBN(bn) { | ||
assert(web3.utils.isBN(bn), `${bn} is not a BN`); | ||
// Can't divide first since a BN can only be an integer. | ||
return bn.toNumber() / toFloat; | ||
} | ||
|
||
function mapFeatures(query) { | ||
return query.split(" ").map(w => { | ||
let result = vocab[w]; | ||
if (result === undefined) { | ||
vocab[w] = result = vocabLength++; | ||
} | ||
return result; | ||
}); | ||
} | ||
|
||
before("deploy classifier", async () => { | ||
const queries = [ | ||
"alarm for 11 am tomorrow", | ||
"will i need a jacket for tomorrow"]; | ||
const featureMappedQueries = queries.map(mapFeatures); | ||
const featureCounts = featureMappedQueries.map(fv => { | ||
const result = {}; | ||
fv.forEach(v => { | ||
if (!(v in result)) { | ||
result[v] = 0; | ||
} | ||
result[v] += 1; | ||
}); | ||
return Object.entries(result).map(pair => [parseInt(pair[0]), pair[1]].map(web3.utils.toBN)); | ||
}); | ||
const classCounts = [1, 1]; | ||
const totalNumFeatures = vocabLength; | ||
classifier = await NaiveBayesClassifier.new(classifications, classCounts, featureCounts, totalNumFeatures, smoothingFactor); | ||
|
||
assert.equal(await classifier.getClassTotalFeatureCount(0).then(parseBN), 5, | ||
"Total feature count for class 0."); | ||
assert.equal(await classifier.getClassTotalFeatureCount(1).then(parseBN), 7, | ||
"Total feature count for class 1."); | ||
|
||
for (let featureIndex = 0; featureIndex < 5; ++featureIndex) { | ||
assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 1); | ||
} | ||
for (let featureIndex = 5; featureIndex < 11; ++featureIndex) { | ||
assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 0); | ||
} | ||
|
||
assert.equal(await classifier.getFeatureCount(1, 0).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(1, 1).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 2).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(1, 3).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(1, 4).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 5).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 6).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 7).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 8).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 9).then(parseBN), 1); | ||
assert.equal(await classifier.getFeatureCount(1, 10).then(parseBN), 0); | ||
}); | ||
|
||
it("...should predict the classification ALARM", async () => { | ||
const data = mapFeatures("alarm for 9 am tomorrow"); | ||
const prediction = await classifier.predict(data).then(parseBN); | ||
assert.equal(prediction, 0); | ||
}); | ||
|
||
it("...should predict the classification WEATHER", async () => { | ||
const data = mapFeatures("will i need a jacket today"); | ||
const prediction = await classifier.predict(data).then(parseBN); | ||
assert.equal(prediction, 1); | ||
}); | ||
|
||
it("...should update", async () => { | ||
const newFeature = vocabLength + 10; | ||
const predictionData = [newFeature]; | ||
assert.equal(await classifier.predict(predictionData).then(parseBN), 0); | ||
|
||
const data = [0, 1, 2, newFeature]; | ||
const classification = 1; | ||
const prevFeatureCounts = []; | ||
for (let i in data) { | ||
const featureIndex = data[i]; | ||
const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN); | ||
await prevFeatureCounts.push(featureCount); | ||
} | ||
const prevTotalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN); | ||
|
||
await classifier.update(data, classification); | ||
|
||
for (let i in prevFeatureCounts) { | ||
const featureIndex = data[i]; | ||
const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN); | ||
assert.equal(featureCount, prevFeatureCounts[i] + 1); | ||
} | ||
|
||
const totalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN); | ||
assert.equal(totalFeatureCount, prevTotalFeatureCount + data.length); | ||
|
||
assert.equal(await classifier.predict(predictionData).then(parseBN), classification); | ||
}); | ||
|
||
it("...should add class", async () => { | ||
const classCount = 3; | ||
const featureCounts = [[0, 2], [1, 3], [6, 5]]; | ||
const classification = "NEW"; | ||
const originalNumClassifications = await classifier.getNumClassifications().then(parseBN); | ||
classifier.addClass(classCount, featureCounts, classification); | ||
const newNumClassifications = await classifier.getNumClassifications().then(parseBN); | ||
assert.equal(newNumClassifications, originalNumClassifications + 1); | ||
const classIndex = originalNumClassifications; | ||
|
||
assert.equal(await classifier.getClassTotalFeatureCount(classIndex).then(parseBN), | ||
featureCounts.map(pair => pair[1]).reduce((a, b) => a + b), | ||
"Total feature count for the new class is wrong."); | ||
|
||
assert.equal(await classifier.getFeatureCount(classIndex, 0).then(parseBN), 2); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 1).then(parseBN), 3); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 2).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 3).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 4).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 5).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 6).then(parseBN), 5); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 7).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 8).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 9).then(parseBN), 0); | ||
assert.equal(await classifier.getFeatureCount(classIndex, 10).then(parseBN), 0); | ||
|
||
assert.equal(await classifier.predict([0, 1, 6]).then(parseBN), classIndex); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
{ | ||
"name": "decai-demo", | ||
"version": "1.0.1", | ||
"version": "1.1.0", | ||
"license": "MIT", | ||
"private": true, | ||
"scripts": { | ||
|