diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 9345ecd89..dc89a349f 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -399,6 +399,161 @@ jobs: *-execution-combine.log secrets: inherit + test-nsys-jax-eks: + needs: build-jax + if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + runs-on: eks + env: + POST_PROCESS_JOB_DESCRIPTION: | + apiVersion: batch/v1 + kind: Job + metadata: + name: ${{ github.run_id }}-${{ github.run_attempt }}-postprocess + spec: + template: + spec: + restartPolicy: Never + initContainers: + - name: download + image: amazon/aws-cli + command: + - aws + - s3 + - cp + - --recursive + - --exclude + - "*" + - --include + - "${{ github.run_id }}-${{ github.run_attempt }}-rank*.zip" + - s3://jax-toolbox-eks-output/ + - /opt/output + volumeMounts: + - mountPath: /opt/output + name: output + containers: + - name: jax + image: ghcr.io/nvidia/jax:jax-2024-11-08 + command: + - bash + - -exo + - pipefail + - -c + - nsys-jax-combine -o /opt/output/combined.zip /opt/output/*.zip --analysis communication + # FIXME: GPU not actually needed, but the test cluster doesn't have appropriate non-GPU nodes + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - mountPath: /opt/output + name: output + volumes: + - name: output + emptyDir: {} + JOB_DESCRIPTION: | + apiVersion: v1 + kind: Service + metadata: + name: jax-headless-svc + spec: + clusterIP: None # clusterIP must be None to create a headless service + selector: + job-name: ${{ github.run_id }}-${{ github.run_attempt }}-jax # must match Job name + --- + apiVersion: batch/v1 + kind: Job + metadata: + name: ${{ github.run_id }}-${{ github.run_attempt }}-jax + spec: + completions: 2 # number of nodes + parallelism: 2 # number of nodes + completionMode: Indexed + template: + spec: + subdomain: jax-headless-svc # has to match Service name + restartPolicy: Never + containers: + - name: jax + image: ghcr.io/nvidia/jax:jax-2024-11-08 + ports: + - containerPort: 3389 + command: + - sh + - -c + - | + install-efa.sh + nsys-jax \ + --output=/opt/output/rank%q{JOB_COMPLETION_INDEX}.zip \ + -- \ + jax-nccl-test \ + --coordinator-address \ + ${{ github.run_id }}-${{ github.run_attempt }}-jax-0.jax-headless-svc:3389 \ + --distributed \ + --gpus-per-process=8 \ + --process-count=2 \ + --process-id=$JOB_COMPLETION_INDEX + touch /opt/output/.done + env: + - name: XLA_FLAGS + value: --xla_gpu_enable_command_buffer= + resources: + limits: + nvidia.com/gpu: 8 + vpc.amazonaws.com/efa: 32 + volumeMounts: + - mountPath: /dev/shm + name: shmem + - mountPath: /opt/output + name: output + - name: upload + image: amazon/aws-cli + command: + - sh + - -c + - | + while [[ ! -f /opt/output/.done ]]; do + sleep 1 + done + aws s3 cp \ + /opt/output/rank${JOB_COMPLETION_INDEX}.zip \ + s3://jax-toolbox-eks-output/${{ github.run_id }}-${{ github.run_attempt }}-rank${JOB_COMPLETION_INDEX}.zip + volumeMounts: + - mountPath: /opt/output + name: output + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory + sizeLimit: 8Gi + steps: + - name: Submit Kubernetes job + run: kubectl apply -f - <<< "${JOB_DESCRIPTION}" + - name: Wait for Kubernetes job to start + run: | + while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-jax --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do + sleep 2 + done + - name: Stream Kubernetes job output + run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-jax + # Clean up in case of errors as well as success + - name: Delete Kubernetes job + if: always() + run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-jax + - name: Submit post-processing Kubernetes job + run: kubectl apply -f - <<< "${POST_PROCESS_JOB_DESCRIPTION}" + - name: Wait for post-processing Kubernetes job to start + run: | + while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-postprocess --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do + sleep 2 + done + - name: Stream post-processing Kubernetes job output + run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-postprocess + # Clean up in case of errors as well as success + - name: Delete post-processing Kubernetes job + if: always() + run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-postprocess + # test-equinox: # needs: build-equinox # if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a