Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
farakiko committed Dec 2, 2024
1 parent e30344c commit a603778
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 66 deletions.
71 changes: 17 additions & 54 deletions mlpf/data/key4hep/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def __init__(
self.cluster_features = cluster_features # feature matrix of the calo clusters
self.track_features = track_features # feature matrix of the tracks
self.genparticle_to_hit = genparticle_to_hit # sparse COO matrix of genparticles to hits (idx_gp, idx_hit, weight)
self.genparticle_to_track = (
genparticle_to_track # sparse COO matrix of genparticles to tracks (idx_gp, idx_track, weight)
)
self.genparticle_to_track = genparticle_to_track # sparse COO matrix of genparticles to tracks (idx_gp, idx_track, weight)
self.hit_to_cluster = hit_to_cluster # sparse COO matrix of hits to clusters (idx_hit, idx_cluster, weight)
self.gp_merges = gp_merges # sparse COO matrix of any merged genparticles

Expand Down Expand Up @@ -250,10 +248,7 @@ def get_calohit_matrix_and_genadj(dataset, hit_data, calohit_links, iev, collect
hit_idx_global += 1
hit_idx_local_to_global = {v: k for k, v in hit_idx_global_to_local.items()}
hit_feature_matrix = awkward.Record(
{
k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))])
for k in hit_feature_matrix[0].fields
}
{k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))]) for k in hit_feature_matrix[0].fields}
)

# add all edges from genparticle to calohit
Expand Down Expand Up @@ -348,9 +343,7 @@ def gen_to_features(dataset, prop_data, iev):
gen_arr = {k.replace(mc_coll + ".", ""): gen_arr[k] for k in gen_arr.fields}

MCParticles_p4 = vector.awk(
awkward.zip(
{"mass": gen_arr["mass"], "x": gen_arr["momentum.x"], "y": gen_arr["momentum.y"], "z": gen_arr["momentum.z"]}
)
awkward.zip({"mass": gen_arr["mass"], "x": gen_arr["momentum.x"], "y": gen_arr["momentum.y"], "z": gen_arr["momentum.z"]})
)
gen_arr["pt"] = MCParticles_p4.pt
gen_arr["eta"] = MCParticles_p4.eta
Expand Down Expand Up @@ -535,9 +528,7 @@ def track_to_features(dataset, prop_data, iev):
if dataset == "clic":
ret[k] = awkward.to_numpy(prop_data["SiTracks_1"]["SiTracks_1." + k][iev][trackstate_idx])
elif dataset == "fcc":
ret[k] = awkward.to_numpy(
prop_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx]
)
ret[k] = awkward.to_numpy(prop_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx])

else:
raise Exception("--dataset provided is not supported. Only 'fcc' or 'clic' are supported atm.")
Expand Down Expand Up @@ -626,9 +617,7 @@ def add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_tr

def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links, sitrack_links, iev, collectionIDs):
gen_features = gen_to_features(dataset, prop_data, iev)
hit_features, genparticle_to_hit, hit_idx_local_to_global = get_calohit_matrix_and_genadj(
dataset, hit_data, calohit_links, iev, collectionIDs
)
hit_features, genparticle_to_hit, hit_idx_local_to_global = get_calohit_matrix_and_genadj(dataset, hit_data, calohit_links, iev, collectionIDs)
hit_to_cluster = hit_cluster_adj(dataset, prop_data, hit_idx_local_to_global, iev)
cluster_features = cluster_to_features(prop_data, hit_features, hit_to_cluster, iev)
track_features = track_to_features(dataset, prop_data, iev)
Expand All @@ -638,21 +627,15 @@ def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links
mask_status1 = gen_features["generatorStatus"] == 1

if gen_features["index"] is not None: # if there are even daughters
genparticle_to_hit, genparticle_to_trk = add_daughters_to_status1(
gen_features, genparticle_to_hit, genparticle_to_trk
)
genparticle_to_hit, genparticle_to_trk = add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_trk)

