Skip to content

Commit

Permalink
Add mean along axis for Variables and tests (#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaferretti authored and mratsim committed Dec 22, 2018
1 parent a4d3fc1 commit 8a2e247
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/autograd/gates_reduce.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ import ../private/ast_utils,
type MeanGate* {.final.} [TT] = ref object of Gate[TT]
## TODO: generalize to C <- alpha AB + C
cached_input_shape: MetadataArray
axis: int

proc shape_product(m: MeanGate): int {.inline.} =
result = 1
for i, v in m.cached_input_shape:
if i != m.axis:
result *= v

proc mean_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[TT] =
let gradient = payload.variable.grad
Expand All @@ -30,6 +37,11 @@ proc mean_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[
let z_shape = newSeqWith(self.cached_input_shape.len, 1)
result[0] = result[0].reshape(z_shape).broadcast(self.cached_input_shape)

proc mean_with_axis_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[TT] =
let gradient = payload.variable.grad
result = newDiffs[TT](1)
result[0] = (gradient / getSubType(TT)(self.shape_product)).broadcast(self.cached_input_shape)

proc mean_cache[TT](result: Variable[TT], a: Variable[TT]) =
# Gate
var gate: MeanGate[TT]
Expand All @@ -49,6 +61,26 @@ proc mean_cache[TT](result: Variable[TT], a: Variable[TT]) =
a
)

proc mean_cache[TT](result: Variable[TT], a: Variable[TT], axis: Natural) =
# Gate
var gate: MeanGate[TT]
new gate
gate.cached_input_shape = a.value.shape
gate.axis = axis

# Result setup
result.grad = zeros_like(result.value)
result.requires_grad = true

# Caching for backprop
register_node(
"Mean",
gate,
mean_with_axis_backward_ag[TT],
result,
a
)

proc mean*[TT](a: Variable[TT]): Variable[TT] =
# Resulting var
new result
Expand All @@ -59,6 +91,16 @@ proc mean*[TT](a: Variable[TT]): Variable[TT] =
if a.is_grad_needed:
result.mean_cache(a)

proc mean*[TT](a: Variable[TT], axis: Natural): Variable[TT] =
# Resulting var
new result
result.context = a.context
result.value = a.value.mean(axis)

# Caching for backprop
if a.is_grad_needed:
result.mean_cache(a, axis)

type SumGate* {.final.} [TT] = ref object of Gate[TT]
## TODO: generalize to C <- alpha AB + C
cached_input_shape: MetadataArray
Expand Down
35 changes: 35 additions & 0 deletions tests/autograd/test_gate_basic.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,38 @@ suite "Autograd of basic operations":

check: va.grad == onesTensor
check: vb.grad == onesTensor

test "Gradient of mean":

let a = toSeq(1..8).toTensor.reshape(2,4).astype(float32)

let ctx = newContext Tensor[float32]

let va = ctx.variable(a, requires_grad = true)
let m = va.mean()

m.backprop()

let constantTensor = ones[float32](2, 4) / 8.0

check: va.grad == constantTensor

test "Gradient of mean along one axis":

let a = toSeq(1..8).toTensor.reshape(2,4).astype(float32)

let ctx = newContext Tensor[float32]

let va = ctx.variable(a, requires_grad = true)

let m0 = va.mean(axis=0)
m0.backprop()
let constantTensor0 = ones[float32](2, 4) / 4.0
check: va.grad == constantTensor0

va.grad = zeros_like(va.grad)

let m = va.mean(axis=1)
m.backprop()
let constantTensor1 = ones[float32](2, 4) / 2.0
check: va.grad == constantTensor1

0 comments on commit 8a2e247

Please sign in to comment.