Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Models for TRIDENT Detector in Graphnet #767

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

040601
Copy link

@040601 040601 commented Nov 8, 2024

This PR extends the Graphnet project to include support for the TRIDENT detector by adding specific code files for training and testing. The changes include four new files which are described in detail below.

Changes Made:

  • TRIDENTNodeDefinition.py: Defines the node structure TRIDENTGraphDefinition used for training with TRIDENT data.
  • TRIDENTGraphDefinition.py: Contains both the TRIDENT detector information and the graph definition class TRIDENTGraphDefinition to represent the TRIDENT detector structure and graph type.
  • MiddleReconModel.py: Defines the model MiddleReconModel used in TRIDENT training. Key fuctions include:
    compute_loss, forward, shared_step, construct_trainer, fit, _print_callbacks, _contains_callback, configure_optimizers, training_step, validation_step, predict_step, inference, train, predict, predict_as_dataframe, _create_default_callbacks and _add_early_stopping.
  • TridentNet.py: Contains three network classes used in TRIDENT training: StaticEdgeConv, DynamicEdgeConv, and TridentTrackNet.

Additional Notes:

  • We would also like to seek some advice on where it would make more sense to place the added code files.

@040601
Copy link
Author

040601 commented Nov 8, 2024

tagging @wlhwl @cmo-ft

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this - it's wonderful to see contributions from TRIDENT!

General comments:
We need all the pull request checks to pass. Certain code climate checks we can accept not to pass. I highly recommend using pre-commit to make sure your code conforms to the standards in the library (see here for more details).

I have left detailed comments and suggestions.

If you need help, let us know.

from torch_geometric.data import Data
from graphnet.models.graphs.nodes import NodeDefinition

class TRIDENTNodeDefinition(NodeDefinition):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class contains quite a few unnecessary arguments and print statements. I think a more descriptive class name along with some comments in the code on what this method actually does would be useful. Here's a suggestion on how to refactor:

class ADescriptiveName(NodeDefinition):
    """
    A useful docstring that explains clearly what this definition outputs
    """

    def __init__(
        self,
        input_feature_names: List[str],
        xyz_columns: List[str] = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"],
        time_column: str = "t",
    ) -> None:
        super().__init__(input_feature_names=input_feature_names)
        # Member variables
        self._keys = input_feature_names
        self._xyz = [self._keys.index(key) for key in xyz_columns]
        self._time_index = self._keys.index(time_column)
        self._charge_index = len(self._keys)
        self._norm_index: int = self._charge_index + 1

    def _define_output_feature_names(
        self, input_feature_names: List[str]
    ) -> List[str]:
        # Explain what these features are
        return ["nx","ny","nz","t1st","nhits","norm_xyz"]

    def _construct_nodes(self, x: torch.Tensor) -> Data:
        x = x.numpy()
        # Add charge and norm columns
        x = np.insert(x, self._charge_index, np.zeros(x.shape[0]), axis=1)
        x = np.insert(x, self._norm_index, np.zeros(x.shape[0]), axis=1)

        # Sort by time, select origin as first hit dom
        x = x[x[:, self._time_index].argsort()]
        x[:, self._time_index] -= x[0, self._time_index]
        x[:,self._xyz] -= x[0, self._xyz]

        # Fill norm column
        x[:, self._norm_index] = np.linalg.norm(x[:, self._xyz], axis=1)

        x[:, self._xyz] /= x[:, self._norm_index].reshape(-1, 1).clip(min=1e-6)
        x[:, self._time_index] *= 0.2998

        # Fill charge column
        doms = x[:, self._xyz]
        unique_values, inverse, dom_counts = np.unique(doms, return_inverse=True, return_counts=True, axis=0)

        x[:, self._charge_index] = dom_counts[inverse]

        # group doms and set time to median time
        x_= []
        for ids in unique_values:
            mask = np.where((x[:, self._xyz] == ids).all(axis=1))[0]
            t_median = np.median(x[mask, self._time_index])
            x_.append([*ids, t_median, *x[mask[0], self._charge_index:]])

        x = np.array(x_)             
        return Data(x=torch.tensor(x,dtype=torch.float))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things I found unclear in the code, that could use a short comment or two

  • What does the corresponding nodes (["nx","ny","nz","t1st","nhits","norm_xyz"]) represent?
  • The "charge" column appears to count number of pulses on the DOM - when the charge is available in the data, would you use that instead of the number of pulses?
  • Why is the time column scaled by *= 0.2998?

return x# / 1.05e04