n_gp = awkward.count(gen_features["PDG"])
n_track = awkward.count(track_features["type"])
n_hit = awkward.count(hit_features["type"])
n_cluster = awkward.count(cluster_features["type"])

if len(genparticle_to_trk[0]) > 0:
gp_to_track = (
coo_matrix((genparticle_to_trk[2], (genparticle_to_trk[0], genparticle_to_trk[1])), shape=(n_gp, n_track))
.max(axis=1)
.todense()
)
gp_to_track = coo_matrix((genparticle_to_trk[2], (genparticle_to_trk[0], genparticle_to_trk[1])), shape=(n_gp, n_track)).max(axis=1).todense()
else:
gp_to_track = np.zeros((n_gp, 1))

Expand Down Expand Up @@ -682,9 +665,7 @@ def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links

if len(np.array(mask_visible)) == 1:
# event has only one particle (then index will be empty because no daughters)
gen_features = awkward.Record(
{feat: (gen_features[feat][mask_visible] if feat != "index" else None) for feat in gen_features.keys()}
)
gen_features = awkward.Record({feat: (gen_features[feat][mask_visible] if feat != "index" else None) for feat in gen_features.keys()})
else:
gen_features = awkward.Record({feat: gen_features[feat][mask_visible] for feat in gen_features.keys()})

Expand Down Expand Up @@ -717,12 +698,8 @@ def assign_genparticles_to_obj_and_merge(gpdata):
).todense()
)

gp_to_calohit = coo_matrix(
(gpdata.genparticle_to_hit[2], (gpdata.genparticle_to_hit[0], gpdata.genparticle_to_hit[1])), shape=(n_gp, n_hit)
)
calohit_to_cluster = coo_matrix(
(gpdata.hit_to_cluster[2], (gpdata.hit_to_cluster[0], gpdata.hit_to_cluster[1])), shape=(n_hit, n_cluster)
)
gp_to_calohit = coo_matrix((gpdata.genparticle_to_hit[2], (gpdata.genparticle_to_hit[0], gpdata.genparticle_to_hit[1])), shape=(n_gp, n_hit))
calohit_to_cluster = coo_matrix((gpdata.hit_to_cluster[2], (gpdata.hit_to_cluster[0], gpdata.hit_to_cluster[1])), shape=(n_hit, n_cluster))

gp_to_cluster = np.array((gp_to_calohit * calohit_to_cluster).todense())

Expand Down Expand Up @@ -902,9 +879,7 @@ def get_reco_properties(dataset, prop_data, iev):
raise Exception("--dataset provided is not supported. Only 'fcc' or 'clic' are supported atm.")

reco_p4 = vector.awk(
awkward.zip(
{"mass": reco_arr["mass"], "x": reco_arr["momentum.x"], "y": reco_arr["momentum.y"], "z": reco_arr["momentum.z"]}
)
awkward.zip({"mass": reco_arr["mass"], "x": reco_arr["momentum.x"], "y": reco_arr["momentum.y"], "z": reco_arr["momentum.z"]})
)
reco_arr["pt"] = reco_p4.pt
reco_arr["eta"] = reco_p4.eta
Expand Down Expand Up @@ -1214,29 +1189,19 @@ def process_one_file(fn, ofn, dataset):
assert np.all(used_rps == 1)

gps_track = get_particle_feature_matrix(track_to_gp_all, gpdata_cleaned.gen_features, particle_feature_order)
gps_track[:, 0] = np.array(
[map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(gps_track[:, 0], gps_track[:, 1])]
)
gps_track[:, 0] = np.array([map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(gps_track[:, 0], gps_track[:, 1])])
gps_cluster = get_particle_feature_matrix(cluster_to_gp_all, gpdata_cleaned.gen_features, particle_feature_order)
gps_cluster[:, 0] = np.array(
[map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(gps_cluster[:, 0], gps_cluster[:, 1])]
)
gps_cluster[:, 0] = np.array([map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(gps_cluster[:, 0], gps_cluster[:, 1])])
gps_cluster[:, 1] = 0

