Skip to content

Commit

Permalink
allow to load multiple configs
Browse files Browse the repository at this point in the history
  • Loading branch information
kondratyevd committed Feb 13, 2024
1 parent 6136607 commit e185d99
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 24 deletions.
45 changes: 28 additions & 17 deletions af_benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import yaml
import scalpl
import glob
import pandas as pd

from profiling.timing import time_profiler as tp
Expand Down Expand Up @@ -33,11 +34,7 @@ def read_yaml(file_path):


class Benchmark:
@tp.enable
def __init__(self, config_path):
print(" > Reading config")
self.config = read_yaml(config_path)

def __init__(self, config_path=None):
self.report_df = pd.DataFrame(
columns=[
"dataset",
Expand All @@ -50,9 +47,14 @@ def __init__(self, config_path):
"col_handler",
]
)
self.n_files = 0
if config_path:
self.reinitialize(config_path)

def reinitialize(self, config_path):
tp.reset()

self.config = read_yaml(config_path)

print(" > Initializing executor")
# Select executor backend
self.backend = self.config.get('executor.backend')
if self.backend in executors:
Expand All @@ -62,7 +64,6 @@ def __init__(self, config_path):
f"Invalid backend: {self.backend}. Allowed values are: {executors.keys()}"
)

print(" > Initializing column handler")
# Select file handler method
self.method = self.config.get('processing.method')
if self.method in handlers:
Expand Down Expand Up @@ -93,8 +94,9 @@ def run(self):

return outputs

def report(self):
run_stats = tp.report_df.loc[
def update_report(self):
# print(tp.report_df)
run_time = tp.report_df.loc[
tp.report_df.func_name=="run",
"func_time"
].values[0]
Expand All @@ -113,20 +115,29 @@ def report(self):
}])
])

def print_report(self):
print(self.report_df)


def run_benchmark(args):
print("> Creating benchmark ...")
b = Benchmark(args.config_file)
print("> Starting benchmark ...")
outputs = b.run()
print("> Benchmark finished")
b.report()

if args.config_path.endswith(".yaml") or args.config_path.endswith(".yml"):
configs = [args.config_path]
else:
configs = glob.glob(args.config_path+"/*.yaml") + glob.glob(args.config_path+"/*.yml")

b = Benchmark()
for config_file in configs:
print(f"> Loading config from {config_file}")
b.reinitialize(config_file)
b.run()
b.update_report()

b.print_report()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('config_file', help="Path to YAML config")
parser.add_argument('config_path', help="Path to YAML config or directory with YAML configs")
args = parser.parse_args()
run_benchmark(args)
26 changes: 26 additions & 0 deletions af_benchmark/example-configs/example-config-1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
executor:
backend: sequential
# backend: futures
# backend: dask-local
# backend: dask-gateway
data-access:
# mode: local
# files:
# - tests/data/nano_dimuon.root
mode: local_dir
files_dir: /eos/purdue/store/data/Run2016B/SingleMuon/NANOAOD/02Apr2020_ver2-v1/20000/
processing:
method: uproot
# method: nanoevents
columns:
- event
- Muon_pt
- Muon_eta
operation: sum
write: false
measurements:
what:
- timing
- ram
report:
- print
18 changes: 18 additions & 0 deletions af_benchmark/example-configs/example-config-2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
executor:
backend: sequential
data-access:
mode: local_dir
files_dir: /eos/purdue/store/data/Run2016B/SingleMuon/NANOAOD/02Apr2020_ver2-v1/20000/
processing:
method: uproot
columns:
- event
- Muon_pt
operation: sum
write: false
measurements:
what:
- timing
- ram
report:
- print
1 change: 0 additions & 1 deletion af_benchmark/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def execute(self, func, args, **kwargs):

return results

@tp.enable
@abstractmethod
def _execute(self, func, args, **kwargs):
"""Executor-specific implementation (see inherited classes)
Expand Down
2 changes: 0 additions & 2 deletions af_benchmark/executor/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __del__(self):
if hasattr(self, 'client') and self.client is not None:
self.client.close()

@tp.enable
def _execute(self, func, args, **kwargs):
"""Execute ``func`` over ``args`` in parallel using ``distributed.Client::submit()``.
Expand Down Expand Up @@ -84,7 +83,6 @@ def _find_gateway_client(self):
self.cluster = self.gateway.connect(first_cluster_name)
self.client = self.cluster.get_client()

@tp.enable
def _execute(self, func, args, **kwargs):
"""Execute ``func`` over ``args`` in parallel using ``distributed.Client::submit()``.
Expand Down
2 changes: 0 additions & 2 deletions af_benchmark/executor/futures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from executor.base import BaseExecutor
from profiling.timing import time_profiler as tp
from concurrent import futures
import multiprocessing

Expand All @@ -10,7 +9,6 @@ class FuturesExecutor(BaseExecutor):
on the same node where the benchmark is launched.
"""

@tp.enable
def _execute(self, func, args, **kwargs):
"""Execute ``func`` over ``args`` in parallel using ``concurrent.futures.ThreadPoolExecutor``.
Expand Down
2 changes: 0 additions & 2 deletions af_benchmark/executor/sequential.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from executor.base import BaseExecutor
from profiling.timing import time_profiler as tp
import tqdm


Expand All @@ -9,7 +8,6 @@ class SequentialExecutor(BaseExecutor):
Processes arguments in a ``for`` loop.
"""

@tp.enable
def _execute(self, func, args, **kwargs):
"""Execute ``func`` over ``args`` in a loop.
Expand Down
4 changes: 4 additions & 0 deletions af_benchmark/profiling/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def __init__(self):
self.profiler = cProfile.Profile()
self.report_df = pd.DataFrame()
self.enabled = []

def reset(self):
self.__init__()

def enable(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down

0 comments on commit e185d99

Please sign in to comment.