Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: implement all_subsystems switch in DPD builder #152

Merged
merged 7 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"ms-vsliveshare.vsliveshare",
"oijaz.unicode-latex",
"redhat.vscode-yaml",
"ryanluker.vscode-coverage-gutters",
"soulcode.vscode-unwanted-extensions",
"stkb.rewrap",
"streetsidesoftware.code-spell-checker",
Expand Down
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
"cSpell.enabled": true,
"coverage-gutters.coverageFileNames": ["coverage.xml"],
"coverage-gutters.coverageReportFileName": "**/htmlcov/index.html",
"coverage-gutters.showGutterCoverage": false,
"coverage-gutters.showLineCoverage": true,
"diffEditor.experimental.showMoves": true,
"editor.formatOnSave": true,
"files.eol": "\n",
Expand Down
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ coverage:
status:
project:
default:
target: 54%
target: 65%
threshold: 1%
base: auto
if_no_uploads: error
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ commands = [
[
"pytest",
{replace = "posargs", extend = true},
"--cov-fail-under=54",
"--cov-fail-under=65",
"--cov-report=html",
"--cov-report=xml",
"--cov=ampform_dpd",
Expand Down
10 changes: 9 additions & 1 deletion src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self,
decay: ThreeBodyDecay,
min_ls: tuple[bool, bool] | bool = True,
all_subsystems: bool = False,
) -> None:
"""Amplitude builder for the helicity formalism with Dalitz-plot decomposition.

Expand All @@ -62,6 +63,8 @@ def __init__(
element of the `tuple` defines whether to use helicity couplings on the
**production** `.IsobarNode` and the second configures the **decay**
`.IsobarNode`.
all_subsystems: Formulate the amplitude model for all allowed subsystems in
the decay, even if they do not exist in the `.ThreeBodyDecay` object.
"""
self.decay = decay
self.dynamics_choices = DynamicsConfigurator(decay)
Expand All @@ -76,6 +79,7 @@ def __init__(
else:
msg = f"Cannot configure helicity couplings with a {type(min_ls).__name__}"
raise NotImplementedError(msg, min_ls)
self.all_subsystems = all_subsystems

def formulate(
self,
Expand All @@ -94,8 +98,12 @@ def formulate(
angle_definitions = {}
parameter_defaults = {}
args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
if self.all_subsystems:
subsystem_ids: list[FinalStateID] = [1, 2, 3]
else:
subsystem_ids = sorted(_get_subsystem_ids(self.decay))
for args in product(*allowed_helicities.values()): # type:ignore[assignment]
for sub_system in _get_subsystem_ids(self.decay):
for sub_system in subsystem_ids:
chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type]
amplitude_definitions.update(chain_model.amplitudes)
angle_definitions.update(chain_model.variables)
Expand Down
6 changes: 3 additions & 3 deletions src/ampform_dpd/dynamics/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@


def formulate_breit_wigner_with_form_factor(
decay: ThreeBodyDecayChain,
decay_chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, complex | float]]:
decay_node = decay.decay_node
decay_node = decay_chain.decay_node
s = get_mandelstam_s(decay_node)
parameter_defaults = {}
production_ff, new_pars = _create_form_factor(s, decay.production_node)
production_ff, new_pars = _create_form_factor(s, decay_chain.production_node)
parameter_defaults.update(new_pars)
decay_ff, new_pars = _create_form_factor(s, decay_node)
parameter_defaults.update(new_pars)
Expand Down
121 changes: 61 additions & 60 deletions tests/adapter/test_qrules.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# cspell:ignore pksigma
# pyright: reportPrivateUsage=false
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, SupportsFloat

import attrs
import pytest
import qrules

from ampform_dpd.adapter.qrules import (
_convert_transition,
Expand All @@ -19,50 +18,10 @@
from ampform_dpd.decay import LSCoupling, Particle

if TYPE_CHECKING:
from _pytest.fixtures import SubRequest
from qrules.topology import FrozenTransition
from qrules.transition import ReactionInfo, StateTransition


@pytest.fixture(scope="session")
def a2pipipi_reaction() -> ReactionInfo:
return qrules.generate_transitions(
initial_state="a(1)(1260)0",
final_state=["pi0", "pi0", "pi0"],
allowed_intermediate_particles=["a(0)(980)0"],
formalism="helicity",
)


@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"])
def jpsi2pksigma_reaction(request: SubRequest) -> ReactionInfo: # cspell:ignore pksigma
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [+1])],
final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
allowed_interaction_types="strong",
allowed_intermediate_particles=["Sigma(1660)"],
formalism=request.param,
)


