Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parametrized pulses #92

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,049 changes: 1,049 additions & 0 deletions docs/notebooks/01_example_simple_state_to_state_parametrization.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/krotov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
mu,
objectives,
parallelization,
parametrization,
propagators,
result,
second_order,
Expand Down
19 changes: 14 additions & 5 deletions src/krotov/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import numpy as np

from .parametrization import ParametrizedArray


__all__ = [
'control_onto_interval',
Expand Down Expand Up @@ -116,25 +118,28 @@ def discretize(control, tlist, args=(None,), kwargs=None, via_midpoints=False):
kwargs=kwargs,
via_midpoints=False,
)
return pulse_onto_tlist(pulse_on_midpoints)
result = pulse_onto_tlist(pulse_on_midpoints)
else:
# relies on np.ComplexWarning being thrown as an error
return np.array(
result = np.array(
[float(control(t, *args, **kwargs)) for t in tlist],
dtype=np.float64,
)
elif isinstance(control, (np.ndarray, list)):
# relies on np.ComplexWarning being thrown as an error
control = np.array([float(v) for v in control], dtype=np.float64)
if len(control) != len(tlist):
result = np.array([float(v) for v in control], dtype=np.float64)
if len(result) != len(tlist):
raise ValueError(
"If control is an array, it must of the same length as tlist"
)
return control
else:
raise TypeError(
"control must be either a callable func(t, args) or a numpy array"
)
if hasattr(control, 'parametrization'):
return ParametrizedArray(result, control.parametrization)
else:
return result


def extract_controls(objectives):
Expand Down Expand Up @@ -354,6 +359,8 @@ def control_onto_interval(control):
if isinstance(control, np.ndarray):
assert len(control.shape) == 1 # must be 1D array
pulse = np.zeros(len(control) - 1, dtype=control.dtype.type)
if hasattr(control, 'parametrization'):
pulse = ParametrizedArray(pulse, control.parametrization)
pulse[0] = control[0]
for i in range(1, len(control) - 1):
pulse[i] = 2.0 * control[i] - pulse[i - 1]
Expand Down Expand Up @@ -383,6 +390,8 @@ def pulse_onto_tlist(pulse):
of the input values before and after the point.
"""
control = np.zeros(len(pulse) + 1, dtype=pulse.dtype.type)
if hasattr(pulse, 'parametrization'):
control = ParametrizedArray(control, pulse.parametrization)
control[0] = pulse[0]
for i in range(1, len(control) - 1):
control[i] = 0.5 * (pulse[i - 1] + pulse[i])
Expand Down
3 changes: 3 additions & 0 deletions src/krotov/mu.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,7 @@ def derivative_wrt_pulse(
raise NotImplementedError(
"Time-dependent collapse operators not implemented"
)
if hasattr(pulses[i_pulse], 'parametrization'):
ϵ = pulses[i_pulse][time_index]
mu *= pulses[i_pulse].parametrization.derivative(ϵ)
return mu
21 changes: 16 additions & 5 deletions src/krotov/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .info_hooks import chain
from .mu import derivative_wrt_pulse
from .parallelization import USE_THREADPOOL_LIMITS
from .parametrization import ParametrizedArray
from .propagators import Propagator, expm
from .result import Result
from .second_order import _overlap
Expand Down Expand Up @@ -441,7 +442,7 @@ def optimize_pulses(
if second_order:
for i_obj in range(len(objectives)):
forward_states[i_obj][0] = objectives[i_obj].initial_state
delta_eps = [
delta_u = [
np.zeros(len(tlist) - 1, dtype=np.complex128) for _ in guess_pulses
]
optimized_pulses = copy.deepcopy(guess_pulses)
Expand All @@ -467,12 +468,12 @@ def optimize_pulses(
update *= chi_norms[i_obj]
if second_order:
update += 0.5 * σ * overlap(delta_phis[i_obj], μ(Ψ))
delta_eps[i_pulse][time_index] += update
delta_u[i_pulse][time_index] += update
λₐ = lambda_vals[i_pulse]
S_t = shape_arrays[i_pulse][time_index]
Δϵ = (S_t / λₐ) * delta_eps[i_pulse][time_index].imag # ∈ ℝ
g_a_integrals[i_pulse] += abs(Δϵ) ** 2 * dt # dt may vary!
optimized_pulses[i_pulse][time_index] += Δϵ
Δu = (S_t / λₐ) * delta_u[i_pulse][time_index].imag # ∈ ℝ
g_a_integrals[i_pulse] += abs(Δu) ** 2 * dt # dt may vary!
_add_update(optimized_pulses[i_pulse], time_index, Δu)
# forward propagation
fw_states = parallel_map[2](
_forward_propagation_step,
Expand Down Expand Up @@ -884,6 +885,16 @@ def _backward_propagation(
return storage_array


def _add_update(pulse, time_index, Δu):
if isinstance(pulse, ParametrizedArray):
ϵ = pulse[time_index]
u = pulse.parametrization.parametrize(ϵ)
pulse[time_index] = pulse.parametrization.unparametrize(u + Δu)
else:
# ϵ = u ⇒ Δϵ = Δu
pulse[time_index] += Δu


def _forward_propagation_step(
i_state,
states,
Expand Down
118 changes: 118 additions & 0 deletions src/krotov/parametrization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
r"""Classes to realized parametrized optimization pulses."""
import sys
import warnings
from abc import ABCMeta, abstractmethod

import numpy as np


class ParametrizedFunction(metaclass=ABCMeta):
"""Wrapped function, adding a `parametrization` attribute."""

def __init__(self, func, parametrization):
self._func = func
self.parametrization = parametrization

def __call__(self, t, args):
return self._func(t, args)


class ParametrizedArray(np.ndarray):
"""Wrapped numpy array, adding a `parametrization` attribute."""

# See https://numpy.org/doc/stable/user/basics.subclassing.html
def __new__(cls, input_array, parametrization):
obj = np.asarray(input_array).view(cls)
obj.parametrization = parametrization
if not isinstance(obj.parametrization, Parametrization):
raise ValueError(
"parametrization must be a Parametrization instance, not %r"
% type(parametrization)
)
return obj

def __array_finalize__(self, obj):
if obj is None:
return
self.parametrization = getattr(obj, 'parametrization', None)


class Parametrization(metaclass=ABCMeta):
"""Abstract base class for a parametrizations."""

@abstractmethod
def parametrize(self, eps_val):
return NotImplementedError

@abstractmethod
def unparametrize(self, u_val):
return NotImplementedError

@abstractmethod
def derivative(self):
return NotImplementedError


class TanhParametrization(Parametrization):
def __init__(self, *, eps_max, eps_min):
self.eps_max = eps_max
self.eps_min = eps_min

def parametrize(self, eps_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
ϵ = eps_val
if ϵ >= ϵ_max or ϵ <= ϵ_min:
warnings.warn(
"Pulse value %r out of range (%r, %r) for %s. "
"Value will be clipped."
% (ϵ, ϵ_min, ϵ_max, self.__class__.__name__)
)
Δ = ϵ_max - ϵ_min
a = np.clip(
2 * ϵ / Δ - (ϵ_max + ϵ_min) / Δ,
-1 + sys.float_info.epsilon,
1 - sys.float_info.epsilon,
)
u = np.arctanh(a) # -18.4 < u < 18.4
return u

def unparametrize(self, u_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
u = u_val
cp = 0.5 * (ϵ_max + ϵ_min)
cm = 0.5 * (ϵ_max - ϵ_min)
ϵ = cm * np.tanh(u) + cp
return ϵ

def derivative(self, eps_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
ϵ = eps_val
Δ = ϵ_max - ϵ_min
a = np.clip(
2 * ϵ / Δ - (ϵ_max + ϵ_min) / Δ,
-1 + sys.float_info.epsilon,
1 - sys.float_info.epsilon,
)
u = np.arctanh(a)
return 0.5 * Δ / np.cosh(u) ** 2


class SquareParametrization(Parametrization):
def parametrize(self, eps_val):
if eps_val < 0:
warnings.warn(
"Pulse value %r < 0 out of range for %s. Clip to 0."
% (eps_val, self.__class__.__name__)
)
eps_val = 0
return np.sqrt(eps_val)

def unparametrize(self, u_val):
return u_val ** 2

def derivative(self, eps_val):
u = self.parametrize(eps_val)
return 2 * u