Skip to content

Commit

Permalink
Merge pull request #152 from dolfin-adjoint/explog
Browse files Browse the repository at this point in the history
Log and exp for adjfloats
  • Loading branch information
colinjcotter committed Jun 25, 2024
2 parents 34511f9 + 2570662 commit 43b4032
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyadjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .tape import (Tape,
set_working_tape, get_working_tape, no_annotations,
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
from .adjfloat import AdjFloat
from .adjfloat import AdjFloat, exp, log
from .reduced_functional import ReducedFunctional
from .drivers import compute_gradient, compute_hessian, solve_adjoint
from .verification import taylor_test, taylor_to_dict
Expand Down
94 changes: 94 additions & 0 deletions pyadjoint/adjfloat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import wraps
from .block import Block
from .overloaded_type import OverloadedType, register_overloaded_type, create_overloaded_object
from .tape import get_working_tape, annotate_tape, stop_annotating
import math


def annotate_operator(operator):
Expand Down Expand Up @@ -129,6 +131,98 @@ def _ad_str(self):
return str(self.block_variable.saved_output)


_exp = math.exp
_log = math.log


@wraps(_exp)
def exp(a, **kwargs):
annotate = annotate_tape(kwargs)
if annotate:
a = create_overloaded_object(a)

block = ExpBlock(a)
tape = get_working_tape()
tape.add_block(block)

with stop_annotating():
out = _exp(a)
out = AdjFloat(out)

if annotate:
block.add_output(out.block_variable)
return out


def log(a, **kwargs):
"""Return the natural logarithm of a."""
annotate = annotate_tape(kwargs)
if annotate:
a = create_overloaded_object(a)

block = LogBlock(a)
tape = get_working_tape()
tape.add_block(block)

with stop_annotating():
out = _log(a)
out = AdjFloat(out)

if annotate:
block.add_output(out.block_variable)
return out


class ExpBlock(Block):
def __init__(self, a):
super().__init__()
self.add_dependency(a)

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
adj_input = adj_inputs[0]
input0 = inputs[0]
return _exp(input0) * adj_input

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
tlm_input = tlm_inputs[0]
input0 = inputs[0]
return _exp(input0) * tlm_input

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
input0 = inputs[0]
hessian = hessian_inputs[0]
return _exp(input0) * hessian

def recompute_component(self, inputs, block_variable, idx, prepared):
return _exp(inputs[0])


class LogBlock(Block):
def __init__(self, a):
super().__init__()
self.add_dependency(a)

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
adj_input = adj_inputs[0]
input0 = inputs[0]
return adj_input / input0

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
tlm_input = tlm_inputs[0]
input0 = inputs[0]
return tlm_input / input0

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
input0 = inputs[0]
hessian = hessian_inputs[0]
return -hessian / input0 / input0

def recompute_component(self, inputs, block_variable, idx, prepared):
return _log(inputs[0])


_min = min
_max = max

Expand Down
32 changes: 30 additions & 2 deletions tests/pyadjoint/test_floats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from math import log
import math
from numpy.testing import assert_approx_equal
from numpy.random import rand
from pyadjoint import *
Expand Down Expand Up @@ -155,6 +155,34 @@ def test_float_neg():
assert rf2.derivative() == - 2.0


def test_float_logexp():
a = AdjFloat(3.0)
b = exp(a)
c = log(b)
assert_approx_equal(c, 3.0)

b = log(a)
c = exp(b)
assert c, 3.0

rf = ReducedFunctional(c, Control(a))
assert_approx_equal(rf(a), 3.0)
assert_approx_equal(rf(AdjFloat(1.0)), 1.0)
assert_approx_equal(rf(AdjFloat(9.0)), 9.0)

assert_approx_equal(rf.derivative(), 1.0)

a = AdjFloat(3.0)
b = exp(a)
rf = ReducedFunctional(b, Control(a))
assert_approx_equal(rf.derivative(), math.exp(3.0))

a = AdjFloat(2.0)
b = log(a)
rf = ReducedFunctional(b, Control(a))
assert_approx_equal(rf.derivative(), 1./2.)


def test_float_exponentiation():
a = AdjFloat(3.0)
b = AdjFloat(2.0)
Expand All @@ -172,7 +200,7 @@ def test_float_exponentiation():
assert rf(AdjFloat(1.0)) == 1.0
assert rf(AdjFloat(2.0)) == 4.0
# d(a**a)/da = dexp(a log(a))/da = a**a * (log(a) + 1)
assert_approx_equal(rf.derivative(), 4.0 * (log(2.0)+1.0))
assert_approx_equal(rf.derivative(), 4.0 * (math.log(2.0)+1.0))

# TODO: __rpow__ is not yet implemented

Expand Down

0 comments on commit 43b4032

Please sign in to comment.