@pytest.fixture(scope="session")
def xib2pkk_reaction() -> ReactionInfo:
reaction = qrules.generate_transitions(
initial_state="Xi(b)-",
final_state=["p", "K-", "K-"],
allowed_intermediate_particles=["Lambda(1520)"],
formalism="helicity",
)
swapped_transitions = tuple(
attrs.evolve(t, topology=t.topology.swap_edges(1, 2))
for t in reaction.transitions
)
return qrules.transition.ReactionInfo(
transitions=reaction.transitions + swapped_transitions,
formalism=reaction.formalism,
)


def test_convert_transitions(xib2pkk_reaction: ReactionInfo):
reaction = normalize_state_ids(xib2pkk_reaction)
assert reaction.get_intermediate_particles().names == ["Lambda(1520)"]
Expand All @@ -81,32 +40,40 @@ def test_filter_min_ls(jpsi2pksigma_reaction: ReactionInfo):

ls_couplings = [_get_couplings(t) for t in transitions]
if reaction.formalism == "canonical-helicity":
assert len(ls_couplings) == 2
assert len(ls_couplings) == 3
assert ls_couplings == [
(
{"L": 0, "S": 1.0},
{"L": 0, "S": 1},
{"L": 1, "S": 0.5},
),
(
{"L": 2, "S": 1.0},
{"L": 2, "S": 1},
{"L": 1, "S": 0.5},
),
(
{"L": 1, "S": 2},
{"L": 2, "S": 0.5},
),
]
else:
assert len(ls_couplings) == 1
assert len(ls_couplings) == 2
for ls_coupling in ls_couplings:
for ls in ls_coupling:
assert ls == {"L": None, "S": None}

min_ls_transitions = filter_min_ls(transitions)
ls_couplings = [_get_couplings(t) for t in min_ls_transitions]
assert len(ls_couplings) == 1
assert len(ls_couplings) == 2
if reaction.formalism == "canonical-helicity":
assert ls_couplings == [
(
{"L": 0, "S": 1.0},
{"L": 0, "S": 1},
{"L": 1, "S": 0.5},
),
(
{"L": 1, "S": 2},
{"L": 2, "S": 0.5},
),
]


Expand Down Expand Up @@ -181,26 +148,60 @@ def test_to_three_body_decay(jpsi2pksigma_reaction: ReactionInfo, min_ls: bool):
2: "Sigma+",
3: "p~",
}
n_chains = len(decay.chains)
if reaction.formalism == "canonical-helicity":
production_ls = [c.incoming_ls for c in decay.chains]
decay_ls = [c.outgoing_ls for c in decay.chains]
if min_ls:
assert len(decay.chains) == 1
assert decay.chains[0].incoming_ls == LSCoupling(L=0, S=1)
assert decay.chains[0].outgoing_ls == LSCoupling(L=1, S=0.5)
assert n_chains == 2
assert production_ls == [
LSCoupling(L=1, S=1),
LSCoupling(L=0, S=1),
]
assert decay_ls == [
LSCoupling(L=2, S=0.5),
LSCoupling(L=1, S=0.5),
]
else:
assert len(decay.chains) == 2
assert decay.chains[1].incoming_ls == LSCoupling(L=2, S=1)
assert decay.chains[1].outgoing_ls == LSCoupling(L=1, S=0.5)
assert n_chains == 4
assert production_ls == [
LSCoupling(L=1, S=1),
LSCoupling(L=1, S=2),
LSCoupling(L=0, S=1),
LSCoupling(L=2, S=1),
]
assert decay_ls == [
LSCoupling(L=2, S=0.5),
LSCoupling(L=2, S=0.5),
LSCoupling(L=1, S=0.5),
LSCoupling(L=1, S=0.5),
]
elif reaction.formalism == "helicity":
assert len(decay.chains) == 1
assert decay.chains[0].incoming_ls is None
assert decay.chains[0].outgoing_ls is None
assert n_chains == 2
for chain in decay.chains:
assert chain.incoming_ls is None
assert chain.outgoing_ls is None
resonance_names = set()
for chain in decay.chains:
assert isinstance(chain.resonance, Particle)
assert chain.resonance.name == "Sigma(1660)~-"
resonance_names.add(chain.resonance.name)
assert resonance_names == {
"N(1700)+",
"Sigma(1660)~-",
}


