diff --git a/src/autograd/gates_reduce.nim b/src/autograd/gates_reduce.nim index f16483c0f..cf115a6a2 100644 --- a/src/autograd/gates_reduce.nim +++ b/src/autograd/gates_reduce.nim @@ -20,6 +20,13 @@ import ../private/ast_utils, type MeanGate* {.final.} [TT] = ref object of Gate[TT] ## TODO: generalize to C <- alpha AB + C cached_input_shape: MetadataArray + axis: int + +proc shape_product(m: MeanGate): int {.inline.} = + result = 1 + for i, v in m.cached_input_shape: + if i != m.axis: + result *= v proc mean_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[TT] = let gradient = payload.variable.grad @@ -30,6 +37,11 @@ proc mean_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[ let z_shape = newSeqWith(self.cached_input_shape.len, 1) result[0] = result[0].reshape(z_shape).broadcast(self.cached_input_shape) +proc mean_with_axis_backward_ag[TT](self: MeanGate[TT], payload: Payload[TT]): SmallDiffs[TT] = + let gradient = payload.variable.grad + result = newDiffs[TT](1) + result[0] = (gradient / getSubType(TT)(self.shape_product)).broadcast(self.cached_input_shape) + proc mean_cache[TT](result: Variable[TT], a: Variable[TT]) = # Gate var gate: MeanGate[TT] @@ -49,6 +61,26 @@ proc mean_cache[TT](result: Variable[TT], a: Variable[TT]) = a ) +proc mean_cache[TT](result: Variable[TT], a: Variable[TT], axis: Natural) = + # Gate + var gate: MeanGate[TT] + new gate + gate.cached_input_shape = a.value.shape + gate.axis = axis + + # Result setup + result.grad = zeros_like(result.value) + result.requires_grad = true + + # Caching for backprop + register_node( + "Mean", + gate, + mean_with_axis_backward_ag[TT], + result, + a + ) + proc mean*[TT](a: Variable[TT]): Variable[TT] = # Resulting var new result @@ -59,6 +91,16 @@ proc mean*[TT](a: Variable[TT]): Variable[TT] = if a.is_grad_needed: result.mean_cache(a) +proc mean*[TT](a: Variable[TT], axis: Natural): Variable[TT] = + # Resulting var + new result + result.context = a.context + result.value = a.value.mean(axis) + + # Caching for backprop + if a.is_grad_needed: + result.mean_cache(a, axis) + type SumGate* {.final.} [TT] = ref object of Gate[TT] ## TODO: generalize to C <- alpha AB + C cached_input_shape: MetadataArray diff --git a/tests/autograd/test_gate_basic.nim b/tests/autograd/test_gate_basic.nim index 9176b4fea..d45b6d585 100644 --- a/tests/autograd/test_gate_basic.nim +++ b/tests/autograd/test_gate_basic.nim @@ -34,3 +34,38 @@ suite "Autograd of basic operations": check: va.grad == onesTensor check: vb.grad == onesTensor + + test "Gradient of mean": + + let a = toSeq(1..8).toTensor.reshape(2,4).astype(float32) + + let ctx = newContext Tensor[float32] + + let va = ctx.variable(a, requires_grad = true) + let m = va.mean() + + m.backprop() + + let constantTensor = ones[float32](2, 4) / 8.0 + + check: va.grad == constantTensor + + test "Gradient of mean along one axis": + + let a = toSeq(1..8).toTensor.reshape(2,4).astype(float32) + + let ctx = newContext Tensor[float32] + + let va = ctx.variable(a, requires_grad = true) + + let m0 = va.mean(axis=0) + m0.backprop() + let constantTensor0 = ones[float32](2, 4) / 4.0 + check: va.grad == constantTensor0 + + va.grad = zeros_like(va.grad) + + let m = va.mean(axis=1) + m.backprop() + let constantTensor1 = ones[float32](2, 4) / 2.0 + check: va.grad == constantTensor1 \ No newline at end of file