Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
farakiko committed Dec 2, 2024
1 parent 17eca68 commit e30344c
Showing 1 changed file with 57 additions and 36 deletions.
93 changes: 57 additions & 36 deletions mlpf/data/key4hep/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def __init__(
self.cluster_features = cluster_features # feature matrix of the calo clusters
self.track_features = track_features # feature matrix of the tracks
self.genparticle_to_hit = genparticle_to_hit # sparse COO matrix of genparticles to hits (idx_gp, idx_hit, weight)
self.genparticle_to_track = genparticle_to_track # sparse COO matrix of genparticles to tracks (idx_gp, idx_track, weight)
self.genparticle_to_track = (
genparticle_to_track # sparse COO matrix of genparticles to tracks (idx_gp, idx_track, weight)
)
self.hit_to_cluster = hit_to_cluster # sparse COO matrix of hits to clusters (idx_hit, idx_cluster, weight)
self.gp_merges = gp_merges # sparse COO matrix of any merged genparticles

Expand Down Expand Up @@ -248,7 +250,10 @@ def get_calohit_matrix_and_genadj(dataset, hit_data, calohit_links, iev, collect
hit_idx_global += 1
hit_idx_local_to_global = {v: k for k, v in hit_idx_global_to_local.items()}
hit_feature_matrix = awkward.Record(
{k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))]) for k in hit_feature_matrix[0].fields}
{
k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))])
for k in hit_feature_matrix[0].fields
}
)

# add all edges from genparticle to calohit
Expand All @@ -259,11 +264,6 @@ def get_calohit_matrix_and_genadj(dataset, hit_data, calohit_links, iev, collect
calohit_to_gen_calo_idx = calohit_links["CalohitMCTruthLink#0.index"][iev]
calohit_to_gen_gen_idx = calohit_links["CalohitMCTruthLink#1.index"][iev]
elif dataset == "fcc":
# calohit_to_gen_calo_colid = calohit_links["_CalohitMCTruthLink_rec/_CalohitMCTruthLink_rec.collectionID"][iev]
# calohit_to_gen_gen_colid = calohit_links["_CalohitMCTruthLink_sim/_CalohitMCTruthLink_sim.collectionID"][iev]
# calohit_to_gen_calo_idx = calohit_links["_CalohitMCTruthLink_rec/_CalohitMCTruthLink_rec.index"][iev]
# calohit_to_gen_gen_idx = calohit_links["_CalohitMCTruthLink_sim/_CalohitMCTruthLink_sim.index"][iev]

calohit_to_gen_calo_colid = calohit_links["_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID"][iev]
calohit_to_gen_gen_colid = calohit_links["_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID"][iev]
calohit_to_gen_calo_idx = calohit_links["_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.index"][iev]
Expand Down Expand Up @@ -348,7 +348,9 @@ def gen_to_features(dataset, prop_data, iev):
gen_arr = {k.replace(mc_coll + ".", ""): gen_arr[k] for k in gen_arr.fields}

MCParticles_p4 = vector.awk(
awkward.zip({"mass": gen_arr["mass"], "x": gen_arr["momentum.x"], "y": gen_arr["momentum.y"], "z": gen_arr["momentum.z"]})
awkward.zip(
{"mass": gen_arr["mass"], "x": gen_arr["momentum.x"], "y": gen_arr["momentum.y"], "z": gen_arr["momentum.z"]}
)
)
gen_arr["pt"] = MCParticles_p4.pt
gen_arr["eta"] = MCParticles_p4.eta
Expand Down Expand Up @@ -395,8 +397,6 @@ def genparticle_track_adj(dataset, sitrack_links, iev):
trk_to_gen_trkidx = sitrack_links["SiTracksMCTruthLink#0.index"][iev]
trk_to_gen_genidx = sitrack_links["SiTracksMCTruthLink#1.index"][iev]
elif dataset == "fcc":
# trk_to_gen_trkidx = sitrack_links["_SiTracksMCTruthLink_rec/_SiTracksMCTruthLink_rec.index"][iev]
# trk_to_gen_genidx = sitrack_links["_SiTracksMCTruthLink_sim/_SiTracksMCTruthLink_sim.index"][iev]
trk_to_gen_trkidx = sitrack_links["_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.index"][iev]
trk_to_gen_genidx = sitrack_links["_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.index"][iev]
else:
Expand Down Expand Up @@ -501,7 +501,8 @@ def track_to_features(dataset, prop_data, iev):
ret["dEdx"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.value"]
ret["dEdxError"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.error"]

num_tracks = len(track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.value"])
# build the radiusOfInnermostHit variable
num_tracks = len(ret["dEdx"])
innermost_radius = []
for itrack in range(num_tracks):

Expand Down Expand Up @@ -534,8 +535,9 @@ def track_to_features(dataset, prop_data, iev):
if dataset == "clic":
ret[k] = awkward.to_numpy(prop_data["SiTracks_1"]["SiTracks_1." + k][iev][trackstate_idx])
elif dataset == "fcc":
# ret[k] = awkward.to_numpy(prop_data["_SiTracks_trackStates"]["_SiTracks_trackStates." + k][iev][trackstate_idx])
ret[k] = awkward.to_numpy(prop_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx])
ret[k] = awkward.to_numpy(
prop_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx]
)

else:
raise Exception("--dataset provided is not supported. Only 'fcc' or 'clic' are supported atm.")
Expand Down Expand Up @@ -624,7 +626,9 @@ def add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_tr

def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links, sitrack_links, iev, collectionIDs):
gen_features = gen_to_features(dataset, prop_data, iev)
hit_features, genparticle_to_hit, hit_idx_local_to_global = get_calohit_matrix_and_genadj(dataset, hit_data, calohit_links, iev, collectionIDs)
hit_features, genparticle_to_hit, hit_idx_local_to_global = get_calohit_matrix_and_genadj(
dataset, hit_data, calohit_links, iev, collectionIDs
)
hit_to_cluster = hit_cluster_adj(dataset, prop_data, hit_idx_local_to_global, iev)
cluster_features = cluster_to_features(prop_data, hit_features, hit_to_cluster, iev)
track_features = track_to_features(dataset, prop_data, iev)
Expand All @@ -634,15 +638,21 @@ def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links
mask_status1 = gen_features["generatorStatus"] == 1

if gen_features["index"] is not None: # if there are even daughters
genparticle_to_hit, genparticle_to_trk = add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_trk)
genparticle_to_hit, genparticle_to_trk = add_daughters_to_status1(
gen_features, genparticle_to_hit, genparticle_to_trk
)

