forked from pratikpc/tf-kmeans
-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
61 lines (55 loc) · 1.79 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
const KMeans = require("tf-kmeans");
const tf = require("@tensorflow/tfjs");
function SyncTest() {
tf.tidy(() => {
const kmeans = new KMeans.default({
k: 2,
maxIter: 30,
distanceFunction: KMeans.default.EuclideanDistance
});
const dataset = tf.tensor([[2, 2, 2], [5, 5, 5], [3, 3, 3], [4, 4, 4], [7, 8, 7]]);
const predictions = kmeans.Train(
dataset
);
console.log("Assigned To ", predictions.arraySync());
console.log("Centroids Used are ", kmeans.Centroids().arraySync());
console.log("Prediction for Given Value is");
kmeans.Predict(tf.tensor([2, 3, 2])).print();
console.log("Amount of Memory Used is ", tf.memory());
// Use this In case kmeans not executed in Tidy Function
kmeans.Dispose();
predictions.dispose();
dataset.dispose();
});
}
async function AsyncTest() {
const kmeans = new KMeans.default({
k: 3,
maxIter: 30,
distanceFunction: KMeans.default.EuclideanDistance
});
const dataset = tf.tensor([[2, 2, 2], [5, 5, 5], [3, 3, 3], [4, 4, 4], [7, 8, 7]]);
console.log("\n\nAsync Test");
const predictions = await kmeans.TrainAsync(
dataset,
// Called At End of Every Iteration
async(iter, centroid, preds)=>{
console.log("===");
console.log("Iteration Count", iter);
console.log("Centroid ", await centroid.array());
console.log("Prediction ", await preds.array());
console.log("===");
// You could instead use TFVIS for Plotting Here
}
);
console.log("Assigned To ", await predictions.array());
console.log("Centroids Used are ", await kmeans.Centroids().array());
console.log("Prediction for Given Value is");
kmeans.Predict(tf.tensor([2, 3, 2])).print();
console.log("Amount of Memory Used is ", tf.memory());
kmeans.Dispose();
predictions.dispose();
dataset.dispose();
}
SyncTest();
AsyncTest().then(() => console.log("Hi"));