Skip to content

Commit

Permalink
Generate classes identical up to the shim package name [databricks] (#…
Browse files Browse the repository at this point in the history
…11665)

Generate classes identical up to the shim package name

Signed-off-by: Gera Shegalov <[email protected]>
  • Loading branch information
gerashegalov authored Oct 28, 2024
1 parent b653ce2 commit 986eb5d
Show file tree
Hide file tree
Showing 27 changed files with 68 additions and 710 deletions.
44 changes: 33 additions & 11 deletions build/shimplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import os
import re
import subprocess
from functools import partial


def __project():
Expand Down Expand Up @@ -199,7 +200,9 @@ def __csv_as_arr(str_val):
__shim_comment_pattern = re.compile(re.escape(__opening_shim_tag) +
r'\n(.*)\n' +
re.escape(__closing_shim_tag), re.DOTALL)

__spark_version_classifier = '$_spark.version.classifier_'
__spark_version_placeholder = re.escape(__spark_version_classifier)
__package_pattern = re.compile('package .*' + '(' + __spark_version_placeholder + ')')
def __upsert_shim_json(filename, bv_list):
with open(filename, 'r') as file:
contents = file.readlines()
Expand Down Expand Up @@ -365,10 +368,7 @@ def __generate_symlinks():
__log.info("# generating symlinks for shim %s %s files", buildver, src_type)
__traverse_source_tree_of_all_shims(
src_type,
lambda src_type, path, build_ver_arr: __generate_symlink_to_file(buildver,
src_type,
path,
build_ver_arr))
partial(__generate_symlink_to_file, buildver=buildver, src_type=src_type))

def __traverse_source_tree_of_all_shims(src_type, func):
"""Walks src/<src_type>/sparkXYZ"""
Expand All @@ -392,11 +392,10 @@ def __traverse_source_tree_of_all_shims(src_type, func):
build_ver_arr = map(lambda x: str(json.loads(x).get('spark')), shim_arr)
__log.debug("extracted shims %s", build_ver_arr)
assert build_ver_arr == sorted(build_ver_arr),\
"%s shim list is not properly sorted" % shim_file_path
func(src_type, shim_file_path, build_ver_arr)

"%s shim list is not properly sorted: %s" % (shim_file_path, build_ver_arr)
func(shim_file_path=shim_file_path, build_ver_arr=build_ver_arr, shim_file_txt=shim_file_txt)

def __generate_symlink_to_file(buildver, src_type, shim_file_path, build_ver_arr):
def __generate_symlink_to_file(buildver, src_type, shim_file_path, build_ver_arr, shim_file_txt):
if buildver in build_ver_arr:
project_base_dir = str(__project().getBaseDir())
base_dir = __src_basedir
Expand All @@ -416,9 +415,32 @@ def __generate_symlink_to_file(buildver, src_type, shim_file_path, build_ver_arr
target_shim_file_path = os.path.join(target_root, target_rel_path)
__log.debug("creating symlink %s -> %s", target_shim_file_path, shim_file_path)
__makedirs(os.path.dirname(target_shim_file_path))
if __should_overwrite:
package_match = __package_pattern.search(shim_file_txt)
if __should_overwrite or package_match:
__remove_file(target_shim_file_path)
__symlink(shim_file_path, target_shim_file_path)
if package_match:
with open(target_shim_file_path, mode='w') as f:
f.write(shim_file_txt[0:package_match.start(1)])
f.write("spark")
f.write(buildver)
f.write('\n')
f.write('''
/*
!!! DO NOT EDIT THIS FILE !!!
This file has been generated from the original
%s
by interpolating $_spark.version.classifier_=%s
Be sure to edit the original file if required
*/
''' % (shim_file_path, 'spark' + buildver))
f.write(shim_file_txt[package_match.end(1):])
else:
__symlink(shim_file_path, target_shim_file_path)


def __symlink(src, target):
Expand Down
10 changes: 9 additions & 1 deletion docs/dev/shimplify.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ validations:
* The file is stored under the *owner shim* directory.

* All files participating listing the `buildver` of the current Maven build session are symlinked to
`target/${buildver}/generated/src/(main|test)/(scala|java)`. Thus, instead of hardcoding distinct
`target/${buildver}/generated/src/(main|test)/(scala|java)`
except for template classes requiring spark.version.classifier in the package name.

* If the package name of a class such as RapidsShuffleManager contains `$_spark.version.classifier_`
(because it is source-identical across shims up to the package name) it will be materialized in the
`target/${buildver}/generated/src/(main|test)/(scala|java)` with `spark.version.classifier`
interpolated into the package name.

Thus, instead of hardcoding distinct
lists of directories for `build-helper` Maven plugin to add (one for each shim) after the full
transition to shimplify, the pom will have only 4 add source statements that is independent of the
number of supported shims.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,33 @@

/*** spark-rapids-shim-json-lines
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "350"}
{"spark": "350db"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.spark320
package com.nvidia.spark.rapids.$_spark.version.classifier_

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.ProxyRapidsShuffleInternalManagerBase
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 986eb5d

Please sign in to comment.