diff --git a/.github/eks-workflow-files/job.yml b/.github/eks-workflow-files/job.yml new file mode 100644 index 000000000..463f0ee31 --- /dev/null +++ b/.github/eks-workflow-files/job.yml @@ -0,0 +1,80 @@ +apiVersion: v1 +kind: Service +metadata: + name: jax-headless-svc +spec: + clusterIP: None # clusterIP must be None to create a headless service + selector: + job-name: PLACEHOLDER # must match Job name +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER +spec: + completions: 2 # number of nodes + parallelism: 2 # number of nodes + completionMode: Indexed + template: + spec: + subdomain: jax-headless-svc # has to match Service name + restartPolicy: Never + containers: + - name: jax + image: PLACEHOLDER + ports: + - containerPort: 3389 + command: + - sh + - -c + - | + install-efa.sh + nsys-jax \ + --output=/opt/output/${JOB_NAME}-rank%q{JOB_COMPLETION_INDEX}.zip \ + -- \ + jax-nccl-test \ + --coordinator-address \ + ${JOB_NAME}-0.jax-headless-svc:3389 \ + --distributed \ + --gpus-per-process=8 \ + --process-count=2 \ + --process-id=$JOB_COMPLETION_INDEX + touch /opt/output/.done + env: + - name: JOB_NAME + value: PLACEHOLDER + - name: XLA_FLAGS + value: --xla_gpu_enable_command_buffer= + resources: + limits: + nvidia.com/gpu: 8 + vpc.amazonaws.com/efa: 32 + volumeMounts: + - mountPath: /dev/shm + name: shmem + - mountPath: /opt/output + name: output + - name: upload + image: amazon/aws-cli + command: + - sh + - -c + - | + while [[ ! -f /opt/output/.done ]]; do + sleep 1 + done + aws s3 cp \ + /opt/output/*rank${JOB_COMPLETION_INDEX}.zip \ + s3://jax-toolbox-eks-output/ + volumeMounts: + - mountPath: /opt/output + name: output + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory + sizeLimit: 8Gi diff --git a/.github/eks-workflow-files/post-process-job.yml b/.github/eks-workflow-files/post-process-job.yml new file mode 100644 index 000000000..989ddebe2 --- /dev/null +++ b/.github/eks-workflow-files/post-process-job.yml @@ -0,0 +1,46 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER +spec: + template: + spec: + restartPolicy: Never + initContainers: + - name: download + image: amazon/aws-cli + command: + - aws + - s3 + - cp + - --recursive + - --exclude + - "*" + - --include + - PLACEHOLDER + - s3://jax-toolbox-eks-output/ + - /opt/output + volumeMounts: + - mountPath: /opt/output + name: output + containers: + - name: jax + image: PLACEHOLDER + command: + - bash + - -exo + - pipefail + - -c + - nsys-jax-combine -o /opt/output/combined.zip /opt/output/*.zip --analysis communication + # FIXME: GPU not actually needed, but the test cluster doesn't have appropriate non-GPU nodes + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - mountPath: /opt/output + name: output + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: output + emptyDir: {} diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index ca6bc0452..dba51a3cd 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -405,133 +405,19 @@ jobs: if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a runs-on: eks env: - POST_PROCESS_JOB_DESCRIPTION: | - apiVersion: batch/v1 - kind: Job - metadata: - name: ${{ github.run_id }}-${{ github.run_attempt }}-postprocess - spec: - template: - spec: - restartPolicy: Never - initContainers: - - name: download - image: amazon/aws-cli - command: - - aws - - s3 - - cp - - --recursive - - --exclude - - "*" - - --include - - "${{ github.run_id }}-${{ github.run_attempt }}-rank*.zip" - - s3://jax-toolbox-eks-output/ - - /opt/output - volumeMounts: - - mountPath: /opt/output - name: output - containers: - - name: jax - image: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} - command: - - bash - - -exo - - pipefail - - -c - - nsys-jax-combine -o /opt/output/combined.zip /opt/output/*.zip --analysis communication - # FIXME: GPU not actually needed, but the test cluster doesn't have appropriate non-GPU nodes - resources: - limits: - nvidia.com/gpu: 1 - volumeMounts: - - mountPath: /opt/output - name: output - imagePullSecrets: - - name: ${{ github.run_id }}-${{ github.run_attempt }}-token - volumes: - - name: output - emptyDir: {} - JOB_DESCRIPTION: | - apiVersion: v1 - kind: Service - metadata: - name: jax-headless-svc - spec: - clusterIP: None # clusterIP must be None to create a headless service - selector: - job-name: ${{ github.run_id }}-${{ github.run_attempt }}-jax # must match Job name - --- - apiVersion: batch/v1 - kind: Job - metadata: - name: ${{ github.run_id }}-${{ github.run_attempt }}-jax - spec: - completions: 2 # number of nodes - parallelism: 2 # number of nodes - completionMode: Indexed - template: - spec: - subdomain: jax-headless-svc # has to match Service name - restartPolicy: Never - containers: - - name: jax - image: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} - ports: - - containerPort: 3389 - command: - - sh - - -c - - | - install-efa.sh - nsys-jax \ - --output=/opt/output/rank%q{JOB_COMPLETION_INDEX}.zip \ - -- \ - jax-nccl-test \ - --coordinator-address \ - ${{ github.run_id }}-${{ github.run_attempt }}-jax-0.jax-headless-svc:3389 \ - --distributed \ - --gpus-per-process=8 \ - --process-count=2 \ - --process-id=$JOB_COMPLETION_INDEX - touch /opt/output/.done - env: - - name: XLA_FLAGS - value: --xla_gpu_enable_command_buffer= - resources: - limits: - nvidia.com/gpu: 8 - vpc.amazonaws.com/efa: 32 - volumeMounts: - - mountPath: /dev/shm - name: shmem - - mountPath: /opt/output - name: output - - name: upload - image: amazon/aws-cli - command: - - sh - - -c - - | - while [[ ! -f /opt/output/.done ]]; do - sleep 1 - done - aws s3 cp \ - /opt/output/rank${JOB_COMPLETION_INDEX}.zip \ - s3://jax-toolbox-eks-output/${{ github.run_id }}-${{ github.run_attempt }}-rank${JOB_COMPLETION_INDEX}.zip - volumeMounts: - - mountPath: /opt/output - name: output - imagePullSecrets: - - name: ${{ github.run_id }}-${{ github.run_attempt }}-token - volumes: - - name: output - emptyDir: {} - - name: shmem - emptyDir: - medium: Memory - sizeLimit: 8Gi + JAX_DOCKER_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} + JOB_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-jax + POSTPROCESS_JOB_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-postprocess + TOKEN_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-token steps: + - name: Check out the repository + uses: actions/checkout@v4 + - name: Install yq + run: | + mkdir local_bin/ + curl -L -o ./local_bin/yq https://github.com/mikefarah/yq/releases/latest/download/yq_linux_$(dpkg --print-architecture) + chmod 777 ./local_bin/yq + echo "${PWD}/local_bin" >> "${GITHUB_PATH}" - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: @@ -544,8 +430,17 @@ jobs: ${{ github.run_id }}-${{ github.run_attempt }}-token \ --from-file=.dockerconfigjson=$HOME/.docker/config.json \ --type=kubernetes.io/dockerconfigjson + - name: Configure Kubernetes job + run: | + yq -i ea 'select(di == 0).spec.selector.job-name = strenv(JOB_NAME) + | select(di == 1).metadata.name = strenv(JOB_NAME) + | select(di == 1).spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME) + | select(di == 1).spec.template.spec.containers[0].image = strenv(JAX_DOCKER_IMAGE) + | select(di == 1).spec.template.spec.containers[0].env[0].value = strenv(JOB_NAME)' \ + .github/eks-workflow-files/job.yml + git diff .github/eks-workflow-files/job.yml - name: Submit Kubernetes job - run: kubectl apply -f - <<< "${JOB_DESCRIPTION}" + run: kubectl apply -f .github/eks-workflow-files/job.yml - name: Wait for Kubernetes job to start run: | while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-jax --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do @@ -557,8 +452,17 @@ jobs: - name: Delete Kubernetes job if: always() run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-jax + - name: Configure post-processing job + run: | + export JOB_OUTPUT_PATTERN="${JOB_NAME}-rank*.zip" + yq -i '.metadata.name = strenv(POSTPROCESS_JOB_NAME) + | .spec.template.spec.containers[].image = strenv(JAX_DOCKER_IMAGE) + | .spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME) + | .spec.template.spec.initContainers[].command[7] = strenv(JOB_OUTPUT_PATTERN)' \ + .github/eks-workflow-files/post-process-job.yml + git diff .github/eks-workflow-files/post-process-job.yml - name: Submit post-processing Kubernetes job - run: kubectl apply -f - <<< "${POST_PROCESS_JOB_DESCRIPTION}" + run: kubectl apply -f .github/eks-workflow-files/post-process-job.yml - name: Wait for post-processing Kubernetes job to start run: | while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-postprocess --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do