rps_track = get_particle_feature_matrix(track_to_rp_all, reco_features, particle_feature_order)
rps_track[:, 0] = np.array(
[map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(rps_track[:, 0], rps_track[:, 1])]
)
rps_track[:, 0] = np.array([map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(rps_track[:, 0], rps_track[:, 1])])
rps_cluster = get_particle_feature_matrix(cluster_to_rp_all, reco_features, particle_feature_order)
rps_cluster[:, 0] = np.array(
[map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(rps_cluster[:, 0], rps_cluster[:, 1])]
)
rps_cluster[:, 0] = np.array([map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(rps_cluster[:, 0], rps_cluster[:, 1])])
rps_cluster[:, 1] = 0

# all initial gen/reco particle energy must be reconstructable
assert (
abs(np.sum(gps_track[:, 6]) + np.sum(gps_cluster[:, 6]) - np.sum(gpdata_cleaned.gen_features["energy"])) < 1e-2
)
assert abs(np.sum(gps_track[:, 6]) + np.sum(gps_cluster[:, 6]) - np.sum(gpdata_cleaned.gen_features["energy"])) < 1e-2

assert abs(np.sum(rps_track[:, 6]) + np.sum(rps_cluster[:, 6]) - np.sum(reco_features["energy"])) < 1e-2

Expand Down Expand Up @@ -1283,9 +1248,7 @@ def process_one_file(fn, ofn, dataset):
sorted_jet_idx = awkward.argsort(target_jets.pt, axis=-1, ascending=False).to_list()
target_jets_indices = target_jets_indices.to_list()
for jet_idx in sorted_jet_idx:
jet_constituents = [
index_mapping[idx] for idx in target_jets_indices[jet_idx]
] # map back to constituent index *before* masking
jet_constituents = [index_mapping[idx] for idx in target_jets_indices[jet_idx]] # map back to constituent index *before* masking
ytarget_constituents[jet_constituents] = jet_idx
ytarget_track_constituents = ytarget_constituents[: len(ytarget_track)]
ytarget_cluster_constituents = ytarget_constituents[len(ytarget_track) :]
Expand Down
16 changes: 4 additions & 12 deletions mlpf/model/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def __init__(self, data_dir, name, split, num_samples=None, sort=False):
builder = tfds.builder(name, data_dir=data_dir)
except Exception:
_logger.error(
"Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format(
name, data_dir
)
"Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format(name, data_dir)
)
sys.exit(1)
self.ds = TFDSDataSource(builder.as_data_source(split=split), sort=sort)
Expand Down Expand Up @@ -157,19 +155,15 @@ def to(self, device, **kwargs):
class Collater:
def __init__(self, per_particle_keys_to_get, per_event_keys_to_get, **kwargs):
super(Collater, self).__init__(**kwargs)
self.per_particle_keys_to_get = (
per_particle_keys_to_get # these quantities are a variable-length tensor per each event
)
self.per_particle_keys_to_get = per_particle_keys_to_get # these quantities are a variable-length tensor per each event
self.per_event_keys_to_get = per_event_keys_to_get # these quantities are one value (scalar) per event

def __call__(self, inputs):
ret = {}

# per-particle quantities need to be padded across events of different size
for key_to_get in self.per_particle_keys_to_get:
ret[key_to_get] = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True
)
ret[key_to_get] = torch.nn.utils.rnn.pad_sequence([torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True)

# per-event quantities can be stacked across events
for key_to_get in self.per_event_keys_to_get:
Expand Down Expand Up @@ -266,9 +260,7 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=Collater(
["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"]
),
collate_fn=Collater(["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"]),
sampler=sampler,
num_workers=config["num_workers"],
prefetch_factor=config["prefetch_factor"],
Expand Down

0 comments on commit a603778

Please sign in to comment.