From df7a2fc20df3232c7a6aae38ce4733c01368543e Mon Sep 17 00:00:00 2001 From: "Justin D. Harris" Date: Wed, 31 Jul 2019 17:50:53 -0400 Subject: [PATCH] [demo] Add Naive Bayes Classifier (#24) demo: Bump to version 1.1.0 --- .../classification/NaiveBayesClassifier.sol | 162 ++++++++++++++++++ demo/client/package.json | 2 +- .../contracts/classification/naivebayes.js | 154 +++++++++++++++++ demo/package.json | 2 +- 4 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 demo/client/contracts/classification/NaiveBayesClassifier.sol create mode 100644 demo/client/test/contracts/classification/naivebayes.js diff --git a/demo/client/contracts/classification/NaiveBayesClassifier.sol b/demo/client/contracts/classification/NaiveBayesClassifier.sol new file mode 100644 index 00000000..1f76476f --- /dev/null +++ b/demo/client/contracts/classification/NaiveBayesClassifier.sol @@ -0,0 +1,162 @@ +pragma solidity ^0.5.0; +pragma experimental ABIEncoderV2; + +import "../libs/Math.sol"; +import "../libs/SafeMath.sol"; +import "../libs/SignedSafeMath.sol"; + +import {Classifier64} from "./Classifier.sol"; + +/** + * A Multinomial Naive Bayes classifier. + * Works like in https://scikit-learn.org/stable/modules/naive_bayes.html#multinomial-naive-bayes. + * + * The prediction function is not optimized with typical things like working with log-probabilities because: + * * it is mainly an example, + * * storing log probabilities would be extra work for the update function which should be very efficient, and + * * computing log in solidity is not reliable (yet). + */ +contract NaiveBayesClassifier is Classifier64 { + using SafeMath for uint256; + using SignedSafeMath for int256; + + /** A class has been added. */ + event AddClass( + /** The name of the class. */ + string name, + /** The index for the class in the members of this classifier. */ + uint index + ); + + /** + * Information for a class. + */ + struct ClassInfo { + /** + * The total number of occurrences of all features (sum of featureCounts). + */ + uint64 totalFeatureCount; + /** + * The number of occurrences of a feature. + */ + mapping(uint32 => uint32) featureCounts; + } + + ClassInfo[] public classInfos; + + /** + * The smoothing factor (sometimes called alpha). + * Use toFloat (1 mapped) for Laplace smoothing. + */ + uint32 public smoothingFactor; + + /** + * The number of samples in each class. + * We use this instead of class prior probabilities. + * Scaled by toFloat. + */ + uint[] public classCounts; + + /** + * The total number of features throughout all classes. + */ + uint totalNumFeatures; + + constructor( + string[] memory _classifications, + uint[] memory _classCounts, + uint32[][][] memory _featureCounts, + uint _totalNumFeatures, + uint32 _smoothingFactor) + Classifier64(_classifications) public { + require(_classifications.length > 0, "At least one class is required."); + require(_classifications.length < 2 ** 64, "Too many classes given."); + totalNumFeatures = _totalNumFeatures; + smoothingFactor = _smoothingFactor; + classCounts = _classCounts; + for (uint i = 0; i < _featureCounts.length; ++i){ + ClassInfo memory info = ClassInfo(0); + uint totalFeatureCount = 0; + classInfos.push(info); + ClassInfo storage storedInfo = classInfos[i]; + for (uint j = 0 ; j < _featureCounts[i].length; ++j) { + storedInfo.featureCounts[_featureCounts[i][j][0]] = _featureCounts[i][j][1]; + totalFeatureCount = totalFeatureCount.add(_featureCounts[i][j][1]); + } + require(totalFeatureCount < 2 ** 64, "There are too many features."); + classInfos[i].totalFeatureCount = uint64(totalFeatureCount); + } + } + + // Main overriden methods for training and predicting: + + function addClass(uint classCount, uint32[][] memory featureCounts, string memory classification) public onlyOwner { + require(classifications.length + 1 < 2 ** 64, "There are too many classes already."); + classifications.push(classification); + uint classIndex = classifications.length - 1; + emit AddClass(classification, classIndex); + classCounts.push(classCount); + ClassInfo memory info = ClassInfo(0); + uint totalFeatureCount = 0; + classInfos.push(info); + ClassInfo storage storedInfo = classInfos[classIndex]; + for (uint j = 0 ; j < featureCounts.length; ++j) { + storedInfo.featureCounts[featureCounts[j][0]] = featureCounts[j][1]; + totalFeatureCount = totalFeatureCount.add(featureCounts[j][1]); + } + require(totalFeatureCount < 2 ** 64, "There are too many features."); + classInfos[classIndex].totalFeatureCount = uint64(totalFeatureCount); + } + + function norm(int64[] memory /* data */) public pure returns (uint) { + revert("Normalization is not required."); + } + + function predict(int64[] memory data) public view returns (uint64 bestClass) { + // Implementation: simple calculation (no log-probabilities optimization, see contract docs for the reasons) + bestClass = 0; + uint maxProb = 0; + uint denominatorSmoothFactor = uint(smoothingFactor).mul(totalNumFeatures); + for (uint classIndex = 0; classIndex < classCounts.length; ++classIndex) { + uint prob = classCounts[classIndex].mul(toFloat); + ClassInfo storage info = classInfos[classIndex]; + for (uint featureIndex = 0; featureIndex < data.length; ++featureIndex) { + uint32 featureCount = info.featureCounts[uint32(data[featureIndex])]; + prob = prob.mul(toFloat * featureCount + smoothingFactor).div(toFloat * info.totalFeatureCount + denominatorSmoothFactor); + } + if (prob > maxProb) { + maxProb = prob; + // There are already checks to make sure there are a limited number of classes. + bestClass = uint64(classIndex); + } + } + } + + function update(int64[] memory data, uint64 classification) public onlyOwner { + // Data is binarized (data holds the indices of the features that are present). + require(classification < classifications.length, "Classification is out of bounds."); + classCounts[classification] = classCounts[classification].add(1); + + ClassInfo storage info = classInfos[classification]; + + uint totalFeatureCount = data.length.add(info.totalFeatureCount); + require(totalFeatureCount < 2 ** 64, "Feature count will be too high."); + info.totalFeatureCount = uint64(totalFeatureCount); + + for (uint dataIndex = 0; dataIndex < data.length; ++dataIndex) { + int64 featureIndex = data[dataIndex]; + require(featureIndex < 2 ** 32, "A feature index is too high."); + info.featureCounts[uint32(featureIndex)] += 1; + } + } + + // Useful methods to view the underlying data: + + function getClassTotalFeatureCount(uint classIndex) public view returns (uint64) { + return classInfos[classIndex].totalFeatureCount; + } + + function getFeatureCount(uint classIndex, uint32 featureIndex) public view returns (uint32) { + return classInfos[classIndex].featureCounts[featureIndex]; + } +} diff --git a/demo/client/package.json b/demo/client/package.json index fbee2d5a..7cf40566 100644 --- a/demo/client/package.json +++ b/demo/client/package.json @@ -1,6 +1,6 @@ { "name": "decai-demo-client", - "version": "1.0.1", + "version": "1.1.0", "license": "MIT", "private": true, "proxy": "http://localhost:5387/", diff --git a/demo/client/test/contracts/classification/naivebayes.js b/demo/client/test/contracts/classification/naivebayes.js new file mode 100644 index 00000000..a089f291 --- /dev/null +++ b/demo/client/test/contracts/classification/naivebayes.js @@ -0,0 +1,154 @@ +const NaiveBayesClassifier = artifacts.require("./classification/NaiveBayesClassifier"); + +contract('NaiveBayesClassifier', function (accounts) { + const toFloat = 1E9; + + const smoothingFactor = convertNum(1); + const classifications = ["ALARM", "WEATHER"]; + const vocab = {}; + let vocabLength = 0; + let classifier; + + function convertNum(num) { + return web3.utils.toBN(Math.round(num * toFloat)); + } + + function parseBN(num) { + if (web3.utils.isBN(num)) { + return num.toNumber(); + } else { + assert.typeOf(num, 'number'); + return num; + } + } + + function parseFloatBN(bn) { + assert(web3.utils.isBN(bn), `${bn} is not a BN`); + // Can't divide first since a BN can only be an integer. + return bn.toNumber() / toFloat; + } + + function mapFeatures(query) { + return query.split(" ").map(w => { + let result = vocab[w]; + if (result === undefined) { + vocab[w] = result = vocabLength++; + } + return result; + }); + } + + before("deploy classifier", async () => { + const queries = [ + "alarm for 11 am tomorrow", + "will i need a jacket for tomorrow"]; + const featureMappedQueries = queries.map(mapFeatures); + const featureCounts = featureMappedQueries.map(fv => { + const result = {}; + fv.forEach(v => { + if (!(v in result)) { + result[v] = 0; + } + result[v] += 1; + }); + return Object.entries(result).map(pair => [parseInt(pair[0]), pair[1]].map(web3.utils.toBN)); + }); + const classCounts = [1, 1]; + const totalNumFeatures = vocabLength; + classifier = await NaiveBayesClassifier.new(classifications, classCounts, featureCounts, totalNumFeatures, smoothingFactor); + + assert.equal(await classifier.getClassTotalFeatureCount(0).then(parseBN), 5, + "Total feature count for class 0."); + assert.equal(await classifier.getClassTotalFeatureCount(1).then(parseBN), 7, + "Total feature count for class 1."); + + for (let featureIndex = 0; featureIndex < 5; ++featureIndex) { + assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 1); + } + for (let featureIndex = 5; featureIndex < 11; ++featureIndex) { + assert.equal(await classifier.getFeatureCount(0, featureIndex).then(parseBN), 0); + } + + assert.equal(await classifier.getFeatureCount(1, 0).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(1, 1).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 2).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(1, 3).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(1, 4).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 5).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 6).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 7).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 8).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 9).then(parseBN), 1); + assert.equal(await classifier.getFeatureCount(1, 10).then(parseBN), 0); + }); + + it("...should predict the classification ALARM", async () => { + const data = mapFeatures("alarm for 9 am tomorrow"); + const prediction = await classifier.predict(data).then(parseBN); + assert.equal(prediction, 0); + }); + + it("...should predict the classification WEATHER", async () => { + const data = mapFeatures("will i need a jacket today"); + const prediction = await classifier.predict(data).then(parseBN); + assert.equal(prediction, 1); + }); + + it("...should update", async () => { + const newFeature = vocabLength + 10; + const predictionData = [newFeature]; + assert.equal(await classifier.predict(predictionData).then(parseBN), 0); + + const data = [0, 1, 2, newFeature]; + const classification = 1; + const prevFeatureCounts = []; + for (let i in data) { + const featureIndex = data[i]; + const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN); + await prevFeatureCounts.push(featureCount); + } + const prevTotalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN); + + await classifier.update(data, classification); + + for (let i in prevFeatureCounts) { + const featureIndex = data[i]; + const featureCount = await classifier.getFeatureCount(classification, featureIndex).then(parseBN); + assert.equal(featureCount, prevFeatureCounts[i] + 1); + } + + const totalFeatureCount = await classifier.getClassTotalFeatureCount(classification).then(parseBN); + assert.equal(totalFeatureCount, prevTotalFeatureCount + data.length); + + assert.equal(await classifier.predict(predictionData).then(parseBN), classification); + }); + + it("...should add class", async () => { + const classCount = 3; + const featureCounts = [[0, 2], [1, 3], [6, 5]]; + const classification = "NEW"; + const originalNumClassifications = await classifier.getNumClassifications().then(parseBN); + classifier.addClass(classCount, featureCounts, classification); + const newNumClassifications = await classifier.getNumClassifications().then(parseBN); + assert.equal(newNumClassifications, originalNumClassifications + 1); + const classIndex = originalNumClassifications; + + assert.equal(await classifier.getClassTotalFeatureCount(classIndex).then(parseBN), + featureCounts.map(pair => pair[1]).reduce((a, b) => a + b), + "Total feature count for the new class is wrong."); + + assert.equal(await classifier.getFeatureCount(classIndex, 0).then(parseBN), 2); + assert.equal(await classifier.getFeatureCount(classIndex, 1).then(parseBN), 3); + assert.equal(await classifier.getFeatureCount(classIndex, 2).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 3).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 4).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 5).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 6).then(parseBN), 5); + assert.equal(await classifier.getFeatureCount(classIndex, 7).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 8).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 9).then(parseBN), 0); + assert.equal(await classifier.getFeatureCount(classIndex, 10).then(parseBN), 0); + + assert.equal(await classifier.predict([0, 1, 6]).then(parseBN), classIndex); + }); +}); diff --git a/demo/package.json b/demo/package.json index cb7ff985..46d7826e 100644 --- a/demo/package.json +++ b/demo/package.json @@ -1,6 +1,6 @@ { "name": "decai-demo", - "version": "1.0.1", + "version": "1.1.0", "license": "MIT", "private": true, "scripts": {