Skip to content

Commit

Permalink
Add registering template, simplify input interfaces for functional algs
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarcell committed Sep 11, 2023
1 parent b07e9eb commit 3914c0b
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 91 deletions.
2 changes: 1 addition & 1 deletion k4FWCore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ target_include_directories(k4FWCore PUBLIC
file(GLOB k4fwcore_plugin_sources components/*.cpp)
gaudi_add_module(k4FWCorePlugins
SOURCES ${k4fwcore_plugin_sources}
LINK Gaudi::GaudiAlgLib Gaudi::GaudiKernel k4FWCore k4FWCore::k4Interface ROOT::Core ROOT::RIO ROOT::Tree)
LINK Gaudi::GaudiAlgLib Gaudi::GaudiKernel k4FWCore k4FWCore::k4Interface ROOT::Core ROOT::RIO ROOT::Tree EDM4HEP::edm4hep)
target_include_directories(k4FWCorePlugins PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}/include>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
Expand Down
155 changes: 153 additions & 2 deletions k4FWCore/components/PodioInput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,172 @@

#include "k4FWCore/PodioDataSvc.h"

#include "edm4hep/MCParticleCollection.h"
#include "edm4hep/SimTrackerHitCollection.h"
#include "edm4hep/CaloHitContributionCollection.h"
#include "edm4hep/SimCalorimeterHitCollection.h"
#include "edm4hep/RawCalorimeterHitCollection.h"
#include "edm4hep/CalorimeterHitCollection.h"
#include "edm4hep/ParticleIDCollection.h"
#include "edm4hep/ClusterCollection.h"
#include "edm4hep/TrackerHitCollection.h"
#include "edm4hep/TrackerHitPlaneCollection.h"
#include "edm4hep/RawTimeSeriesCollection.h"
#include "edm4hep/TrackCollection.h"
#include "edm4hep/VertexCollection.h"
#include "edm4hep/ReconstructedParticleCollection.h"
#include "edm4hep/MCRecoParticleAssociationCollection.h"
#include "edm4hep/MCRecoCaloAssociationCollection.h"
#include "edm4hep/MCRecoTrackerAssociationCollection.h"
#include "edm4hep/MCRecoTrackerHitPlaneAssociationCollection.h"
#include "podio/UserDataCollection.h"


DECLARE_COMPONENT(PodioInput)

template <typename T>
inline void PodioInput::maybeRead(std::string_view CollType, std::string_view collName) const {
if (m_podioDataSvc->readCollection<T>(std::string(collName)).isFailure()) {
error() << "Failed to register collection " << collName << endmsg;
}
}

void PodioInput::fillReaders() {
m_readers["edm4hep::MCParticleCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::MCParticleCollection>("edm4hep::MCParticleCollection", collName);
};
m_readers["edm4hep::SimTrackerHitCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::SimTrackerHitCollection>("edm4hep::SimTrackerHitCollection", collName);
};
m_readers["edm4hep::CaloHitContributionCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::CaloHitContributionCollection>("edm4hep::CaloHitContributionCollection", collName);
};
m_readers["edm4hep::SimCalorimeterHitCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::SimCalorimeterHitCollection>("edm4hep::SimCalorimeterHitCollection", collName);
};
m_readers["edm4hep::RawCalorimeterHitCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::RawCalorimeterHitCollection>("edm4hep::RawCalorimeterHitCollection", collName);
};
m_readers["edm4hep::CalorimeterHitCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::CalorimeterHitCollection>("edm4hep::CalorimeterHitCollection", collName);
};
m_readers["edm4hep::ParticleIDCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::ParticleIDCollection>("edm4hep::ParticleIDCollection", collName);
};
m_readers["edm4hep::ClusterCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::ClusterCollection>("edm4hep::ClusterCollection", collName);
};
m_readers["edm4hep::TrackerHitCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::TrackerHitCollection>("edm4hep::TrackerHitCollection", collName);
};
m_readers["edm4hep::TrackerHitPlaneCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::TrackerHitPlaneCollection>("edm4hep::TrackerHitPlaneCollection", collName);
};
m_readers["edm4hep::RawTimeSeriesCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::RawTimeSeriesCollection>("edm4hep::RawTimeSeriesCollection", collName);
};
m_readers["edm4hep::TrackCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::TrackCollection>("edm4hep::TrackCollection", collName);
};
m_readers["edm4hep::VertexCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::VertexCollection>("edm4hep::VertexCollection", collName);
};
m_readers["edm4hep::ReconstructedParticleCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::ReconstructedParticleCollection>("edm4hep::ReconstructedParticleCollection", collName);
};
m_readers["edm4hep::MCRecoParticleAssociationCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::MCRecoParticleAssociationCollection>("edm4hep::MCRecoParticleAssociationCollection", collName);
};
m_readers["edm4hep::MCRecoCaloAssociationCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::MCRecoCaloAssociationCollection>("edm4hep::MCRecoCaloAssociationCollection", collName);
};
m_readers["edm4hep::MCRecoTrackerAssociationCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::MCRecoTrackerAssociationCollection>("edm4hep::MCRecoTrackerAssociationCollection", collName);
};
m_readers["edm4hep::MCRecoTrackerHitPlaneAssociationCollection"] =
[&](std::string_view collName) {
maybeRead<edm4hep::MCRecoTrackerHitPlaneAssociationCollection>("edm4hep::MCRecoTrackerHitPlaneAssociationCollection", collName);
};
m_readers["podio::UserDataCollection<int>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<int>>("podio::UserDataCollection<int>", collName);
};
m_readers["podio::UserDataCollection<float>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<float>>("podio::UserDataCollection<float>", collName);
};
m_readers["podio::UserDataCollection<double>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<double>>("podio::UserDataCollection<double>", collName);
};
m_readers["podio::UserDataCollection<int8_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<int8_t>>("podio::UserDataCollection<int8_t>", collName);
};
m_readers["podio::UserDataCollection<int16_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<int16_t>>("podio::UserDataCollection<int16_t>", collName);
};
m_readers["podio::UserDataCollection<int32_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<int32_t>>("podio::UserDataCollection<int32_t>", collName);
};
m_readers["podio::UserDataCollection<int64_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<int64_t>>("podio::UserDataCollection<int64_t>", collName);
};
m_readers["podio::UserDataCollection<uint8_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<uint8_t>>("podio::UserDataCollection<uint8_t>", collName);
};
m_readers["podio::UserDataCollection<uint16_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<uint16_t>>("podio::UserDataCollection<uint16_t>", collName);
};
m_readers["podio::UserDataCollection<uint32_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<uint32_t>>("podio::UserDataCollection<uint32_t>", collName);
};
m_readers["podio::UserDataCollection<uint64_t>"] =
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<uint64_t>>("podio::UserDataCollection<uint64_t>", collName);
};
}

PodioInput::PodioInput(const std::string& name, ISvcLocator* svcLoc) : Consumer(name, svcLoc) {
// check whether we have the PodioEvtSvc active
m_podioDataSvc = dynamic_cast<PodioDataSvc*>(evtSvc().get());
if (!m_podioDataSvc) {
error() << "Could not get PodioDataSvc" << endmsg;
}
fillReaders();
}

void PodioInput::operator()() const {
for (auto& collName : m_collectionNames) {
debug() << "Registering collection to read " << collName << endmsg;
if (m_podioDataSvc->readCollection(collName).isFailure()) {
error() << "Failed to register collection " << collName << endmsg;
auto type = m_podioDataSvc->getCollectionType(collName);
if (m_readers.find(type) != m_readers.end()) {
m_readers[type](collName);
} else {
maybeRead<podio::CollectionBase>(type, collName);
}
}

Expand Down
4 changes: 4 additions & 0 deletions k4FWCore/components/PodioInput.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ class PodioInput final : public Gaudi::Functional::Consumer<void(), BaseClass_t>
void operator()() const override;

private:
template <typename T>
void maybeRead(std::string_view CollType, std::string_view collName) const;
void fillReaders();
// Name of collections to read. Set by option collections (this is temporary)
Gaudi::Property<std::vector<std::string>> m_collectionNames{this, "collections", {}, "Places of collections to read"};
// Data service: needed to register objects and get collection IDs. Just an observing pointer.
PodioDataSvc* m_podioDataSvc;
mutable std::map<std::string_view, std::function<void(std::string_view)>> m_readers;
};

#endif
3 changes: 3 additions & 0 deletions k4FWCore/include/k4FWCore/DataWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ template <class T> class GAUDI_API DataWrapper : public DataWrapperBase {
void setData(const T* data) { m_data = data; }
virtual void resetData() { m_data = nullptr; }

operator const T&() const & {
return *m_data;
}
private:
/// try to cast to collectionBase; may return nullptr;
virtual podio::CollectionBase* collectionBase();
Expand Down
24 changes: 22 additions & 2 deletions k4FWCore/include/k4FWCore/FunctionalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,36 @@
#include "GaudiAlg/GaudiAlgorithm.h"
#include "GaudiKernel/DataObjectHandle.h"
#include "k4FWCore/DataWrapper.h"
#include "edm4hep/MCParticleCollection.h"

#include "Gaudi/Functional/details.h"
#include "Gaudi/Functional/utilities.h"
#include "Gaudi/Functional/Consumer.h"
#include <GaudiKernel/FunctionalFilterDecision.h>

// Base class used for the Traits template argument of the
// Gaudi::Functional algorithms
struct BaseClass_t {
template<typename T>
// template<typename T>
// using InputHandle_t = DataObjectReadHandle<DataWrapper<T>>;
using InputHandle = DataObjectReadHandle<DataWrapper<T>>;
template<typename T>
using OutputHandle = DataObjectWriteHandle<DataWrapper<T>>;

using BaseClass = Gaudi::Algorithm;
};


// namespace Gaudi::Functional::details {
// using colltype = edm4hep::MCParticleCollection;

// template <typename Traits_>
// struct Consumer<void(const colltype&), Traits_, false>
// : DataHandleMixin<std::tuple<>, filter_evtcontext<DataWrapper<colltype>>, Traits_> {
// using DataHandleMixin<std::tuple<>, filter_evtcontext<DataWrapper<colltype>>, Traits_>::DataHandleMixin;

// virtual void operator()( const DataWrapper<colltype> ) {}
// virtual void operator()( const DataWrapper<colltype> ) {}
// };
// }

#endif
16 changes: 15 additions & 1 deletion k4FWCore/include/k4FWCore/PodioDataSvc.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "podio/Frame.h"
#include "podio/ROOTFrameReader.h"
// Forward declarations
#include "k4FWCore/DataWrapper.h"
class DataWrapperBase;
class PodioOutput;
template <typename T> class MetaDataHandle;
Expand Down Expand Up @@ -65,7 +66,20 @@ class PodioDataSvc : public DataSvc {
virtual StatusCode registerObject(std::string_view parentPath, std::string_view fullPath,
DataObject* pObject) override final;

StatusCode readCollection(const std::string& collectionName);
const std::string_view getCollectionType(const std::string& collName);

template <typename T>
StatusCode readCollection(const std::string& collName) {
const T* collection(nullptr);
collection = static_cast<const T*>(m_eventframe.get(collName));
if (collection == nullptr) {
error() << "Collection " << collName << " does not exist." << endmsg;
}
auto wrapper = new DataWrapper<T>;
wrapper->setData(collection);
m_podio_datawrappers.push_back(wrapper);
return DataSvc::registerObject("/Event", "/" + collName, wrapper);
}

const podio::Frame& getEventFrame() const { return m_eventframe; }

Expand Down
38 changes: 30 additions & 8 deletions k4FWCore/src/PodioDataSvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* limitations under the License.
*/
#include "k4FWCore/PodioDataSvc.h"
#include <GaudiKernel/StatusCode.h>
#include "GaudiKernel/IConversionSvc.h"
#include "GaudiKernel/IEventProcessor.h"
#include "GaudiKernel/IProperty.h"
Expand Down Expand Up @@ -148,18 +149,39 @@ PodioDataSvc::PodioDataSvc(const std::string& name, ISvcLocator* svc) : DataSvc(
/// Standard Destructor
PodioDataSvc::~PodioDataSvc() {}

StatusCode PodioDataSvc::readCollection(const std::string& collName) {
const podio::CollectionBase* collection(nullptr);
collection = m_eventframe.get(collName);
if (collection == nullptr) {
const std::string_view PodioDataSvc::getCollectionType(const std::string& collName) {
auto coll = m_eventframe.get(collName);
if (coll == nullptr) {
error() << "Collection " << collName << " does not exist." << endmsg;
}
auto wrapper = new DataWrapper<podio::CollectionBase>;
wrapper->setData(collection);
m_podio_datawrappers.push_back(wrapper);
return DataSvc::registerObject("/Event", "/" + collName, wrapper);
return coll->getTypeName();
}

// template <typename T>
// StatusCode PodioDataSvc::readCollection(const std::string& collName) {
// const T* collection(nullptr);
// collection = m_eventframe.get(collName);
// if (collection == nullptr) {
// error() << "Collection " << collName << " does not exist." << endmsg;
// }
// auto wrapper = new DataWrapper<T>;
// wrapper->setData(collection);
// m_podio_datawrappers.push_back(wrapper);
// return DataSvc::registerObject("/Event", "/" + collName, wrapper);
// }

// StatusCode PodioDataSvc::readCollection(const std::string& collName) {
// const podio::CollectionBase* collection(nullptr);
// collection = m_eventframe.get(collName);
// if (collection == nullptr) {
// error() << "Collection " << collName << " does not exist." << endmsg;
// }
// auto wrapper = new DataWrapper<podio::CollectionBase>;
// wrapper->setData(collection);
// m_podio_datawrappers.push_back(wrapper);
// return DataSvc::registerObject("/Event", "/" + collName, wrapper);
// }

StatusCode PodioDataSvc::registerObject(std::string_view parentPath, std::string_view fullPath, DataObject* pObject) {
DataWrapperBase* wrapper = dynamic_cast<DataWrapperBase*>(pObject);
if (wrapper != nullptr) {
Expand Down
13 changes: 5 additions & 8 deletions test/k4FWCoreTest/src/components/ExampleFunctionalConsumer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
#include "edm4hep/MCParticleCollection.h"
#include "podio/CollectionBase.h"

#include <string>
// Define BaseClass_t
#include "k4FWCore/FunctionalUtils.h"

// This will always be Gaudi::Algorithm
using BaseClass_t = Gaudi::Functional::Traits::BaseClass_t<Gaudi::Algorithm>;
#include <string>

// Which type of collection we are reading
// this will always be podio::CollectionBase
// Has to be wrapped in DataWrapper
using colltype = DataWrapper<podio::CollectionBase>;
using colltype = edm4hep::MCParticleCollection;

struct ExampleFunctionalConsumer final : Gaudi::Functional::Consumer<void(const colltype& input), BaseClass_t> {
// The pair in KeyValue can be changed from python and it corresponds
Expand All @@ -26,9 +24,8 @@ struct ExampleFunctionalConsumer final : Gaudi::Functional::Consumer<void(const
// Note that the function has to be const, as well as the collections
// we get from the input
void operator()(const colltype& input) const override {
auto* coll = dynamic_cast<const edm4hep::MCParticleCollection*>(input.getData());
int i = 0;
for (const auto& particle : *coll) {
for (const auto& particle : input) {
if ((particle.getPDG() != 1 + i) || (particle.getGeneratorStatus() != 2 + i) ||
(particle.getSimulatorStatus() != 3 + i) || (particle.getCharge() != 4 + i) ||
(particle.getTime() != 5 + i) || (particle.getMass() != 6 + i)) {
Expand Down
Loading

0 comments on commit 3914c0b

Please sign in to comment.