From 070db5c5207ef203e7ca7acbaa30aeb9c48e8d89 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Fri, 1 Nov 2024 14:17:45 +0100 Subject: [PATCH 01/33] 'Add AMD support for TorchServe' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'update backend to be hardware agnostic' Rony Leppänen 'update frontend to be hardware agnostic' Anders Smedegaard Pedersen 'update Dockerfile.dev to also work for AMD' 'update requirements/ for AMD support' Samu Tamminen Other contributions: Bipradip Chowdhury Jarkko Lehtiranta Jarkko Vainio Tero Kemppi --- .gitignore | 6 + docker/Dockerfile.dev | 37 ++- frontend/build.gradle | 4 +- .../org/pytorch/serve/device/Accelerator.java | 90 +++++++ .../serve/device/AcceleratorVendor.java | 9 + .../org/pytorch/serve/device/SystemInfo.java | 167 +++++++++++++ .../interfaces/IAcceleratorUtility.java | 226 ++++++++++++++++++ .../device/interfaces/ICsvSmiParser.java | 65 +++++ .../device/interfaces/IJsonSmiParser.java | 39 +++ .../pytorch/serve/device/utils/AppleUtil.java | 99 ++++++++ .../pytorch/serve/device/utils/CudaUtil.java | 66 +++++ .../pytorch/serve/device/utils/ROCmUtil.java | 118 +++++++++ .../pytorch/serve/device/utils/XpuUtil.java | 90 +++++++ .../java/org/pytorch/serve/util/ApiUtils.java | 3 +- .../org/pytorch/serve/util/ConfigManager.java | 95 +------- .../pytorch/serve/wlm/WorkerLifeCycle.java | 3 + .../org/pytorch/serve/wlm/WorkerThread.java | 54 +---- .../org/pytorch/serve/ModelServerTest.java | 3 +- .../pytorch/serve/device/AcceleratorTest.java | 76 ++++++ .../pytorch/serve/device/SystemInfoTest.java | 47 ++++ .../serve/device/utils/AppleUtilTest.java | 121 ++++++++++ .../serve/device/utils/CudaUtilTest.java | 132 ++++++++++ .../serve/device/utils/ROCmUtilTest.java | 143 +++++++++++ .../serve/device/utils/XpuUtilTest.java | 138 +++++++++++ .../metrics/sample_amd_discovery.json | 26 ++ .../resources/metrics/sample_amd_metrics.json | 46 ++++ .../metrics/sample_amd_updated_metrics.json | 46 ++++ .../resources/metrics/sample_apple_smi.json | 33 +++ frontend/server/testng.xml | 6 +- kubernetes/kserve/tests/scripts/test_mnist.sh | 52 +++- requirements/common_rocm.txt | 1 + requirements/torch_rocm60.txt | 5 + requirements/torch_rocm61.txt | 4 + ts/metrics/metric_collector.py | 6 +- ts/metrics/system_metrics.py | 96 +++++--- ts/torch_handler/base_handler.py | 8 +- ts_scripts/install_dependencies.py | 54 ++++- ts_scripts/install_utils | 35 ++- ts_scripts/print_env_info.py | 143 ++++++----- ts_scripts/sanity_utils.py | 33 ++- ts_scripts/utils.py | 13 +- ts_scripts/validate_model_on_gpu.py | 11 +- 42 files changed, 2151 insertions(+), 298 deletions(-) create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/Accelerator.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IAcceleratorUtility.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/interfaces/ICsvSmiParser.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IJsonSmiParser.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/utils/CudaUtil.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/utils/ROCmUtil.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/device/utils/XpuUtil.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/AcceleratorTest.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/SystemInfoTest.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/utils/CudaUtilTest.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/utils/ROCmUtilTest.java create mode 100644 frontend/server/src/test/java/org/pytorch/serve/device/utils/XpuUtilTest.java create mode 100644 frontend/server/src/test/resources/metrics/sample_amd_discovery.json create mode 100644 frontend/server/src/test/resources/metrics/sample_amd_metrics.json create mode 100644 frontend/server/src/test/resources/metrics/sample_amd_updated_metrics.json create mode 100644 frontend/server/src/test/resources/metrics/sample_apple_smi.json create mode 100644 requirements/common_rocm.txt create mode 100644 requirements/torch_rocm60.txt create mode 100644 requirements/torch_rocm61.txt diff --git a/.gitignore b/.gitignore index a2edb60d81..ba0296673f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,9 @@ instances.yaml.backup # cpp cpp/_build cpp/third-party + +# projects +.tool-versions +**/*/.classpath +**/*/.settings +**/*/.project diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 2f02d84680..bea2787bcc 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -10,7 +10,7 @@ # For reference: # https://docs.docker.com/develop/develop-images/build_enhancements/ -ARG BASE_IMAGE=ubuntu:rolling +ARG BASE_IMAGE=ubuntu:24.04 ARG BUILD_TYPE=dev FROM ${BASE_IMAGE} AS compile-image @@ -19,6 +19,7 @@ ARG BRANCH_NAME=master ARG REPO_URL=https://github.com/pytorch/serve.git ARG MACHINE_TYPE=cpu ARG CUDA_VERSION +ARG ROCM_VERSION ARG BUILD_WITH_IPEX ARG IPEX_VERSION=1.11.0 @@ -41,7 +42,7 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ git \ python$PYTHON_VERSION \ python$PYTHON_VERSION-dev \ - python3-distutils \ + python3-setuptools \ python$PYTHON_VERSION-venv \ python3-venv \ build-essential \ @@ -49,6 +50,8 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ curl \ vim \ numactl \ + zip \ + wget \ && if [ "$BUILD_WITH_IPEX" = "true" ]; then apt-get update && apt-get install -y libjemalloc-dev libgoogle-perftools-dev libomp-dev && ln -s /usr/lib/x86_64-linux-gnu/libjemalloc.so /usr/lib/libjemalloc.so && ln -s /usr/lib/x86_64-linux-gnu/libtcmalloc.so /usr/lib/libtcmalloc.so && ln -s /usr/lib/x86_64-linux-gnu/libiomp5.so /usr/lib/libiomp5.so; fi \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ @@ -58,19 +61,43 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ RUN update-alternatives --install /usr/bin/python python /usr/bin/python$PYTHON_VERSION 1 \ && update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 +RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ + if [ -n "$ROCM_VERSION" ]; then \ + apt-get update \ + && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive sudo apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ + && sudo apt-get update \ + && sudo apt-get install --no-install-recommends -y amdgpu-dkms rocm \ + && cd /home/; \ + else \ + echo "Skip ROCm installation"; \ + fi + # Build Dev Image FROM compile-image AS dev-image ARG MACHINE_TYPE=cpu ARG CUDA_VERSION -RUN if [ "$MACHINE_TYPE" = "gpu" ]; then export USE_CUDA=1; fi \ +RUN if [ "$MACHINE_TYPE" = "nvidia_gpu" ]; then export USE_CUDA=1; fi \ && git clone $REPO_URL \ && cd serve \ && git checkout ${BRANCH_NAME} \ && python$PYTHON_VERSION -m venv /home/venv ENV PATH="/home/venv/bin:$PATH" WORKDIR serve + +COPY . . + RUN python -m pip install -U pip setuptools \ - && if [ -z "$CUDA_VERSION" ]; then python ts_scripts/install_dependencies.py --environment=dev; else python ts_scripts/install_dependencies.py --environment=dev --cuda $CUDA_VERSION; fi \ + && if ([ -z "$CUDA_VERSION" ] && [ -z "$ROCM_VERSION" ]); then \ + python ts_scripts/install_dependencies.py --environment=dev; \ + elif [ -n "$ROCM_VERSION" ]; then \ + python ts_scripts/install_dependencies.py --environment=dev --rocm $ROCM_VERSION \ + && cd /opt/rocm/share/amd_smi \ + && pip install . \ + && cd /serve/; \ + else \ + python ts_scripts/install_dependencies.py --environment=dev --cuda $CUDA_VERSION; \ + fi \ && if [ "$BUILD_WITH_IPEX" = "true" ]; then python -m pip install --no-cache-dir intel_extension_for_pytorch==${IPEX_VERSION} -f ${IPEX_URL}; fi \ && python ts_scripts/install_from_src.py \ && useradd -m model-server \ @@ -83,7 +110,6 @@ RUN python -m pip install -U pip setuptools \ && chown -R model-server /home/venv EXPOSE 8080 8081 8082 7070 7071 -USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] @@ -112,4 +138,5 @@ RUN set -ex \ FROM ${BUILD_TYPE}-image AS final-image ARG BUILD_TYPE +ENV CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 RUN echo "${BUILD_TYPE} image creation completed" diff --git a/frontend/build.gradle b/frontend/build.gradle index 33920df5a3..cb5143dc11 100644 --- a/frontend/build.gradle +++ b/frontend/build.gradle @@ -37,8 +37,8 @@ def javaProjects() { configure(javaProjects()) { apply plugin: 'java-library' - sourceCompatibility = 1.8 - targetCompatibility = 1.8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 defaultTasks 'jar' diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/Accelerator.java b/frontend/server/src/main/java/org/pytorch/serve/device/Accelerator.java new file mode 100644 index 0000000000..4692653ccf --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/Accelerator.java @@ -0,0 +1,90 @@ +package org.pytorch.serve.device; + +import java.text.MessageFormat; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; + +public class Accelerator { + public final Integer id; + public final AcceleratorVendor vendor; + public final String model; + public IAcceleratorUtility acceleratorUtility; + public Float usagePercentage; + public Float memoryUtilizationPercentage; + public Integer memoryAvailableMegabytes; + public Integer memoryUtilizationMegabytes; + + public Accelerator(String acceleratorName, AcceleratorVendor vendor, Integer gpuId) { + this.model = acceleratorName; + this.vendor = vendor; + this.id = gpuId; + this.usagePercentage = (float) 0.0; + this.memoryUtilizationPercentage = (float) 0.0; + this.memoryAvailableMegabytes = 0; + this.memoryUtilizationMegabytes = 0; + } + + // Getters + public Integer getMemoryAvailableMegaBytes() { + return memoryAvailableMegabytes; + } + + public AcceleratorVendor getVendor() { + return vendor; + } + + public String getAcceleratorModel() { + return model; + } + + public Integer getAcceleratorId() { + return id; + } + + public Float getUsagePercentage() { + return usagePercentage; + } + + public Float getMemoryUtilizationPercentage() { + return memoryUtilizationPercentage; + } + + public Integer getMemoryUtilizationMegabytes() { + return memoryUtilizationMegabytes; + } + + // Setters + public void setMemoryAvailableMegaBytes(Integer memoryAvailable) { + this.memoryAvailableMegabytes = memoryAvailable; + } + + public void setUsagePercentage(Float acceleratorUtilization) { + this.usagePercentage = acceleratorUtilization; + } + + public void setMemoryUtilizationPercentage(Float memoryUtilizationPercentage) { + this.memoryUtilizationPercentage = memoryUtilizationPercentage; + } + + public void setMemoryUtilizationMegabytes(Integer memoryUtilizationMegabytes) { + this.memoryUtilizationMegabytes = memoryUtilizationMegabytes; + } + + // Other Methods + public String utilizationToString() { + final String message = + MessageFormat.format( + "gpuId::{0} utilization.gpu::{1} % utilization.memory::{2} % memory.used::{3} MiB", + id, + usagePercentage, + memoryUtilizationPercentage, + memoryUtilizationMegabytes); + + return message; + } + + public void updateDynamicAttributes(Accelerator updated) { + this.usagePercentage = updated.usagePercentage; + this.memoryUtilizationPercentage = updated.memoryUtilizationPercentage; + this.memoryUtilizationMegabytes = updated.memoryUtilizationMegabytes; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java b/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java new file mode 100644 index 0000000000..22fd1f5d68 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/AcceleratorVendor.java @@ -0,0 +1,9 @@ +package org.pytorch.serve.device; + +public enum AcceleratorVendor { + AMD, + NVIDIA, + INTEL, + APPLE, + UNKNOWN +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java b/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java new file mode 100644 index 0000000000..f2034ca186 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/SystemInfo.java @@ -0,0 +1,167 @@ +package org.pytorch.serve.device; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.utils.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SystemInfo { + static final Logger logger = LoggerFactory.getLogger(SystemInfo.class); + // + // Contains information about the system (physical or virtual machine) + // we are running the workload on. + // Specifically how many accelerators and info about them. + // + + public AcceleratorVendor acceleratorVendor; + ArrayList accelerators; + private IAcceleratorUtility acceleratorUtil; + + public SystemInfo() { + // Detect and set the vendor of any accelerators in the system + this.acceleratorVendor = detectVendorType(); + this.accelerators = new ArrayList(); + + // If accelerators are present (vendor != UNKNOWN), + // initialize accelerator utilities + Optional.of(hasAccelerators()) + // Only proceed if hasAccelerators() returns true + .filter(Boolean::booleanValue) + // Execute this block if accelerators are present + .ifPresent( + hasAcc -> { + // Create the appropriate utility class based on vendor + this.acceleratorUtil = createAcceleratorUtility(); + // Populate the accelerators list based on environment + // variables and available devices + populateAccelerators(); + }); + + // Safely handle accelerator metrics update + Optional.ofNullable(accelerators) + // Only proceed if the accelerators list is not empty + .filter(list -> !list.isEmpty()) + // Update metrics (utilization, memory, etc.) for all accelerators if list + // exists and not empty + .ifPresent(list -> updateAcceleratorMetrics()); + } + + private IAcceleratorUtility createAcceleratorUtility() { + switch (this.acceleratorVendor) { + case AMD: + return new ROCmUtil(); + case NVIDIA: + return new CudaUtil(); + case INTEL: + return new XpuUtil(); + case APPLE: + return new AppleUtil(); + default: + return null; + } + } + + private void populateAccelerators() { + if (this.acceleratorUtil != null) { + String envVarName = this.acceleratorUtil.getGpuEnvVariableName(); + String requestedAcceleratorIds = System.getenv(envVarName); + LinkedHashSet availableAcceleratorIds = + IAcceleratorUtility.parseVisibleDevicesEnv(requestedAcceleratorIds); + this.accelerators = + this.acceleratorUtil.getAvailableAccelerators(availableAcceleratorIds); + } else { + this.accelerators = new ArrayList<>(); + } + } + + boolean hasAccelerators() { + return this.acceleratorVendor != AcceleratorVendor.UNKNOWN; + } + + public Integer getNumberOfAccelerators() { + // since we instance create `accelerators` as an empty list + // in the constructor, the null check should be redundant. + // leaving it to be sure. + return (accelerators != null) ? accelerators.size() : 0; + } + + public static AcceleratorVendor detectVendorType() { + if (isCommandAvailable("rocm-smi")) { + return AcceleratorVendor.AMD; + } else if (isCommandAvailable("nvidia-smi")) { + return AcceleratorVendor.NVIDIA; + } else if (isCommandAvailable("xpu-smi")) { + return AcceleratorVendor.INTEL; + } else if (isCommandAvailable("system_profiler")) { + return AcceleratorVendor.APPLE; + } else { + return AcceleratorVendor.UNKNOWN; + } + } + + private static boolean isCommandAvailable(String command) { + String operatingSystem = System.getProperty("os.name").toLowerCase(); + String commandCheck = operatingSystem.contains("win") ? "where" : "which"; + ProcessBuilder processBuilder = new ProcessBuilder(commandCheck, command); + try { + Process process = processBuilder.start(); + int exitCode = process.waitFor(); + return exitCode == 0; + } catch (IOException | InterruptedException e) { + return false; + } + } + + public ArrayList getAccelerators() { + return this.accelerators; + } + + private void updateAccelerators(List updatedAccelerators) { + // Create a map of existing accelerators with ID as key + Map existingAcceleratorsMap = + this.accelerators.stream().collect(Collectors.toMap(acc -> acc.id, acc -> acc)); + + // Update existing accelerators and add new ones + this.accelerators = + updatedAccelerators.stream() + .map( + updatedAcc -> { + Accelerator existingAcc = + existingAcceleratorsMap.get(updatedAcc.id); + if (existingAcc != null) { + existingAcc.updateDynamicAttributes(updatedAcc); + return existingAcc; + } else { + return updatedAcc; + } + }) + .collect(Collectors.toCollection(ArrayList::new)); + } + + public void updateAcceleratorMetrics() { + if (this.acceleratorUtil != null) { + List updatedAccelerators = + this.acceleratorUtil.getUpdatedAcceleratorsUtilization(this.accelerators); + + updateAccelerators(updatedAccelerators); + } + } + + public AcceleratorVendor getAcceleratorVendor() { + return this.acceleratorVendor; + } + + public String getVisibleDevicesEnvName() { + if (this.accelerators.isEmpty() || this.accelerators == null) { + return null; + } + return this.accelerators.get(0).acceleratorUtility.getGpuEnvVariableName(); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IAcceleratorUtility.java b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IAcceleratorUtility.java new file mode 100644 index 0000000000..8bbe630c47 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IAcceleratorUtility.java @@ -0,0 +1,226 @@ +package org.pytorch.serve.device.interfaces; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.commons.io.IOUtils; +import org.pytorch.serve.device.Accelerator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Provides functionality to detect hardware devices for accelerated workloads. For example GPUs. + */ +public interface IAcceleratorUtility { + static final Logger logger = LoggerFactory.getLogger(IAcceleratorUtility.class); + + /** + * Returns the name of the environment variable used to specify visible GPU devices. + * Implementing classes should define this based on their specific requirements. + * + *

