diff --git a/src/graphnet/data/constants.py b/src/graphnet/data/constants.py index 10ed4c66e..02a9e5a46 100644 --- a/src/graphnet/data/constants.py +++ b/src/graphnet/data/constants.py @@ -50,6 +50,8 @@ class TRUTH: "interaction_type", "interaction_time", # Added for vertex reconstruction "inelasticity", + "visible_inelasticity", + "visible_energy", "stopped_muon", ] DEEPCORE = ICECUBE86 diff --git a/src/graphnet/data/extractors/icecube/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py index 4db330fc0..1c101def7 100644 --- a/src/graphnet/data/extractors/icecube/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.path as mpath +from scipy.spatial import ConvexHull, Delaunay from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from .i3extractor import I3Extractor @@ -12,10 +13,12 @@ from graphnet.utilities.imports import has_icecube_package if has_icecube_package() or TYPE_CHECKING: - from icecube import ( + from icecube import ( # noqa: F401 dataclasses, icetray, phys_services, + dataio, + LeptonInjector, ) # pyright: reportMissingImports=false @@ -27,6 +30,7 @@ def __init__( name: str = "truth", borders: Optional[List[np.ndarray]] = None, mctree: Optional[str] = "I3MCTree", + extend_boundary: Optional[float] = 0.0, ): """Construct I3TruthExtractor. @@ -37,6 +41,8 @@ def __init__( stopping within the detector. Defaults to hard-coded boundary coordinates. mctree: Str of which MCTree to use for truth values. + extend_boundary: Distance to extend the convex hull of the detector + for defining starting events. """ # Base class constructor super().__init__(name) @@ -78,15 +84,53 @@ def __init__( self._borders = [border_xy, border_z] else: self._borders = borders + + self._extend_boundary = extend_boundary self._mctree = mctree + def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None: + """Extract GFrame and CFrame from i3/gcd-file pair. + + Information from these frames will be set as member variables of + `I3Extractor.` + + Args: + i3_file: Path to i3 file that is being converted. + gcd_file: Path to GCD file. Defaults to None. If no GCD file is + given, the method will attempt to find C and G frames in + the i3 file instead. If either one of those are not + present, `RuntimeErrors` will be raised. + """ + super().set_gcd(i3_file=i3_file, gcd_file=gcd_file) + + # Modifications specific to I3TruthExtractor + # These modifications are needed to identify starting events + coordinates = [] + for _, g in self._gcd_dict.items(): + if g.position.z > 1200: + continue # We want to exclude icetop + coordinates.append([g.position.x, g.position.y, g.position.z]) + coordinates = np.array(coordinates) + + if self._extend_boundary != 0.0: + center = np.mean(coordinates, axis=0) + d = coordinates - center + norms = np.linalg.norm(d, axis=1, keepdims=True) + dn = d / norms + coordinates = coordinates + dn * self._extend_boundary + + hull = ConvexHull(coordinates) + + self.hull = hull + self.delaunay = Delaunay(coordinates[self.hull.vertices]) + def __call__( self, frame: "icetray.I3Frame", padding_value: Any = -1 ) -> Dict[str, Any]: """Extract truth-level information.""" is_mc = frame_is_montecarlo(frame, self._mctree) is_noise = frame_is_noise(frame, self._mctree) - sim_type = self._find_data_type(is_mc, self._i3_file) + sim_type = self._find_data_type(is_mc, self._i3_file, frame) output = { "energy": padding_value, @@ -119,6 +163,7 @@ def __call__( "L5_oscNext_bool": padding_value, "L6_oscNext_bool": padding_value, "L7_oscNext_bool": padding_value, + "is_starting": padding_value, } # Only InIceSplit P frames contain ML appropriate @@ -230,6 +275,13 @@ def __call__( } ) + is_starting = self._contained_vertex(output) + output.update( + { + "is_starting": is_starting, + } + ) + return output def _extract_dbang_decay_length( @@ -374,15 +426,34 @@ def _get_primary_particle_interaction_type_and_elasticity( # all variables and has no nans (always muon) else: MCInIcePrimary = None - try: - interaction_type = frame["I3MCWeightDict"]["InteractionType"] - except KeyError: - interaction_type = padding_value - try: - elasticity = frame["I3GENIEResultDict"]["y"] - except KeyError: - elasticity = padding_value + if sim_type == "LeptonInjector": + event_properties = frame["EventProperties"] + final_state_1 = event_properties.finalType1 + if final_state_1 in [ + dataclasses.I3Particle.NuE, + dataclasses.I3Particle.NuMu, + dataclasses.I3Particle.NuTau, + dataclasses.I3Particle.NuEBar, + dataclasses.I3Particle.NuMuBar, + dataclasses.I3Particle.NuTauBar, + ]: + interaction_type = 2 # NC + else: + interaction_type = 1 # CC + + elasticity = 1 - event_properties.finalStateY + + else: + try: + interaction_type = frame["I3MCWeightDict"]["InteractionType"] + except KeyError: + interaction_type = int(padding_value) + + try: + elasticity = 1 - frame["I3MCWeightDict"]["BjorkenY"] + except KeyError: + elasticity = padding_value return MCInIcePrimary, interaction_type, elasticity @@ -418,12 +489,15 @@ def _get_primary_track_energy_and_inelasticity( return energy_track, energy_cascade, inelasticity # Utility methods - def _find_data_type(self, mc: bool, input_file: str) -> str: + def _find_data_type( + self, mc: bool, input_file: str, frame: "icetray.I3Frame" + ) -> str: """Determine the data type. Args: mc: Whether `input_file` is Monte Carlo simulation. input_file: Path to I3-file. + frame: Physics frame containing MC record Returns: The simulation/data type. @@ -439,8 +513,26 @@ def _find_data_type(self, mc: bool, input_file: str) -> str: sim_type = "genie" elif "noise" in input_file: sim_type = "noise" - elif "L2" in input_file: # not robust - sim_type = "dbang" - else: + elif frame.Has("EventProprties") or frame.Has( + "LeptonInjectorProperties" + ): + sim_type = "LeptonInjector" + elif frame.Has("I3MCWeightDict"): sim_type = "NuGen" + else: + raise NotImplementedError("Could not determine data type.") return sim_type + + def _contained_vertex(self, truth: Dict[str, Any]) -> bool: + """Determine if an event is starting based on vertex position. + + Args: + truth: Dictionary of already extracted truth-level information. + + Returns: + True/False if vertex is inside detector. + """ + vertex = np.array( + [truth["position_x"], truth["position_y"], truth["position_z"]] + ) + return self.delaunay.find_simplex(vertex) >= 0 diff --git a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py index db115bd21..dfbfd06c0 100644 --- a/src/graphnet/data/extractors/icecube/utilities/i3_filters.py +++ b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py @@ -64,6 +64,30 @@ def _keep_frame(self, frame: "icetray.I3Frame") -> bool: return True +class SubEventStreamI3Filter(I3Filter): + """A filter that only keeps frames from select splits.""" + + def __init__(self, selection: List[str]): + """Initialize SubEventStreamI3Filter. + + Args: + selection: List of subevent streams to keep. + """ + self._selection = selection + + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Check if current frame should be kept. + + Args: + frame: I3-frame + The I3-frame to check. + """ + if frame.Has("I3EventHeader"): + if frame["I3EventHeader"].sub_event_stream not in self._selection: + return False + return True + + class I3FilterMask(I3Filter): """Checks list of filters from the FilterMask in I3 frames.""" diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 558ec96f4..e8f8d749d 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -326,6 +326,7 @@ def __init__( "z_offset": None, "z_scaling": None, }, + sample_pulses: bool = True, ) -> None: """Construct `IceMixNodes`. @@ -339,6 +340,9 @@ def __init__( ice in IceCube are added to the feature set based on z coordinate. ice_args: Offset and scaling of the z coordinate in the Detector, to be able to make similar conversion in the ice data. + sample_pulses: Enable sampling random pulses. If True and the + event is longer than the max_length, they will be sampled. If + False, then only the first max_length pulses will be selected. """ if input_feature_names is None: input_feature_names = [ @@ -384,6 +388,7 @@ def __init__( self.z_name = z_name self.hlc_name = hlc_name self.add_ice_properties = add_ice_properties + self.sampling_enabled = sample_pulses def _define_output_feature_names( self, input_feature_names: List[str] @@ -437,7 +442,14 @@ def _construct_nodes(self, x: torch.Tensor) -> Tuple[Data, List[str]]: x[:, self.feature_indexes[self.hlc_name]] = torch.logical_not( x[:, self.feature_indexes[self.hlc_name]] ) # hlc in kaggle was flipped - ids = self._pulse_sampler(x, event_length) + if self.sampling_enabled: + ids = self._pulse_sampler(x, event_length) + else: + if event_length < self.max_length: + ids = torch.arange(event_length) + else: + ids = torch.arange(self.max_length) + event_length = min(self.max_length, event_length) graph = torch.zeros([event_length, self.n_features]) diff --git a/src/graphnet/models/task/reconstruction.py b/src/graphnet/models/task/reconstruction.py index 5408aa5b9..18a3794e1 100644 --- a/src/graphnet/models/task/reconstruction.py +++ b/src/graphnet/models/task/reconstruction.py @@ -232,3 +232,19 @@ class InelasticityReconstruction(StandardLearnedTask): def _forward(self, x: Tensor) -> Tensor: # Transform output to unit range return torch.sigmoid(x) + + +class VisibleInelasticityReconstruction(StandardLearnedTask): + """Reconstructs interaction visible inelasticity. + + That is, 1-(visible track energy / visible hadronic energy). + """ + + # Requires one features: inelasticity itself + default_target_labels = ["visible_inelasticity"] + default_prediction_labels = ["visible_inelasticity_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Transform output to unit range + return 0.5 * (torch.tanh(2.0 * x) + 1.0) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 6d7c27c06..534d095eb 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -63,6 +63,14 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Syntax like `.forward`, for implentation in inheriting classes.""" +class MAELoss(LossFunction): + """Mean absolute error loss.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Implement loss calculation.""" + return torch.mean(torch.abs(prediction - target), dim=-1) + + class MSELoss(LossFunction): """Mean squared error loss."""