forked from Wei-1/Scala-Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OneHot.scala
61 lines (58 loc) · 2.14 KB
/
OneHot.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Wei Chen - OneHot
// 2018-11-07
package com.scalaml.algorithm
class OneHot(
// name, type
var table: Array[(String, String)],
// name, cname list
var categoryTable: Map[String, Array[String]] = Map[String, Array[String]](),
// name, category (null for number)
var decodeTable: Array[(String, String)] = Array[(String, String)]()
) {
def encode(data: Array[Array[String]]): Array[Array[Double]] = try {
if(data.head.size == table.size) {
// Pre-process
data.foreach { row =>
row.zip(table).foreach { case (str, (name, ctype)) =>
if(ctype == "category") {
val arr = categoryTable.getOrElse(name, Array[String]())
if(!arr.contains(str)) categoryTable += name -> (arr :+ str)
}
}
}
decodeTable = table.flatMap { case (name, ctype) =>
if(categoryTable.contains(name)) categoryTable(name).map((name, _))
else Array((name, null: String))
}
// One-Hot
data.map { row =>
row.zip(table).flatMap { case (str, (name, ctype)) =>
if(categoryTable.contains(name)) {
val arr = new Array[Double](categoryTable(name).size)
arr(categoryTable(name).indexOf(str)) = 1.0
arr
} else {
Array(str.toDouble)
}
}
}
} else null
} catch { case e: Exception =>
Console.err.println(e)
null
}
def decode(data: Array[Array[Double]]): Array[Array[String]] = try {
if(decodeTable.size == data.head.size) {
data.map { row =>
row.zip(decodeTable).flatMap { case (value, (name, cname)) =>
if(cname == null) Array(value.toString)
else if(value < 0.5) Array[String]()
else Array(cname)
}
}
} else null
} catch { case e: Exception =>
Console.err.println(e)
null
}
}