Skip to content

Commit

Permalink
CI: example job using EKS jumphost
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Nov 12, 2024
1 parent 61d8446 commit 9e03ef0
Showing 1 changed file with 155 additions and 0 deletions.
155 changes: 155 additions & 0 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,161 @@ jobs:
*-execution-combine.log
secrets: inherit

test-nsys-jax-eks:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
runs-on: eks
env:
POST_PROCESS_JOB_DESCRIPTION: |
apiVersion: batch/v1
kind: Job
metadata:
name: ${{ github.run_id }}-${{ github.run_attempt }}-postprocess
spec:
template:
spec:
restartPolicy: Never
initContainers:
- name: download
image: amazon/aws-cli
command:
- aws
- s3
- cp
- --recursive
- --exclude
- "*"
- --include
- "${{ github.run_id }}-${{ github.run_attempt }}-rank*.zip"
- s3://jax-toolbox-eks-output/
- /opt/output
volumeMounts:
- mountPath: /opt/output
name: output
containers:
- name: jax
image: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
command:
- bash
- -exo
- pipefail
- -c
- nsys-jax-combine -o /opt/output/combined.zip /opt/output/*.zip --analysis communication
# FIXME: GPU not actually needed, but the test cluster doesn't have appropriate non-GPU nodes
resources:
limits:
nvidia.com/gpu: 1
volumeMounts:
- mountPath: /opt/output
name: output
volumes:
- name: output
emptyDir: {}
JOB_DESCRIPTION: |
apiVersion: v1
kind: Service
metadata:
name: jax-headless-svc
spec:
clusterIP: None # clusterIP must be None to create a headless service
selector:
job-name: ${{ github.run_id }}-${{ github.run_attempt }}-jax # must match Job name
---
apiVersion: batch/v1
kind: Job
metadata:
name: ${{ github.run_id }}-${{ github.run_attempt }}-jax
spec:
completions: 2 # number of nodes
parallelism: 2 # number of nodes
completionMode: Indexed
template:
spec:
subdomain: jax-headless-svc # has to match Service name
restartPolicy: Never
containers:
- name: jax
image: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
ports:
- containerPort: 3389
command:
- sh
- -c
- |
install-efa.sh
nsys-jax \
--output=/opt/output/rank%q{JOB_COMPLETION_INDEX}.zip \
-- \
jax-nccl-test \
--coordinator-address \
${{ github.run_id }}-${{ github.run_attempt }}-jax-0.jax-headless-svc:3389 \
--distributed \
--gpus-per-process=8 \
--process-count=2 \
--process-id=$JOB_COMPLETION_INDEX
touch /opt/output/.done
env:
- name: XLA_FLAGS
value: --xla_gpu_enable_command_buffer=
resources:
limits:
nvidia.com/gpu: 8
vpc.amazonaws.com/efa: 32
volumeMounts:
- mountPath: /dev/shm
name: shmem
- mountPath: /opt/output
name: output
- name: upload
image: amazon/aws-cli
command:
- sh
- -c
- |
while [[ ! -f /opt/output/.done ]]; do
sleep 1
done
aws s3 cp \
/opt/output/rank${JOB_COMPLETION_INDEX}.zip \
s3://jax-toolbox-eks-output/${{ github.run_id }}-${{ github.run_attempt }}-rank${JOB_COMPLETION_INDEX}.zip
volumeMounts:
- mountPath: /opt/output
name: output
volumes:
- name: output
emptyDir: {}
- name: shmem
emptyDir:
medium: Memory
sizeLimit: 8Gi
steps:
- name: Submit Kubernetes job
run: kubectl apply -f - <<< "${JOB_DESCRIPTION}"
- name: Wait for Kubernetes job to start
run: |
while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-jax --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do
sleep 2
done
- name: Stream Kubernetes job output
run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-jax
# Clean up in case of errors as well as success
- name: Delete Kubernetes job
if: always()
run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-jax
- name: Submit post-processing Kubernetes job
run: kubectl apply -f - <<< "${POST_PROCESS_JOB_DESCRIPTION}"
- name: Wait for post-processing Kubernetes job to start
run: |
while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-postprocess --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do
sleep 2
done
- name: Stream post-processing Kubernetes job output
run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-postprocess
# Clean up in case of errors as well as success
- name: Delete post-processing Kubernetes job
if: always()
run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-postprocess

# test-equinox:
# needs: build-equinox
# if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
Expand Down

0 comments on commit 9e03ef0

Please sign in to comment.