Skip to content

Commit

Permalink
Fix and prettify (#24)
Browse files Browse the repository at this point in the history
* WIP

* Fix computational errors, improve propagator task speed

* Improve DAG generation code

* Improve runtime by making vertex term a constant
  • Loading branch information
AntonReinhard authored Sep 14, 2024
1 parent 763ffc1 commit 5a70448
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: "Run CompatHelper"
run: |
import CompatHelper
CompatHelper.main(; master_branch="main")
CompatHelper.main(; master_branch="dev")
shell: julia --color=yes {0}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
Expand Down
159 changes: 71 additions & 88 deletions src/computable_dags/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
import ComputableDAGs: compute, compute_effort, children

const e = sqrt(4π / 137)
_vertex() = -1im * e * gamma()
const VERTEX = -1im * e * gamma()

compute_effort(::ComputeTask_BaseState) = 0
compute_effort(::ComputeTask_Propagator) = 0
Expand All @@ -42,22 +42,22 @@ struct BaseStateInput{PS_T<:AbstractParticleStateful,SPIN_POL_T<:AbstractSpinOrP
spin_pol::SPIN_POL_T
end

function compute( #=@inline=#
::ComputeTask_BaseState,
input::BaseStateInput{PS,SPIN_POL},
function compute(
::ComputeTask_BaseState, input::BaseStateInput{PS,SPIN_POL}
) where {PS,SPIN_POL}
species = particle_species(input.particle)
if is_outgoing(input.particle)
species = _invert(species)
end
state = QEDbase.base_state(
particle_species(input.particle),
particle_direction(input.particle),
momentum(input.particle),
input.spin_pol,
)
return Propagated( # "propagated" because it goes directly into the next pair
species,
QEDbase.base_state(
particle_species(input.particle),
particle_direction(input.particle),
momentum(input.particle),
input.spin_pol,
),
state,
# bispinor, adjointbispinor, or lorentzvector
)
end
Expand All @@ -67,33 +67,43 @@ struct PropagatorInput{VP_T<:VirtualParticle,PSP_T<:AbstractPhaseSpacePoint}
psp::Ref{PSP_T}
end

function compute( #=@inline=#
::ComputeTask_Propagator,
input::PropagatorInput{VP_T,PSP_T},
) where {VP_T,PSP_T}
vp_mom = zero(typeof(momentum(input.psp[], Incoming(), 1)))
for i in eachindex(_in_contributions(input.vp))
if _in_contributions(input.vp)[i]
vp_mom += momentum(input.psp[], Incoming(), i)
end
end
for o in eachindex(_out_contributions(input.vp))
if (_out_contributions(input.vp))[o]
vp_mom -= momentum(input.psp[], Outgoing(), o)
end
@inline _masked_sum(::Tuple{}, ::Tuple{}) = error("masked sum needs at least one argument")
@inline function _masked_sum(values::Tuple{T}, mask::Tuple{Bool}) where {T}
return mask[1] ? values[1] : zero(T)
end
@inline function _masked_sum(
values::Tuple{T,Vararg{T,N}}, mask::Tuple{Bool,Vararg{Bool,N}}
) where {N,T}
return if mask[1]
values[1] + _masked_sum(values[2:end], mask[2:end])
else
_masked_sum(values[2:end], mask[2:end])
end
end

function _vp_momentum(
vp::VirtualParticle{PROC,SPECIES,I,O}, psp::PhaseSpacePoint
) where {PROC,SPECIES,I,O}
return _masked_sum(momenta(psp, Incoming()), _in_contributions(vp)) -
_masked_sum(momenta(psp, Outgoing()), _out_contributions(vp))
end

function compute(
::ComputeTask_Propagator, input::PropagatorInput{VP_T,PSP_T}
) where {VP_T,PSP_T}
# TODO: this is currently eating the most time of the computation, improve this
vp_mom = _vp_momentum(input.vp, input.psp[])
vp_species = particle_species(input.vp)
return QEDbase.propagator(vp_species, vp_mom)
# diracmatrix or scalar number
inner = QEDbase.propagator(vp_species, vp_mom)
return inner
end

struct Unpropagated{PARTICLE_T<:AbstractParticleType,VALUE_T}
particle::PARTICLE_T
value::VALUE_T
end

function Base.:+(a::Unpropagated{P,V}, b::Unpropagated{P,V}) where {P,V}
@inline function Base.:+(a::Unpropagated{P,V}, b::Unpropagated{P,V}) where {P,V}
return Unpropagated(a.particle, a.value + b.value)
end

Expand All @@ -102,89 +112,62 @@ struct Propagated{PARTICLE_T<:AbstractParticleType,VALUE_T}
value::VALUE_T
end

# maybe add the γ matrix term here too?
function compute( #=@inline=#
@inline function compute( # photon, electron
::ComputeTask_Pair,
electron::Propagated{Electron,V1},
positron::Propagated{Positron,V2},
) where {V1,V2}
return Unpropagated(Photon(), positron.value * _vertex() * electron.value) # fermion - antifermion -> photon
end
function compute( #=@inline=#
::ComputeTask_Pair,
positron::Propagated{Positron,V1},
photon::Propagated{Photon,V1},
electron::Propagated{Electron,V2},
) where {V1,V2}
return Unpropagated(Photon(), positron.value * _vertex() * electron.value) # antifermion - fermion -> photon
return Unpropagated(Electron(), photon.value * VERTEX * electron.value) # photon - electron -> electron
end
function compute( #=@inline=#
@inline function compute( # photon, positron
::ComputeTask_Pair,
photon::Propagated{Photon,V1},
fermion::Propagated{F,V2},
) where {F<:FermionLike,V1,V2}
return Unpropagated(fermion.particle, photon.value * _vertex() * fermion.value) # (anti-)fermion - photon -> (anti-)fermion
positron::Propagated{Positron,V2},
) where {V1,V2}
return Unpropagated(Positron(), positron.value * VERTEX * photon.value) # photon - positron -> positron
end
function compute( #=@inline=#
@inline function compute( # electron, positron
::ComputeTask_Pair,
fermion::Propagated{F,V2},
photon::Propagated{Photon,V1},
) where {F<:FermionLike,V1,V2}
return Unpropagated(fermion.particle, photon.value * _vertex() * fermion.value) # photon - (anti-)fermion -> (anti-)fermion
electron::Propagated{Electron,V1},
positron::Propagated{Positron,V2},
) where {V1,V2}
return Unpropagated(Photon(), positron.value * VERTEX * electron.value) # electron - positron -> photon
end

function compute( #=@inline=#
::ComputeTask_PropagatePairs,
left::PROP_V,
right::Unpropagated{P,VAL},
) where {PROP_V,P<:AbstractParticleType,VAL}
return Propagated(right.particle, left * right.value)
@inline function compute(
::ComputeTask_PropagatePairs, prop::PROP_V, photon::Unpropagated{Photon,VAL}
) where {PROP_V,VAL}
return Propagated(Photon(), photon.value * prop)
end
@inline function compute(
::ComputeTask_PropagatePairs, prop::PROP_V, electron::Unpropagated{Electron,VAL}
) where {PROP_V,VAL}
return Propagated(Electron(), prop * electron.value)
end
function compute( #=@inline=#
::ComputeTask_PropagatePairs,
left::Unpropagated{P,VAL},
right::PROP_V,
) where {PROP_V,P<:AbstractParticleType,VAL}
return Propagated(left.particle, right * left.value)
@inline function compute(
::ComputeTask_PropagatePairs, prop::PROP_V, positron::Unpropagated{Positron,VAL}
) where {PROP_V,VAL}
return Propagated(Positron(), positron.value * prop)
end

function compute( #=@inline=#
@inline function compute(
::ComputeTask_Triple,
photon::Propagated{Photon,V1},
electron::Propagated{Electron,V2},
positron::Propagated{Positron,V3},
) where {V1,V2,V3}
return positron.value * _vertex() * photon.value * electron.value
end
function compute( #=@inline=#
c::ComputeTask_Triple,
photon::Propagated{Photon,V1},
positron::Propagated{Positron,V2},
electron::Propagated{Electron,V3},
) where {V1,V2,V3}
return compute(c, photon, electron, positron)
end
function compute( #=@inline=#
c::ComputeTask_Triple,
f1::Propagated{F1,V1},
f2::Propagated{F2,V2},
photon::Propagated{Photon,V3},
) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike}
return compute(c, photon, f1, f2)
end
function compute( #=@inline=#
c::ComputeTask_Triple,
f1::Propagated{F1,V1},
photon::Propagated{Photon,V2},
f2::Propagated{F2,V3},
) where {V1,V2,V3,F1<:FermionLike,F2<:FermionLike}
return compute(c, photon, f1, f2)
return positron.value * (VERTEX * photon.value) * electron.value
end

# this compiles in a reasonable amount of time for up to about 1e4 parameters
# use a summation algorithm with more accuracy and/or parallelization
compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T} = sum(args) #=@inline=#
compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T} = sum(args) #=@inline=#
function compute(::ComputeTask_SpinPolCumulation, args::Vararg{N,T}) where {N,T} #=@inline=#
@inline function compute(::ComputeTask_CollectPairs, args::Vararg{N,T}) where {N,T}
return sum(args)
end
@inline function compute(::ComputeTask_CollectTriples, args::Vararg{N,T}) where {N,T}
return sum(args)
end
function compute(::ComputeTask_SpinPolCumulation, args::Vararg{N,T}) where {N,T}
sum = 0.0
for arg in args
sum += abs2(arg)
Expand Down
88 changes: 46 additions & 42 deletions src/computable_dags/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ function _make_node_name(spin_pols::Vector)
return node_name
end

# return an index for the argument ordering on edges in the DAG for a given particle species, photon -> 1, electron -> 2, positron -> 3
_edge_index_from_species(::Photon) = 1
_edge_index_from_species(::Electron) = 2
_edge_index_from_species(::Positron) = 3
_edge_index_from_vp(vp::VirtualParticle) = _edge_index_from_species(particle_species(vp))

"""
generate_DAG(proc::AbstractProcessDefinition)
Expand All @@ -315,10 +321,15 @@ function generate_DAG(proc::AbstractProcessDefinition)

# TODO: use the spin/pol iterator here once it has been implemented
# -- Base State Tasks --
base_state_task_outputs = Dict()
propagated_outputs = Dict{VirtualParticle,Vector{Node}}()
for dir in (Incoming(), Outgoing())
for species in (Electron(), Positron(), Photon())
for index in 1:number_particles(proc, dir, species)
p = VirtualParticle(
proc,
is_outgoing(dir) ? _invert(species) : species,
_momentum_contribution(proc, dir, species, index)...,
)
for spin_pol in _spin_pols(spin_or_pol(proc, dir, species, index))
# gen entry nodes
# names are "bs_<dir>_<species>_<spin/pol>_<index>"
Expand All @@ -338,7 +349,10 @@ function generate_DAG(proc::AbstractProcessDefinition)
insert_edge!(graph, data_in, compute_base_state)
insert_edge!(graph, compute_base_state, data_out)

base_state_task_outputs[data_node_name] = data_out
if !haskey(propagated_outputs, p)
propagated_outputs[p] = Vector{Node}()
end
push!(propagated_outputs[p], data_out)
end
end
end
Expand All @@ -363,9 +377,8 @@ function generate_DAG(proc::AbstractProcessDefinition)
end

# -- Pair Tasks --
pair_task_outputs = Dict{VirtualParticle,Vector{Node}}()
for (product_particle, input_particle_vector) in pairs
pair_task_outputs[product_particle] = Vector{Node}()
propagated_outputs[product_particle] = Vector{Node}()

# make a dictionary of vectors to collect the outputs depending on spin/pol configs of the input particles
N = _number_contributions(product_particle)
Expand All @@ -374,29 +387,29 @@ function generate_DAG(proc::AbstractProcessDefinition)
}()

for input_particles in input_particle_vector
particles_data_out_nodes = (Vector(), Vector())
c = 0
for p in input_particles
c += 1
if (is_external(p))
# grab from base_states (broadcast over _base_state_name because it is a tuple for different spin_pols)
push!.(
Ref(particles_data_out_nodes[c]),
getindex.(Ref(base_state_task_outputs), _base_state_name(p)),
)
else
# grab from propagated particles
append!(particles_data_out_nodes[c], pair_task_outputs[p])
end
end
# input_particles is a tuple of first and second particle
particles_data_out_nodes = (
propagated_outputs[input_particles[1]],
propagated_outputs[input_particles[2]],
)

for in_nodes in Iterators.product(particles_data_out_nodes...)
# make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations
compute_pair = insert_node!(graph, ComputeTask_Pair())
pair_data_out = insert_node!(graph, DataTask(0))

insert_edge!(graph, in_nodes[1], compute_pair)
insert_edge!(graph, in_nodes[2], compute_pair)
insert_edge!(
graph,
in_nodes[1],
compute_pair,
_edge_index_from_vp(input_particles[1]),
)
insert_edge!(
graph,
in_nodes[2],
compute_pair,
_edge_index_from_vp(input_particles[2]),
)
insert_edge!(graph, compute_pair, pair_data_out)

# get the spin/pol config of the input particles from the data_out names
Expand Down Expand Up @@ -427,39 +440,30 @@ function generate_DAG(proc::AbstractProcessDefinition)
end

insert_edge!(graph, compute_pairs_sum, data_pairs_sum)
insert_edge!(graph, propagator_node, compute_propagated)
insert_edge!(graph, data_pairs_sum, compute_propagated)

insert_edge!(graph, propagator_node, compute_propagated, 1)
insert_edge!(graph, data_pairs_sum, compute_propagated, 2)

insert_edge!(graph, compute_propagated, data_out_propagated)

push!(pair_task_outputs[product_particle], data_out_propagated)
push!(propagated_outputs[product_particle], data_out_propagated)
end
end

# -- Triples --
triples_results = Dict()
for (ph, el, po) in triples # for each triple (each "diagram")
photons = if is_external(ph)
getindex.(Ref(base_state_task_outputs), _base_state_name(ph))
else
pair_task_outputs[ph]
end
electrons = if is_external(el)
getindex.(Ref(base_state_task_outputs), _base_state_name(el))
else
pair_task_outputs[el]
end
positrons = if is_external(po)
getindex.(Ref(base_state_task_outputs), _base_state_name(po))
else
pair_task_outputs[po]
end
photons = propagated_outputs[ph]
electrons = propagated_outputs[el]
positrons = propagated_outputs[po]

for (a, b, c) in Iterators.product(photons, electrons, positrons) # for each spin/pol config of each part
compute_triples = insert_node!(graph, ComputeTask_Triple())
data_triples = insert_node!(graph, DataTask(0))

insert_edge!(graph, a, compute_triples)
insert_edge!(graph, b, compute_triples)
insert_edge!(graph, c, compute_triples)
insert_edge!(graph, a, compute_triples, 1) # first argument photons
insert_edge!(graph, b, compute_triples, 2) # second argument electrons
insert_edge!(graph, c, compute_triples, 3) # third argument positrons

insert_edge!(graph, compute_triples, data_triples)

Expand Down

0 comments on commit 5a70448

Please sign in to comment.