def _get_couplings(transition: StateTransition) -> tuple[dict, dict]:
return tuple( # type:ignore[return-value]
{"L": node.l_magnitude, "S": node.s_magnitude}
{"L": _to_float(node.l_magnitude), "S": _to_float(node.s_magnitude)}
for node in transition.interactions.values()
)


def _to_float(value: SupportsFloat | None) -> float | int | None:
if value is None:
return None
value = float(value)
if value.is_integer():
return int(value)
return value
51 changes: 51 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# cspell:ignore pksigma
from __future__ import annotations

from typing import TYPE_CHECKING

import attrs
import pytest
import qrules

if TYPE_CHECKING:
from _pytest.fixtures import SubRequest
from qrules.transition import ReactionInfo


@pytest.fixture(scope="session")
def a2pipipi_reaction() -> ReactionInfo:
return qrules.generate_transitions(
initial_state="a(1)(1260)0",
final_state=["pi0", "pi0", "pi0"],
allowed_intermediate_particles=["a(0)(980)0"],
formalism="helicity",
)


@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"])
def jpsi2pksigma_reaction(request: SubRequest) -> ReactionInfo:
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [+1])],
final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
allowed_interaction_types="strong",
allowed_intermediate_particles=["N(1700)+", "Sigma(1660)"],
formalism=request.param,
)


@pytest.fixture(scope="session")
def xib2pkk_reaction() -> ReactionInfo:
reaction = qrules.generate_transitions(
initial_state="Xi(b)-",
final_state=["p", "K-", "K-"],
allowed_intermediate_particles=["Lambda(1520)"],
formalism="helicity",
)
swapped_transitions = tuple(
attrs.evolve(t, topology=t.topology.swap_edges(1, 2))
for t in reaction.transitions
)
return qrules.transition.ReactionInfo(
transitions=reaction.transitions + swapped_transitions,
formalism=reaction.formalism,
)
55 changes: 55 additions & 0 deletions tests/test_dpd_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# cspell:ignore pksigma
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from ampform_dpd import DalitzPlotDecompositionBuilder
from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay
from ampform_dpd.dynamics.builder import formulate_breit_wigner_with_form_factor

if TYPE_CHECKING:
from qrules.transition import ReactionInfo


class TestDalitzPlotDecompositionBuilder:
@pytest.mark.parametrize("all_subsystems", [False, True])
@pytest.mark.parametrize("min_ls", [False, True])
def test_all_subsystems(
self, jpsi2pksigma_reaction: ReactionInfo, all_subsystems: bool, min_ls: bool
):
if jpsi2pksigma_reaction.formalism == "helicity" and not min_ls:
pytest.skip("Helicity formalism with all LS not supported")
transitions = normalize_state_ids(jpsi2pksigma_reaction.transitions)
decay = to_three_body_decay(transitions, min_ls=min_ls)
builder = DalitzPlotDecompositionBuilder(
decay, min_ls=min_ls, all_subsystems=all_subsystems
)
if jpsi2pksigma_reaction.formalism == "canonical-helicity":
for chain in builder.decay.chains:
builder.dynamics_choices.register_builder(
chain, formulate_breit_wigner_with_form_factor
)
if all_subsystems:
with pytest.warns(
UserWarning,
match=r"Decay J/psi\(1S\) → 1: K0, 2: Sigma\+, 3: p~ only has subsystems 2, 3, not 1",
):
model = builder.formulate(reference_subsystem=2)
else:
model = builder.formulate(reference_subsystem=2)
expected_variables = {
R"\zeta^0_{2(2)}",
R"\zeta^0_{3(2)}",
R"\zeta^2_{2(2)}",
R"\zeta^2_{3(2)}",
R"\zeta^3_{2(2)}",
R"\zeta^3_{3(2)}",
"theta_12",
"theta_23",
"theta_31",
}
if not all_subsystems:
expected_variables.remove("theta_23")
assert {s.name for s in model.variables} == expected_variables