Examples are 'HIP_VISIBLE_DEVICES', 'CUDA_VISIBLE_DEVICES' + * + * @return The name of the environment variable for visible GPU devices. + */ + String getGpuEnvVariableName(); + + /** + * Returns the SMI command specific to the implementing class. + * + * @return An array of strings representing the SMI command and its arguments for getting the + * utilizaiton stats for the available accelerators + */ + String[] getUtilizationSmiCommand(); + + /** + * Parses a string representation of visible devices into an {@code LinkedHashSet} of device + * identifiers. + * + *

This method processes a comma-separated list of device identifiers, typically obtained + * from an environment variable like {X}_VISIBLE_DEVICES. It performs validation and cleaning of + * the input string. + * + * @param visibleDevices A string containing comma-separated device identifiers. Can be null or + * empty. + * @return A LinkedHashSet of Integers each representing a device identifier. Returns an empty + * set if the input is null or empty. + * @throws IllegalArgumentException if the input string is not in the correct format (integers + * separated by commas, with or without spaces). + * @example // Returns [0, 1, 2] parseVisibleDevicesEnv("0,1,2") + *

// notice spaces between the commas and the next number // Returns [0, 1, 2] + * parseVisibleDevicesEnv("0, 1, 2") + *

// Returns [0, 2] parseVisibleDevicesEnv("0,0,2") + *

// Returns [] parseVisibleDevicesEnv("") + *

// Throws IllegalArgumentException parseVisibleDevicesEnv("0,1,a") + */ + static LinkedHashSet parseVisibleDevicesEnv(String visibleDevices) { + // return an empty set if null or an empty string is passed + if (visibleDevices == null || visibleDevices.isEmpty()) { + return new LinkedHashSet<>(); + } + + // Remove all spaces from the input + String cleaned = visibleDevices.replaceAll("\\s", ""); + + // Check if the cleaned string matches the pattern of integers separated by + // commas + if (!cleaned.matches("^\\d+(,\\d+)*$")) { + throw new IllegalArgumentException( + "Invalid format: The env defining visible devices must be integers separated by commas"); + } + + // split the string on comma, cast to Integer, and collect to a List + List allIntegers = + Arrays.stream(cleaned.split(",")) + .map(Integer::parseInt) + .collect(Collectors.toList()); + + // use Sets to deduplicate integers + LinkedHashSet uniqueIntegers = new LinkedHashSet<>(); + Set duplicates = + allIntegers.stream() + .filter(n -> !uniqueIntegers.add(n)) + .collect(Collectors.toSet()); + + if (!duplicates.isEmpty()) { + logger.warn( + "Duplicate GPU IDs found in {}: {}. Duplicates will be removed.", + visibleDevices, + duplicates); + } + + // return the set of unique integers + return uniqueIntegers; + } + + /** + * Parses the output of a system management interface (SMI) command to create a list of {@code + * Accelerator} objects with updated metrics. + * + * @param smiOutput The raw output string from the SMI command. + * @param parsed_gpu_ids A set of GPU IDs that have already been parsed. + * @return An {@code ArrayList} of {@code Accelerator} objects representing the parsed + * accelerators. + * @implNote The specific SMI command, output format, and environment variables will vary + * depending on the accelerator type. The SMI command should return core usage, memory + * utilization. Implementations should document these specifics in their method comments. If + * {@code parsed_gpu_ids} is empty, all accelerators found by the smi command should be + * returned. + * @throws IllegalArgumentException If the SMI output is invalid or cannot be parsed. + * @throws NullPointerException If either {@code smiOutput} or {@code parsed_gpu_ids} is null. + */ + ArrayList smiOutputToUpdatedAccelerators( + String smiOutput, LinkedHashSet parsed_gpu_ids); + + /** + * @param availableAcceleratorIds + * @return + */ + public ArrayList getAvailableAccelerators( + LinkedHashSet availableAcceleratorIds); + + /** + * Converts a number of bytes to megabytes. + * + *

This method uses the binary definition of a megabyte, where 1 MB = 1,048,576 bytes (1024 * + * 1024). The result is rounded to two decimal places. + * + * @param bytes The number of bytes to convert, as a long value. + * @return The equivalent number of megabytes, as a double value rounded to two decimal places. + */ + static Integer bytesToMegabytes(long bytes) { + final double BYTES_IN_MEGABYTE = 1024 * 1024; + return (int) (bytes / BYTES_IN_MEGABYTE); + } + + /** + * Executes an SMI (System Management Interface) command and returns its output. + * + *

This method runs the specified command using a ProcessBuilder, combines standard output + * and error streams, waits for the process to complete, and returns the output as a string. + * + * @param command An array of strings representing the SMI command and its arguments. + * @return A string containing the output of the SMI command. + * @throws AssertionError If the SMI command returns a non-zero exit code. + * @throws Error If an IOException or InterruptedException occurs during execution. The original + * exception is wrapped in the Error. + */ + static String callSMI(String[] command) { + try { + ProcessBuilder processBuilder = new ProcessBuilder(command); + processBuilder.redirectErrorStream(true); + Process process = processBuilder.start(); + int ret = process.waitFor(); + if (ret != 0) { + throw new AssertionError("SMI command returned a non-zero"); + } + + String output = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + if (output.isEmpty()) { + throw new AssertionError("Unexpected smi response."); + } + return output; + + } catch (IOException | InterruptedException e) { + logger.debug("SMI command not available or failed: " + e.getMessage()); + throw new Error(e); + } + } + + /** + * Updates the utilization information for a list of accelerators. + * + *

This method retrieves the current utilization statistics for the given accelerators using + * a System Management Interface (SMI) command specific to the implementing class. It then + * parses the SMI output and returns an updated {@code ArrayList} of accelerator objects with + * the latest information. + * + * @param accelerators An ArrayList of Accelerator objects to be updated. Must not be null or + * empty. + * @return An ArrayList of updated Accelerator objects with the latest utilization information. + * @throws IllegalArgumentException If the input accelerators list is null or empty, or if the + * SMI command returned by getUtilizationSmiCommand() is null or empty. + * @throws RuntimeException If an error occurs while executing the SMI command or parsing its + * output. The specific exception will depend on the implementation of callSMI() and + * smiOutputToAccelerators(). + * @implNote This method uses getUtilizationSmiCommand() to retrieve the SMI command specific to + * the implementing class. Subclasses must implement this method to provide the correct + * command. The method also relies on callSMI() to execute the command and + * smiOutputToAccelerators() to parse the output, both of which must be implemented by the + * subclass. + * @implSpec The implementation first checks if the input is valid, then retrieves and validates + * the SMI command. It executes the command, extracts the GPU IDs from the input + * accelerators, and uses these to parse the SMI output into updated Accelerator objects. + * @see #getUtilizationSmiCommand() + * @see #callSMI(String[]) + * @see #smiOutputToUpdatedAccelerators(String, LinkedHashSet) + */ + default ArrayList getUpdatedAcceleratorsUtilization( + ArrayList accelerators) { + if (accelerators == null || accelerators.isEmpty()) { + logger.warn("No accelerators to update."); + throw new IllegalArgumentException( + "`accelerators` cannot be null or empty when trying to update the accelerator stats"); + } + + String[] smiCommand = getUtilizationSmiCommand(); + if (smiCommand == null || smiCommand.length == 0) { + throw new IllegalArgumentException( + "`smiCommand` cannot be null or empty when trying to update accelerator stats"); + } + + String smiOutput = callSMI(smiCommand); + LinkedHashSet acceleratorIds = + accelerators.stream() + .map(accelerator -> accelerator.id) + .collect(Collectors.toCollection(LinkedHashSet::new)); + ArrayList updatedAccelerators = + smiOutputToUpdatedAccelerators(smiOutput, acceleratorIds); + return updatedAccelerators; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/ICsvSmiParser.java b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/ICsvSmiParser.java new file mode 100644 index 0000000000..98a7351467 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/ICsvSmiParser.java @@ -0,0 +1,65 @@ +package org.pytorch.serve.device.interfaces; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.function.Function; +import org.pytorch.serve.device.Accelerator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public interface ICsvSmiParser { + static final Logger csvSmiParserLogger = LoggerFactory.getLogger(ICsvSmiParser.class); + + /** + * Parses CSV output from SMI commands and converts it into a list of Accelerator objects. + * + * @param csvOutput The CSV string output from an SMI command. + * @param parsedAcceleratorIds A set of accelerator IDs to consider. If empty, all accelerators + * are included. + * @param parseFunction A function that takes an array of CSV fields and returns an Accelerator + * object. This function should handle the specific parsing logic for different SMI command + * outputs. + * @return An ArrayList of Accelerator objects parsed from the CSV output. + * @throws NumberFormatException If there's an error parsing numeric fields in the CSV. + *

This method provides a general way to parse CSV output from various SMI commands. It + * skips the header line of the CSV, then applies the provided parseFunction to each + * subsequent line. Accelerators are only included if their ID is in parsedAcceleratorIds, + * or if parsedAcceleratorIds is empty (indicating all accelerators should be included). + *

The parseFunction parameter allows for flexibility in handling different CSV formats + * from various SMI commands. This function should handle the specific logic for creating an + * Accelerator object from a line of CSV data. + */ + default ArrayList csvSmiOutputToAccelerators( + final String csvOutput, + final LinkedHashSet parsedGpuIds, + Function parseFunction) { + final ArrayList accelerators = new ArrayList<>(); + + List lines = Arrays.asList(csvOutput.split("\n")); + + final boolean addAll = parsedGpuIds.isEmpty(); + + lines.stream() + .skip(1) // Skip the header line + .forEach( + line -> { + final String[] parts = line.split(","); + try { + Accelerator accelerator = parseFunction.apply(parts); + if (accelerator != null + && (addAll + || parsedGpuIds.contains( + accelerator.getAcceleratorId()))) { + accelerators.add(accelerator); + } + } catch (final NumberFormatException e) { + csvSmiParserLogger.warn( + "Failed to parse GPU ID: " + parts[1].trim(), e); + } + }); + + return accelerators; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IJsonSmiParser.java b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IJsonSmiParser.java new file mode 100644 index 0000000000..0a39ebfc91 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/interfaces/IJsonSmiParser.java @@ -0,0 +1,39 @@ +package org.pytorch.serve.device.interfaces; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import org.pytorch.serve.device.Accelerator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public interface IJsonSmiParser { + static final Logger jsonSmiParserLogger = LoggerFactory.getLogger(IJsonSmiParser.class); + + default ArrayList jsonOutputToAccelerators( + JsonElement rootObject, LinkedHashSet parsedAcceleratorIds) { + + ArrayList accelerators = new ArrayList<>(); + List acceleratorObjects = extractAccelerators(rootObject); + + for (JsonObject acceleratorObject : acceleratorObjects) { + Integer acceleratorId = extractAcceleratorId(acceleratorObject); + if (acceleratorId != null + && (parsedAcceleratorIds.isEmpty() + || parsedAcceleratorIds.contains(acceleratorId))) { + Accelerator accelerator = jsonObjectToAccelerator(acceleratorObject); + accelerators.add(accelerator); + } + } + + return accelerators; + } + + public Integer extractAcceleratorId(JsonObject jsonObject); + + public Accelerator jsonObjectToAccelerator(JsonObject jsonObject); + + public List extractAccelerators(JsonElement rootObject); +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java new file mode 100644 index 0000000000..ae87e85255 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java @@ -0,0 +1,99 @@ +package org.pytorch.serve.device.utils; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.interfaces.IJsonSmiParser; + +public class AppleUtil implements IAcceleratorUtility, IJsonSmiParser { + + @Override + public String getGpuEnvVariableName() { + return null; // Apple doesn't use a GPU environment variable + } + + @Override + public String[] getUtilizationSmiCommand() { + return new String[] { + "system_profiler", "-json", "-detailLevel", "mini", "SPDisplaysDataType" + }; + } + + @Override + public ArrayList getAvailableAccelerators( + LinkedHashSet availableAcceleratorIds) { + String jsonOutput = IAcceleratorUtility.callSMI(getUtilizationSmiCommand()); + JsonObject rootObject = JsonParser.parseString(jsonOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, availableAcceleratorIds); + } + + @Override + public ArrayList smiOutputToUpdatedAccelerators( + String smiOutput, LinkedHashSet parsedGpuIds) { + JsonObject rootObject = JsonParser.parseString(smiOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, parsedGpuIds); + } + + @Override + public Accelerator jsonObjectToAccelerator(JsonObject gpuObject) { + String model = gpuObject.get("sppci_model").getAsString(); + if (!model.startsWith("Apple M")) { + return null; + } + + Accelerator accelerator = new Accelerator(model, AcceleratorVendor.APPLE, 0); + + // Set additional information + accelerator.setUsagePercentage(0f); // Not available from system_profiler + accelerator.setMemoryUtilizationPercentage(0f); // Not available from system_profiler + accelerator.setMemoryUtilizationMegabytes(0); // Not available from system_profiler + + return accelerator; + } + + @Override + public Integer extractAcceleratorId(JsonObject cardObject) { + // `system_profiler` only returns one object for + // the integrated GPU on M1, M2, M3 Macs + return 0; + } + + @Override + public List extractAccelerators(JsonElement rootObject) { + List accelerators = new ArrayList<>(); + JsonArray displaysArray = + rootObject + .getAsJsonObject() // Gets the outer object + .get("SPDisplaysDataType") // Gets the "SPDisplaysDataType" element + .getAsJsonArray(); + JsonObject gpuObject = displaysArray.get(0).getAsJsonObject(); + accelerators.add(gpuObject); + return accelerators; + } + + public ArrayList jsonOutputToAccelerators( + JsonObject rootObject, LinkedHashSet parsedAcceleratorIds) { + + ArrayList accelerators = new ArrayList<>(); + List acceleratorObjects = extractAccelerators(rootObject); + + for (JsonObject acceleratorObject : acceleratorObjects) { + Integer acceleratorId = extractAcceleratorId(acceleratorObject); + if (acceleratorId != null + && (parsedAcceleratorIds.isEmpty() + || parsedAcceleratorIds.contains(acceleratorId))) { + Accelerator accelerator = jsonObjectToAccelerator(acceleratorObject); + accelerators.add(accelerator); + } + } + + return accelerators; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/CudaUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/CudaUtil.java new file mode 100644 index 0000000000..b64faa57a4 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/CudaUtil.java @@ -0,0 +1,66 @@ +package org.pytorch.serve.device.utils; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.interfaces.ICsvSmiParser; + +public class CudaUtil implements IAcceleratorUtility, ICsvSmiParser { + + @Override + public String getGpuEnvVariableName() { + return "CUDA_VISIBLE_DEVICES"; + } + + @Override + public String[] getUtilizationSmiCommand() { + String metrics = + String.join( + ",", + "index", + "gpu_name", + "utilization.gpu", + "utilization.memory", + "memory.used"); + return new String[] {"nvidia-smi", "--query-gpu=" + metrics, "--format=csv,nounits"}; + } + + @Override + public ArrayList getAvailableAccelerators( + LinkedHashSet availableAcceleratorIds) { + String[] command = {"nvidia-smi", "--query-gpu=index,gpu_name", "--format=csv,nounits"}; + + String smiOutput = IAcceleratorUtility.callSMI(command); + return csvSmiOutputToAccelerators( + smiOutput, availableAcceleratorIds, this::parseAccelerator); + } + + @Override + public ArrayList smiOutputToUpdatedAccelerators( + String smiOutput, LinkedHashSet parsedGpuIds) { + + return csvSmiOutputToAccelerators(smiOutput, parsedGpuIds, this::parseUpdatedAccelerator); + } + + public Accelerator parseAccelerator(String[] parts) { + int id = Integer.parseInt(parts[0].trim()); + String model = parts[1].trim(); + return new Accelerator(model, AcceleratorVendor.NVIDIA, id); + } + + public Accelerator parseUpdatedAccelerator(String[] parts) { + int id = Integer.parseInt(parts[0].trim()); + String model = parts[1].trim(); + Float usagePercentage = Float.parseFloat(parts[2].trim()); + Float memoryUtilizationPercentage = Float.parseFloat(parts[3].trim()); + int memoryUtilizationMegabytes = Integer.parseInt(parts[4].trim()); + + Accelerator accelerator = new Accelerator(model, AcceleratorVendor.NVIDIA, id); + accelerator.setUsagePercentage(usagePercentage); + accelerator.setMemoryUtilizationPercentage(memoryUtilizationPercentage); + accelerator.setMemoryUtilizationMegabytes(memoryUtilizationMegabytes); + return accelerator; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/ROCmUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/ROCmUtil.java new file mode 100644 index 0000000000..0b165469f7 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/ROCmUtil.java @@ -0,0 +1,118 @@ +package org.pytorch.serve.device.utils; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.interfaces.IJsonSmiParser; + +public class ROCmUtil implements IAcceleratorUtility, IJsonSmiParser { + private static final Pattern GPU_ID_PATTERN = Pattern.compile("card(\\d+)"); + + @Override + public String getGpuEnvVariableName() { + return "HIP_VISIBLE_DEVICES"; + } + + @Override + public String[] getUtilizationSmiCommand() { + return new String[] { + "rocm-smi", + "--showid", + "--showproductname", + "--showuse", + "--showmemuse", + "--showmeminfo", + "vram", + "-P", + "--json" + }; + } + + @Override + public ArrayList getAvailableAccelerators( + LinkedHashSet availableAcceleratorIds) { + String[] smiCommand = {"rocm-smi", "--showproductname", "-P", "--json"}; + String jsonOutput = IAcceleratorUtility.callSMI(smiCommand); + + JsonObject rootObject = JsonParser.parseString(jsonOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, availableAcceleratorIds); + } + + @Override + public ArrayList smiOutputToUpdatedAccelerators( + String smiOutput, LinkedHashSet parsedGpuIds) { + JsonObject rootObject = JsonParser.parseString(smiOutput).getAsJsonObject(); + return jsonOutputToAccelerators(rootObject, parsedGpuIds); + } + + @Override + public List extractAccelerators(JsonElement rootObject) { + JsonObject root = rootObject.getAsJsonObject(); + List accelerators = new ArrayList<>(); + for (String key : root.keySet()) { + if (GPU_ID_PATTERN.matcher(key).matches()) { + JsonObject accelerator = root.getAsJsonObject(key); + accelerator.addProperty("cardId", key); // Add the card ID to the JsonObject + accelerators.add(accelerator); + } + } + return accelerators; + } + + @Override + public Integer extractAcceleratorId(JsonObject jsonObject) { + String cardId = jsonObject.get("cardId").getAsString(); + Matcher matcher = GPU_ID_PATTERN.matcher(cardId); + if (matcher.matches()) { + return Integer.parseInt(matcher.group(1)); + } + return null; + } + + @Override + public Accelerator jsonObjectToAccelerator(JsonObject jsonObject) { + // Check if required field exists + if (!jsonObject.has("Card Series")) { + throw new IllegalArgumentException("Missing required field: Card Series"); + } + + String model = jsonObject.get("Card Series").getAsString(); + Integer acceleratorId = extractAcceleratorId(jsonObject); + Accelerator accelerator = new Accelerator(model, AcceleratorVendor.AMD, acceleratorId); + + // Set optional fields using GSON's has() method + if (jsonObject.has("GPU use (%)")) { + accelerator.setUsagePercentage( + Float.parseFloat(jsonObject.get("GPU use (%)").getAsString())); + } + + if (jsonObject.has("GPU Memory Allocated (VRAM%)")) { + accelerator.setMemoryUtilizationPercentage( + Float.parseFloat(jsonObject.get("GPU Memory Allocated (VRAM%)").getAsString())); + } + + if (jsonObject.has("VRAM Total Memory (B)")) { + String totalMemoryStr = jsonObject.get("VRAM Total Memory (B)").getAsString().strip(); + Long totalMemoryBytes = Long.parseLong(totalMemoryStr); + accelerator.setMemoryAvailableMegaBytes( + IAcceleratorUtility.bytesToMegabytes(totalMemoryBytes)); + } + + if (jsonObject.has("VRAM Total Used Memory (B)")) { + String usedMemoryStr = jsonObject.get("VRAM Total Used Memory (B)").getAsString(); + Long usedMemoryBytes = Long.parseLong(usedMemoryStr); + accelerator.setMemoryUtilizationMegabytes( + IAcceleratorUtility.bytesToMegabytes(usedMemoryBytes)); + } + + return accelerator; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/XpuUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/XpuUtil.java new file mode 100644 index 0000000000..2ec2900035 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/XpuUtil.java @@ -0,0 +1,90 @@ +package org.pytorch.serve.device.utils; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.pytorch.serve.device.interfaces.ICsvSmiParser; + +public class XpuUtil implements IAcceleratorUtility, ICsvSmiParser { + + @Override + public String getGpuEnvVariableName() { + return "XPU_VISIBLE_DEVICES"; + } + + @Override + public ArrayList getAvailableAccelerators( + final LinkedHashSet availableAcceleratorIds) { + final String[] smiCommand = { + "xpu-smi", + "discovery", + "--dump", // output as csv + String.join( + ",", + "1", // device Id + "2", // Device name + "16" // Memory physical size + ) + }; + final String smiOutput = IAcceleratorUtility.callSMI(smiCommand); + + final String acceleratorEnv = getGpuEnvVariableName(); + final String requestedAccelerators = System.getenv(acceleratorEnv); + final LinkedHashSet parsedAcceleratorIds = + IAcceleratorUtility.parseVisibleDevicesEnv(requestedAccelerators); + + return csvSmiOutputToAccelerators( + smiOutput, parsedAcceleratorIds, this::parseDiscoveryOutput); + } + + @Override + public final ArrayList smiOutputToUpdatedAccelerators( + final String smiOutput, final LinkedHashSet parsedGpuIds) { + return csvSmiOutputToAccelerators(smiOutput, parsedGpuIds, this::parseUtilizationOutput); + } + + @Override + public String[] getUtilizationSmiCommand() { + // https://intel.github.io/xpumanager/smi_user_guide.html#get-the-device-real-time-statistics + // Timestamp, DeviceId, GPU Utilization (%), GPU Memory Utilization (%) + // 06:14:46.000, 0, 0.00, 14.61 + // 06:14:47.000, 1, 0.00, 14.59 + final String[] smiCommand = { + "xpu-smi", + "dump", + "-d -1", // all devices + "-n 1", // one dump + "-m", // metrics + String.join( + ",", + "0", // GPU Utilization (%), GPU active time of the elapsed time, per tile or + // device. + // Device-level is the average value of tiles for multi-tiles. + "5" // GPU Memory Utilization (%), per tile or device. Device-level is the + // average + // value of tiles for multi-tiles. + ) + }; + + return smiCommand; + } + + private Accelerator parseDiscoveryOutput(String[] parts) { + final int acceleratorId = Integer.parseInt(parts[1].trim()); + final String deviceName = parts[2].trim(); + logger.debug("Found accelerator at index: {}, Card name: {}", acceleratorId, deviceName); + return new Accelerator(deviceName, AcceleratorVendor.INTEL, acceleratorId); + } + + private Accelerator parseUtilizationOutput(String[] parts) { + final int acceleratorId = Integer.parseInt(parts[1].trim()); + final Float usagePercentage = Float.parseFloat(parts[2]); + final Float memoryUsagePercentage = Float.parseFloat(parts[3]); + Accelerator accelerator = new Accelerator("", AcceleratorVendor.INTEL, acceleratorId); + accelerator.setUsagePercentage(usagePercentage); + accelerator.setMemoryUtilizationPercentage(memoryUsagePercentage); + return accelerator; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java index 70f5a1c644..12a00d57d0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java @@ -374,7 +374,8 @@ public static String getWorkerStatus() { } else if ((numWorking == 0) && (numScaled > 0)) { response = "Unhealthy"; } - // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" and + // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" + // and // "Unhealthy" return response; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 791dac511c..ab8693185e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -5,11 +5,9 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; -import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.io.InputStreamReader; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -27,7 +25,6 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; -import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collection; @@ -46,6 +43,7 @@ import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; import org.pytorch.serve.archive.model.Manifest; +import org.pytorch.serve.device.SystemInfo; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; import org.pytorch.serve.snapshot.SnapshotSerializerFactory; @@ -53,8 +51,10 @@ import org.slf4j.LoggerFactory; public final class ConfigManager { - // Variables that can be configured through config.properties and Environment Variables - // NOTE: Variables which can be configured through environment variables **SHOULD** have a + // Variables that can be configured through config.properties and Environment + // Variables + // NOTE: Variables which can be configured through environment variables + // **SHOULD** have a // "TS_" prefix private static final String TS_DEBUG = "debug"; @@ -128,7 +128,8 @@ public final class ConfigManager { private static final String TS_DISABLE_TOKEN_AUTHORIZATION = "disable_token_authorization"; private static final String TS_ENABLE_MODEL_API = "enable_model_api"; - // Configuration which are not documented or enabled through environment variables + // Configuration which are not documented or enabled through environment + // variables private static final String USE_NATIVE_IO = "use_native_io"; private static final String IO_RATIO = "io_ratio"; private static final String METRIC_TIME_INTERVAL = "metric_time_interval"; @@ -176,10 +177,13 @@ public final class ConfigManager { private String headerKeySequenceStart; private String headerKeySequenceEnd; + public SystemInfo systemInfo; + private static final Logger logger = LoggerFactory.getLogger(ConfigManager.class); private ConfigManager(Arguments args) throws IOException { prop = new Properties(); + this.systemInfo = new SystemInfo(); this.snapshotDisabled = args.isSnapshotDisabled(); String version = readFile(getModelServerHome() + "/ts/version.txt"); @@ -271,7 +275,7 @@ private ConfigManager(Arguments args) throws IOException { TS_NUMBER_OF_GPU, String.valueOf( Integer.min( - getAvailableGpu(), + this.systemInfo.getNumberOfAccelerators(), getIntProperty(TS_NUMBER_OF_GPU, Integer.MAX_VALUE)))); String pythonExecutable = args.getPythonExecutable(); @@ -931,83 +935,6 @@ private static String getCanonicalPath(String path) { return getCanonicalPath(new File(path)); } - private static int getAvailableGpu() { - try { - - List gpuIds = new ArrayList<>(); - String visibleCuda = System.getenv("CUDA_VISIBLE_DEVICES"); - if (visibleCuda != null && !visibleCuda.isEmpty()) { - String[] ids = visibleCuda.split(","); - for (String id : ids) { - gpuIds.add(Integer.parseInt(id)); - } - } else if (System.getProperty("os.name").startsWith("Mac")) { - Process process = Runtime.getRuntime().exec("system_profiler SPDisplaysDataType"); - int ret = process.waitFor(); - if (ret != 0) { - return 0; - } - - BufferedReader reader = - new BufferedReader(new InputStreamReader(process.getInputStream())); - String line; - while ((line = reader.readLine()) != null) { - if (line.contains("Chipset Model:") && !line.contains("Apple M1")) { - return 0; - } - if (line.contains("Total Number of Cores:")) { - String[] parts = line.split(":"); - if (parts.length >= 2) { - return (Integer.parseInt(parts[1].trim())); - } - } - } - // No MPS devices detected - return 0; - } else { - - try { - Process process = - Runtime.getRuntime().exec("nvidia-smi --query-gpu=index --format=csv"); - int ret = process.waitFor(); - if (ret != 0) { - return 0; - } - List list = - IOUtils.readLines(process.getInputStream(), StandardCharsets.UTF_8); - if (list.isEmpty() || !"index".equals(list.get(0))) { - throw new AssertionError("Unexpected nvidia-smi response."); - } - for (int i = 1; i < list.size(); i++) { - gpuIds.add(Integer.parseInt(list.get(i))); - } - } catch (IOException | InterruptedException e) { - System.out.println("nvidia-smi not available or failed: " + e.getMessage()); - } - try { - Process process = Runtime.getRuntime().exec("xpu-smi discovery --dump 1"); - int ret = process.waitFor(); - if (ret != 0) { - return 0; - } - List list = - IOUtils.readLines(process.getInputStream(), StandardCharsets.UTF_8); - if (list.isEmpty() || !list.get(0).contains("Device ID")) { - throw new AssertionError("Unexpected xpu-smi response."); - } - for (int i = 1; i < list.size(); i++) { - gpuIds.add(Integer.parseInt(list.get(i))); - } - } catch (IOException | InterruptedException e) { - logger.debug("xpu-smi not available or failed: " + e.getMessage()); - } - } - return gpuIds.size(); - } catch (IOException | InterruptedException e) { - return 0; - } - } - public List getAllowedUrls() { String allowedURL = prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS); return Arrays.asList(allowedURL.split(",")); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 74b31dfd24..177a515f32 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -15,6 +15,7 @@ import java.util.regex.Pattern; import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.archive.model.ModelConfig.ParallelType; +import org.pytorch.serve.device.AcceleratorVendor; import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.metrics.MetricCache; import org.pytorch.serve.util.ConfigManager; @@ -135,6 +136,8 @@ private void startWorkerPython(int port, String deviceIds) attachRunner(argl, envp, port, deviceIds); } else { if (deviceIds != null) { + AcceleratorVendor visibleDeviceEnvName = + configManager.systemInfo.getAcceleratorVendor(); envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds); } argl.add(EnvironmentUtils.getPythonRunTime(model)); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index bedf5fac3e..28c1eee18c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -9,13 +9,9 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.net.HttpURLConnection; import java.net.SocketAddress; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -27,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import org.pytorch.serve.device.Accelerator; import org.pytorch.serve.job.Job; import org.pytorch.serve.job.RestJob; import org.pytorch.serve.metrics.IMetric; @@ -118,51 +115,13 @@ public WorkerState getState() { } public String getGpuUsage() { - Process process; StringBuffer gpuUsage = new StringBuffer(); if (gpuId >= 0) { try { - // TODO : add a generic code to capture gpu details for different devices instead of - // just NVIDIA - ProcessBuilder pb = - new ProcessBuilder( - "nvidia-smi", - "-i", - String.valueOf(gpuId), - "--query-gpu=utilization.gpu,utilization.memory,memory.used", - "--format=csv"); - - // Start the process - process = pb.start(); - process.waitFor(); - int exitCode = process.exitValue(); - if (exitCode != 0) { - gpuUsage.append("failed to obtained gpu usage"); - InputStream error = process.getErrorStream(); - for (int i = 0; i < error.available(); i++) { - logger.error("" + error.read()); - } - return gpuUsage.toString(); - } - InputStream stdout = process.getInputStream(); - BufferedReader reader = - new BufferedReader(new InputStreamReader(stdout, StandardCharsets.UTF_8)); - String line; - String[] headers = new String[3]; - Boolean firstLine = true; - while ((line = reader.readLine()) != null) { - if (firstLine) { - headers = line.split(","); - firstLine = false; - } else { - String[] values = line.split(","); - StringBuffer sb = new StringBuffer("gpuId::" + gpuId + " "); - for (int i = 0; i < headers.length; i++) { - sb.append(headers[i] + "::" + values[i].strip()); - } - gpuUsage.append(sb.toString()); - } - } + configManager.systemInfo.updateAcceleratorMetrics(); + Accelerator accelerator = + this.configManager.systemInfo.getAccelerators().get(gpuId); + return accelerator.utilizationToString(); } catch (Exception e) { gpuUsage.append("failed to obtained gpu usage"); logger.error("Exception Raised : " + e.toString()); @@ -333,7 +292,8 @@ public void run() { } } finally { // WorkerThread is running in thread pool, the thread will be assigned to next - // Runnable once this worker is finished. If currentThread keep holding the reference + // Runnable once this worker is finished. If currentThread keep holding the + // reference // of the thread, currentThread.interrupt() might kill next worker. for (int i = 0; i < backendChannel.size(); i++) { backendChannel.get(i).disconnect(); diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index f419a26657..57f7e40679 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -1347,10 +1347,11 @@ public void testMetricManager() throws JsonParseException, InterruptedException // Wait till first value is read in int count = 0; while (metrics.isEmpty()) { - Thread.sleep(500); + Thread.sleep(1000); metrics = metricManager.getMetrics(); Assert.assertTrue(++count < 5); } + for (Metric metric : metrics) { if (metric.getMetricName().equals("CPUUtilization")) { Assert.assertEquals(metric.getUnit(), "Percent"); diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/AcceleratorTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/AcceleratorTest.java new file mode 100644 index 0000000000..3dd2e07107 --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/AcceleratorTest.java @@ -0,0 +1,76 @@ +package org.pytorch.serve.device; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class AcceleratorTest { + + @Test + public void testAcceleratorConstructor() { + Accelerator accelerator = new Accelerator("TestGPU", AcceleratorVendor.NVIDIA, 0); + Assert.assertEquals(accelerator.getAcceleratorModel(), "TestGPU"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.NVIDIA); + Assert.assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + } + + @Test + public void testGettersAndSetters() { + Accelerator accelerator = new Accelerator("TestGPU", AcceleratorVendor.AMD, 1); + + accelerator.setMemoryAvailableMegaBytes(8192); + Assert.assertEquals(accelerator.getMemoryAvailableMegaBytes(), Integer.valueOf(8192)); + + accelerator.setUsagePercentage(75.5f); + Assert.assertEquals(accelerator.getUsagePercentage(), Float.valueOf(75.5f)); + + accelerator.setMemoryUtilizationPercentage(60.0f); + Assert.assertEquals(accelerator.getMemoryUtilizationPercentage(), Float.valueOf(60.0f)); + + accelerator.setMemoryUtilizationMegabytes(4096); + Assert.assertEquals(accelerator.getMemoryUtilizationMegabytes(), Integer.valueOf(4096)); + } + + @Test + public void testUtilizationToString() { + Accelerator accelerator = new Accelerator("TestGPU", AcceleratorVendor.NVIDIA, 2); + accelerator.setUsagePercentage(80.0f); + accelerator.setMemoryUtilizationPercentage(70.0f); + accelerator.setMemoryUtilizationMegabytes(5120); + + String expected = + "gpuId::2 utilization.gpu::80 % utilization.memory::70 % memory.used::5,120 MiB"; + Assert.assertEquals(accelerator.utilizationToString(), expected); + } + + @Test + public void testUpdateDynamicAttributes() { + Accelerator accelerator = new Accelerator("TestGPU", AcceleratorVendor.INTEL, 3); + accelerator.setUsagePercentage(42.42f); + accelerator.setMemoryUtilizationPercentage(1.0f); + accelerator.setMemoryUtilizationMegabytes(9999999); + Accelerator updated = new Accelerator("UpdatedGPU", AcceleratorVendor.INTEL, 3); + updated.setUsagePercentage(90.0f); + updated.setMemoryUtilizationPercentage(85.0f); + updated.setMemoryUtilizationMegabytes(6144); + + accelerator.updateDynamicAttributes(updated); + + Assert.assertEquals(accelerator.getUsagePercentage(), Float.valueOf(90.0f)); + Assert.assertEquals(accelerator.getMemoryUtilizationPercentage(), Float.valueOf(85.0f)); + Assert.assertEquals(accelerator.getMemoryUtilizationMegabytes(), Integer.valueOf(6144)); + + // Check that static attributes are not updated + Assert.assertEquals(accelerator.getAcceleratorModel(), "TestGPU"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.INTEL); + Assert.assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(3)); + } + + @Test + public void testAcceleratorVendorEnumValues() { + Assert.assertEquals(AcceleratorVendor.AMD.name(), "AMD"); + Assert.assertEquals(AcceleratorVendor.NVIDIA.name(), "NVIDIA"); + Assert.assertEquals(AcceleratorVendor.INTEL.name(), "INTEL"); + Assert.assertEquals(AcceleratorVendor.APPLE.name(), "APPLE"); + Assert.assertEquals(AcceleratorVendor.UNKNOWN.name(), "UNKNOWN"); + } +} diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/SystemInfoTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/SystemInfoTest.java new file mode 100644 index 0000000000..05521217f8 --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/SystemInfoTest.java @@ -0,0 +1,47 @@ +package org.pytorch.serve.device; + +import java.util.LinkedHashSet; +import org.pytorch.serve.device.interfaces.IAcceleratorUtility; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class SystemInfoTest { + + @Test + public void testParseVisibleDevicesEnv() { + LinkedHashSet result = IAcceleratorUtility.parseVisibleDevicesEnv("0,1,2"); + Assert.assertEquals(result.size(), 3); + Assert.assertTrue(result.contains(0)); + Assert.assertTrue(result.contains(1)); + Assert.assertTrue(result.contains(2)); + + result = IAcceleratorUtility.parseVisibleDevicesEnv("0, 1, 2"); + Assert.assertEquals(result.size(), 3); + Assert.assertTrue(result.contains(0)); + Assert.assertTrue(result.contains(1)); + Assert.assertTrue(result.contains(2)); + + result = IAcceleratorUtility.parseVisibleDevicesEnv("0,0,2"); + Assert.assertEquals(result.size(), 2); + Assert.assertTrue(result.contains(0)); + Assert.assertTrue(result.contains(2)); + + result = IAcceleratorUtility.parseVisibleDevicesEnv(""); + Assert.assertTrue(result.isEmpty()); + + result = IAcceleratorUtility.parseVisibleDevicesEnv(null); + Assert.assertTrue(result.isEmpty()); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testParseVisibleDevicesEnvInvalidInput() { + IAcceleratorUtility.parseVisibleDevicesEnv("0,1,a"); + } + + @Test + public void testBytesToMegabytes() { + Assert.assertEquals(IAcceleratorUtility.bytesToMegabytes(1048576L), Integer.valueOf(1)); + Assert.assertEquals(IAcceleratorUtility.bytesToMegabytes(2097152L), Integer.valueOf(2)); + Assert.assertEquals(IAcceleratorUtility.bytesToMegabytes(0L), Integer.valueOf(0)); + } +} diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java new file mode 100644 index 0000000000..c52e105fc4 --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java @@ -0,0 +1,121 @@ +package org.pytorch.serve.device.utils; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertEqualsNoOrder; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class AppleUtilTest { + + private AppleUtil appleUtil; + private String jsonStringPath; + private JsonObject sampleOutputJson; + + @BeforeClass + public void setUp() { + appleUtil = new AppleUtil(); + jsonStringPath = "src/test/resources/metrics/sample_apple_smi.json"; + + try { + FileReader reader = new FileReader(jsonStringPath); + JsonElement jsonElement = JsonParser.parseReader(reader); + sampleOutputJson = jsonElement.getAsJsonObject(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Test + public void testGetGpuEnvVariableName() { + assertNull(appleUtil.getGpuEnvVariableName()); + } + + @Test + public void testGetUtilizationSmiCommand() { + String[] expectedCommand = { + "system_profiler", "-json", "-detailLevel", "mini", "SPDisplaysDataType" + }; + assertEqualsNoOrder(appleUtil.getUtilizationSmiCommand(), expectedCommand); + } + + @Test + public void testJsonObjectToAccelerator() { + JsonObject gpuObject = + sampleOutputJson.getAsJsonArray("SPDisplaysDataType").get(0).getAsJsonObject(); + Accelerator accelerator = appleUtil.jsonObjectToAccelerator(gpuObject); + + assertNotNull(accelerator); + assertEquals(accelerator.getAcceleratorModel(), "Apple M1"); + assertEquals(accelerator.getVendor(), AcceleratorVendor.APPLE); + assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + assertEquals(accelerator.getUsagePercentage(), Float.valueOf(0f)); + assertEquals(accelerator.getMemoryUtilizationPercentage(), Float.valueOf(0f)); + assertEquals(accelerator.getMemoryUtilizationMegabytes(), Integer.valueOf(0)); + } + + @Test + public void testExtractAcceleratorId() { + JsonObject gpuObject = + sampleOutputJson.getAsJsonArray("SPDisplaysDataType").get(0).getAsJsonObject(); + assertEquals(appleUtil.extractAcceleratorId(gpuObject), Integer.valueOf(0)); + } + + @Test + public void testExtractAccelerators() { + List accelerators = appleUtil.extractAccelerators(sampleOutputJson); + + assertEquals(accelerators.size(), 1); + assertEquals(accelerators.get(0).get("sppci_model").getAsString(), "Apple M1"); + } + + @Test + public void testSmiOutputToUpdatedAccelerators() { + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(); + parsedGpuIds.add(0); + + ArrayList updatedAccelerators = + appleUtil.smiOutputToUpdatedAccelerators(sampleOutputJson.toString(), parsedGpuIds); + + assertEquals(updatedAccelerators.size(), 1); + Accelerator accelerator = updatedAccelerators.get(0); + assertEquals(accelerator.getAcceleratorModel(), "Apple M1"); + assertEquals(accelerator.getVendor(), AcceleratorVendor.APPLE); + assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + } + + @Test + public void testGetAvailableAccelerators() { + LinkedHashSet availableAcceleratorIds = new LinkedHashSet<>(); + availableAcceleratorIds.add(0); + + // Mock the callSMI method to return our sample output + AppleUtil spyAppleUtil = + new AppleUtil() { + @Override + public String[] getUtilizationSmiCommand() { + return new String[] {"echo", sampleOutputJson.toString()}; + } + }; + + ArrayList availableAccelerators = + spyAppleUtil.getAvailableAccelerators(availableAcceleratorIds); + + assertEquals(availableAccelerators.size(), 1); + Accelerator accelerator = availableAccelerators.get(0); + assertEquals(accelerator.getAcceleratorModel(), "Apple M1"); + assertEquals(accelerator.getVendor(), AcceleratorVendor.APPLE); + assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + } +} diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/utils/CudaUtilTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/utils/CudaUtilTest.java new file mode 100644 index 0000000000..76c2012658 --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/utils/CudaUtilTest.java @@ -0,0 +1,132 @@ +package org.pytorch.serve.device.utils; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class CudaUtilTest { + + private CudaUtil cudaUtil = new CudaUtil(); + + @Test + public void testGetGpuEnvVariableName() { + Assert.assertEquals(cudaUtil.getGpuEnvVariableName(), "CUDA_VISIBLE_DEVICES"); + } + + @Test + public void testGetUtilizationSmiCommand() { + String[] expectedCommand = { + "nvidia-smi", + "--query-gpu=index,gpu_name,utilization.gpu,utilization.memory,memory.used", + "--format=csv,nounits" + }; + Assert.assertEquals(cudaUtil.getUtilizationSmiCommand(), expectedCommand); + } + + @Test + public void testSmiOutputToUpdatedAccelerators() { + String smiOutput = + "index,gpu_name,utilization.gpu,utilization.memory,memory.used\n" + + "0,NVIDIA GeForce RTX 3080,50,60,8000\n" + + "1,NVIDIA Tesla V100,75,80,16000\n"; + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(java.util.Arrays.asList(0, 1)); + + ArrayList accelerators = + cudaUtil.smiOutputToUpdatedAccelerators(smiOutput, parsedGpuIds); + + Assert.assertEquals(accelerators.size(), 2); + + Accelerator accelerator1 = accelerators.get(0); + Assert.assertEquals((int) accelerator1.getAcceleratorId(), 0); + Assert.assertEquals(accelerator1.getAcceleratorModel(), "NVIDIA GeForce RTX 3080"); + Assert.assertEquals((float) accelerator1.getUsagePercentage(), 50f); + Assert.assertEquals((float) accelerator1.getMemoryUtilizationPercentage(), 60f); + Assert.assertEquals((int) accelerator1.getMemoryUtilizationMegabytes(), 8000); + + Accelerator accelerator2 = accelerators.get(1); + Assert.assertEquals((int) accelerator2.getAcceleratorId(), 1); + Assert.assertEquals(accelerator2.getAcceleratorModel(), "NVIDIA Tesla V100"); + Assert.assertEquals((float) accelerator2.getUsagePercentage(), 75f); + Assert.assertEquals((float) accelerator2.getMemoryUtilizationPercentage(), 80f); + Assert.assertEquals((int) accelerator2.getMemoryUtilizationMegabytes(), 16000); + } + + @Test + public void testParseAccelerator() { + String[] parts = {"0", "NVIDIA GeForce RTX 3080"}; + Accelerator accelerator = cudaUtil.parseAccelerator(parts); + + Assert.assertEquals((int) accelerator.getAcceleratorId(), 0); + Assert.assertEquals(accelerator.getAcceleratorModel(), "NVIDIA GeForce RTX 3080"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.NVIDIA); + } + + @Test + public void testParseAcceleratorWithDifferentId() { + String[] parts = {"2", "NVIDIA Tesla T4"}; + Accelerator accelerator = cudaUtil.parseAccelerator(parts); + + Assert.assertEquals((int) accelerator.getAcceleratorId(), 2); + Assert.assertEquals(accelerator.getAcceleratorModel(), "NVIDIA Tesla T4"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.NVIDIA); + } + + @Test(expectedExceptions = NumberFormatException.class) + public void testParseAcceleratorWithInvalidId() { + String[] parts = {"invalid", "NVIDIA GeForce GTX 1080"}; + cudaUtil.parseAccelerator(parts); + } + + @Test + public void testParseUpdatedAccelerator() { + String[] parts = {"1", "NVIDIA Tesla V100", "75", "80", "16000"}; + Accelerator accelerator = cudaUtil.parseUpdatedAccelerator(parts); + + Assert.assertEquals((int) accelerator.getAcceleratorId(), 1); + Assert.assertEquals(accelerator.getAcceleratorModel(), "NVIDIA Tesla V100"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.NVIDIA); + Assert.assertEquals((float) accelerator.getUsagePercentage(), 75f); + Assert.assertEquals((float) accelerator.getMemoryUtilizationPercentage(), 80f); + Assert.assertEquals((int) accelerator.getMemoryUtilizationMegabytes(), 16000); + } + + @Test + public void testParseUpdatedAcceleratorWithDifferentValues() { + String[] parts = {"3", "NVIDIA A100", "30.5", "45.7", "40960"}; + Accelerator accelerator = cudaUtil.parseUpdatedAccelerator(parts); + + Assert.assertEquals((int) accelerator.getAcceleratorId(), 3); + Assert.assertEquals(accelerator.getAcceleratorModel(), "NVIDIA A100"); + Assert.assertEquals(accelerator.getVendor(), AcceleratorVendor.NVIDIA); + Assert.assertEquals((float) accelerator.getUsagePercentage(), 30.5f); + Assert.assertEquals((float) accelerator.getMemoryUtilizationPercentage(), 45.7f); + Assert.assertEquals((int) accelerator.getMemoryUtilizationMegabytes(), 40960); + } + + @Test(expectedExceptions = NumberFormatException.class) + public void testParseUpdatedAcceleratorWithInvalidUsagePercentage() { + String[] parts = {"0", "NVIDIA GeForce RTX 2080", "invalid", "80", "8000"}; + cudaUtil.parseUpdatedAccelerator(parts); + } + + @Test(expectedExceptions = NumberFormatException.class) + public void testParseUpdatedAcceleratorWithInvalidMemoryUtilization() { + String[] parts = {"0", "NVIDIA GeForce RTX 2080", "75", "invalid", "8000"}; + cudaUtil.parseUpdatedAccelerator(parts); + } + + @Test(expectedExceptions = NumberFormatException.class) + public void testParseUpdatedAcceleratorWithInvalidMemoryUsage() { + String[] parts = {"0", "NVIDIA GeForce RTX 2080", "75", "80", "invalid"}; + cudaUtil.parseUpdatedAccelerator(parts); + } + + @Test(expectedExceptions = ArrayIndexOutOfBoundsException.class) + public void testParseUpdatedAcceleratorWithInsufficientData() { + String[] parts = {"0", "NVIDIA GeForce RTX 2080", "75", "80"}; + cudaUtil.parseUpdatedAccelerator(parts); + } +} diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/utils/ROCmUtilTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/utils/ROCmUtilTest.java new file mode 100644 index 0000000000..26e48264f5 --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/utils/ROCmUtilTest.java @@ -0,0 +1,143 @@ +package org.pytorch.serve.device.utils; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertEqualsNoOrder; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import org.pytorch.serve.device.Accelerator; +import org.pytorch.serve.device.AcceleratorVendor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class ROCmUtilTest { + + private ROCmUtil rocmUtil; + private String sampleDiscoveryJsonPath; + private String sampleMetricsJsonPath; + private String sampleUpdatedMetricsJsonPath; + private JsonObject sampleDiscoveryJsonObject; + private JsonObject sampleMetricsJsonObject; + private JsonObject sampleUpdatedMetricsJsonObject; + + @BeforeClass + public void setUp() { + rocmUtil = new ROCmUtil(); + sampleDiscoveryJsonPath = "src/test/resources/metrics/sample_amd_discovery.json"; + sampleMetricsJsonPath = "src/test/resources/metrics/sample_amd_metrics.json"; + sampleUpdatedMetricsJsonPath = "src/test/resources/metrics/sample_amd_updated_metrics.json"; + + try { + FileReader reader = new FileReader(sampleDiscoveryJsonPath); + JsonElement jsonElement = JsonParser.parseReader(reader); + sampleDiscoveryJsonObject = jsonElement.getAsJsonObject(); + + reader = new FileReader(sampleMetricsJsonPath); + jsonElement = JsonParser.parseReader(reader); + sampleMetricsJsonObject = jsonElement.getAsJsonObject(); + + reader = new FileReader(sampleUpdatedMetricsJsonPath); + jsonElement = JsonParser.parseReader(reader); + sampleUpdatedMetricsJsonObject = jsonElement.getAsJsonObject(); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Test + public void testGetGpuEnvVariableName() { + assertEquals(rocmUtil.getGpuEnvVariableName(), "HIP_VISIBLE_DEVICES"); + } + + @Test + public void testGetUtilizationSmiCommand() { + String[] expectedCommand = { + "rocm-smi", + "--showid", + "--showproductname", + "--showuse", + "--showmemuse", + "--showmeminfo", + "vram", + "-P", + "--json" + }; + assertEqualsNoOrder(rocmUtil.getUtilizationSmiCommand(), expectedCommand); + } + + @Test + public void testExtractAccelerators() { + List accelerators = rocmUtil.extractAccelerators(sampleMetricsJsonObject); + assertEquals(accelerators.size(), 2); + assertEquals(accelerators.get(0).get("cardId").getAsString(), "card0"); + assertEquals(accelerators.get(1).get("cardId").getAsString(), "card1"); + } + + @Test + public void testExtractAcceleratorId() { + JsonObject card0Object = rocmUtil.extractAccelerators(sampleMetricsJsonObject).get(0); + JsonObject card1Object = rocmUtil.extractAccelerators(sampleMetricsJsonObject).get(1); + + Integer acceleratorId0 = rocmUtil.extractAcceleratorId(card0Object); + Integer acceleratorId1 = rocmUtil.extractAcceleratorId(card1Object); + + assertEquals(acceleratorId0, Integer.valueOf(0)); + assertEquals(acceleratorId1, Integer.valueOf(1)); + } + + @Test + public void testJsonMetricsObjectToAccelerator() { + JsonObject card0Object = rocmUtil.extractAccelerators(sampleMetricsJsonObject).get(0); + Accelerator accelerator = rocmUtil.jsonObjectToAccelerator(card0Object); + + assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + assertEquals(accelerator.getAcceleratorModel(), "AMD INSTINCT MI250 (MCM) OAM AC MBA"); + assertEquals(accelerator.getVendor(), AcceleratorVendor.AMD); + assertEquals((float) accelerator.getUsagePercentage(), 50.0f); + assertEquals((float) accelerator.getMemoryUtilizationPercentage(), 75.0f); + assertEquals(accelerator.getMemoryAvailableMegaBytes(), Integer.valueOf(65520)); + assertEquals(accelerator.getMemoryUtilizationMegabytes(), Integer.valueOf(49140)); + } + + @Test + public void testJsonDiscoveryObjectToAccelerator() { + JsonObject card0Object = rocmUtil.extractAccelerators(sampleDiscoveryJsonObject).get(0); + Accelerator accelerator = rocmUtil.jsonObjectToAccelerator(card0Object); + + assertEquals(accelerator.getAcceleratorId(), Integer.valueOf(0)); + assertEquals(accelerator.getAcceleratorModel(), "AMD INSTINCT MI250 (MCM) OAM AC MBA"); + assertEquals(accelerator.getVendor(), AcceleratorVendor.AMD); + } + + @Test + public void testSmiOutputToUpdatedAccelerators() { + String smiOutput = sampleMetricsJsonObject.toString(); + String updatedMetrics = sampleUpdatedMetricsJsonObject.toString(); + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(); + parsedGpuIds.add(0); + parsedGpuIds.add(1); + + ArrayList accelerators = + rocmUtil.smiOutputToUpdatedAccelerators(smiOutput, parsedGpuIds); + accelerators = rocmUtil.smiOutputToUpdatedAccelerators(updatedMetrics, parsedGpuIds); + + assertEquals(accelerators.size(), 2); + + System.out.println(accelerators.toString()); + + Accelerator accelerator0 = accelerators.get(0); + assertEquals(accelerator0.getAcceleratorId(), Integer.valueOf(0)); + assertEquals(accelerator0.getAcceleratorModel(), "AMD INSTINCT MI250 (MCM) OAM AC MBA"); + assertEquals(accelerator0.getVendor(), AcceleratorVendor.AMD); + assertEquals((float) accelerator0.getUsagePercentage(), 25.0f); + assertEquals((float) accelerator0.getMemoryUtilizationPercentage(), 25.0f); + assertEquals(accelerator0.getMemoryAvailableMegaBytes(), Integer.valueOf(65520)); + assertEquals(accelerator0.getMemoryUtilizationMegabytes(), Integer.valueOf(49140)); + } +} diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/utils/XpuUtilTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/utils/XpuUtilTest.java new file mode 100644 index 0000000000..5656a1660c --- /dev/null +++ b/frontend/server/src/test/java/org/pytorch/serve/device/utils/XpuUtilTest.java @@ -0,0 +1,138 @@ +package org.pytorch.serve.device.utils; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.testng.Assert.*; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import org.pytorch.serve.device.Accelerator; +import org.testng.annotations.*; + +public class XpuUtilTest { + + private XpuUtil xpuUtil; + + @BeforeMethod + public void setUp() { + xpuUtil = new XpuUtil(); + } + + @Test + public void testGetGpuEnvVariableName() { + assertEquals( + xpuUtil.getGpuEnvVariableName(), + "XPU_VISIBLE_DEVICES", + "GPU environment variable name should be XPU_VISIBLE_DEVICES"); + } + + @Test + public void testGetUtilizationSmiCommand() { + String[] expectedCommand = {"xpu-smi", "dump", "-d -1", "-n 1", "-m", "0,5"}; + assertArrayEquals( + xpuUtil.getUtilizationSmiCommand(), + expectedCommand, + "Utilization SMI command should match expected"); + } + + @Test + public void testSmiOutputToUpdatedAccelerators() { + String smiOutput = + "Timestamp,DeviceId,GPU Utilization (%),GPU Memory Utilization (%)\n" + + "06:14:46.000,0,50.00,75.50\n" + + "06:14:47.000,1,25.00,60.25"; + + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(); + parsedGpuIds.add(0); + parsedGpuIds.add(1); + + ArrayList updatedAccelerators = + xpuUtil.smiOutputToUpdatedAccelerators(smiOutput, parsedGpuIds); + + assertEquals(updatedAccelerators.size(), 2, "Should return 2 updated accelerators"); + assertEquals( + (int) updatedAccelerators.get(0).getAcceleratorId(), + 0, + "First accelerator should have ID 0"); + assertEquals( + (int) updatedAccelerators.get(1).getAcceleratorId(), + 1, + "Second accelerator should have ID 1"); + assertEquals( + (float) updatedAccelerators.get(0).getUsagePercentage(), + 50.00f, + 0.01, + "GPU utilization should match for first accelerator"); + assertEquals( + (float) updatedAccelerators.get(0).getMemoryUtilizationPercentage(), + 75.50f, + 0.01, + "Memory utilization should match for first accelerator"); + assertEquals( + (float) updatedAccelerators.get(1).getUsagePercentage(), + 25.00f, + 0.01, + "GPU utilization should match for second accelerator"); + assertEquals( + (float) updatedAccelerators.get(1).getMemoryUtilizationPercentage(), + 60.25f, + 0.01, + "Memory utilization should match for second accelerator"); + } + + @Test + public void testSmiOutputToUpdatedAcceleratorsWithFilteredIds() { + String smiOutput = + "Timestamp,DeviceId,GPU Utilization (%),GPU Memory Utilization (%)\n" + + "06:14:46.000,0,50.00,75.50\n" + + "06:14:47.000,1,25.00,60.25\n" + + "06:14:48.000,2,30.00,70.00"; + + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(); + parsedGpuIds.add(0); + parsedGpuIds.add(2); + + ArrayList updatedAccelerators = + xpuUtil.smiOutputToUpdatedAccelerators(smiOutput, parsedGpuIds); + + assertEquals(updatedAccelerators.size(), 2, "Should return 2 updated accelerators"); + assertEquals( + (int) updatedAccelerators.get(0).getAcceleratorId(), + 0, + "First accelerator should have ID 0"); + assertEquals( + (int) updatedAccelerators.get(1).getAcceleratorId(), + 2, + "Second accelerator should have ID 2"); + assertEquals( + (float) updatedAccelerators.get(0).getUsagePercentage(), + 50.00f, + 0.01, + "GPU utilization should match for first accelerator"); + assertEquals( + (float) updatedAccelerators.get(0).getMemoryUtilizationPercentage(), + 75.50f, + 0.01, + "Memory utilization should match for first accelerator"); + assertEquals( + (float) updatedAccelerators.get(1).getUsagePercentage(), + 30.00f, + 0.01, + "GPU utilization should match for third accelerator"); + assertEquals( + (float) updatedAccelerators.get(1).getMemoryUtilizationPercentage(), + 70.00f, + 0.01, + "Memory utilization should match for third accelerator"); + } + + @Test + public void testSmiOutputToUpdatedAcceleratorsWithInvalidInput() { + String invalidSmiOutput = "Invalid SMI output"; + LinkedHashSet parsedGpuIds = new LinkedHashSet<>(); + parsedGpuIds.add(0); + + ArrayList accelerators = + xpuUtil.smiOutputToUpdatedAccelerators(invalidSmiOutput, parsedGpuIds); + assertEquals(accelerators.size(), 0); + } +} diff --git a/frontend/server/src/test/resources/metrics/sample_amd_discovery.json b/frontend/server/src/test/resources/metrics/sample_amd_discovery.json new file mode 100644 index 0000000000..e69c0e6439 --- /dev/null +++ b/frontend/server/src/test/resources/metrics/sample_amd_discovery.json @@ -0,0 +1,26 @@ +{ + "card0": { + "Average Graphics Package Power (W)": "92.0", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Subsystem ID": "0x0b0c", + "Device Rev": "0x01", + "Node ID": "4", + "GUID": "11743", + "GFX Version": "gfx9010" + }, + "card1": { + "Average Graphics Package Power (W)": "N/A (Secondary die)", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Subsystem ID": "0x0b0c", + "Device Rev": "0x01", + "Node ID": "5", + "GUID": "61477", + "GFX Version": "gfx9010" + } +} \ No newline at end of file diff --git a/frontend/server/src/test/resources/metrics/sample_amd_metrics.json b/frontend/server/src/test/resources/metrics/sample_amd_metrics.json new file mode 100644 index 0000000000..688403d772 --- /dev/null +++ b/frontend/server/src/test/resources/metrics/sample_amd_metrics.json @@ -0,0 +1,46 @@ +{ + "card0": { + "Device Name": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Device ID": "0x740c", + "Device Rev": "0x01", + "Subsystem ID": "0x0b0c", + "GUID": "11743", + "Average Graphics Package Power (W)": "92.0", + "GPU use (%)": "50", + "GFX Activity": "62827955", + "GPU Memory Allocated (VRAM%)": "75", + "GPU Memory Read/Write Activity (%)": "0", + "Memory Activity": "17044038", + "Avg. Memory Bandwidth": "0", + "VRAM Total Memory (B)": "68702699520", + "VRAM Total Used Memory (B)": "51527024640", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Node ID": "4", + "GFX Version": "gfx9010" + }, + "card1": { + "Device Name": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Device ID": "0x740c", + "Device Rev": "0x01", + "Subsystem ID": "0x0b0c", + "GUID": "61477", + "Average Graphics Package Power (W)": "N/A (Secondary die)", + "GPU use (%)": "50", + "GFX Activity": "46030661", + "GPU Memory Allocated (VRAM%)": "50", + "GPU Memory Read/Write Activity (%)": "0", + "Memory Activity": "10645369", + "Avg. Memory Bandwidth": "0", + "VRAM Total Memory (B)": "68702699520", + "VRAM Total Used Memory (B)": "51527024640", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Node ID": "5", + "GFX Version": "gfx9010" + } +} \ No newline at end of file diff --git a/frontend/server/src/test/resources/metrics/sample_amd_updated_metrics.json b/frontend/server/src/test/resources/metrics/sample_amd_updated_metrics.json new file mode 100644 index 0000000000..3cea9de9bd --- /dev/null +++ b/frontend/server/src/test/resources/metrics/sample_amd_updated_metrics.json @@ -0,0 +1,46 @@ +{ + "card0": { + "Device Name": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Device ID": "0x740c", + "Device Rev": "0x01", + "Subsystem ID": "0x0b0c", + "GUID": "11743", + "Average Graphics Package Power (W)": "92.0", + "GPU use (%)": "25", + "GFX Activity": "62827955", + "GPU Memory Allocated (VRAM%)": "25", + "GPU Memory Read/Write Activity (%)": "0", + "Memory Activity": "17044038", + "Avg. Memory Bandwidth": "0", + "VRAM Total Memory (B)": "68702699520", + "VRAM Total Used Memory (B)": "51527024640", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Node ID": "4", + "GFX Version": "gfx9010" + }, + "card1": { + "Device Name": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Device ID": "0x740c", + "Device Rev": "0x01", + "Subsystem ID": "0x0b0c", + "GUID": "61477", + "Average Graphics Package Power (W)": "N/A (Secondary die)", + "GPU use (%)": "50", + "GFX Activity": "46030661", + "GPU Memory Allocated (VRAM%)": "50", + "GPU Memory Read/Write Activity (%)": "0", + "Memory Activity": "10645369", + "Avg. Memory Bandwidth": "0", + "VRAM Total Memory (B)": "68702699520", + "VRAM Total Used Memory (B)": "51527024640", + "Card Series": "AMD INSTINCT MI250 (MCM) OAM AC MBA", + "Card Model": "0x740c", + "Card Vendor": "Advanced Micro Devices, Inc. [AMD/ATI]", + "Card SKU": "D65210V", + "Node ID": "5", + "GFX Version": "gfx9010" + } +} \ No newline at end of file diff --git a/frontend/server/src/test/resources/metrics/sample_apple_smi.json b/frontend/server/src/test/resources/metrics/sample_apple_smi.json new file mode 100644 index 0000000000..8562248586 --- /dev/null +++ b/frontend/server/src/test/resources/metrics/sample_apple_smi.json @@ -0,0 +1,33 @@ +{ + "SPDisplaysDataType": [ + { + "_name": "kHW_AppleM1Item", + "spdisplays_metalfamily": "spdisplays_mtlgpufamilyapple7", + "spdisplays_ndrvs": [ + { + "_name": "Color LCD", + "_spdisplays_display-product-id": "a045", + "_spdisplays_display-serial-number": "fd626d62", + "_spdisplays_display-vendor-id": "610", + "_spdisplays_display-week": "0", + "_spdisplays_display-year": "0", + "_spdisplays_displayID": "1", + "_spdisplays_pixels": "2880 x 1800", + "_spdisplays_resolution": "1440 x 900 @ 60.00Hz", + "spdisplays_ambient_brightness": "spdisplays_no", + "spdisplays_connection_type": "spdisplays_internal", + "spdisplays_main": "spdisplays_yes", + "spdisplays_mirror": "spdisplays_off", + "spdisplays_online": "spdisplays_yes", + "spdisplays_pixelresolution": "2880 x 1800", + "spdisplays_resolution": "1440 x 900 @ 60.00Hz" + } + ], + "spdisplays_vendor": "sppci_vendor_Apple", + "sppci_bus": "spdisplays_builtin", + "sppci_cores": "7", + "sppci_device_type": "spdisplays_gpu", + "sppci_model": "Apple M1" + } + ] +} \ No newline at end of file diff --git a/frontend/server/testng.xml b/frontend/server/testng.xml index ee898ca7f9..8a5335ca91 100644 --- a/frontend/server/testng.xml +++ b/frontend/server/testng.xml @@ -16,5 +16,9 @@ - + + + + + diff --git a/kubernetes/kserve/tests/scripts/test_mnist.sh b/kubernetes/kserve/tests/scripts/test_mnist.sh index 5c3532e1e5..7771d2cea9 100755 --- a/kubernetes/kserve/tests/scripts/test_mnist.sh +++ b/kubernetes/kserve/tests/scripts/test_mnist.sh @@ -3,14 +3,13 @@ set -o errexit -o nounset -o pipefail device=$1 +TEST_GPU="false" if [ "$device" = "gpu" ]; then TEST_GPU="true" -else - TEST_GPU="false" fi -function validate_gpu_memory_usage() { +function validate_gpu_memory_usage_nvidia() { echo "Validating GPU memory usage..." memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) @@ -32,6 +31,52 @@ function validate_gpu_memory_usage() { fi } +function validate_gpu_memory_usage_amd() { + # Capture the output of the command into an array, line by line + mapfile -t memory_usage < <(amd-smi metric --mem-usage --csv) + memory_above_zero=false + + for row in "${memory_usage[@]}"; do + # Read each column in the row separated by commas + IFS=',' read -r -a columns <<< "$row" + if [ "${columns[0]}" == "gpu" ]; then + continue + fi + + if [ "${columns[2]}" -gt 0 ]; then + memory_above_zero=true + break + fi + done + + if [ "$memory_above_zero" = true ]; then + echo "GPU memory usage is greater than 0, proceeding with the tests." + else + echo "✘ GPU memory usage is 0, indicating no GPU activity. Test failed." + delete_minikube_cluster + exit 1 + fi +} + +function validate_gpu_memory_usage() { + if [ "$GPU_TYPE" = "nvidia-smi" ]; then + validate_gpu_memory_usage_nvidia + elif [ "$GPU_TYPE" = "amd-smi" ]; then + validate_gpu_memory_usage_amd + fi +} + +function detect_gpu_smi() { + for cmd in nvidia-smi amd-smi system_profiler xpu-smi; do + if command -v "$cmd" && "$cmd" > /dev/null 2>&1; then + echo "$cmd found and able to communicate with GPU(s)." + GPU_TYPE=$cmd + return + fi + done + echo "Cannot communicate with GPU(s)." +} + function start_minikube_cluster() { echo "Removing any previous Kubernetes cluster" minikube delete @@ -204,6 +249,7 @@ install_kserve echo "MNIST KServe V2 test begin" if [ "$TEST_GPU" = "true" ]; then deploy_cluster "kubernetes/kserve/tests/configs/mnist_v2_gpu.yaml" "torchserve-mnist-v2-predictor" + detect_gpu_smi validate_gpu_memory_usage else deploy_cluster "kubernetes/kserve/tests/configs/mnist_v2_cpu.yaml" "torchserve-mnist-v2-predictor" diff --git a/requirements/common_rocm.txt b/requirements/common_rocm.txt new file mode 100644 index 0000000000..20789dd473 --- /dev/null +++ b/requirements/common_rocm.txt @@ -0,0 +1 @@ +pyrsmi; sys_platform == 'linux' \ No newline at end of file diff --git a/requirements/torch_rocm60.txt b/requirements/torch_rocm60.txt new file mode 100644 index 0000000000..f07063f625 --- /dev/null +++ b/requirements/torch_rocm60.txt @@ -0,0 +1,5 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +-r torch_common.txt +torch==2.3.1+rocm6.0; sys_platform == 'linux' +torchvision==0.18.1+rocm6.0; sys_platform == 'linux' +torchaudio==2.3.1+rocm6.0; sys_platform == 'linux' diff --git a/requirements/torch_rocm61.txt b/requirements/torch_rocm61.txt new file mode 100644 index 0000000000..0030b05f7f --- /dev/null +++ b/requirements/torch_rocm61.txt @@ -0,0 +1,4 @@ +--index-url https://download.pytorch.org/whl/rocm6.1 +torch==2.4.1+rocm6.1; sys_platform == 'linux' +torchvision==0.19.1+rocm6.1; sys_platform == 'linux' +torchaudio==2.4.1+rocm6.1; sys_platform == 'linux' diff --git a/ts/metrics/metric_collector.py b/ts/metrics/metric_collector.py index 9e1f9d698c..9032d42e41 100644 --- a/ts/metrics/metric_collector.py +++ b/ts/metrics/metric_collector.py @@ -15,15 +15,15 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--gpu", + "--gpus", action="store", - help="number of GPU", + help="number of GPUs", type=int ) arguments = parser.parse_args() logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) - system_metrics.collect_all(sys.modules['ts.metrics.system_metrics'], arguments.gpu) + system_metrics.collect_all(sys.modules['ts.metrics.system_metrics'], arguments.gpus) check_process_mem_usage(sys.stdin) diff --git a/ts/metrics/system_metrics.py b/ts/metrics/system_metrics.py index e0a21f1c4f..5dda2476e6 100644 --- a/ts/metrics/system_metrics.py +++ b/ts/metrics/system_metrics.py @@ -1,11 +1,13 @@ """ Module to collect system metrics for front-end """ + import logging import types from builtins import str import psutil +import torch from ts.metrics.dimension import Dimension from ts.metrics.metric import Metric @@ -49,74 +51,98 @@ def disk_available(): system_metrics.append(Metric("DiskAvailable", data, "GB", dimension)) -def gpu_utilization(num_of_gpu): +def collect_gpu_metrics(num_of_gpus): """ - Collect gpu metrics. - - :param num_of_gpu: + Collect GPU metrics. Supports NVIDIA and AMD GPUs. + :param num_of_gpus: Total number of available GPUs. :return: """ - if num_of_gpu <= 0: - return - - # pylint: disable=wrong-import-position - # pylint: disable=import-outside-toplevel - import nvgpu - import pynvml - from nvgpu import list_gpus + for gpu_index in range(num_of_gpus): + if torch.version.cuda: + free, total = torch.cuda.mem_get_info(gpu_index) + mem_used = (total - free) // 1024**2 + gpu_mem_utilization = torch.cuda.memory_usage(gpu_index) + gpu_utilization = torch.cuda.utilization(gpu_index) + elif torch.version.hip: + # There is currently a bug in + # https://github.com/pytorch/pytorch/blob/838958de94ed3b9021ddb395fe3e7ed22a60b06c/torch/cuda/__init__.py#L1171 + # which does not capture the rate/percentage correctly. + # Otherwise same methods could be used. + # https://rocm.docs.amd.com/projects/amdsmi/en/latest/how-to/using-amdsmi-for-python.html#amdsmi-get-gpu-activity + import amdsmi + + try: + amdsmi.amdsmi_init() + + handle = amdsmi.amdsmi_get_processor_handles()[gpu_index] + mem_used = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] + engine_usage = amdsmi.amdsmi_get_gpu_activity(handle) + gpu_utilization = engine_usage["gfx_activity"] + gpu_mem_utilization = engine_usage["umc_activity"] + except amdsmi.AmdSmiException as e: + logging.error("Could not initialize AMD-SMI library.") + finally: + try: + amdsmi.amdsmi_shut_down() + except amdsmi.AmdSmiException as e: + logging.error("Could not shut down AMD-SMI library.") - # pylint: enable=wrong-import-position - # pylint: enable=import-outside-toplevel - - info = nvgpu.gpu_info() - for value in info: dimension_gpu = [ Dimension("Level", "Host"), - Dimension("device_id", value["index"]), + Dimension("device_id", gpu_index), ] system_metrics.append( Metric( "GPUMemoryUtilization", - value["mem_used_percent"], + gpu_mem_utilization, "percent", dimension_gpu, ) ) + system_metrics.append(Metric("GPUMemoryUsed", mem_used, "MB", dimension_gpu)) system_metrics.append( - Metric("GPUMemoryUsed", value["mem_used"], "MB", dimension_gpu) + Metric("GPUUtilization", gpu_utilization, "percent", dimension_gpu) ) - try: - statuses = list_gpus.device_statuses() - except pynvml.nvml.NVMLError_NotSupported: - logging.error("gpu device monitoring not supported") - statuses = [] - for idx, status in enumerate(statuses): - dimension_gpu = [Dimension("Level", "Host"), Dimension("device_id", idx)] - system_metrics.append( - Metric("GPUUtilization", status["utilization"], "percent", dimension_gpu) - ) +def gpu_utilization(num_of_gpus): + """ + Generic GPU utilization function that supports NVIDIA and AMD GPUs. + :param num_of_gpu: Total number of available GPUs. + :return: + """ + if num_of_gpus <= 0: + return + + if torch.cuda.is_available() and not (torch.version.cuda or torch.version.hip): + logging.error("No supported GPU detected.") + return + + if torch.cuda.is_available() and torch.version.cuda: + logging.info("Collecting NVIDIA GPU metrics...") + elif torch.cuda.is_available() and torch.version.hip: + logging.info("Collecting AMD GPU metrics...") + collect_gpu_metrics(num_of_gpus) -def collect_all(mod, num_of_gpu): + +def collect_all(mod, num_of_gpus): """ Collect all system metrics. :param mod: - :param num_of_gpu: + :param num_of_gpus: Total number of available GPUs. :return: """ - members = dir(mod) for i in members: value = getattr(mod, i) if isinstance(value, types.FunctionType) and value.__name__ not in ( "collect_all", - "log_msg", + "collect_gpu_metrics", ): if value.__name__ == "gpu_utilization": - value(num_of_gpu) + value(num_of_gpus) else: value() diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index fa4be5841c..1dad241922 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -41,7 +41,7 @@ if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0.0a"): PT2_AVAILABLE = True - if torch.cuda.is_available(): + if torch.cuda.is_available() and torch.version.cuda: # If Ampere enable tensor cores which will give better performance # Ideally get yourself an A10G or A100 for optimal performance if torch.cuda.get_device_capability() >= (8, 0): @@ -227,7 +227,7 @@ def initialize(self, context): if "compile" in pt2_value: compile_options = pt2_value["compile"] - if compile_options["enable"] == True: + if compile_options["enable"]: del compile_options["enable"] # if backend is not provided, compile will use its default, which is valid @@ -284,7 +284,7 @@ def initialize(self, context): self.model = self.model.to(memory_format=torch.channels_last) self.model = self.model.to(self.device) self.model = ipex.optimize(self.model) - logger.info(f"Compiled model with ipex") + logger.info("Compiled model with ipex") logger.debug("Model file %s loaded successfully", self.model_pt_path) @@ -364,7 +364,7 @@ def _use_torch_export_aot_compile(self): export_value = pt2_value.get("export", None) if isinstance(export_value, dict) and "aot_compile" in export_value: torch_export_aot_compile = ( - True if export_value["aot_compile"] == True else False + True if export_value["aot_compile"] else False ) return torch_export_aot_compile diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 20bd76599a..7fb7ba4d8f 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -94,7 +94,12 @@ def install_java(self): def install_nodejs(self): pass - def install_torch_packages(self, cuda_version): + def install_torch_packages(self, cuda_version=None, rocm_version=None): + if cuda_version and rocm_version: + raise ValueError( + "Cannot install both CUDA and ROCm dependencies, please pass only either one." + ) + if cuda_version: if platform.system() == "Darwin": print( @@ -110,6 +115,16 @@ def install_torch_packages(self, cuda_version): os.system( f"{sys.executable} -m pip install -U -r requirements/torch_{cuda_version}_{platform.system().lower()}.txt" ) + elif rocm_version: + if platform.system() in ["Darwin", "Windows"]: + print( + f"ROCm not supported on {platform.system()}. Refer https://pytorch.org/." + ) + sys.exit(1) + else: + os.system( + f"{sys.executable} -m pip install -U -r requirements/torch_{rocm_version}.txt" + ) elif args.neuronx: torch_neuronx_requirements_file = os.path.join( "requirements", "torch_neuronx_linux.txt" @@ -127,7 +142,9 @@ def install_torch_packages(self, cuda_version): f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt" ) - def install_python_packages(self, cuda_version, requirements_file_path, nightly): + def install_python_packages( + self, cuda_version, rocm_version, requirements_file_path, nightly + ): check = "where" if platform.system() == "Windows" else "which" if os.system(f"{check} conda") == 0: # conda install command should run before the pip install commands @@ -143,16 +160,23 @@ def install_python_packages(self, cuda_version, requirements_file_path, nightly) elif args.skip_torch_install: print("Skipping Torch installation") else: - self.install_torch_packages(cuda_version) + self.install_torch_packages( + cuda_version=cuda_version, rocm_version=rocm_version + ) # developer.txt also installs packages from common.txt os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}") - # Install dependencies for GPU + # Install dependencies for NVIDIA GPU if not isinstance(cuda_version, type(None)): gpu_requirements_file = os.path.join("requirements", "common_gpu.txt") os.system(f"{sys.executable} -m pip install -U -r {gpu_requirements_file}") + # Install dependencies for AMD GPU + if not isinstance(rocm_version, type(None)): + gpu_requirements_file = os.path.join("requirements", "common_rocm.txt") + os.system(f"{sys.executable} -m pip install -U -r {gpu_requirements_file}") + # Install dependencies for Inferentia2 if args.neuronx: neuronx_requirements_file = os.path.join("requirements", "neuronx.txt") @@ -306,7 +330,7 @@ def install_neuronx_driver(self): pass -def install_dependencies(cuda_version=None, nightly=False): +def install_dependencies(cuda_version=None, rocm_version=None, nightly=False): os_map = {"Linux": Linux, "Windows": Windows, "Darwin": Darwin} system = os_map[platform.system()]() @@ -325,7 +349,12 @@ def install_dependencies(cuda_version=None, nightly=False): requirements_file = "common.txt" if args.environment == "prod" else "developer.txt" requirements_file_path = os.path.join("requirements", requirements_file) - system.install_python_packages(cuda_version, requirements_file_path, nightly) + system.install_python_packages( + cuda_version, + rocm_version, + requirements_file_path=requirements_file_path, + nightly=nightly, + ) if args.cpp: system.install_cpp_dependencies() @@ -363,6 +392,15 @@ def get_brew_version(): action="store_true", help="Install dependencies for inferentia2 support", ) + parser.add_argument( + "--rocm", + default=None, + choices=[ + "rocm60", + "rocm61", + ], + help="ROCm version for torch", + ) parser.add_argument( "--cpp", action="store_true", @@ -394,4 +432,6 @@ def get_brew_version(): ) args = parser.parse_args() - install_dependencies(cuda_version=args.cuda, nightly=args.nightly_torch) + install_dependencies( + cuda_version=args.cuda, rocm_version=args.rocm, nightly=args.nightly_torch + ) diff --git a/ts_scripts/install_utils b/ts_scripts/install_utils index 3dd01da6fe..cae4a2de5a 100755 --- a/ts_scripts/install_utils +++ b/ts_scripts/install_utils @@ -25,14 +25,16 @@ install_java_deps() set -e } -install_torch_deps() -{ - if is_gpu_instance && [ ! -z "$1" ]; +install_torch_deps() { + if [ ! -z "$1" ]; then - pip install -U -r requirements_$1.txt -f https://download.pytorch.org/whl/torch_stable.html - else - pip install -U -r requirements.txt - fi + if [[ "$1" == *"cu"* || "$1" == *"rocm"* ]] && ! is_gpu_instance; + then + echo "Cannot install GPU-specific requirements." + exit 1 + fi + pip install -U -r requirements/$1.txt + fi } install_pytest_suite_deps() @@ -275,20 +277,15 @@ clean_up_build_residuals() rm -rf ts/utils/__pycache__/ } -is_gpu_instance(){ - if command -v nvidia-smi; - then - nvidia-smi | grep 'NVIDIA-SMI has failed' - if [ $? == 0 ]; - then - return 1 - else +is_gpu_instance() { + for cmd in nvidia-smi amd-smi system_profiler xpu-smi; do + if command -v "$cmd" && "$cmd" > /dev/null 2>&1; then + echo "$cmd found and able to communicate with GPU(s)." return 0 fi - - else - return 1 - fi + done + echo "Cannot communicate with GPU(s)." + return 1 } run_markdown_link_checker(){ diff --git a/ts_scripts/print_env_info.py b/ts_scripts/print_env_info.py index 0e74a61661..430bb98e29 100644 --- a/ts_scripts/print_env_info.py +++ b/ts_scripts/print_env_info.py @@ -13,6 +13,7 @@ except (ImportError, NameError, AttributeError): TORCH_AVAILABLE = False + torchserve_env = { "torch": "**Warning: torch not present ..", "torch_model_archiver": "**Warning: torch-model-archiver not installed ..", @@ -38,7 +39,16 @@ "cuda_runtime_version": "N/A", "nvidia_gpu_models": [], "nvidia_driver_version": "N/A", - "cudnn_version": [], + "nvidia_driver_cuda_version": "N/A", + "cudnn_version": "N/A", +} + +hip_env = { + "is_hip_available": "No", + "hip_runtime_version": "N/A", + "amd_gpu_models": [], + "rocm_version": "N/A", + "miopen_version": "N/A", } npm_env = {"npm_pkg_version": []} @@ -46,14 +56,6 @@ cpp_env = {"LIBRARY_PATH": ""} -def get_nvidia_smi(): - # Note: nvidia-smi is currently available only on Windows and Linux - smi = "nvidia-smi" - if get_platform() == "win32": - smi = "nvidia-smi.exe" - return smi - - def run(command): """Returns (return-code, stdout, stderr)""" p = subprocess.Popen( @@ -197,65 +199,54 @@ def get_cmake_version(): return run_and_parse_first_match("cmake --version", r"cmake (.*)") -def get_nvidia_driver_version(): - smi = get_nvidia_smi() - if get_platform() == "darwin": - cmd = "kextstat | grep -i cuda" - return run_and_parse_first_match(cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]") +def get_gpu_info(): + num_of_gpus = torch.cuda.device_count() + gpu_types = [torch.cuda.get_device_name(gpu_index) for gpu_index in range(num_of_gpus)] + return "\n".join(["", *gpu_types]) - return run_and_parse_first_match(smi, r"Driver Version: (.*?) ") - -def get_nvidia_gpu_info(): - smi = get_nvidia_smi() - if get_platform() == "darwin": - if TORCH_AVAILABLE and torch.cuda.is_available(): - return torch.cuda.get_device_name(None) - return None - uuid_regex = re.compile(r" \(UUID: .+?\)") - rc, out, _ = run(smi + " -L") - if rc != 0: - return None - # Anonymize GPUs by removing their UUID - return "\n" + re.sub(uuid_regex, "", out) +def get_nvidia_driver_version(): + # Local import because ts_scripts/install_dependencies.py + # imports a function from this module at a stage when pynvml is not yet installed + import pynvml + pynvml.nvmlInit() + driver_version = pynvml.nvmlSystemGetDriverVersion() + pynvml.nvmlShutdown() + return driver_version + + +def get_nvidia_driver_cuda_version(): + # Local import because ts_scripts/install_dependencies.py + # imports a function from this module at a stage when pynvml is not yet installed + import pynvml + pynvml.nvmlInit() + cuda = pynvml.nvmlSystemGetCudaDriverVersion() + cuda_major = cuda // 1000 + cuda_minor = (cuda % 1000) // 10 + pynvml.nvmlShutdown() + return f"{cuda_major}.{cuda_minor}" def get_running_cuda_version(): - return run_and_parse_first_match("nvcc --version", r"V([\d.]+)") + cuda = torch._C._cuda_getCompiledVersion() + cuda_major = cuda // 1000 + cuda_minor = (cuda % 1000) // 10 + return f"{cuda_major}.{cuda_minor}" def get_cudnn_version(): - """This will return a list of libcudnn.so; it's hard to tell which one is being used""" - if get_platform() == "win32": - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") - where_cmd = os.path.join(system_root, "System32", "where") - cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == "darwin": - # CUDA libraries and drivers can be found in /usr/local/cuda/. See - cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" - else: - cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' - rc, out, _ = run(cudnn_cmd) - # find will return 1 if there are permission errors or if not found - if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get("CUDNN_LIBRARY") - if l is not None and os.path.isfile(l): - return os.path.realpath(l) - return None - files = set() - for fn in out.split("\n"): - fn = os.path.realpath(fn) # eliminate symbolic links - if os.path.isfile(fn): - files.add(fn) - if not files: - return None - # Alphabetize the result because the order is non-deterministic otherwise - files = list(sorted(files)) - if len(files) == 1: - return files[0] - result = "\n".join(files) - return "Probably one of the following:\n{}".format(result) + cudnn = torch.backends.cudnn.version() + cudnn_major = cudnn // 10000 + cudnn = cudnn % 10000 + cudnn_minor = cudnn // 100 + cudnn_patch = cudnn % 100 + return f"{cudnn_major}.{cudnn_minor}.{cudnn_patch}" + + +def get_miopen_version(): + cfg = torch._C._show_config() + miopen = re.search("MIOpen \d+\.\d+\.\d+", cfg).group() + return miopen.split(" ")[1] def get_torchserve_version(): @@ -341,11 +332,20 @@ def populate_os_env(): def populate_cuda_env(cuda_available_str): cuda_env["is_cuda_available"] = cuda_available_str cuda_env["cuda_runtime_version"] = get_running_cuda_version() - cuda_env["nvidia_gpu_models"] = get_nvidia_gpu_info() + cuda_env["nvidia_gpu_models"] = get_gpu_info() cuda_env["nvidia_driver_version"] = get_nvidia_driver_version() + cuda_env["nvidia_driver_cuda_version"] = get_nvidia_driver_cuda_version() cuda_env["cudnn_version"] = get_cudnn_version() +def populate_hip_env(hip_available_str): + hip_env["is_hip_available"] = hip_available_str + hip_env["hip_runtime_version"] = torch.version.hip + hip_env["amd_gpu_models"] = get_gpu_info() + hip_env["rocm_version"] = run_and_parse_first_match("amd-smi version", r"ROCm version: ([\d.]+)") + hip_env["miopen_version"] = get_miopen_version() + + def populate_npm_env(): npm_env["npm_pkg_version"] = get_npm_packages() @@ -371,8 +371,10 @@ def populate_env_info(): populate_os_env() # cuda environment - if TORCH_AVAILABLE and torch.cuda.is_available(): + if TORCH_AVAILABLE and torch.cuda.is_available() and torch.version.cuda: populate_cuda_env("Yes") + elif TORCH_AVAILABLE and torch.cuda.is_available() and torch.version.hip: + populate_hip_env("Yes") if get_platform() == "darwin": populate_npm_env() @@ -412,11 +414,20 @@ def populate_env_info(): cuda_info_fmt = """ Is CUDA available: {is_cuda_available} CUDA runtime version: {cuda_runtime_version} -GPU models and configuration: {nvidia_gpu_models} +NVIDIA GPU models and configuration: {nvidia_gpu_models} Nvidia driver version: {nvidia_driver_version} +Nvidia driver cuda version: {nvidia_driver_cuda_version} cuDNN version: {cudnn_version} """ +hip_info_fmt = """ +Is HIP available: {is_hip_available} +HIP runtime version: {hip_runtime_version} +AMD GPU models and configuration: {amd_gpu_models} +ROCm version: {rocm_version} +MIOpen version: {miopen_version} +""" + npm_info_fmt = """ Versions of npm installed packages: {npm_pkg_version} @@ -431,6 +442,7 @@ def populate_env_info(): def get_pretty_env_info(branch_name): global env_info_fmt global cuda_info_fmt + global hip_info_fmt global npm_info_fmt global cpp_env_info_fmt populate_env_info() @@ -443,9 +455,12 @@ def get_pretty_env_info(branch_name): **cpp_env, } - if TORCH_AVAILABLE and torch.cuda.is_available(): + if TORCH_AVAILABLE and torch.cuda.is_available() and torch.version.cuda: env_dict.update(cuda_env) env_info_fmt = env_info_fmt + "\n" + cuda_info_fmt + elif TORCH_AVAILABLE and torch.cuda.is_available() and torch.version.hip: + env_dict.update(hip_env) + env_info_fmt = env_info_fmt + "\n" + hip_info_fmt if get_platform() == "darwin": env_dict.update(npm_env) diff --git a/ts_scripts/sanity_utils.py b/ts_scripts/sanity_utils.py index f6b5126213..83dd27c806 100755 --- a/ts_scripts/sanity_utils.py +++ b/ts_scripts/sanity_utils.py @@ -8,9 +8,6 @@ import torch -from ts_scripts import marsgen as mg -from ts_scripts import tsutils as ts -from ts_scripts import utils REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") sys.path.append(REPO_ROOT) @@ -19,6 +16,11 @@ ) +from ts_scripts import marsgen as mg +from ts_scripts import tsutils as ts +from ts_scripts import utils + + async def markdown_link_checker(in_queue, out_queue, n): print(f"worker started {n}") while True: @@ -75,15 +77,23 @@ def run_markdown_link_checker(): def validate_model_on_gpu(): # A quick \ crude way of checking if model is loaded in GPU # Assumption is - - # 1. GPUs on test setup are only utlizied by torchserve + # 1. GPUs on test setup are only utilized by torchserve # 2. Models are successfully UNregistered between subsequent calls - import nvgpu - model_loaded = False - for info in nvgpu.gpu_info(): - if info["mem_used"] > 0 and info["mem_used_percent"] > 0.0: + + if torch.cuda.is_available() and not (torch.version.cuda or torch.version.hip): + return model_loaded + + num_of_gpus = torch.cuda.device_count() + + for gpu_index in range(num_of_gpus): + free, total = torch.cuda.mem_get_info(gpu_index) + mem_used = (total - free) // 1024**2 + mem_used_percent = 100.0 * (1 - (free / total)) + + if mem_used > 0 and mem_used_percent > 0.0: model_loaded = True - break + return model_loaded @@ -218,7 +228,6 @@ def test_sanity(): def test_workflow_sanity(): - current_path = os.getcwd() ts_log_file = os.path.join("logs", "ts_console.log") os.makedirs("model_store", exist_ok=True) os.makedirs("logs", exist_ok=True) @@ -237,7 +246,7 @@ def test_workflow_sanity(): if response and response.status_code == 200: print(response.text) else: - print(f"## Failed to register workflow") + print("## Failed to register workflow") sys.exit(1) # Run prediction on workflow @@ -254,7 +263,7 @@ def test_workflow_sanity(): if response and response.status_code == 200: print(response.text) else: - print(f"## Failed to unregister workflow") + print("## Failed to unregister workflow") sys.exit(1) stopped = ts.stop_torchserve() diff --git a/ts_scripts/utils.py b/ts_scripts/utils.py index de755b8d2d..c1228c4499 100644 --- a/ts_scripts/utils.py +++ b/ts_scripts/utils.py @@ -3,18 +3,29 @@ import subprocess import sys + REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") sys.path.append(REPO_ROOT) + nvidia_smi_cmd = { "Windows": "nvidia-smi.exe", "Darwin": "nvidia-smi", "Linux": "nvidia-smi", } +amd_smi_cmd = { + "Linux": "amd-smi", +} + def is_gpu_instance(): - return True if os.system(nvidia_smi_cmd[platform.system()]) == 0 else False + return ( + True + if os.system(nvidia_smi_cmd[platform.system()]) == 0 + or os.system(amd_smi_cmd[platform.system()]) == 0 + else False + ) def is_conda_build_env(): diff --git a/ts_scripts/validate_model_on_gpu.py b/ts_scripts/validate_model_on_gpu.py index 733fda16db..3ee1014f71 100644 --- a/ts_scripts/validate_model_on_gpu.py +++ b/ts_scripts/validate_model_on_gpu.py @@ -1,13 +1,6 @@ -import nvgpu +from sanity_utils import validate_model_on_gpu -gpu_info = nvgpu.gpu_info() - -model_loaded = False - -for info in gpu_info: - if info['mem_used'] > 0 and info['mem_used_percent'] > 0.0: - model_loaded = True - break +model_loaded = validate_model_on_gpu() if not model_loaded: exit(1) From ce19723bbdc2b44237f0e48d429192d8ed545bd4 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 11 Nov 2024 14:21:38 +0100 Subject: [PATCH 02/33] Update README.md with rocm flags --- README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a74b952708..097d975280 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ TorchServe now enforces token authorization enabled and model API control disabl TorchServe is a flexible and easy-to-use tool for serving and scaling PyTorch models in production. -Requires python >= 3.8 +Requires: +- python >= 3.8 +- Java >= 17 ```bash curl http://127.0.0.1:8080/predictions/bert -T input.txt @@ -22,7 +24,10 @@ curl http://127.0.0.1:8080/predictions/bert -T input.txt ```bash # Install dependencies -# cuda is optional +python ./ts_scripts/install_dependencies.py + +# Include depeendencies for accelerator support with the relevant optional flags +python ./ts_scripts/install_dependencies.py --rocm=rocm61 python ./ts_scripts/install_dependencies.py --cuda=cu121 # Latest release @@ -36,7 +41,10 @@ pip install torchserve-nightly torch-model-archiver-nightly torch-workflow-archi ```bash # Install dependencies -# cuda is optional +python ./ts_scripts/install_dependencies.py + +# Include depeendencies for accelerator support with the relevant optional flags +python ./ts_scripts/install_dependencies.py --rocm=rocm61 python ./ts_scripts/install_dependencies.py --cuda=cu121 # Latest release From 0fad8e269e0af89311bc2547c1a27e7089cd9190 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 11 Nov 2024 14:44:57 +0100 Subject: [PATCH 03/33] add rocm to CONTRIBUTING.md --- CONTRIBUTING.md | 57 ++++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a25e754761..eaf2305ad8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,18 +11,7 @@ Your contributions will fall into two categories: - Search for your issue here: https://github.com/pytorch/serve/issues (look for the "good first issue" tag if you're a first time contributor) - Pick an issue and comment on the task that you want to work on this feature. - To ensure your changes doesn't break any of the existing features run the sanity suite as follows from serve directory: - - Install dependencies (if not already installed) - For CPU - - ```bash - python ts_scripts/install_dependencies.py --environment=dev - ``` - - For GPU - ```bash - python ts_scripts/install_dependencies.py --environment=dev --cuda=cu121 - ``` - > Supported cuda versions as cu121, cu118, cu117, cu116, cu113, cu111, cu102, cu101, cu92 + - [Install dependencies](#Install-TorchServe-for-development) (if not already installed) - Install `pre-commit` to your Git flow: ```bash pre-commit install @@ -60,26 +49,30 @@ pytest -k test/pytest/test_mnist_template.py If you plan to develop with TorchServe and change some source code, you must install it from source code. -Ensure that you have `python3` installed, and the user has access to the site-packages or `~/.local/bin` is added to the `PATH` environment variable. - -Run the following script from the top of the source directory. - -NOTE: This script force re-installs `torchserve`, `torch-model-archiver` and `torch-workflow-archiver` if existing installations are found - -#### For Debian Based Systems/ MacOS - -``` -python ./ts_scripts/install_dependencies.py --environment=dev -python ./ts_scripts/install_from_src.py --environment=dev -``` - -Use `--cuda` flag with `install_dependencies.py` for installing cuda version specific dependencies. Possible values are `cu111`, `cu102`, `cu101`, `cu92` - -#### For Windows - -Refer to the documentation [here](docs/torchserve_on_win_native.md). - -For information about the model archiver, see [detailed documentation](model-archiver/README.md). +1. Clone the repository, including third-party modules, with `git clone --recurse-submodules --remote-submodules git@github.com:pytorch/serve.git` +2. Ensure that you have `python3` installed, and the user has access to the site-packages or `~/.local/bin` is added to the `PATH` environment variable. +3. Run the following script from the top of the source directory. NOTE: This script force re-installs `torchserve`, `torch-model-archiver` and `torch-workflow-archiver` if existing installations are found + + #### For Debian Based Systems/MacOS + + ``` + python ./ts_scripts/install_dependencies.py --environment=dev + python ./ts_scripts/install_from_src.py --environment=dev + ``` + ##### Installing Dependencies for Accelerator Support + Use the optional `--rocm` or `--cuda` flag with `install_dependencies.py` for installing accelerator specific dependencies. + + Possible values are + - rocm: `rocm61`, `rocm60` + - cuda: `cu111`, `cu102`, `cu101`, `cu92` + + For example `python ./ts_scripts/install_dependencies.py --environment=dev --rocm=rocm61` + + #### For Windows + + Refer to the documentation [here](docs/torchserve_on_win_native.md). + + For information about the model archiver, see [detailed documentation](model-archiver/README.md). ### What to Contribute? From 3247498f050ad98079f3a7cdfb41adf3cb329426 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 12 Nov 2024 12:04:07 +0100 Subject: [PATCH 04/33] WorkerLifeCycle uses SystemInfo to get X_VISIBLE_DEVICES --- .../src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 177a515f32..0a6b95b294 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -136,9 +136,8 @@ private void startWorkerPython(int port, String deviceIds) attachRunner(argl, envp, port, deviceIds); } else { if (deviceIds != null) { - AcceleratorVendor visibleDeviceEnvName = - configManager.systemInfo.getAcceleratorVendor(); - envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds); + String visibleDeviceEnvName = configManager.systemInfo.getVisibleDevicesEnvName(); + envp.add(visibleDeviceEnvName + "=" + deviceIds); } argl.add(EnvironmentUtils.getPythonRunTime(model)); } From bae9b2cb62a827841f8f30e19143de0a9e0dfdf9 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 12 Nov 2024 13:01:01 +0100 Subject: [PATCH 05/33] AppleUtil adds Accelerator `number_of_cores` times --- .../pytorch/serve/device/utils/AppleUtil.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java index ae87e85255..921763adca 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java @@ -68,13 +68,19 @@ public Integer extractAcceleratorId(JsonObject cardObject) { @Override public List extractAccelerators(JsonElement rootObject) { List accelerators = new ArrayList<>(); - JsonArray displaysArray = - rootObject - .getAsJsonObject() // Gets the outer object - .get("SPDisplaysDataType") // Gets the "SPDisplaysDataType" element - .getAsJsonArray(); + JsonArray displaysArray = rootObject + .getAsJsonObject() // Gets the outer object + .get("SPDisplaysDataType") // Gets the "SPDisplaysDataType" element + .getAsJsonArray(); JsonObject gpuObject = displaysArray.get(0).getAsJsonObject(); - accelerators.add(gpuObject); + int number_of_cores = Integer.parseInt(gpuObject.get("sppci_cores").getAsString()); + + // add the object `number_of_cores` times to maintain the exsisitng + // functionality + accelerators = IntStream.range(0, number_of_cores) + .mapToObj(i -> gpuObject) + .collect(Collectors.toList()); + return accelerators; } From 88f3cb833a1214f2faa0961bde6d2ca33d27c16f Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Wed, 13 Nov 2024 15:46:14 +0100 Subject: [PATCH 06/33] fix typo in README.md Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 097d975280..143cf3d035 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ curl http://127.0.0.1:8080/predictions/bert -T input.txt # Install dependencies python ./ts_scripts/install_dependencies.py -# Include depeendencies for accelerator support with the relevant optional flags +# Include dependencies for accelerator support with the relevant optional flags python ./ts_scripts/install_dependencies.py --rocm=rocm61 python ./ts_scripts/install_dependencies.py --cuda=cu121 From 8e4d24cc63a640fff86d8a775dd9ed530c57d3d6 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Wed, 13 Nov 2024 15:54:52 +0100 Subject: [PATCH 07/33] remove mention of java version from README.md --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 143cf3d035..0b0e9992a7 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,7 @@ TorchServe now enforces token authorization enabled and model API control disabl TorchServe is a flexible and easy-to-use tool for serving and scaling PyTorch models in production. -Requires: -- python >= 3.8 -- Java >= 17 +Requires python >= 3.8 ```bash curl http://127.0.0.1:8080/predictions/bert -T input.txt From ff4daa8bd0022d3d35122988b624bf05a8191780 Mon Sep 17 00:00:00 2001 From: Samu Tamminen Date: Thu, 14 Nov 2024 08:50:29 +0100 Subject: [PATCH 08/33] revert unnecessary changes --- docker/Dockerfile.dev | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index bea2787bcc..bb5578a650 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -10,7 +10,7 @@ # For reference: # https://docs.docker.com/develop/develop-images/build_enhancements/ -ARG BASE_IMAGE=ubuntu:24.04 +ARG BASE_IMAGE=ubuntu:rolling ARG BUILD_TYPE=dev FROM ${BASE_IMAGE} AS compile-image @@ -110,6 +110,7 @@ RUN python -m pip install -U pip setuptools \ && chown -R model-server /home/venv EXPOSE 8080 8081 8082 7070 7071 +USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] @@ -138,5 +139,4 @@ RUN set -ex \ FROM ${BUILD_TYPE}-image AS final-image ARG BUILD_TYPE -ENV CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 RUN echo "${BUILD_TYPE} image creation completed" From 0bc3e3c2af5586ce60306adf550414abbd3f93e9 Mon Sep 17 00:00:00 2001 From: jakki Date: Thu, 14 Nov 2024 16:20:15 +0200 Subject: [PATCH 09/33] Fix import errors in AppleUtils --- .../src/main/java/org/pytorch/serve/device/utils/AppleUtil.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java index 921763adca..ef3bc2485d 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java @@ -7,6 +7,8 @@ import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.pytorch.serve.device.Accelerator; import org.pytorch.serve.device.AcceleratorVendor; import org.pytorch.serve.device.interfaces.IAcceleratorUtility; From 1e635e13192ba2e00c2d26c07167b46db5c32708 Mon Sep 17 00:00:00 2001 From: Samu Tamminen Date: Thu, 14 Nov 2024 15:44:26 +0100 Subject: [PATCH 10/33] remove rocm support from dockerfile.dev to simplify --- docker/Dockerfile.dev | 35 ++++------------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index bb5578a650..01b5826de8 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -19,7 +19,6 @@ ARG BRANCH_NAME=master ARG REPO_URL=https://github.com/pytorch/serve.git ARG MACHINE_TYPE=cpu ARG CUDA_VERSION -ARG ROCM_VERSION ARG BUILD_WITH_IPEX ARG IPEX_VERSION=1.11.0 @@ -42,7 +41,7 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ git \ python$PYTHON_VERSION \ python$PYTHON_VERSION-dev \ - python3-setuptools \ + python3-distutils \ python$PYTHON_VERSION-venv \ python3-venv \ build-essential \ @@ -50,8 +49,6 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ curl \ vim \ numactl \ - zip \ - wget \ && if [ "$BUILD_WITH_IPEX" = "true" ]; then apt-get update && apt-get install -y libjemalloc-dev libgoogle-perftools-dev libomp-dev && ln -s /usr/lib/x86_64-linux-gnu/libjemalloc.so /usr/lib/libjemalloc.so && ln -s /usr/lib/x86_64-linux-gnu/libtcmalloc.so /usr/lib/libtcmalloc.so && ln -s /usr/lib/x86_64-linux-gnu/libiomp5.so /usr/lib/libiomp5.so; fi \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ @@ -61,43 +58,19 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ RUN update-alternatives --install /usr/bin/python python /usr/bin/python$PYTHON_VERSION 1 \ && update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 -RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ - if [ -n "$ROCM_VERSION" ]; then \ - apt-get update \ - && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ - && DEBIAN_FRONTEND=noninteractive sudo apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ - && sudo apt-get update \ - && sudo apt-get install --no-install-recommends -y amdgpu-dkms rocm \ - && cd /home/; \ - else \ - echo "Skip ROCm installation"; \ - fi - # Build Dev Image FROM compile-image AS dev-image ARG MACHINE_TYPE=cpu ARG CUDA_VERSION -RUN if [ "$MACHINE_TYPE" = "nvidia_gpu" ]; then export USE_CUDA=1; fi \ +RUN if [ "$MACHINE_TYPE" = "gpu" ]; then export USE_CUDA=1; fi \ && git clone $REPO_URL \ && cd serve \ && git checkout ${BRANCH_NAME} \ && python$PYTHON_VERSION -m venv /home/venv ENV PATH="/home/venv/bin:$PATH" WORKDIR serve - -COPY . . - RUN python -m pip install -U pip setuptools \ - && if ([ -z "$CUDA_VERSION" ] && [ -z "$ROCM_VERSION" ]); then \ - python ts_scripts/install_dependencies.py --environment=dev; \ - elif [ -n "$ROCM_VERSION" ]; then \ - python ts_scripts/install_dependencies.py --environment=dev --rocm $ROCM_VERSION \ - && cd /opt/rocm/share/amd_smi \ - && pip install . \ - && cd /serve/; \ - else \ - python ts_scripts/install_dependencies.py --environment=dev --cuda $CUDA_VERSION; \ - fi \ + && if [ -z "$CUDA_VERSION" ]; then python ts_scripts/install_dependencies.py --environment=dev; else python ts_scripts/install_dependencies.py --environment=dev --cuda $CUDA_VERSION; fi \ && if [ "$BUILD_WITH_IPEX" = "true" ]; then python -m pip install --no-cache-dir intel_extension_for_pytorch==${IPEX_VERSION} -f ${IPEX_URL}; fi \ && python ts_scripts/install_from_src.py \ && useradd -m model-server \ @@ -139,4 +112,4 @@ RUN set -ex \ FROM ${BUILD_TYPE}-image AS final-image ARG BUILD_TYPE -RUN echo "${BUILD_TYPE} image creation completed" +RUN echo "${BUILD_TYPE} image creation completed" \ No newline at end of file From 16478266680c2431f667291f0e366bae0611cc00 Mon Sep 17 00:00:00 2001 From: Samu Tamminen Date: Thu, 14 Nov 2024 15:46:23 +0100 Subject: [PATCH 11/33] fix missing newline --- docker/Dockerfile.dev | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 01b5826de8..2f02d84680 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -112,4 +112,4 @@ RUN set -ex \ FROM ${BUILD_TYPE}-image AS final-image ARG BUILD_TYPE -RUN echo "${BUILD_TYPE} image creation completed" \ No newline at end of file +RUN echo "${BUILD_TYPE} image creation completed" From 0dc51450718269e27e12ffcdf009849414a8eb4e Mon Sep 17 00:00:00 2001 From: Samu Tamminen Date: Thu, 14 Nov 2024 08:50:29 +0100 Subject: [PATCH 12/33] revert unnecessary changes --- docs/contents.rst | 7 ++- docs/hardware_support/amd_support.md | 48 +++++++++++++++++++ .../apple_silicon_support.md | 34 ++++++------- docs/hardware_support/hardware_support.rst | 7 +++ docs/{ => hardware_support}/linux_aarch64.md | 0 docs/{ => hardware_support}/nvidia_mps.md | 0 6 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 docs/hardware_support/amd_support.md rename docs/{ => hardware_support}/apple_silicon_support.md (90%) create mode 100644 docs/hardware_support/hardware_support.rst rename docs/{ => hardware_support}/linux_aarch64.md (100%) rename docs/{ => hardware_support}/nvidia_mps.md (100%) diff --git a/docs/contents.rst b/docs/contents.rst index 1ba7e83e32..9113e770c2 100644 --- a/docs/contents.rst +++ b/docs/contents.rst @@ -16,7 +16,6 @@ model_zoo request_envelopes server - nvidia_mps snapshot intel_extension_for_pytorch torchserve_on_win_native @@ -27,6 +26,12 @@ Security FAQs +.. toctree:: + :maxdepth: 0 + :caption: Hardware Support: + + hardware_support/hardware_support + .. toctree:: :maxdepth: 0 :caption: Service APIs: diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md new file mode 100644 index 0000000000..f47b8a78f7 --- /dev/null +++ b/docs/hardware_support/amd_support.md @@ -0,0 +1,48 @@ +# AMD Support + +TorchServe can be run on any combination of operating system and device that is +[supported by ROCm](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/compatibility.html). + +## Supported Versions of ROCm + +The current stable `major.patch` version of ROCm and the previous path version will be supported. For example version `N.2` and `N.1` where `N` is the current major version. + +## Installation + + 1. Make sure you have **python >= 3.8 installed** on your system. + 2. clone the repo + `git clone git@github.com:pytorch/serve.git` + 3. cd into the cloned folder + `cd serve` + 4. create a virtual environment for python + `python -m venv venv` + 5. activate the virtual environment. If you use another shell (fish, csh, powershell) use the relevant option in from `/venv/bin/` + `source venv/bin/activate` + 6. install the dependencies needed for ROCm support. + `python ./ts_scripts/install_dependencies.py --rocm=rocm61` + `python ./ts_scripts/install_from_src.py` + 7. enable amd-smi in the python virtual environment + `sudo chown -R $USER:$USER /opt/rocm/share/amd_smi/` + `pip install -e /opt/rocm/share/amd_smi/` + +### Selecting Accelerators Using `HIP_VISIBLE_DEVICES` + +If you have multiple accelerators on the system where you are running TorchServe yuo can select which accelerators should be visible to TorchServe +by setting the environment variable `HIP_VISIBLE_DEVICES` to a string of 0-indexed comma-separated integers representing the ids of the accelerators. + +If you have 8 accelerators but only want TorchServe to see the last four of them do `export HIP_VISIBLE_DEVICES=4,5,6,7`. + +>ℹ️ Not setting `HIP_VISIBLE_DEVICES` will cause TorchServe to use all available accelerators on the system it is running on. + +> ⚠️ You can run into trouble if you set `HIP_VISIBLE_DEVICES` to an empty string. +> eg. `export HIP_VISIBLE_DEVICES=` or `export HIP_VISIBLE_DEVICES=""` +> use `unset HIP_VISIBLE_DEVICES` if you want to remove it's effect. + +> ⚠️ Setting both `CUDA_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` may cause a unintended behaviour and should be avoided. +> Doing so may cause an exception in the future. + +## Example Usage + +After installing TorchServe with re required dependencies for ROCm you should be ready to serve your model. + +For a simple example, refer to `serve/examples/image_classifier/mnist/`. diff --git a/docs/apple_silicon_support.md b/docs/hardware_support/apple_silicon_support.md similarity index 90% rename from docs/apple_silicon_support.md rename to docs/hardware_support/apple_silicon_support.md index facd8a7f28..6e0f479b8a 100644 --- a/docs/apple_silicon_support.md +++ b/docs/hardware_support/apple_silicon_support.md @@ -1,19 +1,19 @@ -# Apple Silicon Support +# Apple Silicon Support -## What is supported +## What is supported * TorchServe CI jobs now include M1 hardware in order to ensure support, [documentation](https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories) on github M1 hardware. - - [Regression Tests](https://github.com/pytorch/serve/blob/master/.github/workflows/regression_tests_cpu.yml) - - [Regression binaries Test](https://github.com/pytorch/serve/blob/master/.github/workflows/regression_tests_cpu_binaries.yml) + - [Regression Tests](https://github.com/pytorch/serve/blob/master/.github/workflows/regression_tests_cpu.yml) + - [Regression binaries Test](https://github.com/pytorch/serve/blob/master/.github/workflows/regression_tests_cpu_binaries.yml) * For [Docker](https://docs.docker.com/desktop/install/mac-install/) ensure Docker for Apple silicon is installed then follow [setup steps](https://github.com/pytorch/serve/tree/master/docker) ## Experimental Support -* For GPU jobs on Apple Silicon, [MPS](https://pytorch.org/docs/master/notes/mps.html) is now auto detected and enabled. To prevent TorchServe from using MPS, users have to set `deviceType: "cpu"` in model-config.yaml. - * This is an experimental feature and NOT ALL models are guaranteed to work. +* For GPU jobs on Apple Silicon, [MPS](https://pytorch.org/docs/master/notes/mps.html) is now auto detected and enabled. To prevent TorchServe from using MPS, users have to set `deviceType: "cpu"` in model-config.yaml. + * This is an experimental feature and NOT ALL models are guaranteed to work. * Number of GPUs now reports GPUs on Apple Silicon -### Testing -* [Pytests](https://github.com/pytorch/serve/tree/master/test/pytest/test_device_config.py) that checks for MPS on MacOS M1 devices +### Testing +* [Pytests](https://github.com/pytorch/serve/tree/master/test/pytest/test_device_config.py) that checks for MPS on MacOS M1 devices * Models that have been tested and work: Resnet-18, Densenet161, Alexnet * Models that have been tested and DO NOT work: MNIST @@ -31,10 +31,10 @@ Config file: N/A Inference address: http://127.0.0.1:8080 Management address: http://127.0.0.1:8081 Metrics address: http://127.0.0.1:8082 -Model Store: +Model Store: Initial Models: resnet-18=resnet-18.mar -Log dir: -Metrics dir: +Log dir: +Metrics dir: Netty threads: 0 Netty client threads: 0 Default workers per model: 16 @@ -48,7 +48,7 @@ Custom python dependency for model allowed: false Enable metrics API: true Metrics mode: LOG Disable system metrics: false -Workflow Store: +Workflow Store: CPP log config: N/A Model config: N/A 024-04-08T14:18:02,380 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Loading snapshot serializer plugin... @@ -69,17 +69,17 @@ serve % curl http://127.0.0.1:8080/predictions/resnet-18 -T ./examples/image_cla } ... ``` -#### Conda Example +#### Conda Example ``` -(myenv) serve % pip list | grep torch +(myenv) serve % pip list | grep torch torch 2.2.1 torchaudio 2.2.1 torchdata 0.7.1 torchtext 0.17.1 torchvision 0.17.1 (myenv3) serve % conda install -c pytorch-nightly torchserve torch-model-archiver torch-workflow-archiver -(myenv3) serve % pip list | grep torch +(myenv3) serve % pip list | grep torch torch 2.2.1 torch-model-archiver 0.10.0b20240312 torch-workflow-archiver 0.2.12b20240312 @@ -119,11 +119,11 @@ System metrics command: default 2024-03-12T15:58:54,702 [DEBUG] main org.pytorch.serve.wlm.ModelManager - updateModel: densenet161, count: 10 Model server started. ... -(myenv3) serve % curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg +(myenv3) serve % curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg { "tabby": 0.46661922335624695, "tiger_cat": 0.46449029445648193, "Egyptian_cat": 0.0661405548453331, "lynx": 0.001292439759708941, "plastic_bag": 0.00022909720428287983 -} \ No newline at end of file +} diff --git a/docs/hardware_support/hardware_support.rst b/docs/hardware_support/hardware_support.rst new file mode 100644 index 0000000000..75c201c2de --- /dev/null +++ b/docs/hardware_support/hardware_support.rst @@ -0,0 +1,7 @@ +.. toctree:: + :caption: Hardware Support: + + amd_support + apple_silicon_support + linux_aarch64 + nvidia_mps diff --git a/docs/linux_aarch64.md b/docs/hardware_support/linux_aarch64.md similarity index 100% rename from docs/linux_aarch64.md rename to docs/hardware_support/linux_aarch64.md diff --git a/docs/nvidia_mps.md b/docs/hardware_support/nvidia_mps.md similarity index 100% rename from docs/nvidia_mps.md rename to docs/hardware_support/nvidia_mps.md From f905d0ee591cbe23cffff9ed2615d04bfc514afb Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Thu, 14 Nov 2024 17:44:45 +0100 Subject: [PATCH 13/33] 'improve formatting for amd_support.md' --- docs/hardware_support/amd_support.md | 51 +++++++++++++++++++--------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index f47b8a78f7..cf459b554d 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -9,21 +9,40 @@ The current stable `major.patch` version of ROCm and the previous path version w ## Installation - 1. Make sure you have **python >= 3.8 installed** on your system. - 2. clone the repo - `git clone git@github.com:pytorch/serve.git` - 3. cd into the cloned folder - `cd serve` - 4. create a virtual environment for python - `python -m venv venv` - 5. activate the virtual environment. If you use another shell (fish, csh, powershell) use the relevant option in from `/venv/bin/` - `source venv/bin/activate` - 6. install the dependencies needed for ROCm support. - `python ./ts_scripts/install_dependencies.py --rocm=rocm61` - `python ./ts_scripts/install_from_src.py` - 7. enable amd-smi in the python virtual environment - `sudo chown -R $USER:$USER /opt/rocm/share/amd_smi/` - `pip install -e /opt/rocm/share/amd_smi/` + - Make sure you have **python >= 3.8 installed** on your system. + - clone the repo + ```bash + git clone git@github.com:pytorch/serve.git + ``` + + - cd into the cloned folder + + ```bash + cd serve + ``` + + - create a virtual environment for python + + ```bash + python -m venv venv + ``` + + - activate the virtual environment. If you use another shell (fish, csh, powershell) use the relevant option in from `/venv/bin/` + ```bash + source venv/bin/activate + ``` + + - install the dependencies needed for ROCm support. + + ```bash + python ./ts_scripts/install_dependencies.py --rocm=rocm61 + python ./ts_scripts/install_from_src.py + ``` + - enable amd-smi in the python virtual environment + ```bash + sudo chown -R $USER:$USER /opt/rocm/share/amd_smi/ + pip install -e /opt/rocm/share/amd_smi/ + ``` ### Selecting Accelerators Using `HIP_VISIBLE_DEVICES` @@ -32,7 +51,7 @@ by setting the environment variable `HIP_VISIBLE_DEVICES` to a string of 0-index If you have 8 accelerators but only want TorchServe to see the last four of them do `export HIP_VISIBLE_DEVICES=4,5,6,7`. ->ℹ️ Not setting `HIP_VISIBLE_DEVICES` will cause TorchServe to use all available accelerators on the system it is running on. +>ℹ️ **Not setting** `HIP_VISIBLE_DEVICES` will cause TorchServe to use all available accelerators on the system it is running on. > ⚠️ You can run into trouble if you set `HIP_VISIBLE_DEVICES` to an empty string. > eg. `export HIP_VISIBLE_DEVICES=` or `export HIP_VISIBLE_DEVICES=""` From 9a515b8a8c3309a64226f90fc36139dbb7f1b0d8 Mon Sep 17 00:00:00 2001 From: jakki Date: Mon, 18 Nov 2024 16:37:49 +0200 Subject: [PATCH 14/33] Fix AppleUtils tests --- .../java/org/pytorch/serve/device/utils/AppleUtilTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java b/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java index c52e105fc4..e333f7ec83 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/device/utils/AppleUtilTest.java @@ -76,7 +76,7 @@ public void testExtractAcceleratorId() { public void testExtractAccelerators() { List accelerators = appleUtil.extractAccelerators(sampleOutputJson); - assertEquals(accelerators.size(), 1); + assertEquals(accelerators.size(), 7); assertEquals(accelerators.get(0).get("sppci_model").getAsString(), "Apple M1"); } @@ -88,7 +88,7 @@ public void testSmiOutputToUpdatedAccelerators() { ArrayList updatedAccelerators = appleUtil.smiOutputToUpdatedAccelerators(sampleOutputJson.toString(), parsedGpuIds); - assertEquals(updatedAccelerators.size(), 1); + assertEquals(updatedAccelerators.size(), 7); Accelerator accelerator = updatedAccelerators.get(0); assertEquals(accelerator.getAcceleratorModel(), "Apple M1"); assertEquals(accelerator.getVendor(), AcceleratorVendor.APPLE); @@ -112,7 +112,7 @@ public String[] getUtilizationSmiCommand() { ArrayList availableAccelerators = spyAppleUtil.getAvailableAccelerators(availableAcceleratorIds); - assertEquals(availableAccelerators.size(), 1); + assertEquals(availableAccelerators.size(), 7); Accelerator accelerator = availableAccelerators.get(0); assertEquals(accelerator.getAcceleratorModel(), "Apple M1"); assertEquals(accelerator.getVendor(), AcceleratorVendor.APPLE); From 9d3015930d2855035a2646020770ed18f358416f Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Wed, 20 Nov 2024 15:02:07 +0100 Subject: [PATCH 15/33] fixes 11. parse-metrics-failed-collecting-amd-gpu-metrics (#24) --- ts/metrics/system_metrics.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/ts/metrics/system_metrics.py b/ts/metrics/system_metrics.py index 5dda2476e6..3a04b949bf 100644 --- a/ts/metrics/system_metrics.py +++ b/ts/metrics/system_metrics.py @@ -57,6 +57,8 @@ def collect_gpu_metrics(num_of_gpus): :param num_of_gpus: Total number of available GPUs. :return: """ + if num_of_gpus <= 0: + return for gpu_index in range(num_of_gpus): if torch.version.cuda: free, total = torch.cuda.mem_get_info(gpu_index) @@ -105,27 +107,6 @@ def collect_gpu_metrics(num_of_gpus): ) -def gpu_utilization(num_of_gpus): - """ - Generic GPU utilization function that supports NVIDIA and AMD GPUs. - :param num_of_gpu: Total number of available GPUs. - :return: - """ - if num_of_gpus <= 0: - return - - if torch.cuda.is_available() and not (torch.version.cuda or torch.version.hip): - logging.error("No supported GPU detected.") - return - - if torch.cuda.is_available() and torch.version.cuda: - logging.info("Collecting NVIDIA GPU metrics...") - elif torch.cuda.is_available() and torch.version.hip: - logging.info("Collecting AMD GPU metrics...") - - collect_gpu_metrics(num_of_gpus) - - def collect_all(mod, num_of_gpus): """ Collect all system metrics. @@ -141,7 +122,7 @@ def collect_all(mod, num_of_gpus): "collect_all", "collect_gpu_metrics", ): - if value.__name__ == "gpu_utilization": + if value.__name__ == "collect_gpu_metrics": value(num_of_gpus) else: value() From 8cdf54b7f5e66dd16b84f6c3d70b63914864fb60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Wed, 20 Nov 2024 19:10:35 +0000 Subject: [PATCH 16/33] extend testMetricManager --- .../org/pytorch/serve/ModelServerTest.java | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index 57f7e40679..edc21f6bbd 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -26,7 +26,9 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; @@ -1340,6 +1342,21 @@ public void testErrorBatch() throws InterruptedException { alwaysRun = true, dependsOnMethods = {"testErrorBatch"}) public void testMetricManager() throws JsonParseException, InterruptedException { + final String UNIT = "Unit"; + final String LEVEL = "Level"; + final String HOST = "Host"; + + // Define expected metrics + // See ts/metrics/system_metrics.py, ts/configs/metrics.yaml + Map> expectedMetrics = new HashMap<>(); + expectedMetrics.put("CPUUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); + expectedMetrics.put("MemoryUsed", Map.of(UNIT, "Megabytes", LEVEL, HOST)); + expectedMetrics.put("MemoryAvailable", Map.of(UNIT, "Megabytes", LEVEL, HOST)); + expectedMetrics.put("MemoryUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); + expectedMetrics.put("DiskUsage", Map.of(UNIT, "Gigabytes", LEVEL, HOST)); + expectedMetrics.put("DiskUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); + expectedMetrics.put("DiskAvailable", Map.of(UNIT, "Gigabytes", LEVEL, HOST)); + MetricManager.scheduleMetrics(configManager); MetricManager metricManager = MetricManager.getInstance(); List metrics = metricManager.getMetrics(); @@ -1352,19 +1369,25 @@ public void testMetricManager() throws JsonParseException, InterruptedException Assert.assertTrue(++count < 5); } + Assert.assertEquals(metrics.size(), expectedMetrics.size()); + for (Metric metric : metrics) { - if (metric.getMetricName().equals("CPUUtilization")) { - Assert.assertEquals(metric.getUnit(), "Percent"); - } - if (metric.getMetricName().equals("MemoryUsed")) { - Assert.assertEquals(metric.getUnit(), "Megabytes"); + String metricName = metric.getMetricName(); + Assert.assertTrue(expectedMetrics.containsKey(metricName)); + + Map expectedValues = expectedMetrics.get(metricName); + Assert.assertEquals(expectedValues.get(UNIT), metric.getUnit()); + + List dimensions = metric.getDimensions(); + Map dimensionMap = new HashMap<>(); + for (Dimension dimension : dimensions) { + dimensionMap.put(dimension.getName(), dimension.getValue()); } - if (metric.getMetricName().equals("DiskUsed")) { - List dimensions = metric.getDimensions(); - for (Dimension dimension : dimensions) { - if (dimension.getName().equals("Level")) { - Assert.assertEquals(dimension.getValue(), "Host"); - } + + for (Map.Entry entry : expectedValues.entrySet()) { + if (!entry.getKey().equals(UNIT)) { + Assert.assertTrue(dimensionMap.containsKey(entry.getKey())); + Assert.assertEquals(entry.getValue(), dimensionMap.get(entry.getKey())); } } } From e5d382f915264c2a75d874031e3c6350bdd792fb Mon Sep 17 00:00:00 2001 From: jakki Date: Thu, 14 Nov 2024 13:22:57 +0200 Subject: [PATCH 17/33] Add latest ROCM support --- requirements/torch_rocm62.txt | 4 ++++ ts_scripts/install_dependencies.py | 5 +---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 requirements/torch_rocm62.txt diff --git a/requirements/torch_rocm62.txt b/requirements/torch_rocm62.txt new file mode 100644 index 0000000000..291a07b410 --- /dev/null +++ b/requirements/torch_rocm62.txt @@ -0,0 +1,4 @@ +--index-url https://download.pytorch.org/whl/rocm6.2 +torch==2.5.1+rocm6.2; sys_platform == 'linux' +torchvision==0.20.1+rocm6.2; sys_platform == 'linux' +torchaudio==2.5.1+rocm6.2; sys_platform == 'linux' diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 7fb7ba4d8f..e66b5b5701 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -395,10 +395,7 @@ def get_brew_version(): parser.add_argument( "--rocm", default=None, - choices=[ - "rocm60", - "rocm61", - ], + choices=["rocm60", "rocm61", "rocm62"], help="ROCm version for torch", ) parser.add_argument( From f2d17d58584e7a59f34764b5fd0be7187ea410c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Fri, 22 Nov 2024 08:05:08 +0000 Subject: [PATCH 18/33] PR 24 system_metrics bugfix --- ts/metrics/system_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ts/metrics/system_metrics.py b/ts/metrics/system_metrics.py index 3a04b949bf..5e69377f5a 100644 --- a/ts/metrics/system_metrics.py +++ b/ts/metrics/system_metrics.py @@ -120,7 +120,6 @@ def collect_all(mod, num_of_gpus): value = getattr(mod, i) if isinstance(value, types.FunctionType) and value.__name__ not in ( "collect_all", - "collect_gpu_metrics", ): if value.__name__ == "collect_gpu_metrics": value(num_of_gpus) From 49bc0519f5b86875f14acc5d34edf402f0f0d6bb Mon Sep 17 00:00:00 2001 From: jakki Date: Fri, 22 Nov 2024 15:44:26 +0200 Subject: [PATCH 19/33] Format files --- .../pytorch/serve/device/utils/AppleUtil.java | 16 +++++++++------- .../org/pytorch/serve/wlm/WorkerLifeCycle.java | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java index ef3bc2485d..3c32be3317 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java +++ b/frontend/server/src/main/java/org/pytorch/serve/device/utils/AppleUtil.java @@ -70,18 +70,20 @@ public Integer extractAcceleratorId(JsonObject cardObject) { @Override public List extractAccelerators(JsonElement rootObject) { List accelerators = new ArrayList<>(); - JsonArray displaysArray = rootObject - .getAsJsonObject() // Gets the outer object - .get("SPDisplaysDataType") // Gets the "SPDisplaysDataType" element - .getAsJsonArray(); + JsonArray displaysArray = + rootObject + .getAsJsonObject() // Gets the outer object + .get("SPDisplaysDataType") // Gets the "SPDisplaysDataType" element + .getAsJsonArray(); JsonObject gpuObject = displaysArray.get(0).getAsJsonObject(); int number_of_cores = Integer.parseInt(gpuObject.get("sppci_cores").getAsString()); // add the object `number_of_cores` times to maintain the exsisitng // functionality - accelerators = IntStream.range(0, number_of_cores) - .mapToObj(i -> gpuObject) - .collect(Collectors.toList()); + accelerators = + IntStream.range(0, number_of_cores) + .mapToObj(i -> gpuObject) + .collect(Collectors.toList()); return accelerators; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 0a6b95b294..0b3186f099 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -15,7 +15,6 @@ import java.util.regex.Pattern; import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.archive.model.ModelConfig.ParallelType; -import org.pytorch.serve.device.AcceleratorVendor; import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.metrics.MetricCache; import org.pytorch.serve.util.ConfigManager; @@ -136,7 +135,8 @@ private void startWorkerPython(int port, String deviceIds) attachRunner(argl, envp, port, deviceIds); } else { if (deviceIds != null) { - String visibleDeviceEnvName = configManager.systemInfo.getVisibleDevicesEnvName(); + String visibleDeviceEnvName = + configManager.systemInfo.getVisibleDevicesEnvName(); envp.add(visibleDeviceEnvName + "=" + deviceIds); } argl.add(EnvironmentUtils.getPythonRunTime(model)); From 4bff6d30a215111efd493dcaf165d60049fc1b82 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 26 Nov 2024 09:40:07 +0100 Subject: [PATCH 20/33] Update docs/hardware_support/amd_support.md Co-authored-by: Jeff Daily --- docs/hardware_support/amd_support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index cf459b554d..e9334861bb 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -62,6 +62,6 @@ If you have 8 accelerators but only want TorchServe to see the last four of them ## Example Usage -After installing TorchServe with re required dependencies for ROCm you should be ready to serve your model. +After installing TorchServe with the required dependencies for ROCm you should be ready to serve your model. For a simple example, refer to `serve/examples/image_classifier/mnist/`. From b9a16272324b8a8928b7e870351bd512321747d1 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 26 Nov 2024 09:40:48 +0100 Subject: [PATCH 21/33] typo in docs/hardware_support/amd_support.md Co-authored-by: Jeff Daily --- docs/hardware_support/amd_support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index e9334861bb..6db6a36e52 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -57,7 +57,7 @@ If you have 8 accelerators but only want TorchServe to see the last four of them > eg. `export HIP_VISIBLE_DEVICES=` or `export HIP_VISIBLE_DEVICES=""` > use `unset HIP_VISIBLE_DEVICES` if you want to remove it's effect. -> ⚠️ Setting both `CUDA_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` may cause a unintended behaviour and should be avoided. +> ⚠️ Setting both `CUDA_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` may cause unintended behaviour and should be avoided. > Doing so may cause an exception in the future. ## Example Usage From 964e5f129db8aebfe386f99c3cf7d92939e04380 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 26 Nov 2024 09:41:17 +0100 Subject: [PATCH 22/33] Update docs/hardware_support/amd_support.md Co-authored-by: Jeff Daily --- docs/hardware_support/amd_support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index 6db6a36e52..eeb044ab23 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -55,7 +55,7 @@ If you have 8 accelerators but only want TorchServe to see the last four of them > ⚠️ You can run into trouble if you set `HIP_VISIBLE_DEVICES` to an empty string. > eg. `export HIP_VISIBLE_DEVICES=` or `export HIP_VISIBLE_DEVICES=""` -> use `unset HIP_VISIBLE_DEVICES` if you want to remove it's effect. +> use `unset HIP_VISIBLE_DEVICES` if you want to remove its effect. > ⚠️ Setting both `CUDA_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` may cause unintended behaviour and should be avoided. > Doing so may cause an exception in the future. From 61da32e599da4ea26b0cf8a58df303f7ffafc62c Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 26 Nov 2024 09:41:32 +0100 Subject: [PATCH 23/33] Update docs/hardware_support/amd_support.md Co-authored-by: Jeff Daily --- docs/hardware_support/amd_support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index eeb044ab23..f406f980a6 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -46,7 +46,7 @@ The current stable `major.patch` version of ROCm and the previous path version w ### Selecting Accelerators Using `HIP_VISIBLE_DEVICES` -If you have multiple accelerators on the system where you are running TorchServe yuo can select which accelerators should be visible to TorchServe +If you have multiple accelerators on the system where you are running TorchServe you can select which accelerators should be visible to TorchServe by setting the environment variable `HIP_VISIBLE_DEVICES` to a string of 0-indexed comma-separated integers representing the ids of the accelerators. If you have 8 accelerators but only want TorchServe to see the last four of them do `export HIP_VISIBLE_DEVICES=4,5,6,7`. From 0a4d628ca4c8a446efb7749a07d94b9ea8e45c38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Tue, 26 Nov 2024 13:09:18 +0000 Subject: [PATCH 24/33] remove pyrsmi and nvgpu deps --- requirements/common_gpu.txt | 2 -- requirements/common_rocm.txt | 1 - ts_scripts/install_dependencies.py | 10 ---------- ts_scripts/print_env_info.py | 2 +- 4 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 requirements/common_gpu.txt delete mode 100644 requirements/common_rocm.txt diff --git a/requirements/common_gpu.txt b/requirements/common_gpu.txt deleted file mode 100644 index 1e893cc7c1..0000000000 --- a/requirements/common_gpu.txt +++ /dev/null @@ -1,2 +0,0 @@ -nvgpu; sys_platform != 'win32' -nvgpu==0.10.0; sys_platform == 'win32' diff --git a/requirements/common_rocm.txt b/requirements/common_rocm.txt deleted file mode 100644 index 20789dd473..0000000000 --- a/requirements/common_rocm.txt +++ /dev/null @@ -1 +0,0 @@ -pyrsmi; sys_platform == 'linux' \ No newline at end of file diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index e66b5b5701..4d464e03fe 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -167,16 +167,6 @@ def install_python_packages( # developer.txt also installs packages from common.txt os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}") - # Install dependencies for NVIDIA GPU - if not isinstance(cuda_version, type(None)): - gpu_requirements_file = os.path.join("requirements", "common_gpu.txt") - os.system(f"{sys.executable} -m pip install -U -r {gpu_requirements_file}") - - # Install dependencies for AMD GPU - if not isinstance(rocm_version, type(None)): - gpu_requirements_file = os.path.join("requirements", "common_rocm.txt") - os.system(f"{sys.executable} -m pip install -U -r {gpu_requirements_file}") - # Install dependencies for Inferentia2 if args.neuronx: neuronx_requirements_file = os.path.join("requirements", "neuronx.txt") diff --git a/ts_scripts/print_env_info.py b/ts_scripts/print_env_info.py index 430bb98e29..b6d194b688 100644 --- a/ts_scripts/print_env_info.py +++ b/ts_scripts/print_env_info.py @@ -112,7 +112,7 @@ def run_with_pip(pip): elif package_name == "torch": grep_cmd = 'grep "' + package_name + '"' else: - grep_cmd = r'grep "numpy\|pytest\|pylint\|transformers\|psutil\|wheel\|requests\|sentencepiece\|pillow\|captum\|nvgpu\|pygit2\|torch"' + grep_cmd = r'grep "numpy\|pytest\|pylint\|transformers\|psutil\|wheel\|requests\|sentencepiece\|pillow\|captum\|pygit2\|torch"' return run_and_read_all(pip + " list --format=freeze | " + grep_cmd) out = run_with_pip("pip3") From aa96f2f6b5127722d8ad6d51da1bc5499eb7334c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Tue, 26 Nov 2024 14:36:14 +0000 Subject: [PATCH 25/33] metric collector revert gpu arg name --- ts/metrics/metric_collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ts/metrics/metric_collector.py b/ts/metrics/metric_collector.py index 9032d42e41..6a811d1bcc 100644 --- a/ts/metrics/metric_collector.py +++ b/ts/metrics/metric_collector.py @@ -15,7 +15,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--gpus", + "--gpu", action="store", help="number of GPUs", type=int @@ -24,6 +24,6 @@ logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) - system_metrics.collect_all(sys.modules['ts.metrics.system_metrics'], arguments.gpus) + system_metrics.collect_all(sys.modules['ts.metrics.system_metrics'], arguments.gpu) check_process_mem_usage(sys.stdin) From a26eefbfb7810c6c47d73846e6ddedc4c9af36cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Tue, 26 Nov 2024 16:04:54 +0000 Subject: [PATCH 26/33] fix number of metrics assertion in testMetricManager --- .../src/test/java/org/pytorch/serve/ModelServerTest.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index edc21f6bbd..bd7f654ce7 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -1349,6 +1349,9 @@ public void testMetricManager() throws JsonParseException, InterruptedException // Define expected metrics // See ts/metrics/system_metrics.py, ts/configs/metrics.yaml Map> expectedMetrics = new HashMap<>(); + expectedMetrics.put("GPUMemoryUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); + expectedMetrics.put("GPUMemoryUsed", Map.of(UNIT, "Megabytes", LEVEL, HOST)); + expectedMetrics.put("GPUUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); expectedMetrics.put("CPUUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); expectedMetrics.put("MemoryUsed", Map.of(UNIT, "Megabytes", LEVEL, HOST)); expectedMetrics.put("MemoryAvailable", Map.of(UNIT, "Megabytes", LEVEL, HOST)); @@ -1369,7 +1372,8 @@ public void testMetricManager() throws JsonParseException, InterruptedException Assert.assertTrue(++count < 5); } - Assert.assertEquals(metrics.size(), expectedMetrics.size()); + // 7 system-level metrics + 3 gpu-specific metrics + Assert.assertEquals(metrics.size(), 7 + 3 * configManager.getNumberOfGpu()); for (Metric metric : metrics) { String metricName = metric.getMetricName(); From f0b1dfb18f2d2a14f629a7587046ab2a1826db77 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Wed, 27 Nov 2024 12:58:32 +0100 Subject: [PATCH 27/33] 'move Intel docs under Hardware Support' (#31) --- docs/contents.rst | 1 - docs/hardware_support/hardware_support.rst | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents.rst b/docs/contents.rst index 9113e770c2..c42a6a3076 100644 --- a/docs/contents.rst +++ b/docs/contents.rst @@ -17,7 +17,6 @@ request_envelopes server snapshot - intel_extension_for_pytorch torchserve_on_win_native torchserve_on_wsl use_cases diff --git a/docs/hardware_support/hardware_support.rst b/docs/hardware_support/hardware_support.rst index 75c201c2de..267525fc65 100644 --- a/docs/hardware_support/hardware_support.rst +++ b/docs/hardware_support/hardware_support.rst @@ -5,3 +5,4 @@ apple_silicon_support linux_aarch64 nvidia_mps + Intel Extension for PyTorch From d3304945625edb5d673019fa752c64d5ee1882cd Mon Sep 17 00:00:00 2001 From: jakki Date: Wed, 27 Nov 2024 14:50:09 +0200 Subject: [PATCH 28/33] Fix docstring --- .../src/main/java/org/pytorch/serve/util/ConfigManager.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index ab8693185e..e73b138b4c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -51,10 +51,8 @@ import org.slf4j.LoggerFactory; public final class ConfigManager { - // Variables that can be configured through config.properties and Environment - // Variables - // NOTE: Variables which can be configured through environment variables - // **SHOULD** have a + // Variables that can be configured through config.properties and Environment Variables + // NOTE: Variables which can be configured through environment variables **SHOULD** have a // "TS_" prefix private static final String TS_DEBUG = "debug"; From cbdfe255371c63696dd10cabbd619397079f2f9d Mon Sep 17 00:00:00 2001 From: jakki Date: Thu, 28 Nov 2024 13:40:37 +0200 Subject: [PATCH 29/33] Add Dockerfile.rocm --- docker/Dockerfile | 12 +- docker/Dockerfile.rocm | 321 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 327 insertions(+), 6 deletions(-) create mode 100644 docker/Dockerfile.rocm diff --git a/docker/Dockerfile b/docker/Dockerfile index 94f4a1ba99..3a2ba23a98 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -37,7 +37,7 @@ ARG BRANCH_NAME ARG REPO_URL=https://github.com/pytorch/serve.git ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ @@ -112,12 +112,12 @@ FROM ${BASE_IMAGE} AS production-image ARG PYTHON_VERSION ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,target=/var/cache/apt \ +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository ppa:deadsnakes/ppa -y && \ - apt remove python-pip python3-pip && \ + apt remove -y python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-distutils \ @@ -158,12 +158,12 @@ ARG PYTHON_VERSION ARG BRANCH_NAME ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,target=/var/cache/apt \ +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository -y ppa:deadsnakes/ppa && \ - apt remove python-pip python3-pip && \ + apt remove -y python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-distutils \ @@ -207,7 +207,7 @@ ARG BUILD_WITH_IPEX ARG IPEX_VERSION=1.11.0 ARG IPEX_URL=https://software.intel.com/ipex-whl-stable ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,target=/var/cache/apt \ +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 0000000000..d8c7b842f7 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,321 @@ +# syntax = docker/dockerfile:experimental +# +# This file can build images for cpu and gpu env. By default it builds image for CPU. +# Use following option to build image for cuda/GPU: --build-arg BASE_IMAGE=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 +# Here is complete command for GPU/cuda - +# $ DOCKER_BUILDKIT=1 docker build --file Dockerfile --build-arg BASE_IMAGE=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 -t torchserve:latest . +# +# Following comments have been shamelessly copied from https://github.com/pytorch/pytorch/blob/master/Dockerfile +# +# NOTE: To build this you will need a docker version > 18.06 with +# experimental enabled and DOCKER_BUILDKIT=1 +# +# If you do not use buildkit you are not going to have a good time +# +# For reference: +# https://docs.docker.com/develop/develop-images/build_enhancements/ + +ARG BASE_IMAGE=ubuntu:24.04 +ARG BRANCH_NAME=master +# Note: +# Define here the default python version to be used in all later build-stages as default. +# ARG and ENV variables do not persist across stages (they're build-stage scoped). +# That is crucial for ARG PYTHON_VERSION, which otherwise becomes "" leading to nasty bugs, +# that don't let the build fail, but break current version handling logic and result +# in images with wrong python version. To fix that, we will restate the ARG PYTHON_VERSION +# on each build-stage. +ARG PYTHON_VERSION=3.11 + +FROM ${BASE_IMAGE} AS compile-image +ARG BASE_IMAGE=ubuntu:24.04 +ARG PYTHON_VERSION +ARG BUILD_NIGHTLY +ARG BUILD_FROM_SRC +ARG LOCAL_CHANGES +ARG BRANCH_NAME +ARG REPO_URL=https://github.com/pytorch/serve.git +ENV PYTHONUNBUFFERED TRUE + +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install software-properties-common -y && \ + add-apt-repository -y ppa:deadsnakes/ppa && \ + apt remove -y python-pip python3-pip && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + ca-certificates \ + g++ \ + python3-setuptools \ + python$PYTHON_VERSION \ + python$PYTHON_VERSION-dev \ + python$PYTHON_VERSION-venv \ + openjdk-17-jdk \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Make the virtual environment and "activating" it by adding it first to the path. +# From here on the python$PYTHON_VERSION interpreter is used and the packages +# are installed in /home/venv which is what we need for the "runtime-image" +RUN python$PYTHON_VERSION -m venv /home/venv +ENV PATH="/home/venv/bin:$PATH" + +ARG USE_ROCM_VERSION="" + +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + if [ -n "$USE_ROCM_VERSION" ]; then \ + apt-get update \ + && curl -O https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ + && apt-get update \ + && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ + else \ + echo "Skip ROCm installation"; \ + fi + +COPY ./ serve + +RUN \ + if echo "$LOCAL_CHANGES" | grep -q "false"; then \ + rm -rf /serve;\ + git clone --recursive $REPO_URL -b $BRANCH_NAME /serve; \ + fi + + +WORKDIR "/serve" + +RUN cp docker/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh + +RUN \ + # Install ROCm version specific binary when ROCm version is specified as a build arg + if [ "$USE_ROCM_VERSION" ]; then \ + python ./ts_scripts/install_dependencies.py --rocm $USE_ROCM_VERSION \ + && python -m pip install /opt/rocm/share/amd_smi; \ + # Install the binary with the latest CPU image on a ROCm base image + else \ + python ./ts_scripts/install_dependencies.py;\ + fi; + +# Make sure latest version of torchserve is uploaded before running this +RUN \ + if echo "$BUILD_FROM_SRC" | grep -q "true"; then \ + python -m pip install -r requirements/developer.txt \ + && python ts_scripts/install_from_src.py;\ + elif echo "$BUILD_NIGHTLY" | grep -q "false"; then \ + python -m pip install --no-cache-dir torchserve torch-model-archiver torch-workflow-archiver;\ + else \ + python -m pip install --no-cache-dir torchserve-nightly torch-model-archiver-nightly torch-workflow-archiver-nightly;\ + fi + +# Final image for production +FROM ${BASE_IMAGE} AS production-image +# Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) +ARG PYTHON_VERSION +ARG USE_ROCM_VERSION +ENV PYTHONUNBUFFERED TRUE + +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install software-properties-common -y && \ + add-apt-repository ppa:deadsnakes/ppa -y && \ + apt remove -y python-pip python3-pip && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + python$PYTHON_VERSION \ + python3-setuptools \ + python$PYTHON_VERSION-dev \ + python$PYTHON_VERSION-venv \ + # using openjdk-17-jdk due to circular dependency(ca-certificates) bug in openjdk-17-jre-headless debian package + # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=1009905 + openjdk-17-jdk \ + build-essential \ + && rm -rf /var/lib/apt/lists/* \ + && cd /tmp + +RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ + if [ "$USE_ROCM_VERSION" ]; then \ + apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && apt-get update \ + && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ + else \ + echo "Skip ROCm installation"; \ + fi + +RUN useradd -m model-server \ + && mkdir -p /home/model-server/tmp + +COPY --chown=model-server --from=compile-image /home/venv /home/venv +COPY --from=compile-image /usr/local/bin/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh +ENV PATH="/home/venv/bin:$PATH" + +COPY --from=compile-image /opt/rocm/share/amd_smi /opt/rocm/share/amd_smi + +RUN \ + if [ "$USE_ROCM_VERSION" ]; then \ + python -m pip install /opt/rocm/share/amd_smi; \ + else \ + echo "Skip ROCm installation"; \ + fi + +RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ + && chown -R model-server /home/model-server + +COPY docker/config.properties /home/model-server/config.properties +RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store + +EXPOSE 8080 8081 8082 7070 7071 + +USER model-server +WORKDIR /home/model-server +ENV TEMP=/home/model-server/tmp +ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] +CMD ["serve"] + +# Final image for docker regression +FROM ${BASE_IMAGE} AS ci-image +# Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) +ARG PYTHON_VERSION +ARG BRANCH_NAME +ENV PYTHONUNBUFFERED TRUE + +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install software-properties-common -y && \ + add-apt-repository -y ppa:deadsnakes/ppa && \ + apt remove -y python-pip python3-pip && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + python$PYTHON_VERSION \ + python3-setuptools \ + python$PYTHON_VERSION-dev \ + python$PYTHON_VERSION-venv \ + # using openjdk-17-jdk due to circular dependency(ca-certificates) bug in openjdk-17-jre-headless debian package + # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=1009905 + openjdk-17-jdk \ + build-essential \ + wget \ + numactl \ + nodejs \ + npm \ + zip \ + unzip \ + && npm install -g newman@5.3.2 newman-reporter-htmlextra markdown-link-check \ + && rm -rf /var/lib/apt/lists/* \ + && cd /tmp + +RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ + if [ "$USE_ROCM_VERSION" ]; then \ + apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && apt-get update \ + && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ + else \ + echo "Skip ROCm installation"; \ + fi + +COPY --from=compile-image /home/venv /home/venv +ENV PATH="/home/venv/bin:$PATH" + +RUN python -m pip install --no-cache-dir -r https://raw.githubusercontent.com/pytorch/serve/$BRANCH_NAME/requirements/developer.txt + +COPY --from=compile-image /opt/rocm/share/amd_smi /opt/rocm/share/amd_smi + +RUN \ + if [ "$USE_ROCM_VERSION" ]; then \ + python -m pip install /opt/rocm/share/amd_smi; \ + else \ + echo "Skip ROCm installation"; \ + fi + +RUN mkdir /serve +ENV TS_RUN_IN_DOCKER True + +WORKDIR /serve +CMD ["python", "test/regression_tests.py"] + +#Final image for developer Docker image +FROM ${BASE_IMAGE} as dev-image +# Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) +ARG PYTHON_VERSION +ARG BRANCH_NAME +ARG USE_ROCM_VERSION +ARG BUILD_FROM_SRC +ARG LOCAL_CHANGES +ARG BUILD_WITH_IPEX +ARG IPEX_VERSION=1.11.0 +ARG IPEX_URL=https://software.intel.com/ipex-whl-stable +ENV PYTHONUNBUFFERED TRUE +RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install software-properties-common -y && \ + add-apt-repository -y ppa:deadsnakes/ppa && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + fakeroot \ + ca-certificates \ + dpkg-dev \ + sudo \ + g++ \ + git \ + python$PYTHON_VERSION \ + python$PYTHON_VERSION-dev \ + python3-setuptools \ + python$PYTHON_VERSION-venv \ + # using openjdk-17-jdk due to circular dependency(ca-certificates) bug in openjdk-17-jre-headless debian package + # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=1009905 + openjdk-17-jdk \ + build-essential \ + curl \ + vim \ + numactl \ + && rm -rf /var/lib/apt/lists/* + +RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ + if [ "$USE_ROCM_VERSION" ]; then \ + apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && apt-get update \ + && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ + else \ + echo "Skip ROCm installation"; \ + fi + +COPY ./ /serve + +RUN \ + if echo "$LOCAL_CHANGES" | grep -q "false"; then \ + rm -rf /serve;\ + git clone --recursive $REPO_URL -b $BRANCH_NAME /serve; \ + fi + +COPY --from=compile-image /home/venv /home/venv +ENV PATH="/home/venv/bin:$PATH" + +WORKDIR "/serve" + +RUN \ + if [ "$USE_ROCM_VERSION" ]; then \ + python ts_scripts/install_dependencies.py --environment=dev --rocm $USE_ROCM_VERSION \ + && python -m pip install /opt/rocm/share/amd_smi; \ + # Install the binary with the latest CPU image on a ROCm base image + else \ + python ts_scripts/install_dependencies.py --environment=dev;\ + fi; + +RUN python -m pip install -U pip setuptools \ + && python -m pip install --no-cache-dir -r requirements/developer.txt \ + && python ts_scripts/install_from_src.py --environment=dev \ + && useradd -m model-server \ + && mkdir -p /home/model-server/tmp \ + && cp docker/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh \ + && chmod +x /usr/local/bin/dockerd-entrypoint.sh \ + && chown -R model-server /home/model-server \ + && cp docker/config.properties /home/model-server/config.properties \ + && mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store \ + && chown -R model-server /home/venv +EXPOSE 8080 8081 8082 7070 7071 +WORKDIR /home/model-server +ENV TEMP=/home/model-server/tmp +ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] +CMD ["serve"] From 8330233793015cb8fdcd5fa97240b08bc2718d48 Mon Sep 17 00:00:00 2001 From: jakki Date: Thu, 28 Nov 2024 13:44:33 +0200 Subject: [PATCH 30/33] Remove sharing lock from bind mounts --- docker/Dockerfile.rocm | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index d8c7b842f7..e32cbe857a 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -62,7 +62,7 @@ ENV PATH="/home/venv/bin:$PATH" ARG USE_ROCM_VERSION="" -RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ +RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ if [ -n "$USE_ROCM_VERSION" ]; then \ apt-get update \ && curl -O https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ @@ -132,7 +132,7 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp -RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ +RUN --mount=type=bind,from=compile-image,target=/mnt \ if [ "$USE_ROCM_VERSION" ]; then \ apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ @@ -204,7 +204,7 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp -RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ +RUN --mount=type=bind,from=compile-image,target=/mnt \ if [ "$USE_ROCM_VERSION" ]; then \ apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ @@ -271,7 +271,7 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ numactl \ && rm -rf /var/lib/apt/lists/* -RUN --mount=type=bind,sharing=locked,from=compile-image,target=/mnt \ +RUN --mount=type=bind,from=compile-image,target=/mnt \ if [ "$USE_ROCM_VERSION" ]; then \ apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ From 9e5afd030a52f03f6de03981c366352d139e097b Mon Sep 17 00:00:00 2001 From: jakki Date: Fri, 29 Nov 2024 15:35:50 +0200 Subject: [PATCH 31/33] Update Dockerfile.rocm --- docker/Dockerfile.rocm | 117 ++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index e32cbe857a..a6f578ecb4 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -41,7 +41,6 @@ RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository -y ppa:deadsnakes/ppa && \ - apt remove -y python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ ca-certificates \ g++ \ @@ -51,6 +50,7 @@ RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ python$PYTHON_VERSION-venv \ openjdk-17-jdk \ curl \ + wget \ git \ && rm -rf /var/lib/apt/lists/* @@ -62,17 +62,6 @@ ENV PATH="/home/venv/bin:$PATH" ARG USE_ROCM_VERSION="" -RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ - if [ -n "$USE_ROCM_VERSION" ]; then \ - apt-get update \ - && curl -O https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ - && apt-get update \ - && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ - else \ - echo "Skip ROCm installation"; \ - fi - COPY ./ serve RUN \ @@ -81,45 +70,53 @@ RUN \ git clone --recursive $REPO_URL -b $BRANCH_NAME /serve; \ fi - WORKDIR "/serve" RUN cp docker/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + if [ -n "$USE_ROCM_VERSION" ]; then \ + apt-get update \ + && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ + && apt-get update \ + && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ + else \ + echo "Skip ROCm installation"; \ + fi + RUN \ # Install ROCm version specific binary when ROCm version is specified as a build arg if [ "$USE_ROCM_VERSION" ]; then \ - python ./ts_scripts/install_dependencies.py --rocm $USE_ROCM_VERSION \ - && python -m pip install /opt/rocm/share/amd_smi; \ + python$PYTHON_VERSION ./ts_scripts/install_dependencies.py --rocm $USE_ROCM_VERSION; \ # Install the binary with the latest CPU image on a ROCm base image else \ - python ./ts_scripts/install_dependencies.py;\ + python$PYTHON_VERSION ./ts_scripts/install_dependencies.py;\ fi; # Make sure latest version of torchserve is uploaded before running this RUN \ if echo "$BUILD_FROM_SRC" | grep -q "true"; then \ - python -m pip install -r requirements/developer.txt \ - && python ts_scripts/install_from_src.py;\ + python$PYTHON_VERSION -m pip install -r requirements/developer.txt \ + && python$PYTHON_VERSION ts_scripts/install_from_src.py;\ elif echo "$BUILD_NIGHTLY" | grep -q "false"; then \ - python -m pip install --no-cache-dir torchserve torch-model-archiver torch-workflow-archiver;\ + python$PYTHON_VERSION -m pip install --no-cache-dir torchserve torch-model-archiver torch-workflow-archiver;\ else \ - python -m pip install --no-cache-dir torchserve-nightly torch-model-archiver-nightly torch-workflow-archiver-nightly;\ + python$PYTHON_VERSION -m pip install --no-cache-dir torchserve-nightly torch-model-archiver-nightly torch-workflow-archiver-nightly;\ fi # Final image for production FROM ${BASE_IMAGE} AS production-image # Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) ARG PYTHON_VERSION -ARG USE_ROCM_VERSION ENV PYTHONUNBUFFERED TRUE +ARG USE_ROCM_VERSION RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository ppa:deadsnakes/ppa -y && \ - apt remove -y python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-setuptools \ @@ -129,13 +126,15 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=1009905 openjdk-17-jdk \ build-essential \ + wget \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp -RUN --mount=type=bind,from=compile-image,target=/mnt \ - if [ "$USE_ROCM_VERSION" ]; then \ +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + if [ -n "$USE_ROCM_VERSION" ]; then \ apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ && apt-get update \ && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ else \ @@ -149,13 +148,10 @@ COPY --chown=model-server --from=compile-image /home/venv /home/venv COPY --from=compile-image /usr/local/bin/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh ENV PATH="/home/venv/bin:$PATH" -COPY --from=compile-image /opt/rocm/share/amd_smi /opt/rocm/share/amd_smi - RUN \ - if [ "$USE_ROCM_VERSION" ]; then \ - python -m pip install /opt/rocm/share/amd_smi; \ - else \ - echo "Skip ROCm installation"; \ + if [ -n "$USE_ROCM_VERSION" ]; then \ + python$PYTHON_VERSION -m pip install -U pip setuptools \ + && python -m pip install /opt/rocm/share/amd_smi; \ fi RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ @@ -177,6 +173,7 @@ FROM ${BASE_IMAGE} AS ci-image # Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) ARG PYTHON_VERSION ARG BRANCH_NAME +ARG USE_ROCM_VERSION ENV PYTHONUNBUFFERED TRUE RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ @@ -184,7 +181,6 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository -y ppa:deadsnakes/ppa && \ - apt remove -y python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-setuptools \ @@ -204,10 +200,11 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp -RUN --mount=type=bind,from=compile-image,target=/mnt \ - if [ "$USE_ROCM_VERSION" ]; then \ +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + if [ -n "$USE_ROCM_VERSION" ]; then \ apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ && apt-get update \ && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ else \ @@ -215,19 +212,17 @@ RUN --mount=type=bind,from=compile-image,target=/mnt \ fi COPY --from=compile-image /home/venv /home/venv -ENV PATH="/home/venv/bin:$PATH" -RUN python -m pip install --no-cache-dir -r https://raw.githubusercontent.com/pytorch/serve/$BRANCH_NAME/requirements/developer.txt - -COPY --from=compile-image /opt/rocm/share/amd_smi /opt/rocm/share/amd_smi +ENV PATH="/home/venv/bin:$PATH" RUN \ - if [ "$USE_ROCM_VERSION" ]; then \ - python -m pip install /opt/rocm/share/amd_smi; \ - else \ - echo "Skip ROCm installation"; \ + if [ -n "$USE_ROCM_VERSION" ]; then \ + python$PYTHON_VERSION -m pip install -U pip setuptools \ + && python -m pip install /opt/rocm/share/amd_smi; \ fi +RUN python$PYTHON_VERSION -m pip install --no-cache-dir -r https://raw.githubusercontent.com/pytorch/serve/$BRANCH_NAME/requirements/developer.txt + RUN mkdir /serve ENV TS_RUN_IN_DOCKER True @@ -239,8 +234,8 @@ FROM ${BASE_IMAGE} as dev-image # Re-state ARG PYTHON_VERSION to make it active in this build-stage (uses default define at the top) ARG PYTHON_VERSION ARG BRANCH_NAME -ARG USE_ROCM_VERSION ARG BUILD_FROM_SRC +ARG USE_ROCM_VERSION ARG LOCAL_CHANGES ARG BUILD_WITH_IPEX ARG IPEX_VERSION=1.11.0 @@ -266,22 +261,29 @@ RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=1009905 openjdk-17-jdk \ build-essential \ + wget \ curl \ vim \ numactl \ + nodejs \ + npm \ + zip \ + unzip \ + && npm install -g newman@5.3.2 newman-reporter-htmlextra markdown-link-check \ && rm -rf /var/lib/apt/lists/* -RUN --mount=type=bind,from=compile-image,target=/mnt \ - if [ "$USE_ROCM_VERSION" ]; then \ +RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ + if [ -n "$USE_ROCM_VERSION" ]; then \ apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y ./mnt/amdgpu-install_6.2.60202-1_all.deb \ + && wget https://repo.radeon.com/amdgpu-install/6.2.2/ubuntu/noble/amdgpu-install_6.2.60202-1_all.deb \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ./amdgpu-install_6.2.60202-1_all.deb \ && apt-get update \ && apt-get install --no-install-recommends -y amdgpu-dkms rocm; \ else \ echo "Skip ROCm installation"; \ fi -COPY ./ /serve +COPY ./ serve RUN \ if echo "$LOCAL_CHANGES" | grep -q "false"; then \ @@ -292,20 +294,17 @@ RUN \ COPY --from=compile-image /home/venv /home/venv ENV PATH="/home/venv/bin:$PATH" -WORKDIR "/serve" - RUN \ - if [ "$USE_ROCM_VERSION" ]; then \ - python ts_scripts/install_dependencies.py --environment=dev --rocm $USE_ROCM_VERSION \ + if [ -n "$USE_ROCM_VERSION" ]; then \ + python$PYTHON_VERSION -m pip install -U pip setuptools \ && python -m pip install /opt/rocm/share/amd_smi; \ - # Install the binary with the latest CPU image on a ROCm base image - else \ - python ts_scripts/install_dependencies.py --environment=dev;\ - fi; + fi + +WORKDIR "serve" -RUN python -m pip install -U pip setuptools \ - && python -m pip install --no-cache-dir -r requirements/developer.txt \ - && python ts_scripts/install_from_src.py --environment=dev \ +RUN python$PYTHON_VERSION -m pip install -U pip setuptools \ + && python$PYTHON_VERSION -m pip install --no-cache-dir -r requirements/developer.txt \ + && python$PYTHON_VERSION ts_scripts/install_from_src.py --environment=dev \ && useradd -m model-server \ && mkdir -p /home/model-server/tmp \ && cp docker/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh \ From 8f35524cd4f720a8e19e12a17334063aa195ad73 Mon Sep 17 00:00:00 2001 From: jakki Date: Fri, 29 Nov 2024 15:45:04 +0200 Subject: [PATCH 32/33] Revert Dockerfile changes --- docker/Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 3a2ba23a98..94f4a1ba99 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -37,7 +37,7 @@ ARG BRANCH_NAME ARG REPO_URL=https://github.com/pytorch/serve.git ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,sharing=locked,id=apt-dev,target=/var/cache/apt \ +RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ @@ -112,12 +112,12 @@ FROM ${BASE_IMAGE} AS production-image ARG PYTHON_VERSION ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ +RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository ppa:deadsnakes/ppa -y && \ - apt remove -y python-pip python3-pip && \ + apt remove python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-distutils \ @@ -158,12 +158,12 @@ ARG PYTHON_VERSION ARG BRANCH_NAME ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ +RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ add-apt-repository -y ppa:deadsnakes/ppa && \ - apt remove -y python-pip python3-pip && \ + apt remove python-pip python3-pip && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ python$PYTHON_VERSION \ python3-distutils \ @@ -207,7 +207,7 @@ ARG BUILD_WITH_IPEX ARG IPEX_VERSION=1.11.0 ARG IPEX_URL=https://software.intel.com/ipex-whl-stable ENV PYTHONUNBUFFERED TRUE -RUN --mount=type=cache,sharing=locked,target=/var/cache/apt \ +RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && \ apt-get upgrade -y && \ apt-get install software-properties-common -y && \ From f5ce2ec074a138ca140bdc5377e855c6d4511dd4 Mon Sep 17 00:00:00 2001 From: jakki Date: Fri, 29 Nov 2024 16:01:56 +0200 Subject: [PATCH 33/33] Update documentation for Docker support --- docker/README.md | 1 + docs/hardware_support/amd_support.md | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/docker/README.md b/docker/README.md index beb0604e10..9e5ca8a229 100644 --- a/docker/README.md +++ b/docker/README.md @@ -164,6 +164,7 @@ Creates a docker image with `torchserve` and `torch-model-archiver` installed fr ./build_image.sh -bt dev -g [-cv cu121|cu118] -cpp ``` +- For ROCm support (*experimental*), refer to [this documentation](../docs/hardware_support/amd_support.md). ## Start a container with a TorchServe image diff --git a/docs/hardware_support/amd_support.md b/docs/hardware_support/amd_support.md index f406f980a6..55de40f6d4 100644 --- a/docs/hardware_support/amd_support.md +++ b/docs/hardware_support/amd_support.md @@ -60,6 +60,20 @@ If you have 8 accelerators but only want TorchServe to see the last four of them > ⚠️ Setting both `CUDA_VISIBLE_DEVICES` and `HIP_VISIBLE_DEVICES` may cause unintended behaviour and should be avoided. > Doing so may cause an exception in the future. +## Docker + +**In Development** + +`Dockerfile.rocm` provides preliminary ROCm support for TorchServe. + +Building and running `dev-image`: + +```bash +docker build --file docker/Dockerfile.rocm --target dev-image -t torch-serve-dev-image-rocm --build-arg USE_ROCM_VERSION=rocm62 --build-arg BUILD_FROM_SRC=true . + +docker run -it --rm --device=/dev/kfd --device=/dev/dri torch-serve-dev-image-rocm bash +``` + ## Example Usage After installing TorchServe with the required dependencies for ROCm you should be ready to serve your model.