Skip to content

Commit

Permalink
Improve xgrid rotations structure, fix EKO raw representation
Browse files Browse the repository at this point in the history
  • Loading branch information
alecandido committed Apr 25, 2023
1 parent d197fb2 commit 7f8a628
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 83 deletions.
176 changes: 98 additions & 78 deletions src/eko/io/manipulate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Manipulate output generate by EKO."""
import logging
import warnings
from typing import Optional
from typing import Callable, Optional, Union

import numpy as np
import numpy.typing as npt

from .. import basis_rotation as br
from .. import interpolation
from ..interpolation import XGrid
from .struct import EKO

logger = logging.getLogger(__name__)
Expand All @@ -16,102 +18,120 @@
SIMGRID_ROTATION = "ij,ajbk,kl->aibl"
"""Simultaneous grid rotation contraction indices."""

Basis = Union[XGrid, npt.NDArray]


def rotation(new: Optional[Basis], old: Basis, check: Callable, compute: Callable):
"""Define grid rotation.
This function returns the new grid to be assigned and the rotation computed,
if the checks for a non-trivial new grid are passed.
However, the check and the computation are delegated respectively to the
callables `check` and `compute`.
"""
if new is None:
return old, None

if check(new, old):
warnings.warn("The new grid is close to the current one")
return old, None

return new, compute(new, old)


def xgrid_check(new: Optional[XGrid], old: XGrid):
"""Check validity of new xgrid."""
return new is not None and len(new) == len(old) and np.allclose(new.raw, old.raw)


def xgrid_compute_rotation(new: XGrid, old: XGrid, interpdeg: int, swap: bool = False):
"""Compute rotation from old to new xgrid.
By default, the roation is computed for a target xgrid. Whether the function
should be used for an input xgrid, the `swap` argument should be set to
`True`, in order to compute it in the other direction (i.e. the transposed).
"""
if swap:
new, old = old, new
b = interpolation.InterpolatorDispatcher(old, interpdeg, False)
return b.get_interpolation(new.raw)


def xgrid_reshape(
eko: EKO,
targetgrid: Optional[interpolation.XGrid] = None,
inputgrid: Optional[interpolation.XGrid] = None,
targetgrid: Optional[XGrid] = None,
inputgrid: Optional[XGrid] = None,
):
"""Reinterpolate operators on output and/or input grids.
The operation is in-place.
Target corresponds to the output PDF.
Parameters
----------
eko :
the operator to be rotated
targetgrid :
xgrid for the target (output PDF)
inputgrid :
xgrid for the input (input PDF)
The operation is in-place.
"""
eko.assert_permissions(write=True)

# calling with no arguments is an error
if targetgrid is None and inputgrid is None:
raise ValueError("Nor inputgrid nor targetgrid was given")
# now check to the current status
if (
targetgrid is not None
and len(targetgrid) == len(eko.rotations.targetgrid)
and np.allclose(targetgrid.raw, eko.rotations.targetgrid.raw)
):
targetgrid = None
warnings.warn("The new targetgrid is close to the current targetgrid")
if (
inputgrid is not None
and len(inputgrid) == len(eko.rotations.inputgrid)
and np.allclose(inputgrid.raw, eko.rotations.inputgrid.raw)
):
inputgrid = None
warnings.warn("The new inputgrid is close to the current inputgrid")

interpdeg = eko.operator_card.configs.interpolation_polynomial_degree
check = xgrid_check
crot = xgrid_compute_rotation

# construct matrices
newtarget, targetrot = rotation(
targetgrid,
eko.rotations.targetgrid,
check,
lambda new, old: crot(new, old, interpdeg),
)
newinput, inputrot = rotation(
inputgrid,
eko.rotations.inputgrid,
check,
lambda new, old: crot(new, old, interpdeg, swap=True),
)

# after the checks: if there is still nothing to do, skip
if targetgrid is None and inputgrid is None:
if targetrot is None and inputrot is None:
logger.debug("Nothing done.")
return

# construct matrices
if targetgrid is not None:
b = interpolation.InterpolatorDispatcher(
eko.rotations.targetgrid,
eko.operator_card.configs.interpolation_polynomial_degree,
False,
)
target_rot = b.get_interpolation(targetgrid.raw)
eko.rotations.targetgrid = targetgrid
if inputgrid is not None:
b = interpolation.InterpolatorDispatcher(
inputgrid,
eko.operator_card.configs.interpolation_polynomial_degree,
False,
)
input_rot = b.get_interpolation(eko.rotations.inputgrid.raw)
eko.rotations.inputgrid = inputgrid
# if no rotation is done, the grids are not modified
if targetrot is not None:
eko.rotations.targetgrid = newtarget
if targetrot is not None:
eko.rotations.targetgrid = newinput

# build new grid
for q2, elem in eko.items():
ops = elem.operator
errs = elem.error
if targetgrid is not None and inputgrid is None:
ops = np.einsum(TARGETGRID_ROTATION, target_rot, ops, optimize="optimal")
errs = (
np.einsum(TARGETGRID_ROTATION, target_rot, errs, optimize="optimal")
if errs is not None
else None
)
elif inputgrid is not None and targetgrid is None:
ops = np.einsum(INPUTGRID_ROTATION, ops, input_rot, optimize="optimal")
errs = (
np.einsum(INPUTGRID_ROTATION, errs, input_rot, optimize="optimal")
if errs is not None
else None
)
for ep, elem in eko.items():
assert elem is not None

operands = [elem.operator]
operands_errs = [elem.error]

if targetrot is not None and inputrot is None:
contraction = TARGETGRID_ROTATION
elif inputrot is not None and targetrot is None:
contraction = INPUTGRID_ROTATION
else:
ops = np.einsum(
SIMGRID_ROTATION, target_rot, ops, input_rot, optimize="optimal"
)
errs = (
np.einsum(
SIMGRID_ROTATION, target_rot, errs, input_rot, optimize="optimal"
)
if errs is not None
else None
)
elem.operator = ops
elem.error = errs
contraction = SIMGRID_ROTATION

eko[q2] = elem
if targetrot is not None:
operands.insert(0, targetrot)
operands_errs.insert(0, targetrot)
if inputrot is not None:
operands.append(inputrot)
operands_errs.append(inputrot)

elem.operator = np.einsum(contraction, *operands, optimize="optimal")
if elem.error is not None:
elem.error = np.einsum(contraction, *operands_errs, optimize="optimal")

eko[ep] = elem

eko.update()

Expand All @@ -124,8 +144,8 @@ def xgrid_reshape(

def flavor_reshape(
eko: EKO,
targetpids: Optional[np.ndarray] = None,
inputpids: Optional[np.ndarray] = None,
targetpids: Optional[npt.NDArray] = None,
inputpids: Optional[npt.NDArray] = None,
update: bool = True,
):
"""Change the operators to have in the output targetpids and/or in the input inputpids.
Expand Down
6 changes: 3 additions & 3 deletions src/eko/io/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,8 @@ def items(self):
immediately after
"""
for ep, op in self._operators.items():
yield ep, op
for ep in self._operators:
yield ep, self[ep]
del self[ep]

def __contains__(self, q2: float) -> bool:
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def raw(self) -> dict:
operators themselves
"""
return dict(mu2grid=self.mu2grid.tolist(), metadata=self.metadata.raw)
return dict(mu2grid=self.mu2grid, metadata=self.metadata.raw)


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from eko import interpolation
from eko.io.runcards import OperatorCard, TheoryCard
from eko.io.struct import EKO, Operator
from eko.io.types import EvolutionPoint
from ekobox import cards


Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(self, theory: TheoryCard, operator: OperatorCard, path: os.PathLike
self.cache: Optional[EKO] = None

@staticmethod
def _operators(mugrid: Iterable[float], shape: Tuple[int, int]):
def _operators(mugrid: Iterable[EvolutionPoint], shape: Tuple[int, int]):
ops = {}
for mu in mugrid:
ops[mu] = Operator(np.random.rand(*shape, *shape))
Expand All @@ -94,7 +95,7 @@ def _create(self):
lx = len(self.operator.xgrid)
lpids = len(self.operator.pids)
for mu2, op in self._operators(
mugrid=self.operator.mu2grid, shape=(lpids, lx)
mugrid=self.operator.evolgrid, shape=(lpids, lx)
).items():
self.cache[mu2] = op

Expand Down

0 comments on commit 7f8a628

Please sign in to comment.