n_gp = awkward.count(gen_features["PDG"])
n_track = awkward.count(track_features["type"])
n_hit = awkward.count(hit_features["type"])
n_cluster = awkward.count(cluster_features["type"])

if len(genparticle_to_trk[0]) > 0:
gp_to_track = coo_matrix((genparticle_to_trk[2], (genparticle_to_trk[0], genparticle_to_trk[1])), shape=(n_gp, n_track)).max(axis=1).todense()
gp_to_track = (
coo_matrix((genparticle_to_trk[2], (genparticle_to_trk[0], genparticle_to_trk[1])), shape=(n_gp, n_track))
.max(axis=1)
.todense()
)
else:
gp_to_track = np.zeros((n_gp, 1))

Expand Down Expand Up @@ -672,7 +682,9 @@ def get_genparticles_and_adjacencies(dataset, prop_data, hit_data, calohit_links

if len(np.array(mask_visible)) == 1:
# event has only one particle (then index will be empty because no daughters)
gen_features = awkward.Record({feat: (gen_features[feat][mask_visible] if feat != "index" else None) for feat in gen_features.keys()})
gen_features = awkward.Record(
{feat: (gen_features[feat][mask_visible] if feat != "index" else None) for feat in gen_features.keys()}
)
else:
gen_features = awkward.Record({feat: gen_features[feat][mask_visible] for feat in gen_features.keys()})

Expand Down Expand Up @@ -705,8 +717,12 @@ def assign_genparticles_to_obj_and_merge(gpdata):
).todense()
)

gp_to_calohit = coo_matrix((gpdata.genparticle_to_hit[2], (gpdata.genparticle_to_hit[0], gpdata.genparticle_to_hit[1])), shape=(n_gp, n_hit))
calohit_to_cluster = coo_matrix((gpdata.hit_to_cluster[2], (gpdata.hit_to_cluster[0], gpdata.hit_to_cluster[1])), shape=(n_hit, n_cluster))
gp_to_calohit = coo_matrix(
(gpdata.genparticle_to_hit[2], (gpdata.genparticle_to_hit[0], gpdata.genparticle_to_hit[1])), shape=(n_gp, n_hit)
)
calohit_to_cluster = coo_matrix(
(gpdata.hit_to_cluster[2], (gpdata.hit_to_cluster[0], gpdata.hit_to_cluster[1])), shape=(n_hit, n_cluster)
)

gp_to_cluster = np.array((gp_to_calohit * calohit_to_cluster).todense())

Expand Down Expand Up @@ -886,7 +902,9 @@ def get_reco_properties(dataset, prop_data, iev):
raise Exception("--dataset provided is not supported. Only 'fcc' or 'clic' are supported atm.")

