Skip to content

Commit

Permalink
Apply comment suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ksew1 committed Dec 31, 2024
1 parent 05f5d26 commit e38812e
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions lib/scholar/naive_bayes/categorical.ex
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ defmodule Scholar.NaiveBayes.Categorical do
},
x
) do
{_, _, _, jll} =
while {i = 0, feature_log_probability, x, jll = Nx.broadcast(0.0, Nx.shape(x))},
{_, jll} =
while {{i = 0, feature_log_probability, x}, jll = Nx.broadcast(0.0, Nx.shape(x))},
i < Nx.axis_size(x, 1) do
indices = Nx.slice_along_axis(x, i, 1, axis: 1) |> Nx.squeeze(axes: [1])

Expand All @@ -488,33 +488,32 @@ defmodule Scholar.NaiveBayes.Categorical do
|> Nx.transpose()
|> Nx.add(jll)

{i + 1, feature_log_probability, x, jll}
{{i + 1, feature_log_probability, x}, jll}
end

total_jll = jll + class_log_priors
total_jll
jll + class_log_priors
end

defnp count(x, y_weighted, num_features, num_classes, num_categories, num_samples) do
class_count = Nx.sum(y_weighted, axes: [0])

feature_count = Nx.broadcast(0.0, {num_features, num_classes, num_categories})

{_, _, _, _, feature_count} =
while {i = 0, x, y_weighted, num_features, feature_count}, i < num_samples do
{_, _, _, _, _, feature_count} =
while {j = 0, x, y_weighted, i, num_features, feature_count}, j < num_features do
{_, feature_count} =
while {{i = 0, x, y_weighted, num_features}, feature_count}, i < num_samples do
{_, feature_count} =
while {{j = 0, x, y_weighted, i, num_features}, feature_count}, j < num_features do
category_value = x[i][j]
class_label = Nx.argmax(y_weighted[i])

index = Nx.stack([j, class_label, category_value])

feature_count = Nx.indexed_add(feature_count, index, y_weighted[i][class_label])

{j + 1, x, y_weighted, i, num_features, feature_count}
{{j + 1, x, y_weighted, i, num_features}, feature_count}
end

{i + 1, x, y_weighted, num_features, feature_count}
{{i + 1, x, y_weighted, num_features}, feature_count}
end

{class_count, feature_count}
Expand All @@ -523,8 +522,8 @@ defmodule Scholar.NaiveBayes.Categorical do
defnp compute_feature_log_probability(feature_count, alpha, min_categories) do
feature_log_probability = Nx.broadcast(0.0, Nx.shape(feature_count))

{_, _, _, _, feature_log_probability} =
while {i = 0, feature_count, alpha, min_categories, feature_log_probability},
{_, feature_log_probability} =
while {{i = 0, feature_count, alpha, min_categories}, feature_log_probability},
i < Nx.axis_size(feature_count, 0) do
smoothed_class_count =
Nx.sum(feature_count[i], axes: [1])
Expand All @@ -548,7 +547,7 @@ defmodule Scholar.NaiveBayes.Categorical do
feature_log_probability =
Nx.put_slice(feature_log_probability, [i, 0, 0], smoothed_cat_count)

{i + 1, feature_count, alpha, min_categories, feature_log_probability}
{{i + 1, feature_count, alpha, min_categories}, feature_log_probability}
end

feature_log_probability
Expand Down

0 comments on commit e38812e

Please sign in to comment.