Skip to content

Commit

Permalink
Merge pull request #142 from lgray/topic_laurelin
Browse files Browse the repository at this point in the history
Use laurelin to read root files in spark tests
  • Loading branch information
lgray authored Jul 16, 2019
2 parents 6a86252 + a1c7ccc commit e776a80
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
6 changes: 5 additions & 1 deletion coffea/processor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={},
raise ValueError("Expected executor to derive from SparkExecutor")

executor_args.setdefault('config', None)
executor_args.setdefault('file_type', 'parquet')
executor_args.setdefault('laurelin_version', '0.1.0')
file_type = executor_args['file_type']

if executor_args['config'] is None:
executor_args.pop('config')
Expand All @@ -352,7 +355,8 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={},
if not isinstance(spark, pyspark.sql.session.SparkSession):
raise ValueError("Expected 'spark' to be a pyspark.sql.session.SparkSession")

dfslist = _spark_make_dfs(spark, fileset, partitionsize, processor_instance.columns, thread_workers)
dfslist = _spark_make_dfs(spark, fileset, partitionsize, processor_instance.columns,
thread_workers, file_type)

output = processor_instance.accumulator.identity()
executor(spark, dfslist, processor_instance, output, thread_workers)
Expand Down
27 changes: 21 additions & 6 deletions coffea/processor/spark/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from tqdm import tqdm
import pyspark.sql
import pyspark.sql.functions as fn
from pyarrow.compat import guid
from collections.abc import Sequence

from ..executor import futures_handler

# this is a reasonable local spark configuration
_default_config = pyspark.sql.SparkSession.builder \
.appName('coffea-analysis') \
.appName('coffea-analysis-%s' % guid()) \
.master('local[*]') \
.config('spark.sql.execution.arrow.enabled', 'true') \
.config('spark.sql.execution.arrow.maxRecordsPerBatch', 200000)
Expand All @@ -26,6 +27,12 @@ def _spark_initialize(config=_default_config, **kwargs):
if not spark_progress:
cfg_actual = cfg_actual.config('spark.ui.showConsoleProgress', 'false')

# always load laurelin even if we may not use it
kwargs.setdefault('laurelin_version', '0.1.0')
laurelin = kwargs['laurelin_version']
cfg_actual = cfg_actual.config('spark.jars.packages',
'edu.vanderbilt.accre:laurelin:%s' % laurelin)

session = cfg_actual.getOrCreate()
sc = session.sparkContext

Expand All @@ -37,10 +44,16 @@ def _spark_initialize(config=_default_config, **kwargs):
return session


def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize):
def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, treeName='Events'):
if not isinstance(files_or_dirs, Sequence):
raise ValueError("spark dataset file list must be a Sequence (like list())")
df = spark.read.parquet(*files_or_dirs)
raise ValueError('spark dataset file list must be a Sequence (like list())')
df = None
if file_type == 'parquet':
df = spark.read.parquet(*files_or_dirs)
else:
df = spark.read.format(file_type) \
.option('tree', treeName) \
.load(*files_or_dirs)
count = df.count()

df_cols = set(df.columns)
Expand All @@ -57,7 +70,7 @@ def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize):
return df, dataset, count


def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, status=True):
def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, file_type, status=True):
dfs = {}
ana_cols = set(columns)

Expand All @@ -66,7 +79,8 @@ def dfs_accumulator(total, result):
total[ds] = (df, count)

with ThreadPoolExecutor(max_workers=thread_workers) as executor:
futures = set(executor.submit(_read_df, spark, ds, files, ana_cols, partitionsize) for ds, files in fileset.items())
futures = set(executor.submit(_read_df, spark, ds, files,
ana_cols, partitionsize, file_type) for ds, files in fileset.items())

futures_handler(futures, dfs, status, 'datasets', 'loading', futures_accumulator=dfs_accumulator)

Expand All @@ -75,4 +89,5 @@ def dfs_accumulator(total, result):

def _spark_stop(spark):
# this may do more later?
spark._jvm.SparkSession.clearActiveSession()
spark.stop()
10 changes: 6 additions & 4 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_spark_imports():

def test_spark_executor():
pyspark = pytest.importorskip("pyspark", minversion="2.4.1")
from pyarrow.compat import guid

from coffea.processor.spark.detail import (_spark_initialize,
_spark_make_dfs,
Expand All @@ -32,15 +33,15 @@ def test_spark_executor():

import pyspark.sql
spark_config = pyspark.sql.SparkSession.builder \
.appName('spark-executor-test') \
.appName('spark-executor-test-%s' % guid()) \
.master('local[*]') \
.config('spark.sql.execution.arrow.enabled','true') \
.config('spark.sql.execution.arrow.maxRecordsPerBatch', 200000)

spark = _spark_initialize(config=spark_config,log_level='ERROR',spark_progress=False)

filelist = {'ZJets': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.parquet')],
'Data' : ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.parquet')]
filelist = {'ZJets': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.root')],
'Data' : ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.root')]
}

from coffea.processor.test_items import NanoTestProcessor
Expand All @@ -49,7 +50,8 @@ def test_spark_executor():
columns = ['nMuon','Muon_pt','Muon_eta','Muon_phi','Muon_mass']
proc = NanoTestProcessor(columns=columns)

hists = run_spark_job(filelist, processor_instance=proc, executor=spark_executor, spark=spark, thread_workers=1)
hists = run_spark_job(filelist, processor_instance=proc, executor=spark_executor, spark=spark, thread_workers=1,
executor_args={'file_type': 'root'})

_spark_stop(spark)

Expand Down

0 comments on commit e776a80

Please sign in to comment.