Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Nov 14, 2023
1 parent ffba46a commit 980eb63
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions examples/linear_regression.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
defmodule LinearRegression do
import Nx.Defn

# y = mx + b
defn init_random_params do
{m, new_key} =
Nx.Random.key(42)
Expand All @@ -14,18 +13,26 @@ defmodule LinearRegression do
{m, b}
end

defn predict({m, b}, inp) do
Nx.dot(inp, m) + b
defn predict({m, b}, input) do
Nx.dot(input, m) + b
end

# MSE Loss
defn loss({m, b}, inp, tar) do
Nx.mean(Nx.pow(tar - predict({m, b}, inp), 2))
defn loss({m, b}, input, target) do
target - predict({m, b}, input)
|> Nx.pow(2)
|> Nx.mean()
end

defn update({m, b} = params, inp, tar, step) do
{grad_m, grad_b} = grad(params, &loss(&1, inp, tar))
{m - grad_m * step, b - grad_b * step}
defn update({m, b} = params, input, target, step) do
{grad_m, grad_b} =
params
|> grad(&loss(&1, input, target))

{
m - grad_m * step,
b - grad_b * step
}
end

def train(params, epochs, lin_fn) do
Expand All @@ -40,9 +47,9 @@ defmodule LinearRegression do
|> Enum.reduce(
acc,
fn batch, cur_params ->
{inp, tar} = Enum.unzip(batch)
x = Nx.reshape(Nx.tensor(inp), {32, 1})
y = Nx.reshape(Nx.tensor(tar), {32, 1})
{input, target} = Enum.unzip(batch)
x = Nx.reshape(Nx.tensor(input), {32, 1})
y = Nx.reshape(Nx.tensor(target), {32, 1})
update(cur_params, x, y, 0.001)
end
)
Expand Down

0 comments on commit 980eb63

Please sign in to comment.