class TRIDENTGraphDefinition(GraphDefinition):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a lot of the forward pass seems to be a carbon-copy of the original method. Instead of duplicating the code, let's solve this through inheritance. Also, parts of the code could use comments and more descriptive variable names. See the suggested refractor below:

class ADescriptiveName(GraphDefinition):
    """ A useful docstring to help understand what this is"""
    def __init__(
        self,
        detector: Detector,
        time_column: str,
        input_feature_names: Optional[List[str]] = None,
        refractive_index: float = 1.385,
    ) -> None:
        """ Doc string :-) """
        node_def = TRIDENTNodeDefinition(input_feature_names = input_feature_names,
                                         xyz_columns = detector.xyz,
                                         time_column = time_column)
        super().__init__(
            detector=detector,
            node_definition=node_def,
            )

        # Member variables:
        self._refractive_index = refractive_index

    def forward(  # type: ignore
        self,
        input_features: np.ndarray,
        input_feature_names: List[str],
        truth_dicts: Optional[List[Dict[str, Any]]] = None,
        custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
        loss_weight_column: Optional[str] = None,
        loss_weight: Optional[float] = None,
        loss_weight_default_value: Optional[float] = None,
        data_path: Optional[str] = None,
    ) -> Data:
        # Parent Class forward pass
        graph = super().forward(input_feature_names=input_feature_names,
                                input_features=input_features,
                                truth_dicts=truth_dicts,
                                custom_label_functions=custom_label_functions,
                                loss_weight_column=loss_weight_column,
                                loss_weight=loss_weight,
                                loss_weight_default_value=loss_weight_default_value,
                                data_path=data_path)
        
        # Specific modifications 
        if len(input_features) > 0:
            first_hit = input_features[torch.min(input_features[:, 3],dim=0)[1]]
            graph.pos = torch.stack([graph.nx*graph.norm_xyz,graph.ny*graph.norm_xyz,graph.nz*graph.norm_xyz],dim=1)
            graph.vertex = torch.stack([graph.initial_state_x,graph.initial_state_y,graph.initial_state_z],dim=1) - first_hit[0:3]
            graph.inject_pos = self._inject_pos(graph)

        return graph    

    def _inject_pos(self, graph):
        """ Explain this function """
        costh = 1 / self._refractive_index
        tanth = math.sqrt(1 - costh*costh) / costh
        graph.vertex = graph.vertex.view(-1)
        graph.direction = graph.direction.view(-1)

        vr = graph.pos - graph.vertex
        l = (vr * graph.direction).sum(dim=1).view(-1,1)
        d = ((vr**2).sum(dim=1).view(-1,1) - l**2).clip(min=0)**0.5
        inject_pos = graph.vertex + (l -  d / tanth) * graph.direction - graph.pos
        return inject_pos 

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parts that are unclear to me:

  • What is _inject_pos calculating?
  • Hard-coded column slices like input_features[:, 3] appears to be unwise to me. What column are you looking for? It would be much better to index the input variable names (i.e. input_feature_names.index(your_column))

import os
import math

class TRIDENT(Detector):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a Detector-class for the prometheus dataset, so I do not think this is needed.

}
return feature_map

def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not scale in the input?



class TridentTrackNet(Model):
def __init__(self, settings, DEVICE):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch Lightning should automatically place all model layers on the correct device, so I think it's unnecessary to pass this variable and set the layers manually in the __init__ method. Did you try this?

"""Compute and sum losses."""
data_merged: Dict[str, Tensor] = {}
data_merged['inject_pos'] = torch.cat([d['inject_pos'] for d in data], dim=0).float()
# print(outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you forgot some comments here

logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
# distribution_strategy: Optional[str] = "ddp",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment

logger=logger,
log_every_n_steps=log_every_n_steps,
gradient_clip_val=gradient_clip_val,
# distribution_strategy=distribution_strategy,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment

from pytorch_lightning.strategies import DDPStrategy

class MiddleReconModel(Model):
"""A suggested Model class that comes with simple user syntax.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc-string should be much more informative. Include here some details on what this method requires what it does. In your case, the method assumes a very specific backend with a very specific data representation for direction reconstruction only, right? Mention that here and feel free to link to your paper

def __init__(
self,
*,
backbone: Model = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method is written specifically for your "TridentNeT" method, right? In that case, I think it's best that we do not leave the backbone as an argument, but instead hardcode this, such that users can import your model directly. I.e.

from graphnet.models import MiddleReconModel

model = MiddleReconModel(...)
train_dataloader, test_dataloader = ... 
model.fit(train_dataloader = ..)
predictions = model.predict( test_dataloader = ..)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants