Skip to content

Commit

Permalink
added a (passing) test.
Browse files Browse the repository at this point in the history
  • Loading branch information
colinjcotter committed Jun 18, 2024
1 parent a051a99 commit 688cb60
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyadjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .tape import (Tape,
set_working_tape, get_working_tape, no_annotations,
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
from .adjfloat import AdjFloat
from .adjfloat import AdjFloat, exp, log
from .reduced_functional import ReducedFunctional
from .drivers import compute_gradient, compute_hessian, solve_adjoint
from .verification import taylor_test, taylor_to_dict
Expand Down
4 changes: 2 additions & 2 deletions pyadjoint/adjfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v
return _exp(input0) * hessian

def recompute_component(self, inputs, block_variable, idx, prepared):
return _exp(inputs)
return _exp(inputs[0])


class LogBlock(Block):
Expand All @@ -217,7 +217,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v
return -hessian / input0 / input0

def recompute_component(self, inputs, block_variable, idx, prepared):
return _log(inputs)
return _log(inputs[0])


_min = min
Expand Down
32 changes: 30 additions & 2 deletions tests/pyadjoint/test_floats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from math import log
import math
from numpy.testing import assert_approx_equal
from numpy.random import rand
from pyadjoint import *
Expand Down Expand Up @@ -155,6 +155,34 @@ def test_float_neg():
assert rf2.derivative() == - 2.0


def test_float_logexp():
a = AdjFloat(3.0)
b = exp(a)
c = log(b)
assert_approx_equal(c, 3.0)

b = log(a)
c = exp(b)
assert c, 3.0

rf = ReducedFunctional(c, Control(a))
assert_approx_equal(rf(a), 3.0)
assert_approx_equal(rf(AdjFloat(1.0)), 1.0)
assert_approx_equal(rf(AdjFloat(9.0)), 9.0)

assert_approx_equal(rf.derivative(), 1.0)

a = AdjFloat(3.0)
b = exp(a)
rf = ReducedFunctional(b, Control(a))
assert_approx_equal(rf.derivative(), math.exp(3.0))

a = AdjFloat(2.0)
b = log(a)
rf = ReducedFunctional(b, Control(a))
assert_approx_equal(rf.derivative(), 1./2.)


def test_float_exponentiation():
a = AdjFloat(3.0)
b = AdjFloat(2.0)
Expand All @@ -172,7 +200,7 @@ def test_float_exponentiation():
assert rf(AdjFloat(1.0)) == 1.0
assert rf(AdjFloat(2.0)) == 4.0
# d(a**a)/da = dexp(a log(a))/da = a**a * (log(a) + 1)
assert_approx_equal(rf.derivative(), 4.0 * (log(2.0)+1.0))
assert_approx_equal(rf.derivative(), 4.0 * (math.log(2.0)+1.0))

# TODO: __rpow__ is not yet implemented

Expand Down

0 comments on commit 688cb60

Please sign in to comment.