-
Notifications
You must be signed in to change notification settings - Fork 0
/
script.js
100 lines (83 loc) · 2.93 KB
/
script.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
window.onload = function () {
let canvas = document.querySelector(".canvas");
let clearButton = document.querySelector(".clearButton");
let predictButton = document.querySelector(".predictButton");
let result = document.querySelector(".result");
canvas.height = 200;
canvas.width = 200;
canvas.style.backgroundColor = "black";
let c = canvas.getContext("2d");
let boundings = canvas.getBoundingClientRect();
function getXY(e) {
let xcord = e.clientX == undefined ? e.touches[0].clientX : e.clientX;
let ycord = e.clientY == undefined ? e.touches[0].clientY : e.clientY;
return [xcord - boundings.left, ycord - boundings.top];
}
let x, y;
function draw(e) {
e.preventDefault();
[x, y] = getXY(e);
c.strokeStyle = "white";
c.lineWidth = 20;
c.lineTo(x, y);
c.stroke();
}
function startCaptureCord(e) {
// console.log("Start Capture");
[x, y] = getXY(e);
c.beginPath();
c.moveTo(x, y);
canvas.addEventListener("touchmove", draw);
canvas.addEventListener("mousemove", draw);
}
function stopCaptureCord(e) {
// console.log("Stop Capture");
canvas.removeEventListener("touchmove", draw);
canvas.removeEventListener("mousemove", draw);
}
canvas.addEventListener("touchstart", startCaptureCord);
canvas.addEventListener("mousedown", startCaptureCord);
document.addEventListener("touchend", stopCaptureCord);
document.addEventListener("mouseup", stopCaptureCord);
function clear() {
c.closePath();
c.clearRect(0, 0, canvas.width, canvas.height);
}
clearButton.addEventListener("click", clear);
function preprocessCanvas(image) {
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([28, 28])
.mean(2)
.toFloat()
.reshape([1, 784]);
return tensor;
}
async function predictWithData(data) {
let model = undefined;
model = await tf.loadLayersModel("models/model.json");
console.log("model loaded");
// console.log(model);
let output = model.predict(data);
let outputArray = output.dataSync();
let res = outputArray.indexOf(Math.max.apply(null, outputArray));
console.log(res);
result.innerText = res;
}
function predict() {
let digitDataArr = preprocessCanvas(canvas)
let digitData = digitDataArr.dataSync();
console.log(digitData.length);
let digitDataArray = [];
console.log(digitData.data);
for (let i = 0; i < digitData.length; i++) {
digitDataArray[i] = digitData[i];
}
let digitData2d = [];
while (digitDataArray.length) digitData2d.push(digitDataArray.splice(0, 28));
// console.log(digitDataArr.dataSync());
let predictData = tf.tensor3d([digitData2d]);
console.log(predictData.dataSync());
predictWithData(predictData);
}
predictButton.addEventListener("click", predict);
}