Skip to content

Commit

Permalink
Fix input type (#40)
Browse files Browse the repository at this point in the history
* Fix input type assembling

* Add test file
  • Loading branch information
AntonReinhard authored Nov 1, 2024
1 parent 0ea9ca9 commit dd3c037
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/computable_dags/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,28 @@ function ComputableDAGs.input_expr(
end
end

# recursion termination: base case
@inline _assemble_input_type(::Tuple{}, ::ParticleDirection) = ()

# function assembling the correct type information for the tuple of ParticleStatefuls in a phasespace point for input_type
@inline function _assemble_input_type(
particle_types::Tuple{SPECIES_T,Vararg{AbstractParticleType}}, dir::DIR_T
) where {SPECIES_T<:AbstractParticleType,DIR_T<:ParticleDirection}
return (
AbstractParticleStateful{DIR_T,SPECIES_T},
_assemble_input_type(particle_types[2:end], dir)...,
)
end

function ComputableDAGs.input_type(p::AbstractProcessDefinition)
# TODO correctly assemble abstract types here
# See https://github.com/QEDjl-project/QEDFeynmanDiagrams.jl/issues/29
in_t = QEDcore._assemble_tuple_type(incoming_particles(p), Incoming(), SFourMomentum)
out_t = QEDcore._assemble_tuple_type(outgoing_particles(p), Outgoing(), SFourMomentum)
in_t = _assemble_input_type(incoming_particles(p), Incoming())
out_t = _assemble_input_type(outgoing_particles(p), Outgoing())
return AbstractPhaseSpacePoint{
typeof(p),
<:AbstractModelDefinition,
<:AbstractPhasespaceDefinition,
Tuple{in_t...},
Tuple{out_t...},
<:Tuple{in_t...},
<:Tuple{out_t...},
}
end

Expand Down
39 changes: 39 additions & 0 deletions test/input_type.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# file for testing that the generated input_type of a generated dag is correct

using Random
using QEDcore
using QEDprocesses
using ComputableDAGs
using QEDFeynmanDiagrams

using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

include("utils.jl")

RNG = MersenneTwister(0)

@testset "Compton-like process with $n incoming photons" for n in (1, 2, 3, 4)
proc = ScatteringProcess(
(Electron(), ntuple(_ -> Photon(), n)...),
(Electron(), Photon()),
(AllSpin(), ntuple(_ -> PolX(), n)...),
(AllSpin(), AllPol()),
)

for n_other in (1, 2, 3, 4)
n_other_proc = ScatteringProcess(
(Electron(), ntuple(_ -> Photon(), n_other)...),
(Electron(), Photon()),
(AllSpin(), ntuple(_ -> PolX(), n_other)...),
(AllSpin(), AllPol()),
)
input = gen_process_input(RNG, n_other_proc)

if n_other == n
@test typeof(input) <: input_type(proc)
else
@test !(typeof(input) <: input_type(proc))
end
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using SafeTestsets

@safetestset "Input Type" begin
include("input_type.jl")
end

@safetestset "Synced Spins and Polarizations" begin
include("synced_spin_pol.jl")
end

0 comments on commit dd3c037

Please sign in to comment.