reco_p4 = vector.awk(
awkward.zip({"mass": reco_arr["mass"], "x": reco_arr["momentum.x"], "y": reco_arr["momentum.y"], "z": reco_arr["momentum.z"]})
awkward.zip(
{"mass": reco_arr["mass"], "x": reco_arr["momentum.x"], "y": reco_arr["momentum.y"], "z": reco_arr["momentum.z"]}
)
)
reco_arr["pt"] = reco_p4.pt
reco_arr["eta"] = reco_p4.eta
Expand Down Expand Up @@ -1061,22 +1079,17 @@ def process_one_file(fn, ofn, dataset):
"MCParticles.daughters_end",
"_MCParticles_daughters/_MCParticles_daughters.index", # similar to "MCParticles#1.index" in clic
track_coll,
# "_SiTracks_trackStates",
"_SiTracks_Refitted_trackStates",
"PandoraClusters",
"_PandoraClusters_hits/_PandoraClusters_hits.index",
"_PandoraClusters_hits/_PandoraClusters_hits.collectionID",
"PandoraPFOs",
"SiTracks_Refitted_dQdx", # TODO: new
"SiTracks_Refitted_dQdx",
]
)
calohit_links = arrs.arrays(
[
"CalohitMCTruthLink.weight",
# "_CalohitMCTruthLink_rec/_CalohitMCTruthLink_rec.collectionID",
# "_CalohitMCTruthLink_rec/_CalohitMCTruthLink_rec.index",
# "_CalohitMCTruthLink_sim/_CalohitMCTruthLink_sim.collectionID",
# "_CalohitMCTruthLink_sim/_CalohitMCTruthLink_sim.index",
"_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID",
"_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.index",
"_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID",
Expand All @@ -1086,10 +1099,6 @@ def process_one_file(fn, ofn, dataset):
sitrack_links = arrs.arrays(
[
"SiTracksMCTruthLink.weight",
# "_SiTracksMCTruthLink_rec/_SiTracksMCTruthLink_rec.collectionID",
# "_SiTracksMCTruthLink_rec/_SiTracksMCTruthLink_rec.index",
# "_SiTracksMCTruthLink_sim/_SiTracksMCTruthLink_sim.collectionID",
# "_SiTracksMCTruthLink_sim/_SiTracksMCTruthLink_sim.index",
"_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.collectionID",
"_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.index",
"_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.collectionID",
Expand Down Expand Up @@ -1205,19 +1214,29 @@ def process_one_file(fn, ofn, dataset):
assert np.all(used_rps == 1)

gps_track = get_particle_feature_matrix(track_to_gp_all, gpdata_cleaned.gen_features, particle_feature_order)
gps_track[:, 0] = np.array([map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(gps_track[:, 0], gps_track[:, 1])])
gps_track[:, 0] = np.array(
[map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(gps_track[:, 0], gps_track[:, 1])]
)
gps_cluster = get_particle_feature_matrix(cluster_to_gp_all, gpdata_cleaned.gen_features, particle_feature_order)
gps_cluster[:, 0] = np.array([map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(gps_cluster[:, 0], gps_cluster[:, 1])])
gps_cluster[:, 0] = np.array(
[map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(gps_cluster[:, 0], gps_cluster[:, 1])]
)
gps_cluster[:, 1] = 0

rps_track = get_particle_feature_matrix(track_to_rp_all, reco_features, particle_feature_order)
rps_track[:, 0] = np.array([map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(rps_track[:, 0], rps_track[:, 1])])
rps_track[:, 0] = np.array(
[map_neutral_to_charged(map_pdgid_to_candid(p, c)) for p, c in zip(rps_track[:, 0], rps_track[:, 1])]
)
rps_cluster = get_particle_feature_matrix(cluster_to_rp_all, reco_features, particle_feature_order)
rps_cluster[:, 0] = np.array([map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(rps_cluster[:, 0], rps_cluster[:, 1])])
rps_cluster[:, 0] = np.array(
[map_charged_to_neutral(map_pdgid_to_candid(p, c)) for p, c in zip(rps_cluster[:, 0], rps_cluster[:, 1])]
)
rps_cluster[:, 1] = 0

# all initial gen/reco particle energy must be reconstructable
assert abs(np.sum(gps_track[:, 6]) + np.sum(gps_cluster[:, 6]) - np.sum(gpdata_cleaned.gen_features["energy"])) < 1e-2
assert (
abs(np.sum(gps_track[:, 6]) + np.sum(gps_cluster[:, 6]) - np.sum(gpdata_cleaned.gen_features["energy"])) < 1e-2
)

assert abs(np.sum(rps_track[:, 6]) + np.sum(rps_cluster[:, 6]) - np.sum(reco_features["energy"])) < 1e-2

Expand Down Expand Up @@ -1264,7 +1283,9 @@ def process_one_file(fn, ofn, dataset):
sorted_jet_idx = awkward.argsort(target_jets.pt, axis=-1, ascending=False).to_list()
target_jets_indices = target_jets_indices.to_list()
for jet_idx in sorted_jet_idx:
jet_constituents = [index_mapping[idx] for idx in target_jets_indices[jet_idx]] # map back to constituent index *before* masking
jet_constituents = [
index_mapping[idx] for idx in target_jets_indices[jet_idx]
] # map back to constituent index *before* masking
ytarget_constituents[jet_constituents] = jet_idx
ytarget_track_constituents = ytarget_constituents[: len(ytarget_track)]
ytarget_cluster_constituents = ytarget_constituents[len(ytarget_track) :]
Expand Down

0 comments on commit e30344c

Please sign in to comment.