Skip to content

Commit

Permalink
final
Browse files Browse the repository at this point in the history
  • Loading branch information
tgthuan committed Jan 1, 2022
1 parent 75f1305 commit 09c372b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 38 deletions.
59 changes: 34 additions & 25 deletions main.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
using ScikitLearn
using ScikitLearn.CrossValidation: train_test_split
using PyCall
np = pyimport("numpy")
using Random

function train_test_split(data, y_column, train_size)
#set seed = 30 -> fix random state
Random.seed!(30);
n = size(data,1)
#random
idx = shuffle(Vector(1:n))
#split train and test dataset
train_idx = view(idx, 1:floor(Int, train_size*n))
test_idx = view(idx, (floor(Int, train_size*n)+1):n)
data[train_idx, :], y_column[train_idx, :], data[test_idx, :], y_column[test_idx, :]
end


function read_file(filename)
# X = Array{Float64}(undef, 0,2)
X=[]
y = []
fp = open(filename,"r")
Expand All @@ -17,20 +25,19 @@ function read_file(filename)
popfirst!(x_line)
y_line = pop!(x_line)
x_line = [parse(Float64,ss) for ss in x_line]
# x_line = np.asarray(x_line).astype("float64")

#create x
append!(X, [x_line])
# X = [X;x_line]

#create y
push!(y, y_line)
end
# X= np.asarray(X)
# y = np.asarray(y)

X = hcat(X...)'
close(fp)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=42);
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=42);
X_train, y_train, X_test, y_test = train_test_split(X, y,2/3)
return X_train, y_train, X_test, y_test, attr
# return X, y, attr
end

mutable struct Node
Expand Down Expand Up @@ -65,9 +72,6 @@ mutable struct DecisionTree
tree = new(min_value_gain,depth,attr)
return tree
end



end

function fit(X_train, y_train, tree, attr)
Expand All @@ -86,10 +90,11 @@ function fit(X_train, y_train, tree, attr)
#init tree
tree.root = root
list_root_check = [root]
#process with tree's child
while length(list_root_check) != 0
active_node = pop!(list_root_check)
if active_node.entropy < tree.min_value_gain || active_node.depth < tree.depth
active_node.child = make_split(active_node, X_train, y_train, attr)
active_node.child = make_split(active_node, X_train, y_train)
if length(active_node.child) == 0
y_dummy = [y_train[i] for i in active_node.pos]
vals_unique = unique(y_dummy)
Expand All @@ -112,7 +117,7 @@ function fit(X_train, y_train, tree, attr)
return tree
end

function make_split(node::Node, x_train, y_train, attr)
function make_split(node::Node, x_train, y_train)
choose_gain = 0
choose_split = []
choose_attr = ""
Expand All @@ -122,12 +127,13 @@ function make_split(node::Node, x_train, y_train, attr)
x_train_T = copy(x_train)'
pos = node.pos
for col in 1:length(x_train_T[:,1])
H_min = 9999
H_min = 10
left_set_choose = []
right_set_choose = []
entropy_list_choose = []
value_choose = 0

#get unique value to choose cutoff
unique_value = Set(sort(x_train_T[col,:]))
if length(unique_value) == 1
continue
Expand Down Expand Up @@ -162,6 +168,7 @@ function make_split(node::Node, x_train, y_train, attr)
print("x")
end
end
#choose cutoff for each column
if H < H_min
H_min = H
left_set_choose = left_set
Expand All @@ -173,7 +180,10 @@ function make_split(node::Node, x_train, y_train, attr)
if minimum([length(left_set_choose), length(right_set_choose)]) < 2
continue
end

#calc information gain
gain_information = node.entropy - H_min
#choose column which have a higher information gain
if gain_information > choose_gain
choose_gain = gain_information
choose_attr = col
Expand All @@ -185,6 +195,7 @@ function make_split(node::Node, x_train, y_train, attr)
node.name = string(choose_attr)
node.value_split = best_value
k = 1
#create new child
for split in choose_split
if k == 1
new_node = Node("< "*string(best_value),entropy_best[k],split,node.depth+1)
Expand All @@ -195,7 +206,6 @@ function make_split(node::Node, x_train, y_train, attr)
end
k += 1
end
println("1")
return child_nodes
end

Expand Down Expand Up @@ -227,18 +237,17 @@ function accuracy(y_predict, y_true)
end

function main()
X_train, y_train, X_test, y_test, attr = read_file("Iris.csv")
# X_train, y_train, attr = read_file("Iris.csv")
X_train, y_train, X_test, y_test, attr = read_file(joinpath(@__DIR__,"Iris.csv"))
tree = DecisionTree(0, 10, attr)
fit(X_train, y_train, tree, attr)
y_hat_train = predict(tree,X_train)
y_hat_train = predict(tree, X_train)
acc_train = accuracy(y_hat_train, y_train)

y_hat_test = predict(tree,X_test)
y_hat_test = predict(tree, X_test)
acc_test = accuracy(y_hat_test, y_test)

println("Accuracy of train dataset: ",acc_train)
println("Accuracy of test dataset: ",acc_test)
println("Accuracy of train dataset: ",acc_train,"\n")
println("Accuracy of test dataset: ",acc_test,"\n")
end

main()
Expand Down
Binary file added report.pdf
Binary file not shown.
13 changes: 0 additions & 13 deletions test.jl

This file was deleted.

0 comments on commit 09c372b

Please sign in to comment.