-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Models for TRIDENT Detector in Graphnet #767
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this - it's wonderful to see contributions from TRIDENT!
General comments:
We need all the pull request checks to pass. Certain code climate checks we can accept not to pass. I highly recommend using pre-commit
to make sure your code conforms to the standards in the library (see here for more details).
I have left detailed comments and suggestions.
If you need help, let us know.
from torch_geometric.data import Data | ||
from graphnet.models.graphs.nodes import NodeDefinition | ||
|
||
class TRIDENTNodeDefinition(NodeDefinition): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class contains quite a few unnecessary arguments and print statements. I think a more descriptive class name along with some comments in the code on what this method actually does would be useful. Here's a suggestion on how to refactor:
class ADescriptiveName(NodeDefinition):
"""
A useful docstring that explains clearly what this definition outputs
"""
def __init__(
self,
input_feature_names: List[str],
xyz_columns: List[str] = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"],
time_column: str = "t",
) -> None:
super().__init__(input_feature_names=input_feature_names)
# Member variables
self._keys = input_feature_names
self._xyz = [self._keys.index(key) for key in xyz_columns]
self._time_index = self._keys.index(time_column)
self._charge_index = len(self._keys)
self._norm_index: int = self._charge_index + 1
def _define_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
# Explain what these features are
return ["nx","ny","nz","t1st","nhits","norm_xyz"]
def _construct_nodes(self, x: torch.Tensor) -> Data:
x = x.numpy()
# Add charge and norm columns
x = np.insert(x, self._charge_index, np.zeros(x.shape[0]), axis=1)
x = np.insert(x, self._norm_index, np.zeros(x.shape[0]), axis=1)
# Sort by time, select origin as first hit dom
x = x[x[:, self._time_index].argsort()]
x[:, self._time_index] -= x[0, self._time_index]
x[:,self._xyz] -= x[0, self._xyz]
# Fill norm column
x[:, self._norm_index] = np.linalg.norm(x[:, self._xyz], axis=1)
x[:, self._xyz] /= x[:, self._norm_index].reshape(-1, 1).clip(min=1e-6)
x[:, self._time_index] *= 0.2998
# Fill charge column
doms = x[:, self._xyz]
unique_values, inverse, dom_counts = np.unique(doms, return_inverse=True, return_counts=True, axis=0)
x[:, self._charge_index] = dom_counts[inverse]
# group doms and set time to median time
x_= []
for ids in unique_values:
mask = np.where((x[:, self._xyz] == ids).all(axis=1))[0]
t_median = np.median(x[mask, self._time_index])
x_.append([*ids, t_median, *x[mask[0], self._charge_index:]])
x = np.array(x_)
return Data(x=torch.tensor(x,dtype=torch.float))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things I found unclear in the code, that could use a short comment or two
- What does the corresponding nodes (["nx","ny","nz","t1st","nhits","norm_xyz"]) represent?
- The "charge" column appears to count number of pulses on the DOM - when the charge is available in the data, would you use that instead of the number of pulses?
- Why is the time column scaled by *= 0.2998?
return x# / 1.05e04 | ||
|
||
|
||
class TRIDENTGraphDefinition(GraphDefinition): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite a lot of the forward
pass seems to be a carbon-copy of the original method. Instead of duplicating the code, let's solve this through inheritance. Also, parts of the code could use comments and more descriptive variable names. See the suggested refractor below:
class ADescriptiveName(GraphDefinition):
""" A useful docstring to help understand what this is"""
def __init__(
self,
detector: Detector,
time_column: str,
input_feature_names: Optional[List[str]] = None,
refractive_index: float = 1.385,
) -> None:
""" Doc string :-) """
node_def = TRIDENTNodeDefinition(input_feature_names = input_feature_names,
xyz_columns = detector.xyz,
time_column = time_column)
super().__init__(
detector=detector,
node_definition=node_def,
)
# Member variables:
self._refractive_index = refractive_index
def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
loss_weight: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
data_path: Optional[str] = None,
) -> Data:
# Parent Class forward pass
graph = super().forward(input_feature_names=input_feature_names,
input_features=input_features,
truth_dicts=truth_dicts,
custom_label_functions=custom_label_functions,
loss_weight_column=loss_weight_column,
loss_weight=loss_weight,
loss_weight_default_value=loss_weight_default_value,
data_path=data_path)
# Specific modifications
if len(input_features) > 0:
first_hit = input_features[torch.min(input_features[:, 3],dim=0)[1]]
graph.pos = torch.stack([graph.nx*graph.norm_xyz,graph.ny*graph.norm_xyz,graph.nz*graph.norm_xyz],dim=1)
graph.vertex = torch.stack([graph.initial_state_x,graph.initial_state_y,graph.initial_state_z],dim=1) - first_hit[0:3]
graph.inject_pos = self._inject_pos(graph)
return graph
def _inject_pos(self, graph):
""" Explain this function """
costh = 1 / self._refractive_index
tanth = math.sqrt(1 - costh*costh) / costh
graph.vertex = graph.vertex.view(-1)
graph.direction = graph.direction.view(-1)
vr = graph.pos - graph.vertex
l = (vr * graph.direction).sum(dim=1).view(-1,1)
d = ((vr**2).sum(dim=1).view(-1,1) - l**2).clip(min=0)**0.5
inject_pos = graph.vertex + (l - d / tanth) * graph.direction - graph.pos
return inject_pos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parts that are unclear to me:
- What is
_inject_pos
calculating? - Hard-coded column slices like
input_features[:, 3]
appears to be unwise to me. What column are you looking for? It would be much better to index the input variable names (i.e.input_feature_names.index(your_column)
)
import os | ||
import math | ||
|
||
class TRIDENT(Detector): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a Detector-class for the prometheus dataset, so I do not think this is needed.
} | ||
return feature_map | ||
|
||
def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would you not scale in the input?
|
||
|
||
class TridentTrackNet(Model): | ||
def __init__(self, settings, DEVICE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch Lightning should automatically place all model layers on the correct device, so I think it's unnecessary to pass this variable and set the layers manually in the __init__
method. Did you try this?
"""Compute and sum losses.""" | ||
data_merged: Dict[str, Tensor] = {} | ||
data_merged['inject_pos'] = torch.cat([d['inject_pos'] for d in data], dim=0).float() | ||
# print(outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you forgot some comments here
logger: Optional[LightningLogger] = None, | ||
log_every_n_steps: int = 1, | ||
gradient_clip_val: Optional[float] = None, | ||
# distribution_strategy: Optional[str] = "ddp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
logger=logger, | ||
log_every_n_steps=log_every_n_steps, | ||
gradient_clip_val=gradient_clip_val, | ||
# distribution_strategy=distribution_strategy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
from pytorch_lightning.strategies import DDPStrategy | ||
|
||
class MiddleReconModel(Model): | ||
"""A suggested Model class that comes with simple user syntax. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doc-string should be much more informative. Include here some details on what this method requires what it does. In your case, the method assumes a very specific backend with a very specific data representation for direction reconstruction only, right? Mention that here and feel free to link to your paper
def __init__( | ||
self, | ||
*, | ||
backbone: Model = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this method is written specifically for your "TridentNeT" method, right? In that case, I think it's best that we do not leave the backbone
as an argument, but instead hardcode this, such that users can import your model directly. I.e.
from graphnet.models import MiddleReconModel
model = MiddleReconModel(...)
train_dataloader, test_dataloader = ...
model.fit(train_dataloader = ..)
predictions = model.predict( test_dataloader = ..)
This PR extends the Graphnet project to include support for the TRIDENT detector by adding specific code files for training and testing. The changes include four new files which are described in detail below.
Changes Made:
TRIDENTNodeDefinition.py
: Defines the node structureTRIDENTGraphDefinition
used for training with TRIDENT data.TRIDENTGraphDefinition.py
: Contains both theTRIDENT
detector information and the graph definition classTRIDENTGraphDefinition
to represent the TRIDENT detector structure and graph type.MiddleReconModel.py
: Defines the modelMiddleReconModel
used in TRIDENT training. Key fuctions include:compute_loss
,forward
,shared_step
,construct_trainer
,fit
,_print_callbacks
,_contains_callback
,configure_optimizers
,training_step
,validation_step
,predict_step
,inference
,train
,predict
,predict_as_dataframe
,_create_default_callbacks
and_add_early_stopping
.TridentNet.py
: Contains three network classes used in TRIDENT training:StaticEdgeConv
,DynamicEdgeConv
, andTridentTrackNet
.Additional Notes: