Skip to content

Commit

Permalink
Manually build binary ufunc dispatch tables
Browse files Browse the repository at this point in the history
Eventually we might also want to pawn these off on vector
  • Loading branch information
nsmith- authored and lgray committed Jan 30, 2024
1 parent a8494f4 commit 8946d59
Showing 1 changed file with 32 additions and 68 deletions.
100 changes: 32 additions & 68 deletions src/coffea/nanoevents/methods/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def negative(self):
behavior=self.behavior,
)

@awkward.mixin_class_method(numpy.add, {"TwoVector"})
def add(self, other):
"""Add two vectors together elementwise using `x` and `y` components"""
return awkward.zip(
Expand All @@ -166,18 +165,6 @@ def add(self, other):
behavior=self.behavior,
)

@awkward.mixin_class_method(
numpy.subtract,
{
"TwoVector",
"ThreeVector",
"SphericalThreeVector",
"LorentzVector",
"PtEtaPhiMLorentzVector",
"PtEtaPhiELorentzVector",
},
transpose=False,
)
def subtract(self, other):
"""Subtract a vector from another elementwise using `x` and `y` components"""
return awkward.zip(
Expand Down Expand Up @@ -297,17 +284,13 @@ def negative(self):
behavior=self.behavior,
)

@awkward.mixin_class_method(
numpy.add, {"ThreeVector", "TwoVector", "PolarTwoVector"}
)
def add(self, other):
"""Add two vectors together elementwise using `x`, `y`, and `z` components"""
other_3d = other.to_Vector3D()
return awkward.zip(
{
"x": self.x + other_3d.x,
"y": self.y + other_3d.y,
"z": self.z + other_3d.z,
"x": self.x + other.x,
"y": self.y + other.y,
"z": self.z + other.z,
},
with_name="ThreeVector",
behavior=self.behavior,
Expand All @@ -319,28 +302,14 @@ def divide(self, other):
This is realized by using the multiplication functionality"""
return self.multiply(1 / other)

@awkward.mixin_class_method(
numpy.subtract,
{
"TwoVector",
"PolarTwoVector",
"ThreeVector",
"SphericalThreeVector",
"LorentzVector",
"PtEtaPhiMLorentzVector",
"PtEtaPhiELorentzVector",
},
transpose=False,
)
def subtract(self, other):
"""Subtract a vector from another elementwise using `x`, `y`, and `z` components"""
other_3d = other.to_Vector3D()

return awkward.zip(
{
"x": self.x - other_3d.x,
"y": self.y - other_3d.y,
"z": self.z - other_3d.z,
"x": self.x - other.x,
"y": self.y - other.y,
"z": self.z - other.z,
},
with_name="ThreeVector",
behavior=self.behavior,
Expand Down Expand Up @@ -511,51 +480,28 @@ def absolute(self):
"""
return self.mass

@awkward.mixin_class_method(
numpy.add,
{
"LorentzVector",
"ThreeVector",
"SphericalThreeVector",
"TwoVector",
"PolarTwoVector",
},
)
def add(self, other):
"""Add two vectors together elementwise using `x`, `y`, `z`, and `t` components"""
other_4d = other.to_Vector4D()
return awkward.zip(
{
"x": self.x + other_4d.x,
"y": self.y + other_4d.y,
"z": self.z + other_4d.z,
"t": self.t + other_4d.t,
"x": self.x + other.x,
"y": self.y + other.y,
"z": self.z + other.z,
"t": self.t + other.t,
},
with_name="LorentzVector",
behavior=self.behavior,
)

@awkward.mixin_class_method(
numpy.subtract,
{
"LorentzVector",
"ThreeVector",
"SphericalThreeVector",
"TwoVector",
"PolarTwoVector",
},
transpose=False,
)
def subtract(self, other):
"""Subtract a vector from another elementwise using `x`, `y`, `z`, and `t` components"""
other_4d = other.to_Vector4D()

return awkward.zip(
{
"x": self.x - other_4d.x,
"y": self.y - other_4d.y,
"z": self.z - other_4d.z,
"t": self.t - other_4d.t,
"x": self.x - other.x,
"y": self.y - other.y,
"z": self.z - other.z,
"t": self.t - other.t,
},
with_name="LorentzVector",
behavior=self.behavior,
Expand Down Expand Up @@ -965,6 +911,24 @@ def negative(self):
)


_binary_dispatch_cls = {
"TwoVector": TwoVector,
"PolarTwoVector": TwoVector,
"ThreeVector": ThreeVector,
"SphericalThreeVector": ThreeVector,
"LorentzVector": LorentzVector,
"PtEtaPhiMLorentzVector": LorentzVector,
"PtEtaPhiELorentzVector": LorentzVector,
}
_rank = [TwoVector, ThreeVector, LorentzVector]

for lhs, lhs_to in _binary_dispatch_cls.items():
for rhs, rhs_to in _binary_dispatch_cls.items():
out_to = min(lhs_to, rhs_to, key=_rank.index)
behavior[(numpy.add, lhs, rhs)] = out_to.add
behavior[(numpy.subtract, lhs, rhs)] = out_to.subtract


TwoVectorArray.ProjectionClass2D = TwoVectorArray # noqa: F821
TwoVectorArray.ProjectionClass3D = ThreeVectorArray # noqa: F821
TwoVectorArray.ProjectionClass4D = LorentzVectorArray # noqa: F821
Expand Down

0 comments on commit 8946d59

Please sign in to comment.