Skip to content

Commit

Permalink
Fix python tool picking up wrong JAR version
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Sep 23, 2024
1 parent 92911a4 commit da4e166
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
4 changes: 2 additions & 2 deletions user_tools/build.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,7 +105,7 @@ build "$build_mode"

# Check build status
if [ $? -eq 0 ]; then
echo "Build successful. To install, use: pip install <wheel-file>"
echo "Build successful. To install, use: pip install dist/<wheel-file>"
else
echo "Build failed."
exit 1
Expand Down
50 changes: 41 additions & 9 deletions user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""Implementation of class holding the execution context of a rapids tool"""

import os
import re
import tarfile
from dataclasses import dataclass, field
from glob import glob
from logging import Logger
from typing import Type, Any, ClassVar, List

Expand Down Expand Up @@ -87,7 +87,7 @@ def get_deploy_mode(self) -> Any:
return self.platform_opts.get('deployMode')

def is_fatwheel_mode(self) -> bool:
return self.get_ctxt('fatwheelModeEnabled')
return self.get_ctxt('fatWheelModeEnabled')

def set_ctxt(self, key: str, val: Any):
self.props['wrapperCtx'][key] = val
Expand Down Expand Up @@ -135,14 +135,29 @@ def set_local_workdir(self, parent: str):
self.logger.info('Dependencies are generated locally in local disk as: %s', dep_folder)
self.logger.info('Local output folder is set as: %s', exec_root_dir)

def identify_fat_wheel_jar(self, resource_files: List[str]) -> None:
"""
Identifies the tools JAR file from resource files in fat wheel mode and sets its name in the context.
:param resource_files: List of resource files to search for the tools JAR file.
:raises AssertionError: If the number of matching files is not exactly one.
"""
tools_jar_regex_str = self.get_value('sparkRapids', 'toolsJarRegex')
tools_jar_regex = re.compile(tools_jar_regex_str)
matched_files = [f for f in resource_files if tools_jar_regex.search(f)]
assert len(matched_files) == 1, \
(f'Expected exactly one tools JAR file, found {len(matched_files)}. '
'Rebuild the wheel package with the correct tools JAR file.')
# set the tools JAR file name in the context
self.set_ctxt('fatWheelModeJarFileName', FSUtil.get_resource_name(matched_files[0]))

def load_prepackaged_resources(self):
"""
Checks if the packaging includes the CSP dependencies. If so, it moves the dependencies
into the tmp folder. This allows the tool to pick the resources from cache folder.
"""
if not self.are_resources_prepackaged():
return
self.set_ctxt('fatwheelModeEnabled', True)
self.set_ctxt('fatWheelModeEnabled', True)
self.logger.info(Utils.gen_str_header('Fat Wheel Mode Is Enabled',
ruler='_', line_width=50))

Expand All @@ -151,10 +166,12 @@ def load_prepackaged_resources(self):
if os.path.isdir(res_path):
# this is a directory, copy all the contents to the tmp
FSUtil.copy_resource(res_path, self.get_cache_folder())
self.identify_fat_wheel_jar(FSUtil.get_all_files(res_path))
else:
# this is an archived file
with tarfile.open(res_path, mode='r:*') as tar_file:
tar_file.extractall(self.get_cache_folder())
self.identify_fat_wheel_jar(tar_file.getnames())
tar_file.close()

def get_output_folder(self) -> str:
Expand All @@ -173,12 +190,7 @@ def get_rapids_jar_url(self) -> str:
# get the version from the package, instead of the yaml file
# jar_version = self.get_value('sparkRapids', 'version')
if self.is_fatwheel_mode():
offline_path_regex = FSUtil.build_path(self.get_cache_folder(), 'rapids-4-spark-tools_*.jar')
matching_files = glob(offline_path_regex)
if not matching_files:
raise FileNotFoundError('In Fat Mode. No matching JAR files found.')
self.logger.info('Using jar from wheel file %s', matching_files[0])
return matching_files[0]
return self._get_tools_jar_in_fat_wheel_mode()
mvn_base_url = self.get_value('sparkRapids', 'mvnUrl')
jar_version = Utilities.get_latest_mvn_jar_from_metadata(mvn_base_url)
rapids_url = self.get_value('sparkRapids', 'repoUrl').format(mvn_base_url, jar_version, jar_version)
Expand Down Expand Up @@ -209,3 +221,23 @@ def get_platform_name(self) -> str:
:return: the name of the platform of the runtime in lower_case.
"""
return CspEnv.pretty_print(self.platform.type_id)

def _get_tools_jar_in_fat_wheel_mode(self) -> str:
"""
Extracts the tools JAR file from the context and returns its path from the cache folder.
"""
jar_filename = self.get_ctxt('fatWheelModeJarFileName')
if jar_filename is None:
raise ValueError(
'In Fat Mode. Tools JAR file name not found in context. '
'Rebuild the wheel package or re-run without fat wheel mode.'
)
# construct the path to the tools JAR file in the cache folder
jar_filepath = FSUtil.build_path(self.get_cache_folder(), jar_filename)
if not FSUtil.resource_exists(jar_filepath):
raise FileNotFoundError(
f'In Fat Mode. Tools JAR not found in cache folder: {jar_filepath}. '
'Rebuild the wheel package or re-run without fat wheel mode.'
)
self.logger.info('Using jar from wheel file %s', jar_filepath)
return jar_filepath
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ toolOutput:
sparkRapids:
mvnUrl: 'https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark-tools_2.12'
repoUrl: '{}/{}/rapids-4-spark-tools_2.12-{}.jar'
toolsJarRegex: 'rapids-4-spark-tools_2.12-.*.jar'
mainClass: 'com.nvidia.spark.rapids.tool.qualification.QualificationMain'
outputDocURL: 'https://docs.nvidia.com/spark-rapids/user-guide/latest/qualification/quickstart.html#qualification-output'
enableAutoTuner: true
Expand Down

0 comments on commit da4e166

Please sign in to comment.