Skip to content

Commit

Permalink
emd loss + updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed Sep 21, 2021
1 parent 948d364 commit e9d5dee
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 39 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include README.md
recursive-include jetnet/evaluation/fpnd_resources *.txt *.pt
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
# JetNet

*One jet library to rule them all.*

A library for developing and reproducing jet-based machine learning projects. Currently under development.
A library for developing and reproducing jet-based machine learning (ML) projects.

JetNet provides common standardized PyTorch-based datasets, evaluation metrics, and loss functions for working with jets using ML. Currently supports the flagship JetNet dataset, and the Fréchet ParticleNet Distance (FPDN), Wasserstein-1 (W1), coverage and minimum matching distance (MMD) all introduced in Ref. [[1](#References)], as well as jet utilities and differentiable implementation of the energy mover's distance [[2](#References)] for use as a loss function. Additional functionality is currently under development.


## Installation

JetNet can be installed with pip:

```bash
pip install jetnet
```

To use the differentiable EMD loss `jetnet.losses.EMDLoss`, additional libraries must be installed via

```bash
pip install jetnet[emdloss]
```

Finally [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) must be installed independently for the Fréchet ParticleNet Distance metric `jetnet.evaluation.fpnd` ([Installation instructions](https://github.com/pyg-team/pytorch_geometric#installation)).


## Documentation

The API reference is available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/).

More detailed information about each dataset can (or will) be found at [jet-net.github.io](https://jet-net.github.io/jetnet/).

Tutorials for datasets and functions are coming soon.


### References

[1] R. Kansal et al. *Particle Cloud Generation with Message Passing Generative Adversarial Networks* (2021) [[2106.11535](https://arxiv.org/abs/2106.11535)]

[2] P. T. Komiske, E. M. Metodiev, and J. Thaler, _The Metric Space of Collider Events_, [Phys. Rev. Lett. __123__ (2019) 041801](https://doi.org/10.1103/PhysRevLett.123.041801) [[1902.02346](https://arxiv.org/abs/1902.02346)].
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import sys
sys.path.insert(0, os.path.abspath('../'))
autodoc_mock_imports = ['energyflow', 'awkward', 'coffea', 'tqdm', 'scipy', 'torch_geometric', 'torch']
autodoc_mock_imports = ['energyflow', 'awkward', 'coffea', 'tqdm', 'scipy', 'torch_geometric', 'torch', 'cvxpy', 'qpth']

# -- Project information -----------------------------------------------------

Expand Down
10 changes: 9 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
JetNet
==================================

JetNet is a library for all your ML jet needs.
API reference for the JetNet library.

.. contents:: Contents
:local:
Expand All @@ -29,6 +29,14 @@ Evaluation
:exclude-members: JetNet


Losses
**********************************
.. automodule:: jetnet.losses
:members:
:imported-members:
:autosummary:


Utility Functions
**********************************
.. automodule:: jetnet.utils
Expand Down
1 change: 1 addition & 0 deletions jetnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# IMPORTANT: evaluation has to be imported first since energyflow must be imported before torch because of https://github.com/pkomiske/EnergyFlow/issues/24
import jetnet.evaluation
import jetnet.datasets
import jetnet.losses
import jetnet.utils
1 change: 0 additions & 1 deletion jetnet/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
# from .particlenet import *
from .gen_metrics import *
52 changes: 22 additions & 30 deletions jetnet/evaluation/gen_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# energyflow needs to be imported before pytorch because of https://github.com/pkomiske/EnergyFlow/issues/24
from energyflow.emd import emds
from energyflow import EFPSet

import logging
import warnings
Expand Down Expand Up @@ -89,7 +88,7 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
_eval_module.fpnd_dict = {"NUM_SAMPLES": 50000}


def _init_fpnd_dict(dataset_name: str, jet_type: str, num_particles: int, num_particle_features: int, device: str = ""):
def _init_fpnd_dict(dataset_name: str, jet_type: str, num_particles: int, num_particle_features: int, device: str = "cpu"):
try:
from .particlenet import _ParticleNet
except ModuleNotFoundError:
Expand Down Expand Up @@ -117,24 +116,21 @@ def _init_fpnd_dict(dataset_name: str, jet_type: str, num_particles: int, num_pa


# TODO !!! check gen jets are not in place normalized !!!
def fpnd(
jets: Union[Tensor, np.ndarray], jet_type: str, use_mask: bool = True, dataset_name: str = "JetNet", device: str = "", batch_size: int = 16
) -> float:
def fpnd(jets: Union[Tensor, np.ndarray], jet_type: str, dataset_name: str = "JetNet", device: str = None, batch_size: int = 16) -> float:
"""
Calculates the Frechet ParticleNet Distance, as defined in https://arxiv.org/abs/2106.11535, for input ``jets`` of type ``jet_type``.
``jets`` are passed through our pretrained ParticleNet module and activations are compared with the cached activations from real jets.
The recommended and max number of jets is 50,000
**torch_geometric must be installed separately for running inference with ParticleNet**
**torch_geometric must be installed separately for running inference with ParticleNet**
Currently FPND only supported for the JetNet dataset with 30 particles,
but functionality for other datasets + ability for users to use their own version is in development.
Args:
jets (Union[Tensor, np.ndarray]): Tensor or array of jets, of shape ``[num_jets, num_particles, num_features]`` with features in order ``[eta, phi, pt, (optional) mask]``
jet_type (str): jet type, out of ``['g', 't', 'q']``.
use_mask (bool): Use the last binary mask feature to zero the 0-masked particles. Defaults to True.
dataset_name (str): Dataset to use. Currently only JetNet is supported. Defaults to "JetNet".
device (str): 'cpu' or 'cuda'. If not specified, defaults to cuda if available else cpu.
batch_size (int): Batch size for ParticleNet inference. Defaults to 16.
Expand All @@ -146,7 +142,7 @@ def fpnd(
assert dataset_name == "JetNet", "Only JetNet is currently supported with FPND"

num_particles = jets.shape[1]
num_particle_features = jets.shape[2] - int(use_mask)
num_particle_features = jets.shape[2]

assert num_particles == 30, "Currently FPND only supported for 30 particles - more functionality coming soon."
assert num_particle_features == 3, "Not the right number of particle features for the JetNet dataset."
Expand All @@ -157,7 +153,7 @@ def fpnd(
if isinstance(jets, np.ndarray):
jets = Tensor(jets)

if device == "":
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

assert device == "cuda" or device == "cpu", "Invalid device type"
Expand All @@ -166,10 +162,10 @@ def fpnd(
JetNet.normalize_features(jets, fpnd=True)
# TODO other datasets

if use_mask:
# features for all masked paricles are set to 0 and mask feature is removed
mask = jets[:, :, -1:] > 0
jets = (jets * mask)[:, :, :-1]
# if use_mask:
# # features for all masked paricles are set to 0 and mask feature is removed
# mask = jets[:, :, -1:] > 0
# jets = (jets * mask)[:, :, :-1]

# ParticleNet module and the real mu's and sigma's are only loaded once
if (
Expand Down Expand Up @@ -284,6 +280,7 @@ def w1p(
def w1m(
jets1: Union[Tensor, np.ndarray],
jets2: Union[Tensor, np.ndarray],
use_particle_masses: bool = False,
num_eval_samples: int = 10000,
num_batches: int = 5,
average_over_features: bool = True,
Expand All @@ -293,7 +290,7 @@ def w1m(
Get 1-Wasserstein distance between masses of ``jets1`` and ``jets2``.
Args:
jets1 (Union[Tensor, np.ndarray]): Tensor or array of jets, of shape ``[num_jets, num_particles, num_features]`` with features in order ``[eta, phi, pt]``
jets1 (Union[Tensor, np.ndarray]): Tensor or array of jets, of shape ``[num_jets, num_particles, num_features]`` with features in order ``[eta, phi, pt, (optional) mass]``
jets2 (Union[Tensor, np.ndarray]): Tensor or array of jets, of same format as ``jets1``.
num_eval_samples (int): Number of jets out of the total to use for W1 measurement. Defaults to 10000.
num_batches (int): Number of different batches to average W1 scores over. Defaults to 5.
Expand Down Expand Up @@ -333,7 +330,7 @@ def w1m(
def w1efp(
jets1: Union[Tensor, np.ndarray],
jets2: Union[Tensor, np.ndarray],
particle_masses: bool = False,
use_particle_masses: bool = False,
efpset_args: list = [("n==", 4), ("d==", 4), ("p==", 1)],
num_eval_samples: int = 10000,
num_batches: int = 5,
Expand All @@ -347,7 +344,7 @@ def w1efp(
jets1 (Union[Tensor, np.ndarray]): Tensor or array of jets of shape ``[num_jets, num_particles, num_features]``, with features in order ``[eta, phi, pt, (optional) mass]``.
If no particle masses given (``particle_masses`` should be False), they are assumed to be 0.
jets2 (Union[Tensor, np.ndarray]): Tensor or array of jets, of same format as ``jets1``.
particle_masses (bool): Whether ``jets1`` and ``jets2`` have particle masses as their 4th particle features. Defaults to False.
use_particle_masses (bool): Whether ``jets1`` and ``jets2`` have particle masses as their 4th particle features. Defaults to False.
efpset_args (List): Args for the energyflow.efpset function to specify which EFPs to use, as defined here https://energyflow.network/docs/efp/#efpset.
Defaults to the n=4, d=5, prime EFPs.
num_eval_samples (int): Number of jets out of the total to use for W1 measurement. Defaults to 10000.
Expand All @@ -371,21 +368,12 @@ def w1efp(
jets2 = jets2.cpu().detach().numpy()

assert len(jets1.shape) == 3 and len(jets2.shape) == 3, "input jets format is incorrect"
assert (jets1.shape[2] == 3 and not particle_masses) or (jets1.shape[2] == 4 and particle_masses), "particle feature format is incorrect"
assert (jets2.shape[2] == 3 and not particle_masses) or (jets2.shape[2] == 4 and particle_masses), "particle feature format is incorrect"
assert (jets1.shape[2] - int(use_particle_masses) >= 3) and (
jets1.shape[2] - int(use_particle_masses) >= 3
), "particle feature format is incorrect"

# convert from JetNet [eta, phi, pt] format to energyflow [pt, eta, phi]
if particle_masses:
jets1 = jets1[:, :, [2, 0, 1, 3]]
jets2 = jets2[:, :, [2, 0, 1, 3]]
else:
# pad 0 mass as the 4th feature for each particle
jets1 = np.pad(jets1[:, :, [2, 0, 1]], ((0, 0), (0, 0), (0, 1)))
jets2 = np.pad(jets2[:, :, [2, 0, 1]], ((0, 0), (0, 0), (0, 1)))

efpset = EFPSet(*efpset_args, measure="hadr", beta=1, normed=None, coords="ptyphim")
efps1 = efpset.batch_compute(jets1)
efps2 = efpset.batch_compute(jets2)
efps1 = utils.efps(jets1, use_particle_masses=use_particle_masses, efpset_args=efpset_args)
efps2 = utils.efps(jets2, use_particle_masses=use_particle_masses, efpset_args=efpset_args)
num_efps = efps1.shape[1]

w1s = []
Expand Down Expand Up @@ -431,13 +419,17 @@ def cov_mmd(
- **float**: MMD, averaged over ``num_batches``.
"""
assert len(real_jets.shape) == 3 and len(gen_jets.shape) == 3, "input jets format is incorrect"
assert (real_jets.shape[2] >= 3) and (gen_jets.shape[2] >= 3), "particle feature format is incorrect"

if isinstance(real_jets, Tensor):
real_jets = real_jets.cpu().detach().numpy()

if isinstance(gen_jets, Tensor):
gen_jets = gen_jets.cpu().detach().numpy()

assert np.all(real_jets[:, :, 2] >= 0) and np.all(gen_jets[:, :, 2] >= 0), "particle pTs must all be >= 0 for EMD calculation"

# convert from JetNet [eta, phi, pt] format to energyflow [pt, eta, phi]
real_jets = real_jets[:, :, [2, 0, 1]]
gen_jets = gen_jets[:, :, [2, 0, 1]]
Expand Down
1 change: 1 addition & 0 deletions jetnet/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .losses import *
Loading

0 comments on commit e9d5dee

Please sign in to comment.