Skip to content

Commit

Permalink
[demo] Add Naive Bayes Classifier (#24)
Browse files Browse the repository at this point in the history
demo: Bump to version 1.1.0
  • Loading branch information
juharris authored Jul 31, 2019
1 parent e2ebaa3 commit df7a2fc
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 2 deletions.
162 changes: 162 additions & 0 deletions demo/client/contracts/classification/NaiveBayesClassifier.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
pragma solidity ^0.5.0;
pragma experimental ABIEncoderV2;

import "../libs/Math.sol";
import "../libs/SafeMath.sol";
import "../libs/SignedSafeMath.sol";

import {Classifier64} from "./Classifier.sol";

/**
* A Multinomial Naive Bayes classifier.
* Works like in https://scikit-learn.org/stable/modules/naive_bayes.html#multinomial-naive-bayes.
*
* The prediction function is not optimized with typical things like working with log-probabilities because:
* * it is mainly an example,
* * storing log probabilities would be extra work for the update function which should be very efficient, and
* * computing log in solidity is not reliable (yet).
*/
contract NaiveBayesClassifier is Classifier64 {
using SafeMath for uint256;
using SignedSafeMath for int256;

/** A class has been added. */
event AddClass(
/** The name of the class. */
string name,
/** The index for the class in the members of this classifier. */
uint index
);

/**
* Information for a class.
*/
struct ClassInfo {
/**
* The total number of occurrences of all features (sum of featureCounts).
*/
uint64 totalFeatureCount;
/**
* The number of occurrences of a feature.
*/
mapping(uint32 => uint32) featureCounts;
}

ClassInfo[] public classInfos;

/**
* The smoothing factor (sometimes called alpha).
* Use toFloat (1 mapped) for Laplace smoothing.
*/
uint32 public smoothingFactor;

/**
* The number of samples in each class.
* We use this instead of class prior probabilities.
* Scaled by toFloat.
*/
uint[] public classCounts;

/**
* The total number of features throughout all classes.
*/
uint totalNumFeatures;

constructor(
string[] memory _classifications,
uint[] memory _classCounts,
uint32[][][] memory _featureCounts,
uint _totalNumFeatures,
uint32 _smoothingFactor)
Classifier64(_classifications) public {
require(_classifications.length > 0, "At least one class is required.");
require(_classifications.length < 2 ** 64, "Too many classes given.");
totalNumFeatures = _totalNumFeatures;
smoothingFactor = _smoothingFactor;
classCounts = _classCounts;
for (uint i = 0; i < _featureCounts.length; ++i){
ClassInfo memory info = ClassInfo(0);
uint totalFeatureCount = 0;
classInfos.push(info);
ClassInfo storage storedInfo = classInfos[i];
for (uint j = 0 ; j < _featureCounts[i].length; ++j) {
storedInfo.featureCounts[_featureCounts[i][j][0]] = _featureCounts[i][j][1];
totalFeatureCount = totalFeatureCount.add(_featureCounts[i][j][1]);
}
require(totalFeatureCount < 2 ** 64, "There are too many features.");
classInfos[i].totalFeatureCount = uint64(totalFeatureCount);
}
}

// Main overriden methods for training and predicting:

function addClass(uint classCount, uint32[][] memory featureCounts, string memory classification) public onlyOwner {
require(classifications.length + 1 < 2 ** 64, "There are too many classes already.");
classifications.push(classification);
uint classIndex = classifications.length - 1;
emit AddClass(classification, classIndex);
classCounts.push(classCount);
ClassInfo memory info = ClassInfo(0);
uint totalFeatureCount = 0;
classInfos.push(info);
ClassInfo storage storedInfo = classInfos[classIndex];
for (uint j = 0 ; j < featureCounts.length; ++j) {
storedInfo.featureCounts[featureCounts[j][0]] = featureCounts[j][1];
totalFeatureCount = totalFeatureCount.add(featureCounts[j][1]);
}
require(totalFeatureCount < 2 ** 64, "There are too many features.");
classInfos[classIndex].totalFeatureCount = uint64(totalFeatureCount);
}

function norm(int64[] memory /* data */) public pure returns (uint) {
revert("Normalization is not required.");
}

function predict(int64[] memory data) public view returns (uint64 bestClass) {
// Implementation: simple calculation (no log-probabilities optimization, see contract docs for the reasons)
bestClass = 0;
uint maxProb = 0;
uint denominatorSmoothFactor = uint(smoothingFactor).mul(totalNumFeatures);
for (uint classIndex = 0; classIndex < classCounts.length; ++classIndex) {
uint prob = classCounts[classIndex].mul(toFloat);
ClassInfo storage info = classInfos[classIndex];
for (uint featureIndex = 0; featureIndex < data.length; ++featureIndex) {
uint32 featureCount = info.featureCounts[uint32(data[featureIndex])];
prob = prob.mul(toFloat * featureCount + smoothingFactor).div(toFloat * info.totalFeatureCount + denominatorSmoothFactor);
}
if (prob > maxProb) {
maxProb = prob;
// There are already checks to make sure there are a limited number of classes.
bestClass = uint64(classIndex);
}
}
}

function update(int64[] memory data, uint64 classification) public onlyOwner {
// Data is binarized (data holds the indices of the features that are present).
require(classification < classifications.length, "Classification is out of bounds.");
classCounts[classification] = classCounts[classification].add(1);

ClassInfo storage info = classInfos[classification];

uint totalFeatureCount = data.length.add(info.totalFeatureCount);
require(totalFeatureCount < 2 ** 64, "Feature count will be too high.");
info.totalFeatureCount = uint64(totalFeatureCount);

for (uint dataIndex = 0; dataIndex < data.length; ++dataIndex) {
int64 featureIndex = data[dataIndex];
require(featureIndex < 2 ** 32, "A feature index is too high.");
info.featureCounts[uint32(featureIndex)] += 1;
}
}

// Useful methods to view the underlying data:

function getClassTotalFeatureCount(uint classIndex) public view returns (uint64) {
return classInfos[classIndex].totalFeatureCount;
}

function getFeatureCount(uint classIndex, uint32 featureIndex) public view returns (uint32) {
return classInfos[classIndex].featureCounts[featureIndex];
}
}
2 changes: 1 addition & 1 deletion demo/client/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "decai-demo-client",
"version": "1.0.1",
"version": "1.1.0",
"license": "MIT",
"private": true,
"proxy": "http://localhost:5387/",
Expand Down
154 changes: 154 additions & 0 deletions demo/client/test/contracts/classification/naivebayes.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
const NaiveBayesClassifier = artifacts.require("./classification/NaiveBayesClassifier");

contract('NaiveBayesClassifier', function (accounts) {
const toFloat = 1E9;

const smoothingFactor = convertNum(1);
const classifications = ["ALARM", "WEATHER"];
const vocab = {};
let vocabLength = 0;
let classifier;

function convertNum(num) {
return web3.utils.toBN(Math.round(num * toFloat));
}

function parseBN(num) {
if (web3.utils.isBN(num)) {
return num.toNumber();
} else {
assert.typeOf(num, 'number');
return num;
}
}

function parseFloatBN(bn) {
assert(web3.utils.isBN(bn), `${bn} is not a BN`);
// Can't divide first since a BN can only be an integer.
return bn.toNumber() / toFloat;
}

function mapFeatures(query) {
return query.split(" ").map(w => {
let result = vocab[w];
if (result === undefined) {
vocab[w] = result = vocabLength++;
}
return result;
});
}

before("deploy classifier", async () => {
const queries = [
"alarm for 11 am tomorrow",
"will i need a jacket for tomorrow"];
const featureMappedQueries = queries.map(mapFeatures);
const featureCounts = featureMappedQueries.map(fv => {
const result = {};
fv.forEach(v => {
if (!(v in result)) {
result[v] = 0;
}
result[v] += 1;
});
return Object.entries(result).map(pair => [parseInt(pair[0]), pair[1]].map(web3.utils.toBN));
});
const classCounts = [1, 1];
const totalNumFeatures = vocabLength;
classifier = await NaiveBayesClassifier.new(classifications, classCounts, featureCounts, totalNumFeatures, smoothingFactor);

assert.equal(await classifier.getClassTotalFeatureCount(0).then(parseBN), 5,
"Total feature count for class 0.");
assert.equal(await classifier.getClassTotalFeatureCount(1).then(parseBN), 7,
"Total feature count for class 1.");

for (let featureIndex = 0; featureIndex < 5; ++featureIndex) {
assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 1);
}
for (let featureIndex = 5; featureIndex < 11; ++featureIndex) {
assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 0);
}

assert.equal(await classifier.getFeatureCount(1, 0).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(1, 1).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 2).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(1, 3).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(1, 4).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 5).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 6).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 7).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 8).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 9).then(parseBN), 1);
assert.equal(await classifier.getFeatureCount(1, 10).then(parseBN), 0);
});

it("...should predict the classification ALARM", async () => {
const data = mapFeatures("alarm for 9 am tomorrow");
const prediction = await classifier.predict(data).then(parseBN);
assert.equal(prediction, 0);
});

it("...should predict the classification WEATHER", async () => {
const data = mapFeatures("will i need a jacket today");
const prediction = await classifier.predict(data).then(parseBN);
assert.equal(prediction, 1);
});

it("...should update", async () => {
const newFeature = vocabLength + 10;
const predictionData = [newFeature];
assert.equal(await classifier.predict(predictionData).then(parseBN), 0);

const data = [0, 1, 2, newFeature];
const classification = 1;
const prevFeatureCounts = [];
for (let i in data) {
const featureIndex = data[i];
const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN);
await prevFeatureCounts.push(featureCount);
}
const prevTotalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN);

await classifier.update(data, classification);

for (let i in prevFeatureCounts) {
const featureIndex = data[i];
const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN);
assert.equal(featureCount, prevFeatureCounts[i] + 1);
}

const totalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN);
assert.equal(totalFeatureCount, prevTotalFeatureCount + data.length);

assert.equal(await classifier.predict(predictionData).then(parseBN), classification);
});

it("...should add class", async () => {
const classCount = 3;
const featureCounts = [[0, 2], [1, 3], [6, 5]];
const classification = "NEW";
const originalNumClassifications = await classifier.getNumClassifications().then(parseBN);
classifier.addClass(classCount, featureCounts, classification);
const newNumClassifications = await classifier.getNumClassifications().then(parseBN);
assert.equal(newNumClassifications, originalNumClassifications + 1);
const classIndex = originalNumClassifications;

assert.equal(await classifier.getClassTotalFeatureCount(classIndex).then(parseBN),
featureCounts.map(pair => pair[1]).reduce((a, b) => a + b),
"Total feature count for the new class is wrong.");

assert.equal(await classifier.getFeatureCount(classIndex, 0).then(parseBN), 2);
assert.equal(await classifier.getFeatureCount(classIndex, 1).then(parseBN), 3);
assert.equal(await classifier.getFeatureCount(classIndex, 2).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 3).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 4).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 5).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 6).then(parseBN), 5);
assert.equal(await classifier.getFeatureCount(classIndex, 7).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 8).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 9).then(parseBN), 0);
assert.equal(await classifier.getFeatureCount(classIndex, 10).then(parseBN), 0);

assert.equal(await classifier.predict([0, 1, 6]).then(parseBN), classIndex);
});
});
2 changes: 1 addition & 1 deletion demo/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "decai-demo",
"version": "1.0.1",
"version": "1.1.0",
"license": "MIT",
"private": true,
"scripts": {
Expand Down

0 comments on commit df7a2fc

Please sign in to comment.