From dd3c0378626db1d6e54e29bf8b6054bd4d16d773 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Fri, 1 Nov 2024 14:40:52 +0100 Subject: [PATCH] Fix input type (#40) * Fix input type assembling * Add test file --- src/computable_dags/generation.jl | 23 +++++++++++++----- test/input_type.jl | 39 +++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 test/input_type.jl diff --git a/src/computable_dags/generation.jl b/src/computable_dags/generation.jl index 9be7b37..40a9283 100644 --- a/src/computable_dags/generation.jl +++ b/src/computable_dags/generation.jl @@ -121,17 +121,28 @@ function ComputableDAGs.input_expr( end end +# recursion termination: base case +@inline _assemble_input_type(::Tuple{}, ::ParticleDirection) = () + +# function assembling the correct type information for the tuple of ParticleStatefuls in a phasespace point for input_type +@inline function _assemble_input_type( + particle_types::Tuple{SPECIES_T,Vararg{AbstractParticleType}}, dir::DIR_T +) where {SPECIES_T<:AbstractParticleType,DIR_T<:ParticleDirection} + return ( + AbstractParticleStateful{DIR_T,SPECIES_T}, + _assemble_input_type(particle_types[2:end], dir)..., + ) +end + function ComputableDAGs.input_type(p::AbstractProcessDefinition) - # TODO correctly assemble abstract types here - # See https://github.com/QEDjl-project/QEDFeynmanDiagrams.jl/issues/29 - in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum) - out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum) + in_t = _assemble_input_type(incoming_particles(p), Incoming()) + out_t = _assemble_input_type(outgoing_particles(p), Outgoing()) return AbstractPhaseSpacePoint{ typeof(p), <:AbstractModelDefinition, <:AbstractPhasespaceDefinition, - Tuple{in_t...}, - Tuple{out_t...}, + <:Tuple{in_t...}, + <:Tuple{out_t...}, } end diff --git a/test/input_type.jl b/test/input_type.jl new file mode 100644 index 0000000..7968a47 --- /dev/null +++ b/test/input_type.jl @@ -0,0 +1,39 @@ +# file for testing that the generated input_type of a generated dag is correct + +using Random +using QEDcore +using QEDprocesses +using ComputableDAGs +using QEDFeynmanDiagrams + +using RuntimeGeneratedFunctions +RuntimeGeneratedFunctions.init(@__MODULE__) + +include("utils.jl") + +RNG = MersenneTwister(0) + +@testset "Compton-like process with $n incoming photons" for n in (1, 2, 3, 4) + proc = ScatteringProcess( + (Electron(), ntuple(_ -> Photon(), n)...), + (Electron(), Photon()), + (AllSpin(), ntuple(_ -> PolX(), n)...), + (AllSpin(), AllPol()), + ) + + for n_other in (1, 2, 3, 4) + n_other_proc = ScatteringProcess( + (Electron(), ntuple(_ -> Photon(), n_other)...), + (Electron(), Photon()), + (AllSpin(), ntuple(_ -> PolX(), n_other)...), + (AllSpin(), AllPol()), + ) + input = gen_process_input(RNG, n_other_proc) + + if n_other == n + @test typeof(input) <: input_type(proc) + else + @test !(typeof(input) <: input_type(proc)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index dbabbaa..87cd0ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using SafeTestsets +@safetestset "Input Type" begin + include("input_type.jl") +end + @safetestset "Synced Spins and Polarizations" begin include("synced_spin_pol.jl") end